#include "mr_private.h"
#include <errno.h>
-#include <netinet/in.h>
#include <stdlib.h>
#include <string.h>
+#ifndef _WIN32
+#include <netinet/in.h>
+#endif /* _WIN32 */
+
+#ifdef HAVE_UNISTD_H
#include <unistd.h>
+#endif
RCSID("$Header$");
int mr_send(int fd, struct mr_params *params)
{
- u_long length, written;
+ u_long length;
+ int written;
int i, *argl;
char *buf, *p;
length = p - buf;
putlong(buf, length);
- written = write(fd, buf, length);
+ written = send(fd, buf, length, 0);
free(buf);
if (!params->mr_argl)
free(argl);
- if (written != length)
+ if (written != (int)length)
return MR_ABORTED;
else
return MR_SUCCESS;
int mr_receive(int fd, struct mr_params *reply)
{
- u_long length, data;
- ssize_t size, more;
- char *p;
- int i;
+ int status;
memset(reply, 0, sizeof(struct mr_params));
+ do
+ status = mr_cont_receive(fd, reply);
+ while (status == -1);
- size = read(fd, &data, 4);
- if (size != 4)
- return size ? MR_ABORTED : MR_NOT_CONNECTED;
- length = ntohl(data) - 4;
- reply->mr_flattened = malloc(length);
- if (!reply->mr_flattened)
- return ENOMEM;
+ return status;
+}
+
+/* Read some or all of a client response, without losing if it turns
+ * out to be malformed. Returns MR_SUCCESS on success, an error code
+ * on failure, or -1 if the packet hasn't been completely received
+ * yet.
+ */
- for (size = 0; size < length; size += more)
+int mr_cont_receive(int fd, struct mr_params *reply)
+{
+ u_long length, data;
+ int size, more;
+ char *p, *end;
+ int i;
+
+ if (!reply->mr_flattened)
{
- more = read(fd, reply->mr_flattened + size, length - size);
- if (!more)
- break;
+ char lbuf[4];
+
+ size = recv(fd, lbuf, 4, 0);
+ if (size != 4)
+ return size ? MR_ABORTED : MR_NOT_CONNECTED;
+ getlong(lbuf, length);
+ if (length > 8192)
+ return MR_INTERNAL;
+ reply->mr_flattened = malloc(length);
+ if (!reply->mr_flattened)
+ return ENOMEM;
+ memcpy(reply->mr_flattened, lbuf, 4);
+ reply->mr_filled = 4;
+
+ return -1;
}
- if (size != length)
+ else
+ getlong(reply->mr_flattened, length);
+
+ more = recv(fd, reply->mr_flattened + reply->mr_filled,
+ length - reply->mr_filled, 0);
+ if (more == -1)
{
mr_destroy_reply(*reply);
return MR_ABORTED;
}
- getlong(reply->mr_flattened, data);
+ reply->mr_filled += more;
+
+ if (reply->mr_filled != length)
+ return -1;
+
+ getlong(reply->mr_flattened + 4, data);
if (data != MR_VERSION_2)
{
mr_destroy_reply(*reply);
return MR_VERSION_MISMATCH;
}
- getlong(reply->mr_flattened + 4, reply->u.mr_status);
- getlong(reply->mr_flattened + 8, reply->mr_argc);
+ getlong(reply->mr_flattened + 8, reply->u.mr_status);
+ getlong(reply->mr_flattened + 12, reply->mr_argc);
+ if (reply->mr_argc > ((int)length - 16) / 8)
+ {
+ mr_destroy_reply(*reply);
+ return MR_INTERNAL;
+ }
reply->mr_argv = malloc(reply->mr_argc * sizeof(char *));
reply->mr_argl = malloc(reply->mr_argc * sizeof(int));
if (reply->mr_argc && (!reply->mr_argv || !reply->mr_argl))
return ENOMEM;
}
- for (i = 0, p = reply->mr_flattened + 12; i < reply->mr_argc; i++)
+ p = (char *)reply->mr_flattened + 16;
+ end = (char *)reply->mr_flattened + length;
+ for (i = 0; i < reply->mr_argc && p + 4 <= end; i++)
{
getlong(p, reply->mr_argl[i]);
+ if (p + 4 + reply->mr_argl[i] > end)
+ break;
reply->mr_argv[i] = p + 4;
p += 4 + reply->mr_argl[i] + (4 - reply->mr_argl[i] % 4) % 4;
}
+ if (i != reply->mr_argc)
+ {
+ mr_destroy_reply(*reply);
+ return MR_INTERNAL;
+ }
+
return MR_SUCCESS;
}