# HG changeset patch # User Matthew Wild # Date 1419504486 0 # Node ID d363a6692a10234d70cedaffabecf3c00030ea17 Initial commit. Tortoises are fun. diff -r 000000000000 -r d363a6692a10 clients.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/clients.lua Thu Dec 25 10:48:06 2014 +0000 @@ -0,0 +1,80 @@ +local socket = require "socket"; +local server = require "net.server_select"; +local http_server = require"net.http.server"; +local log = require "util.logger".init("clients"); + +local response_head = table.concat({ + "HTTP/1.1 200 OK"; + "Max-Age: 0"; + "Expires: 0"; + "Cache-Control: no-cache, private"; + "Pragma: no-cache"; + "Content-Type: multipart/x-mixed-replace; boundary=--BoundaryString"; + "Keep-Alive: timeout=5, max=99"; + "Connection: Keep-Alive"; + "Transfer-Encoding: chunked"; + ""; + ""; +}, "\r\n"); + +local listener = { onconnect = function () end; onincoming = function () end; } + +local last_chunk; + +local clients = {}; + +-- Called when a HTTP stream client closes +function listener.ondisconnect(conn) + clients[conn] = nil; + if not next(clients) then + log("debug", "No more clients"); + events.fire_event("no-clients"); + end +end +listener.ondetach = listener.ondisconnect; + +function handle_request(event, path) + local path = event.request.url.path; + if path ~= "/cam" then + return 404; + end + + if not next(clients) then + log("debug", "Have clients now"); + events.fire_event("have-clients"); + end + + local conn = event.response.conn; + + conn:write(response_head); + clients[conn] = true; + + if last_chunk then + conn:write(last_chunk); + end + + conn:setlistener(listener); + + return true; +end + +events.add_handler("image-changed", function (event) + log("debug", "New image"); + local chunk_data = table.concat({ + "--BoundaryString", + "Content-Type: image/jpeg"; + "Content-Length: "..#event.image; + ""; + event.image; + }, "\r\n"); + last_chunk = ("%x\r\n%s\r\n"):format(#chunk_data, chunk_data); + + for client in pairs(clients) do + client:write(last_chunk); + end +end); + +http_server.add_host("localhost"); +http_server.set_default_host("localhost"); +http_server.add_handler("GET localhost/*", handle_request); +http_server.listen_on(8006); diff -r 000000000000 -r d363a6692a10 main.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/main.lua Thu Dec 25 10:48:06 2014 +0000 @@ -0,0 +1,42 @@ +collectgarbage("setpause", 110); + +assert(arg[1], "Please supply upstream URL"); + +local server = require "net.server_select"; +local http = require "net.http"; +local logger = require "util.logger"; +events = require "util.events".new(); + +require "clients"; + +local function printf(source, level, fmt, ...) return print(source, level, fmt:format(...)); end +for _, level in ipairs{"debug", "info", "warn", "error"} do + logger.add_level_sink(level, printf); +end + +local request; + +events.add_handler("have-clients", function () + request = http.request(arg[1], {success_on_chunk=true}, function (data, code, response) + local boundary = response.headers["content-type"]:match("; boundary=(.+)$").."\r\n"; + if #data > #boundary and data:sub(1, #boundary) ~= boundary then + print(string.format(("%02X "):rep(6), data:byte(1,6))); + error("bad format"); + end + local img_header, header_len = data:sub(#boundary-1):match("(\r\n.-\r\n)\r\n()"); + if not img_header then return; end + local img_size = tonumber(img_header:match("\r\nContent%-Length:%s*(%d+)\r\n")); + local start_idx, end_idx = #boundary-2+header_len, #boundary-2+header_len+img_size-1; + if end_idx + 2 > #data then return; end + local img = data:sub(start_idx, end_idx); + assert(#img == img_size, "incorrect image size: actual:"..#img.." vs expected: "..img_size.." data: "..#data) + response.body = data:sub(end_idx+3); + events.fire_event("image-changed", { image = img }); + end); +end); + +events.add_handler("no-clients", function () + request.conn:close(); +end); + +server.loop(); diff -r 000000000000 -r d363a6692a10 net/adns.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/net/adns.lua Thu Dec 25 10:48:06 2014 +0000 @@ -0,0 +1,91 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local server = require "net.server"; +local dns = require "net.dns"; + +local log = require "util.logger".init("adns"); + +local t_insert, t_remove = table.insert, table.remove; +local coroutine, tostring, pcall = coroutine, tostring, pcall; + +local function dummy_send(sock, data, i, j) return (j-i)+1; end + +module "adns" + +function lookup(handler, qname, qtype, qclass) + return coroutine.wrap(function (peek) + if peek then + log("debug", "Records for %s already cached, using those...", qname); + handler(peek); + return; + end + log("debug", "Records for %s not in cache, sending query (%s)...", qname, tostring(coroutine.running())); + local ok, err = dns.query(qname, qtype, qclass); + if ok then + coroutine.yield({ qclass or "IN", qtype or "A", qname, coroutine.running()}); -- Wait for reply + log("debug", "Reply for %s (%s)", qname, tostring(coroutine.running())); + end + if ok then + ok, err = pcall(handler, dns.peek(qname, qtype, qclass)); + else + log("error", "Error sending DNS query: %s", err); + ok, err = pcall(handler, nil, err); + end + if not ok then + log("error", "Error in DNS response handler: %s", tostring(err)); + end + end)(dns.peek(qname, qtype, qclass)); +end + +function cancel(handle, call_handler, reason) + log("warn", "Cancelling DNS lookup for %s", tostring(handle[3])); + dns.cancel(handle[1], handle[2], handle[3], handle[4], call_handler); +end + +function new_async_socket(sock, resolver) + local peername = ""; + local listener = {}; + local handler = {}; + local err; + function listener.onincoming(conn, data) + if data then + dns.feed(handler, data); + end + end + function listener.ondisconnect(conn, err) + if err then + log("warn", "DNS socket for %s disconnected: %s", peername, err); + local servers = resolver.server; + if resolver.socketset[conn] == resolver.best_server and resolver.best_server == #servers then + log("error", "Exhausted all %d configured DNS servers, next lookup will try %s again", #servers, servers[1]); + end + + resolver:servfail(conn); -- Let the magic commence + end + end + handler, err = server.wrapclient(sock, "dns", 53, listener); + if not handler then + return nil, err; + end + + handler.settimeout = function () end + handler.setsockname = function (_, ...) return sock:setsockname(...); end + handler.setpeername = function (_, ...) peername = (...); local ret, err = sock:setpeername(...); _:set_send(dummy_send); return ret, err; end + handler.connect = function (_, ...) return sock:connect(...) end + --handler.send = function (_, data) _:write(data); return _.sendbuffer and _.sendbuffer(); end + handler.send = function (_, data) + log("debug", "Sending DNS query to %s", peername); + return sock:send(data); + end + return handler; +end + +dns.socket_wrapper_set(new_async_socket); + +return _M; diff -r 000000000000 -r d363a6692a10 net/cqueues.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/net/cqueues.lua Thu Dec 25 10:48:06 2014 +0000 @@ -0,0 +1,65 @@ +-- Prosody IM +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- +-- This module allows you to use cqueues with a net.server mainloop +-- + +local server = require "net.server"; +local cqueues = require "cqueues"; + +-- Create a single top level cqueue +local cq; + +if server.cq then -- server provides cqueues object + cq = server.cq; +elseif server.get_backend() == "select" and server._addtimer then -- server_select + cq = cqueues.new(); + local function step() + assert(cq:loop(0)); + end + + -- Use wrapclient (as wrapconnection isn't exported) to get server_select to watch cq fd + local handler = server.wrapclient({ + getfd = function() return cq:pollfd(); end; + settimeout = function() end; -- Method just needs to exist + close = function() end; -- Need close method for 'closeall' + }, nil, nil, {}); + + -- Only need to listen for readable; cqueues handles everything under the hood + -- readbuffer is called when `select` notes an fd as readable + handler.readbuffer = step; + + -- Use server_select low lever timer facility, + -- this callback gets called *every* time there is a timeout in the main loop + server._addtimer(function(current_time) + -- This may end up in extra step()'s, but cqueues handles it for us. + step(); + return cq:timeout(); + end); +elseif server.event and server.base then -- server_event + cq = cqueues.new(); + -- Only need to listen for readable; cqueues handles everything under the hood + local EV_READ = server.event.EV_READ; + server.base:addevent(cq:pollfd(), EV_READ, function(e) + assert(cq:loop(0)); + -- Convert a cq timeout to an acceptable timeout for luaevent + local t = cq:timeout(); + if t == 0 then -- if you give luaevent 0, it won't call this callback again + t = 0.000001; -- 1 microsecond is the smallest that works (goes into a `struct timeval`) + elseif t == nil then -- you always need to give a timeout, pick something big if we don't have one + t = 0x7FFFFFFF; -- largest 32bit int + end + return EV_READ, t; + end, + -- Schedule the callback to fire on first tick to ensure any cq:wrap calls that happen during start-up are serviced. + 0.000001); +else + error "NYI" +end + +return { + cq = cq; +} diff -r 000000000000 -r d363a6692a10 net/dns.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/net/dns.lua Thu Dec 25 10:48:06 2014 +0000 @@ -0,0 +1,1104 @@ +-- Prosody IM +-- This file is included with Prosody IM. It has modifications, +-- which are hereby placed in the public domain. + + +-- todo: quick (default) header generation +-- todo: nxdomain, error handling +-- todo: cache results of encodeName + + +-- reference: http://tools.ietf.org/html/rfc1035 +-- reference: http://tools.ietf.org/html/rfc1876 (LOC) + + +local socket = require "socket"; +local timer = require "util.timer"; +local new_ip = require "util.ip".new_ip; + +local _, windows = pcall(require, "util.windows"); +local is_windows = (_ and windows) or os.getenv("WINDIR"); + +local coroutine, io, math, string, table = + coroutine, io, math, string, table; + +local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type= + ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type; + +local ztact = { -- public domain 20080404 lua@ztact.com + get = function(parent, ...) + local len = select('#', ...); + for i=1,len do + parent = parent[select(i, ...)]; + if parent == nil then break; end + end + return parent; + end; + set = function(parent, ...) + local len = select('#', ...); + local key, value = select(len-1, ...); + local cutpoint, cutkey; + + for i=1,len-2 do + local key = select (i, ...) + local child = parent[key] + + if value == nil then + if child == nil then + return; + elseif next(child, next(child)) then + cutpoint = nil; cutkey = nil; + elseif cutpoint == nil then + cutpoint = parent; cutkey = key; + end + elseif child == nil then + child = {}; + parent[key] = child; + end + parent = child + end + + if value == nil and cutpoint then + cutpoint[cutkey] = nil; + else + parent[key] = value; + return value; + end + end; +}; +local get, set = ztact.get, ztact.set; + +local default_timeout = 15; + +-------------------------------------------------- module dns +module('dns') +local dns = _M; + + +-- dns type & class codes ------------------------------ dns type & class codes + + +local append = table.insert + + +local function highbyte(i) -- - - - - - - - - - - - - - - - - - - highbyte + return (i-(i%0x100))/0x100; +end + + +local function augment (t) -- - - - - - - - - - - - - - - - - - - - augment + local a = {}; + for i,s in pairs(t) do + a[i] = s; + a[s] = s; + a[string.lower(s)] = s; + end + return a; +end + + +local function encode (t) -- - - - - - - - - - - - - - - - - - - - - encode + local code = {}; + for i,s in pairs(t) do + local word = string.char(highbyte(i), i%0x100); + code[i] = word; + code[s] = word; + code[string.lower(s)] = word; + end + return code; +end + + +dns.types = { + 'A', 'NS', 'MD', 'MF', 'CNAME', 'SOA', 'MB', 'MG', 'MR', 'NULL', 'WKS', + 'PTR', 'HINFO', 'MINFO', 'MX', 'TXT', + [ 28] = 'AAAA', [ 29] = 'LOC', [ 33] = 'SRV', + [252] = 'AXFR', [253] = 'MAILB', [254] = 'MAILA', [255] = '*' }; + + +dns.classes = { 'IN', 'CS', 'CH', 'HS', [255] = '*' }; + + +dns.type = augment (dns.types); +dns.class = augment (dns.classes); +dns.typecode = encode (dns.types); +dns.classcode = encode (dns.classes); + + + +local function standardize(qname, qtype, qclass) -- - - - - - - standardize + if string.byte(qname, -1) ~= 0x2E then qname = qname..'.'; end + qname = string.lower(qname); + return qname, dns.type[qtype or 'A'], dns.class[qclass or 'IN']; +end + + +local function prune(rrs, time, soft) -- - - - - - - - - - - - - - - prune + time = time or socket.gettime(); + for i,rr in ipairs(rrs) do + if rr.tod then + -- rr.tod = rr.tod - 50 -- accelerated decripitude + rr.ttl = math.floor(rr.tod - time); + if rr.ttl <= 0 then + rrs[rr[rr.type:lower()]] = nil; + table.remove(rrs, i); + return prune(rrs, time, soft); -- Re-iterate + end + elseif soft == 'soft' then -- What is this? I forget! + assert(rr.ttl == 0); + rrs[rr[rr.type:lower()]] = nil; + table.remove(rrs, i); + end + end +end + + +-- metatables & co. ------------------------------------------ metatables & co. + + +local resolver = {}; +resolver.__index = resolver; + +resolver.timeout = default_timeout; + +local function default_rr_tostring(rr) + local rr_val = rr.type and rr[rr.type:lower()]; + if type(rr_val) ~= "string" then + return ""; + end + return rr_val; +end + +local special_tostrings = { + LOC = resolver.LOC_tostring; + MX = function (rr) + return string.format('%2i %s', rr.pref, rr.mx); + end; + SRV = function (rr) + local s = rr.srv; + return string.format('%5d %5d %5d %s', s.priority, s.weight, s.port, s.target); + end; +}; + +local rr_metatable = {}; -- - - - - - - - - - - - - - - - - - - rr_metatable +function rr_metatable.__tostring(rr) + local rr_string = (special_tostrings[rr.type] or default_rr_tostring)(rr); + return string.format('%2s %-5s %6i %-28s %s', rr.class, rr.type, rr.ttl, rr.name, rr_string); +end + + +local rrs_metatable = {}; -- - - - - - - - - - - - - - - - - - rrs_metatable +function rrs_metatable.__tostring(rrs) + local t = {}; + for i,rr in ipairs(rrs) do + append(t, tostring(rr)..'\n'); + end + return table.concat(t); +end + + +local cache_metatable = {}; -- - - - - - - - - - - - - - - - cache_metatable +function cache_metatable.__tostring(cache) + local time = socket.gettime(); + local t = {}; + for class,types in pairs(cache) do + for type,names in pairs(types) do + for name,rrs in pairs(names) do + prune(rrs, time); + append(t, tostring(rrs)); + end + end + end + return table.concat(t); +end + + +function resolver:new() -- - - - - - - - - - - - - - - - - - - - - resolver + local r = { active = {}, cache = {}, unsorted = {} }; + setmetatable(r, resolver); + setmetatable(r.cache, cache_metatable); + setmetatable(r.unsorted, { __mode = 'kv' }); + return r; +end + + +-- packet layer -------------------------------------------------- packet layer + + +function dns.random(...) -- - - - - - - - - - - - - - - - - - - dns.random + math.randomseed(math.floor(10000*socket.gettime()) % 0x100000000); + dns.random = math.random; + return dns.random(...); +end + + +local function encodeHeader(o) -- - - - - - - - - - - - - - - encodeHeader + o = o or {}; + o.id = o.id or dns.random(0, 0xffff); -- 16b (random) id + + o.rd = o.rd or 1; -- 1b 1 recursion desired + o.tc = o.tc or 0; -- 1b 1 truncated response + o.aa = o.aa or 0; -- 1b 1 authoritative response + o.opcode = o.opcode or 0; -- 4b 0 query + -- 1 inverse query + -- 2 server status request + -- 3-15 reserved + o.qr = o.qr or 0; -- 1b 0 query, 1 response + + o.rcode = o.rcode or 0; -- 4b 0 no error + -- 1 format error + -- 2 server failure + -- 3 name error + -- 4 not implemented + -- 5 refused + -- 6-15 reserved + o.z = o.z or 0; -- 3b 0 resvered + o.ra = o.ra or 0; -- 1b 1 recursion available + + o.qdcount = o.qdcount or 1; -- 16b number of question RRs + o.ancount = o.ancount or 0; -- 16b number of answers RRs + o.nscount = o.nscount or 0; -- 16b number of nameservers RRs + o.arcount = o.arcount or 0; -- 16b number of additional RRs + + -- string.char() rounds, so prevent roundup with -0.4999 + local header = string.char( + highbyte(o.id), o.id %0x100, + o.rd + 2*o.tc + 4*o.aa + 8*o.opcode + 128*o.qr, + o.rcode + 16*o.z + 128*o.ra, + highbyte(o.qdcount), o.qdcount %0x100, + highbyte(o.ancount), o.ancount %0x100, + highbyte(o.nscount), o.nscount %0x100, + highbyte(o.arcount), o.arcount %0x100 + ); + + return header, o.id; +end + + +local function encodeName(name) -- - - - - - - - - - - - - - - - encodeName + local t = {}; + for part in string.gmatch(name, '[^.]+') do + append(t, string.char(string.len(part))); + append(t, part); + end + append(t, string.char(0)); + return table.concat(t); +end + + +local function encodeQuestion(qname, qtype, qclass) -- - - - encodeQuestion + qname = encodeName(qname); + qtype = dns.typecode[qtype or 'a']; + qclass = dns.classcode[qclass or 'in']; + return qname..qtype..qclass; +end + + +function resolver:byte(len) -- - - - - - - - - - - - - - - - - - - - - byte + len = len or 1; + local offset = self.offset; + local last = offset + len - 1; + if last > #self.packet then + error(string.format('out of bounds: %i>%i', last, #self.packet)); + end + self.offset = offset + len; + return string.byte(self.packet, offset, last); +end + + +function resolver:word() -- - - - - - - - - - - - - - - - - - - - - - word + local b1, b2 = self:byte(2); + return 0x100*b1 + b2; +end + + +function resolver:dword () -- - - - - - - - - - - - - - - - - - - - - dword + local b1, b2, b3, b4 = self:byte(4); + --print('dword', b1, b2, b3, b4); + return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4; +end + + +function resolver:sub(len) -- - - - - - - - - - - - - - - - - - - - - - sub + len = len or 1; + local s = string.sub(self.packet, self.offset, self.offset + len - 1); + self.offset = self.offset + len; + return s; +end + + +function resolver:header(force) -- - - - - - - - - - - - - - - - - - header + local id = self:word(); + --print(string.format(':header id %x', id)); + if not self.active[id] and not force then return nil; end + + local h = { id = id }; + + local b1, b2 = self:byte(2); + + h.rd = b1 %2; + h.tc = b1 /2%2; + h.aa = b1 /4%2; + h.opcode = b1 /8%16; + h.qr = b1 /128; + + h.rcode = b2 %16; + h.z = b2 /16%8; + h.ra = b2 /128; + + h.qdcount = self:word(); + h.ancount = self:word(); + h.nscount = self:word(); + h.arcount = self:word(); + + for k,v in pairs(h) do h[k] = v-v%1; end + + return h; +end + + +function resolver:name() -- - - - - - - - - - - - - - - - - - - - - - name + local remember, pointers = nil, 0; + local len = self:byte(); + local n = {}; + if len == 0 then return "." end -- Root label + while len > 0 do + if len >= 0xc0 then -- name is "compressed" + pointers = pointers + 1; + if pointers >= 20 then error('dns error: 20 pointers'); end; + local offset = ((len-0xc0)*0x100) + self:byte(); + remember = remember or self.offset; + self.offset = offset + 1; -- +1 for lua + else -- name is not compressed + append(n, self:sub(len)..'.'); + end + len = self:byte(); + end + self.offset = remember or self.offset; + return table.concat(n); +end + + +function resolver:question() -- - - - - - - - - - - - - - - - - - question + local q = {}; + q.name = self:name(); + q.type = dns.type[self:word()]; + q.class = dns.class[self:word()]; + return q; +end + + +function resolver:A(rr) -- - - - - - - - - - - - - - - - - - - - - - - - A + local b1, b2, b3, b4 = self:byte(4); + rr.a = string.format('%i.%i.%i.%i', b1, b2, b3, b4); +end + +function resolver:AAAA(rr) + local addr = {}; + for i = 1, rr.rdlength, 2 do + local b1, b2 = self:byte(2); + table.insert(addr, ("%02x%02x"):format(b1, b2)); + end + addr = table.concat(addr, ":"):gsub("%f[%x]0+(%x)","%1"); + local zeros = {}; + for item in addr:gmatch(":[0:]+:") do + table.insert(zeros, item) + end + if #zeros == 0 then + rr.aaaa = addr; + return + elseif #zeros > 1 then + table.sort(zeros, function(a, b) return #a > #b end); + end + rr.aaaa = addr:gsub(zeros[1], "::", 1):gsub("^0::", "::"):gsub("::0$", "::"); +end + +function resolver:CNAME(rr) -- - - - - - - - - - - - - - - - - - - - CNAME + rr.cname = self:name(); +end + + +function resolver:MX(rr) -- - - - - - - - - - - - - - - - - - - - - - - MX + rr.pref = self:word(); + rr.mx = self:name(); +end + + +function resolver:LOC_nibble_power() -- - - - - - - - - - LOC_nibble_power + local b = self:byte(); + --print('nibbles', ((b-(b%0x10))/0x10), (b%0x10)); + return ((b-(b%0x10))/0x10) * (10^(b%0x10)); +end + + +function resolver:LOC(rr) -- - - - - - - - - - - - - - - - - - - - - - LOC + rr.version = self:byte(); + if rr.version == 0 then + rr.loc = rr.loc or {}; + rr.loc.size = self:LOC_nibble_power(); + rr.loc.horiz_pre = self:LOC_nibble_power(); + rr.loc.vert_pre = self:LOC_nibble_power(); + rr.loc.latitude = self:dword(); + rr.loc.longitude = self:dword(); + rr.loc.altitude = self:dword(); + end +end + + +local function LOC_tostring_degrees(f, pos, neg) -- - - - - - - - - - - - - + f = f - 0x80000000; + if f < 0 then pos = neg; f = -f; end + local deg, min, msec; + msec = f%60000; + f = (f-msec)/60000; + min = f%60; + deg = (f-min)/60; + return string.format('%3d %2d %2.3f %s', deg, min, msec/1000, pos); +end + + +function resolver.LOC_tostring(rr) -- - - - - - - - - - - - - LOC_tostring + local t = {}; + + --[[ + for k,name in pairs { 'size', 'horiz_pre', 'vert_pre', 'latitude', 'longitude', 'altitude' } do + append(t, string.format('%4s%-10s: %12.0f\n', '', name, rr.loc[name])); + end + --]] + + append(t, string.format( + '%s %s %.2fm %.2fm %.2fm %.2fm', + LOC_tostring_degrees (rr.loc.latitude, 'N', 'S'), + LOC_tostring_degrees (rr.loc.longitude, 'E', 'W'), + (rr.loc.altitude - 10000000) / 100, + rr.loc.size / 100, + rr.loc.horiz_pre / 100, + rr.loc.vert_pre / 100 + )); + + return table.concat(t); +end + + +function resolver:NS(rr) -- - - - - - - - - - - - - - - - - - - - - - - NS + rr.ns = self:name(); +end + + +function resolver:SOA(rr) -- - - - - - - - - - - - - - - - - - - - - - SOA +end + + +function resolver:SRV(rr) -- - - - - - - - - - - - - - - - - - - - - - SRV + rr.srv = {}; + rr.srv.priority = self:word(); + rr.srv.weight = self:word(); + rr.srv.port = self:word(); + rr.srv.target = self:name(); +end + +function resolver:PTR(rr) + rr.ptr = self:name(); +end + +function resolver:TXT(rr) -- - - - - - - - - - - - - - - - - - - - - - TXT + rr.txt = self:sub (self:byte()); +end + + +function resolver:rr() -- - - - - - - - - - - - - - - - - - - - - - - - rr + local rr = {}; + setmetatable(rr, rr_metatable); + rr.name = self:name(self); + rr.type = dns.type[self:word()] or rr.type; + rr.class = dns.class[self:word()] or rr.class; + rr.ttl = 0x10000*self:word() + self:word(); + rr.rdlength = self:word(); + + if rr.ttl <= 0 then + rr.tod = self.time + 30; + else + rr.tod = self.time + rr.ttl; + end + + local remember = self.offset; + local rr_parser = self[dns.type[rr.type]]; + if rr_parser then rr_parser(self, rr); end + self.offset = remember; + rr.rdata = self:sub(rr.rdlength); + return rr; +end + + +function resolver:rrs (count) -- - - - - - - - - - - - - - - - - - - - - rrs + local rrs = {}; + for i = 1,count do append(rrs, self:rr()); end + return rrs; +end + + +function resolver:decode(packet, force) -- - - - - - - - - - - - - - decode + self.packet, self.offset = packet, 1; + local header = self:header(force); + if not header then return nil; end + local response = { header = header }; + + response.question = {}; + local offset = self.offset; + for i = 1,response.header.qdcount do + append(response.question, self:question()); + end + response.question.raw = string.sub(self.packet, offset, self.offset - 1); + + if not force then + if not self.active[response.header.id] or not self.active[response.header.id][response.question.raw] then + self.active[response.header.id] = nil; + return nil; + end + end + + response.answer = self:rrs(response.header.ancount); + response.authority = self:rrs(response.header.nscount); + response.additional = self:rrs(response.header.arcount); + + return response; +end + + +-- socket layer -------------------------------------------------- socket layer + + +resolver.delays = { 1, 3 }; + + +function resolver:addnameserver(address) -- - - - - - - - - - addnameserver + self.server = self.server or {}; + append(self.server, address); +end + + +function resolver:setnameserver(address) -- - - - - - - - - - setnameserver + self.server = {}; + self:addnameserver(address); +end + + +function resolver:adddefaultnameservers() -- - - - - adddefaultnameservers + if is_windows then + if windows and windows.get_nameservers then + for _, server in ipairs(windows.get_nameservers()) do + self:addnameserver(server); + end + end + if not self.server or #self.server == 0 then + -- TODO log warning about no nameservers, adding opendns servers as fallback + self:addnameserver("208.67.222.222"); + self:addnameserver("208.67.220.220"); + end + else -- posix + local resolv_conf = io.open("/etc/resolv.conf"); + if resolv_conf then + for line in resolv_conf:lines() do + line = line:gsub("#.*$", "") + :match('^%s*nameserver%s+([%x:%.]*)%s*$'); + if line then + local ip = new_ip(line); + if ip then + self:addnameserver(ip.addr); + end + end + end + end + if not self.server or #self.server == 0 then + -- TODO log warning about no nameservers, adding localhost as the default nameserver + self:addnameserver("127.0.0.1"); + end + end +end + + +function resolver:getsocket(servernum) -- - - - - - - - - - - - - getsocket + self.socket = self.socket or {}; + self.socketset = self.socketset or {}; + + local sock = self.socket[servernum]; + if sock then return sock; end + + local ok, err; + local peer = self.server[servernum]; + if peer:find(":") then + sock, err = socket.udp6(); + else + sock, err = socket.udp(); + end + if sock and self.socket_wrapper then sock, err = self.socket_wrapper(sock, self); end + if not sock then + return nil, err; + end + sock:settimeout(0); + -- todo: attempt to use a random port, fallback to 0 + self.socket[servernum] = sock; + self.socketset[sock] = servernum; + -- set{sock,peer}name can fail, eg because of local routing table + -- if so, try the next server + ok, err = sock:setsockname('*', 0); + if not ok then return self:servfail(sock, err); end + ok, err = sock:setpeername(peer, 53); + if not ok then return self:servfail(sock, err); end + return sock; +end + +function resolver:voidsocket(sock) + if self.socket[sock] then + self.socketset[self.socket[sock]] = nil; + self.socket[sock] = nil; + elseif self.socketset[sock] then + self.socket[self.socketset[sock]] = nil; + self.socketset[sock] = nil; + end + sock:close(); +end + +function resolver:socket_wrapper_set(func) -- - - - - - - socket_wrapper_set + self.socket_wrapper = func; +end + + +function resolver:closeall () -- - - - - - - - - - - - - - - - - - closeall + for i,sock in ipairs(self.socket) do + self.socket[i] = nil; + self.socketset[sock] = nil; + sock:close(); + end +end + + +function resolver:remember(rr, type) -- - - - - - - - - - - - - - remember + --print ('remember', type, rr.class, rr.type, rr.name) + local qname, qtype, qclass = standardize(rr.name, rr.type, rr.class); + + if type ~= '*' then + type = qtype; + local all = get(self.cache, qclass, '*', qname); + --print('remember all', all); + if all then append(all, rr); end + end + + self.cache = self.cache or setmetatable({}, cache_metatable); + local rrs = get(self.cache, qclass, type, qname) or + set(self.cache, qclass, type, qname, setmetatable({}, rrs_metatable)); + if not rrs[rr[qtype:lower()]] then + rrs[rr[qtype:lower()]] = true; + append(rrs, rr); + end + + if type == 'MX' then self.unsorted[rrs] = true; end +end + + +local function comp_mx(a, b) -- - - - - - - - - - - - - - - - - - - comp_mx + return (a.pref == b.pref) and (a.mx < b.mx) or (a.pref < b.pref); +end + + +function resolver:peek (qname, qtype, qclass) -- - - - - - - - - - - - peek + qname, qtype, qclass = standardize(qname, qtype, qclass); + local rrs = get(self.cache, qclass, qtype, qname); + if not rrs then return nil; end + if prune(rrs, socket.gettime()) and qtype == '*' or not next(rrs) then + set(self.cache, qclass, qtype, qname, nil); + return nil; + end + if self.unsorted[rrs] then table.sort (rrs, comp_mx); end + return rrs; +end + + +function resolver:purge(soft) -- - - - - - - - - - - - - - - - - - - purge + if soft == 'soft' then + self.time = socket.gettime(); + for class,types in pairs(self.cache or {}) do + for type,names in pairs(types) do + for name,rrs in pairs(names) do + prune(rrs, self.time, 'soft') + end + end + end + else self.cache = setmetatable({}, cache_metatable); end +end + + +function resolver:query(qname, qtype, qclass) -- - - - - - - - - - -- query + qname, qtype, qclass = standardize(qname, qtype, qclass) + + local co = coroutine.running(); + local q = get(self.wanted, qclass, qtype, qname); + if co and q then + -- We are already waiting for a reply to an identical query. + set(self.wanted, qclass, qtype, qname, co, true); + return true; + end + + if not self.server then self:adddefaultnameservers(); end + + local question = encodeQuestion(qname, qtype, qclass); + local peek = self:peek (qname, qtype, qclass); + if peek then return peek; end + + local header, id = encodeHeader(); + --print ('query id', id, qclass, qtype, qname) + local o = { + packet = header..question, + server = self.best_server, + delay = 1, + retry = socket.gettime() + self.delays[1] + }; + + -- remember the query + self.active[id] = self.active[id] or {}; + self.active[id][question] = o; + + -- remember which coroutine wants the answer + if co then + set(self.wanted, qclass, qtype, qname, co, true); + end + + local conn, err = self:getsocket(o.server) + if not conn then + return nil, err; + end + conn:send (o.packet) + + if timer and self.timeout then + local num_servers = #self.server; + local i = 1; + timer.add_task(self.timeout, function () + if get(self.wanted, qclass, qtype, qname, co) then + if i < num_servers then + i = i + 1; + self:servfail(conn); + o.server = self.best_server; + conn, err = self:getsocket(o.server); + if conn then + conn:send(o.packet); + return self.timeout; + end + end + -- Tried everything, failed + self:cancel(qclass, qtype, qname); + end + end) + end + return true; +end + +function resolver:servfail(sock, err) + -- Resend all queries for this server + + local num = self.socketset[sock] + + -- Socket is dead now + sock = self:voidsocket(sock); + + -- Find all requests to the down server, and retry on the next server + self.time = socket.gettime(); + for id,queries in pairs(self.active) do + for question,o in pairs(queries) do + if o.server == num then -- This request was to the broken server + o.server = o.server + 1 -- Use next server + if o.server > #self.server then + o.server = 1; + end + + o.retries = (o.retries or 0) + 1; + if o.retries >= #self.server then + --print('timeout'); + queries[question] = nil; + else + sock, err = self:getsocket(o.server); + if sock then sock:send(o.packet); end + end + end + end + if next(queries) == nil then + self.active[id] = nil; + end + end + + if num == self.best_server then + self.best_server = self.best_server + 1; + if self.best_server > #self.server then + -- Exhausted all servers, try first again + self.best_server = 1; + end + end + return sock, err; +end + +function resolver:settimeout(seconds) + self.timeout = seconds; +end + +function resolver:receive(rset) -- - - - - - - - - - - - - - - - - receive + --print('receive'); print(self.socket); + self.time = socket.gettime(); + rset = rset or self.socket; + + local response; + for i,sock in pairs(rset) do + + if self.socketset[sock] then + local packet = sock:receive(); + if packet then + response = self:decode(packet); + if response and self.active[response.header.id] + and self.active[response.header.id][response.question.raw] then + --print('received response'); + --self.print(response); + + for j,rr in pairs(response.answer) do + if rr.name:sub(-#response.question[1].name, -1) == response.question[1].name then + self:remember(rr, response.question[1].type) + end + end + + -- retire the query + local queries = self.active[response.header.id]; + queries[response.question.raw] = nil; + + if not next(queries) then self.active[response.header.id] = nil; end + if not next(self.active) then self:closeall(); end + + -- was the query on the wanted list? + local q = response.question[1]; + local cos = get(self.wanted, q.class, q.type, q.name); + if cos then + for co in pairs(cos) do + if coroutine.status(co) == "suspended" then coroutine.resume(co); end + end + set(self.wanted, q.class, q.type, q.name, nil); + end + end + + end + end + end + + return response; +end + + +function resolver:feed(sock, packet, force) + --print('receive'); print(self.socket); + self.time = socket.gettime(); + + local response = self:decode(packet, force); + if response and self.active[response.header.id] + and self.active[response.header.id][response.question.raw] then + --print('received response'); + --self.print(response); + + for j,rr in pairs(response.answer) do + self:remember(rr, response.question[1].type); + end + + -- retire the query + local queries = self.active[response.header.id]; + queries[response.question.raw] = nil; + if not next(queries) then self.active[response.header.id] = nil; end + if not next(self.active) then self:closeall(); end + + -- was the query on the wanted list? + local q = response.question[1]; + if q then + local cos = get(self.wanted, q.class, q.type, q.name); + if cos then + for co in pairs(cos) do + if coroutine.status(co) == "suspended" then coroutine.resume(co); end + end + set(self.wanted, q.class, q.type, q.name, nil); + end + end + end + + return response; +end + +function resolver:cancel(qclass, qtype, qname) + local cos = get(self.wanted, qclass, qtype, qname); + if cos then + for co in pairs(cos) do + if coroutine.status(co) == "suspended" then coroutine.resume(co); end + end + set(self.wanted, qclass, qtype, qname, nil); + end +end + +function resolver:pulse() -- - - - - - - - - - - - - - - - - - - - - pulse + --print(':pulse'); + while self:receive() do end + if not next(self.active) then return nil; end + + self.time = socket.gettime(); + for id,queries in pairs(self.active) do + for question,o in pairs(queries) do + if self.time >= o.retry then + + o.server = o.server + 1; + if o.server > #self.server then + o.server = 1; + o.delay = o.delay + 1; + end + + if o.delay > #self.delays then + --print('timeout'); + queries[question] = nil; + if not next(queries) then self.active[id] = nil; end + if not next(self.active) then return nil; end + else + --print('retry', o.server, o.delay); + local _a = self.socket[o.server]; + if _a then _a:send(o.packet); end + o.retry = self.time + self.delays[o.delay]; + end + end + end + end + + if next(self.active) then return true; end + return nil; +end + + +function resolver:lookup(qname, qtype, qclass) -- - - - - - - - - - lookup + self:query (qname, qtype, qclass) + while self:pulse() do + local recvt = {} + for i, s in ipairs(self.socket) do + recvt[i] = s + end + socket.select(recvt, nil, 4) + end + --print(self.cache); + return self:peek(qname, qtype, qclass); +end + +function resolver:lookupex(handler, qname, qtype, qclass) -- - - - - - - - - - lookup + return self:peek(qname, qtype, qclass) or self:query(qname, qtype, qclass); +end + +function resolver:tohostname(ip) + return dns.lookup(ip:gsub("(%d+)%.(%d+)%.(%d+)%.(%d+)", "%4.%3.%2.%1.in-addr.arpa."), "PTR"); +end + +--print ---------------------------------------------------------------- print + + +local hints = { -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints + qr = { [0]='query', 'response' }, + opcode = { [0]='query', 'inverse query', 'server status request' }, + aa = { [0]='non-authoritative', 'authoritative' }, + tc = { [0]='complete', 'truncated' }, + rd = { [0]='recursion not desired', 'recursion desired' }, + ra = { [0]='recursion not available', 'recursion available' }, + z = { [0]='(reserved)' }, + rcode = { [0]='no error', 'format error', 'server failure', 'name error', 'not implemented' }, + + type = dns.type, + class = dns.class +}; + + +local function hint(p, s) -- - - - - - - - - - - - - - - - - - - - - - hint + return (hints[s] and hints[s][p[s]]) or ''; +end + + +function resolver.print(response) -- - - - - - - - - - - - - resolver.print + for s,s in pairs { 'id', 'qr', 'opcode', 'aa', 'tc', 'rd', 'ra', 'z', + 'rcode', 'qdcount', 'ancount', 'nscount', 'arcount' } do + print( string.format('%-30s', 'header.'..s), response.header[s], hint(response.header, s) ); + end + + for i,question in ipairs(response.question) do + print(string.format ('question[%i].name ', i), question.name); + print(string.format ('question[%i].type ', i), question.type); + print(string.format ('question[%i].class ', i), question.class); + end + + local common = { name=1, type=1, class=1, ttl=1, rdlength=1, rdata=1 }; + local tmp; + for s,s in pairs({'answer', 'authority', 'additional'}) do + for i,rr in pairs(response[s]) do + for j,t in pairs({ 'name', 'type', 'class', 'ttl', 'rdlength' }) do + tmp = string.format('%s[%i].%s', s, i, t); + print(string.format('%-30s', tmp), rr[t], hint(rr, t)); + end + for j,t in pairs(rr) do + if not common[j] then + tmp = string.format('%s[%i].%s', s, i, j); + print(string.format('%-30s %s', tostring(tmp), tostring(t))); + end + end + end + end +end + + +-- module api ------------------------------------------------------ module api + + +function dns.resolver () -- - - - - - - - - - - - - - - - - - - - - resolver + -- this function seems to be redundant with resolver.new () + + local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, best_server = 1 }; + setmetatable (r, resolver); + setmetatable (r.cache, cache_metatable); + setmetatable (r.unsorted, { __mode = 'kv' }); + return r; +end + +local _resolver = dns.resolver(); +dns._resolver = _resolver; + +function dns.lookup(...) -- - - - - - - - - - - - - - - - - - - - - lookup + return _resolver:lookup(...); +end + +function dns.tohostname(...) + return _resolver:tohostname(...); +end + +function dns.purge(...) -- - - - - - - - - - - - - - - - - - - - - - purge + return _resolver:purge(...); +end + +function dns.peek(...) -- - - - - - - - - - - - - - - - - - - - - - - peek + return _resolver:peek(...); +end + +function dns.query(...) -- - - - - - - - - - - - - - - - - - - - - - query + return _resolver:query(...); +end + +function dns.feed(...) -- - - - - - - - - - - - - - - - - - - - - - - feed + return _resolver:feed(...); +end + +function dns.cancel(...) -- - - - - - - - - - - - - - - - - - - - - - cancel + return _resolver:cancel(...); +end + +function dns.settimeout(...) + return _resolver:settimeout(...); +end + +function dns.cache() + return _resolver.cache; +end + +function dns.socket_wrapper_set(...) -- - - - - - - - - socket_wrapper_set + return _resolver:socket_wrapper_set(...); +end + +return dns; diff -r 000000000000 -r d363a6692a10 net/http.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/net/http.lua Thu Dec 25 10:48:06 2014 +0000 @@ -0,0 +1,202 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local b64 = require "mime".b64; +local url = require "socket.url" +local httpstream_new = require "net.http.parser".new; +local util_http = require "util.http"; + +local ssl_available = pcall(require, "ssl"); + +local server = require "net.server" + +local t_insert, t_concat = table.insert, table.concat; +local pairs = pairs; +local tonumber, tostring, xpcall, select, traceback = + tonumber, tostring, xpcall, select, debug.traceback; +local assert, error = assert, error + +local log = require "util.logger".init("http"); + +module "http" + +local requests = {}; -- Open requests + +local listener = { default_port = 80, default_mode = "*a" }; + +function listener.onconnect(conn) + local req = requests[conn]; + -- Send the request + local request_line = { req.method or "GET", " ", req.path, " HTTP/1.1\r\n" }; + if req.query then + t_insert(request_line, 4, "?"..req.query); + end + + conn:write(t_concat(request_line)); + local t = { [2] = ": ", [4] = "\r\n" }; + for k, v in pairs(req.headers) do + t[1], t[3] = k, v; + conn:write(t_concat(t)); + end + conn:write("\r\n"); + + if req.body then + conn:write(req.body); + end +end + +function listener.onincoming(conn, data) + local request = requests[conn]; + + if not request then + log("warn", "Received response from connection %s with no request attached!", tostring(conn)); + return; + end + + if data and request.reader then + request:reader(data); + end +end + +function listener.ondisconnect(conn, err) + local request = requests[conn]; + if request and request.conn then + request:reader(nil, err); + end + requests[conn] = nil; +end + +function listener.ondetach(conn) + requests[conn] = nil; +end + +local function request_reader(request, data, err) + if not request.parser then + local function error_cb(reason) + if request.callback then + request.callback(reason or "connection-closed", 0, request); + request.callback = nil; + end + destroy_request(request); + end + + if not data then + error_cb(err); + return; + end + + local function success_cb(r) + if request.callback then + request.callback(r.body, r.code, r, request); + if r.partial then return; end + request.callback = nil; + end + destroy_request(request); + end + local function options_cb() + return request; + end + request.parser = httpstream_new(success_cb, error_cb, "client", options_cb); + end + request.parser:feed(data); +end + +local function handleerr(err) log("error", "Traceback[http]: %s", traceback(tostring(err), 2)); end +function request(u, ex, callback) + local req = url.parse(u); + + if not (req and req.host) then + callback(nil, 0, req); + return nil, "invalid-url"; + end + + if not req.path then + req.path = "/"; + end + + local method, headers, body; + + local host, port = req.host, req.port; + local host_header = host; + if (port == "80" and req.scheme == "http") + or (port == "443" and req.scheme == "https") then + port = nil; + elseif port then + host_header = host_header..":"..port; + end + + headers = { + ["Host"] = host_header; + ["User-Agent"] = "Prosody XMPP Server"; + }; + + if req.userinfo then + headers["Authorization"] = "Basic "..b64(req.userinfo); + end + + if ex then + req.onlystatus = ex.onlystatus; + body = ex.body; + if body then + method = "POST"; + headers["Content-Length"] = tostring(#body); + headers["Content-Type"] = "application/x-www-form-urlencoded"; + end + if ex.method then method = ex.method; end + if ex.headers then + for k, v in pairs(ex.headers) do + headers[k] = v; + end + end + end + + -- Attach to request object + req.method, req.headers, req.body = method, headers, body; + + local using_https = req.scheme == "https"; + if using_https and not ssl_available then + error("SSL not available, unable to contact https URL"); + end + local port_number = port and tonumber(port) or (using_https and 443 or 80); + + local sslctx = false; + if using_https then + sslctx = ex and ex.sslctx or { mode = "client", protocol = "sslv23", options = { "no_sslv2", "no_sslv3" } }; + end + + local handler, conn = server.addclient(host, port_number, listener, "*a", sslctx) + if not handler then + callback(nil, 0, req); + return nil, conn; + end + req.handler, req.conn = handler, conn + req.write = function (...) return req.handler:write(...); end + + req.callback = function (content, code, request, response) log("debug", "Calling callback, status %s", code or "---"); return select(2, xpcall(function () return callback(content, code, request, response) end, handleerr)); end + req.success_on_chunk = ex.success_on_chunk; + req.reader = request_reader; + req.state = "status"; + + requests[req.handler] = req; + return req; +end + +function destroy_request(request) + if request.conn then + request.conn = nil; + request.handler:close() + end +end + +local urlencode, urldecode = util_http.urlencode, util_http.urldecode; +local formencode, formdecode = util_http.formencode, util_http.formdecode; + +_M.urlencode, _M.urldecode = urlencode, urldecode; +_M.formencode, _M.formdecode = formencode, formdecode; + +return _M; diff -r 000000000000 -r d363a6692a10 net/http/codes.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/net/http/codes.lua Thu Dec 25 10:48:06 2014 +0000 @@ -0,0 +1,67 @@ + +local response_codes = { + -- Source: http://www.iana.org/assignments/http-status-codes + -- s/^\(\d*\)\s*\(.*\S\)\s*\[RFC.*\]\s*$/^I["\1"] = "\2"; + [100] = "Continue"; + [101] = "Switching Protocols"; + [102] = "Processing"; + + [200] = "OK"; + [201] = "Created"; + [202] = "Accepted"; + [203] = "Non-Authoritative Information"; + [204] = "No Content"; + [205] = "Reset Content"; + [206] = "Partial Content"; + [207] = "Multi-Status"; + [208] = "Already Reported"; + [226] = "IM Used"; + + [300] = "Multiple Choices"; + [301] = "Moved Permanently"; + [302] = "Found"; + [303] = "See Other"; + [304] = "Not Modified"; + [305] = "Use Proxy"; + -- The 306 status code was used in a previous version of [RFC2616], is no longer used, and the code is reserved. + [307] = "Temporary Redirect"; + + [400] = "Bad Request"; + [401] = "Unauthorized"; + [402] = "Payment Required"; + [403] = "Forbidden"; + [404] = "Not Found"; + [405] = "Method Not Allowed"; + [406] = "Not Acceptable"; + [407] = "Proxy Authentication Required"; + [408] = "Request Timeout"; + [409] = "Conflict"; + [410] = "Gone"; + [411] = "Length Required"; + [412] = "Precondition Failed"; + [413] = "Request Entity Too Large"; + [414] = "Request-URI Too Long"; + [415] = "Unsupported Media Type"; + [416] = "Requested Range Not Satisfiable"; + [417] = "Expectation Failed"; + [418] = "I'm a teapot"; + [422] = "Unprocessable Entity"; + [423] = "Locked"; + [424] = "Failed Dependency"; + -- The 425 status code is reserved for the WebDAV advanced collections expired proposal [RFC2817] + [426] = "Upgrade Required"; + + [500] = "Internal Server Error"; + [501] = "Not Implemented"; + [502] = "Bad Gateway"; + [503] = "Service Unavailable"; + [504] = "Gateway Timeout"; + [505] = "HTTP Version Not Supported"; + [506] = "Variant Also Negotiates"; -- Experimental + [507] = "Insufficient Storage"; + [508] = "Loop Detected"; + [510] = "Not Extended"; +}; + +for k,v in pairs(response_codes) do response_codes[k] = k.." "..v; end +return setmetatable(response_codes, { __index = function(t, k) return k.." Unassigned"; end }) diff -r 000000000000 -r d363a6692a10 net/http/parser.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/net/http/parser.lua Thu Dec 25 10:48:06 2014 +0000 @@ -0,0 +1,170 @@ +local tonumber = tonumber; +local assert = assert; +local url_parse = require "socket.url".parse; +local urldecode = require "util.http".urldecode; + +local function preprocess_path(path) + path = urldecode((path:gsub("//+", "/"))); + if path:sub(1,1) ~= "/" then + path = "/"..path; + end + local level = 0; + for component in path:gmatch("([^/]+)/") do + if component == ".." then + level = level - 1; + elseif component ~= "." then + level = level + 1; + end + if level < 0 then + return nil; + end + end + return path; +end + +local httpstream = {}; + +function httpstream.new(success_cb, error_cb, parser_type, options_cb) + local client = true; + if not parser_type or parser_type == "server" then client = false; else assert(parser_type == "client", "Invalid parser type"); end + local buf = ""; + local chunked, chunk_size, chunk_start; + local success_on_chunk = (options_cb and options_cb().success_on_chunk); + local state = nil; + local packet; + local len; + local have_body; + local error; + return { + feed = function(self, data) + if error then return nil, "parse has failed"; end + if not data then -- EOF + if state and client and not len then -- reading client body until EOF + packet.body = buf; + success_cb(packet); + elseif buf ~= "" then -- unexpected EOF + error = true; return error_cb(); + end + return; + end + buf = buf..data; + while #buf > 0 do + if state == nil then -- read request + local index = buf:find("\r\n\r\n", nil, true); + if not index then return; end -- not enough data + local method, path, httpversion, status_code, reason_phrase; + local first_line; + local headers = {}; + for line in buf:sub(1,index+1):gmatch("([^\r\n]+)\r\n") do -- parse request + if first_line then + local key, val = line:match("^([^%s:]+): *(.*)$"); + if not key then error = true; return error_cb("invalid-header-line"); end -- TODO handle multi-line and invalid headers + key = key:lower(); + headers[key] = headers[key] and headers[key]..","..val or val; + else + first_line = line; + if client then + httpversion, status_code, reason_phrase = line:match("^HTTP/(1%.[01]) (%d%d%d) (.*)$"); + status_code = tonumber(status_code); + if not status_code then error = true; return error_cb("invalid-status-line"); end + have_body = not + ( (options_cb and options_cb().method == "HEAD") + or (status_code == 204 or status_code == 304 or status_code == 301) + or (status_code >= 100 and status_code < 200) ); + else + method, path, httpversion = line:match("^(%w+) (%S+) HTTP/(1%.[01])$"); + if not method then error = true; return error_cb("invalid-status-line"); end + end + end + end + if not first_line then error = true; return error_cb("invalid-status-line"); end + chunked = have_body and headers["transfer-encoding"] == "chunked"; + len = tonumber(headers["content-length"]); -- TODO check for invalid len + if client then + -- FIXME handle '100 Continue' response (by skipping it) + if not have_body then len = 0; end + packet = { + code = status_code; + httpversion = httpversion; + headers = headers; + body = have_body and "" or nil; + partial = chunked; + -- COMPAT the properties below are deprecated + responseversion = httpversion; + responseheaders = headers; + }; + else + local parsed_url; + if path:byte() == 47 then -- starts with / + local _path, _query = path:match("([^?]*).?(.*)"); + if _query == "" then _query = nil; end + parsed_url = { path = _path, query = _query }; + else + parsed_url = url_parse(path); + if not(parsed_url and parsed_url.path) then error = true; return error_cb("invalid-url"); end + end + path = preprocess_path(parsed_url.path); + headers.host = parsed_url.host or headers.host; + + len = len or 0; + packet = { + method = method; + url = parsed_url; + path = path; + httpversion = httpversion; + headers = headers; + body = nil; + }; + end + buf = buf:sub(index + 4); + state = true; + end + if state then -- read body + if client then + if chunked then + if not buf:find("\r\n", nil, true) then + return; + end -- not enough data + if not chunk_size then + chunk_size, chunk_start = buf:match("^(%x+)[^\r\n]*\r\n()"); + chunk_size = chunk_size and tonumber(chunk_size, 16); + if not chunk_size then error = true; return error_cb("invalid-chunk-size"); end + end + if chunk_size == 0 and buf:find("\r\n\r\n", chunk_start-2, true) then + state, chunk_size = nil, nil; + buf = buf:gsub("^.-\r\n\r\n", ""); -- This ensure extensions and trailers are stripped + packet.partial = nil; + success_cb(packet); + elseif #buf - chunk_start + 2 >= chunk_size then -- we have a chunk + packet.body = packet.body..buf:sub(chunk_start, chunk_start + (chunk_size-1)); + buf = buf:sub(chunk_start + chunk_size + 2); + chunk_size, chunk_start = nil, nil; + if success_on_chunk then + success_cb(packet); + end + else -- Partial chunk remaining + break; + end + elseif len and #buf >= len then + if packet.code == 101 then + packet.body, buf = buf, "" + else + packet.body, buf = buf:sub(1, len), buf:sub(len + 1); + end + state = nil; success_cb(packet); + else + break; + end + elseif #buf >= len then + packet.body, buf = buf:sub(1, len), buf:sub(len + 1); + state = nil; success_cb(packet); + else + break; + end + end + end + end; + }; +end + +return httpstream; diff -r 000000000000 -r d363a6692a10 net/http/server.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/net/http/server.lua Thu Dec 25 10:48:06 2014 +0000 @@ -0,0 +1,314 @@ + +local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat; +local parser_new = require "net.http.parser".new; +local events = require "util.events".new(); +local addserver = require "net.server".addserver; +local log = require "util.logger".init("http.server"); +local os_date = os.date; +local pairs = pairs; +local s_upper = string.upper; +local setmetatable = setmetatable; +local xpcall = xpcall; +local traceback = debug.traceback; +local tostring = tostring; +local codes = require "net.http.codes"; + +local _M = {}; + +local sessions = {}; +local listener = {}; +local hosts = {}; +local default_host; + +local function is_wildcard_event(event) + return event:sub(-2, -1) == "/*"; +end +local function is_wildcard_match(wildcard_event, event) + return wildcard_event:sub(1, -2) == event:sub(1, #wildcard_event-1); +end + +local recent_wildcard_events, max_cached_wildcard_events = {}, 10000; + +local event_map = events._event_map; +setmetatable(events._handlers, { + -- Called when firing an event that doesn't exist (but may match a wildcard handler) + __index = function (handlers, curr_event) + if is_wildcard_event(curr_event) then return; end -- Wildcard events cannot be fired + -- Find all handlers that could match this event, sort them + -- and then put the array into handlers[curr_event] (and return it) + local matching_handlers_set = {}; + local handlers_array = {}; + for event, handlers_set in pairs(event_map) do + if event == curr_event or + is_wildcard_event(event) and is_wildcard_match(event, curr_event) then + for handler, priority in pairs(handlers_set) do + matching_handlers_set[handler] = { (select(2, event:gsub("/", "%1"))), is_wildcard_event(event) and 0 or 1, priority }; + table.insert(handlers_array, handler); + end + end + end + if #handlers_array > 0 then + table.sort(handlers_array, function(b, a) + local a_score, b_score = matching_handlers_set[a], matching_handlers_set[b]; + for i = 1, #a_score do + if a_score[i] ~= b_score[i] then -- If equal, compare next score value + return a_score[i] < b_score[i]; + end + end + return false; + end); + else + handlers_array = false; + end + rawset(handlers, curr_event, handlers_array); + if not event_map[curr_event] then -- Only wildcard handlers match, if any + table.insert(recent_wildcard_events, curr_event); + if #recent_wildcard_events > max_cached_wildcard_events then + rawset(handlers, table.remove(recent_wildcard_events, 1), nil); + end + end + return handlers_array; + end; + __newindex = function (handlers, curr_event, handlers_array) + if handlers_array == nil + and is_wildcard_event(curr_event) then + -- Invalidate the indexes of all matching events + for event in pairs(handlers) do + if is_wildcard_match(curr_event, event) then + handlers[event] = nil; + end + end + end + rawset(handlers, curr_event, handlers_array); + end; +}); + +local handle_request; +local _1, _2, _3; +local function _handle_request() return handle_request(_1, _2, _3); end + +local last_err; +local function _traceback_handler(err) last_err = err; log("error", "Traceback[httpserver]: %s", traceback(tostring(err), 2)); end +events.add_handler("http-error", function (error) + return "Error processing request: "..codes[error.code]..". Check your error log for more information."; +end, -1); + +function listener.onconnect(conn) + local secure = conn:ssl() and true or nil; + local pending = {}; + local waiting = false; + local function process_next() + if waiting then return; end -- log("debug", "can't process_next, waiting"); + waiting = true; + while sessions[conn] and #pending > 0 do + local request = t_remove(pending); + --log("debug", "process_next: %s", request.path); + --handle_request(conn, request, process_next); + _1, _2, _3 = conn, request, process_next; + if not xpcall(_handle_request, _traceback_handler) then + conn:write("HTTP/1.0 500 Internal Server Error\r\n\r\n"..events.fire_event("http-error", { code = 500, private_message = last_err })); + conn:close(); + end + end + --log("debug", "ready for more"); + waiting = false; + end + local function success_cb(request) + --log("debug", "success_cb: %s", request.path); + if waiting then + log("error", "http connection handler is not reentrant: %s", request.path); + assert(false, "http connection handler is not reentrant"); + end + request.secure = secure; + t_insert(pending, request); + process_next(); + end + local function error_cb(err) + log("debug", "error_cb: %s", err or ""); + -- FIXME don't close immediately, wait until we process current stuff + -- FIXME if err, send off a bad-request response + sessions[conn] = nil; + conn:close(); + end + sessions[conn] = parser_new(success_cb, error_cb); +end + +function listener.ondisconnect(conn) + local open_response = conn._http_open_response; + if open_response and open_response.on_destroy then + open_response.finished = true; + open_response:on_destroy(); + end + sessions[conn] = nil; +end + +function listener.ondetach(conn) + sessions[conn] = nil; +end + +function listener.onincoming(conn, data) + sessions[conn]:feed(data); +end + +local headerfix = setmetatable({}, { + __index = function(t, k) + local v = "\r\n"..k:gsub("_", "-"):gsub("%f[%w].", s_upper)..": "; + t[k] = v; + return v; + end +}); + +function _M.hijack_response(response, listener) + error("TODO"); +end +function handle_request(conn, request, finish_cb) + --log("debug", "handler: %s", request.path); + local headers = {}; + for k,v in pairs(request.headers) do headers[k:gsub("-", "_")] = v; end + request.headers = headers; + request.conn = conn; + + local date_header = os_date('!%a, %d %b %Y %H:%M:%S GMT'); -- FIXME use + local conn_header = request.headers.connection; + conn_header = conn_header and ","..conn_header:gsub("[ \t]", ""):lower().."," or "" + local httpversion = request.httpversion + local persistent = conn_header:find(",keep-alive,", 1, true) + or (httpversion == "1.1" and not conn_header:find(",close,", 1, true)); + + local response_conn_header; + if persistent then + response_conn_header = "Keep-Alive"; + else + response_conn_header = httpversion == "1.1" and "close" or nil + end + + local response = { + request = request; + status_code = 200; + headers = { date = date_header, connection = response_conn_header }; + persistent = persistent; + conn = conn; + send = _M.send_response; + done = _M.finish_response; + finish_cb = finish_cb; + }; + conn._http_open_response = response; + + local host = (request.headers.host or ""):match("[^:]+"); + + -- Some sanity checking + local err_code, err; + if not request.path then + err_code, err = 400, "Invalid path"; + elseif not hosts[host] then + if hosts[default_host] then + host = default_host; + elseif host then + err_code, err = 404, "Unknown host: "..host; + else + err_code, err = 400, "Missing or invalid 'Host' header"; + end + end + + if err then + response.status_code = err_code; + response:send(events.fire_event("http-error", { code = err_code, message = err })); + return; + end + + local event = request.method.." "..host..request.path:match("[^?]*"); + local payload = { request = request, response = response }; + --log("debug", "Firing event: %s", event); + local result = events.fire_event(event, payload); + if result ~= nil then + if result ~= true then + local body; + local result_type = type(result); + if result_type == "number" then + response.status_code = result; + if result >= 400 then + body = events.fire_event("http-error", { code = result }); + end + elseif result_type == "string" then + body = result; + elseif result_type == "table" then + for k, v in pairs(result) do + if k ~= "headers" then + response[k] = v; + else + for header_name, header_value in pairs(v) do + response.headers[header_name] = header_value; + end + end + end + end + response:send(body); + end + return; + end + + -- if handler not called, return 404 + response.status_code = 404; + response:send(events.fire_event("http-error", { code = 404 })); +end +local function prepare_header(response) + local status_line = "HTTP/"..response.request.httpversion.." "..(response.status or codes[response.status_code]); + local headers = response.headers; + local output = { status_line }; + for k,v in pairs(headers) do + t_insert(output, headerfix[k]..v); + end + t_insert(output, "\r\n\r\n"); + return output; +end +_M.prepare_header = prepare_header; +function _M.send_response(response, body) + if response.finished then return; end + body = body or response.body or ""; + response.headers.content_length = #body; + local output = prepare_header(response); + t_insert(output, body); + response.conn:write(t_concat(output)); + response:done(); +end +function _M.finish_response(response) + if response.finished then return; end + response.finished = true; + response.conn._http_open_response = nil; + if response.on_destroy then + response:on_destroy(); + response.on_destroy = nil; + end + if response.persistent then + response:finish_cb(); + else + response.conn:close(); + end +end +function _M.add_handler(event, handler, priority) + events.add_handler(event, handler, priority); +end +function _M.remove_handler(event, handler) + events.remove_handler(event, handler); +end + +function _M.listen_on(port, interface, ssl) + addserver(interface or "*", port, listener, "*a", ssl); +end +function _M.add_host(host) + hosts[host] = true; +end +function _M.remove_host(host) + hosts[host] = nil; +end +function _M.set_default_host(host) + default_host = host; +end +function _M.fire_event(event, ...) + return events.fire_event(event, ...); +end + +_M.listener = listener; +_M.codes = codes; +_M._events = events; +return _M; diff -r 000000000000 -r d363a6692a10 net/httpserver.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/net/httpserver.lua Thu Dec 25 10:48:06 2014 +0000 @@ -0,0 +1,15 @@ +-- COMPAT w/pre-0.9 +local log = require "util.logger".init("net.httpserver"); +local traceback = debug.traceback; + +module "httpserver" + +function fail() + log("error", "Attempt to use legacy HTTP API. For more info see http://prosody.im/doc/developers/legacy_http"); + log("error", "Legacy HTTP API usage, %s", traceback("", 2)); +end + +new, new_from_config = fail, fail; +set_default_handler = fail; + +return _M; diff -r 000000000000 -r d363a6692a10 net/server.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/net/server.lua Thu Dec 25 10:48:06 2014 +0000 @@ -0,0 +1,104 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local server_type = prosody and require "core.configmanager".get("*", "network_backend") or "select"; +if prosody and require "core.configmanager".get("*", "use_libevent") then + server_type = "event"; +end + +if server_type == "event" then + if not pcall(require, "luaevent.core") then + log("error", "libevent not found, falling back to select()"); + server_type = "select" + end +end + +local server; +local set_config; +if server_type == "event" then + server = require "net.server_event"; + + local defaults = {}; + for k,v in pairs(server.cfg) do + defaults[k] = v; + end + function set_config(settings) + local event_settings = { + ACCEPT_DELAY = settings.event_accept_retry_interval; + ACCEPT_QUEUE = settings.tcp_backlog; + CLEAR_DELAY = settings.event_clear_interval; + CONNECT_TIMEOUT = settings.connect_timeout; + DEBUG = settings.debug; + HANDSHAKE_TIMEOUT = settings.ssl_handshake_timeout; + MAX_CONNECTIONS = settings.max_connections; + MAX_HANDSHAKE_ATTEMPTS = settings.max_ssl_handshake_roundtrips; + MAX_READ_LENGTH = settings.max_receive_buffer_size; + MAX_SEND_LENGTH = settings.max_send_buffer_size; + READ_TIMEOUT = settings.read_timeout; + WRITE_TIMEOUT = settings.send_timeout; + }; + + for k,default in pairs(defaults) do + server.cfg[k] = event_settings[k] or default; + end + end +elseif server_type == "select" then + server = require "net.server_select"; + + local defaults = {}; + for k,v in pairs(server.getsettings()) do + defaults[k] = v; + end + function set_config(settings) + local select_settings = {}; + for k,default in pairs(defaults) do + select_settings[k] = settings[k] or default; + end + server.changesettings(select_settings); + end +else + error("Unsupported server type") +end + +-- If server.hook_signal exists, replace signal.signal() +local has_signal, signal = pcall(require, "util.signal"); +if has_signal then + if server.hook_signal then + function signal.signal(signal_id, handler) + if type(signal_id) == "string" then + signal_id = signal[signal_id:upper()]; + end + if type(signal_id) ~= "number" then + return false, "invalid-signal"; + end + return server.hook_signal(signal_id, handler); + end + else + server.hook_signal = signal.signal; + end +else + if not server.hook_signal then + server.hook_signal = function() + return false, "signal hooking not supported" + end + end +end + +if prosody then + local config_get = require "core.configmanager".get; + local function load_config() + local settings = config_get("*", "network_settings") or {}; + return set_config(settings); + end + load_config(); + prosody.events.add_handler("config-reloaded", load_config); +end + +-- require "net.server" shall now forever return this, +-- ie. server_select or server_event as chosen above. +return server; diff -r 000000000000 -r d363a6692a10 net/server_event.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/net/server_event.lua Thu Dec 25 10:48:06 2014 +0000 @@ -0,0 +1,891 @@ +--[[ + + + server.lua based on lua/libevent by blastbeat + + notes: + -- when using luaevent, never register 2 or more EV_READ at one socket, same for EV_WRITE + -- you cant even register a new EV_READ/EV_WRITE callback inside another one + -- to do some of the above, use timeout events or something what will called from outside + -- dont let garbagecollect eventcallbacks, as long they are running + -- when using luasec, there are 4 cases of timeout errors: wantread or wantwrite during reading or writing + +--]] + +local SCRIPT_NAME = "server_event.lua" +local SCRIPT_VERSION = "0.05" +local SCRIPT_AUTHOR = "blastbeat" +local LAST_MODIFIED = "2009/11/20" + +local cfg = { + MAX_CONNECTIONS = 100000, -- max per server connections (use "ulimit -n" on *nix) + MAX_HANDSHAKE_ATTEMPTS= 1000, -- attempts to finish ssl handshake + HANDSHAKE_TIMEOUT = 60, -- timeout in seconds per handshake attempt + MAX_READ_LENGTH = 1024 * 1024 * 1024 * 1024, -- max bytes allowed to read from sockets + MAX_SEND_LENGTH = 1024 * 1024 * 1024 * 1024, -- max bytes size of write buffer (for writing on sockets) + ACCEPT_QUEUE = 128, -- might influence the length of the pending sockets queue + ACCEPT_DELAY = 10, -- seconds to wait until the next attempt of a full server to accept + READ_TIMEOUT = 60 * 60 * 6, -- timeout in seconds for read data from socket + WRITE_TIMEOUT = 180, -- timeout in seconds for write data on socket + CONNECT_TIMEOUT = 20, -- timeout in seconds for connection attempts + CLEAR_DELAY = 5, -- seconds to wait for clearing interface list (and calling ondisconnect listeners) + DEBUG = true, -- show debug messages +} + +local function use(x) return rawget(_G, x); end +local ipairs = use "ipairs" +local string = use "string" +local select = use "select" +local require = use "require" +local tostring = use "tostring" +local coroutine = use "coroutine" +local setmetatable = use "setmetatable" + +local t_insert = table.insert +local t_concat = table.concat + +local has_luasec, ssl = pcall ( require , "ssl" ) +local socket = use "socket" or require "socket" +local getaddrinfo = socket.dns.getaddrinfo + +local log = require ("util.logger").init("socket") + +local function debug(...) + return log("debug", ("%s "):rep(select('#', ...)), ...) +end +local vdebug = debug; + +local bitor = ( function( ) -- thx Rici Lake + local hasbit = function( x, p ) + return x % ( p + p ) >= p + end + return function( x, y ) + local p = 1 + local z = 0 + local limit = x > y and x or y + while p <= limit do + if hasbit( x, p ) or hasbit( y, p ) then + z = z + p + end + p = p + p + end + return z + end +end )( ) + +local event = require "luaevent.core" +local base = event.new( ) +local EV_READ = event.EV_READ +local EV_WRITE = event.EV_WRITE +local EV_TIMEOUT = event.EV_TIMEOUT +local EV_SIGNAL = event.EV_SIGNAL + +local EV_READWRITE = bitor( EV_READ, EV_WRITE ) + +local interfacelist = ( function( ) -- holds the interfaces for sockets + local array = { } + local len = 0 + return function( method, arg ) + if "add" == method then + len = len + 1 + array[ len ] = arg + arg:_position( len ) + return len + elseif "delete" == method then + if len <= 0 then + return nil, "array is already empty" + end + local position = arg:_position() -- get position in array + if position ~= len then + local interface = array[ len ] -- get last interface + array[ position ] = interface -- copy it into free position + array[ len ] = nil -- free last position + interface:_position( position ) -- set new position in array + else -- free last position + array[ len ] = nil + end + len = len - 1 + return len + else + return array + end + end +end )( ) + +-- Client interface methods +local interface_mt +do + interface_mt = {}; interface_mt.__index = interface_mt; + + local addevent = base.addevent + local coroutine_wrap, coroutine_yield = coroutine.wrap,coroutine.yield + + -- Private methods + function interface_mt:_position(new_position) + self.position = new_position or self.position + return self.position; + end + function interface_mt:_close() + return self:_destroy(); + end + + function interface_mt:_start_connection(plainssl) -- called from wrapclient + local callback = function( event ) + if EV_TIMEOUT == event then -- timeout during connection + self.fatalerror = "connection timeout" + self:ontimeout() -- call timeout listener + self:_close() + debug( "new connection failed. id:", self.id, "error:", self.fatalerror ) + else + if plainssl and has_luasec then -- start ssl session + self:starttls(self._sslctx, true) + else -- normal connection + self:_start_session(true) + end + debug( "new connection established. id:", self.id ) + end + self.eventconnect = nil + return -1 + end + self.eventconnect = addevent( base, self.conn, EV_WRITE, callback, cfg.CONNECT_TIMEOUT ) + return true + end + function interface_mt:_start_session(call_onconnect) -- new session, for example after startssl + if self.type == "client" then + local callback = function( ) + self:_lock( false, false, false ) + --vdebug( "start listening on client socket with id:", self.id ) + self.eventread = addevent( base, self.conn, EV_READ, self.readcallback, cfg.READ_TIMEOUT ); -- register callback + if call_onconnect then + self:onconnect() + end + self.eventsession = nil + return -1 + end + self.eventsession = addevent( base, nil, EV_TIMEOUT, callback, 0 ) + else + self:_lock( false ) + --vdebug( "start listening on server socket with id:", self.id ) + self.eventread = addevent( base, self.conn, EV_READ, self.readcallback ) -- register callback + end + return true + end + function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed, therefore we have to close read/write events first + --vdebug( "starting ssl session with client id:", self.id ) + local _ + _ = self.eventread and self.eventread:close( ) -- close events; this must be called outside of the event callbacks! + _ = self.eventwrite and self.eventwrite:close( ) + self.eventread, self.eventwrite = nil, nil + local err + self.conn, err = ssl.wrap( self.conn, self._sslctx ) + if err then + self.fatalerror = err + self.conn = nil -- cannot be used anymore + if call_onconnect then + self.ondisconnect = nil -- dont call this when client isnt really connected + end + self:_close() + debug( "fatal error while ssl wrapping:", err ) + return false + end + self.conn:settimeout( 0 ) -- set non blocking + local handshakecallback = coroutine_wrap( + function( event ) + local _, err + local attempt = 0 + local maxattempt = cfg.MAX_HANDSHAKE_ATTEMPTS + while attempt < maxattempt do -- no endless loop + attempt = attempt + 1 + debug( "ssl handshake of client with id:"..tostring(self)..", attempt:"..attempt ) + if attempt > maxattempt then + self.fatalerror = "max handshake attempts exceeded" + elseif EV_TIMEOUT == event then + self.fatalerror = "timeout during handshake" + else + _, err = self.conn:dohandshake( ) + if not err then + self:_lock( false, false, false ) -- unlock the interface; sending, closing etc allowed + self.send = self.conn.send -- caching table lookups with new client object + self.receive = self.conn.receive + if not call_onconnect then -- trigger listener + self:onstatus("ssl-handshake-complete"); + end + self:_start_session( call_onconnect ) + debug( "ssl handshake done" ) + self.eventhandshake = nil + return -1 + end + if err == "wantwrite" then + event = EV_WRITE + elseif err == "wantread" then + event = EV_READ + else + debug( "ssl handshake error:", err ) + self.fatalerror = err + end + end + if self.fatalerror then + if call_onconnect then + self.ondisconnect = nil -- dont call this when client isnt really connected + end + self:_close() + debug( "handshake failed because:", self.fatalerror ) + self.eventhandshake = nil + return -1 + end + event = coroutine_yield( event, cfg.HANDSHAKE_TIMEOUT ) -- yield this monster... + end + end + ) + debug "starting handshake..." + self:_lock( false, true, true ) -- unlock read/write events, but keep interface locked + self.eventhandshake = addevent( base, self.conn, EV_READWRITE, handshakecallback, cfg.HANDSHAKE_TIMEOUT ) + return true + end + function interface_mt:_destroy() -- close this interface + events and call last listener + debug( "closing client with id:", self.id, self.fatalerror ) + self:_lock( true, true, true ) -- first of all, lock the interface to avoid further actions + local _ + _ = self.eventread and self.eventread:close( ) + if self.type == "client" then + _ = self.eventwrite and self.eventwrite:close( ) + _ = self.eventhandshake and self.eventhandshake:close( ) + _ = self.eventstarthandshake and self.eventstarthandshake:close( ) + _ = self.eventconnect and self.eventconnect:close( ) + _ = self.eventsession and self.eventsession:close( ) + _ = self.eventwritetimeout and self.eventwritetimeout:close( ) + _ = self.eventreadtimeout and self.eventreadtimeout:close( ) + _ = self.ondisconnect and self:ondisconnect( self.fatalerror ~= "client to close" and self.fatalerror) -- call ondisconnect listener (wont be the case if handshake failed on connect) + _ = self.conn and self.conn:close( ) -- close connection + _ = self._server and self._server:counter(-1); + self.eventread, self.eventwrite = nil, nil + self.eventstarthandshake, self.eventhandshake, self.eventclose = nil, nil, nil + self.readcallback, self.writecallback = nil, nil + else + self.conn:close( ) + self.eventread, self.eventclose = nil, nil + self.interface, self.readcallback = nil, nil + end + interfacelist( "delete", self ) + return true + end + + function interface_mt:_lock(nointerface, noreading, nowriting) -- lock or unlock this interface or events + self.nointerface, self.noreading, self.nowriting = nointerface, noreading, nowriting + return nointerface, noreading, nowriting + end + + --TODO: Deprecate + function interface_mt:lock_read(switch) + if switch then + return self:pause(); + else + return self:resume(); + end + end + + function interface_mt:pause() + return self:_lock(self.nointerface, true, self.nowriting); + end + + function interface_mt:resume() + self:_lock(self.nointerface, false, self.nowriting); + if not self.eventread then + self.eventread = addevent( base, self.conn, EV_READ, self.readcallback, cfg.READ_TIMEOUT ); -- register callback + end + end + + function interface_mt:counter(c) + if c then + self._connections = self._connections + c + end + return self._connections + end + + -- Public methods + function interface_mt:write(data) + if self.nowriting then return nil, "locked" end + --vdebug( "try to send data to client, id/data:", self.id, data ) + data = tostring( data ) + local len = #data + local total = len + self.writebufferlen + if total > cfg.MAX_SEND_LENGTH then -- check buffer length + local err = "send buffer exceeded" + debug( "error:", err ) -- to much, check your app + return nil, err + end + t_insert(self.writebuffer, data) -- new buffer + self.writebufferlen = total + if not self.eventwrite then -- register new write event + --vdebug( "register new write event" ) + self.eventwrite = addevent( base, self.conn, EV_WRITE, self.writecallback, cfg.WRITE_TIMEOUT ) + end + return true + end + function interface_mt:close() + if self.nointerface then return nil, "locked"; end + debug( "try to close client connection with id:", self.id ) + if self.type == "client" then + self.fatalerror = "client to close" + if self.eventwrite then -- wait for incomplete write request + self:_lock( true, true, false ) + debug "closing delayed until writebuffer is empty" + return nil, "writebuffer not empty, waiting" + else -- close now + self:_lock( true, true, true ) + self:_close() + return true + end + else + debug( "try to close server with id:", tostring(self.id)) + self.fatalerror = "server to close" + self:_lock( true ) + self:_close( 0 ) + return true + end + end + + function interface_mt:socket() + return self.conn + end + + function interface_mt:server() + return self._server or self; + end + + function interface_mt:port() + return self._port + end + + function interface_mt:serverport() + return self._serverport + end + + function interface_mt:ip() + return self._ip + end + + function interface_mt:ssl() + return self._usingssl + end + interface_mt.clientport = interface_mt.port -- COMPAT server_select + + function interface_mt:type() + return self._type or "client" + end + + function interface_mt:connections() + return self._connections + end + + function interface_mt:address() + return self.addr + end + + function interface_mt:set_sslctx(sslctx) + self._sslctx = sslctx; + if sslctx then + self.starttls = nil; -- use starttls() of interface_mt + else + self.starttls = false; -- prevent starttls() + end + end + + function interface_mt:set_mode(pattern) + if pattern then + self._pattern = pattern; + end + return self._pattern; + end + + function interface_mt:set_send(new_send) + -- No-op, we always use the underlying connection's send + end + + function interface_mt:starttls(sslctx, call_onconnect) + debug( "try to start ssl at client id:", self.id ) + local err + self._sslctx = sslctx; + if self._usingssl then -- startssl was already called + err = "ssl already active" + end + if err then + debug( "error:", err ) + return nil, err + end + self._usingssl = true + self.startsslcallback = function( ) -- we have to start the handshake outside of a read/write event + self.startsslcallback = nil + self:_start_ssl(call_onconnect); + self.eventstarthandshake = nil + return -1 + end + if not self.eventwrite then + self:_lock( true, true, true ) -- lock the interface, to not disturb the handshake + self.eventstarthandshake = addevent( base, nil, EV_TIMEOUT, self.startsslcallback, 0 ) -- add event to start handshake + else -- wait until writebuffer is empty + self:_lock( true, true, false ) + debug "ssl session delayed until writebuffer is empty..." + end + self.starttls = false; + return true + end + + function interface_mt:setoption(option, value) + if self.conn.setoption then + return self.conn:setoption(option, value); + end + return false, "setoption not implemented"; + end + + function interface_mt:setlistener(listener) + self:ondetach(); -- Notify listener that it is no longer responsible for this connection + self.onconnect, self.ondisconnect, self.onincoming, self.ontimeout, + self.onreadtimeout, self.onstatus, self.ondetach + = listener.onconnect, listener.ondisconnect, listener.onincoming, listener.ontimeout, + listener.onreadtimeout, listener.onstatus, listener.ondetach; + end + + -- Stub handlers + function interface_mt:onconnect() + end + function interface_mt:onincoming() + end + function interface_mt:ondisconnect() + end + function interface_mt:ontimeout() + end + function interface_mt:onreadtimeout() + self.fatalerror = "timeout during receiving" + debug( "connection failed:", self.fatalerror ) + self:_close() + self.eventread = nil + end + function interface_mt:ondrain() + end + function interface_mt:ondetach() + end + function interface_mt:onstatus() + end +end + +-- End of client interface methods + +local handleclient; +do + local string_sub = string.sub -- caching table lookups + local addevent = base.addevent + local socket_gettime = socket.gettime + function handleclient( client, ip, port, server, pattern, listener, sslctx ) -- creates an client interface + --vdebug("creating client interfacce...") + local interface = { + type = "client"; + conn = client; + currenttime = socket_gettime( ); -- safe the origin + writebuffer = {}; -- writebuffer + writebufferlen = 0; -- length of writebuffer + send = client.send; -- caching table lookups + receive = client.receive; + onconnect = listener.onconnect; -- will be called when client disconnects + ondisconnect = listener.ondisconnect; -- will be called when client disconnects + onincoming = listener.onincoming; -- will be called when client sends data + ontimeout = listener.ontimeout; -- called when fatal socket timeout occurs + onreadtimeout = listener.onreadtimeout; -- called when socket inactivity timeout occurs + ondrain = listener.ondrain; -- called when writebuffer is empty + ondetach = listener.ondetach; -- called when disassociating this listener from this connection + onstatus = listener.onstatus; -- called for status changes (e.g. of SSL/TLS) + eventread = false, eventwrite = false, eventclose = false, + eventhandshake = false, eventstarthandshake = false; -- event handler + eventconnect = false, eventsession = false; -- more event handler... + eventwritetimeout = false; -- even more event handler... + eventreadtimeout = false; + fatalerror = false; -- error message + writecallback = false; -- will be called on write events + readcallback = false; -- will be called on read events + nointerface = true; -- lock/unlock parameter of this interface + noreading = false, nowriting = false; -- locks of the read/writecallback + startsslcallback = false; -- starting handshake callback + position = false; -- position of client in interfacelist + + -- Properties + _ip = ip, _port = port, _server = server, _pattern = pattern, + _serverport = (server and server:port() or nil), + _sslctx = sslctx; -- parameters + _usingssl = false; -- client is using ssl; + } + if not has_luasec then interface.starttls = false; end + interface.id = tostring(interface):match("%x+$"); + interface.writecallback = function( event ) -- called on write events + --vdebug( "new client write event, id/ip/port:", interface, ip, port ) + if interface.nowriting or ( interface.fatalerror and ( "client to close" ~= interface.fatalerror ) ) then -- leave this event + --vdebug( "leaving this event because:", interface.nowriting or interface.fatalerror ) + interface.eventwrite = false + return -1 + end + if EV_TIMEOUT == event then -- took too long to write some data to socket -> disconnect + interface.fatalerror = "timeout during writing" + debug( "writing failed:", interface.fatalerror ) + interface:_close() + interface.eventwrite = false + return -1 + else -- can write :) + if interface._usingssl then -- handle luasec + if interface.eventreadtimeout then -- we have to read first + local ret = interface.readcallback( ) -- call readcallback + --vdebug( "tried to read in writecallback, result:", ret ) + end + if interface.eventwritetimeout then -- luasec only + interface.eventwritetimeout:close( ) -- first we have to close timeout event which where regged after a wantread error + interface.eventwritetimeout = false + end + end + interface.writebuffer = { t_concat(interface.writebuffer) } + local succ, err, byte = interface.conn:send( interface.writebuffer[1], 1, interface.writebufferlen ) + --vdebug( "write data:", interface.writebuffer, "error:", err, "part:", byte ) + if succ then -- writing succesful + interface.writebuffer[1] = nil + interface.writebufferlen = 0 + interface:ondrain(); + if interface.fatalerror then + debug "closing client after writing" + interface:_close() -- close interface if needed + elseif interface.startsslcallback then -- start ssl connection if needed + debug "starting ssl handshake after writing" + interface.eventstarthandshake = addevent( base, nil, EV_TIMEOUT, interface.startsslcallback, 0 ) + elseif interface.eventreadtimeout then + return EV_WRITE, EV_TIMEOUT + end + interface.eventwrite = nil + return -1 + elseif byte and (err == "timeout" or err == "wantwrite") then -- want write again + --vdebug( "writebuffer is not empty:", err ) + interface.writebuffer[1] = string_sub( interface.writebuffer[1], byte + 1, interface.writebufferlen ) -- new buffer + interface.writebufferlen = interface.writebufferlen - byte + if "wantread" == err then -- happens only with luasec + local callback = function( ) + interface:_close() + interface.eventwritetimeout = nil + return -1; + end + interface.eventwritetimeout = addevent( base, nil, EV_TIMEOUT, callback, cfg.WRITE_TIMEOUT ) -- reg a new timeout event + debug( "wantread during write attempt, reg it in readcallback but dont know what really happens next..." ) + -- hopefully this works with luasec; its simply not possible to use 2 different write events on a socket in luaevent + return -1 + end + return EV_WRITE, cfg.WRITE_TIMEOUT + else -- connection was closed during writing or fatal error + interface.fatalerror = err or "fatal error" + debug( "connection failed in write event:", interface.fatalerror ) + interface:_close() + interface.eventwrite = nil + return -1 + end + end + end + + interface.readcallback = function( event ) -- called on read events + --vdebug( "new client read event, id/ip/port:", tostring(interface.id), tostring(ip), tostring(port) ) + if interface.noreading or interface.fatalerror then -- leave this event + --vdebug( "leaving this event because:", tostring(interface.noreading or interface.fatalerror) ) + interface.eventread = nil + return -1 + end + if EV_TIMEOUT == event and interface:onreadtimeout() ~= true then + return -1 -- took too long to get some data from client -> disconnect + end + if interface._usingssl then -- handle luasec + if interface.eventwritetimeout then -- ok, in the past writecallback was regged + local ret = interface.writecallback( ) -- call it + --vdebug( "tried to write in readcallback, result:", tostring(ret) ) + end + if interface.eventreadtimeout then + interface.eventreadtimeout:close( ) + interface.eventreadtimeout = nil + end + end + local buffer, err, part = interface.conn:receive( interface._pattern ) -- receive buffer with "pattern" + --vdebug( "read data:", tostring(buffer), "error:", tostring(err), "part:", tostring(part) ) + buffer = buffer or part + if buffer and #buffer > cfg.MAX_READ_LENGTH then -- check buffer length + interface.fatalerror = "receive buffer exceeded" + debug( "fatal error:", interface.fatalerror ) + interface:_close() + interface.eventread = nil + return -1 + end + if err and ( err ~= "timeout" and err ~= "wantread" ) then + if "wantwrite" == err then -- need to read on write event + if not interface.eventwrite then -- register new write event if needed + interface.eventwrite = addevent( base, interface.conn, EV_WRITE, interface.writecallback, cfg.WRITE_TIMEOUT ) + end + interface.eventreadtimeout = addevent( base, nil, EV_TIMEOUT, + function( ) + interface:_close() + end, cfg.READ_TIMEOUT + ) + debug( "wantwrite during read attempt, reg it in writecallback but dont know what really happens next..." ) + -- to be honest i dont know what happens next, if it is allowed to first read, the write etc... + else -- connection was closed or fatal error + interface.fatalerror = err + debug( "connection failed in read event:", interface.fatalerror ) + interface:_close() + interface.eventread = nil + return -1 + end + else + interface.onincoming( interface, buffer, err ) -- send new data to listener + end + if interface.noreading then + interface.eventread = nil; + return -1; + end + return EV_READ, cfg.READ_TIMEOUT + end + + client:settimeout( 0 ) -- set non blocking + setmetatable(interface, interface_mt) + interfacelist( "add", interface ) -- add to interfacelist + return interface + end +end + +local handleserver +do + function handleserver( server, addr, port, pattern, listener, sslctx ) -- creates an server interface + debug "creating server interface..." + local interface = { + _connections = 0; + + conn = server; + onconnect = listener.onconnect; -- will be called when new client connected + eventread = false; -- read event handler + eventclose = false; -- close event handler + readcallback = false; -- read event callback + fatalerror = false; -- error message + nointerface = true; -- lock/unlock parameter + + _ip = addr, _port = port, _pattern = pattern, + _sslctx = sslctx; + } + interface.id = tostring(interface):match("%x+$"); + interface.readcallback = function( event ) -- server handler, called on incoming connections + --vdebug( "server can accept, id/addr/port:", interface, addr, port ) + if interface.fatalerror then + --vdebug( "leaving this event because:", self.fatalerror ) + interface.eventread = nil + return -1 + end + local delay = cfg.ACCEPT_DELAY + if EV_TIMEOUT == event then + if interface._connections >= cfg.MAX_CONNECTIONS then -- check connection count + debug( "to many connections, seconds to wait for next accept:", delay ) + return EV_TIMEOUT, delay -- timeout... + else + return EV_READ -- accept again + end + end + --vdebug("max connection check ok, accepting...") + local client, err = server:accept() -- try to accept; TODO: check err + while client do + if interface._connections >= cfg.MAX_CONNECTIONS then + client:close( ) -- refuse connection + debug( "maximal connections reached, refuse client connection; accept delay:", delay ) + return EV_TIMEOUT, delay -- delay for next accept attempt + end + local client_ip, client_port = client:getpeername( ) + interface._connections = interface._connections + 1 -- increase connection count + local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, sslctx ) + --vdebug( "client id:", clientinterface, "startssl:", startssl ) + if has_luasec and sslctx then + clientinterface:starttls(sslctx, true) + else + clientinterface:_start_session( true ) + end + debug( "accepted incoming client connection from:", client_ip or "", client_port or "", "to", port or ""); + + client, err = server:accept() -- try to accept again + end + return EV_READ + end + + server:settimeout( 0 ) + setmetatable(interface, interface_mt) + interfacelist( "add", interface ) + interface:_start_session() + return interface + end +end + +local addserver = ( function( ) + return function( addr, port, listener, pattern, sslctx, startssl ) -- TODO: check arguments + --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslctx or "nil", startssl or "nil") + if sslctx and not has_luasec then + debug "fatal error: luasec not found" + return nil, "luasec not found" + end + local server, err = socket.bind( addr, port, cfg.ACCEPT_QUEUE ) -- create server socket + if not server then + debug( "creating server socket on "..addr.." port "..port.." failed:", err ) + return nil, err + end + local interface = handleserver( server, addr, port, pattern, listener, sslctx, startssl ) -- new server handler + debug( "new server created with id:", tostring(interface)) + return interface + end +end )( ) + +local addclient, wrapclient +do + function wrapclient( client, ip, port, listeners, pattern, sslctx ) + local interface = handleclient( client, ip, port, nil, pattern, listeners, sslctx ) + interface:_start_connection(sslctx) + return interface, client + --function handleclient( client, ip, port, server, pattern, listener, _, sslctx ) -- creates an client interface + end + + function addclient( addr, serverport, listener, pattern, sslctx, typ ) + if sslctx and not has_luasec then + debug "need luasec, but not available" + return nil, "luasec not found" + end + if getaddrinfo and not typ then + local addrinfo, err = getaddrinfo(addr) + if not addrinfo then return nil, err end + if addrinfo[1] and addrinfo[1].family == "inet6" then + typ = "tcp6" + end + end + local create = socket[typ or "tcp"] + if type( create ) ~= "function" then + return nil, "invalid socket type" + end + local client, err = create() -- creating new socket + if not client then + debug( "cannot create socket:", err ) + return nil, err + end + client:settimeout( 0 ) -- set nonblocking + local res, err = client:connect( addr, serverport ) -- connect + if res or ( err == "timeout" or err == "Operation already in progress" ) then + if client.getsockname then + addr = client:getsockname( ) + end + local interface = wrapclient( client, addr, serverport, listener, pattern, sslctx ) + debug( "new connection id:", interface.id ) + return interface, err + else + debug( "new connection failed:", err ) + return nil, err + end + end +end + + +local loop = function( ) -- starts the event loop + base:loop( ) + return "quitting"; +end + +local newevent = ( function( ) + local add = base.addevent + return function( ... ) + return add( base, ... ) + end +end )( ) + +local closeallservers = function( arg ) + for _, item in ipairs( interfacelist( ) ) do + if item.type == "server" then + item:close( arg ) + end + end +end + +local function setquitting(yes) + if yes then + -- Quit now + closeallservers(); + base:loopexit(); + end +end + +local function get_backend() + return base:method(); +end + +-- We need to hold onto the events to stop them +-- being garbage-collected +local signal_events = {}; -- [signal_num] -> event object +local function hook_signal(signal_num, handler) + local function _handler(event) + local ret = handler(); + if ret ~= false then -- Continue handling this signal? + return EV_SIGNAL; -- Yes + end + return -1; -- Close this event + end + signal_events[signal_num] = base:addevent(signal_num, EV_SIGNAL, _handler); + return signal_events[signal_num]; +end + +local function link(sender, receiver, buffersize) + local sender_locked; + + function receiver:ondrain() + if sender_locked then + sender:resume(); + sender_locked = nil; + end + end + + function sender:onincoming(data) + receiver:write(data); + if receiver.writebufferlen >= buffersize then + sender_locked = true; + sender:pause(); + end + end + sender:set_mode("*a"); +end + +local add_task do + local EVENT_LEAVE = (event.core and event.core.LEAVE) or -1; + local socket_gettime = socket.gettime + function add_task(delay, callback) + local event_handle; + event_handle = base:addevent(nil, 0, function () + local ret = callback(socket_gettime()); + if ret then + return 0, ret; + elseif event_handle then + return EVENT_LEAVE; + end + end + , delay); + end +end + +return { + + cfg = cfg, + base = base, + loop = loop, + link = link, + event = event, + event_base = base, + addevent = newevent, + addserver = addserver, + addclient = addclient, + wrapclient = wrapclient, + setquitting = setquitting, + closeall = closeallservers, + get_backend = get_backend, + hook_signal = hook_signal, + add_task = add_task, + + __NAME = SCRIPT_NAME, + __DATE = LAST_MODIFIED, + __AUTHOR = SCRIPT_AUTHOR, + __VERSION = SCRIPT_VERSION, + +} diff -r 000000000000 -r d363a6692a10 net/server_select.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/net/server_select.lua Thu Dec 25 10:48:06 2014 +0000 @@ -0,0 +1,1073 @@ +-- +-- server.lua by blastbeat of the luadch project +-- Re-used here under the MIT/X Consortium License +-- +-- Modifications (C) 2008-2010 Matthew Wild, Waqas Hussain +-- + +-- // wrapping luadch stuff // -- + +local use = function( what ) + return _G[ what ] +end + +local log, table_concat = require ("util.logger").init("socket"), table.concat; +local out_put = function (...) return log("debug", table_concat{...}); end +local out_error = function (...) return log("warn", table_concat{...}); end + +----------------------------------// DECLARATION //-- + +--// constants //-- + +local STAT_UNIT = 1 -- byte + +--// lua functions //-- + +local type = use "type" +local pairs = use "pairs" +local ipairs = use "ipairs" +local tonumber = use "tonumber" +local tostring = use "tostring" + +--// lua libs //-- + +local os = use "os" +local table = use "table" +local string = use "string" +local coroutine = use "coroutine" + +--// lua lib methods //-- + +local os_difftime = os.difftime +local math_min = math.min +local math_huge = math.huge +local table_concat = table.concat +local table_insert = table.insert +local string_sub = string.sub +local coroutine_wrap = coroutine.wrap +local coroutine_yield = coroutine.yield + +--// extern libs //-- + +local has_luasec, luasec = pcall ( require , "ssl" ) +local luasocket = use "socket" or require "socket" +local luasocket_gettime = luasocket.gettime +local getaddrinfo = luasocket.dns.getaddrinfo + +--// extern lib methods //-- + +local ssl_wrap = ( has_luasec and luasec.wrap ) +local socket_bind = luasocket.bind +local socket_sleep = luasocket.sleep +local socket_select = luasocket.select + +--// functions //-- + +local id +local loop +local stats +local idfalse +local closeall +local addsocket +local addserver +local addtimer +local getserver +local wrapserver +local getsettings +local closesocket +local removesocket +local removeserver +local wrapconnection +local changesettings + +--// tables //-- + +local _server +local _readlist +local _timerlist +local _sendlist +local _socketlist +local _closelist +local _readtimes +local _writetimes + +--// simple data types //-- + +local _ +local _readlistlen +local _sendlistlen +local _timerlistlen + +local _sendtraffic +local _readtraffic + +local _selecttimeout +local _sleeptime +local _tcpbacklog + +local _starttime +local _currenttime + +local _maxsendlen +local _maxreadlen + +local _checkinterval +local _sendtimeout +local _readtimeout + +local _timer + +local _maxselectlen +local _maxfd + +local _maxsslhandshake + +----------------------------------// DEFINITION //-- + +_server = { } -- key = port, value = table; list of listening servers +_readlist = { } -- array with sockets to read from +_sendlist = { } -- arrary with sockets to write to +_timerlist = { } -- array of timer functions +_socketlist = { } -- key = socket, value = wrapped socket (handlers) +_readtimes = { } -- key = handler, value = timestamp of last data reading +_writetimes = { } -- key = handler, value = timestamp of last data writing/sending +_closelist = { } -- handlers to close + +_readlistlen = 0 -- length of readlist +_sendlistlen = 0 -- length of sendlist +_timerlistlen = 0 -- lenght of timerlist + +_sendtraffic = 0 -- some stats +_readtraffic = 0 + +_selecttimeout = 1 -- timeout of socket.select +_sleeptime = 0 -- time to wait at the end of every loop +_tcpbacklog = 128 -- some kind of hint to the OS + +_maxsendlen = 51000 * 1024 -- max len of send buffer +_maxreadlen = 25000 * 1024 -- max len of read buffer + +_checkinterval = 30 -- interval in secs to check idle clients +_sendtimeout = 60000 -- allowed send idle time in secs +_readtimeout = 6 * 60 * 60 -- allowed read idle time in secs + +local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows +_maxfd = (is_windows and math.huge) or luasocket._SETSIZE or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows +_maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows + +_maxsslhandshake = 30 -- max handshake round-trips + +----------------------------------// PRIVATE //-- + +wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd + + if socket:getfd() >= _maxfd then + out_error("server.lua: Disallowed FD number: "..socket:getfd()) + socket:close() + return nil, "fd-too-large" + end + + local connections = 0 + + local dispatch, disconnect = listeners.onconnect, listeners.ondisconnect + + local accept = socket.accept + + --// public methods of the object //-- + + local handler = { } + + handler.shutdown = function( ) end + + handler.ssl = function( ) + return sslctx ~= nil + end + handler.sslctx = function( ) + return sslctx + end + handler.remove = function( ) + connections = connections - 1 + if handler then + handler.resume( ) + end + end + handler.close = function() + socket:close( ) + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + _server[ip..":"..serverport] = nil; + _socketlist[ socket ] = nil + handler = nil + socket = nil + --mem_free( ) + out_put "server.lua: closed server handler and removed sockets from list" + end + handler.pause = function( hard ) + if not handler.paused then + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + if hard then + _socketlist[ socket ] = nil + socket:close( ) + socket = nil; + end + handler.paused = true; + end + end + handler.resume = function( ) + if handler.paused then + if not socket then + socket = socket_bind( ip, serverport, _tcpbacklog ); + socket:settimeout( 0 ) + end + _readlistlen = addsocket(_readlist, socket, _readlistlen) + _socketlist[ socket ] = handler + handler.paused = false; + end + end + handler.ip = function( ) + return ip + end + handler.serverport = function( ) + return serverport + end + handler.socket = function( ) + return socket + end + handler.readbuffer = function( ) + if _readlistlen >= _maxselectlen or _sendlistlen >= _maxselectlen then + handler.pause( ) + out_put( "server.lua: refused new client connection: server full" ) + return false + end + local client, err = accept( socket ) -- try to accept + if client then + local ip, clientport = client:getpeername( ) + local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket + if err then -- error while wrapping ssl socket + return false + end + connections = connections + 1 + out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport)) + if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes + return dispatch( handler ); + end + return; + elseif err then -- maybe timeout or something else + out_put( "server.lua: error with new client connection: ", tostring(err) ) + return false + end + end + return handler +end + +wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object + + if socket:getfd() >= _maxfd then + out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent + socket:close( ) -- Should we send some kind of error here? + if server then + server.pause( ) + end + return nil, nil, "fd-too-large" + end + socket:settimeout( 0 ) + + --// local import of socket methods //-- + + local send + local receive + local shutdown + + --// private closures of the object //-- + + local ssl + + local dispatch = listeners.onincoming + local status = listeners.onstatus + local disconnect = listeners.ondisconnect + local drain = listeners.ondrain + local onreadtimeout = listeners.onreadtimeout; + local detach = listeners.ondetach + + local bufferqueue = { } -- buffer array + local bufferqueuelen = 0 -- end of buffer array + + local toclose + local fatalerror + local needtls + + local bufferlen = 0 + + local noread = false + local nosend = false + + local sendtraffic, readtraffic = 0, 0 + + local maxsendlen = _maxsendlen + local maxreadlen = _maxreadlen + + --// public methods of the object //-- + + local handler = bufferqueue -- saves a table ^_^ + + handler.dispatch = function( ) + return dispatch + end + handler.disconnect = function( ) + return disconnect + end + handler.onreadtimeout = onreadtimeout; + + handler.setlistener = function( self, listeners ) + if detach then + detach(self) -- Notify listener that it is no longer responsible for this connection + end + dispatch = listeners.onincoming + disconnect = listeners.ondisconnect + status = listeners.onstatus + drain = listeners.ondrain + handler.onreadtimeout = listeners.onreadtimeout + detach = listeners.ondetach + end + handler.getstats = function( ) + return readtraffic, sendtraffic + end + handler.ssl = function( ) + return ssl + end + handler.sslctx = function ( ) + return sslctx + end + handler.send = function( _, data, i, j ) + return send( socket, data, i, j ) + end + handler.receive = function( pattern, prefix ) + return receive( socket, pattern, prefix ) + end + handler.shutdown = function( pattern ) + return shutdown( socket, pattern ) + end + handler.setoption = function (self, option, value) + if socket.setoption then + return socket:setoption(option, value); + end + return false, "setoption not implemented"; + end + handler.force_close = function ( self, err ) + if bufferqueuelen ~= 0 then + out_put("server.lua: discarding unwritten data for ", tostring(ip), ":", tostring(clientport)) + bufferqueuelen = 0; + end + return self:close(err); + end + handler.close = function( self, err ) + if not handler then return true; end + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + _readtimes[ handler ] = nil + if bufferqueuelen ~= 0 then + handler.sendbuffer() -- Try now to send any outstanding data + if bufferqueuelen ~= 0 then -- Still not empty, so we'll try again later + if handler then + handler.write = nil -- ... but no further writing allowed + end + toclose = true + return false + end + end + if socket then + _ = shutdown and shutdown( socket ) + socket:close( ) + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) + _socketlist[ socket ] = nil + socket = nil + else + out_put "server.lua: socket already closed" + end + if handler then + _writetimes[ handler ] = nil + _closelist[ handler ] = nil + local _handler = handler; + handler = nil + if disconnect then + disconnect(_handler, err or false); + disconnect = nil + end + end + if server then + server.remove( ) + end + out_put "server.lua: closed client handler and removed socket from list" + return true + end + handler.ip = function( ) + return ip + end + handler.serverport = function( ) + return serverport + end + handler.clientport = function( ) + return clientport + end + handler.port = handler.clientport -- COMPAT server_event + local write = function( self, data ) + bufferlen = bufferlen + #data + if bufferlen > maxsendlen then + _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle + handler.write = idfalse -- dont write anymore + return false + elseif socket and not _sendlist[ socket ] then + _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) + end + bufferqueuelen = bufferqueuelen + 1 + bufferqueue[ bufferqueuelen ] = data + if handler then + _writetimes[ handler ] = _writetimes[ handler ] or _currenttime + end + return true + end + handler.write = write + handler.bufferqueue = function( self ) + return bufferqueue + end + handler.socket = function( self ) + return socket + end + handler.set_mode = function( self, new ) + pattern = new or pattern + return pattern + end + handler.set_send = function ( self, newsend ) + send = newsend or send + return send + end + handler.bufferlen = function( self, readlen, sendlen ) + maxsendlen = sendlen or maxsendlen + maxreadlen = readlen or maxreadlen + return bufferlen, maxreadlen, maxsendlen + end + --TODO: Deprecate + handler.lock_read = function (self, switch) + if switch == true then + local tmp = _readlistlen + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + _readtimes[ handler ] = nil + if _readlistlen ~= tmp then + noread = true + end + elseif switch == false then + if noread then + noread = false + _readlistlen = addsocket(_readlist, socket, _readlistlen) + _readtimes[ handler ] = _currenttime + end + end + return noread + end + handler.pause = function (self) + return self:lock_read(true); + end + handler.resume = function (self) + return self:lock_read(false); + end + handler.lock = function( self, switch ) + handler.lock_read (switch) + if switch == true then + handler.write = idfalse + local tmp = _sendlistlen + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) + _writetimes[ handler ] = nil + if _sendlistlen ~= tmp then + nosend = true + end + elseif switch == false then + handler.write = write + if nosend then + nosend = false + write( "" ) + end + end + return noread, nosend + end + local _readbuffer = function( ) -- this function reads data + local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern" + if not err or (err == "wantread" or err == "timeout") then -- received something + local buffer = buffer or part or "" + local len = #buffer + if len > maxreadlen then + handler:close( "receive buffer exceeded" ) + return false + end + local count = len * STAT_UNIT + readtraffic = readtraffic + count + _readtraffic = _readtraffic + count + _readtimes[ handler ] = _currenttime + --out_put( "server.lua: read data '", buffer:gsub("[^%w%p ]", "."), "', error: ", err ) + return dispatch( handler, buffer, err ) + else -- connections was closed or fatal error + out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) ) + fatalerror = true + _ = handler and handler:force_close( err ) + return false + end + end + local _sendbuffer = function( ) -- this function sends data + local succ, err, byte, buffer, count; + if socket then + buffer = table_concat( bufferqueue, "", 1, bufferqueuelen ) + succ, err, byte = send( socket, buffer, 1, bufferlen ) + count = ( succ or byte or 0 ) * STAT_UNIT + sendtraffic = sendtraffic + count + _sendtraffic = _sendtraffic + count + for i = bufferqueuelen,1,-1 do + bufferqueue[ i ] = nil + end + --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) ) + else + succ, err, count = false, "unexpected close", 0; + end + if succ then -- sending succesful + bufferqueuelen = 0 + bufferlen = 0 + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) -- delete socket from writelist + _writetimes[ handler ] = nil + if drain then + drain(handler) + end + _ = needtls and handler:starttls(nil) + _ = toclose and handler:force_close( ) + return true + elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write + buffer = string_sub( buffer, byte + 1, bufferlen ) -- new buffer + bufferqueue[ 1 ] = buffer -- insert new buffer in queue + bufferqueuelen = 1 + bufferlen = bufferlen - byte + _writetimes[ handler ] = _currenttime + return true + else -- connection was closed during sending or fatal error + out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) ) + fatalerror = true + _ = handler and handler:force_close( err ) + return false + end + end + + -- Set the sslctx + local handshake; + function handler.set_sslctx(self, new_sslctx) + sslctx = new_sslctx; + local read, wrote + handshake = coroutine_wrap( function( client ) -- create handshake coroutine + local err + for i = 1, _maxsslhandshake do + _sendlistlen = ( wrote and removesocket( _sendlist, client, _sendlistlen ) ) or _sendlistlen + _readlistlen = ( read and removesocket( _readlist, client, _readlistlen ) ) or _readlistlen + read, wrote = nil, nil + _, err = client:dohandshake( ) + if not err then + out_put( "server.lua: ssl handshake done" ) + handler.readbuffer = _readbuffer -- when handshake is done, replace the handshake function with regular functions + handler.sendbuffer = _sendbuffer + _ = status and status( handler, "ssl-handshake-complete" ) + if self.autostart_ssl and listeners.onconnect then + listeners.onconnect(self); + if bufferqueuelen ~= 0 then + _sendlistlen = addsocket(_sendlist, client, _sendlistlen) + end + end + _readlistlen = addsocket(_readlist, client, _readlistlen) + return true + else + if err == "wantwrite" then + _sendlistlen = addsocket(_sendlist, client, _sendlistlen) + wrote = true + elseif err == "wantread" then + _readlistlen = addsocket(_readlist, client, _readlistlen) + read = true + else + break; + end + err = nil; + coroutine_yield( ) -- handshake not finished + end + end + out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") ) + _ = handler and handler:force_close("ssl handshake failed") + return false, err -- handshake failed + end + ) + end + if has_luasec then + handler.starttls = function( self, _sslctx) + if _sslctx then + handler:set_sslctx(_sslctx); + end + if bufferqueuelen > 0 then + out_put "server.lua: we need to do tls, but delaying until send buffer empty" + needtls = true + return + end + out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) + local oldsocket, err = socket + socket, err = ssl_wrap( socket, sslctx ) -- wrap socket + if not socket then + out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") ) + return nil, err -- fatal error + end + + socket:settimeout( 0 ) + + -- add the new socket to our system + send = socket.send + receive = socket.receive + shutdown = id + _socketlist[ socket ] = handler + _readlistlen = addsocket(_readlist, socket, _readlistlen) + + -- remove traces of the old socket + _readlistlen = removesocket( _readlist, oldsocket, _readlistlen ) + _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen ) + _socketlist[ oldsocket ] = nil + + handler.starttls = nil + needtls = nil + + -- Secure now (if handshake fails connection will close) + ssl = true + + handler.readbuffer = handshake + handler.sendbuffer = handshake + return handshake( socket ) -- do handshake + end + end + + handler.readbuffer = _readbuffer + handler.sendbuffer = _sendbuffer + send = socket.send + receive = socket.receive + shutdown = ( ssl and id ) or socket.shutdown + + _socketlist[ socket ] = handler + _readlistlen = addsocket(_readlist, socket, _readlistlen) + + if sslctx and has_luasec then + out_put "server.lua: auto-starting ssl negotiation..." + handler.autostart_ssl = true; + local ok, err = handler:starttls(sslctx); + if ok == false then + return nil, nil, err + end + end + + return handler, socket +end + +id = function( ) +end + +idfalse = function( ) + return false +end + +addsocket = function( list, socket, len ) + if not list[ socket ] then + len = len + 1 + list[ len ] = socket + list[ socket ] = len + end + return len; +end + +removesocket = function( list, socket, len ) -- this function removes sockets from a list ( copied from copas ) + local pos = list[ socket ] + if pos then + list[ socket ] = nil + local last = list[ len ] + list[ len ] = nil + if last ~= socket then + list[ last ] = pos + list[ pos ] = last + end + return len - 1 + end + return len +end + +closesocket = function( socket ) + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + _socketlist[ socket ] = nil + socket:close( ) + --mem_free( ) +end + +local function link(sender, receiver, buffersize) + local sender_locked; + local _sendbuffer = receiver.sendbuffer; + function receiver.sendbuffer() + _sendbuffer(); + if sender_locked and receiver.bufferlen() < buffersize then + sender:lock_read(false); -- Unlock now + sender_locked = nil; + end + end + + local _readbuffer = sender.readbuffer; + function sender.readbuffer() + _readbuffer(); + if not sender_locked and receiver.bufferlen() >= buffersize then + sender_locked = true; + sender:lock_read(true); + end + end + sender:set_mode("*a"); +end + +----------------------------------// PUBLIC //-- + +addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server + addr = addr or "*" + local err + if type( listeners ) ~= "table" then + err = "invalid listener table" + elseif type ( addr ) ~= "string" then + err = "invalid address" + elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then + err = "invalid port" + elseif _server[ addr..":"..port ] then + err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist" + elseif sslctx and not has_luasec then + err = "luasec not found" + end + if err then + out_error( "server.lua, [", addr, "]:", port, ": ", err ) + return nil, err + end + local server, err = socket_bind( addr, port, _tcpbacklog ) + if err then + out_error( "server.lua, [", addr, "]:", port, ": ", err ) + return nil, err + end + local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket + if not handler then + server:close( ) + return nil, err + end + server:settimeout( 0 ) + _readlistlen = addsocket(_readlist, server, _readlistlen) + _server[ addr..":"..port ] = handler + _socketlist[ server ] = handler + out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '[", addr, "]:", port, "'" ) + return handler +end + +getserver = function ( addr, port ) + return _server[ addr..":"..port ]; +end + +removeserver = function( addr, port ) + local handler = _server[ addr..":"..port ] + if not handler then + return nil, "no server found on '[" .. addr .. "]:" .. tostring( port ) .. "'" + end + handler:close( ) + _server[ addr..":"..port ] = nil + return true +end + +closeall = function( ) + for _, handler in pairs( _socketlist ) do + handler:close( ) + _socketlist[ _ ] = nil + end + _readlistlen = 0 + _sendlistlen = 0 + _timerlistlen = 0 + _server = { } + _readlist = { } + _sendlist = { } + _timerlist = { } + _socketlist = { } + --mem_free( ) +end + +getsettings = function( ) + return { + select_timeout = _selecttimeout; + select_sleep_time = _sleeptime; + tcp_backlog = _tcpbacklog; + max_send_buffer_size = _maxsendlen; + max_receive_buffer_size = _maxreadlen; + select_idle_check_interval = _checkinterval; + send_timeout = _sendtimeout; + read_timeout = _readtimeout; + max_connections = _maxselectlen; + max_ssl_handshake_roundtrips = _maxsslhandshake; + highest_allowed_fd = _maxfd; + } +end + +changesettings = function( new ) + if type( new ) ~= "table" then + return nil, "invalid settings table" + end + _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout + _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime + _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen + _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen + _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval + _tcpbacklog = tonumber( new.tcp_backlog ) or _tcpbacklog + _sendtimeout = tonumber( new.send_timeout ) or _sendtimeout + _readtimeout = tonumber( new.read_timeout ) or _readtimeout + _maxselectlen = new.max_connections or _maxselectlen + _maxsslhandshake = new.max_ssl_handshake_roundtrips or _maxsslhandshake + _maxfd = new.highest_allowed_fd or _maxfd + return true +end + +addtimer = function( listener ) + if type( listener ) ~= "function" then + return nil, "invalid listener function" + end + _timerlistlen = _timerlistlen + 1 + _timerlist[ _timerlistlen ] = listener + return true +end + +local add_task do + local data = {}; + local new_data = {}; + + function add_task(delay, callback) + local current_time = luasocket_gettime(); + delay = delay + current_time; + if delay >= current_time then + table_insert(new_data, {delay, callback}); + else + local r = callback(current_time); + if r and type(r) == "number" then + return add_task(r, callback); + end + end + end + + addtimer(function() + local current_time = luasocket_gettime(); + if #new_data > 0 then + for _, d in pairs(new_data) do + table_insert(data, d); + end + new_data = {}; + end + + local next_time = math_huge; + for i, d in pairs(data) do + local t, callback = d[1], d[2]; + if t <= current_time then + data[i] = nil; + local r = callback(current_time); + if type(r) == "number" then + add_task(r, callback); + next_time = math_min(next_time, r); + end + else + next_time = math_min(next_time, t - current_time); + end + end + return next_time; + end); +end + +stats = function( ) + return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen +end + +local quitting; + +local function setquitting(quit) + quitting = not not quit; +end + +loop = function(once) -- this is the main loop of the program + if quitting then return "quitting"; end + if once then quitting = "once"; end + local next_timer_time = math_huge; + repeat + local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) ) + for i, socket in ipairs( write ) do -- send data waiting in writequeues + local handler = _socketlist[ socket ] + if handler then + handler.sendbuffer( ) + else + closesocket( socket ) + out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen + end + end + for i, socket in ipairs( read ) do -- receive data + local handler = _socketlist[ socket ] + if handler then + handler.readbuffer( ) + else + closesocket( socket ) + out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen + end + end + for handler, err in pairs( _closelist ) do + handler.disconnect( )( handler, err ) + handler:force_close() -- forced disconnect + _closelist[ handler ] = nil; + end + _currenttime = luasocket_gettime( ) + + -- Check for socket timeouts + local difftime = os_difftime( _currenttime - _starttime ) + if difftime > _checkinterval then + _starttime = _currenttime + for handler, timestamp in pairs( _writetimes ) do + if os_difftime( _currenttime - timestamp ) > _sendtimeout then + handler.disconnect( )( handler, "send timeout" ) + handler:force_close() -- forced disconnect + end + end + for handler, timestamp in pairs( _readtimes ) do + if os_difftime( _currenttime - timestamp ) > _readtimeout then + if not(handler.onreadtimeout) or handler:onreadtimeout() ~= true then + handler.disconnect( )( handler, "read timeout" ) + handler:close( ) -- forced disconnect? + else + _readtimes[ handler ] = _currenttime -- reset timer + end + end + end + end + + -- Fire timers + if _currenttime - _timer >= math_min(next_timer_time, 1) then + next_timer_time = math_huge; + for i = 1, _timerlistlen do + local t = _timerlist[ i ]( _currenttime ) -- fire timers + if t then next_timer_time = math_min(next_timer_time, t); end + end + _timer = _currenttime + else + next_timer_time = next_timer_time - (_currenttime - _timer); + end + + -- wait some time (0 by default) + socket_sleep( _sleeptime ) + until quitting; + if once and quitting == "once" then quitting = nil; return; end + return "quitting" +end + +local function step() + return loop(true); +end + +local function get_backend() + return "select"; +end + +--// EXPERIMENTAL //-- + +local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx ) + local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) + if not handler then return nil, err end + _socketlist[ socket ] = handler + if not sslctx then + _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) + if listeners.onconnect then + -- When socket is writeable, call onconnect + local _sendbuffer = handler.sendbuffer; + handler.sendbuffer = function () + handler.sendbuffer = _sendbuffer; + listeners.onconnect(handler); + return _sendbuffer(); -- Send any queued outgoing data + end + end + end + return handler, socket +end + +local addclient = function( address, port, listeners, pattern, sslctx, typ ) + local err + if type( listeners ) ~= "table" then + err = "invalid listener table" + elseif type ( address ) ~= "string" then + err = "invalid address" + elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then + err = "invalid port" + elseif sslctx and not has_luasec then + err = "luasec not found" + end + if getaddrinfo and not typ then + local addrinfo, err = getaddrinfo(address) + if not addrinfo then return nil, err end + if addrinfo[1] and addrinfo[1].family == "inet6" then + typ = "tcp6" + end + end + local create = luasocket[typ or "tcp"] + if type( create ) ~= "function" then + err = "invalid socket type" + end + + if err then + out_error( "server.lua, addclient: ", err ) + return nil, err + end + + local client, err = create( ) + if err then + return nil, err + end + client:settimeout( 0 ) + local ok, err = client:connect( address, port ) + if ok or err == "timeout" or err == "Operation already in progress" then + return wrapclient( client, address, port, listeners, pattern, sslctx ) + else + return nil, err + end +end + +--// EXPERIMENTAL //-- + +----------------------------------// BEGIN //-- + +use "setmetatable" ( _socketlist, { __mode = "k" } ) +use "setmetatable" ( _readtimes, { __mode = "k" } ) +use "setmetatable" ( _writetimes, { __mode = "k" } ) + +_timer = luasocket_gettime( ) +_starttime = luasocket_gettime( ) + +local function setlogger(new_logger) + local old_logger = log; + if new_logger then + log = new_logger; + end + return old_logger; +end + +----------------------------------// PUBLIC INTERFACE //-- + +return { + _addtimer = addtimer, + add_task = add_task; + + addclient = addclient, + wrapclient = wrapclient, + + loop = loop, + link = link, + step = step, + stats = stats, + closeall = closeall, + addserver = addserver, + getserver = getserver, + setlogger = setlogger, + getsettings = getsettings, + setquitting = setquitting, + removeserver = removeserver, + get_backend = get_backend, + changesettings = changesettings, +} diff -r 000000000000 -r d363a6692a10 net/websocket.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/net/websocket.lua Thu Dec 25 10:48:06 2014 +0000 @@ -0,0 +1,272 @@ +-- Prosody IM +-- Copyright (C) 2012 Florian Zeitz +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local t_concat = table.concat; + +local http = require "net.http"; +local frames = require "net.websocket.frames"; +local base64 = require "util.encodings".base64; +local sha1 = require "util.hashes".sha1; +local random_bytes = require "util.random".bytes; +local timer = require "util.timer"; +local log = require "util.logger".init "websocket"; + +local close_timeout = 3; -- Seconds to wait after sending close frame until closing connection. + +local websockets = {}; + +local websocket_listeners = {}; +function websocket_listeners.ondisconnect(handler, err) + local s = websockets[handler]; + websockets[handler] = nil; + if s.close_timer then + timer.stop(s.close_timer); + s.close_timer = nil; + end + s.readyState = 3; + if s.close_code == nil and s.onerror then s:onerror(err); end + if s.onclose then s:onclose(s.close_code, s.close_message or err); end +end + +function websocket_listeners.ondetach(handler) + websockets[handler] = nil; +end + +local function fail(s, code, reason) + module:log("warn", "WebSocket connection failed, closing. %d %s", code, reason); + s:close(code, reason); + s.handler:close(); + return false +end + +function websocket_listeners.onincoming(handler, buffer, err) + local s = websockets[handler]; + s.readbuffer = s.readbuffer..buffer; + while true do + local frame, len = frames.parse(s.readbuffer); + if frame == nil then break end + s.readbuffer = s.readbuffer:sub(len+1); + + log("debug", "Websocket received frame: opcode=%0x, %i bytes", frame.opcode, #frame.data); + + -- Error cases + if frame.RSV1 or frame.RSV2 or frame.RSV3 then -- Reserved bits non zero + return fail(s, 1002, "Reserved bits not zero"); + end + + if frame.opcode < 0x8 then + local databuffer = s.databuffer; + if frame.opcode == 0x0 then -- Continuation frames + if not databuffer then + return fail(s, 1002, "Unexpected continuation frame"); + end + databuffer[#databuffer+1] = frame.data; + elseif frame.opcode == 0x1 or frame.opcode == 0x2 then -- Text or Binary frame + if databuffer then + return fail(s, 1002, "Continuation frame expected"); + end + databuffer = {type=frame.opcode, frame.data}; + s.databuffer = databuffer; + else + return fail(s, 1002, "Reserved opcode"); + end + if frame.FIN then + s.databuffer = nil; + if s.onmessage then + s:onmessage(t_concat(databuffer), databuffer.type); + end + end + else -- Control frame + if frame.length > 125 then -- Control frame with too much payload + return fail(s, 1002, "Payload too large"); + elseif not frame.FIN then -- Fragmented control frame + return fail(s, 1002, "Fragmented control frame"); + end + if frame.opcode == 0x8 then -- Close request + if frame.length == 1 then + return fail(s, 1002, "Close frame with payload, but too short for status code"); + end + local status_code, message = frames.parse_close(frame.data); + if status_code == nil then + --[[ RFC 6455 7.4.1 + 1005 is a reserved value and MUST NOT be set as a status code in a + Close control frame by an endpoint. It is designated for use in + applications expecting a status code to indicate that no status + code was actually present. + ]] + status_code = 1005 + elseif status_code < 1000 then + return fail(s, 1002, "Closed with invalid status code"); + elseif ((status_code > 1003 and status_code < 1007) or status_code > 1011) and status_code < 3000 then + return fail(s, 1002, "Closed with reserved status code"); + end + s.close_code, s.close_message = status_code, message; + s:close(1000); + return true; + elseif frame.opcode == 0x9 then -- Ping frame + frame.opcode = 0xA; + frame.MASK = true; -- RFC 6455 6.1.5: If the data is being sent by the client, the frame(s) MUST be masked + handler:write(frames.build(frame)); + elseif frame.opcode == 0xA then -- Pong frame + log("debug", "Received unexpected pong frame: " .. tostring(frame.data)); + else + return fail(s, 1002, "Reserved opcode"); + end + end + end + return true; +end + +local websocket_methods = {}; +local function close_timeout_cb(now, timerid, s) + s.close_timer = nil; + log("warn", "Close timeout waiting for server to close, closing manually."); + s.handler:close(); +end +function websocket_methods:close(code, reason) + if self.readyState < 2 then + code = code or 1000; + log("debug", "closing WebSocket with code %i: %s" , code , tostring(reason)); + self.readyState = 2; + local handler = self.handler; + handler:write(frames.build_close(code, reason, true)); + -- Do not close socket straight away, wait for acknowledgement from server. + self.close_timer = timer.add_task(close_timeout, close_timeout_cb, self); + elseif self.readyState == 2 then + log("debug", "tried to close a closing WebSocket, closing the raw socket."); + -- Stop timer + if self.close_timer then + timer.stop(self.close_timer); + self.close_timer = nil; + end + local handler = self.handler; + handler:close(); + else + log("debug", "tried to close a closed WebSocket, ignoring."); + end +end +function websocket_methods:send(data, opcode) + if self.readyState < 1 then + return nil, "WebSocket not open yet, unable to send data."; + elseif self.readyState >= 2 then + return nil, "WebSocket closed, unable to send data."; + end + if opcode == "text" or opcode == nil then + opcode = 0x1; + elseif opcode == "binary" then + opcode = 0x2; + end + local frame = { + FIN = true; + MASK = true; -- RFC 6455 6.1.5: If the data is being sent by the client, the frame(s) MUST be masked + opcode = opcode; + data = tostring(data); + }; + log("debug", "WebSocket sending frame: opcode=%0x, %i bytes", frame.opcode, #frame.data); + return self.handler:write(frames.build(frame)); +end + +local websocket_metatable = { + __index = websocket_methods; +}; + +local function connect(url, ex, listeners) + ex = ex or {}; + + --[[RFC 6455 4.1.7: + The request MUST include a header field with the name + |Sec-WebSocket-Key|. The value of this header field MUST be a + nonce consisting of a randomly selected 16-byte value that has + been base64-encoded (see Section 4 of [RFC4648]). The nonce + MUST be selected randomly for each connection. + ]] + local key = base64.encode(random_bytes(16)); + + -- Either a single protocol string or an array of protocol strings. + local protocol = ex.protocol; + if type(protocol) == "string" then + protocol = { protocol, [protocol] = true }; + elseif type(protocol) == "table" and protocol[1] then + for _, v in ipairs(protocol) do + protocol[v] = true; + end + else + protocol = nil; + end + + local headers = { + ["Upgrade"] = "websocket"; + ["Connection"] = "Upgrade"; + ["Sec-WebSocket-Key"] = key; + ["Sec-WebSocket-Protocol"] = protocol and t_concat(protocol, ", "); + ["Sec-WebSocket-Version"] = "13"; + ["Sec-WebSocket-Extensions"] = ex.extensions; + } + if ex.headers then + for k,v in pairs(ex.headers) do + headers[k] = v; + end + end + + local s = setmetatable({ + readbuffer = ""; + databuffer = nil; + handler = nil; + close_code = nil; + close_message = nil; + close_timer = nil; + readyState = 0; + protocol = nil; + + url = url; + + onopen = listeners.onopen; + onclose = listeners.onclose; + onmessage = listeners.onmessage; + onerror = listeners.onerror; + }, websocket_metatable); + + local http_url = url:gsub("^(ws)", "http"); + local http_req = http.request(http_url, { + method = "GET"; + headers = headers; + sslctx = ex.sslctx; + }, function(b, c, r, http_req) + if c ~= 101 + or r.headers["connection"]:lower() ~= "upgrade" + or r.headers["upgrade"] ~= "websocket" + or r.headers["sec-websocket-accept"] ~= base64.encode(sha1(key .. "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) + or (protocol and not protocol[r.headers["sec-websocket-protocol"]]) + then + s.readyState = 3; + log("warn", "WebSocket connection to %s failed: %s", url, tostring(b)); + if s.onerror then s:onerror("connecting-failed"); end + return; + end + + s.protocol = r.headers["sec-websocket-protocol"]; + + -- Take possession of socket from http + http_req.conn = nil; + local handler = http_req.handler; + s.handler = handler; + websockets[handler] = s; + handler:setlistener(websocket_listeners); + + log("debug", "WebSocket connected successfully to %s", url); + s.readyState = 1; + if s.onopen then s:onopen(); end + websocket_listeners.onincoming(handler, b); + end); + + return s; +end + +return { + connect = connect; +}; diff -r 000000000000 -r d363a6692a10 net/websocket/frames.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/net/websocket/frames.lua Thu Dec 25 10:48:06 2014 +0000 @@ -0,0 +1,195 @@ +-- Prosody IM +-- Copyright (C) 2012 Florian Zeitz +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local softreq = require "util.dependencies".softreq; +local log = require "util.logger".init "websocket.frames"; +local random_bytes = require "util.random".bytes; + +local bit; +pcall(function() bit = require"bit"; end); +bit = bit or softreq"bit32" +if not bit then log("error", "No bit module found. Either LuaJIT 2, lua-bitop or Lua 5.2 is required"); end +local band = bit.band; +local bor = bit.bor; +local bxor = bit.bxor; +local lshift = bit.lshift; +local rshift = bit.rshift; + +local t_concat = table.concat; +local s_byte = string.byte; +local s_char= string.char; +local s_sub = string.sub; + +local function read_uint16be(str, pos) + local l1, l2 = s_byte(str, pos, pos+1); + return l1*256 + l2; +end +-- FIXME: this may lose precision +local function read_uint64be(str, pos) + local l1, l2, l3, l4, l5, l6, l7, l8 = s_byte(str, pos, pos+7); + return lshift(l1, 56) + lshift(l2, 48) + lshift(l3, 40) + lshift(l4, 32) + + lshift(l5, 24) + lshift(l6, 16) + lshift(l7, 8) + l8; +end +local function pack_uint16be(x) + return s_char(rshift(x, 8), band(x, 0xFF)); +end +local function get_byte(x, n) + return band(rshift(x, n), 0xFF); +end +local function pack_uint64be(x) + return s_char(rshift(x, 56), get_byte(x, 48), get_byte(x, 40), get_byte(x, 32), + get_byte(x, 24), get_byte(x, 16), get_byte(x, 8), band(x, 0xFF)); +end + +local function parse_frame_header(frame) + if #frame < 2 then return; end + + local byte1, byte2 = s_byte(frame, 1, 2); + local result = { + FIN = band(byte1, 0x80) > 0; + RSV1 = band(byte1, 0x40) > 0; + RSV2 = band(byte1, 0x20) > 0; + RSV3 = band(byte1, 0x10) > 0; + opcode = band(byte1, 0x0F); + + MASK = band(byte2, 0x80) > 0; + length = band(byte2, 0x7F); + }; + + local length_bytes = 0; + if result.length == 126 then + length_bytes = 2; + elseif result.length == 127 then + length_bytes = 8; + end + + local header_length = 2 + length_bytes + (result.MASK and 4 or 0); + if #frame < header_length then return; end + + if length_bytes == 2 then + result.length = read_uint16be(frame, 3); + elseif length_bytes == 8 then + result.length = read_uint64be(frame, 3); + end + + if result.MASK then + result.key = { s_byte(frame, length_bytes+3, length_bytes+6) }; + end + + return result, header_length; +end + +-- XORs the string `str` with the array of bytes `key` +-- TODO: optimize +local function apply_mask(str, key, from, to) + from = from or 1 + if from < 0 then from = #str + from + 1 end -- negative indicies + to = to or #str + if to < 0 then to = #str + to + 1 end -- negative indicies + local key_len = #key + local counter = 0; + local data = {}; + for i = from, to do + local key_index = counter%key_len + 1; + counter = counter + 1; + data[counter] = s_char(bxor(key[key_index], s_byte(str, i))); + end + return t_concat(data); +end + +local function parse_frame_body(frame, header, pos) + if header.MASK then + return apply_mask(frame, header.key, pos, pos + header.length - 1); + else + return frame:sub(pos, pos + header.length - 1); + end +end + +local function parse_frame(frame) + local result, pos = parse_frame_header(frame); + if result == nil or #frame < (pos + result.length) then return; end + result.data = parse_frame_body(frame, result, pos+1); + return result, pos + result.length; +end + +local function build_frame(desc) + local data = desc.data or ""; + + assert(desc.opcode and desc.opcode >= 0 and desc.opcode <= 0xF, "Invalid WebSocket opcode"); + if desc.opcode >= 0x8 then + -- RFC 6455 5.5 + assert(#data <= 125, "WebSocket control frames MUST have a payload length of 125 bytes or less."); + end + + local b1 = bor(desc.opcode, + desc.FIN and 0x80 or 0, + desc.RSV1 and 0x40 or 0, + desc.RSV2 and 0x20 or 0, + desc.RSV3 and 0x10 or 0); + + local b2 = #data; + local length_extra; + if b2 <= 125 then -- 7-bit length + length_extra = ""; + elseif b2 <= 0xFFFF then -- 2-byte length + b2 = 126; + length_extra = pack_uint16be(#data); + else -- 8-byte length + b2 = 127; + length_extra = pack_uint64be(#data); + end + + local key = "" + if desc.MASK then + local key_a = desc.key + if key_a then + key = s_char(unpack(key_a, 1, 4)); + else + key = random_bytes(4); + key_a = {key:byte(1,4)}; + end + b2 = bor(b2, 0x80); + data = apply_mask(data, key_a); + end + + return s_char(b1, b2) .. length_extra .. key .. data +end + +local function parse_close(data) + local code, message + if #data >= 2 then + code = read_uint16be(data, 1); + if #data > 2 then + message = s_sub(data, 3); + end + end + return code, message +end + +local function build_close(code, message, mask) + local data = pack_uint16be(code); + if message then + assert(#message<=123, "Close reason must be <=123 bytes"); + data = data .. message; + end + return build_frame({ + opcode = 0x8; + FIN = true; + MASK = mask; + data = data; + }); +end + +return { + parse_header = parse_frame_header; + parse_body = parse_frame_body; + parse = parse_frame; + build = build_frame; + parse_close = parse_close; + build_close = build_close; +}; diff -r 000000000000 -r d363a6692a10 util/events.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/util/events.lua Thu Dec 25 10:48:06 2014 +0000 @@ -0,0 +1,83 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + + +local pairs = pairs; +local t_insert = table.insert; +local t_sort = table.sort; +local setmetatable = setmetatable; +local next = next; + +module "events" + +function new() + local handlers = {}; + local event_map = {}; + local function _rebuild_index(handlers, event) + local _handlers = event_map[event]; + if not _handlers or next(_handlers) == nil then return; end + local index = {}; + for handler in pairs(_handlers) do + t_insert(index, handler); + end + t_sort(index, function(a, b) return _handlers[a] > _handlers[b]; end); + handlers[event] = index; + return index; + end; + setmetatable(handlers, { __index = _rebuild_index }); + local function add_handler(event, handler, priority) + local map = event_map[event]; + if map then + map[handler] = priority or 0; + else + map = {[handler] = priority or 0}; + event_map[event] = map; + end + handlers[event] = nil; + end; + local function remove_handler(event, handler) + local map = event_map[event]; + if map then + map[handler] = nil; + handlers[event] = nil; + if next(map) == nil then + event_map[event] = nil; + end + end + end; + local function add_handlers(handlers) + for event, handler in pairs(handlers) do + add_handler(event, handler); + end + end; + local function remove_handlers(handlers) + for event, handler in pairs(handlers) do + remove_handler(event, handler); + end + end; + local function fire_event(event_name, event_data) + local h = handlers[event_name]; + if h then + for i=1,#h do + local ret = h[i](event_data); + if ret ~= nil then return ret; end + end + end + end; + return { + add_handler = add_handler; + remove_handler = remove_handler; + add_handlers = add_handlers; + remove_handlers = remove_handlers; + fire_event = fire_event; + _handlers = handlers; + _event_map = event_map; + }; +end + +return _M; diff -r 000000000000 -r d363a6692a10 util/http.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/util/http.lua Thu Dec 25 10:48:06 2014 +0000 @@ -0,0 +1,64 @@ +-- Prosody IM +-- Copyright (C) 2013 Florian Zeitz +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local format, char = string.format, string.char; +local pairs, ipairs, tonumber = pairs, ipairs, tonumber; +local t_insert, t_concat = table.insert, table.concat; + +local function urlencode(s) + return s and (s:gsub("[^a-zA-Z0-9.~_-]", function (c) return format("%%%02x", c:byte()); end)); +end +local function urldecode(s) + return s and (s:gsub("%%(%x%x)", function (c) return char(tonumber(c,16)); end)); +end + +local function _formencodepart(s) + return s and (s:gsub("%W", function (c) + if c ~= " " then + return format("%%%02x", c:byte()); + else + return "+"; + end + end)); +end + +local function formencode(form) + local result = {}; + if form[1] then -- Array of ordered { name, value } + for _, field in ipairs(form) do + t_insert(result, _formencodepart(field.name).."=".._formencodepart(field.value)); + end + else -- Unordered map of name -> value + for name, value in pairs(form) do + t_insert(result, _formencodepart(name).."=".._formencodepart(value)); + end + end + return t_concat(result, "&"); +end + +local function formdecode(s) + if not s:match("=") then return urldecode(s); end + local r = {}; + for k, v in s:gmatch("([^=&]*)=([^&]*)") do + k, v = k:gsub("%+", "%%20"), v:gsub("%+", "%%20"); + k, v = urldecode(k), urldecode(v); + t_insert(r, { name = k, value = v }); + r[k] = v; + end + return r; +end + +local function contains_token(field, token) + field = ","..field:gsub("[ \t]", ""):lower()..","; + return field:find(","..token:lower()..",", 1, true) ~= nil; +end + +return { + urlencode = urlencode, urldecode = urldecode; + formencode = formencode, formdecode = formdecode; + contains_token = contains_token; +}; diff -r 000000000000 -r d363a6692a10 util/logger.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/util/logger.lua Thu Dec 25 10:48:06 2014 +0000 @@ -0,0 +1,74 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local pcall = pcall; + +local find = string.find; +local ipairs, pairs, setmetatable = ipairs, pairs, setmetatable; + +module "logger" + +local level_sinks = {}; + +local make_logger; + +function init(name) + local log_debug = make_logger(name, "debug"); + local log_info = make_logger(name, "info"); + local log_warn = make_logger(name, "warn"); + local log_error = make_logger(name, "error"); + + return function (level, message, ...) + if level == "debug" then + return log_debug(message, ...); + elseif level == "info" then + return log_info(message, ...); + elseif level == "warn" then + return log_warn(message, ...); + elseif level == "error" then + return log_error(message, ...); + end + end +end + +function make_logger(source_name, level) + local level_handlers = level_sinks[level]; + if not level_handlers then + level_handlers = {}; + level_sinks[level] = level_handlers; + end + + local logger = function (message, ...) + for i = 1,#level_handlers do + level_handlers[i](source_name, level, message, ...); + end + end + + return logger; +end + +function reset() + for level, handler_list in pairs(level_sinks) do + -- Clear all handlers for this level + for i = 1, #handler_list do + handler_list[i] = nil; + end + end +end + +function add_level_sink(level, sink_function) + if not level_sinks[level] then + level_sinks[level] = { sink_function }; + else + level_sinks[level][#level_sinks[level] + 1 ] = sink_function; + end +end + +_M.new = make_logger; + +return _M; diff -r 000000000000 -r d363a6692a10 util/timer.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/util/timer.lua Thu Dec 25 10:48:06 2014 +0000 @@ -0,0 +1,78 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local indexedbheap = require "util.indexedbheap"; +local log = require "util.logger".init("timer"); +local server = require "net.server"; +local get_time = require "socket".gettime; +local type = type; +local debug_traceback = debug.traceback; +local tostring = tostring; +local xpcall = xpcall; + +module "timer" + +local _add_task = server.add_task; +--add_task = _add_task; + +local h = indexedbheap.create(); +local params = {}; +local next_time = nil; +local _id, _callback, _now, _param; +local function _call() return _callback(_now, _id, _param); end +local function _traceback_handler(err) log("error", "Traceback[timer]: %s", debug_traceback(tostring(err), 2)); end +local function _on_timer(now) + local peek; + while true do + peek = h:peek(); + if peek == nil or peek > now then break; end + local _; + _, _callback, _id = h:pop(); + _now = now; + _param = params[_id]; + params[_id] = nil; + --item(now, id, _param); -- FIXME pcall + local success, err = xpcall(_call, _traceback_handler); + if success and type(err) == "number" then + h:insert(_callback, err + now, _id); -- re-add + params[_id] = _param; + end + end + next_time = peek; + if peek ~= nil then + return peek - now; + end +end +function add_task(delay, callback, param) + local current_time = get_time(); + local event_time = current_time + delay; + + local id = h:insert(callback, event_time); + params[id] = param; + if next_time == nil or event_time < next_time then + next_time = event_time; + _add_task(next_time - current_time, _on_timer); + end + return id; +end +function stop(id) + params[id] = nil; + return h:remove(id); +end +function reschedule(id, delay) + local current_time = get_time(); + local event_time = current_time + delay; + h:reprioritize(id, delay); + if next_time == nil or event_time < next_time then + next_time = event_time; + _add_task(next_time - current_time, _on_timer); + end + return id; +end + +return _M; diff -r 000000000000 -r d363a6692a10 util/watchdog.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/util/watchdog.lua Thu Dec 25 10:48:06 2014 +0000 @@ -0,0 +1,34 @@ +local timer = require "util.timer"; +local setmetatable = setmetatable; +local os_time = os.time; + +module "watchdog" + +local watchdog_methods = {}; +local watchdog_mt = { __index = watchdog_methods }; + +function new(timeout, callback) + local watchdog = setmetatable({ timeout = timeout, last_reset = os_time(), callback = callback }, watchdog_mt); + timer.add_task(timeout+1, function (current_time) + local last_reset = watchdog.last_reset; + if not last_reset then + return; + end + local time_left = (last_reset + timeout) - current_time; + if time_left < 0 then + return watchdog:callback(); + end + return time_left + 1; + end); + return watchdog; +end + +function watchdog_methods:reset() + self.last_reset = os_time(); +end + +function watchdog_methods:cancel() + self.last_reset = nil; +end + +return _M;