]> andersk Git - gssapi-openssh.git/blobdiff - openssh/sshconnect.c
Import of OpenSSH 4.9p1
[gssapi-openssh.git] / openssh / sshconnect.c
index a222233d0aa99b29f48d2bed061ee3d318ab5a3e..a604c9724aa0512f2b7732ab509c1142364fd6dc 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: sshconnect.c,v 1.200 2006/10/10 10:12:45 markus Exp $ */
+/* $OpenBSD: sshconnect.c,v 1.203 2007/12/27 14:22:08 dtucker Exp $ */
 /*
  * Author: Tatu Ylonen <ylo@cs.hut.fi>
  * Copyright (c) 1995 Tatu Ylonen <ylo@cs.hut.fi>, Espoo, Finland
@@ -77,6 +77,23 @@ extern pid_t proxy_command_pid;
 static int show_other_keys(const char *, Key *);
 static void warn_changed_key(Key *);
 
+static void
+ms_subtract_diff(struct timeval *start, int *ms)
+{
+       struct timeval diff, finish;
+
+       gettimeofday(&finish, NULL);
+       timersub(&finish, start, &diff);        
+       *ms -= (diff.tv_sec * 1000) + (diff.tv_usec / 1000);
+}
+
+static void
+ms_to_timeval(struct timeval *tv, int ms)
+{
+       tv->tv_sec = ms / 1000;
+       tv->tv_usec = (ms % 1000) * 1000;
+}
+
 /*
  * Connect to the given ssh server using a proxy command.
  */
@@ -86,7 +103,10 @@ ssh_proxy_connect(const char *host, u_short port, const char *proxy_command)
        char *command_string, *tmp;
        int pin[2], pout[2];
        pid_t pid;
-       char strport[NI_MAXSERV];
+       char *shell, strport[NI_MAXSERV];
+
+       if ((shell = getenv("SHELL")) == NULL)
+               shell = _PATH_BSHELL;
 
        /* Convert the port number into a string. */
        snprintf(strport, sizeof strport, "%hu", port);
@@ -132,7 +152,7 @@ ssh_proxy_connect(const char *host, u_short port, const char *proxy_command)
 
                /* Stderr is left as it is so that error messages get
                   printed on the user's terminal. */
-               argv[0] = _PATH_BSHELL;
+               argv[0] = shell;
                argv[1] = "-c";
                argv[2] = command_string;
                argv[3] = NULL;
@@ -204,7 +224,7 @@ ssh_create_socket(int privileged, struct addrinfo *ai)
        gaierr = getaddrinfo(options.bind_address, "0", &hints, &res);
        if (gaierr) {
                error("getaddrinfo: %s: %s", options.bind_address,
-                   gai_strerror(gaierr));
+                   ssh_gai_strerror(gaierr));
                close(sock);
                return -1;
        }
@@ -220,30 +240,36 @@ ssh_create_socket(int privileged, struct addrinfo *ai)
 
 static int
 timeout_connect(int sockfd, const struct sockaddr *serv_addr,
-    socklen_t addrlen, int timeout)
+    socklen_t addrlen, int *timeoutp)
 {
        fd_set *fdset;
-       struct timeval tv;
+       struct timeval tv, t_start;
        socklen_t optlen;
        int optval, rc, result = -1;
 
-       if (timeout <= 0)
-               return (connect(sockfd, serv_addr, addrlen));
+       gettimeofday(&t_start, NULL);
+
+       if (*timeoutp <= 0) {
+               result = connect(sockfd, serv_addr, addrlen);
+               goto done;
+       }
 
        set_nonblock(sockfd);
        rc = connect(sockfd, serv_addr, addrlen);
        if (rc == 0) {
                unset_nonblock(sockfd);
-               return (0);
+               result = 0;
+               goto done;
+       }
+       if (errno != EINPROGRESS) {
+               result = -1;
+               goto done;
        }
-       if (errno != EINPROGRESS)
-               return (-1);
 
        fdset = (fd_set *)xcalloc(howmany(sockfd + 1, NFDBITS),
            sizeof(fd_mask));
        FD_SET(sockfd, fdset);
-       tv.tv_sec = timeout;
-       tv.tv_usec = 0;
+       ms_to_timeval(&tv, *timeoutp);
 
        for (;;) {
                rc = select(sockfd + 1, NULL, fdset, NULL, &tv);
@@ -282,6 +308,16 @@ timeout_connect(int sockfd, const struct sockaddr *serv_addr,
        }
 
        xfree(fdset);
+
+ done:
+       if (result == 0 && *timeoutp > 0) {
+               ms_subtract_diff(&t_start, timeoutp);
+               if (*timeoutp <= 0) {
+                       errno = ETIMEDOUT;
+                       result = -1;
+               }
+       }
+
        return (result);
 }
 
@@ -298,8 +334,8 @@ timeout_connect(int sockfd, const struct sockaddr *serv_addr,
  */
 int
 ssh_connect(const char *host, struct sockaddr_storage * hostaddr,
-    u_short port, int family, int connection_attempts,
-    int needpriv, const char *proxy_command)
+    u_short port, int family, int connection_attempts, int *timeout_ms,
+    int want_keepalive, int needpriv, const char *proxy_command)
 {
        int gaierr;
        int on = 1;
@@ -320,8 +356,8 @@ ssh_connect(const char *host, struct sockaddr_storage * hostaddr,
        hints.ai_socktype = SOCK_STREAM;
        snprintf(strport, sizeof strport, "%u", port);
        if ((gaierr = getaddrinfo(host, strport, &hints, &aitop)) != 0)
-               fatal("%s: %.100s: %s", __progname, host,
-                   gai_strerror(gaierr));
+               fatal("%s: Could not resolve hostname %.100s: %s", __progname,
+                   host, ssh_gai_strerror(gaierr));
 
        for (attempt = 0; attempt < connection_attempts; attempt++) {
                if (attempt > 0) {
@@ -352,7 +388,7 @@ ssh_connect(const char *host, struct sockaddr_storage * hostaddr,
                                continue;
 
                        if (timeout_connect(sock, ai->ai_addr, ai->ai_addrlen,
-                           options.connection_timeout) >= 0) {
+                           timeout_ms) >= 0) {
                                /* Successful connection. */
                                memcpy(hostaddr, ai->ai_addr, ai->ai_addrlen);
                                break;
@@ -379,7 +415,7 @@ ssh_connect(const char *host, struct sockaddr_storage * hostaddr,
        debug("Connection established.");
 
        /* Set SO_KEEPALIVE if requested. */
-       if (options.tcp_keep_alive &&
+       if (want_keepalive &&
            setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, (void *)&on,
            sizeof(on)) < 0)
                error("setsockopt SO_KEEPALIVE: %.100s", strerror(errno));
@@ -395,7 +431,7 @@ ssh_connect(const char *host, struct sockaddr_storage * hostaddr,
  * identification string.
  */
 static void
-ssh_exchange_identification(void)
+ssh_exchange_identification(int timeout_ms)
 {
        char buf[256], remote_version[256];     /* must be same size! */
        int remote_major, remote_minor, mismatch;
@@ -403,16 +439,44 @@ ssh_exchange_identification(void)
        int connection_out = packet_get_connection_out();
        int minor1 = PROTOCOL_MINOR_1;
        u_int i, n;
+       size_t len;
+       int fdsetsz, remaining, rc;
+       struct timeval t_start, t_remaining;
+       fd_set *fdset;
+
+       fdsetsz = howmany(connection_in + 1, NFDBITS) * sizeof(fd_mask);
+       fdset = xcalloc(1, fdsetsz);
 
        /* Read other side's version identification. */
+       remaining = timeout_ms;
        for (n = 0;;) {
                for (i = 0; i < sizeof(buf) - 1; i++) {
-                       size_t len = atomicio(read, connection_in, &buf[i], 1);
+                       if (timeout_ms > 0) {
+                               gettimeofday(&t_start, NULL);
+                               ms_to_timeval(&t_remaining, remaining);
+                               FD_SET(connection_in, fdset);
+                               rc = select(connection_in + 1, fdset, NULL,
+                                   fdset, &t_remaining);
+                               ms_subtract_diff(&t_start, &remaining);
+                               if (rc == 0 || remaining <= 0)
+                                       fatal("Connection timed out during "
+                                           "banner exchange");
+                               if (rc == -1) {
+                                       if (errno == EINTR)
+                                               continue;
+                                       fatal("ssh_exchange_identification: "
+                                           "select: %s", strerror(errno));
+                               }
+                       }
+
+                       len = atomicio(read, connection_in, &buf[i], 1);
 
                        if (len != 1 && errno == EPIPE)
-                               fatal("ssh_exchange_identification: Connection closed by remote host");
+                               fatal("ssh_exchange_identification: "
+                                   "Connection closed by remote host");
                        else if (len != 1)
-                               fatal("ssh_exchange_identification: read: %.100s", strerror(errno));
+                               fatal("ssh_exchange_identification: "
+                                   "read: %.100s", strerror(errno));
                        if (buf[i] == '\r') {
                                buf[i] = '\n';
                                buf[i + 1] = 0;
@@ -423,7 +487,8 @@ ssh_exchange_identification(void)
                                break;
                        }
                        if (++n > 65536)
-                               fatal("ssh_exchange_identification: No banner received");
+                               fatal("ssh_exchange_identification: "
+                                   "No banner received");
                }
                buf[sizeof(buf) - 1] = 0;
                if (strncmp(buf, "SSH-", 4) == 0)
@@ -431,6 +496,7 @@ ssh_exchange_identification(void)
                debug("ssh_exchange_identification: %s", buf);
        }
        server_version_string = xstrdup(buf);
+       xfree(fdset);
 
        /*
         * Check that the versions match.  In future this might accept
@@ -943,7 +1009,7 @@ verify_host_key(char *host, struct sockaddr *hostaddr, Key *host_key)
  */
 void
 ssh_login(Sensitive *sensitive, const char *orighost,
-    struct sockaddr *hostaddr, struct passwd *pw)
+    struct sockaddr *hostaddr, struct passwd *pw, int timeout_ms)
 {
        char *host, *cp;
        char *server_user, *local_user;
@@ -958,7 +1024,7 @@ ssh_login(Sensitive *sensitive, const char *orighost,
                        *cp = (char)tolower(*cp);
 
        /* Exchange protocol version identification strings with the server. */
-       ssh_exchange_identification();
+       ssh_exchange_identification(timeout_ms);
 
        /* Put the connection into non-blocking mode. */
        packet_set_nonblocking();
This page took 0.292334 seconds and 4 git commands to generate.