]> andersk Git - test.git/blob - libhttp/server.c
The server could sometimes end up listening for events even though it
[test.git] / libhttp / server.c
1 // server.c -- Generic server that can deal with HTTP connections
2 // Copyright (C) 2008-2010 Markus Gutschke <markus@shellinabox.com>
3 //
4 // This program is free software; you can redistribute it and/or modify
5 // it under the terms of the GNU General Public License version 2 as
6 // published by the Free Software Foundation.
7 //
8 // This program is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 // GNU General Public License for more details.
12 //
13 // You should have received a copy of the GNU General Public License along
14 // with this program; if not, write to the Free Software Foundation, Inc.,
15 // 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
16 //
17 // In addition to these license terms, the author grants the following
18 // additional rights:
19 //
20 // If you modify this program, or any covered work, by linking or
21 // combining it with the OpenSSL project's OpenSSL library (or a
22 // modified version of that library), containing parts covered by the
23 // terms of the OpenSSL or SSLeay licenses, the author
24 // grants you additional permission to convey the resulting work.
25 // Corresponding Source for a non-source form of such a combination
26 // shall include the source code for the parts of OpenSSL used as well
27 // as that of the covered work.
28 //
29 // You may at your option choose to remove this additional permission from
30 // the work, or from any part of it.
31 //
32 // It is possible to build this program in a way that it loads OpenSSL
33 // libraries at run-time. If doing so, the following notices are required
34 // by the OpenSSL and SSLeay licenses:
35 //
36 // This product includes software developed by the OpenSSL Project
37 // for use in the OpenSSL Toolkit. (http://www.openssl.org/)
38 //
39 // This product includes cryptographic software written by Eric Young
40 // (eay@cryptsoft.com)
41 //
42 //
43 // The most up-to-date version of this program is always available from
44 // http://shellinabox.com
45
46 #include "config.h"
47
48 #include <arpa/inet.h>
49 #include <fcntl.h>
50 #include <netinet/in.h>
51 #include <stdlib.h>
52 #include <string.h>
53 #include <sys/poll.h>
54 #include <sys/socket.h>
55 #include <sys/time.h>
56 #include <sys/types.h>
57 #include <unistd.h>
58
59 #include "libhttp/server.h"
60 #include "libhttp/httpconnection.h"
61 #include "libhttp/ssl.h"
62 #include "logging/logging.h"
63
64 #ifdef HAVE_UNUSED
65 #defined ATTR_UNUSED __attribute__((unused))
66 #defined UNUSED(x)   do { } while (0)
67 #else
68 #define ATTR_UNUSED
69 #define UNUSED(x)    do { (void)(x); } while (0)
70 #endif
71
72 #define INITIAL_TIMEOUT    (10*60)
73
74 // Maximum amount of payload (e.g. form values that have been POST'd) that we
75 // read into memory. If the application needs any more than this, the streaming
76 // API should be used, instead.
77 #define MAX_PAYLOAD_LENGTH (64<<10)
78
79
80 #if defined(__APPLE__) && defined(__MACH__)
81 // While MacOS X does ship with an implementation of poll(), this
82 // implementation is apparently known to be broken and does not comply
83 // with POSIX standards. Fortunately, the operating system is not entirely
84 // unable to check for input events. We can fall back on calling select()
85 // instead. This is generally not desirable, as it is less efficient and
86 // has a compile-time restriction on the maximum number of file
87 // descriptors. But on MacOS X, that's the best we can do.
88
89 int x_poll(struct pollfd *fds, nfds_t nfds, int timeout) {
90   fd_set r, w, x;
91   FD_ZERO(&r);
92   FD_ZERO(&w);
93   FD_ZERO(&x);
94   int maxFd             = -1;
95   for (int i = 0; i < nfds; ++i) {
96     if (fds[i].fd > maxFd) {
97       maxFd = fds[i].fd;
98     } else if (fds[i].fd < 0) {
99       continue;
100     }
101     if (fds[i].events & POLLIN) {
102       FD_SET(fds[i].fd, &r);
103     }
104     if (fds[i].events & POLLOUT) {
105       FD_SET(fds[i].fd, &w);
106     }
107     if (fds[i].events & POLLPRI) {
108       FD_SET(fds[i].fd, &x);
109     }
110   }
111   struct timeval tmoVal = { 0 }, *tmo;
112   if (timeout < 0) {
113     tmo                 = NULL;
114   } else {
115     tmoVal.tv_sec       =  timeout / 1000;
116     tmoVal.tv_usec      = (timeout % 1000) * 1000;
117     tmo                 = &tmoVal;
118   }
119   int numRet            = select(maxFd + 1, &r, &w, &x, tmo);
120   for (int i = 0, n = numRet; i < nfds && n > 0; ++i) {
121     if (fds[i].fd < 0) {
122       continue;
123     }
124     if (FD_ISSET(fds[i].fd, &x)) {
125       fds[i].revents    = POLLPRI;
126     } else if (FD_ISSET(fds[i].fd, &r)) {
127       fds[i].revents    = POLLIN;
128     } else {
129       fds[i].revents    = 0;
130     }
131     if (FD_ISSET(fds[i].fd, &w)) {
132       fds[i].revents   |= POLLOUT;
133     }
134   }
135   return numRet;
136 }
137 #define poll x_poll
138 #endif
139
140 time_t currentTime;
141
142 struct PayLoad {
143   int (*handler)(struct HttpConnection *, void *, const char *, int);
144   void *arg;
145   int  len;
146   char *bytes;
147 };
148
149 static int serverCollectFullPayload(struct HttpConnection *http,
150                                     void *payload_, const char *buf, int len) {
151   int rc                        = HTTP_READ_MORE;
152   struct PayLoad *payload       = (struct PayLoad *)payload_;
153   if (buf && len) {
154     if (payload->len + len > MAX_PAYLOAD_LENGTH) {
155       httpSendReply(http, 400, "Bad Request", NO_MSG);
156       return HTTP_DONE;
157     }
158     check(len > 0);
159     check(payload->bytes        = realloc(payload->bytes, payload->len + len));
160     memcpy(payload->bytes + payload->len, buf, len);
161     payload->len               += len;
162   }
163   const char *contentLength     = getFromHashMap(httpGetHeaders(http),
164                                                  "content-length");
165   if (!contentLength ||
166       (payload->bytes &&
167        ((contentLength && atoi(contentLength) <= payload->len) || !buf))) {
168     rc = payload->handler(http, payload->arg,
169                           payload->bytes ? payload->bytes : "", payload->len);
170     free(payload->bytes);
171     payload->bytes              = NULL;
172     payload->len                = 0;
173   }
174   if (!buf) {
175     if (rc == HTTP_SUSPEND || rc == HTTP_PARTIAL_REPLY) {
176       // Tell the other party that the connection is getting torn down, even
177       // though it requested it to be suspended.
178       payload->handler(http, payload->arg, NULL, 0);
179       rc                        = HTTP_DONE;
180     }
181     free(payload);
182   }
183   return rc;
184   
185 }
186
187 static int serverCollectHandler(struct HttpConnection *http, void *handler_) {
188   struct HttpHandler *handler = handler_;
189   struct PayLoad *payload;
190   check(payload               = malloc(sizeof(struct PayLoad)));
191   payload->handler            = handler->streamingHandler;
192   payload->arg                = handler->streamingArg;
193   payload->len                = 0;
194   payload->bytes              = malloc(0);
195   httpSetCallback(http, serverCollectFullPayload, payload);
196   return HTTP_READ_MORE;
197
198 }
199
200 static void serverDestroyHandlers(void *arg ATTR_UNUSED, char *value) {
201   UNUSED(arg);
202   free(value);
203 }
204
205 void serverRegisterHttpHandler(struct Server *server, const char *url,
206                                int (*handler)(struct HttpConnection *, void *,
207                                               const char *, int), void *arg) {
208   if (!handler) {
209     addToTrie(&server->handlers, url, NULL);
210   } else {
211     struct HttpHandler *h;
212     check(h             = malloc(sizeof(struct HttpHandler)));
213     h->handler          = serverCollectHandler;
214     h->arg              = h;
215     h->streamingHandler = handler;
216     h->websocketHandler = NULL;
217     h->streamingArg     = arg;
218     addToTrie(&server->handlers, url, (char *)h);
219   }
220 }
221
222 void serverRegisterStreamingHttpHandler(struct Server *server, const char *url,
223                                int (*handler)(struct HttpConnection *, void *),
224                                void *arg) {
225   if (!handler) {
226     addToTrie(&server->handlers, url, NULL);
227   } else {
228     struct HttpHandler *h;
229     check(h             = malloc(sizeof(struct HttpHandler)));
230     h->handler          = handler;
231     h->streamingHandler = NULL;
232     h->websocketHandler = NULL;
233     h->streamingArg     = NULL;
234     h->arg              = arg;
235     addToTrie(&server->handlers, url, (char *)h);
236   }
237 }
238
239 void serverRegisterWebSocketHandler(struct Server *server, const char *url,
240        int (*handler)(struct HttpConnection *, void *, int, const char *, int),
241        void *arg) {
242   if (!handler) {
243     addToTrie(&server->handlers, url, NULL);
244   } else {
245     struct HttpHandler *h;
246     check(h             = malloc(sizeof(struct HttpHandler)));
247     h->handler          = NULL;
248     h->streamingHandler = NULL;
249     h->websocketHandler = handler;
250     h->arg              = arg;
251     addToTrie(&server->handlers, url, (char *)h);
252   }
253 }
254
255 static int serverQuitHandler(struct HttpConnection *http ATTR_UNUSED,
256                              void *arg) {
257   UNUSED(arg);
258   httpSendReply(http, 200, "Good Bye", NO_MSG);
259   httpExitLoop(http, 1);
260   return HTTP_DONE;
261 }
262
263 struct Server *newCGIServer(int localhostOnly, int portMin, int portMax,
264                             int timeout) {
265   struct Server *server;
266   check(server = malloc(sizeof(struct Server)));
267   initServer(server, localhostOnly, portMin, portMax, timeout);
268   return server;
269 }
270
271 struct Server *newServer(int localhostOnly, int port) {
272   return newCGIServer(localhostOnly, port, port, -1);
273 }
274
275 void initServer(struct Server *server, int localhostOnly, int portMin,
276                 int portMax, int timeout) {
277   server->looping               = 0;
278   server->exitAll               = 0;
279   server->serverTimeout         = timeout;
280   server->numericHosts          = 0;
281   server->connections           = NULL;
282   server->numConnections        = 0;
283
284   int true                      = 1;
285   server->serverFd              = socket(PF_INET, SOCK_STREAM, 0);
286   check(server->serverFd >= 0);
287   check(!setsockopt(server->serverFd, SOL_SOCKET, SO_REUSEADDR,
288                     &true, sizeof(true)));
289   struct sockaddr_in serverAddr = { 0 };
290   serverAddr.sin_family         = AF_INET;
291   serverAddr.sin_addr.s_addr    = htonl(localhostOnly
292                                         ? INADDR_LOOPBACK : INADDR_ANY);
293
294   // Linux unlike BSD does not have support for picking a local port range.
295   // So, we have to randomly pick a port from our allowed port range, and then
296   // keep iterating until we find an unused port.
297   if (portMin || portMax) {
298     struct timeval tv;
299     check(!gettimeofday(&tv, NULL));
300     srand((int)(tv.tv_usec ^ tv.tv_sec));
301     check(portMin > 0);
302     check(portMax < 65536);
303     check(portMax >= portMin);
304     int portStart               = rand() % (portMax - portMin + 1) + portMin;
305     for (int p = 0; p <= portMax-portMin; p++) {
306       int port                  = (p+portStart)%(portMax-portMin+1)+ portMin;
307       serverAddr.sin_port       = htons(port);
308       if (!bind(server->serverFd, (struct sockaddr *)&serverAddr,
309                 sizeof(serverAddr))) {
310         break;
311       }
312       serverAddr.sin_port       = 0;
313     }
314     if (!serverAddr.sin_port) {
315       fatal("Failed to find any available port");
316     }
317   }
318
319   check(!listen(server->serverFd, SOMAXCONN));
320   socklen_t socklen             = (socklen_t)sizeof(serverAddr);
321   check(!getsockname(server->serverFd, (struct sockaddr *)&serverAddr,
322                      &socklen));
323   check(socklen == sizeof(serverAddr));
324   server->port                  = ntohs(serverAddr.sin_port);
325   info("Listening on port %d", server->port);
326
327   check(server->pollFds         = malloc(sizeof(struct pollfd)));
328   server->pollFds->fd           = server->serverFd;
329   server->pollFds->events       = POLLIN;
330
331   initTrie(&server->handlers, serverDestroyHandlers, NULL);
332   serverRegisterStreamingHttpHandler(server, "/quit", serverQuitHandler, NULL);
333   initSSL(&server->ssl);
334 }
335
336 void destroyServer(struct Server *server) {
337   if (server) {
338     if (server->serverFd >= 0) {
339       info("Shutting down server");
340       close(server->serverFd);
341     }
342     for (int i = 0; i < server->numConnections; i++) {
343       server->connections[i].destroyConnection(server->connections[i].arg);
344     }
345     free(server->connections);
346     free(server->pollFds);
347     destroyTrie(&server->handlers);
348     destroySSL(&server->ssl);
349   }
350 }
351
352 void deleteServer(struct Server *server) {
353   destroyServer(server);
354   free(server);
355 }
356
357 int serverGetListeningPort(struct Server *server) {
358   return server->port;
359 }
360
361 int serverGetFd(struct Server *server) {
362   return server->serverFd;
363 }
364
365 struct ServerConnection *serverAddConnection(struct Server *server, int fd,
366                          int (*handleConnection)(struct ServerConnection *c,
367                                                  void *arg, short *events,
368                                                  short revents),
369                          void (*destroyConnection)(void *arg),
370                          void *arg) {
371   check(server->connections     = realloc(server->connections,
372                                           ++server->numConnections*
373                                           sizeof(struct ServerConnection)));
374   check(server->pollFds         = realloc(server->pollFds,
375                                           (server->numConnections + 1) *
376                                           sizeof(struct pollfd)));
377   server->pollFds[server->numConnections].fd     = fd;
378   server->pollFds[server->numConnections].events = POLLIN;
379   struct ServerConnection *connection            =
380                               server->connections + server->numConnections - 1;
381   connection->deleted           = 0;
382   connection->timeout           = 0;
383   connection->handleConnection  = handleConnection;
384   connection->destroyConnection = destroyConnection;
385   connection->arg               = arg;
386   return connection;
387 }
388
389 void serverDeleteConnection(struct Server *server, int fd) {
390   for (int i = 0; i < server->numConnections; i++) {
391     if (fd == server->pollFds[i + 1].fd && !server->connections[i].deleted) {
392       server->connections[i].deleted = 1;
393       server->connections[i].destroyConnection(server->connections[i].arg);
394       return;
395     }
396   }
397 }
398
399 void serverSetTimeout(struct ServerConnection *connection, time_t timeout) {
400   if (!currentTime) {
401     currentTime       = time(NULL);
402   }
403   connection->timeout = timeout > 0 ? timeout + currentTime : 0;
404 }
405
406 time_t serverGetTimeout(struct ServerConnection *connection) {
407   if (connection->timeout) {
408     // Returns <0 if expired, 0 if not set, and >0 if still pending.
409     if (!currentTime) {
410       currentTime = time(NULL);
411     }
412     int remaining = connection->timeout - currentTime;
413     if (!remaining) {
414       remaining--;
415     }
416     return remaining;
417   } else {
418     return 0;
419   }
420 }
421
422 struct ServerConnection *serverGetConnection(struct Server *server,
423                                              struct ServerConnection *hint,
424                                              int fd) {
425   if (hint &&
426       server->connections <= hint &&
427       server->connections + server->numConnections > hint) {
428     // The compiler would like to optimize the expression:
429     //   &server->connections[hint - server->connections]     <=>
430     //   server->connections + hint - server->connections     <=>
431     //   hint
432     // This transformation is correct as far as the language specification is
433     // concerned, but it is unintended as we actually want to check whether
434     // the alignment is correct. So, instead of comparing
435     //   &server->connections[hint - server->connections] == hint
436     // we first use memcpy() to break aliasing.
437     uintptr_t ptr1, ptr2;
438     memcpy(&ptr1, &hint, sizeof(ptr1));
439     memcpy(&ptr2, &server->connections, sizeof(ptr2));
440     int idx = (ptr1 - ptr2)/sizeof(*server->connections);
441     if (&server->connections[idx] == hint &&
442         !hint->deleted &&
443         server->pollFds[hint - server->connections + 1].fd == fd) {
444       return hint;
445     }
446   }
447   for (int i = 0; i < server->numConnections; i++) {
448     if (server->pollFds[i + 1].fd == fd && !server->connections[i].deleted) {
449       return server->connections + i;
450     }
451   }
452   return NULL;
453 }
454
455 short serverConnectionSetEvents(struct Server *server,
456                                 struct ServerConnection *connection, int fd,
457                                 short events) {
458   dcheck(server);
459   dcheck(connection);
460   dcheck(connection >= server->connections);
461   dcheck(connection < server->connections + server->numConnections);
462   dcheck(connection == &server->connections[connection - server->connections]);
463   dcheck(!connection->deleted);
464   int   idx                       = connection - server->connections;
465   short oldEvents                 = server->pollFds[idx + 1].events;
466   dcheck(fd == server->pollFds[idx + 1].fd);
467   server->pollFds[idx + 1].events = events;
468   return oldEvents;
469 }
470
471 void serverExitLoop(struct Server *server, int exitAll) {
472   server->looping--;
473   server->exitAll |= exitAll;
474 }
475
476 void serverLoop(struct Server *server) {
477   check(server->serverFd >= 0);
478   time_t lastTime;
479   currentTime                             = time(&lastTime);
480   int loopDepth                           = ++server->looping;
481   while (server->looping >= loopDepth && !server->exitAll) {
482     // TODO: There probably should be some limit on the maximum number
483     // of concurrently opened HTTP connections, as this could lead to
484     // memory exhaustion and a DoS attack.
485     time_t timeout                        = -1;
486     int numFds                            = server->numConnections + 1;
487
488     for (int i = 0; i < server->numConnections; i++) {
489       while (i < numFds - 1 && !server->pollFds[i + 1].events) {
490         // Sort filedescriptors that currently do not expect any events to
491         // the end of the list.
492         check(--numFds > 0);
493         struct pollfd tmpPollFd;
494         memmove(&tmpPollFd, server->pollFds + numFds, sizeof(struct pollfd));
495         memmove(server->pollFds + numFds, server->pollFds + i + 1,
496                 sizeof(struct pollfd));
497         memmove(server->pollFds + i + 1, &tmpPollFd, sizeof(struct pollfd));
498         struct ServerConnection tmpConnection;
499         memmove(&tmpConnection, server->connections + numFds - 1,
500                 sizeof(struct ServerConnection));
501         memmove(server->connections + numFds - 1, server->connections + i,
502                 sizeof(struct ServerConnection));
503         memmove(server->connections + i, &tmpConnection,
504                 sizeof(struct ServerConnection));
505       }
506
507       if (server->connections[i].timeout &&
508           (timeout < 0 || timeout > server->connections[i].timeout)) {
509         timeout                           = server->connections[i].timeout;
510       }
511     }
512
513     // serverTimeout is always a delta value, unlike connection timeouts
514     // which are absolute times.
515     if (server->serverTimeout >= 0) {
516       if (timeout < 0 || timeout > server->serverTimeout + currentTime) {
517         timeout                           = server->serverTimeout+currentTime;
518       }
519     }
520
521     if (timeout >= 0) {
522       // Wait at least one second longer than needed, so that even if
523       // poll() decides to return a second early (due to possible rounding
524       // errors), we still correctly detect a timeout condition.
525       if (timeout >= lastTime) {
526         timeout                           = (timeout - lastTime + 1) * 1000;
527       } else {
528         timeout                           = 1000;
529       }
530     }
531
532     int eventCount                        = NOINTR(poll(server->pollFds,
533                                                         numFds,
534                                                         timeout));
535     check(eventCount >= 0);
536     if (timeout >= 0) {
537       timeout                            += lastTime;
538     }
539     currentTime                           = time(&lastTime);
540     int isTimeout                         = timeout >= 0 &&
541                                             timeout/1000 <= lastTime;
542     if (eventCount > 0 && server->pollFds[0].revents) {
543       eventCount--;
544       if (server->pollFds[0].revents && POLLIN) {
545         struct sockaddr_in clientAddr;
546         socklen_t sockLen                 = sizeof(clientAddr);
547         int clientFd                      = accept(
548                    server->serverFd, (struct sockaddr *)&clientAddr, &sockLen);
549         dcheck(clientFd >= 0);
550         if (clientFd >= 0) {
551           check(!fcntl(clientFd, F_SETFL, O_RDWR | O_NONBLOCK));
552           struct HttpConnection *http;
553           http                            = newHttpConnection(
554                                      server, clientFd, server->port,
555                                      server->ssl.enabled ? &server->ssl : NULL,
556                                      server->numericHosts);
557           serverSetTimeout(
558             serverAddConnection(server, clientFd, httpHandleConnection,
559                                 (void (*)(void *))deleteHttpConnection,
560                                 http),
561             INITIAL_TIMEOUT);
562         }
563       }
564     } else {
565       if (server->serverTimeout > 0 && !server->numConnections) {
566         // In CGI mode, exit the server, if we haven't had any active
567         // connections in a while.
568         break;
569       }
570     }
571     for (int i = 1;
572          (isTimeout || eventCount > 0) && i <= server->numConnections;
573          i++) {
574       struct ServerConnection *connection = server->connections + i - 1;
575       if (connection->deleted) {
576         continue;
577       }
578       if (!eventCount) {
579         server->pollFds[i].revents        = 0;
580       }
581       if (server->pollFds[i].revents ||
582           (connection->timeout && lastTime >= connection->timeout)) {
583         if (server->pollFds[i].revents) {
584           eventCount--;
585         }
586         if (!connection->handleConnection(connection, connection->arg,
587                                           &server->pollFds[i].events,
588                                           server->pollFds[i].revents)) {
589           connection                      = server->connections + i - 1;
590           connection->destroyConnection(connection->arg);
591           connection->deleted             = 1;
592         }
593       }
594     }
595     for (int i = 1; i <= server->numConnections; i++) {
596       if (server->connections[i-1].deleted) {
597         memmove(server->pollFds + i, server->pollFds + i + 1,
598                 (server->numConnections - i) * sizeof(struct pollfd));
599         memmove(server->connections + i - 1, server->connections + i,
600                 (server->numConnections - i)*sizeof(struct ServerConnection));
601         check(--i >= 0);
602         check(--server->numConnections >= 0);
603       }
604     }
605   }
606   // Even if multiple clients requested for us to exit the loop, we only
607   // ever exit the outer most loop.
608   server->looping                         = loopDepth - 1;
609 }
610
611 void serverEnableSSL(struct Server *server, int flag) {
612   if (flag) {
613     check(serverSupportsSSL());
614   }
615   sslEnable(&server->ssl, flag);
616 }
617
618 void serverSetCertificate(struct Server *server, const char *filename,
619                           int autoGenerateMissing) {
620   sslSetCertificate(&server->ssl, filename, autoGenerateMissing);
621 }
622
623 void serverSetCertificateFd(struct Server *server, int fd) {
624   sslSetCertificateFd(&server->ssl, fd);
625 }
626
627 void serverSetNumericHosts(struct Server *server, int numericHosts) {
628   server->numericHosts = numericHosts;
629 }
630
631 struct Trie *serverGetHttpHandlers(struct Server *server) {
632   return &server->handlers;
633 }
This page took 0.079804 seconds and 5 git commands to generate.