util/async.lua

changeset 0
550f506de75a
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/util/async.lua	Thu Dec 03 17:05:27 2020 +0000
@@ -0,0 +1,275 @@
+local logger = require "util.logger";
+local log = logger.init("util.async");
+local new_id = require "util.id".short;
+local xpcall = require "util.xpcall".xpcall;
+
+local function checkthread()
+	local thread, main = coroutine.running();
+	if not thread or main then
+		error("Not running in an async context, see https://prosody.im/doc/developers/util/async");
+	end
+	return thread;
+end
+
+local function runner_from_thread(thread)
+	local level = 0;
+	-- Find the 'level' of the top-most function (0 == current level, 1 == caller, ...)
+	while debug.getinfo(thread, level, "") do level = level + 1; end
+	local name, runner = debug.getlocal(thread, level-1, 1);
+	if name ~= "self" or type(runner) ~= "table" or runner.thread ~= thread then
+		return nil;
+	end
+	return runner;
+end
+
+local function call_watcher(runner, watcher_name, ...)
+	local watcher = runner.watchers[watcher_name];
+	if not watcher then
+		return false;
+	end
+	runner:log("debug", "Calling '%s' watcher", watcher_name);
+	local ok, err = xpcall(watcher, debug.traceback, runner, ...);
+	if not ok then
+		runner:log("error", "Error in '%s' watcher: %s", watcher_name, err);
+		return nil, err;
+	end
+	return true;
+end
+
+local function runner_continue(thread)
+	-- ASSUMPTION: runner is in 'waiting' state (but we don't have the runner to know for sure)
+	if coroutine.status(thread) ~= "suspended" then -- This should suffice
+		log("error", "unexpected async state: thread not suspended");
+		return false;
+	end
+	local ok, state, runner = coroutine.resume(thread);
+	if not ok then
+		local err = state;
+		-- Running the coroutine failed, which means we have to find the runner manually,
+		-- in order to inform the error handler
+		runner = runner_from_thread(thread);
+		if not runner then
+			log("error", "unexpected async state: unable to locate runner during error handling");
+			return false;
+		end
+		call_watcher(runner, "error", debug.traceback(thread, err));
+		runner.state = "ready";
+		return runner:run();
+	elseif state == "ready" then
+		-- If state is 'ready', it is our responsibility to update runner.state from 'waiting'.
+		-- We also have to :run(), because the queue might have further items that will not be
+		-- processed otherwise. FIXME: It's probably best to do this in a nexttick (0 timer).
+		runner.state = "ready";
+		runner:run();
+	end
+	return true;
+end
+
+local function waiter(num)
+	local thread = checkthread();
+	num = num or 1;
+	local waiting;
+	return function ()
+		if num == 0 then return; end -- already done
+		waiting = true;
+		coroutine.yield("wait");
+	end, function ()
+		num = num - 1;
+		if num == 0 and waiting then
+			runner_continue(thread);
+		elseif num < 0 then
+			error("done() called too many times");
+		end
+	end;
+end
+
+local function guarder()
+	local guards = {};
+	local default_id = {};
+	return function (id, func)
+		id = id or default_id;
+		local thread = checkthread();
+		local guard = guards[id];
+		if not guard then
+			guard = {};
+			guards[id] = guard;
+			log("debug", "New guard!");
+		else
+			table.insert(guard, thread);
+			log("debug", "Guarded. %d threads waiting.", #guard)
+			coroutine.yield("wait");
+		end
+		local function exit()
+			local next_waiting = table.remove(guard, 1);
+			if next_waiting then
+				log("debug", "guard: Executing next waiting thread (%d left)", #guard)
+				runner_continue(next_waiting);
+			else
+				log("debug", "Guard off duty.")
+				guards[id] = nil;
+			end
+		end
+		if func then
+			func();
+			exit();
+			return;
+		end
+		return exit;
+	end;
+end
+
+local runner_mt = {};
+runner_mt.__index = runner_mt;
+
+local function runner_create_thread(func, self)
+	local thread = coroutine.create(function (self) -- luacheck: ignore 432/self
+		while true do
+			func(coroutine.yield("ready", self));
+		end
+	end);
+	debug.sethook(thread, debug.gethook());
+	assert(coroutine.resume(thread, self)); -- Start it up, it will return instantly to wait for the first input
+	return thread;
+end
+
+local function default_error_watcher(runner, err)
+	runner:log("error", "Encountered error: %s", err);
+	error(err);
+end
+local function default_func(f) f(); end
+local function runner(func, watchers, data)
+	local id = new_id();
+	local _log = logger.init("runner" .. id);
+	return setmetatable({ func = func or default_func, thread = false, state = "ready", notified_state = "ready",
+		queue = {}, watchers = watchers or { error = default_error_watcher }, data = data, id = id, _log = _log; }
+	, runner_mt);
+end
+
+-- Add a task item for the runner to process
+function runner_mt:run(input)
+	if input ~= nil then
+		table.insert(self.queue, input);
+		--self:log("debug", "queued new work item, %d items queued", #self.queue);
+	end
+	if self.state ~= "ready" then
+		-- The runner is busy. Indicate that the task item has been
+		-- queued, and return information about the current runner state
+		return true, self.state, #self.queue;
+	end
+
+	local q, thread = self.queue, self.thread;
+	if not thread or coroutine.status(thread) == "dead" then
+		--luacheck: ignore 143/coroutine
+		if thread and coroutine.close then
+			coroutine.close(thread);
+		end
+		self:log("debug", "creating new coroutine");
+		-- Create a new coroutine for this runner
+		thread = runner_create_thread(self.func, self);
+		self.thread = thread;
+	end
+
+	-- Process task item(s) while the queue is not empty, and we're not blocked
+	local n, state, err = #q, self.state, nil;
+	self.state = "running";
+	--self:log("debug", "running main loop");
+	while n > 0 and state == "ready" and not err do
+		local consumed;
+		-- Loop through queue items, and attempt to run them
+		for i = 1,n do
+			local queued_input = q[i];
+			local ok, new_state = coroutine.resume(thread, queued_input);
+			if not ok then
+				-- There was an error running the coroutine, save the error, mark runner as ready to begin again
+				consumed, state, err = i, "ready", debug.traceback(thread, new_state);
+				self.thread = nil;
+				break;
+			elseif new_state == "wait" then
+				 -- Runner is blocked on waiting for a task item to complete
+				consumed, state = i, "waiting";
+				break;
+			end
+		end
+		-- Loop ended - either queue empty because all tasks passed without blocking (consumed == nil)
+		-- or runner is blocked/errored, and consumed will contain the number of tasks processed so far
+		if not consumed then consumed = n; end
+		-- Remove consumed items from the queue array
+		if q[n+1] ~= nil then
+			n = #q;
+		end
+		for i = 1, n do
+			q[i] = q[consumed+i];
+		end
+		n = #q;
+	end
+	-- Runner processed all items it can, so save current runner state
+	self.state = state;
+	if err or state ~= self.notified_state then
+		self:log("debug", "changed state from %s to %s", self.notified_state, err and ("error ("..state..")") or state);
+		if err then
+			state = "error"
+		else
+			self.notified_state = state;
+		end
+		local handler = self.watchers[state];
+		if handler then handler(self, err); end
+	end
+	if n > 0 then
+		return self:run();
+	end
+	return true, state, n;
+end
+
+-- Add a task item to the queue without invoking the runner, even if it is idle
+function runner_mt:enqueue(input)
+	table.insert(self.queue, input);
+	self:log("debug", "queued new work item, %d items queued", #self.queue);
+	return self;
+end
+
+function runner_mt:log(level, fmt, ...)
+	return self._log(level, fmt, ...);
+end
+
+function runner_mt:onready(f)
+	self.watchers.ready = f;
+	return self;
+end
+
+function runner_mt:onwaiting(f)
+	self.watchers.waiting = f;
+	return self;
+end
+
+function runner_mt:onerror(f)
+	self.watchers.error = f;
+	return self;
+end
+
+local function ready()
+	return pcall(checkthread);
+end
+
+local function wait_for(promise)
+	local async_wait, async_done = waiter();
+	local ret, err = nil, nil;
+	promise:next(
+		function (r) ret = r; end,
+		function (e) err = e; end)
+		:finally(async_done);
+	async_wait();
+	if ret then
+		return ret;
+	else
+		return nil, err;
+	end
+end
+
+return {
+	ready = ready;
+	waiter = waiter;
+	guarder = guarder;
+	runner = runner;
+	wait = wait_for; -- COMPAT w/trunk pre-0.12
+	wait_for = wait_for;
+};

mercurial