net/dns.lua

changeset 0
d363a6692a10
equal deleted inserted replaced
-1:000000000000 0:d363a6692a10
1 -- Prosody IM
2 -- This file is included with Prosody IM. It has modifications,
3 -- which are hereby placed in the public domain.
4
5
6 -- todo: quick (default) header generation
7 -- todo: nxdomain, error handling
8 -- todo: cache results of encodeName
9
10
11 -- reference: http://tools.ietf.org/html/rfc1035
12 -- reference: http://tools.ietf.org/html/rfc1876 (LOC)
13
14
15 local socket = require "socket";
16 local timer = require "util.timer";
17 local new_ip = require "util.ip".new_ip;
18
19 local _, windows = pcall(require, "util.windows");
20 local is_windows = (_ and windows) or os.getenv("WINDIR");
21
22 local coroutine, io, math, string, table =
23 coroutine, io, math, string, table;
24
25 local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type=
26 ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type;
27
28 local ztact = { -- public domain 20080404 lua@ztact.com
29 get = function(parent, ...)
30 local len = select('#', ...);
31 for i=1,len do
32 parent = parent[select(i, ...)];
33 if parent == nil then break; end
34 end
35 return parent;
36 end;
37 set = function(parent, ...)
38 local len = select('#', ...);
39 local key, value = select(len-1, ...);
40 local cutpoint, cutkey;
41
42 for i=1,len-2 do
43 local key = select (i, ...)
44 local child = parent[key]
45
46 if value == nil then
47 if child == nil then
48 return;
49 elseif next(child, next(child)) then
50 cutpoint = nil; cutkey = nil;
51 elseif cutpoint == nil then
52 cutpoint = parent; cutkey = key;
53 end
54 elseif child == nil then
55 child = {};
56 parent[key] = child;
57 end
58 parent = child
59 end
60
61 if value == nil and cutpoint then
62 cutpoint[cutkey] = nil;
63 else
64 parent[key] = value;
65 return value;
66 end
67 end;
68 };
69 local get, set = ztact.get, ztact.set;
70
71 local default_timeout = 15;
72
73 -------------------------------------------------- module dns
74 module('dns')
75 local dns = _M;
76
77
78 -- dns type & class codes ------------------------------ dns type & class codes
79
80
81 local append = table.insert
82
83
84 local function highbyte(i) -- - - - - - - - - - - - - - - - - - - highbyte
85 return (i-(i%0x100))/0x100;
86 end
87
88
89 local function augment (t) -- - - - - - - - - - - - - - - - - - - - augment
90 local a = {};
91 for i,s in pairs(t) do
92 a[i] = s;
93 a[s] = s;
94 a[string.lower(s)] = s;
95 end
96 return a;
97 end
98
99
100 local function encode (t) -- - - - - - - - - - - - - - - - - - - - - encode
101 local code = {};
102 for i,s in pairs(t) do
103 local word = string.char(highbyte(i), i%0x100);
104 code[i] = word;
105 code[s] = word;
106 code[string.lower(s)] = word;
107 end
108 return code;
109 end
110
111
112 dns.types = {
113 'A', 'NS', 'MD', 'MF', 'CNAME', 'SOA', 'MB', 'MG', 'MR', 'NULL', 'WKS',
114 'PTR', 'HINFO', 'MINFO', 'MX', 'TXT',
115 [ 28] = 'AAAA', [ 29] = 'LOC', [ 33] = 'SRV',
116 [252] = 'AXFR', [253] = 'MAILB', [254] = 'MAILA', [255] = '*' };
117
118
119 dns.classes = { 'IN', 'CS', 'CH', 'HS', [255] = '*' };
120
121
122 dns.type = augment (dns.types);
123 dns.class = augment (dns.classes);
124 dns.typecode = encode (dns.types);
125 dns.classcode = encode (dns.classes);
126
127
128
129 local function standardize(qname, qtype, qclass) -- - - - - - - standardize
130 if string.byte(qname, -1) ~= 0x2E then qname = qname..'.'; end
131 qname = string.lower(qname);
132 return qname, dns.type[qtype or 'A'], dns.class[qclass or 'IN'];
133 end
134
135
136 local function prune(rrs, time, soft) -- - - - - - - - - - - - - - - prune
137 time = time or socket.gettime();
138 for i,rr in ipairs(rrs) do
139 if rr.tod then
140 -- rr.tod = rr.tod - 50 -- accelerated decripitude
141 rr.ttl = math.floor(rr.tod - time);
142 if rr.ttl <= 0 then
143 rrs[rr[rr.type:lower()]] = nil;
144 table.remove(rrs, i);
145 return prune(rrs, time, soft); -- Re-iterate
146 end
147 elseif soft == 'soft' then -- What is this? I forget!
148 assert(rr.ttl == 0);
149 rrs[rr[rr.type:lower()]] = nil;
150 table.remove(rrs, i);
151 end
152 end
153 end
154
155
156 -- metatables & co. ------------------------------------------ metatables & co.
157
158
159 local resolver = {};
160 resolver.__index = resolver;
161
162 resolver.timeout = default_timeout;
163
164 local function default_rr_tostring(rr)
165 local rr_val = rr.type and rr[rr.type:lower()];
166 if type(rr_val) ~= "string" then
167 return "<UNKNOWN RDATA TYPE>";
168 end
169 return rr_val;
170 end
171
172 local special_tostrings = {
173 LOC = resolver.LOC_tostring;
174 MX = function (rr)
175 return string.format('%2i %s', rr.pref, rr.mx);
176 end;
177 SRV = function (rr)
178 local s = rr.srv;
179 return string.format('%5d %5d %5d %s', s.priority, s.weight, s.port, s.target);
180 end;
181 };
182
183 local rr_metatable = {}; -- - - - - - - - - - - - - - - - - - - rr_metatable
184 function rr_metatable.__tostring(rr)
185 local rr_string = (special_tostrings[rr.type] or default_rr_tostring)(rr);
186 return string.format('%2s %-5s %6i %-28s %s', rr.class, rr.type, rr.ttl, rr.name, rr_string);
187 end
188
189
190 local rrs_metatable = {}; -- - - - - - - - - - - - - - - - - - rrs_metatable
191 function rrs_metatable.__tostring(rrs)
192 local t = {};
193 for i,rr in ipairs(rrs) do
194 append(t, tostring(rr)..'\n');
195 end
196 return table.concat(t);
197 end
198
199
200 local cache_metatable = {}; -- - - - - - - - - - - - - - - - cache_metatable
201 function cache_metatable.__tostring(cache)
202 local time = socket.gettime();
203 local t = {};
204 for class,types in pairs(cache) do
205 for type,names in pairs(types) do
206 for name,rrs in pairs(names) do
207 prune(rrs, time);
208 append(t, tostring(rrs));
209 end
210 end
211 end
212 return table.concat(t);
213 end
214
215
216 function resolver:new() -- - - - - - - - - - - - - - - - - - - - - resolver
217 local r = { active = {}, cache = {}, unsorted = {} };
218 setmetatable(r, resolver);
219 setmetatable(r.cache, cache_metatable);
220 setmetatable(r.unsorted, { __mode = 'kv' });
221 return r;
222 end
223
224
225 -- packet layer -------------------------------------------------- packet layer
226
227
228 function dns.random(...) -- - - - - - - - - - - - - - - - - - - dns.random
229 math.randomseed(math.floor(10000*socket.gettime()) % 0x100000000);
230 dns.random = math.random;
231 return dns.random(...);
232 end
233
234
235 local function encodeHeader(o) -- - - - - - - - - - - - - - - encodeHeader
236 o = o or {};
237 o.id = o.id or dns.random(0, 0xffff); -- 16b (random) id
238
239 o.rd = o.rd or 1; -- 1b 1 recursion desired
240 o.tc = o.tc or 0; -- 1b 1 truncated response
241 o.aa = o.aa or 0; -- 1b 1 authoritative response
242 o.opcode = o.opcode or 0; -- 4b 0 query
243 -- 1 inverse query
244 -- 2 server status request
245 -- 3-15 reserved
246 o.qr = o.qr or 0; -- 1b 0 query, 1 response
247
248 o.rcode = o.rcode or 0; -- 4b 0 no error
249 -- 1 format error
250 -- 2 server failure
251 -- 3 name error
252 -- 4 not implemented
253 -- 5 refused
254 -- 6-15 reserved
255 o.z = o.z or 0; -- 3b 0 resvered
256 o.ra = o.ra or 0; -- 1b 1 recursion available
257
258 o.qdcount = o.qdcount or 1; -- 16b number of question RRs
259 o.ancount = o.ancount or 0; -- 16b number of answers RRs
260 o.nscount = o.nscount or 0; -- 16b number of nameservers RRs
261 o.arcount = o.arcount or 0; -- 16b number of additional RRs
262
263 -- string.char() rounds, so prevent roundup with -0.4999
264 local header = string.char(
265 highbyte(o.id), o.id %0x100,
266 o.rd + 2*o.tc + 4*o.aa + 8*o.opcode + 128*o.qr,
267 o.rcode + 16*o.z + 128*o.ra,
268 highbyte(o.qdcount), o.qdcount %0x100,
269 highbyte(o.ancount), o.ancount %0x100,
270 highbyte(o.nscount), o.nscount %0x100,
271 highbyte(o.arcount), o.arcount %0x100
272 );
273
274 return header, o.id;
275 end
276
277
278 local function encodeName(name) -- - - - - - - - - - - - - - - - encodeName
279 local t = {};
280 for part in string.gmatch(name, '[^.]+') do
281 append(t, string.char(string.len(part)));
282 append(t, part);
283 end
284 append(t, string.char(0));
285 return table.concat(t);
286 end
287
288
289 local function encodeQuestion(qname, qtype, qclass) -- - - - encodeQuestion
290 qname = encodeName(qname);
291 qtype = dns.typecode[qtype or 'a'];
292 qclass = dns.classcode[qclass or 'in'];
293 return qname..qtype..qclass;
294 end
295
296
297 function resolver:byte(len) -- - - - - - - - - - - - - - - - - - - - - byte
298 len = len or 1;
299 local offset = self.offset;
300 local last = offset + len - 1;
301 if last > #self.packet then
302 error(string.format('out of bounds: %i>%i', last, #self.packet));
303 end
304 self.offset = offset + len;
305 return string.byte(self.packet, offset, last);
306 end
307
308
309 function resolver:word() -- - - - - - - - - - - - - - - - - - - - - - word
310 local b1, b2 = self:byte(2);
311 return 0x100*b1 + b2;
312 end
313
314
315 function resolver:dword () -- - - - - - - - - - - - - - - - - - - - - dword
316 local b1, b2, b3, b4 = self:byte(4);
317 --print('dword', b1, b2, b3, b4);
318 return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4;
319 end
320
321
322 function resolver:sub(len) -- - - - - - - - - - - - - - - - - - - - - - sub
323 len = len or 1;
324 local s = string.sub(self.packet, self.offset, self.offset + len - 1);
325 self.offset = self.offset + len;
326 return s;
327 end
328
329
330 function resolver:header(force) -- - - - - - - - - - - - - - - - - - header
331 local id = self:word();
332 --print(string.format(':header id %x', id));
333 if not self.active[id] and not force then return nil; end
334
335 local h = { id = id };
336
337 local b1, b2 = self:byte(2);
338
339 h.rd = b1 %2;
340 h.tc = b1 /2%2;
341 h.aa = b1 /4%2;
342 h.opcode = b1 /8%16;
343 h.qr = b1 /128;
344
345 h.rcode = b2 %16;
346 h.z = b2 /16%8;
347 h.ra = b2 /128;
348
349 h.qdcount = self:word();
350 h.ancount = self:word();
351 h.nscount = self:word();
352 h.arcount = self:word();
353
354 for k,v in pairs(h) do h[k] = v-v%1; end
355
356 return h;
357 end
358
359
360 function resolver:name() -- - - - - - - - - - - - - - - - - - - - - - name
361 local remember, pointers = nil, 0;
362 local len = self:byte();
363 local n = {};
364 if len == 0 then return "." end -- Root label
365 while len > 0 do
366 if len >= 0xc0 then -- name is "compressed"
367 pointers = pointers + 1;
368 if pointers >= 20 then error('dns error: 20 pointers'); end;
369 local offset = ((len-0xc0)*0x100) + self:byte();
370 remember = remember or self.offset;
371 self.offset = offset + 1; -- +1 for lua
372 else -- name is not compressed
373 append(n, self:sub(len)..'.');
374 end
375 len = self:byte();
376 end
377 self.offset = remember or self.offset;
378 return table.concat(n);
379 end
380
381
382 function resolver:question() -- - - - - - - - - - - - - - - - - - question
383 local q = {};
384 q.name = self:name();
385 q.type = dns.type[self:word()];
386 q.class = dns.class[self:word()];
387 return q;
388 end
389
390
391 function resolver:A(rr) -- - - - - - - - - - - - - - - - - - - - - - - - A
392 local b1, b2, b3, b4 = self:byte(4);
393 rr.a = string.format('%i.%i.%i.%i', b1, b2, b3, b4);
394 end
395
396 function resolver:AAAA(rr)
397 local addr = {};
398 for i = 1, rr.rdlength, 2 do
399 local b1, b2 = self:byte(2);
400 table.insert(addr, ("%02x%02x"):format(b1, b2));
401 end
402 addr = table.concat(addr, ":"):gsub("%f[%x]0+(%x)","%1");
403 local zeros = {};
404 for item in addr:gmatch(":[0:]+:") do
405 table.insert(zeros, item)
406 end
407 if #zeros == 0 then
408 rr.aaaa = addr;
409 return
410 elseif #zeros > 1 then
411 table.sort(zeros, function(a, b) return #a > #b end);
412 end
413 rr.aaaa = addr:gsub(zeros[1], "::", 1):gsub("^0::", "::"):gsub("::0$", "::");
414 end
415
416 function resolver:CNAME(rr) -- - - - - - - - - - - - - - - - - - - - CNAME
417 rr.cname = self:name();
418 end
419
420
421 function resolver:MX(rr) -- - - - - - - - - - - - - - - - - - - - - - - MX
422 rr.pref = self:word();
423 rr.mx = self:name();
424 end
425
426
427 function resolver:LOC_nibble_power() -- - - - - - - - - - LOC_nibble_power
428 local b = self:byte();
429 --print('nibbles', ((b-(b%0x10))/0x10), (b%0x10));
430 return ((b-(b%0x10))/0x10) * (10^(b%0x10));
431 end
432
433
434 function resolver:LOC(rr) -- - - - - - - - - - - - - - - - - - - - - - LOC
435 rr.version = self:byte();
436 if rr.version == 0 then
437 rr.loc = rr.loc or {};
438 rr.loc.size = self:LOC_nibble_power();
439 rr.loc.horiz_pre = self:LOC_nibble_power();
440 rr.loc.vert_pre = self:LOC_nibble_power();
441 rr.loc.latitude = self:dword();
442 rr.loc.longitude = self:dword();
443 rr.loc.altitude = self:dword();
444 end
445 end
446
447
448 local function LOC_tostring_degrees(f, pos, neg) -- - - - - - - - - - - - -
449 f = f - 0x80000000;
450 if f < 0 then pos = neg; f = -f; end
451 local deg, min, msec;
452 msec = f%60000;
453 f = (f-msec)/60000;
454 min = f%60;
455 deg = (f-min)/60;
456 return string.format('%3d %2d %2.3f %s', deg, min, msec/1000, pos);
457 end
458
459
460 function resolver.LOC_tostring(rr) -- - - - - - - - - - - - - LOC_tostring
461 local t = {};
462
463 --[[
464 for k,name in pairs { 'size', 'horiz_pre', 'vert_pre', 'latitude', 'longitude', 'altitude' } do
465 append(t, string.format('%4s%-10s: %12.0f\n', '', name, rr.loc[name]));
466 end
467 --]]
468
469 append(t, string.format(
470 '%s %s %.2fm %.2fm %.2fm %.2fm',
471 LOC_tostring_degrees (rr.loc.latitude, 'N', 'S'),
472 LOC_tostring_degrees (rr.loc.longitude, 'E', 'W'),
473 (rr.loc.altitude - 10000000) / 100,
474 rr.loc.size / 100,
475 rr.loc.horiz_pre / 100,
476 rr.loc.vert_pre / 100
477 ));
478
479 return table.concat(t);
480 end
481
482
483 function resolver:NS(rr) -- - - - - - - - - - - - - - - - - - - - - - - NS
484 rr.ns = self:name();
485 end
486
487
488 function resolver:SOA(rr) -- - - - - - - - - - - - - - - - - - - - - - SOA
489 end
490
491
492 function resolver:SRV(rr) -- - - - - - - - - - - - - - - - - - - - - - SRV
493 rr.srv = {};
494 rr.srv.priority = self:word();
495 rr.srv.weight = self:word();
496 rr.srv.port = self:word();
497 rr.srv.target = self:name();
498 end
499
500 function resolver:PTR(rr)
501 rr.ptr = self:name();
502 end
503
504 function resolver:TXT(rr) -- - - - - - - - - - - - - - - - - - - - - - TXT
505 rr.txt = self:sub (self:byte());
506 end
507
508
509 function resolver:rr() -- - - - - - - - - - - - - - - - - - - - - - - - rr
510 local rr = {};
511 setmetatable(rr, rr_metatable);
512 rr.name = self:name(self);
513 rr.type = dns.type[self:word()] or rr.type;
514 rr.class = dns.class[self:word()] or rr.class;
515 rr.ttl = 0x10000*self:word() + self:word();
516 rr.rdlength = self:word();
517
518 if rr.ttl <= 0 then
519 rr.tod = self.time + 30;
520 else
521 rr.tod = self.time + rr.ttl;
522 end
523
524 local remember = self.offset;
525 local rr_parser = self[dns.type[rr.type]];
526 if rr_parser then rr_parser(self, rr); end
527 self.offset = remember;
528 rr.rdata = self:sub(rr.rdlength);
529 return rr;
530 end
531
532
533 function resolver:rrs (count) -- - - - - - - - - - - - - - - - - - - - - rrs
534 local rrs = {};
535 for i = 1,count do append(rrs, self:rr()); end
536 return rrs;
537 end
538
539
540 function resolver:decode(packet, force) -- - - - - - - - - - - - - - decode
541 self.packet, self.offset = packet, 1;
542 local header = self:header(force);
543 if not header then return nil; end
544 local response = { header = header };
545
546 response.question = {};
547 local offset = self.offset;
548 for i = 1,response.header.qdcount do
549 append(response.question, self:question());
550 end
551 response.question.raw = string.sub(self.packet, offset, self.offset - 1);
552
553 if not force then
554 if not self.active[response.header.id] or not self.active[response.header.id][response.question.raw] then
555 self.active[response.header.id] = nil;
556 return nil;
557 end
558 end
559
560 response.answer = self:rrs(response.header.ancount);
561 response.authority = self:rrs(response.header.nscount);
562 response.additional = self:rrs(response.header.arcount);
563
564 return response;
565 end
566
567
568 -- socket layer -------------------------------------------------- socket layer
569
570
571 resolver.delays = { 1, 3 };
572
573
574 function resolver:addnameserver(address) -- - - - - - - - - - addnameserver
575 self.server = self.server or {};
576 append(self.server, address);
577 end
578
579
580 function resolver:setnameserver(address) -- - - - - - - - - - setnameserver
581 self.server = {};
582 self:addnameserver(address);
583 end
584
585
586 function resolver:adddefaultnameservers() -- - - - - adddefaultnameservers
587 if is_windows then
588 if windows and windows.get_nameservers then
589 for _, server in ipairs(windows.get_nameservers()) do
590 self:addnameserver(server);
591 end
592 end
593 if not self.server or #self.server == 0 then
594 -- TODO log warning about no nameservers, adding opendns servers as fallback
595 self:addnameserver("208.67.222.222");
596 self:addnameserver("208.67.220.220");
597 end
598 else -- posix
599 local resolv_conf = io.open("/etc/resolv.conf");
600 if resolv_conf then
601 for line in resolv_conf:lines() do
602 line = line:gsub("#.*$", "")
603 :match('^%s*nameserver%s+([%x:%.]*)%s*$');
604 if line then
605 local ip = new_ip(line);
606 if ip then
607 self:addnameserver(ip.addr);
608 end
609 end
610 end
611 end
612 if not self.server or #self.server == 0 then
613 -- TODO log warning about no nameservers, adding localhost as the default nameserver
614 self:addnameserver("127.0.0.1");
615 end
616 end
617 end
618
619
620 function resolver:getsocket(servernum) -- - - - - - - - - - - - - getsocket
621 self.socket = self.socket or {};
622 self.socketset = self.socketset or {};
623
624 local sock = self.socket[servernum];
625 if sock then return sock; end
626
627 local ok, err;
628 local peer = self.server[servernum];
629 if peer:find(":") then
630 sock, err = socket.udp6();
631 else
632 sock, err = socket.udp();
633 end
634 if sock and self.socket_wrapper then sock, err = self.socket_wrapper(sock, self); end
635 if not sock then
636 return nil, err;
637 end
638 sock:settimeout(0);
639 -- todo: attempt to use a random port, fallback to 0
640 self.socket[servernum] = sock;
641 self.socketset[sock] = servernum;
642 -- set{sock,peer}name can fail, eg because of local routing table
643 -- if so, try the next server
644 ok, err = sock:setsockname('*', 0);
645 if not ok then return self:servfail(sock, err); end
646 ok, err = sock:setpeername(peer, 53);
647 if not ok then return self:servfail(sock, err); end
648 return sock;
649 end
650
651 function resolver:voidsocket(sock)
652 if self.socket[sock] then
653 self.socketset[self.socket[sock]] = nil;
654 self.socket[sock] = nil;
655 elseif self.socketset[sock] then
656 self.socket[self.socketset[sock]] = nil;
657 self.socketset[sock] = nil;
658 end
659 sock:close();
660 end
661
662 function resolver:socket_wrapper_set(func) -- - - - - - - socket_wrapper_set
663 self.socket_wrapper = func;
664 end
665
666
667 function resolver:closeall () -- - - - - - - - - - - - - - - - - - closeall
668 for i,sock in ipairs(self.socket) do
669 self.socket[i] = nil;
670 self.socketset[sock] = nil;
671 sock:close();
672 end
673 end
674
675
676 function resolver:remember(rr, type) -- - - - - - - - - - - - - - remember
677 --print ('remember', type, rr.class, rr.type, rr.name)
678 local qname, qtype, qclass = standardize(rr.name, rr.type, rr.class);
679
680 if type ~= '*' then
681 type = qtype;
682 local all = get(self.cache, qclass, '*', qname);
683 --print('remember all', all);
684 if all then append(all, rr); end
685 end
686
687 self.cache = self.cache or setmetatable({}, cache_metatable);
688 local rrs = get(self.cache, qclass, type, qname) or
689 set(self.cache, qclass, type, qname, setmetatable({}, rrs_metatable));
690 if not rrs[rr[qtype:lower()]] then
691 rrs[rr[qtype:lower()]] = true;
692 append(rrs, rr);
693 end
694
695 if type == 'MX' then self.unsorted[rrs] = true; end
696 end
697
698
699 local function comp_mx(a, b) -- - - - - - - - - - - - - - - - - - - comp_mx
700 return (a.pref == b.pref) and (a.mx < b.mx) or (a.pref < b.pref);
701 end
702
703
704 function resolver:peek (qname, qtype, qclass) -- - - - - - - - - - - - peek
705 qname, qtype, qclass = standardize(qname, qtype, qclass);
706 local rrs = get(self.cache, qclass, qtype, qname);
707 if not rrs then return nil; end
708 if prune(rrs, socket.gettime()) and qtype == '*' or not next(rrs) then
709 set(self.cache, qclass, qtype, qname, nil);
710 return nil;
711 end
712 if self.unsorted[rrs] then table.sort (rrs, comp_mx); end
713 return rrs;
714 end
715
716
717 function resolver:purge(soft) -- - - - - - - - - - - - - - - - - - - purge
718 if soft == 'soft' then
719 self.time = socket.gettime();
720 for class,types in pairs(self.cache or {}) do
721 for type,names in pairs(types) do
722 for name,rrs in pairs(names) do
723 prune(rrs, self.time, 'soft')
724 end
725 end
726 end
727 else self.cache = setmetatable({}, cache_metatable); end
728 end
729
730
731 function resolver:query(qname, qtype, qclass) -- - - - - - - - - - -- query
732 qname, qtype, qclass = standardize(qname, qtype, qclass)
733
734 local co = coroutine.running();
735 local q = get(self.wanted, qclass, qtype, qname);
736 if co and q then
737 -- We are already waiting for a reply to an identical query.
738 set(self.wanted, qclass, qtype, qname, co, true);
739 return true;
740 end
741
742 if not self.server then self:adddefaultnameservers(); end
743
744 local question = encodeQuestion(qname, qtype, qclass);
745 local peek = self:peek (qname, qtype, qclass);
746 if peek then return peek; end
747
748 local header, id = encodeHeader();
749 --print ('query id', id, qclass, qtype, qname)
750 local o = {
751 packet = header..question,
752 server = self.best_server,
753 delay = 1,
754 retry = socket.gettime() + self.delays[1]
755 };
756
757 -- remember the query
758 self.active[id] = self.active[id] or {};
759 self.active[id][question] = o;
760
761 -- remember which coroutine wants the answer
762 if co then
763 set(self.wanted, qclass, qtype, qname, co, true);
764 end
765
766 local conn, err = self:getsocket(o.server)
767 if not conn then
768 return nil, err;
769 end
770 conn:send (o.packet)
771
772 if timer and self.timeout then
773 local num_servers = #self.server;
774 local i = 1;
775 timer.add_task(self.timeout, function ()
776 if get(self.wanted, qclass, qtype, qname, co) then
777 if i < num_servers then
778 i = i + 1;
779 self:servfail(conn);
780 o.server = self.best_server;
781 conn, err = self:getsocket(o.server);
782 if conn then
783 conn:send(o.packet);
784 return self.timeout;
785 end
786 end
787 -- Tried everything, failed
788 self:cancel(qclass, qtype, qname);
789 end
790 end)
791 end
792 return true;
793 end
794
795 function resolver:servfail(sock, err)
796 -- Resend all queries for this server
797
798 local num = self.socketset[sock]
799
800 -- Socket is dead now
801 sock = self:voidsocket(sock);
802
803 -- Find all requests to the down server, and retry on the next server
804 self.time = socket.gettime();
805 for id,queries in pairs(self.active) do
806 for question,o in pairs(queries) do
807 if o.server == num then -- This request was to the broken server
808 o.server = o.server + 1 -- Use next server
809 if o.server > #self.server then
810 o.server = 1;
811 end
812
813 o.retries = (o.retries or 0) + 1;
814 if o.retries >= #self.server then
815 --print('timeout');
816 queries[question] = nil;
817 else
818 sock, err = self:getsocket(o.server);
819 if sock then sock:send(o.packet); end
820 end
821 end
822 end
823 if next(queries) == nil then
824 self.active[id] = nil;
825 end
826 end
827
828 if num == self.best_server then
829 self.best_server = self.best_server + 1;
830 if self.best_server > #self.server then
831 -- Exhausted all servers, try first again
832 self.best_server = 1;
833 end
834 end
835 return sock, err;
836 end
837
838 function resolver:settimeout(seconds)
839 self.timeout = seconds;
840 end
841
842 function resolver:receive(rset) -- - - - - - - - - - - - - - - - - receive
843 --print('receive'); print(self.socket);
844 self.time = socket.gettime();
845 rset = rset or self.socket;
846
847 local response;
848 for i,sock in pairs(rset) do
849
850 if self.socketset[sock] then
851 local packet = sock:receive();
852 if packet then
853 response = self:decode(packet);
854 if response and self.active[response.header.id]
855 and self.active[response.header.id][response.question.raw] then
856 --print('received response');
857 --self.print(response);
858
859 for j,rr in pairs(response.answer) do
860 if rr.name:sub(-#response.question[1].name, -1) == response.question[1].name then
861 self:remember(rr, response.question[1].type)
862 end
863 end
864
865 -- retire the query
866 local queries = self.active[response.header.id];
867 queries[response.question.raw] = nil;
868
869 if not next(queries) then self.active[response.header.id] = nil; end
870 if not next(self.active) then self:closeall(); end
871
872 -- was the query on the wanted list?
873 local q = response.question[1];
874 local cos = get(self.wanted, q.class, q.type, q.name);
875 if cos then
876 for co in pairs(cos) do
877 if coroutine.status(co) == "suspended" then coroutine.resume(co); end
878 end
879 set(self.wanted, q.class, q.type, q.name, nil);
880 end
881 end
882
883 end
884 end
885 end
886
887 return response;
888 end
889
890
891 function resolver:feed(sock, packet, force)
892 --print('receive'); print(self.socket);
893 self.time = socket.gettime();
894
895 local response = self:decode(packet, force);
896 if response and self.active[response.header.id]
897 and self.active[response.header.id][response.question.raw] then
898 --print('received response');
899 --self.print(response);
900
901 for j,rr in pairs(response.answer) do
902 self:remember(rr, response.question[1].type);
903 end
904
905 -- retire the query
906 local queries = self.active[response.header.id];
907 queries[response.question.raw] = nil;
908 if not next(queries) then self.active[response.header.id] = nil; end
909 if not next(self.active) then self:closeall(); end
910
911 -- was the query on the wanted list?
912 local q = response.question[1];
913 if q then
914 local cos = get(self.wanted, q.class, q.type, q.name);
915 if cos then
916 for co in pairs(cos) do
917 if coroutine.status(co) == "suspended" then coroutine.resume(co); end
918 end
919 set(self.wanted, q.class, q.type, q.name, nil);
920 end
921 end
922 end
923
924 return response;
925 end
926
927 function resolver:cancel(qclass, qtype, qname)
928 local cos = get(self.wanted, qclass, qtype, qname);
929 if cos then
930 for co in pairs(cos) do
931 if coroutine.status(co) == "suspended" then coroutine.resume(co); end
932 end
933 set(self.wanted, qclass, qtype, qname, nil);
934 end
935 end
936
937 function resolver:pulse() -- - - - - - - - - - - - - - - - - - - - - pulse
938 --print(':pulse');
939 while self:receive() do end
940 if not next(self.active) then return nil; end
941
942 self.time = socket.gettime();
943 for id,queries in pairs(self.active) do
944 for question,o in pairs(queries) do
945 if self.time >= o.retry then
946
947 o.server = o.server + 1;
948 if o.server > #self.server then
949 o.server = 1;
950 o.delay = o.delay + 1;
951 end
952
953 if o.delay > #self.delays then
954 --print('timeout');
955 queries[question] = nil;
956 if not next(queries) then self.active[id] = nil; end
957 if not next(self.active) then return nil; end
958 else
959 --print('retry', o.server, o.delay);
960 local _a = self.socket[o.server];
961 if _a then _a:send(o.packet); end
962 o.retry = self.time + self.delays[o.delay];
963 end
964 end
965 end
966 end
967
968 if next(self.active) then return true; end
969 return nil;
970 end
971
972
973 function resolver:lookup(qname, qtype, qclass) -- - - - - - - - - - lookup
974 self:query (qname, qtype, qclass)
975 while self:pulse() do
976 local recvt = {}
977 for i, s in ipairs(self.socket) do
978 recvt[i] = s
979 end
980 socket.select(recvt, nil, 4)
981 end
982 --print(self.cache);
983 return self:peek(qname, qtype, qclass);
984 end
985
986 function resolver:lookupex(handler, qname, qtype, qclass) -- - - - - - - - - - lookup
987 return self:peek(qname, qtype, qclass) or self:query(qname, qtype, qclass);
988 end
989
990 function resolver:tohostname(ip)
991 return dns.lookup(ip:gsub("(%d+)%.(%d+)%.(%d+)%.(%d+)", "%4.%3.%2.%1.in-addr.arpa."), "PTR");
992 end
993
994 --print ---------------------------------------------------------------- print
995
996
997 local hints = { -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints
998 qr = { [0]='query', 'response' },
999 opcode = { [0]='query', 'inverse query', 'server status request' },
1000 aa = { [0]='non-authoritative', 'authoritative' },
1001 tc = { [0]='complete', 'truncated' },
1002 rd = { [0]='recursion not desired', 'recursion desired' },
1003 ra = { [0]='recursion not available', 'recursion available' },
1004 z = { [0]='(reserved)' },
1005 rcode = { [0]='no error', 'format error', 'server failure', 'name error', 'not implemented' },
1006
1007 type = dns.type,
1008 class = dns.class
1009 };
1010
1011
1012 local function hint(p, s) -- - - - - - - - - - - - - - - - - - - - - - hint
1013 return (hints[s] and hints[s][p[s]]) or '';
1014 end
1015
1016
1017 function resolver.print(response) -- - - - - - - - - - - - - resolver.print
1018 for s,s in pairs { 'id', 'qr', 'opcode', 'aa', 'tc', 'rd', 'ra', 'z',
1019 'rcode', 'qdcount', 'ancount', 'nscount', 'arcount' } do
1020 print( string.format('%-30s', 'header.'..s), response.header[s], hint(response.header, s) );
1021 end
1022
1023 for i,question in ipairs(response.question) do
1024 print(string.format ('question[%i].name ', i), question.name);
1025 print(string.format ('question[%i].type ', i), question.type);
1026 print(string.format ('question[%i].class ', i), question.class);
1027 end
1028
1029 local common = { name=1, type=1, class=1, ttl=1, rdlength=1, rdata=1 };
1030 local tmp;
1031 for s,s in pairs({'answer', 'authority', 'additional'}) do
1032 for i,rr in pairs(response[s]) do
1033 for j,t in pairs({ 'name', 'type', 'class', 'ttl', 'rdlength' }) do
1034 tmp = string.format('%s[%i].%s', s, i, t);
1035 print(string.format('%-30s', tmp), rr[t], hint(rr, t));
1036 end
1037 for j,t in pairs(rr) do
1038 if not common[j] then
1039 tmp = string.format('%s[%i].%s', s, i, j);
1040 print(string.format('%-30s %s', tostring(tmp), tostring(t)));
1041 end
1042 end
1043 end
1044 end
1045 end
1046
1047
1048 -- module api ------------------------------------------------------ module api
1049
1050
1051 function dns.resolver () -- - - - - - - - - - - - - - - - - - - - - resolver
1052 -- this function seems to be redundant with resolver.new ()
1053
1054 local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, best_server = 1 };
1055 setmetatable (r, resolver);
1056 setmetatable (r.cache, cache_metatable);
1057 setmetatable (r.unsorted, { __mode = 'kv' });
1058 return r;
1059 end
1060
1061 local _resolver = dns.resolver();
1062 dns._resolver = _resolver;
1063
1064 function dns.lookup(...) -- - - - - - - - - - - - - - - - - - - - - lookup
1065 return _resolver:lookup(...);
1066 end
1067
1068 function dns.tohostname(...)
1069 return _resolver:tohostname(...);
1070 end
1071
1072 function dns.purge(...) -- - - - - - - - - - - - - - - - - - - - - - purge
1073 return _resolver:purge(...);
1074 end
1075
1076 function dns.peek(...) -- - - - - - - - - - - - - - - - - - - - - - - peek
1077 return _resolver:peek(...);
1078 end
1079
1080 function dns.query(...) -- - - - - - - - - - - - - - - - - - - - - - query
1081 return _resolver:query(...);
1082 end
1083
1084 function dns.feed(...) -- - - - - - - - - - - - - - - - - - - - - - - feed
1085 return _resolver:feed(...);
1086 end
1087
1088 function dns.cancel(...) -- - - - - - - - - - - - - - - - - - - - - - cancel
1089 return _resolver:cancel(...);
1090 end
1091
1092 function dns.settimeout(...)
1093 return _resolver:settimeout(...);
1094 end
1095
1096 function dns.cache()
1097 return _resolver.cache;
1098 end
1099
1100 function dns.socket_wrapper_set(...) -- - - - - - - - - socket_wrapper_set
1101 return _resolver:socket_wrapper_set(...);
1102 end
1103
1104 return dns;

mercurial