1 // server.c -- Generic server that can deal with HTTP connections
2 // Copyright (C) 2008-2010 Markus Gutschke <markus@shellinabox.com>
4 // This program is free software; you can redistribute it and/or modify
5 // it under the terms of the GNU General Public License version 2 as
6 // published by the Free Software Foundation.
8 // This program is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 // GNU General Public License for more details.
13 // You should have received a copy of the GNU General Public License along
14 // with this program; if not, write to the Free Software Foundation, Inc.,
15 // 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
17 // In addition to these license terms, the author grants the following
20 // If you modify this program, or any covered work, by linking or
21 // combining it with the OpenSSL project's OpenSSL library (or a
22 // modified version of that library), containing parts covered by the
23 // terms of the OpenSSL or SSLeay licenses, the author
24 // grants you additional permission to convey the resulting work.
25 // Corresponding Source for a non-source form of such a combination
26 // shall include the source code for the parts of OpenSSL used as well
27 // as that of the covered work.
29 // You may at your option choose to remove this additional permission from
30 // the work, or from any part of it.
32 // It is possible to build this program in a way that it loads OpenSSL
33 // libraries at run-time. If doing so, the following notices are required
34 // by the OpenSSL and SSLeay licenses:
36 // This product includes software developed by the OpenSSL Project
37 // for use in the OpenSSL Toolkit. (http://www.openssl.org/)
39 // This product includes cryptographic software written by Eric Young
40 // (eay@cryptsoft.com)
43 // The most up-to-date version of this program is always available from
44 // http://shellinabox.com
48 #include <arpa/inet.h>
50 #include <netinet/in.h>
54 #include <sys/socket.h>
56 #include <sys/types.h>
59 #include "libhttp/server.h"
60 #include "libhttp/httpconnection.h"
61 #include "libhttp/ssl.h"
62 #include "logging/logging.h"
64 #define INITIAL_TIMEOUT (10*60)
66 // Maximum amount of payload (e.g. form values that have been POST'd) that we
67 // read into memory. If the application needs any more than this, the streaming
68 // API should be used, instead.
69 #define MAX_PAYLOAD_LENGTH (64<<10)
74 int (*handler)(struct HttpConnection *, void *, const char *, int);
80 static int serverCollectFullPayload(struct HttpConnection *http,
81 void *payload_, const char *buf, int len) {
82 int rc = HTTP_READ_MORE;
83 struct PayLoad *payload = (struct PayLoad *)payload_;
85 if (payload->len + len > MAX_PAYLOAD_LENGTH) {
86 httpSendReply(http, 400, "Bad Request", NO_MSG);
90 check(payload->bytes = realloc(payload->bytes, payload->len + len));
91 memcpy(payload->bytes + payload->len, buf, len);
94 const char *contentLength = getFromHashMap(httpGetHeaders(http),
98 ((contentLength && atoi(contentLength) <= payload->len) || !buf))) {
99 rc = payload->handler(http, payload->arg,
100 payload->bytes ? payload->bytes : "", payload->len);
101 free(payload->bytes);
102 payload->bytes = NULL;
106 if (rc == HTTP_SUSPEND || rc == HTTP_PARTIAL_REPLY) {
107 // Tell the other party that the connection is getting torn down, even
108 // though it requested it to be suspended.
109 payload->handler(http, payload->arg, NULL, 0);
118 static int serverCollectHandler(struct HttpConnection *http, void *handler_) {
119 struct HttpHandler *handler = handler_;
120 struct PayLoad *payload;
121 check(payload = malloc(sizeof(struct PayLoad)));
122 payload->handler = handler->streamingHandler;
123 payload->arg = handler->streamingArg;
125 payload->bytes = malloc(0);
126 httpSetCallback(http, serverCollectFullPayload, payload);
127 return HTTP_READ_MORE;
131 static void serverDestroyHandlers(void *arg, char *value) {
136 void serverRegisterHttpHandler(struct Server *server, const char *url,
137 int (*handler)(struct HttpConnection *, void *,
138 const char *, int), void *arg) {
140 addToTrie(&server->handlers, url, NULL);
142 struct HttpHandler *h;
143 check(h = malloc(sizeof(struct HttpHandler)));
144 h->handler = serverCollectHandler;
146 h->streamingHandler = handler;
147 h->websocketHandler = NULL;
148 h->streamingArg = arg;
149 addToTrie(&server->handlers, url, (char *)h);
153 void serverRegisterStreamingHttpHandler(struct Server *server, const char *url,
154 int (*handler)(struct HttpConnection *, void *),
157 addToTrie(&server->handlers, url, NULL);
159 struct HttpHandler *h;
160 check(h = malloc(sizeof(struct HttpHandler)));
161 h->handler = handler;
162 h->streamingHandler = NULL;
163 h->websocketHandler = NULL;
164 h->streamingArg = NULL;
166 addToTrie(&server->handlers, url, (char *)h);
170 void serverRegisterWebSocketHandler(struct Server *server, const char *url,
171 int (*handler)(struct HttpConnection *, void *, int, const char *, int),
174 addToTrie(&server->handlers, url, NULL);
176 struct HttpHandler *h;
177 check(h = malloc(sizeof(struct HttpHandler)));
179 h->streamingHandler = NULL;
180 h->websocketHandler = handler;
182 addToTrie(&server->handlers, url, (char *)h);
186 static int serverQuitHandler(struct HttpConnection *http, void *arg) {
188 httpSendReply(http, 200, "Good Bye", NO_MSG);
189 httpExitLoop(http, 1);
193 struct Server *newCGIServer(int localhostOnly, int portMin, int portMax,
195 struct Server *server;
196 check(server = malloc(sizeof(struct Server)));
197 initServer(server, localhostOnly, portMin, portMax, timeout);
201 struct Server *newServer(int localhostOnly, int port) {
202 return newCGIServer(localhostOnly, port, port, -1);
205 void initServer(struct Server *server, int localhostOnly, int portMin,
206 int portMax, int timeout) {
209 server->serverTimeout = timeout;
210 server->numericHosts = 0;
211 server->connections = NULL;
212 server->numConnections = 0;
215 server->serverFd = socket(PF_INET, SOCK_STREAM, 0);
216 check(server->serverFd >= 0);
217 check(!setsockopt(server->serverFd, SOL_SOCKET, SO_REUSEADDR,
218 &true, sizeof(true)));
219 struct sockaddr_in serverAddr = { 0 };
220 serverAddr.sin_family = AF_INET;
221 serverAddr.sin_addr.s_addr = htonl(localhostOnly
222 ? INADDR_LOOPBACK : INADDR_ANY);
224 // Linux unlike BSD does not have support for picking a local port range.
225 // So, we have to randomly pick a port from our allowed port range, and then
226 // keep iterating until we find an unused port.
227 if (portMin || portMax) {
229 check(!gettimeofday(&tv, NULL));
230 srand((int)(tv.tv_usec ^ tv.tv_sec));
232 check(portMax < 65536);
233 check(portMax >= portMin);
234 int portStart = rand() % (portMax - portMin + 1) + portMin;
235 for (int p = 0; p <= portMax-portMin; p++) {
236 int port = (p+portStart)%(portMax-portMin+1)+ portMin;
237 serverAddr.sin_port = htons(port);
238 if (!bind(server->serverFd, (struct sockaddr *)&serverAddr,
239 sizeof(serverAddr))) {
242 serverAddr.sin_port = 0;
244 if (!serverAddr.sin_port) {
245 fatal("Failed to find any available port");
249 check(!listen(server->serverFd, SOMAXCONN));
250 socklen_t socklen = (socklen_t)sizeof(serverAddr);
251 check(!getsockname(server->serverFd, (struct sockaddr *)&serverAddr,
253 check(socklen == sizeof(serverAddr));
254 server->port = ntohs(serverAddr.sin_port);
255 info("Listening on port %d", server->port);
257 check(server->pollFds = malloc(sizeof(struct pollfd)));
258 server->pollFds->fd = server->serverFd;
259 server->pollFds->events = POLLIN;
261 initTrie(&server->handlers, serverDestroyHandlers, NULL);
262 serverRegisterStreamingHttpHandler(server, "/quit", serverQuitHandler, NULL);
263 initSSL(&server->ssl);
266 void destroyServer(struct Server *server) {
268 if (server->serverFd >= 0) {
269 info("Shutting down server");
270 close(server->serverFd);
272 for (int i = 0; i < server->numConnections; i++) {
273 server->connections[i].destroyConnection(server->connections[i].arg);
275 free(server->connections);
276 free(server->pollFds);
277 destroyTrie(&server->handlers);
278 destroySSL(&server->ssl);
282 void deleteServer(struct Server *server) {
283 destroyServer(server);
287 int serverGetListeningPort(struct Server *server) {
291 int serverGetFd(struct Server *server) {
292 return server->serverFd;
295 struct ServerConnection *serverAddConnection(struct Server *server, int fd,
296 int (*handleConnection)(struct ServerConnection *c,
297 void *arg, short *events,
299 void (*destroyConnection)(void *arg),
301 check(server->connections = realloc(server->connections,
302 ++server->numConnections*
303 sizeof(struct ServerConnection)));
304 check(server->pollFds = realloc(server->pollFds,
305 (server->numConnections + 1) *
306 sizeof(struct pollfd)));
307 server->pollFds[server->numConnections].fd = fd;
308 server->pollFds[server->numConnections].events = POLLIN;
309 struct ServerConnection *connection =
310 server->connections + server->numConnections - 1;
311 connection->deleted = 0;
312 connection->timeout = 0;
313 connection->handleConnection = handleConnection;
314 connection->destroyConnection = destroyConnection;
315 connection->arg = arg;
319 void serverDeleteConnection(struct Server *server, int fd) {
320 for (int i = 0; i < server->numConnections; i++) {
321 if (fd == server->pollFds[i + 1].fd && !server->connections[i].deleted) {
322 server->connections[i].deleted = 1;
323 server->connections[i].destroyConnection(server->connections[i].arg);
329 void serverSetTimeout(struct ServerConnection *connection, time_t timeout) {
331 currentTime = time(NULL);
333 connection->timeout = timeout > 0 ? timeout + currentTime : 0;
336 time_t serverGetTimeout(struct ServerConnection *connection) {
337 if (connection->timeout) {
338 // Returns <0 if expired, 0 if not set, and >0 if still pending.
340 currentTime = time(NULL);
342 int remaining = connection->timeout - currentTime;
352 struct ServerConnection *serverGetConnection(struct Server *server,
353 struct ServerConnection *hint,
356 server->connections <= hint &&
357 server->connections + server->numConnections > hint) {
358 // The compiler would like to optimize the expression:
359 // &server->connections[hint - server->connections] <=>
360 // server->connections + hint - server->connections <=>
362 // This transformation is correct as far as the language specification is
363 // concerned, but it is unintended as we actually want to check whether
364 // the alignment is correct. So, instead of comparing
365 // &server->connections[hint - server->connections] == hint
366 // we first use memcpy() to break aliasing.
367 uintptr_t ptr1, ptr2;
368 memcpy(&ptr1, &hint, sizeof(ptr1));
369 memcpy(&ptr2, &server->connections, sizeof(ptr2));
370 int idx = (ptr1 - ptr2)/sizeof(*server->connections);
371 if (&server->connections[idx] == hint &&
373 server->pollFds[hint - server->connections + 1].fd == fd) {
377 for (int i = 0; i < server->numConnections; i++) {
378 if (server->pollFds[i + 1].fd == fd && !server->connections[i].deleted) {
379 return server->connections + i;
385 short serverConnectionSetEvents(struct Server *server,
386 struct ServerConnection *connection,
390 dcheck(connection >= server->connections);
391 dcheck(connection < server->connections + server->numConnections);
392 dcheck(connection == &server->connections[connection - server->connections]);
393 dcheck(!connection->deleted);
394 int idx = connection - server->connections;
395 short oldEvents = server->pollFds[idx + 1].events;
396 server->pollFds[idx + 1].events = events;
400 void serverExitLoop(struct Server *server, int exitAll) {
402 server->exitAll |= exitAll;
405 void serverLoop(struct Server *server) {
406 check(server->serverFd >= 0);
408 currentTime = time(&lastTime);
409 int loopDepth = ++server->looping;
410 while (server->looping >= loopDepth && !server->exitAll) {
411 // TODO: There probably should be some limit on the maximum number
412 // of concurrently opened HTTP connections, as this could lead to
413 // memory exhaustion and a DoS attack.
415 int numFds = server->numConnections + 1;
417 for (int i = 0; i < server->numConnections; i++) {
418 while (i < numFds - 1 && !server->pollFds[i + 1].events) {
419 // Sort filedescriptors that currently do not expect any events to
420 // the end of the list.
422 struct pollfd tmpPollFd;
423 memmove(&tmpPollFd, server->pollFds + numFds, sizeof(struct pollfd));
424 memmove(server->pollFds + numFds, server->pollFds + i + 1,
425 sizeof(struct pollfd));
426 memmove(server->pollFds + i + 1, &tmpPollFd, sizeof(struct pollfd));
427 struct ServerConnection tmpConnection;
428 memmove(&tmpConnection, server->connections + numFds - 1,
429 sizeof(struct ServerConnection));
430 memmove(server->connections + numFds - 1, server->connections + i,
431 sizeof(struct ServerConnection));
432 memmove(server->connections + i, &tmpConnection,
433 sizeof(struct ServerConnection));
436 if (server->connections[i].timeout &&
437 (timeout < 0 || timeout > server->connections[i].timeout)) {
438 timeout = server->connections[i].timeout;
442 // serverTimeout is always a delta value, unlike connection timeouts
443 // which are absolute times.
444 if (server->serverTimeout >= 0) {
445 if (timeout < 0 || timeout > server->serverTimeout + currentTime) {
446 timeout = server->serverTimeout+currentTime;
451 // Wait at least one second longer than needed, so that even if
452 // poll() decides to return a second early (due to possible rounding
453 // errors), we still correctly detect a timeout condition.
454 if (timeout >= lastTime) {
455 timeout = (timeout - lastTime + 1) * 1000;
461 int eventCount = NOINTR(poll(server->pollFds,
464 check(eventCount >= 0);
468 currentTime = time(&lastTime);
469 int isTimeout = timeout >= 0 &&
470 timeout/1000 <= lastTime;
471 if (server->pollFds[0].revents) {
473 if (server->pollFds[0].revents && POLLIN) {
474 struct sockaddr_in clientAddr;
475 socklen_t sockLen = sizeof(clientAddr);
476 int clientFd = accept(
477 server->serverFd, (struct sockaddr *)&clientAddr, &sockLen);
478 dcheck(clientFd >= 0);
480 check(!fcntl(clientFd, F_SETFL, O_RDWR | O_NONBLOCK));
481 struct HttpConnection *http;
482 http = newHttpConnection(
483 server, clientFd, server->port,
484 server->ssl.enabled ? &server->ssl : NULL,
485 server->numericHosts);
487 serverAddConnection(server, clientFd, httpHandleConnection,
488 (void (*)(void *))deleteHttpConnection,
494 if (server->serverTimeout > 0 && !server->numConnections) {
495 // In CGI mode, exit the server, if we haven't had any active
496 // connections in a while.
501 (isTimeout || eventCount > 0) && i <= server->numConnections;
503 struct ServerConnection *connection = server->connections + i - 1;
504 if (connection->deleted) {
508 server->pollFds[i].revents = 0;
510 if (server->pollFds[i].revents ||
511 (connection->timeout && lastTime >= connection->timeout)) {
512 if (server->pollFds[i].revents) {
515 short events = server->pollFds[i].events;
516 if (!connection->handleConnection(connection, connection->arg,
517 &events, server->pollFds[i].revents)){
518 connection = server->connections + i - 1;
519 connection->destroyConnection(connection->arg);
520 connection->deleted = 1;
522 server->pollFds[i].events = events;
526 for (int i = 1; i <= server->numConnections; i++) {
527 if (server->connections[i-1].deleted) {
528 memmove(server->pollFds + i, server->pollFds + i + 1,
529 (server->numConnections - i) * sizeof(struct pollfd));
530 memmove(server->connections + i - 1, server->connections + i,
531 (server->numConnections - i)*sizeof(struct ServerConnection));
533 check(--server->numConnections >= 0);
537 // Even if multiple clients requested for us to exit the loop, we only
538 // ever exit the outer most loop.
539 server->looping = loopDepth - 1;
542 void serverEnableSSL(struct Server *server, int flag) {
544 check(serverSupportsSSL());
546 sslEnable(&server->ssl, flag);
549 void serverSetCertificate(struct Server *server, const char *filename,
550 int autoGenerateMissing) {
551 sslSetCertificate(&server->ssl, filename, autoGenerateMissing);
554 void serverSetCertificateFd(struct Server *server, int fd) {
555 sslSetCertificateFd(&server->ssl, fd);
558 void serverSetNumericHosts(struct Server *server, int numericHosts) {
559 server->numericHosts = numericHosts;
562 struct Trie *serverGetHttpHandlers(struct Server *server) {
563 return &server->handlers;