|
1 -- |
|
2 -- server.lua by blastbeat of the luadch project |
|
3 -- Re-used here under the MIT/X Consortium License |
|
4 -- |
|
5 -- Modifications (C) 2008-2010 Matthew Wild, Waqas Hussain |
|
6 -- |
|
7 |
|
8 -- // wrapping luadch stuff // -- |
|
9 |
|
10 local use = function( what ) |
|
11 return _G[ what ] |
|
12 end |
|
13 |
|
14 local log, table_concat = require ("util.logger").init("socket"), table.concat; |
|
15 local out_put = function (...) return log("debug", table_concat{...}); end |
|
16 local out_error = function (...) return log("warn", table_concat{...}); end |
|
17 |
|
18 ----------------------------------// DECLARATION //-- |
|
19 |
|
20 --// constants //-- |
|
21 |
|
22 local STAT_UNIT = 1 -- byte |
|
23 |
|
24 --// lua functions //-- |
|
25 |
|
26 local type = use "type" |
|
27 local pairs = use "pairs" |
|
28 local ipairs = use "ipairs" |
|
29 local tonumber = use "tonumber" |
|
30 local tostring = use "tostring" |
|
31 |
|
32 --// lua libs //-- |
|
33 |
|
34 local os = use "os" |
|
35 local table = use "table" |
|
36 local string = use "string" |
|
37 local coroutine = use "coroutine" |
|
38 |
|
39 --// lua lib methods //-- |
|
40 |
|
41 local os_difftime = os.difftime |
|
42 local math_min = math.min |
|
43 local math_huge = math.huge |
|
44 local table_concat = table.concat |
|
45 local table_insert = table.insert |
|
46 local string_sub = string.sub |
|
47 local coroutine_wrap = coroutine.wrap |
|
48 local coroutine_yield = coroutine.yield |
|
49 |
|
50 --// extern libs //-- |
|
51 |
|
52 local has_luasec, luasec = pcall ( require , "ssl" ) |
|
53 local luasocket = use "socket" or require "socket" |
|
54 local luasocket_gettime = luasocket.gettime |
|
55 local getaddrinfo = luasocket.dns.getaddrinfo |
|
56 |
|
57 --// extern lib methods //-- |
|
58 |
|
59 local ssl_wrap = ( has_luasec and luasec.wrap ) |
|
60 local socket_bind = luasocket.bind |
|
61 local socket_sleep = luasocket.sleep |
|
62 local socket_select = luasocket.select |
|
63 |
|
64 --// functions //-- |
|
65 |
|
66 local id |
|
67 local loop |
|
68 local stats |
|
69 local idfalse |
|
70 local closeall |
|
71 local addsocket |
|
72 local addserver |
|
73 local addtimer |
|
74 local getserver |
|
75 local wrapserver |
|
76 local getsettings |
|
77 local closesocket |
|
78 local removesocket |
|
79 local removeserver |
|
80 local wrapconnection |
|
81 local changesettings |
|
82 |
|
83 --// tables //-- |
|
84 |
|
85 local _server |
|
86 local _readlist |
|
87 local _timerlist |
|
88 local _sendlist |
|
89 local _socketlist |
|
90 local _closelist |
|
91 local _readtimes |
|
92 local _writetimes |
|
93 |
|
94 --// simple data types //-- |
|
95 |
|
96 local _ |
|
97 local _readlistlen |
|
98 local _sendlistlen |
|
99 local _timerlistlen |
|
100 |
|
101 local _sendtraffic |
|
102 local _readtraffic |
|
103 |
|
104 local _selecttimeout |
|
105 local _sleeptime |
|
106 local _tcpbacklog |
|
107 |
|
108 local _starttime |
|
109 local _currenttime |
|
110 |
|
111 local _maxsendlen |
|
112 local _maxreadlen |
|
113 |
|
114 local _checkinterval |
|
115 local _sendtimeout |
|
116 local _readtimeout |
|
117 |
|
118 local _timer |
|
119 |
|
120 local _maxselectlen |
|
121 local _maxfd |
|
122 |
|
123 local _maxsslhandshake |
|
124 |
|
125 ----------------------------------// DEFINITION //-- |
|
126 |
|
127 _server = { } -- key = port, value = table; list of listening servers |
|
128 _readlist = { } -- array with sockets to read from |
|
129 _sendlist = { } -- arrary with sockets to write to |
|
130 _timerlist = { } -- array of timer functions |
|
131 _socketlist = { } -- key = socket, value = wrapped socket (handlers) |
|
132 _readtimes = { } -- key = handler, value = timestamp of last data reading |
|
133 _writetimes = { } -- key = handler, value = timestamp of last data writing/sending |
|
134 _closelist = { } -- handlers to close |
|
135 |
|
136 _readlistlen = 0 -- length of readlist |
|
137 _sendlistlen = 0 -- length of sendlist |
|
138 _timerlistlen = 0 -- lenght of timerlist |
|
139 |
|
140 _sendtraffic = 0 -- some stats |
|
141 _readtraffic = 0 |
|
142 |
|
143 _selecttimeout = 1 -- timeout of socket.select |
|
144 _sleeptime = 0 -- time to wait at the end of every loop |
|
145 _tcpbacklog = 128 -- some kind of hint to the OS |
|
146 |
|
147 _maxsendlen = 51000 * 1024 -- max len of send buffer |
|
148 _maxreadlen = 25000 * 1024 -- max len of read buffer |
|
149 |
|
150 _checkinterval = 30 -- interval in secs to check idle clients |
|
151 _sendtimeout = 60000 -- allowed send idle time in secs |
|
152 _readtimeout = 6 * 60 * 60 -- allowed read idle time in secs |
|
153 |
|
154 local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows |
|
155 _maxfd = (is_windows and math.huge) or luasocket._SETSIZE or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows |
|
156 _maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows |
|
157 |
|
158 _maxsslhandshake = 30 -- max handshake round-trips |
|
159 |
|
160 ----------------------------------// PRIVATE //-- |
|
161 |
|
162 wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd |
|
163 |
|
164 if socket:getfd() >= _maxfd then |
|
165 out_error("server.lua: Disallowed FD number: "..socket:getfd()) |
|
166 socket:close() |
|
167 return nil, "fd-too-large" |
|
168 end |
|
169 |
|
170 local connections = 0 |
|
171 |
|
172 local dispatch, disconnect = listeners.onconnect, listeners.ondisconnect |
|
173 |
|
174 local accept = socket.accept |
|
175 |
|
176 --// public methods of the object //-- |
|
177 |
|
178 local handler = { } |
|
179 |
|
180 handler.shutdown = function( ) end |
|
181 |
|
182 handler.ssl = function( ) |
|
183 return sslctx ~= nil |
|
184 end |
|
185 handler.sslctx = function( ) |
|
186 return sslctx |
|
187 end |
|
188 handler.remove = function( ) |
|
189 connections = connections - 1 |
|
190 if handler then |
|
191 handler.resume( ) |
|
192 end |
|
193 end |
|
194 handler.close = function() |
|
195 socket:close( ) |
|
196 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) |
|
197 _readlistlen = removesocket( _readlist, socket, _readlistlen ) |
|
198 _server[ip..":"..serverport] = nil; |
|
199 _socketlist[ socket ] = nil |
|
200 handler = nil |
|
201 socket = nil |
|
202 --mem_free( ) |
|
203 out_put "server.lua: closed server handler and removed sockets from list" |
|
204 end |
|
205 handler.pause = function( hard ) |
|
206 if not handler.paused then |
|
207 _readlistlen = removesocket( _readlist, socket, _readlistlen ) |
|
208 if hard then |
|
209 _socketlist[ socket ] = nil |
|
210 socket:close( ) |
|
211 socket = nil; |
|
212 end |
|
213 handler.paused = true; |
|
214 end |
|
215 end |
|
216 handler.resume = function( ) |
|
217 if handler.paused then |
|
218 if not socket then |
|
219 socket = socket_bind( ip, serverport, _tcpbacklog ); |
|
220 socket:settimeout( 0 ) |
|
221 end |
|
222 _readlistlen = addsocket(_readlist, socket, _readlistlen) |
|
223 _socketlist[ socket ] = handler |
|
224 handler.paused = false; |
|
225 end |
|
226 end |
|
227 handler.ip = function( ) |
|
228 return ip |
|
229 end |
|
230 handler.serverport = function( ) |
|
231 return serverport |
|
232 end |
|
233 handler.socket = function( ) |
|
234 return socket |
|
235 end |
|
236 handler.readbuffer = function( ) |
|
237 if _readlistlen >= _maxselectlen or _sendlistlen >= _maxselectlen then |
|
238 handler.pause( ) |
|
239 out_put( "server.lua: refused new client connection: server full" ) |
|
240 return false |
|
241 end |
|
242 local client, err = accept( socket ) -- try to accept |
|
243 if client then |
|
244 local ip, clientport = client:getpeername( ) |
|
245 local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket |
|
246 if err then -- error while wrapping ssl socket |
|
247 return false |
|
248 end |
|
249 connections = connections + 1 |
|
250 out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport)) |
|
251 if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes |
|
252 return dispatch( handler ); |
|
253 end |
|
254 return; |
|
255 elseif err then -- maybe timeout or something else |
|
256 out_put( "server.lua: error with new client connection: ", tostring(err) ) |
|
257 return false |
|
258 end |
|
259 end |
|
260 return handler |
|
261 end |
|
262 |
|
263 wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object |
|
264 |
|
265 if socket:getfd() >= _maxfd then |
|
266 out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent |
|
267 socket:close( ) -- Should we send some kind of error here? |
|
268 if server then |
|
269 server.pause( ) |
|
270 end |
|
271 return nil, nil, "fd-too-large" |
|
272 end |
|
273 socket:settimeout( 0 ) |
|
274 |
|
275 --// local import of socket methods //-- |
|
276 |
|
277 local send |
|
278 local receive |
|
279 local shutdown |
|
280 |
|
281 --// private closures of the object //-- |
|
282 |
|
283 local ssl |
|
284 |
|
285 local dispatch = listeners.onincoming |
|
286 local status = listeners.onstatus |
|
287 local disconnect = listeners.ondisconnect |
|
288 local drain = listeners.ondrain |
|
289 local onreadtimeout = listeners.onreadtimeout; |
|
290 local detach = listeners.ondetach |
|
291 |
|
292 local bufferqueue = { } -- buffer array |
|
293 local bufferqueuelen = 0 -- end of buffer array |
|
294 |
|
295 local toclose |
|
296 local fatalerror |
|
297 local needtls |
|
298 |
|
299 local bufferlen = 0 |
|
300 |
|
301 local noread = false |
|
302 local nosend = false |
|
303 |
|
304 local sendtraffic, readtraffic = 0, 0 |
|
305 |
|
306 local maxsendlen = _maxsendlen |
|
307 local maxreadlen = _maxreadlen |
|
308 |
|
309 --// public methods of the object //-- |
|
310 |
|
311 local handler = bufferqueue -- saves a table ^_^ |
|
312 |
|
313 handler.dispatch = function( ) |
|
314 return dispatch |
|
315 end |
|
316 handler.disconnect = function( ) |
|
317 return disconnect |
|
318 end |
|
319 handler.onreadtimeout = onreadtimeout; |
|
320 |
|
321 handler.setlistener = function( self, listeners ) |
|
322 if detach then |
|
323 detach(self) -- Notify listener that it is no longer responsible for this connection |
|
324 end |
|
325 dispatch = listeners.onincoming |
|
326 disconnect = listeners.ondisconnect |
|
327 status = listeners.onstatus |
|
328 drain = listeners.ondrain |
|
329 handler.onreadtimeout = listeners.onreadtimeout |
|
330 detach = listeners.ondetach |
|
331 end |
|
332 handler.getstats = function( ) |
|
333 return readtraffic, sendtraffic |
|
334 end |
|
335 handler.ssl = function( ) |
|
336 return ssl |
|
337 end |
|
338 handler.sslctx = function ( ) |
|
339 return sslctx |
|
340 end |
|
341 handler.send = function( _, data, i, j ) |
|
342 return send( socket, data, i, j ) |
|
343 end |
|
344 handler.receive = function( pattern, prefix ) |
|
345 return receive( socket, pattern, prefix ) |
|
346 end |
|
347 handler.shutdown = function( pattern ) |
|
348 return shutdown( socket, pattern ) |
|
349 end |
|
350 handler.setoption = function (self, option, value) |
|
351 if socket.setoption then |
|
352 return socket:setoption(option, value); |
|
353 end |
|
354 return false, "setoption not implemented"; |
|
355 end |
|
356 handler.force_close = function ( self, err ) |
|
357 if bufferqueuelen ~= 0 then |
|
358 out_put("server.lua: discarding unwritten data for ", tostring(ip), ":", tostring(clientport)) |
|
359 bufferqueuelen = 0; |
|
360 end |
|
361 return self:close(err); |
|
362 end |
|
363 handler.close = function( self, err ) |
|
364 if not handler then return true; end |
|
365 _readlistlen = removesocket( _readlist, socket, _readlistlen ) |
|
366 _readtimes[ handler ] = nil |
|
367 if bufferqueuelen ~= 0 then |
|
368 handler.sendbuffer() -- Try now to send any outstanding data |
|
369 if bufferqueuelen ~= 0 then -- Still not empty, so we'll try again later |
|
370 if handler then |
|
371 handler.write = nil -- ... but no further writing allowed |
|
372 end |
|
373 toclose = true |
|
374 return false |
|
375 end |
|
376 end |
|
377 if socket then |
|
378 _ = shutdown and shutdown( socket ) |
|
379 socket:close( ) |
|
380 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) |
|
381 _socketlist[ socket ] = nil |
|
382 socket = nil |
|
383 else |
|
384 out_put "server.lua: socket already closed" |
|
385 end |
|
386 if handler then |
|
387 _writetimes[ handler ] = nil |
|
388 _closelist[ handler ] = nil |
|
389 local _handler = handler; |
|
390 handler = nil |
|
391 if disconnect then |
|
392 disconnect(_handler, err or false); |
|
393 disconnect = nil |
|
394 end |
|
395 end |
|
396 if server then |
|
397 server.remove( ) |
|
398 end |
|
399 out_put "server.lua: closed client handler and removed socket from list" |
|
400 return true |
|
401 end |
|
402 handler.ip = function( ) |
|
403 return ip |
|
404 end |
|
405 handler.serverport = function( ) |
|
406 return serverport |
|
407 end |
|
408 handler.clientport = function( ) |
|
409 return clientport |
|
410 end |
|
411 handler.port = handler.clientport -- COMPAT server_event |
|
412 local write = function( self, data ) |
|
413 bufferlen = bufferlen + #data |
|
414 if bufferlen > maxsendlen then |
|
415 _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle |
|
416 handler.write = idfalse -- dont write anymore |
|
417 return false |
|
418 elseif socket and not _sendlist[ socket ] then |
|
419 _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) |
|
420 end |
|
421 bufferqueuelen = bufferqueuelen + 1 |
|
422 bufferqueue[ bufferqueuelen ] = data |
|
423 if handler then |
|
424 _writetimes[ handler ] = _writetimes[ handler ] or _currenttime |
|
425 end |
|
426 return true |
|
427 end |
|
428 handler.write = write |
|
429 handler.bufferqueue = function( self ) |
|
430 return bufferqueue |
|
431 end |
|
432 handler.socket = function( self ) |
|
433 return socket |
|
434 end |
|
435 handler.set_mode = function( self, new ) |
|
436 pattern = new or pattern |
|
437 return pattern |
|
438 end |
|
439 handler.set_send = function ( self, newsend ) |
|
440 send = newsend or send |
|
441 return send |
|
442 end |
|
443 handler.bufferlen = function( self, readlen, sendlen ) |
|
444 maxsendlen = sendlen or maxsendlen |
|
445 maxreadlen = readlen or maxreadlen |
|
446 return bufferlen, maxreadlen, maxsendlen |
|
447 end |
|
448 --TODO: Deprecate |
|
449 handler.lock_read = function (self, switch) |
|
450 if switch == true then |
|
451 local tmp = _readlistlen |
|
452 _readlistlen = removesocket( _readlist, socket, _readlistlen ) |
|
453 _readtimes[ handler ] = nil |
|
454 if _readlistlen ~= tmp then |
|
455 noread = true |
|
456 end |
|
457 elseif switch == false then |
|
458 if noread then |
|
459 noread = false |
|
460 _readlistlen = addsocket(_readlist, socket, _readlistlen) |
|
461 _readtimes[ handler ] = _currenttime |
|
462 end |
|
463 end |
|
464 return noread |
|
465 end |
|
466 handler.pause = function (self) |
|
467 return self:lock_read(true); |
|
468 end |
|
469 handler.resume = function (self) |
|
470 return self:lock_read(false); |
|
471 end |
|
472 handler.lock = function( self, switch ) |
|
473 handler.lock_read (switch) |
|
474 if switch == true then |
|
475 handler.write = idfalse |
|
476 local tmp = _sendlistlen |
|
477 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) |
|
478 _writetimes[ handler ] = nil |
|
479 if _sendlistlen ~= tmp then |
|
480 nosend = true |
|
481 end |
|
482 elseif switch == false then |
|
483 handler.write = write |
|
484 if nosend then |
|
485 nosend = false |
|
486 write( "" ) |
|
487 end |
|
488 end |
|
489 return noread, nosend |
|
490 end |
|
491 local _readbuffer = function( ) -- this function reads data |
|
492 local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern" |
|
493 if not err or (err == "wantread" or err == "timeout") then -- received something |
|
494 local buffer = buffer or part or "" |
|
495 local len = #buffer |
|
496 if len > maxreadlen then |
|
497 handler:close( "receive buffer exceeded" ) |
|
498 return false |
|
499 end |
|
500 local count = len * STAT_UNIT |
|
501 readtraffic = readtraffic + count |
|
502 _readtraffic = _readtraffic + count |
|
503 _readtimes[ handler ] = _currenttime |
|
504 --out_put( "server.lua: read data '", buffer:gsub("[^%w%p ]", "."), "', error: ", err ) |
|
505 return dispatch( handler, buffer, err ) |
|
506 else -- connections was closed or fatal error |
|
507 out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) ) |
|
508 fatalerror = true |
|
509 _ = handler and handler:force_close( err ) |
|
510 return false |
|
511 end |
|
512 end |
|
513 local _sendbuffer = function( ) -- this function sends data |
|
514 local succ, err, byte, buffer, count; |
|
515 if socket then |
|
516 buffer = table_concat( bufferqueue, "", 1, bufferqueuelen ) |
|
517 succ, err, byte = send( socket, buffer, 1, bufferlen ) |
|
518 count = ( succ or byte or 0 ) * STAT_UNIT |
|
519 sendtraffic = sendtraffic + count |
|
520 _sendtraffic = _sendtraffic + count |
|
521 for i = bufferqueuelen,1,-1 do |
|
522 bufferqueue[ i ] = nil |
|
523 end |
|
524 --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) ) |
|
525 else |
|
526 succ, err, count = false, "unexpected close", 0; |
|
527 end |
|
528 if succ then -- sending succesful |
|
529 bufferqueuelen = 0 |
|
530 bufferlen = 0 |
|
531 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) -- delete socket from writelist |
|
532 _writetimes[ handler ] = nil |
|
533 if drain then |
|
534 drain(handler) |
|
535 end |
|
536 _ = needtls and handler:starttls(nil) |
|
537 _ = toclose and handler:force_close( ) |
|
538 return true |
|
539 elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write |
|
540 buffer = string_sub( buffer, byte + 1, bufferlen ) -- new buffer |
|
541 bufferqueue[ 1 ] = buffer -- insert new buffer in queue |
|
542 bufferqueuelen = 1 |
|
543 bufferlen = bufferlen - byte |
|
544 _writetimes[ handler ] = _currenttime |
|
545 return true |
|
546 else -- connection was closed during sending or fatal error |
|
547 out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) ) |
|
548 fatalerror = true |
|
549 _ = handler and handler:force_close( err ) |
|
550 return false |
|
551 end |
|
552 end |
|
553 |
|
554 -- Set the sslctx |
|
555 local handshake; |
|
556 function handler.set_sslctx(self, new_sslctx) |
|
557 sslctx = new_sslctx; |
|
558 local read, wrote |
|
559 handshake = coroutine_wrap( function( client ) -- create handshake coroutine |
|
560 local err |
|
561 for i = 1, _maxsslhandshake do |
|
562 _sendlistlen = ( wrote and removesocket( _sendlist, client, _sendlistlen ) ) or _sendlistlen |
|
563 _readlistlen = ( read and removesocket( _readlist, client, _readlistlen ) ) or _readlistlen |
|
564 read, wrote = nil, nil |
|
565 _, err = client:dohandshake( ) |
|
566 if not err then |
|
567 out_put( "server.lua: ssl handshake done" ) |
|
568 handler.readbuffer = _readbuffer -- when handshake is done, replace the handshake function with regular functions |
|
569 handler.sendbuffer = _sendbuffer |
|
570 _ = status and status( handler, "ssl-handshake-complete" ) |
|
571 if self.autostart_ssl and listeners.onconnect then |
|
572 listeners.onconnect(self); |
|
573 if bufferqueuelen ~= 0 then |
|
574 _sendlistlen = addsocket(_sendlist, client, _sendlistlen) |
|
575 end |
|
576 end |
|
577 _readlistlen = addsocket(_readlist, client, _readlistlen) |
|
578 return true |
|
579 else |
|
580 if err == "wantwrite" then |
|
581 _sendlistlen = addsocket(_sendlist, client, _sendlistlen) |
|
582 wrote = true |
|
583 elseif err == "wantread" then |
|
584 _readlistlen = addsocket(_readlist, client, _readlistlen) |
|
585 read = true |
|
586 else |
|
587 break; |
|
588 end |
|
589 err = nil; |
|
590 coroutine_yield( ) -- handshake not finished |
|
591 end |
|
592 end |
|
593 out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") ) |
|
594 _ = handler and handler:force_close("ssl handshake failed") |
|
595 return false, err -- handshake failed |
|
596 end |
|
597 ) |
|
598 end |
|
599 if has_luasec then |
|
600 handler.starttls = function( self, _sslctx) |
|
601 if _sslctx then |
|
602 handler:set_sslctx(_sslctx); |
|
603 end |
|
604 if bufferqueuelen > 0 then |
|
605 out_put "server.lua: we need to do tls, but delaying until send buffer empty" |
|
606 needtls = true |
|
607 return |
|
608 end |
|
609 out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) |
|
610 local oldsocket, err = socket |
|
611 socket, err = ssl_wrap( socket, sslctx ) -- wrap socket |
|
612 if not socket then |
|
613 out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") ) |
|
614 return nil, err -- fatal error |
|
615 end |
|
616 |
|
617 socket:settimeout( 0 ) |
|
618 |
|
619 -- add the new socket to our system |
|
620 send = socket.send |
|
621 receive = socket.receive |
|
622 shutdown = id |
|
623 _socketlist[ socket ] = handler |
|
624 _readlistlen = addsocket(_readlist, socket, _readlistlen) |
|
625 |
|
626 -- remove traces of the old socket |
|
627 _readlistlen = removesocket( _readlist, oldsocket, _readlistlen ) |
|
628 _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen ) |
|
629 _socketlist[ oldsocket ] = nil |
|
630 |
|
631 handler.starttls = nil |
|
632 needtls = nil |
|
633 |
|
634 -- Secure now (if handshake fails connection will close) |
|
635 ssl = true |
|
636 |
|
637 handler.readbuffer = handshake |
|
638 handler.sendbuffer = handshake |
|
639 return handshake( socket ) -- do handshake |
|
640 end |
|
641 end |
|
642 |
|
643 handler.readbuffer = _readbuffer |
|
644 handler.sendbuffer = _sendbuffer |
|
645 send = socket.send |
|
646 receive = socket.receive |
|
647 shutdown = ( ssl and id ) or socket.shutdown |
|
648 |
|
649 _socketlist[ socket ] = handler |
|
650 _readlistlen = addsocket(_readlist, socket, _readlistlen) |
|
651 |
|
652 if sslctx and has_luasec then |
|
653 out_put "server.lua: auto-starting ssl negotiation..." |
|
654 handler.autostart_ssl = true; |
|
655 local ok, err = handler:starttls(sslctx); |
|
656 if ok == false then |
|
657 return nil, nil, err |
|
658 end |
|
659 end |
|
660 |
|
661 return handler, socket |
|
662 end |
|
663 |
|
664 id = function( ) |
|
665 end |
|
666 |
|
667 idfalse = function( ) |
|
668 return false |
|
669 end |
|
670 |
|
671 addsocket = function( list, socket, len ) |
|
672 if not list[ socket ] then |
|
673 len = len + 1 |
|
674 list[ len ] = socket |
|
675 list[ socket ] = len |
|
676 end |
|
677 return len; |
|
678 end |
|
679 |
|
680 removesocket = function( list, socket, len ) -- this function removes sockets from a list ( copied from copas ) |
|
681 local pos = list[ socket ] |
|
682 if pos then |
|
683 list[ socket ] = nil |
|
684 local last = list[ len ] |
|
685 list[ len ] = nil |
|
686 if last ~= socket then |
|
687 list[ last ] = pos |
|
688 list[ pos ] = last |
|
689 end |
|
690 return len - 1 |
|
691 end |
|
692 return len |
|
693 end |
|
694 |
|
695 closesocket = function( socket ) |
|
696 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) |
|
697 _readlistlen = removesocket( _readlist, socket, _readlistlen ) |
|
698 _socketlist[ socket ] = nil |
|
699 socket:close( ) |
|
700 --mem_free( ) |
|
701 end |
|
702 |
|
703 local function link(sender, receiver, buffersize) |
|
704 local sender_locked; |
|
705 local _sendbuffer = receiver.sendbuffer; |
|
706 function receiver.sendbuffer() |
|
707 _sendbuffer(); |
|
708 if sender_locked and receiver.bufferlen() < buffersize then |
|
709 sender:lock_read(false); -- Unlock now |
|
710 sender_locked = nil; |
|
711 end |
|
712 end |
|
713 |
|
714 local _readbuffer = sender.readbuffer; |
|
715 function sender.readbuffer() |
|
716 _readbuffer(); |
|
717 if not sender_locked and receiver.bufferlen() >= buffersize then |
|
718 sender_locked = true; |
|
719 sender:lock_read(true); |
|
720 end |
|
721 end |
|
722 sender:set_mode("*a"); |
|
723 end |
|
724 |
|
725 ----------------------------------// PUBLIC //-- |
|
726 |
|
727 addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server |
|
728 addr = addr or "*" |
|
729 local err |
|
730 if type( listeners ) ~= "table" then |
|
731 err = "invalid listener table" |
|
732 elseif type ( addr ) ~= "string" then |
|
733 err = "invalid address" |
|
734 elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then |
|
735 err = "invalid port" |
|
736 elseif _server[ addr..":"..port ] then |
|
737 err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist" |
|
738 elseif sslctx and not has_luasec then |
|
739 err = "luasec not found" |
|
740 end |
|
741 if err then |
|
742 out_error( "server.lua, [", addr, "]:", port, ": ", err ) |
|
743 return nil, err |
|
744 end |
|
745 local server, err = socket_bind( addr, port, _tcpbacklog ) |
|
746 if err then |
|
747 out_error( "server.lua, [", addr, "]:", port, ": ", err ) |
|
748 return nil, err |
|
749 end |
|
750 local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket |
|
751 if not handler then |
|
752 server:close( ) |
|
753 return nil, err |
|
754 end |
|
755 server:settimeout( 0 ) |
|
756 _readlistlen = addsocket(_readlist, server, _readlistlen) |
|
757 _server[ addr..":"..port ] = handler |
|
758 _socketlist[ server ] = handler |
|
759 out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '[", addr, "]:", port, "'" ) |
|
760 return handler |
|
761 end |
|
762 |
|
763 getserver = function ( addr, port ) |
|
764 return _server[ addr..":"..port ]; |
|
765 end |
|
766 |
|
767 removeserver = function( addr, port ) |
|
768 local handler = _server[ addr..":"..port ] |
|
769 if not handler then |
|
770 return nil, "no server found on '[" .. addr .. "]:" .. tostring( port ) .. "'" |
|
771 end |
|
772 handler:close( ) |
|
773 _server[ addr..":"..port ] = nil |
|
774 return true |
|
775 end |
|
776 |
|
777 closeall = function( ) |
|
778 for _, handler in pairs( _socketlist ) do |
|
779 handler:close( ) |
|
780 _socketlist[ _ ] = nil |
|
781 end |
|
782 _readlistlen = 0 |
|
783 _sendlistlen = 0 |
|
784 _timerlistlen = 0 |
|
785 _server = { } |
|
786 _readlist = { } |
|
787 _sendlist = { } |
|
788 _timerlist = { } |
|
789 _socketlist = { } |
|
790 --mem_free( ) |
|
791 end |
|
792 |
|
793 getsettings = function( ) |
|
794 return { |
|
795 select_timeout = _selecttimeout; |
|
796 select_sleep_time = _sleeptime; |
|
797 tcp_backlog = _tcpbacklog; |
|
798 max_send_buffer_size = _maxsendlen; |
|
799 max_receive_buffer_size = _maxreadlen; |
|
800 select_idle_check_interval = _checkinterval; |
|
801 send_timeout = _sendtimeout; |
|
802 read_timeout = _readtimeout; |
|
803 max_connections = _maxselectlen; |
|
804 max_ssl_handshake_roundtrips = _maxsslhandshake; |
|
805 highest_allowed_fd = _maxfd; |
|
806 } |
|
807 end |
|
808 |
|
809 changesettings = function( new ) |
|
810 if type( new ) ~= "table" then |
|
811 return nil, "invalid settings table" |
|
812 end |
|
813 _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout |
|
814 _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime |
|
815 _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen |
|
816 _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen |
|
817 _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval |
|
818 _tcpbacklog = tonumber( new.tcp_backlog ) or _tcpbacklog |
|
819 _sendtimeout = tonumber( new.send_timeout ) or _sendtimeout |
|
820 _readtimeout = tonumber( new.read_timeout ) or _readtimeout |
|
821 _maxselectlen = new.max_connections or _maxselectlen |
|
822 _maxsslhandshake = new.max_ssl_handshake_roundtrips or _maxsslhandshake |
|
823 _maxfd = new.highest_allowed_fd or _maxfd |
|
824 return true |
|
825 end |
|
826 |
|
827 addtimer = function( listener ) |
|
828 if type( listener ) ~= "function" then |
|
829 return nil, "invalid listener function" |
|
830 end |
|
831 _timerlistlen = _timerlistlen + 1 |
|
832 _timerlist[ _timerlistlen ] = listener |
|
833 return true |
|
834 end |
|
835 |
|
836 local add_task do |
|
837 local data = {}; |
|
838 local new_data = {}; |
|
839 |
|
840 function add_task(delay, callback) |
|
841 local current_time = luasocket_gettime(); |
|
842 delay = delay + current_time; |
|
843 if delay >= current_time then |
|
844 table_insert(new_data, {delay, callback}); |
|
845 else |
|
846 local r = callback(current_time); |
|
847 if r and type(r) == "number" then |
|
848 return add_task(r, callback); |
|
849 end |
|
850 end |
|
851 end |
|
852 |
|
853 addtimer(function() |
|
854 local current_time = luasocket_gettime(); |
|
855 if #new_data > 0 then |
|
856 for _, d in pairs(new_data) do |
|
857 table_insert(data, d); |
|
858 end |
|
859 new_data = {}; |
|
860 end |
|
861 |
|
862 local next_time = math_huge; |
|
863 for i, d in pairs(data) do |
|
864 local t, callback = d[1], d[2]; |
|
865 if t <= current_time then |
|
866 data[i] = nil; |
|
867 local r = callback(current_time); |
|
868 if type(r) == "number" then |
|
869 add_task(r, callback); |
|
870 next_time = math_min(next_time, r); |
|
871 end |
|
872 else |
|
873 next_time = math_min(next_time, t - current_time); |
|
874 end |
|
875 end |
|
876 return next_time; |
|
877 end); |
|
878 end |
|
879 |
|
880 stats = function( ) |
|
881 return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen |
|
882 end |
|
883 |
|
884 local quitting; |
|
885 |
|
886 local function setquitting(quit) |
|
887 quitting = not not quit; |
|
888 end |
|
889 |
|
890 loop = function(once) -- this is the main loop of the program |
|
891 if quitting then return "quitting"; end |
|
892 if once then quitting = "once"; end |
|
893 local next_timer_time = math_huge; |
|
894 repeat |
|
895 local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) ) |
|
896 for i, socket in ipairs( write ) do -- send data waiting in writequeues |
|
897 local handler = _socketlist[ socket ] |
|
898 if handler then |
|
899 handler.sendbuffer( ) |
|
900 else |
|
901 closesocket( socket ) |
|
902 out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen |
|
903 end |
|
904 end |
|
905 for i, socket in ipairs( read ) do -- receive data |
|
906 local handler = _socketlist[ socket ] |
|
907 if handler then |
|
908 handler.readbuffer( ) |
|
909 else |
|
910 closesocket( socket ) |
|
911 out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen |
|
912 end |
|
913 end |
|
914 for handler, err in pairs( _closelist ) do |
|
915 handler.disconnect( )( handler, err ) |
|
916 handler:force_close() -- forced disconnect |
|
917 _closelist[ handler ] = nil; |
|
918 end |
|
919 _currenttime = luasocket_gettime( ) |
|
920 |
|
921 -- Check for socket timeouts |
|
922 local difftime = os_difftime( _currenttime - _starttime ) |
|
923 if difftime > _checkinterval then |
|
924 _starttime = _currenttime |
|
925 for handler, timestamp in pairs( _writetimes ) do |
|
926 if os_difftime( _currenttime - timestamp ) > _sendtimeout then |
|
927 handler.disconnect( )( handler, "send timeout" ) |
|
928 handler:force_close() -- forced disconnect |
|
929 end |
|
930 end |
|
931 for handler, timestamp in pairs( _readtimes ) do |
|
932 if os_difftime( _currenttime - timestamp ) > _readtimeout then |
|
933 if not(handler.onreadtimeout) or handler:onreadtimeout() ~= true then |
|
934 handler.disconnect( )( handler, "read timeout" ) |
|
935 handler:close( ) -- forced disconnect? |
|
936 else |
|
937 _readtimes[ handler ] = _currenttime -- reset timer |
|
938 end |
|
939 end |
|
940 end |
|
941 end |
|
942 |
|
943 -- Fire timers |
|
944 if _currenttime - _timer >= math_min(next_timer_time, 1) then |
|
945 next_timer_time = math_huge; |
|
946 for i = 1, _timerlistlen do |
|
947 local t = _timerlist[ i ]( _currenttime ) -- fire timers |
|
948 if t then next_timer_time = math_min(next_timer_time, t); end |
|
949 end |
|
950 _timer = _currenttime |
|
951 else |
|
952 next_timer_time = next_timer_time - (_currenttime - _timer); |
|
953 end |
|
954 |
|
955 -- wait some time (0 by default) |
|
956 socket_sleep( _sleeptime ) |
|
957 until quitting; |
|
958 if once and quitting == "once" then quitting = nil; return; end |
|
959 return "quitting" |
|
960 end |
|
961 |
|
962 local function step() |
|
963 return loop(true); |
|
964 end |
|
965 |
|
966 local function get_backend() |
|
967 return "select"; |
|
968 end |
|
969 |
|
970 --// EXPERIMENTAL //-- |
|
971 |
|
972 local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx ) |
|
973 local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) |
|
974 if not handler then return nil, err end |
|
975 _socketlist[ socket ] = handler |
|
976 if not sslctx then |
|
977 _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) |
|
978 if listeners.onconnect then |
|
979 -- When socket is writeable, call onconnect |
|
980 local _sendbuffer = handler.sendbuffer; |
|
981 handler.sendbuffer = function () |
|
982 handler.sendbuffer = _sendbuffer; |
|
983 listeners.onconnect(handler); |
|
984 return _sendbuffer(); -- Send any queued outgoing data |
|
985 end |
|
986 end |
|
987 end |
|
988 return handler, socket |
|
989 end |
|
990 |
|
991 local addclient = function( address, port, listeners, pattern, sslctx, typ ) |
|
992 local err |
|
993 if type( listeners ) ~= "table" then |
|
994 err = "invalid listener table" |
|
995 elseif type ( address ) ~= "string" then |
|
996 err = "invalid address" |
|
997 elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then |
|
998 err = "invalid port" |
|
999 elseif sslctx and not has_luasec then |
|
1000 err = "luasec not found" |
|
1001 end |
|
1002 if getaddrinfo and not typ then |
|
1003 local addrinfo, err = getaddrinfo(address) |
|
1004 if not addrinfo then return nil, err end |
|
1005 if addrinfo[1] and addrinfo[1].family == "inet6" then |
|
1006 typ = "tcp6" |
|
1007 end |
|
1008 end |
|
1009 local create = luasocket[typ or "tcp"] |
|
1010 if type( create ) ~= "function" then |
|
1011 err = "invalid socket type" |
|
1012 end |
|
1013 |
|
1014 if err then |
|
1015 out_error( "server.lua, addclient: ", err ) |
|
1016 return nil, err |
|
1017 end |
|
1018 |
|
1019 local client, err = create( ) |
|
1020 if err then |
|
1021 return nil, err |
|
1022 end |
|
1023 client:settimeout( 0 ) |
|
1024 local ok, err = client:connect( address, port ) |
|
1025 if ok or err == "timeout" or err == "Operation already in progress" then |
|
1026 return wrapclient( client, address, port, listeners, pattern, sslctx ) |
|
1027 else |
|
1028 return nil, err |
|
1029 end |
|
1030 end |
|
1031 |
|
1032 --// EXPERIMENTAL //-- |
|
1033 |
|
1034 ----------------------------------// BEGIN //-- |
|
1035 |
|
1036 use "setmetatable" ( _socketlist, { __mode = "k" } ) |
|
1037 use "setmetatable" ( _readtimes, { __mode = "k" } ) |
|
1038 use "setmetatable" ( _writetimes, { __mode = "k" } ) |
|
1039 |
|
1040 _timer = luasocket_gettime( ) |
|
1041 _starttime = luasocket_gettime( ) |
|
1042 |
|
1043 local function setlogger(new_logger) |
|
1044 local old_logger = log; |
|
1045 if new_logger then |
|
1046 log = new_logger; |
|
1047 end |
|
1048 return old_logger; |
|
1049 end |
|
1050 |
|
1051 ----------------------------------// PUBLIC INTERFACE //-- |
|
1052 |
|
1053 return { |
|
1054 _addtimer = addtimer, |
|
1055 add_task = add_task; |
|
1056 |
|
1057 addclient = addclient, |
|
1058 wrapclient = wrapclient, |
|
1059 |
|
1060 loop = loop, |
|
1061 link = link, |
|
1062 step = step, |
|
1063 stats = stats, |
|
1064 closeall = closeall, |
|
1065 addserver = addserver, |
|
1066 getserver = getserver, |
|
1067 setlogger = setlogger, |
|
1068 getsettings = getsettings, |
|
1069 setquitting = setquitting, |
|
1070 removeserver = removeserver, |
|
1071 get_backend = get_backend, |
|
1072 changesettings = changesettings, |
|
1073 } |