net/dns.lua

Sun, 27 Sep 2009 12:10:50 +0100

author
Matthew Wild <mwild1@gmail.com>
date
Sun, 27 Sep 2009 12:10:50 +0100
changeset 1832
5ae3209fefa2
parent 1824
fc21e8832848
child 1833
f4c88dd32724
child 1949
e3d777d76b1a
permissions
-rw-r--r--

Merge with waqas

-- Prosody IM
-- This file is included with Prosody IM. It has modifications,
-- which are hereby placed in the public domain.

-- public domain 20080404 lua@ztact.com


-- todo: quick (default) header generation
-- todo: nxdomain, error handling
-- todo: cache results of encodeName


-- reference: http://tools.ietf.org/html/rfc1035
-- reference: http://tools.ietf.org/html/rfc1876 (LOC)


require 'socket'
local ztact = require 'util.ztact'
local require = require
local os = os;

local coroutine, io, math, socket, string, table =
      coroutine, io, math, socket, string, table

local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack =
      ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack

local get, set = ztact.get, ztact.set


-------------------------------------------------- module dns
module ('dns')
local dns = _M;


-- dns type & class codes ------------------------------ dns type & class codes


local append = table.insert


local function highbyte (i)    -- - - - - - - - - - - - - - - - - - -  highbyte
  return (i-(i%0x100))/0x100
  end


local function augment (t)    -- - - - - - - - - - - - - - - - - - - -  augment
  local a = {}
  for i,s in pairs (t) do  a[i] = s  a[s] = s  a[string.lower (s)] = s  end
  return a
  end


local function encode (t)    -- - - - - - - - - - - - - - - - - - - - -  encode
  local code = {}
  for i,s in pairs (t) do
    local word = string.char (highbyte (i), i %0x100)
    code[i] = word
    code[s] = word
    code[string.lower (s)] = word
    end
  return code
  end


dns.types = {
  'A', 'NS', 'MD', 'MF', 'CNAME', 'SOA', 'MB', 'MG', 'MR', 'NULL', 'WKS',
  'PTR', 'HINFO', 'MINFO', 'MX', 'TXT',
  [ 28] = 'AAAA', [ 29] = 'LOC',   [ 33] = 'SRV',
  [252] = 'AXFR', [253] = 'MAILB', [254] = 'MAILA', [255] = '*' }


dns.classes = { 'IN', 'CS', 'CH', 'HS', [255] = '*' }


dns.type      = augment (dns.types)
dns.class     = augment (dns.classes)
dns.typecode  = encode  (dns.types)
dns.classcode = encode  (dns.classes)



local function standardize (qname, qtype, qclass)    -- - - - - - - standardize
  if string.byte (qname, -1) ~= 0x2E then  qname = qname..'.'  end
  qname = string.lower (qname)
  return qname, dns.type[qtype or 'A'], dns.class[qclass or 'IN']
  end


local function prune (rrs, time, soft)    -- - - - - - - - - - - - - - -  prune

  time = time or socket.gettime ()
  for i,rr in pairs (rrs) do

    if rr.tod then
      -- rr.tod = rr.tod - 50    -- accelerated decripitude
      rr.ttl = math.floor (rr.tod - time)
      if rr.ttl <= 0 then
        table.remove(rrs, i);
        return prune(rrs, time, soft); -- Re-iterate
      end

    elseif soft == 'soft' then    -- What is this?  I forget!
      assert (rr.ttl == 0)
      rrs[i] = nil
      end  end  end


-- metatables & co. ------------------------------------------ metatables & co.


local resolver = {}
resolver.__index = resolver


local SRV_tostring


local rr_metatable = {}    -- - - - - - - - - - - - - - - - - - -  rr_metatable
function rr_metatable.__tostring (rr)
  local s0 = string.format (
    '%2s %-5s %6i %-28s', rr.class, rr.type, rr.ttl, rr.name )
  local s1 = ''
  if rr.type == 'A' then  s1 = ' '..rr.a
  elseif rr.type == 'MX' then
    s1 = string.format (' %2i %s', rr.pref, rr.mx)
  elseif rr.type == 'CNAME' then  s1 = ' '..rr.cname
  elseif rr.type == 'LOC'   then  s1 = ' '..resolver.LOC_tostring (rr)
  elseif rr.type == 'NS'    then  s1 = ' '..rr.ns
  elseif rr.type == 'SRV'   then  s1 = ' '..SRV_tostring (rr)
  elseif rr.type == 'TXT'   then  s1 = ' '..rr.txt
  else  s1 = ' <UNKNOWN RDATA TYPE>'  end
  return s0..s1
  end


local rrs_metatable = {}    -- - - - - - - - - - - - - - - - - -  rrs_metatable
function rrs_metatable.__tostring (rrs)
  local t = {}
  for i,rr in pairs (rrs) do  append (t, tostring (rr)..'\n')  end
  return table.concat (t)
  end


local cache_metatable = {}    -- - - - - - - - - - - - - - - -  cache_metatable
function cache_metatable.__tostring (cache)
  local time = socket.gettime ()
  local t = {}
  for class,types in pairs (cache) do
    for type,names in pairs (types) do
      for name,rrs in pairs (names) do
        prune (rrs, time)
        append (t, tostring (rrs))  end  end  end
  return table.concat (t)
  end


function resolver:new ()    -- - - - - - - - - - - - - - - - - - - - - resolver
  local r = { active = {}, cache = {}, unsorted = {} }
  setmetatable (r, resolver)
  setmetatable (r.cache, cache_metatable)
  setmetatable (r.unsorted, { __mode = 'kv' })
  return r
  end


-- packet layer -------------------------------------------------- packet layer


function dns.random (...)    -- - - - - - - - - - - - - - - - - - -  dns.random
  math.randomseed (10000*socket.gettime ())
  dns.random = math.random
  return dns.random (...)
  end


local function encodeHeader (o)    -- - - - - - - - - - - - - - -  encodeHeader

  o = o or {}

  o.id = o.id or		-- 16b	(random) id
    dns.random (0, 0xffff)

  o.rd = o.rd or 1		--  1b  1 recursion desired
  o.tc = o.tc or 0		--  1b	1 truncated response
  o.aa = o.aa or 0		--  1b	1 authoritative response
  o.opcode = o.opcode or 0	--  4b	0 query
				--      1 inverse query
				--	2 server status request
				--	3-15 reserved
  o.qr = o.qr or 0		--  1b	0 query, 1 response

  o.rcode = o.rcode or 0	--  4b  0 no error
				--	1 format error
				--	2 server failure
				--	3 name error
				--	4 not implemented
				--	5 refused
				--	6-15 reserved
  o.z  = o.z  or 0		--  3b  0 resvered
  o.ra = o.ra or 0		--  1b  1 recursion available

  o.qdcount = o.qdcount or 1	-- 16b	number of question RRs
  o.ancount = o.ancount or 0	-- 16b	number of answers RRs
  o.nscount = o.nscount or 0	-- 16b	number of nameservers RRs
  o.arcount = o.arcount or 0	-- 16b  number of additional RRs

  -- string.char() rounds, so prevent roundup with -0.4999
  local header = string.char (
    highbyte (o.id),  o.id %0x100,
    o.rd + 2*o.tc + 4*o.aa + 8*o.opcode + 128*o.qr,
    o.rcode + 16*o.z + 128*o.ra,
    highbyte (o.qdcount),  o.qdcount %0x100,
    highbyte (o.ancount),  o.ancount %0x100,
    highbyte (o.nscount),  o.nscount %0x100,
    highbyte (o.arcount),  o.arcount %0x100 )

  return header, o.id
  end


local function encodeName (name)    -- - - - - - - - - - - - - - - - encodeName
  local t = {}
  for part in string.gmatch (name, '[^.]+') do
    append (t, string.char (string.len (part)))
    append (t, part)
    end
  append (t, string.char (0))
  return table.concat (t)
  end


local function encodeQuestion (qname, qtype, qclass)    -- - - - encodeQuestion
  qname  = encodeName (qname)
  qtype  = dns.typecode[qtype or 'a']
  qclass = dns.classcode[qclass or 'in']
  return qname..qtype..qclass;
  end


function resolver:byte (len)    -- - - - - - - - - - - - - - - - - - - - - byte
  len = len or 1
  local offset = self.offset
  local last = offset + len - 1
  if last > #self.packet then
    error (string.format ('out of bounds: %i>%i', last, #self.packet))  end
  self.offset = offset + len
  return string.byte (self.packet, offset, last)
  end


function resolver:word ()    -- - - - - - - - - - - - - - - - - - - - - -  word
  local b1, b2 = self:byte (2)
  return 0x100*b1 + b2
  end


function resolver:dword ()    -- - - - - - - - - - - - - - - - - - - - -  dword
  local b1, b2, b3, b4 = self:byte (4)
  --print ('dword', b1, b2, b3, b4)
  return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4
  end


function resolver:sub (len)    -- - - - - - - - - - - - - - - - - - - - - - sub
  len = len or 1
  local s = string.sub (self.packet, self.offset, self.offset + len - 1)
  self.offset = self.offset + len
  return s
  end


function resolver:header (force)    -- - - - - - - - - - - - - - - - - - header

  local id = self:word ()
  --print (string.format (':header  id  %x', id))
  if not self.active[id] and not force then  return nil  end

  local h = { id = id }

  local b1, b2 = self:byte (2)

  h.rd      = b1 %2
  h.tc      = b1 /2%2
  h.aa      = b1 /4%2
  h.opcode  = b1 /8%16
  h.qr      = b1 /128

  h.rcode   = b2 %16
  h.z       = b2 /16%8
  h.ra      = b2 /128

  h.qdcount = self:word ()
  h.ancount = self:word ()
  h.nscount = self:word ()
  h.arcount = self:word ()

  for k,v in pairs (h) do  h[k] = v-v%1  end

  return h
  end


function resolver:name ()    -- - - - - - - - - - - - - - - - - - - - - -  name
  local remember, pointers = nil, 0
  local len = self:byte ()
  local n = {}
  while len > 0 do
    if len >= 0xc0 then    -- name is "compressed"
      pointers = pointers + 1
      if pointers >= 20 then  error ('dns error: 20 pointers')  end
      local offset = ((len-0xc0)*0x100) + self:byte ()
      remember = remember or self.offset
      self.offset = offset + 1    -- +1 for lua
    else    -- name is not compressed
      append (n, self:sub (len)..'.')
      end
    len = self:byte ()
    end
  self.offset = remember or self.offset
  return table.concat (n)
  end


function resolver:question ()    -- - - - - - - - - - - - - - - - - -  question
  local q = {}
  q.name  = self:name ()
  q.type  = dns.type[self:word ()]
  q.class = dns.class[self:word ()]
  return q
  end


function resolver:A (rr)    -- - - - - - - - - - - - - - - - - - - - - - - -  A
  local b1, b2, b3, b4 = self:byte (4)
  rr.a = string.format ('%i.%i.%i.%i', b1, b2, b3, b4)
  end


function resolver:CNAME (rr)    -- - - - - - - - - - - - - - - - - - - -  CNAME
  rr.cname = self:name ()
  end


function resolver:MX (rr)    -- - - - - - - - - - - - - - - - - - - - - - -  MX
  rr.pref = self:word ()
  rr.mx   = self:name ()
  end


function resolver:LOC_nibble_power ()    -- - - - - - - - - -  LOC_nibble_power
  local b = self:byte ()
  --print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10))
  return ((b-(b%0x10))/0x10) * (10^(b%0x10))
  end


function resolver:LOC (rr)    -- - - - - - - - - - - - - - - - - - - - - -  LOC
  rr.version = self:byte ()
  if rr.version == 0 then
    rr.loc           = rr.loc or {}
    rr.loc.size      = self:LOC_nibble_power ()
    rr.loc.horiz_pre = self:LOC_nibble_power ()
    rr.loc.vert_pre  = self:LOC_nibble_power ()
    rr.loc.latitude  = self:dword ()
    rr.loc.longitude = self:dword ()
    rr.loc.altitude  = self:dword ()
    end  end


local function LOC_tostring_degrees (f, pos, neg)    -- - - - - - - - - - - - -
  f = f - 0x80000000
  if f < 0 then  pos = neg  f = -f  end
  local deg, min, msec
  msec = f%60000
  f    = (f-msec)/60000
  min  = f%60
  deg = (f-min)/60
  return string.format ('%3d %2d %2.3f %s', deg, min, msec/1000, pos)
  end


function resolver.LOC_tostring (rr)    -- - - - - - - - - - - - -  LOC_tostring

  local t = {}

  --[[
  for k,name in pairs { 'size', 'horiz_pre', 'vert_pre',
			'latitude', 'longitude', 'altitude' } do
    append (t, string.format ('%4s%-10s: %12.0f\n', '', name, rr.loc[name]))
    end
  --]]

  append ( t, string.format (
    '%s    %s    %.2fm %.2fm %.2fm %.2fm',
    LOC_tostring_degrees (rr.loc.latitude, 'N', 'S'),
    LOC_tostring_degrees (rr.loc.longitude, 'E', 'W'),
    (rr.loc.altitude - 10000000) / 100,
    rr.loc.size / 100,
    rr.loc.horiz_pre / 100,
    rr.loc.vert_pre / 100 ) )

  return table.concat (t)
  end


function resolver:NS (rr)    -- - - - - - - - - - - - - - - - - - - - - - -  NS
  rr.ns = self:name ()
  end


function resolver:SOA (rr)    -- - - - - - - - - - - - - - - - - - - - - -  SOA
  end


function resolver:SRV (rr)    -- - - - - - - - - - - - - - - - - - - - - -  SRV
  rr.srv = {}
  rr.srv.priority = self:word ()
  rr.srv.weight   = self:word ()
  rr.srv.port     = self:word ()
  rr.srv.target   = self:name ()
  end


function SRV_tostring (rr)    -- - - - - - - - - - - - - - - - - - SRV_tostring
  local s = rr.srv
  return string.format ( '%5d %5d %5d %s',
                         s.priority, s.weight, s.port, s.target )
  end


function resolver:TXT (rr)    -- - - - - - - - - - - - - - - - - - - - - -  TXT
  rr.txt = self:sub (rr.rdlength)
  end


function resolver:rr ()    -- - - - - - - - - - - - - - - - - - - - - - - -  rr
  local rr = {}
  setmetatable (rr, rr_metatable)
  rr.name     = self:name (self)
  rr.type     = dns.type[self:word ()] or rr.type
  rr.class    = dns.class[self:word ()] or rr.class
  rr.ttl      = 0x10000*self:word () + self:word ()
  rr.rdlength = self:word ()

  if rr.ttl == 0 then  -- pass
  else  rr.tod = self.time + rr.ttl  end

  local remember = self.offset
  local rr_parser = self[dns.type[rr.type]]
  if rr_parser then  rr_parser (self, rr)  end
  self.offset = remember
  rr.rdata = self:sub (rr.rdlength)
  return rr
  end


function resolver:rrs (count)    -- - - - - - - - - - - - - - - - - - - - - rrs
  local rrs = {}
  for i = 1,count do  append (rrs, self:rr ())  end
  return rrs
  end


function resolver:decode (packet, force)    -- - - - - - - - - - - - - - decode

  self.packet, self.offset = packet, 1
  local header = self:header (force)
  if not header then  return nil  end
  local response = { header = header }

  response.question = {}
  local offset = self.offset
  for i = 1,response.header.qdcount do
    append (response.question, self:question ())  end
  response.question.raw = string.sub (self.packet, offset, self.offset - 1)

  if not force then
    if not self.active[response.header.id] or
       not self.active[response.header.id][response.question.raw] then
      return nil  end  end

  response.answer     = self:rrs (response.header.ancount)
  response.authority  = self:rrs (response.header.nscount)
  response.additional = self:rrs (response.header.arcount)

  return response
  end


-- socket layer -------------------------------------------------- socket layer


resolver.delays = { 1, 3  }


function resolver:addnameserver (address)    -- - - - - - - - - - addnameserver
  self.server = self.server or {}
  append (self.server, address)
  end


function resolver:setnameserver (address)    -- - - - - - - - - - setnameserver
  self.server = {}
  self:addnameserver (address)
  end


function resolver:adddefaultnameservers ()    -- - - - -  adddefaultnameservers
  local resolv_conf = io.open("/etc/resolv.conf");
  if resolv_conf then
	  for line in resolv_conf:lines() do
		local address = line:gsub("#.*$", ""):match('^%s*nameserver%s+(%d+%.%d+%.%d+%.%d+)%s*$')
		if address then self:addnameserver (address)  end
	  end
  elseif os.getenv("WINDIR") then
  	self:addnameserver ("208.67.222.222")
  	self:addnameserver ("208.67.220.220")  	
  end
  if not self.server or #self.server == 0 then
  	self:addnameserver("127.0.0.1");
  end
end


function resolver:getsocket (servernum)    -- - - - - - - - - - - - - getsocket

  self.socket = self.socket or {}
  self.socketset = self.socketset or {}

  local sock = self.socket[servernum]
  if sock then  return sock  end

  sock = socket.udp ()
  if self.socket_wrapper then  sock = self.socket_wrapper (sock, self)  end
  sock:settimeout (0)
  -- todo: attempt to use a random port, fallback to 0
  sock:setsockname ('*', 0)
  sock:setpeername (self.server[servernum], 53)
  self.socket[servernum] = sock
  self.socketset[sock] = servernum
  return sock
  end

function resolver:voidsocket (sock)
  if self.socket[sock] then
    self.socketset[self.socket[sock]] = nil
    self.socket[sock] = nil
  elseif self.socketset[sock] then
    self.socket[self.socketset[sock]] = nil
    self.socketset[sock] = nil
  end
end

function resolver:socket_wrapper_set (func)  -- - - - - - - socket_wrapper_set
  self.socket_wrapper = func
  end


function resolver:closeall ()    -- - - - - - - - - - - - - - - - - -  closeall
  for i,sock in ipairs (self.socket) do  self.socket[i]:close ()  end
  self.socket = {}
  end


function resolver:remember (rr, type)    -- - - - - - - - - - - - - -  remember

  --print ('remember', type, rr.class, rr.type, rr.name)

  if type ~= '*' then
    type = rr.type
    local all = get (self.cache, rr.class, '*', rr.name)
    --print ('remember all', all)
    if all then  append (all, rr)  end
    end

  self.cache = self.cache or setmetatable ({}, cache_metatable)
  local rrs = get (self.cache, rr.class, type, rr.name) or
    set (self.cache, rr.class, type, rr.name, setmetatable ({}, rrs_metatable))
  append (rrs, rr)

  if type == 'MX' then  self.unsorted[rrs] = true  end
  end


local function comp_mx (a, b)    -- - - - - - - - - - - - - - - - - - - comp_mx
  return (a.pref == b.pref) and (a.mx < b.mx) or (a.pref < b.pref)
  end


function resolver:peek (qname, qtype, qclass)    -- - - - - - - - - - - -  peek
  qname, qtype, qclass = standardize (qname, qtype, qclass)
  local rrs = get (self.cache, qclass, qtype, qname)
  if not rrs then  return nil  end
  if prune (rrs, socket.gettime ()) and qtype == '*' or not next (rrs) then
    set (self.cache, qclass, qtype, qname, nil)  return nil  end
  if self.unsorted[rrs] then  table.sort (rrs, comp_mx)  end
  return rrs
  end


function resolver:purge (soft)    -- - - - - - - - - - - - - - - - - - -  purge
  if soft == 'soft' then
    self.time = socket.gettime ()
    for class,types in pairs (self.cache or {}) do
      for type,names in pairs (types) do
        for name,rrs in pairs (names) do
          prune (rrs, self.time, 'soft')
          end  end  end
  else  self.cache = {}  end
  end


function resolver:query (qname, qtype, qclass)    -- - - - - - - - - - -- query

  qname, qtype, qclass = standardize (qname, qtype, qclass)

  if not self.server then self:adddefaultnameservers ()  end

  local question = encodeQuestion (qname, qtype, qclass)
  local peek = self:peek (qname, qtype, qclass)
  if peek then  return peek  end

  local header, id = encodeHeader ()
  --print ('query  id', id, qclass, qtype, qname)
  local o = { packet = header..question,
              server = self.best_server,
              delay  = 1,
              retry  = socket.gettime () + self.delays[1] }

  -- remember the query
  self.active[id] = self.active[id] or {}
  self.active[id][question] = o

  -- remember which coroutine wants the answer
  local co = coroutine.running ()
  if co then
    set (self.wanted, qclass, qtype, qname, co, true)
    --set (self.yielded, co, qclass, qtype, qname, true)
  end
  
  self:getsocket (o.server):send (o.packet)

end

function resolver:servfail(sock)
  -- Resend all queries for this server
  
  local num = self.socketset[sock]
  
  -- Socket is dead now
  self:voidsocket(sock);
  
  -- Find all requests to the down server, and retry on the next server
  self.time = socket.gettime ()
  for id,queries in pairs (self.active) do
    for question,o in pairs (queries) do
      if o.server == num then -- This request was to the broken server
        o.server = o.server + 1 -- Use next server
        if o.server > #self.server then
          o.server = 1
        end

        o.retries = (o.retries or 0) + 1;
        if o.retries >= #self.server then
          --print ('timeout')
          queries[question] = nil
        else
          local _a = self:getsocket(o.server);
          if _a then _a:send (o.packet) end
        end
      end
    end
  end
   
   if num == self.best_server then
   	self.best_server = self.best_server + 1
   	if self.best_server > #self.server then
   		-- Exhausted all servers, try first again
   		self.best_server = 1
   	end
   end
end

function resolver:receive (rset)    -- - - - - - - - - - - - - - - - -  receive

  --print 'receive'  print (self.socket)
  self.time = socket.gettime ()
  rset = rset or self.socket

  local response
  for i,sock in pairs (rset) do

    if self.socketset[sock] then
    local packet = sock:receive ()
    if packet then

    response = self:decode (packet)
    if response then
    --print 'received response'
    --self.print (response)

    for i,section in pairs { 'answer', 'authority', 'additional' } do
      for j,rr in pairs (response[section]) do
        self:remember (rr, response.question[1].type)  end  end

    -- retire the query
    local queries = self.active[response.header.id]
    if queries[response.question.raw] then
      queries[response.question.raw] = nil  end
    if not next (queries) then  self.active[response.header.id] = nil  end
    if not next (self.active) then  self:closeall ()  end

    -- was the query on the wanted list?
    local q = response.question
    local cos = get (self.wanted, q.class, q.type, q.name)
    if cos then
      for co in pairs (cos) do
        set (self.yielded, co, q.class, q.type, q.name, nil)
	if coroutine.status(co) == "suspended" then  coroutine.resume (co)  end
        end
      set (self.wanted, q.class, q.type, q.name, nil)
      end  end  end  end  end

  return response
  end


function resolver:feed(sock, packet)
  --print 'receive'  print (self.socket)
  self.time = socket.gettime ()

  local response = self:decode (packet)
  if response then
    --print 'received response'
    --self.print (response)

    for i,section in pairs { 'answer', 'authority', 'additional' } do
      for j,rr in pairs (response[section]) do
        self:remember (rr, response.question[1].type)
      end
    end

    -- retire the query
    local queries = self.active[response.header.id]
    if queries[response.question.raw] then
      queries[response.question.raw] = nil
    end
    if not next (queries) then  self.active[response.header.id] = nil  end
    if not next (self.active) then  self:closeall ()  end

    -- was the query on the wanted list?
    local q = response.question[1]
    if q then
      local cos = get (self.wanted, q.class, q.type, q.name)
      if cos then
        for co in pairs (cos) do
          set (self.yielded, co, q.class, q.type, q.name, nil)
          if coroutine.status(co) == "suspended" then coroutine.resume (co)  end
        end
        set (self.wanted, q.class, q.type, q.name, nil)
      end
    end
  end 

  return response
end

function resolver:cancel(data)
	local cos = get (self.wanted, unpack(data, 1, 3))
	if cos then
		cos[data[4]] = nil;
	end
end

function resolver:pulse ()    -- - - - - - - - - - - - - - - - - - - - -  pulse

  --print ':pulse'
  while self:receive() do end
  if not next (self.active) then  return nil  end

  self.time = socket.gettime ()
  for id,queries in pairs (self.active) do
    for question,o in pairs (queries) do
      if self.time >= o.retry then

        o.server = o.server + 1
        if o.server > #self.server then
          o.server = 1
          o.delay = o.delay + 1
          end

        if o.delay > #self.delays then
          --print ('timeout')
          queries[question] = nil
          if not next (queries) then  self.active[id] = nil  end
          if not next (self.active) then  return nil  end
        else
          --print ('retry', o.server, o.delay)
          local _a = self.socket[o.server];
          if _a then _a:send (o.packet) end
          o.retry = self.time + self.delays[o.delay]
          end  end  end  end

  if next (self.active) then  return true  end
  return nil
  end


function resolver:lookup (qname, qtype, qclass)    -- - - - - - - - - -  lookup
  self:query (qname, qtype, qclass)
  while self:pulse () do  socket.select (self.socket, nil, 4)  end
  --print (self.cache)
  return self:peek (qname, qtype, qclass)
  end

function resolver:lookupex (handler, qname, qtype, qclass)    -- - - - - - - - - -  lookup
  return self:peek (qname, qtype, qclass) or self:query (qname, qtype, qclass)
  end


--print ---------------------------------------------------------------- print


local hints = {    -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints
  qr = { [0]='query', 'response' },
  opcode = { [0]='query', 'inverse query', 'server status request' },
  aa = { [0]='non-authoritative', 'authoritative' },
  tc = { [0]='complete', 'truncated' },
  rd = { [0]='recursion not desired', 'recursion desired' },
  ra = { [0]='recursion not available', 'recursion available' },
  z  = { [0]='(reserved)' },
  rcode = { [0]='no error', 'format error', 'server failure', 'name error',
            'not implemented' },

  type = dns.type,
  class = dns.class, }


local function hint (p, s)    -- - - - - - - - - - - - - - - - - - - - - - hint
  return (hints[s] and hints[s][p[s]]) or ''  end


function resolver.print (response)    -- - - - - - - - - - - - - resolver.print

  for s,s in pairs { 'id', 'qr', 'opcode', 'aa', 'tc', 'rd', 'ra', 'z',
		     'rcode', 'qdcount', 'ancount', 'nscount', 'arcount' } do
    print ( string.format ('%-30s', 'header.'..s),
            response.header[s], hint (response.header, s) )
    end

  for i,question in ipairs (response.question) do
    print (string.format ('question[%i].name         ', i), question.name)
    print (string.format ('question[%i].type         ', i), question.type)
    print (string.format ('question[%i].class        ', i), question.class)
    end

  local common = { name=1, type=1, class=1, ttl=1, rdlength=1, rdata=1 }
  local tmp
  for s,s in pairs {'answer', 'authority', 'additional'} do
    for i,rr in pairs (response[s]) do
      for j,t in pairs { 'name', 'type', 'class', 'ttl', 'rdlength' } do
        tmp = string.format ('%s[%i].%s', s, i, t)
        print (string.format ('%-30s', tmp), rr[t], hint (rr, t))
        end
      for j,t in pairs (rr) do
        if not common[j] then
          tmp = string.format ('%s[%i].%s', s, i, j)
          print (string.format ('%-30s  %s', tostring(tmp), tostring(t)))
          end  end  end  end  end


-- module api ------------------------------------------------------ module api


local function resolve (func, ...)    -- - - - - - - - - - - - - - resolver_get
  dns._resolver = dns._resolver or dns.resolver ()
  return func (dns._resolver, ...)
  end


function dns.resolver ()    -- - - - - - - - - - - - - - - - - - - - - resolver

  -- this function seems to be redundant with resolver.new ()

  local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, yielded = {}, 
              best_server = 1 }
  setmetatable (r, resolver)
  setmetatable (r.cache, cache_metatable)
  setmetatable (r.unsorted, { __mode = 'kv' })
  return r
  end


function dns.lookup (...)    -- - - - - - - - - - - - - - - - - - - - -  lookup
  return resolve (resolver.lookup, ...)  end


function dns.purge (...)    -- - - - - - - - - - - - - - - - - - - - - -  purge
  return resolve (resolver.purge, ...)  end

function dns.peek (...)    -- - - - - - - - - - - - - - - - - - - - - - -  peek
  return resolve (resolver.peek, ...)  end


function dns.query (...)    -- - - - - - - - - - - - - - - - - - - - - -  query
  return resolve (resolver.query, ...)  end

function dns.feed (...)    -- - - - - - - - - - - - - - - - - - - - - -  feed
  return resolve (resolver.feed, ...)  end

function dns.cancel(...)   -- - - - - - - - - - - - - - - - - - - - - -  cancel
  return resolve(resolver.cancel, ...) end

function dns:socket_wrapper_set (...)    -- - - - - - - - -  socket_wrapper_set
  return resolve (resolver.socket_wrapper_set, ...)  end


return dns

mercurial