Thu, 03 Dec 2020 17:05:27 +0000
Initial commit
0 | 1 | local logger = require "util.logger"; |
2 | local log = logger.init("util.async"); | |
3 | local new_id = require "util.id".short; | |
4 | local xpcall = require "util.xpcall".xpcall; | |
5 | ||
6 | local function checkthread() | |
7 | local thread, main = coroutine.running(); | |
8 | if not thread or main then | |
9 | error("Not running in an async context, see https://prosody.im/doc/developers/util/async"); | |
10 | end | |
11 | return thread; | |
12 | end | |
13 | ||
14 | local function runner_from_thread(thread) | |
15 | local level = 0; | |
16 | -- Find the 'level' of the top-most function (0 == current level, 1 == caller, ...) | |
17 | while debug.getinfo(thread, level, "") do level = level + 1; end | |
18 | local name, runner = debug.getlocal(thread, level-1, 1); | |
19 | if name ~= "self" or type(runner) ~= "table" or runner.thread ~= thread then | |
20 | return nil; | |
21 | end | |
22 | return runner; | |
23 | end | |
24 | ||
25 | local function call_watcher(runner, watcher_name, ...) | |
26 | local watcher = runner.watchers[watcher_name]; | |
27 | if not watcher then | |
28 | return false; | |
29 | end | |
30 | runner:log("debug", "Calling '%s' watcher", watcher_name); | |
31 | local ok, err = xpcall(watcher, debug.traceback, runner, ...); | |
32 | if not ok then | |
33 | runner:log("error", "Error in '%s' watcher: %s", watcher_name, err); | |
34 | return nil, err; | |
35 | end | |
36 | return true; | |
37 | end | |
38 | ||
39 | local function runner_continue(thread) | |
40 | -- ASSUMPTION: runner is in 'waiting' state (but we don't have the runner to know for sure) | |
41 | if coroutine.status(thread) ~= "suspended" then -- This should suffice | |
42 | log("error", "unexpected async state: thread not suspended"); | |
43 | return false; | |
44 | end | |
45 | local ok, state, runner = coroutine.resume(thread); | |
46 | if not ok then | |
47 | local err = state; | |
48 | -- Running the coroutine failed, which means we have to find the runner manually, | |
49 | -- in order to inform the error handler | |
50 | runner = runner_from_thread(thread); | |
51 | if not runner then | |
52 | log("error", "unexpected async state: unable to locate runner during error handling"); | |
53 | return false; | |
54 | end | |
55 | call_watcher(runner, "error", debug.traceback(thread, err)); | |
56 | runner.state = "ready"; | |
57 | return runner:run(); | |
58 | elseif state == "ready" then | |
59 | -- If state is 'ready', it is our responsibility to update runner.state from 'waiting'. | |
60 | -- We also have to :run(), because the queue might have further items that will not be | |
61 | -- processed otherwise. FIXME: It's probably best to do this in a nexttick (0 timer). | |
62 | runner.state = "ready"; | |
63 | runner:run(); | |
64 | end | |
65 | return true; | |
66 | end | |
67 | ||
68 | local function waiter(num) | |
69 | local thread = checkthread(); | |
70 | num = num or 1; | |
71 | local waiting; | |
72 | return function () | |
73 | if num == 0 then return; end -- already done | |
74 | waiting = true; | |
75 | coroutine.yield("wait"); | |
76 | end, function () | |
77 | num = num - 1; | |
78 | if num == 0 and waiting then | |
79 | runner_continue(thread); | |
80 | elseif num < 0 then | |
81 | error("done() called too many times"); | |
82 | end | |
83 | end; | |
84 | end | |
85 | ||
86 | local function guarder() | |
87 | local guards = {}; | |
88 | local default_id = {}; | |
89 | return function (id, func) | |
90 | id = id or default_id; | |
91 | local thread = checkthread(); | |
92 | local guard = guards[id]; | |
93 | if not guard then | |
94 | guard = {}; | |
95 | guards[id] = guard; | |
96 | log("debug", "New guard!"); | |
97 | else | |
98 | table.insert(guard, thread); | |
99 | log("debug", "Guarded. %d threads waiting.", #guard) | |
100 | coroutine.yield("wait"); | |
101 | end | |
102 | local function exit() | |
103 | local next_waiting = table.remove(guard, 1); | |
104 | if next_waiting then | |
105 | log("debug", "guard: Executing next waiting thread (%d left)", #guard) | |
106 | runner_continue(next_waiting); | |
107 | else | |
108 | log("debug", "Guard off duty.") | |
109 | guards[id] = nil; | |
110 | end | |
111 | end | |
112 | if func then | |
113 | func(); | |
114 | exit(); | |
115 | return; | |
116 | end | |
117 | return exit; | |
118 | end; | |
119 | end | |
120 | ||
121 | local runner_mt = {}; | |
122 | runner_mt.__index = runner_mt; | |
123 | ||
124 | local function runner_create_thread(func, self) | |
125 | local thread = coroutine.create(function (self) -- luacheck: ignore 432/self | |
126 | while true do | |
127 | func(coroutine.yield("ready", self)); | |
128 | end | |
129 | end); | |
130 | debug.sethook(thread, debug.gethook()); | |
131 | assert(coroutine.resume(thread, self)); -- Start it up, it will return instantly to wait for the first input | |
132 | return thread; | |
133 | end | |
134 | ||
135 | local function default_error_watcher(runner, err) | |
136 | runner:log("error", "Encountered error: %s", err); | |
137 | error(err); | |
138 | end | |
139 | local function default_func(f) f(); end | |
140 | local function runner(func, watchers, data) | |
141 | local id = new_id(); | |
142 | local _log = logger.init("runner" .. id); | |
143 | return setmetatable({ func = func or default_func, thread = false, state = "ready", notified_state = "ready", | |
144 | queue = {}, watchers = watchers or { error = default_error_watcher }, data = data, id = id, _log = _log; } | |
145 | , runner_mt); | |
146 | end | |
147 | ||
148 | -- Add a task item for the runner to process | |
149 | function runner_mt:run(input) | |
150 | if input ~= nil then | |
151 | table.insert(self.queue, input); | |
152 | --self:log("debug", "queued new work item, %d items queued", #self.queue); | |
153 | end | |
154 | if self.state ~= "ready" then | |
155 | -- The runner is busy. Indicate that the task item has been | |
156 | -- queued, and return information about the current runner state | |
157 | return true, self.state, #self.queue; | |
158 | end | |
159 | ||
160 | local q, thread = self.queue, self.thread; | |
161 | if not thread or coroutine.status(thread) == "dead" then | |
162 | --luacheck: ignore 143/coroutine | |
163 | if thread and coroutine.close then | |
164 | coroutine.close(thread); | |
165 | end | |
166 | self:log("debug", "creating new coroutine"); | |
167 | -- Create a new coroutine for this runner | |
168 | thread = runner_create_thread(self.func, self); | |
169 | self.thread = thread; | |
170 | end | |
171 | ||
172 | -- Process task item(s) while the queue is not empty, and we're not blocked | |
173 | local n, state, err = #q, self.state, nil; | |
174 | self.state = "running"; | |
175 | --self:log("debug", "running main loop"); | |
176 | while n > 0 and state == "ready" and not err do | |
177 | local consumed; | |
178 | -- Loop through queue items, and attempt to run them | |
179 | for i = 1,n do | |
180 | local queued_input = q[i]; | |
181 | local ok, new_state = coroutine.resume(thread, queued_input); | |
182 | if not ok then | |
183 | -- There was an error running the coroutine, save the error, mark runner as ready to begin again | |
184 | consumed, state, err = i, "ready", debug.traceback(thread, new_state); | |
185 | self.thread = nil; | |
186 | break; | |
187 | elseif new_state == "wait" then | |
188 | -- Runner is blocked on waiting for a task item to complete | |
189 | consumed, state = i, "waiting"; | |
190 | break; | |
191 | end | |
192 | end | |
193 | -- Loop ended - either queue empty because all tasks passed without blocking (consumed == nil) | |
194 | -- or runner is blocked/errored, and consumed will contain the number of tasks processed so far | |
195 | if not consumed then consumed = n; end | |
196 | -- Remove consumed items from the queue array | |
197 | if q[n+1] ~= nil then | |
198 | n = #q; | |
199 | end | |
200 | for i = 1, n do | |
201 | q[i] = q[consumed+i]; | |
202 | end | |
203 | n = #q; | |
204 | end | |
205 | -- Runner processed all items it can, so save current runner state | |
206 | self.state = state; | |
207 | if err or state ~= self.notified_state then | |
208 | self:log("debug", "changed state from %s to %s", self.notified_state, err and ("error ("..state..")") or state); | |
209 | if err then | |
210 | state = "error" | |
211 | else | |
212 | self.notified_state = state; | |
213 | end | |
214 | local handler = self.watchers[state]; | |
215 | if handler then handler(self, err); end | |
216 | end | |
217 | if n > 0 then | |
218 | return self:run(); | |
219 | end | |
220 | return true, state, n; | |
221 | end | |
222 | ||
223 | -- Add a task item to the queue without invoking the runner, even if it is idle | |
224 | function runner_mt:enqueue(input) | |
225 | table.insert(self.queue, input); | |
226 | self:log("debug", "queued new work item, %d items queued", #self.queue); | |
227 | return self; | |
228 | end | |
229 | ||
230 | function runner_mt:log(level, fmt, ...) | |
231 | return self._log(level, fmt, ...); | |
232 | end | |
233 | ||
234 | function runner_mt:onready(f) | |
235 | self.watchers.ready = f; | |
236 | return self; | |
237 | end | |
238 | ||
239 | function runner_mt:onwaiting(f) | |
240 | self.watchers.waiting = f; | |
241 | return self; | |
242 | end | |
243 | ||
244 | function runner_mt:onerror(f) | |
245 | self.watchers.error = f; | |
246 | return self; | |
247 | end | |
248 | ||
249 | local function ready() | |
250 | return pcall(checkthread); | |
251 | end | |
252 | ||
253 | local function wait_for(promise) | |
254 | local async_wait, async_done = waiter(); | |
255 | local ret, err = nil, nil; | |
256 | promise:next( | |
257 | function (r) ret = r; end, | |
258 | function (e) err = e; end) | |
259 | :finally(async_done); | |
260 | async_wait(); | |
261 | if ret then | |
262 | return ret; | |
263 | else | |
264 | return nil, err; | |
265 | end | |
266 | end | |
267 | ||
268 | return { | |
269 | ready = ready; | |
270 | waiter = waiter; | |
271 | guarder = guarder; | |
272 | runner = runner; | |
273 | wait = wait_for; -- COMPAT w/trunk pre-0.12 | |
274 | wait_for = wait_for; | |
275 | }; |