]> andersk Git - test.git/blob - libhttp/server.c
3b115869c5167e1449fa5d4c78f6e9c3ba60f0f8
[test.git] / libhttp / server.c
1 // server.c -- Generic server that can deal with HTTP connections
2 // Copyright (C) 2008-2010 Markus Gutschke <markus@shellinabox.com>
3 //
4 // This program is free software; you can redistribute it and/or modify
5 // it under the terms of the GNU General Public License version 2 as
6 // published by the Free Software Foundation.
7 //
8 // This program is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 // GNU General Public License for more details.
12 //
13 // You should have received a copy of the GNU General Public License along
14 // with this program; if not, write to the Free Software Foundation, Inc.,
15 // 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
16 //
17 // In addition to these license terms, the author grants the following
18 // additional rights:
19 //
20 // If you modify this program, or any covered work, by linking or
21 // combining it with the OpenSSL project's OpenSSL library (or a
22 // modified version of that library), containing parts covered by the
23 // terms of the OpenSSL or SSLeay licenses, the author
24 // grants you additional permission to convey the resulting work.
25 // Corresponding Source for a non-source form of such a combination
26 // shall include the source code for the parts of OpenSSL used as well
27 // as that of the covered work.
28 //
29 // You may at your option choose to remove this additional permission from
30 // the work, or from any part of it.
31 //
32 // It is possible to build this program in a way that it loads OpenSSL
33 // libraries at run-time. If doing so, the following notices are required
34 // by the OpenSSL and SSLeay licenses:
35 //
36 // This product includes software developed by the OpenSSL Project
37 // for use in the OpenSSL Toolkit. (http://www.openssl.org/)
38 //
39 // This product includes cryptographic software written by Eric Young
40 // (eay@cryptsoft.com)
41 //
42 //
43 // The most up-to-date version of this program is always available from
44 // http://shellinabox.com
45
46 #include "config.h"
47
48 #include <arpa/inet.h>
49 #include <fcntl.h>
50 #include <netinet/in.h>
51 #include <stdlib.h>
52 #include <string.h>
53 #include <sys/poll.h>
54 #include <sys/socket.h>
55 #include <sys/time.h>
56 #include <sys/types.h>
57 #include <unistd.h>
58
59 #include "libhttp/server.h"
60 #include "libhttp/httpconnection.h"
61 #include "libhttp/ssl.h"
62 #include "logging/logging.h"
63
64 #define INITIAL_TIMEOUT    (10*60)
65
66 // Maximum amount of payload (e.g. form values that have been POST'd) that we
67 // read into memory. If the application needs any more than this, the streaming
68 // API should be used, instead.
69 #define MAX_PAYLOAD_LENGTH (64<<10)
70
71 time_t currentTime;
72
73 struct PayLoad {
74   int (*handler)(struct HttpConnection *, void *, const char *, int);
75   void *arg;
76   int  len;
77   char *bytes;
78 };
79
80 static int serverCollectFullPayload(struct HttpConnection *http,
81                                     void *payload_, const char *buf, int len) {
82   int rc                        = HTTP_READ_MORE;
83   struct PayLoad *payload       = (struct PayLoad *)payload_;
84   if (buf && len) {
85     if (payload->len + len > MAX_PAYLOAD_LENGTH) {
86       httpSendReply(http, 400, "Bad Request", NO_MSG);
87       return HTTP_DONE;
88     }
89     check(len > 0);
90     check(payload->bytes        = realloc(payload->bytes, payload->len + len));
91     memcpy(payload->bytes + payload->len, buf, len);
92     payload->len               += len;
93   }
94   const char *contentLength     = getFromHashMap(httpGetHeaders(http),
95                                                  "content-length");
96   if (!contentLength ||
97       (payload->bytes &&
98        ((contentLength && atoi(contentLength) <= payload->len) || !buf))) {
99     rc = payload->handler(http, payload->arg,
100                           payload->bytes ? payload->bytes : "", payload->len);
101     free(payload->bytes);
102     payload->bytes              = NULL;
103     payload->len                = 0;
104   }
105   if (!buf) {
106     if (rc == HTTP_SUSPEND || rc == HTTP_PARTIAL_REPLY) {
107       // Tell the other party that the connection is getting torn down, even
108       // though it requested it to be suspended.
109       payload->handler(http, payload->arg, NULL, 0);
110       rc                        = HTTP_DONE;
111     }
112     free(payload);
113   }
114   return rc;
115   
116 }
117
118 static int serverCollectHandler(struct HttpConnection *http, void *handler_) {
119   struct HttpHandler *handler = handler_;
120   struct PayLoad *payload;
121   check(payload               = malloc(sizeof(struct PayLoad)));
122   payload->handler            = handler->streamingHandler;
123   payload->arg                = handler->streamingArg;
124   payload->len                = 0;
125   payload->bytes              = malloc(0);
126   httpSetCallback(http, serverCollectFullPayload, payload);
127   return HTTP_READ_MORE;
128
129 }
130
131 static void serverDestroyHandlers(void *arg, char *value) {
132   (void)arg;
133   free(value);
134 }
135
136 void serverRegisterHttpHandler(struct Server *server, const char *url,
137                                int (*handler)(struct HttpConnection *, void *,
138                                               const char *, int), void *arg) {
139   if (!handler) {
140     addToTrie(&server->handlers, url, NULL);
141   } else {
142     struct HttpHandler *h;
143     check(h             = malloc(sizeof(struct HttpHandler)));
144     h->handler          = serverCollectHandler;
145     h->arg              = h;
146     h->streamingHandler = handler;
147     h->websocketHandler = NULL;
148     h->streamingArg     = arg;
149     addToTrie(&server->handlers, url, (char *)h);
150   }
151 }
152
153 void serverRegisterStreamingHttpHandler(struct Server *server, const char *url,
154                                int (*handler)(struct HttpConnection *, void *),
155                                void *arg) {
156   if (!handler) {
157     addToTrie(&server->handlers, url, NULL);
158   } else {
159     struct HttpHandler *h;
160     check(h             = malloc(sizeof(struct HttpHandler)));
161     h->handler          = handler;
162     h->streamingHandler = NULL;
163     h->websocketHandler = NULL;
164     h->streamingArg     = NULL;
165     h->arg              = arg;
166     addToTrie(&server->handlers, url, (char *)h);
167   }
168 }
169
170 void serverRegisterWebSocketHandler(struct Server *server, const char *url,
171        int (*handler)(struct HttpConnection *, void *, int, const char *, int),
172        void *arg) {
173   if (!handler) {
174     addToTrie(&server->handlers, url, NULL);
175   } else {
176     struct HttpHandler *h;
177     check(h             = malloc(sizeof(struct HttpHandler)));
178     h->handler          = NULL;
179     h->streamingHandler = NULL;
180     h->websocketHandler = handler;
181     h->arg              = arg;
182     addToTrie(&server->handlers, url, (char *)h);
183   }
184 }
185
186 static int serverQuitHandler(struct HttpConnection *http, void *arg) {
187   (void)arg;
188   httpSendReply(http, 200, "Good Bye", NO_MSG);
189   httpExitLoop(http, 1);
190   return HTTP_DONE;
191 }
192
193 struct Server *newCGIServer(int localhostOnly, int portMin, int portMax,
194                             int timeout) {
195   struct Server *server;
196   check(server = malloc(sizeof(struct Server)));
197   initServer(server, localhostOnly, portMin, portMax, timeout);
198   return server;
199 }
200
201 struct Server *newServer(int localhostOnly, int port) {
202   return newCGIServer(localhostOnly, port, port, -1);
203 }
204
205 void initServer(struct Server *server, int localhostOnly, int portMin,
206                 int portMax, int timeout) {
207   server->looping               = 0;
208   server->exitAll               = 0;
209   server->serverTimeout         = timeout;
210   server->numericHosts          = 0;
211   server->connections           = NULL;
212   server->numConnections        = 0;
213
214   int true                      = 1;
215   server->serverFd              = socket(PF_INET, SOCK_STREAM, 0);
216   check(server->serverFd >= 0);
217   check(!setsockopt(server->serverFd, SOL_SOCKET, SO_REUSEADDR,
218                     &true, sizeof(true)));
219   struct sockaddr_in serverAddr = { 0 };
220   serverAddr.sin_family         = AF_INET;
221   serverAddr.sin_addr.s_addr    = htonl(localhostOnly
222                                         ? INADDR_LOOPBACK : INADDR_ANY);
223
224   // Linux unlike BSD does not have support for picking a local port range.
225   // So, we have to randomly pick a port from our allowed port range, and then
226   // keep iterating until we find an unused port.
227   if (portMin || portMax) {
228     struct timeval tv;
229     check(!gettimeofday(&tv, NULL));
230     srand((int)(tv.tv_usec ^ tv.tv_sec));
231     check(portMin > 0);
232     check(portMax < 65536);
233     check(portMax >= portMin);
234     int portStart               = rand() % (portMax - portMin + 1) + portMin;
235     for (int p = 0; p <= portMax-portMin; p++) {
236       int port                  = (p+portStart)%(portMax-portMin+1)+ portMin;
237       serverAddr.sin_port       = htons(port);
238       if (!bind(server->serverFd, (struct sockaddr *)&serverAddr,
239                 sizeof(serverAddr))) {
240         break;
241       }
242       serverAddr.sin_port       = 0;
243     }
244     if (!serverAddr.sin_port) {
245       fatal("Failed to find any available port");
246     }
247   }
248
249   check(!listen(server->serverFd, SOMAXCONN));
250   socklen_t socklen             = (socklen_t)sizeof(serverAddr);
251   check(!getsockname(server->serverFd, (struct sockaddr *)&serverAddr,
252                      &socklen));
253   check(socklen == sizeof(serverAddr));
254   server->port                  = ntohs(serverAddr.sin_port);
255   info("Listening on port %d", server->port);
256
257   check(server->pollFds         = malloc(sizeof(struct pollfd)));
258   server->pollFds->fd           = server->serverFd;
259   server->pollFds->events       = POLLIN;
260
261   initTrie(&server->handlers, serverDestroyHandlers, NULL);
262   serverRegisterStreamingHttpHandler(server, "/quit", serverQuitHandler, NULL);
263   initSSL(&server->ssl);
264 }
265
266 void destroyServer(struct Server *server) {
267   if (server) {
268     if (server->serverFd >= 0) {
269       info("Shutting down server");
270       close(server->serverFd);
271     }
272     for (int i = 0; i < server->numConnections; i++) {
273       server->connections[i].destroyConnection(server->connections[i].arg);
274     }
275     free(server->connections);
276     free(server->pollFds);
277     destroyTrie(&server->handlers);
278     destroySSL(&server->ssl);
279   }
280 }
281
282 void deleteServer(struct Server *server) {
283   destroyServer(server);
284   free(server);
285 }
286
287 int serverGetListeningPort(struct Server *server) {
288   return server->port;
289 }
290
291 int serverGetFd(struct Server *server) {
292   return server->serverFd;
293 }
294
295 struct ServerConnection *serverAddConnection(struct Server *server, int fd,
296                          int (*handleConnection)(struct ServerConnection *c,
297                                                  void *arg, short *events,
298                                                  short revents),
299                          void (*destroyConnection)(void *arg),
300                          void *arg) {
301   check(server->connections     = realloc(server->connections,
302                                           ++server->numConnections*
303                                           sizeof(struct ServerConnection)));
304   check(server->pollFds         = realloc(server->pollFds,
305                                           (server->numConnections + 1) *
306                                           sizeof(struct pollfd)));
307   server->pollFds[server->numConnections].fd     = fd;
308   server->pollFds[server->numConnections].events = POLLIN;
309   struct ServerConnection *connection            =
310                               server->connections + server->numConnections - 1;
311   connection->deleted           = 0;
312   connection->timeout           = 0;
313   connection->handleConnection  = handleConnection;
314   connection->destroyConnection = destroyConnection;
315   connection->arg               = arg;
316   return connection;
317 }
318
319 void serverDeleteConnection(struct Server *server, int fd) {
320   for (int i = 0; i < server->numConnections; i++) {
321     if (fd == server->pollFds[i + 1].fd && !server->connections[i].deleted) {
322       server->connections[i].deleted = 1;
323       server->connections[i].destroyConnection(server->connections[i].arg);
324       return;
325     }
326   }
327 }
328
329 void serverSetTimeout(struct ServerConnection *connection, time_t timeout) {
330   if (!currentTime) {
331     currentTime       = time(NULL);
332   }
333   connection->timeout = timeout > 0 ? timeout + currentTime : 0;
334 }
335
336 time_t serverGetTimeout(struct ServerConnection *connection) {
337   if (connection->timeout) {
338     // Returns <0 if expired, 0 if not set, and >0 if still pending.
339     if (!currentTime) {
340       currentTime = time(NULL);
341     }
342     int remaining = connection->timeout - currentTime;
343     if (!remaining) {
344       remaining--;
345     }
346     return remaining;
347   } else {
348     return 0;
349   }
350 }
351
352 struct ServerConnection *serverGetConnection(struct Server *server,
353                                              struct ServerConnection *hint,
354                                              int fd) {
355   if (hint &&
356       server->connections <= hint &&
357       server->connections + server->numConnections > hint) {
358     // The compiler would like to optimize the expression:
359     //   &server->connections[hint - server->connections]     <=>
360     //   server->connections + hint - server->connections     <=>
361     //   hint
362     // This transformation is correct as far as the language specification is
363     // concerned, but it is unintended as we actually want to check whether
364     // the alignment is correct. So, instead of comparing
365     //   &server->connections[hint - server->connections] == hint
366     // we first use memcpy() to break aliasing.
367     uintptr_t ptr1, ptr2;
368     memcpy(&ptr1, &hint, sizeof(ptr1));
369     memcpy(&ptr2, &server->connections, sizeof(ptr2));
370     int idx = (ptr1 - ptr2)/sizeof(*server->connections);
371     if (&server->connections[idx] == hint &&
372         !hint->deleted &&
373         server->pollFds[hint - server->connections + 1].fd == fd) {
374       return hint;
375     }
376   }
377   for (int i = 0; i < server->numConnections; i++) {
378     if (server->pollFds[i + 1].fd == fd && !server->connections[i].deleted) {
379       return server->connections + i;
380     }
381   }
382   return NULL;
383 }
384
385 short serverConnectionSetEvents(struct Server *server,
386                                 struct ServerConnection *connection,
387                                 short events) {
388   dcheck(server);
389   dcheck(connection);
390   dcheck(connection >= server->connections);
391   dcheck(connection < server->connections + server->numConnections);
392   dcheck(connection == &server->connections[connection - server->connections]);
393   dcheck(!connection->deleted);
394   int   idx                       = connection - server->connections;
395   short oldEvents                 = server->pollFds[idx + 1].events;
396   server->pollFds[idx + 1].events = events;
397   return oldEvents;
398 }
399
400 void serverExitLoop(struct Server *server, int exitAll) {
401   server->looping--;
402   server->exitAll |= exitAll;
403 }
404
405 void serverLoop(struct Server *server) {
406   check(server->serverFd >= 0);
407   time_t lastTime;
408   currentTime                             = time(&lastTime);
409   int loopDepth                           = ++server->looping;
410   while (server->looping >= loopDepth && !server->exitAll) {
411     // TODO: There probably should be some limit on the maximum number
412     // of concurrently opened HTTP connections, as this could lead to
413     // memory exhaustion and a DoS attack.
414     time_t timeout                        = -1;
415     int numFds                            = server->numConnections + 1;
416
417     for (int i = 0; i < server->numConnections; i++) {
418       while (i < numFds - 1 && !server->pollFds[i + 1].events) {
419         // Sort filedescriptors that currently do not expect any events to
420         // the end of the list.
421         check(--numFds > 0);
422         struct pollfd tmpPollFd;
423         memmove(&tmpPollFd, server->pollFds + numFds, sizeof(struct pollfd));
424         memmove(server->pollFds + numFds, server->pollFds + i + 1,
425                 sizeof(struct pollfd));
426         memmove(server->pollFds + i + 1, &tmpPollFd, sizeof(struct pollfd));
427         struct ServerConnection tmpConnection;
428         memmove(&tmpConnection, server->connections + numFds - 1,
429                 sizeof(struct ServerConnection));
430         memmove(server->connections + numFds - 1, server->connections + i,
431                 sizeof(struct ServerConnection));
432         memmove(server->connections + i, &tmpConnection,
433                 sizeof(struct ServerConnection));
434       }
435
436       if (server->connections[i].timeout &&
437           (timeout < 0 || timeout > server->connections[i].timeout)) {
438         timeout                           = server->connections[i].timeout;
439       }
440     }
441
442     // serverTimeout is always a delta value, unlike connection timeouts
443     // which are absolute times.
444     if (server->serverTimeout >= 0) {
445       if (timeout < 0 || timeout > server->serverTimeout + currentTime) {
446         timeout                           = server->serverTimeout+currentTime;
447       }
448     }
449
450     if (timeout >= 0) {
451       // Wait at least one second longer than needed, so that even if
452       // poll() decides to return a second early (due to possible rounding
453       // errors), we still correctly detect a timeout condition.
454       if (timeout >= lastTime) {
455         timeout                           = (timeout - lastTime + 1) * 1000;
456       } else {
457         timeout                           = 1000;
458       }
459     }
460
461     int eventCount                        = NOINTR(poll(server->pollFds,
462                                                         numFds,
463                                                         timeout));
464     check(eventCount >= 0);
465     if (timeout >= 0) {
466       timeout                            += lastTime;
467     }
468     currentTime                           = time(&lastTime);
469     int isTimeout                         = timeout >= 0 &&
470                                             timeout/1000 <= lastTime;
471     if (server->pollFds[0].revents) {
472       eventCount--;
473       if (server->pollFds[0].revents && POLLIN) {
474         struct sockaddr_in clientAddr;
475         socklen_t sockLen                 = sizeof(clientAddr);
476         int clientFd                      = accept(
477                    server->serverFd, (struct sockaddr *)&clientAddr, &sockLen);
478         dcheck(clientFd >= 0);
479         if (clientFd >= 0) {
480           check(!fcntl(clientFd, F_SETFL, O_RDWR | O_NONBLOCK));
481           struct HttpConnection *http;
482           http                            = newHttpConnection(
483                                      server, clientFd, server->port,
484                                      server->ssl.enabled ? &server->ssl : NULL,
485                                      server->numericHosts);
486           serverSetTimeout(
487             serverAddConnection(server, clientFd, httpHandleConnection,
488                                 (void (*)(void *))deleteHttpConnection,
489                                 http),
490             INITIAL_TIMEOUT);
491         }
492       }
493     } else {
494       if (server->serverTimeout > 0 && !server->numConnections) {
495         // In CGI mode, exit the server, if we haven't had any active
496         // connections in a while.
497         break;
498       }
499     }
500     for (int i = 1;
501          (isTimeout || eventCount > 0) && i <= server->numConnections;
502          i++) {
503       struct ServerConnection *connection = server->connections + i - 1;
504       if (connection->deleted) {
505         continue;
506       }
507       if (!eventCount) {
508         server->pollFds[i].revents        = 0;
509       }
510       if (server->pollFds[i].revents ||
511           (connection->timeout && lastTime >= connection->timeout)) {
512         if (server->pollFds[i].revents) {
513           eventCount--;
514         }
515         short events                      = server->pollFds[i].events;
516         if (!connection->handleConnection(connection, connection->arg,
517                                          &events, server->pollFds[i].revents)){
518           connection                      = server->connections + i - 1;
519           connection->destroyConnection(connection->arg);
520           connection->deleted             = 1;
521         } else {
522           server->pollFds[i].events       = events;
523         }
524       }
525     }
526     for (int i = 1; i <= server->numConnections; i++) {
527       if (server->connections[i-1].deleted) {
528         memmove(server->pollFds + i, server->pollFds + i + 1,
529                 (server->numConnections - i) * sizeof(struct pollfd));
530         memmove(server->connections + i - 1, server->connections + i,
531                 (server->numConnections - i)*sizeof(struct ServerConnection));
532         check(--i >= 0);
533         check(--server->numConnections >= 0);
534       }
535     }
536   }
537   // Even if multiple clients requested for us to exit the loop, we only
538   // ever exit the outer most loop.
539   server->looping                         = loopDepth - 1;
540 }
541
542 void serverEnableSSL(struct Server *server, int flag) {
543   if (flag) {
544     check(serverSupportsSSL());
545   }
546   sslEnable(&server->ssl, flag);
547 }
548
549 void serverSetCertificate(struct Server *server, const char *filename,
550                           int autoGenerateMissing) {
551   sslSetCertificate(&server->ssl, filename, autoGenerateMissing);
552 }
553
554 void serverSetCertificateFd(struct Server *server, int fd) {
555   sslSetCertificateFd(&server->ssl, fd);
556 }
557
558 void serverSetNumericHosts(struct Server *server, int numericHosts) {
559   server->numericHosts = numericHosts;
560 }
561
562 struct Trie *serverGetHttpHandlers(struct Server *server) {
563   return &server->handlers;
564 }
This page took 1.271297 seconds and 3 git commands to generate.