]> andersk Git - moira.git/blobdiff - lib/mr_connect.c
watch out for EINTR when accept()ing
[moira.git] / lib / mr_connect.c
index 7c8624decf2c4b0393763be4938c40b56779f757..3dcd4af40c528db1fbc11df1154042fea3ca7a88 100644 (file)
@@ -76,7 +76,6 @@ static char response[53] = "\0\0\0\061\0\0\0\003\0\001\001disposition\0server_id
 int mr_connect(char *server)
 {
   char *port, **pp, *sbuf = NULL;
-  struct hostent *shost;
 
   if (_mr_conn)
     return MR_ALREADY_CONNECTED;
@@ -98,18 +97,14 @@ int mr_connect(char *server)
   if (!server || (strlen(server) == 0))
     server = MOIRA_SERVER;
 
-  shost = gethostbyname(server);
-  if (!shost)
-    return MR_CANT_CONNECT;
-
   if (strchr(server, ':'))
     {
       int len = strcspn(server, ":");
       sbuf = malloc(len + 1);
       strncpy(sbuf, server, len);
-      sbuf[len - 1] = '\0';
-      server = sbuf;
+      sbuf[len] = '\0';
       port = strchr(server, ':') + 1;
+      server = sbuf;
     }
   else
     port = strchr(MOIRA_SERVER, ':') + 1;
@@ -119,8 +114,6 @@ int mr_connect(char *server)
   if (!_mr_conn)
     return MR_CANT_CONNECT;
 
-  /* stash hostname for later use */
-  mr_server_host = strdup(shost->h_name);
   return MR_SUCCESS;
 }
 
@@ -136,7 +129,7 @@ int mr_connect_internal(char *server, char *port)
     return 0;
 
   if (port[0] == '#')
-    target.sin_port = atoi(port + 1);
+    target.sin_port = htons(atoi(port + 1));
   else
     {
       struct servent *s;
@@ -170,7 +163,7 @@ int mr_connect_internal(char *server, char *port)
   for (size = 0; size < sizeof(actualresponse); size += more)
     {
       more = read(fd, actualresponse + size, sizeof(actualresponse) - size);
-      if (!more)
+      if (more <= 0)
        break;
     }
   if (size != sizeof(actualresponse))
@@ -184,6 +177,8 @@ int mr_connect_internal(char *server, char *port)
       return 0;
     }
 
+  mr_server_host = strdup(shost->h_name);
+
   /* You win */
   return fd;
 }
@@ -270,61 +265,88 @@ int mr_listen(char *port)
   return s;
 }
 
+/* mr_accept returns -1 on accept() error, 0 on bad connection,
+   or connection fd on success */
+
 int mr_accept(int s, struct sockaddr_in *sin)
 {
-  int conn, addrlen = sizeof(struct sockaddr_in);
-  char lbuf[4], *buf;
-  long len, size, more;
-
-  conn = accept(s, (struct sockaddr *)sin, &addrlen);
-  if (conn < 0)
-    return -1;
-
-  /* Now do mrgdb accept protocol */
-  /* XXX timeout */
+  int conn = -1, addrlen = sizeof(struct sockaddr_in), nread, status;
+  char *buf = NULL;
 
-  if (read(conn, lbuf, 4) != 4)
+  while (conn < 0)
     {
-      close(conn);
-      return -1;
+      conn = accept(s, (struct sockaddr *)sin, &addrlen);
+      if (conn < 0 && errno != EINTR)
+       return -1;
     }
-  getlong(lbuf, len);
 
-  buf = malloc(len);
-  if (!buf || len < 54)
+  do
+    status = mr_cont_accept(conn, &buf, &nread);
+  while (status == -1);
+
+  return status;
+}
+
+/* mr_cont_accept returns 0 if it has failed, an fd if it has succeeded,
+   or -1 if it is still making progress */
+
+int mr_cont_accept(int conn, char **buf, int *nread)
+{
+  long len, more;
+
+  if (!*buf)
     {
-      close(conn);
-      free(buf);
+      char lbuf[4];
+      if (read(conn, lbuf, 4) != 4)
+       {
+         close(conn);
+         return 0;
+       }
+      getlong(lbuf, len);
+      len += 4;
+
+      *buf = malloc(len);
+      if (!*buf || len < 58)
+       {
+         close(conn);
+         free(*buf);
+         return 0;
+       }
+      putlong(*buf, len);
+      *nread = 4;
       return -1;
     }
+  else
+    getlong(*buf, len);
 
-  for (size = 0; size < len; size += more)
-    {
-      more = read(conn, buf + size, len - size);
-      if (!more)
-       break;
-    }
-  if (size != len)
+  more = read(conn, *buf + *nread, len - *nread);
+
+  if (more == -1 && errno != EINTR)
     {
       close(conn);
-      free(buf);
+      free(*buf);
       return 0;
     }
 
-  if (memcmp(buf, challenge + 4, 34))
+  *nread += more;
+
+  if (*nread != len)
+    return -1;
+
+  if (memcmp(*buf + 4, challenge + 4, 34))
     {
       close(conn);
-      free(buf);
+      free(*buf);
       return 0;
     }
 
   /* good enough */
-  free(buf);
+  free(*buf);
 
   if (write(conn, response, sizeof(response)) != sizeof(response))
     {
       close(conn);
-      return -1;
+      return 0;
     }
   return conn;
 }
This page took 0.040043 seconds and 4 git commands to generate.