1 // server.c -- Generic server that can deal with HTTP connections
2 // Copyright (C) 2008-2010 Markus Gutschke <markus@shellinabox.com>
4 // This program is free software; you can redistribute it and/or modify
5 // it under the terms of the GNU General Public License version 2 as
6 // published by the Free Software Foundation.
8 // This program is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 // GNU General Public License for more details.
13 // You should have received a copy of the GNU General Public License along
14 // with this program; if not, write to the Free Software Foundation, Inc.,
15 // 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
17 // In addition to these license terms, the author grants the following
20 // If you modify this program, or any covered work, by linking or
21 // combining it with the OpenSSL project's OpenSSL library (or a
22 // modified version of that library), containing parts covered by the
23 // terms of the OpenSSL or SSLeay licenses, the author
24 // grants you additional permission to convey the resulting work.
25 // Corresponding Source for a non-source form of such a combination
26 // shall include the source code for the parts of OpenSSL used as well
27 // as that of the covered work.
29 // You may at your option choose to remove this additional permission from
30 // the work, or from any part of it.
32 // It is possible to build this program in a way that it loads OpenSSL
33 // libraries at run-time. If doing so, the following notices are required
34 // by the OpenSSL and SSLeay licenses:
36 // This product includes software developed by the OpenSSL Project
37 // for use in the OpenSSL Toolkit. (http://www.openssl.org/)
39 // This product includes cryptographic software written by Eric Young
40 // (eay@cryptsoft.com)
43 // The most up-to-date version of this program is always available from
44 // http://shellinabox.com
48 #include <arpa/inet.h>
50 #include <netinet/in.h>
54 #include <sys/socket.h>
56 #include <sys/types.h>
59 #include "libhttp/server.h"
60 #include "libhttp/httpconnection.h"
61 #include "libhttp/ssl.h"
62 #include "logging/logging.h"
65 #defined ATTR_UNUSED __attribute__((unused))
66 #defined UNUSED(x) do { } while (0)
69 #define UNUSED(x) do { (void)(x); } while (0)
72 #define INITIAL_TIMEOUT (10*60)
74 // Maximum amount of payload (e.g. form values that have been POST'd) that we
75 // read into memory. If the application needs any more than this, the streaming
76 // API should be used, instead.
77 #define MAX_PAYLOAD_LENGTH (64<<10)
80 #if defined(__APPLE__) && defined(__MACH__)
81 // While MacOS X does ship with an implementation of poll(), this
82 // implementation is apparently known to be broken and does not comply
83 // with POSIX standards. Fortunately, the operating system is not entirely
84 // unable to check for input events. We can fall back on calling select()
85 // instead. This is generally not desirable, as it is less efficient and
86 // has a compile-time restriction on the maximum number of file
87 // descriptors. But on MacOS X, that's the best we can do.
89 int x_poll(struct pollfd *fds, nfds_t nfds, int timeout) {
95 for (int i = 0; i < nfds; ++i) {
96 if (fds[i].fd > maxFd) {
98 } else if (fds[i].fd < 0) {
101 if (fds[i].events & POLLIN) {
102 FD_SET(fds[i].fd, &r);
104 if (fds[i].events & POLLOUT) {
105 FD_SET(fds[i].fd, &w);
107 if (fds[i].events & POLLPRI) {
108 FD_SET(fds[i].fd, &x);
111 struct timeval tmoVal = { 0 }, *tmo;
115 tmoVal.tv_sec = timeout / 1000;
116 tmoVal.tv_usec = (timeout % 1000) * 1000;
119 int numRet = select(maxFd + 1, &r, &w, &x, tmo);
120 for (int i = 0, n = numRet; i < nfds && n > 0; ++i) {
124 if (FD_ISSET(fds[i].fd, &x)) {
125 fds[i].revents = POLLPRI;
126 } else if (FD_ISSET(fds[i].fd, &r)) {
127 fds[i].revents = POLLIN;
131 if (FD_ISSET(fds[i].fd, &w)) {
132 fds[i].revents |= POLLOUT;
143 int (*handler)(struct HttpConnection *, void *, const char *, int);
149 static int serverCollectFullPayload(struct HttpConnection *http,
150 void *payload_, const char *buf, int len) {
151 int rc = HTTP_READ_MORE;
152 struct PayLoad *payload = (struct PayLoad *)payload_;
154 if (payload->len + len > MAX_PAYLOAD_LENGTH) {
155 httpSendReply(http, 400, "Bad Request", NO_MSG);
159 check(payload->bytes = realloc(payload->bytes, payload->len + len));
160 memcpy(payload->bytes + payload->len, buf, len);
163 const char *contentLength = getFromHashMap(httpGetHeaders(http),
165 if (!contentLength ||
167 ((contentLength && atoi(contentLength) <= payload->len) || !buf))) {
168 rc = payload->handler(http, payload->arg,
169 payload->bytes ? payload->bytes : "", payload->len);
170 free(payload->bytes);
171 payload->bytes = NULL;
175 if (rc == HTTP_SUSPEND || rc == HTTP_PARTIAL_REPLY) {
176 // Tell the other party that the connection is getting torn down, even
177 // though it requested it to be suspended.
178 payload->handler(http, payload->arg, NULL, 0);
187 static int serverCollectHandler(struct HttpConnection *http, void *handler_) {
188 struct HttpHandler *handler = handler_;
189 struct PayLoad *payload;
190 check(payload = malloc(sizeof(struct PayLoad)));
191 payload->handler = handler->streamingHandler;
192 payload->arg = handler->streamingArg;
194 payload->bytes = malloc(0);
195 httpSetCallback(http, serverCollectFullPayload, payload);
196 return HTTP_READ_MORE;
200 static void serverDestroyHandlers(void *arg ATTR_UNUSED, char *value) {
205 void serverRegisterHttpHandler(struct Server *server, const char *url,
206 int (*handler)(struct HttpConnection *, void *,
207 const char *, int), void *arg) {
209 addToTrie(&server->handlers, url, NULL);
211 struct HttpHandler *h;
212 check(h = malloc(sizeof(struct HttpHandler)));
213 h->handler = serverCollectHandler;
215 h->streamingHandler = handler;
216 h->websocketHandler = NULL;
217 h->streamingArg = arg;
218 addToTrie(&server->handlers, url, (char *)h);
222 void serverRegisterStreamingHttpHandler(struct Server *server, const char *url,
223 int (*handler)(struct HttpConnection *, void *),
226 addToTrie(&server->handlers, url, NULL);
228 struct HttpHandler *h;
229 check(h = malloc(sizeof(struct HttpHandler)));
230 h->handler = handler;
231 h->streamingHandler = NULL;
232 h->websocketHandler = NULL;
233 h->streamingArg = NULL;
235 addToTrie(&server->handlers, url, (char *)h);
239 void serverRegisterWebSocketHandler(struct Server *server, const char *url,
240 int (*handler)(struct HttpConnection *, void *, int, const char *, int),
243 addToTrie(&server->handlers, url, NULL);
245 struct HttpHandler *h;
246 check(h = malloc(sizeof(struct HttpHandler)));
248 h->streamingHandler = NULL;
249 h->websocketHandler = handler;
251 addToTrie(&server->handlers, url, (char *)h);
255 static int serverQuitHandler(struct HttpConnection *http ATTR_UNUSED,
258 httpSendReply(http, 200, "Good Bye", NO_MSG);
259 httpExitLoop(http, 1);
263 struct Server *newCGIServer(int localhostOnly, int portMin, int portMax,
265 struct Server *server;
266 check(server = malloc(sizeof(struct Server)));
267 initServer(server, localhostOnly, portMin, portMax, timeout);
271 struct Server *newServer(int localhostOnly, int port) {
272 return newCGIServer(localhostOnly, port, port, -1);
275 void initServer(struct Server *server, int localhostOnly, int portMin,
276 int portMax, int timeout) {
279 server->serverTimeout = timeout;
280 server->numericHosts = 0;
281 server->connections = NULL;
282 server->numConnections = 0;
285 server->serverFd = socket(PF_INET, SOCK_STREAM, 0);
286 check(server->serverFd >= 0);
287 check(!setsockopt(server->serverFd, SOL_SOCKET, SO_REUSEADDR,
288 &true, sizeof(true)));
289 struct sockaddr_in serverAddr = { 0 };
290 serverAddr.sin_family = AF_INET;
291 serverAddr.sin_addr.s_addr = htonl(localhostOnly
292 ? INADDR_LOOPBACK : INADDR_ANY);
294 // Linux unlike BSD does not have support for picking a local port range.
295 // So, we have to randomly pick a port from our allowed port range, and then
296 // keep iterating until we find an unused port.
297 if (portMin || portMax) {
299 check(!gettimeofday(&tv, NULL));
300 srand((int)(tv.tv_usec ^ tv.tv_sec));
302 check(portMax < 65536);
303 check(portMax >= portMin);
304 int portStart = rand() % (portMax - portMin + 1) + portMin;
305 for (int p = 0; p <= portMax-portMin; p++) {
306 int port = (p+portStart)%(portMax-portMin+1)+ portMin;
307 serverAddr.sin_port = htons(port);
308 if (!bind(server->serverFd, (struct sockaddr *)&serverAddr,
309 sizeof(serverAddr))) {
312 serverAddr.sin_port = 0;
314 if (!serverAddr.sin_port) {
315 fatal("Failed to find any available port");
319 check(!listen(server->serverFd, SOMAXCONN));
320 socklen_t socklen = (socklen_t)sizeof(serverAddr);
321 check(!getsockname(server->serverFd, (struct sockaddr *)&serverAddr,
323 check(socklen == sizeof(serverAddr));
324 server->port = ntohs(serverAddr.sin_port);
325 info("Listening on port %d", server->port);
327 check(server->pollFds = malloc(sizeof(struct pollfd)));
328 server->pollFds->fd = server->serverFd;
329 server->pollFds->events = POLLIN;
331 initTrie(&server->handlers, serverDestroyHandlers, NULL);
332 serverRegisterStreamingHttpHandler(server, "/quit", serverQuitHandler, NULL);
333 initSSL(&server->ssl);
336 void destroyServer(struct Server *server) {
338 if (server->serverFd >= 0) {
339 info("Shutting down server");
340 close(server->serverFd);
342 for (int i = 0; i < server->numConnections; i++) {
343 server->connections[i].destroyConnection(server->connections[i].arg);
345 free(server->connections);
346 free(server->pollFds);
347 destroyTrie(&server->handlers);
348 destroySSL(&server->ssl);
352 void deleteServer(struct Server *server) {
353 destroyServer(server);
357 int serverGetListeningPort(struct Server *server) {
361 int serverGetFd(struct Server *server) {
362 return server->serverFd;
365 struct ServerConnection *serverAddConnection(struct Server *server, int fd,
366 int (*handleConnection)(struct ServerConnection *c,
367 void *arg, short *events,
369 void (*destroyConnection)(void *arg),
371 check(server->connections = realloc(server->connections,
372 ++server->numConnections*
373 sizeof(struct ServerConnection)));
374 check(server->pollFds = realloc(server->pollFds,
375 (server->numConnections + 1) *
376 sizeof(struct pollfd)));
377 server->pollFds[server->numConnections].fd = fd;
378 server->pollFds[server->numConnections].events = POLLIN;
379 struct ServerConnection *connection =
380 server->connections + server->numConnections - 1;
381 connection->deleted = 0;
382 connection->timeout = 0;
383 connection->handleConnection = handleConnection;
384 connection->destroyConnection = destroyConnection;
385 connection->arg = arg;
389 void serverDeleteConnection(struct Server *server, int fd) {
390 for (int i = 0; i < server->numConnections; i++) {
391 if (fd == server->pollFds[i + 1].fd && !server->connections[i].deleted) {
392 server->connections[i].deleted = 1;
393 server->connections[i].destroyConnection(server->connections[i].arg);
399 void serverSetTimeout(struct ServerConnection *connection, time_t timeout) {
401 currentTime = time(NULL);
403 connection->timeout = timeout > 0 ? timeout + currentTime : 0;
406 time_t serverGetTimeout(struct ServerConnection *connection) {
407 if (connection->timeout) {
408 // Returns <0 if expired, 0 if not set, and >0 if still pending.
410 currentTime = time(NULL);
412 int remaining = connection->timeout - currentTime;
422 struct ServerConnection *serverGetConnection(struct Server *server,
423 struct ServerConnection *hint,
426 server->connections <= hint &&
427 server->connections + server->numConnections > hint) {
428 // The compiler would like to optimize the expression:
429 // &server->connections[hint - server->connections] <=>
430 // server->connections + hint - server->connections <=>
432 // This transformation is correct as far as the language specification is
433 // concerned, but it is unintended as we actually want to check whether
434 // the alignment is correct. So, instead of comparing
435 // &server->connections[hint - server->connections] == hint
436 // we first use memcpy() to break aliasing.
437 uintptr_t ptr1, ptr2;
438 memcpy(&ptr1, &hint, sizeof(ptr1));
439 memcpy(&ptr2, &server->connections, sizeof(ptr2));
440 int idx = (ptr1 - ptr2)/sizeof(*server->connections);
441 if (&server->connections[idx] == hint &&
443 server->pollFds[hint - server->connections + 1].fd == fd) {
447 for (int i = 0; i < server->numConnections; i++) {
448 if (server->pollFds[i + 1].fd == fd && !server->connections[i].deleted) {
449 return server->connections + i;
455 short serverConnectionSetEvents(struct Server *server,
456 struct ServerConnection *connection, int fd,
460 dcheck(connection >= server->connections);
461 dcheck(connection < server->connections + server->numConnections);
462 dcheck(connection == &server->connections[connection - server->connections]);
463 dcheck(!connection->deleted);
464 int idx = connection - server->connections;
465 short oldEvents = server->pollFds[idx + 1].events;
466 dcheck(fd == server->pollFds[idx + 1].fd);
467 server->pollFds[idx + 1].events = events;
471 void serverExitLoop(struct Server *server, int exitAll) {
473 server->exitAll |= exitAll;
476 void serverLoop(struct Server *server) {
477 check(server->serverFd >= 0);
479 currentTime = time(&lastTime);
480 int loopDepth = ++server->looping;
481 while (server->looping >= loopDepth && !server->exitAll) {
482 // TODO: There probably should be some limit on the maximum number
483 // of concurrently opened HTTP connections, as this could lead to
484 // memory exhaustion and a DoS attack.
486 int numFds = server->numConnections + 1;
488 for (int i = 0; i < server->numConnections; i++) {
489 while (i < numFds - 1 && !server->pollFds[i + 1].events) {
490 // Sort filedescriptors that currently do not expect any events to
491 // the end of the list.
493 struct pollfd tmpPollFd;
494 memmove(&tmpPollFd, server->pollFds + numFds, sizeof(struct pollfd));
495 memmove(server->pollFds + numFds, server->pollFds + i + 1,
496 sizeof(struct pollfd));
497 memmove(server->pollFds + i + 1, &tmpPollFd, sizeof(struct pollfd));
498 struct ServerConnection tmpConnection;
499 memmove(&tmpConnection, server->connections + numFds - 1,
500 sizeof(struct ServerConnection));
501 memmove(server->connections + numFds - 1, server->connections + i,
502 sizeof(struct ServerConnection));
503 memmove(server->connections + i, &tmpConnection,
504 sizeof(struct ServerConnection));
507 if (server->connections[i].timeout &&
508 (timeout < 0 || timeout > server->connections[i].timeout)) {
509 timeout = server->connections[i].timeout;
513 // serverTimeout is always a delta value, unlike connection timeouts
514 // which are absolute times.
515 if (server->serverTimeout >= 0) {
516 if (timeout < 0 || timeout > server->serverTimeout + currentTime) {
517 timeout = server->serverTimeout+currentTime;
522 // Wait at least one second longer than needed, so that even if
523 // poll() decides to return a second early (due to possible rounding
524 // errors), we still correctly detect a timeout condition.
525 if (timeout >= lastTime) {
526 timeout = (timeout - lastTime + 1) * 1000;
532 int eventCount = NOINTR(poll(server->pollFds,
535 check(eventCount >= 0);
539 currentTime = time(&lastTime);
540 int isTimeout = timeout >= 0 &&
541 timeout/1000 <= lastTime;
542 if (eventCount > 0 && server->pollFds[0].revents) {
544 if (server->pollFds[0].revents && POLLIN) {
545 struct sockaddr_in clientAddr;
546 socklen_t sockLen = sizeof(clientAddr);
547 int clientFd = accept(
548 server->serverFd, (struct sockaddr *)&clientAddr, &sockLen);
549 dcheck(clientFd >= 0);
551 check(!fcntl(clientFd, F_SETFL, O_RDWR | O_NONBLOCK));
552 struct HttpConnection *http;
553 http = newHttpConnection(
554 server, clientFd, server->port,
555 server->ssl.enabled ? &server->ssl : NULL,
556 server->numericHosts);
558 serverAddConnection(server, clientFd, httpHandleConnection,
559 (void (*)(void *))deleteHttpConnection,
565 if (server->serverTimeout > 0 && !server->numConnections) {
566 // In CGI mode, exit the server, if we haven't had any active
567 // connections in a while.
572 (isTimeout || eventCount > 0) && i <= server->numConnections;
574 struct ServerConnection *connection = server->connections + i - 1;
575 if (connection->deleted) {
579 server->pollFds[i].revents = 0;
581 if (server->pollFds[i].revents ||
582 (connection->timeout && lastTime >= connection->timeout)) {
583 if (server->pollFds[i].revents) {
586 if (!connection->handleConnection(connection, connection->arg,
587 &server->pollFds[i].events,
588 server->pollFds[i].revents)) {
589 connection = server->connections + i - 1;
590 connection->destroyConnection(connection->arg);
591 connection->deleted = 1;
595 for (int i = 1; i <= server->numConnections; i++) {
596 if (server->connections[i-1].deleted) {
597 memmove(server->pollFds + i, server->pollFds + i + 1,
598 (server->numConnections - i) * sizeof(struct pollfd));
599 memmove(server->connections + i - 1, server->connections + i,
600 (server->numConnections - i)*sizeof(struct ServerConnection));
602 check(--server->numConnections >= 0);
606 // Even if multiple clients requested for us to exit the loop, we only
607 // ever exit the outer most loop.
608 server->looping = loopDepth - 1;
611 void serverEnableSSL(struct Server *server, int flag) {
613 check(serverSupportsSSL());
615 sslEnable(&server->ssl, flag);
618 void serverSetCertificate(struct Server *server, const char *filename,
619 int autoGenerateMissing) {
620 sslSetCertificate(&server->ssl, filename, autoGenerateMissing);
623 void serverSetCertificateFd(struct Server *server, int fd) {
624 sslSetCertificateFd(&server->ssl, fd);
627 void serverSetNumericHosts(struct Server *server, int numericHosts) {
628 server->numericHosts = numericHosts;
631 struct Trie *serverGetHttpHandlers(struct Server *server) {
632 return &server->handlers;