]> andersk Git - openssh.git/blobdiff - serverloop.c
- djm@cvs.openbsd.org 2001/12/20 22:50:24
[openssh.git] / serverloop.c
index 049ea4e463f6ba0be629f8ee34f69bde76a90d1b..0754fe76fd43de2e12cb8e913d31ad816a7b6fc1 100644 (file)
@@ -35,7 +35,7 @@
  */
 
 #include "includes.h"
-RCSID("$OpenBSD: serverloop.c,v 1.78 2001/10/04 15:05:40 markus Exp $");
+RCSID("$OpenBSD: serverloop.c,v 1.88 2001/12/20 22:50:24 djm Exp $");
 
 #include "xmalloc.h"
 #include "packet.h"
@@ -80,18 +80,62 @@ static int connection_in;   /* Connection to client (input). */
 static int connection_out;     /* Connection to client (output). */
 static int connection_closed = 0;      /* Connection to client closed. */
 static u_int buffer_high;      /* "Soft" max buffer size. */
+static int client_alive_timeouts = 0;
 
 /*
  * This SIGCHLD kludge is used to detect when the child exits.  The server
  * will exit after that, as soon as forwarded connections have terminated.
  */
 
-static volatile int child_terminated;  /* The child has terminated. */
+static volatile sig_atomic_t child_terminated = 0;     /* The child has terminated. */
 
 /* prototypes */
 static void server_init_dispatch(void);
 
-int client_alive_timeouts = 0;
+/*
+ * we write to this pipe if a SIGCHLD is caught in order to avoid
+ * the race between select() and child_terminated
+ */
+static int notify_pipe[2];
+static void
+notify_setup(void)
+{
+       if (pipe(notify_pipe) < 0) {
+               error("pipe(notify_pipe) failed %s", strerror(errno));
+       } else if ((fcntl(notify_pipe[0], F_SETFD, 1) == -1) ||
+           (fcntl(notify_pipe[1], F_SETFD, 1) == -1)) {
+               error("fcntl(notify_pipe, F_SETFD) failed %s", strerror(errno));
+               close(notify_pipe[0]);
+               close(notify_pipe[1]);
+       } else {
+               set_nonblock(notify_pipe[0]);
+               set_nonblock(notify_pipe[1]);
+               return;
+       }
+       notify_pipe[0] = -1;    /* read end */
+       notify_pipe[1] = -1;    /* write end */
+}
+static void
+notify_parent(void)
+{
+       if (notify_pipe[1] != -1)
+               write(notify_pipe[1], "", 1);
+}
+static void
+notify_prepare(fd_set *readset)
+{
+       if (notify_pipe[0] != -1)
+               FD_SET(notify_pipe[0], readset);
+}
+static void
+notify_done(fd_set *readset)
+{
+       char c;
+
+       if (notify_pipe[0] != -1 && FD_ISSET(notify_pipe[0], readset))
+               while (read(notify_pipe[0], &c, 1) != -1)
+                       debug2("notify_done: reading");
+}
 
 static void
 sigchld_handler(int sig)
@@ -100,6 +144,7 @@ sigchld_handler(int sig)
        debug("Received SIGCHLD.");
        child_terminated = 1;
        mysignal(SIGCHLD, sigchld_handler);
+       notify_parent();
        errno = save_errno;
 }
 
@@ -161,6 +206,26 @@ make_packets_from_stdout_data(void)
        }
 }
 
+static void
+client_alive_check(void)
+{
+       int id;
+
+       /* timeout, check to see how many we have had */
+       if (++client_alive_timeouts > options.client_alive_count_max)
+               packet_disconnect("Timeout, your session not responding.");
+
+       id = channel_find_open();
+       if (id == -1)
+               packet_disconnect("No open channels after timeout!");
+       /*
+        * send a bogus channel request with "wantreply",
+        * we should get back a failure
+        */
+       channel_request_start(id, "keepalive@openssh.com", 1);
+       packet_send();
+}
+
 /*
  * Sleep in select() until we can do something.  This will initialize the
  * select masks.  Upon return, the masks will indicate which descriptors
@@ -176,12 +241,12 @@ wait_until_can_do_something(fd_set **readsetp, fd_set **writesetp, int *maxfdp,
        int client_alive_scheduled = 0;
 
        /*
-        * if using client_alive, set the max timeout accordingly, 
+        * if using client_alive, set the max timeout accordingly,
         * and indicate that this particular timeout was for client
         * alive by setting the client_alive_scheduled flag.
         *
         * this could be randomized somewhat to make traffic
-        * analysis more difficult, but we're not doing it yet.  
+        * analysis more difficult, but we're not doing it yet.
         */
        if (compat20 &&
            max_time_milliseconds == 0 && options.client_alive_interval) {
@@ -189,9 +254,6 @@ wait_until_can_do_something(fd_set **readsetp, fd_set **writesetp, int *maxfdp,
                max_time_milliseconds = options.client_alive_interval * 1000;
        }
 
-       /* When select fails we restart from here. */
-retry_select:
-
        /* Allocate and update select() masks for channel descriptors. */
        channel_prepare_select(readsetp, writesetp, maxfdp, nallocp, 0);
 
@@ -226,6 +288,7 @@ retry_select:
                if (fdin != -1 && buffer_len(&stdin_buffer) > 0)
                        FD_SET(fdin, *writesetp);
        }
+       notify_prepare(*readsetp);
 
        /*
         * If we have buffered packet data going to the client, mark that
@@ -250,41 +313,21 @@ retry_select:
                tvp = &tv;
        }
        if (tvp!=NULL)
-               debug3("tvp!=NULL kid %d mili %d", child_terminated, max_time_milliseconds);
+               debug3("tvp!=NULL kid %d mili %d", (int) child_terminated,
+                   max_time_milliseconds);
 
        /* Wait for something to happen, or the timeout to expire. */
        ret = select((*maxfdp)+1, *readsetp, *writesetp, NULL, tvp);
 
        if (ret == -1) {
+               memset(*readsetp, 0, *nallocp);
+               memset(*writesetp, 0, *nallocp);
                if (errno != EINTR)
                        error("select: %.100s", strerror(errno));
-               else
-                       goto retry_select;
-       }
-       if (ret == 0 && client_alive_scheduled) {
-               /* timeout, check to see how many we have had */
-               client_alive_timeouts++;
+       } else if (ret == 0 && client_alive_scheduled)
+               client_alive_check();
 
-               if (client_alive_timeouts > options.client_alive_count_max ) {
-                       packet_disconnect(
-                               "Timeout, your session not responding.");
-               } else {
-                       /*
-                        * send a bogus channel request with "wantreply" 
-                        * we should get back a failure
-                        */
-                       int id;
-                       
-                       id = channel_find_open();
-                       if (id != -1) {
-                               channel_request_start(id,
-                                 "keepalive@openssh.com", 1);
-                               packet_send();
-                       } else 
-                               packet_disconnect(
-                                       "No open channels after timeout!");
-               }
-       } 
+       notify_done(*readsetp);
 }
 
 /*
@@ -474,6 +517,8 @@ server_loop(pid_t pid, int fdin_arg, int fdout_arg, int fderr_arg)
        connection_in = packet_get_connection_in();
        connection_out = packet_get_connection_out();
 
+       notify_setup();
+
        previous_stdout_buffer_bytes = 0;
 
        /* Set approximate I/O buffer size. */
@@ -579,6 +624,7 @@ server_loop(pid_t pid, int fdin_arg, int fdout_arg, int fderr_arg)
                max_fd = MAX(max_fd, fdin);
                max_fd = MAX(max_fd, fdout);
                max_fd = MAX(max_fd, fderr);
+               max_fd = MAX(max_fd, notify_pipe[0]);
 
                /* Sleep in select() until we can do something. */
                wait_until_can_do_something(&readset, &writeset, &max_fd,
@@ -604,7 +650,7 @@ server_loop(pid_t pid, int fdin_arg, int fdout_arg, int fderr_arg)
        drain_output();
 
        debug("End of interactive session; stdin %ld, stdout (read %ld, sent %ld), stderr %ld bytes.",
-             stdin_bytes, fdout_bytes, stdout_bytes, stderr_bytes);
+           stdin_bytes, fdout_bytes, stdout_bytes, stderr_bytes);
 
        /* Free and clear the buffers. */
        buffer_free(&stdin_buffer);
@@ -671,12 +717,30 @@ server_loop(pid_t pid, int fdin_arg, int fdout_arg, int fderr_arg)
        /* NOTREACHED */
 }
 
+static void
+collect_children(void)
+{
+       pid_t pid;
+       sigset_t oset, nset;
+       int status;
+
+       /* block SIGCHLD while we check for dead children */
+       sigemptyset(&nset);
+       sigaddset(&nset, SIGCHLD);
+       sigprocmask(SIG_BLOCK, &nset, &oset);
+       if (child_terminated) {
+               while ((pid = waitpid(-1, &status, WNOHANG)) > 0)
+                       session_close_by_pid(pid, status);
+               child_terminated = 0;
+       }
+       sigprocmask(SIG_SETMASK, &oset, NULL);
+}
+
 void
 server_loop2(Authctxt *authctxt)
 {
        fd_set *readset = NULL, *writeset = NULL;
-       int rekeying = 0, max_fd, status, nalloc = 0;
-       pid_t pid;
+       int rekeying = 0, max_fd, nalloc = 0;
 
        debug("Entering interactive session for SSH2.");
 
@@ -685,7 +749,11 @@ server_loop2(Authctxt *authctxt)
        connection_in = packet_get_connection_in();
        connection_out = packet_get_connection_out();
 
+       notify_setup();
+
        max_fd = MAX(connection_in, connection_out);
+       max_fd = MAX(max_fd, notify_pipe[0]);
+
        xxx_authctxt = authctxt;
 
        server_init_dispatch();
@@ -699,11 +767,8 @@ server_loop2(Authctxt *authctxt)
                        channel_output_poll();
                wait_until_can_do_something(&readset, &writeset, &max_fd,
                    &nalloc, 0);
-               if (child_terminated) {
-                       while ((pid = waitpid(-1, &status, WNOHANG)) > 0)
-                               session_close_by_pid(pid, status);
-                       child_terminated = 0;
-               }
+
+               collect_children();
                if (!rekeying)
                        channel_after_select(readset, writeset);
                process_input(readset);
@@ -711,49 +776,35 @@ server_loop2(Authctxt *authctxt)
                        break;
                process_output(writeset);
        }
+       collect_children();
+
        if (readset)
                xfree(readset);
        if (writeset)
                xfree(writeset);
 
-       mysignal(SIGCHLD, SIG_DFL);
-
-       while ((pid = waitpid(-1, &status, WNOHANG)) > 0)
-               session_close_by_pid(pid, status);
-       /*
-        * there is a race between channel_free_all() killing children and
-        * children dying before kill()
-        */
-       channel_detach_all();
-       channel_stop_listening();
-
-       while (session_have_children()) {
-               pid = waitpid(-1, &status, 0);
-               if (pid > 0)
-                       session_close_by_pid(pid, status);
-               else {
-                       error("waitpid returned %d: %s", pid, strerror(errno));
-                       break;
-               }
-       }
+       /* free all channels, no more reads and writes */
        channel_free_all();
+
+       /* free remaining sessions, e.g. remove wtmp entries */
+       session_destroy_all();
 }
 
 static void
-server_input_channel_failure(int type, int plen, void *ctxt)
+server_input_channel_failure(int type, int plen, u_int32_t seq, void *ctxt)
 {
        debug("Got CHANNEL_FAILURE for keepalive");
-       /* 
+       /*
         * reset timeout, since we got a sane answer from the client.
         * even if this was generated by something other than
         * the bogus CHANNEL_REQUEST we send for keepalives.
         */
-       client_alive_timeouts = 0; 
+       client_alive_timeouts = 0;
 }
 
 
 static void
-server_input_stdin_data(int type, int plen, void *ctxt)
+server_input_stdin_data(int type, int plen, u_int32_t seq, void *ctxt)
 {
        char *data;
        u_int data_len;
@@ -770,7 +821,7 @@ server_input_stdin_data(int type, int plen, void *ctxt)
 }
 
 static void
-server_input_eof(int type, int plen, void *ctxt)
+server_input_eof(int type, int plen, u_int32_t seq, void *ctxt)
 {
        /*
         * Eof from the client.  The stdin descriptor to the
@@ -783,7 +834,7 @@ server_input_eof(int type, int plen, void *ctxt)
 }
 
 static void
-server_input_window_size(int type, int plen, void *ctxt)
+server_input_window_size(int type, int plen, u_int32_t seq, void *ctxt)
 {
        int row = packet_get_int();
        int col = packet_get_int();
@@ -861,7 +912,7 @@ server_request_session(char *ctype)
 }
 
 static void
-server_input_channel_open(int type, int plen, void *ctxt)
+server_input_channel_open(int type, int plen, u_int32_t seq, void *ctxt)
 {
        Channel *c = NULL;
        char *ctype;
@@ -911,7 +962,7 @@ server_input_channel_open(int type, int plen, void *ctxt)
 }
 
 static void
-server_input_global_request(int type, int plen, void *ctxt)
+server_input_global_request(int type, int plen, u_int32_t seq, void *ctxt)
 {
        char *rtype;
        int want_reply;
This page took 1.19786 seconds and 4 git commands to generate.