|
1 /*-------------------------------------------------------------------------- |
|
2 * LuaSec 0.4 |
|
3 * Copyright (C) 2006-2009 Bruno Silvestre |
|
4 * |
|
5 *--------------------------------------------------------------------------*/ |
|
6 |
|
7 #include <string.h> |
|
8 |
|
9 #include <openssl/ssl.h> |
|
10 #include <openssl/err.h> |
|
11 |
|
12 #include <lua.h> |
|
13 #include <lauxlib.h> |
|
14 |
|
15 #include "io.h" |
|
16 #include "buffer.h" |
|
17 #include "timeout.h" |
|
18 #include "socket.h" |
|
19 #include "ssl.h" |
|
20 |
|
21 /** |
|
22 * Map error code into string. |
|
23 */ |
|
24 static const char *ssl_ioerror(void *ctx, int err) |
|
25 { |
|
26 if (err == IO_SSL) { |
|
27 p_ssl ssl = (p_ssl) ctx; |
|
28 switch(ssl->error) { |
|
29 case SSL_ERROR_NONE: return "No error"; |
|
30 case SSL_ERROR_ZERO_RETURN: return "closed"; |
|
31 case SSL_ERROR_WANT_READ: return "wantread"; |
|
32 case SSL_ERROR_WANT_WRITE: return "wantwrite"; |
|
33 case SSL_ERROR_WANT_CONNECT: return "'connect' not completed"; |
|
34 case SSL_ERROR_WANT_ACCEPT: return "'accept' not completed"; |
|
35 case SSL_ERROR_WANT_X509_LOOKUP: return "Waiting for callback"; |
|
36 case SSL_ERROR_SYSCALL: return "System error"; |
|
37 case SSL_ERROR_SSL: return ERR_reason_error_string(ERR_get_error()); |
|
38 default: return "Unknown SSL error"; |
|
39 } |
|
40 } |
|
41 return socket_strerror(err); |
|
42 } |
|
43 |
|
44 /** |
|
45 * Close the connection before the GC collect the object. |
|
46 */ |
|
47 static int meth_destroy(lua_State *L) |
|
48 { |
|
49 p_ssl ssl = (p_ssl) lua_touserdata(L, 1); |
|
50 if (ssl->ssl) { |
|
51 socket_setblocking(&ssl->sock); |
|
52 SSL_shutdown(ssl->ssl); |
|
53 socket_destroy(&ssl->sock); |
|
54 SSL_free(ssl->ssl); |
|
55 ssl->ssl = NULL; |
|
56 } |
|
57 return 0; |
|
58 } |
|
59 |
|
60 /** |
|
61 * Perform the TLS/SSL handshake |
|
62 */ |
|
63 static int handshake(p_ssl ssl) |
|
64 { |
|
65 int err; |
|
66 p_timeout tm = timeout_markstart(&ssl->tm); |
|
67 if (ssl->state == ST_SSL_CLOSED) |
|
68 return IO_CLOSED; |
|
69 for ( ; ; ) { |
|
70 ERR_clear_error(); |
|
71 err = SSL_do_handshake(ssl->ssl); |
|
72 ssl->error = SSL_get_error(ssl->ssl, err); |
|
73 switch(ssl->error) { |
|
74 case SSL_ERROR_NONE: |
|
75 ssl->state = ST_SSL_CONNECTED; |
|
76 return IO_DONE; |
|
77 case SSL_ERROR_WANT_READ: |
|
78 err = socket_waitfd(&ssl->sock, WAITFD_R, tm); |
|
79 if (err == IO_TIMEOUT) return IO_SSL; |
|
80 if (err != IO_DONE) return err; |
|
81 break; |
|
82 case SSL_ERROR_WANT_WRITE: |
|
83 err = socket_waitfd(&ssl->sock, WAITFD_W, tm); |
|
84 if (err == IO_TIMEOUT) return IO_SSL; |
|
85 if (err != IO_DONE) return err; |
|
86 break; |
|
87 case SSL_ERROR_SYSCALL: |
|
88 if (ERR_peek_error()) { |
|
89 ssl->error = SSL_ERROR_SSL; |
|
90 return IO_SSL; |
|
91 } |
|
92 if (err == 0) |
|
93 return IO_CLOSED; |
|
94 return socket_error(); |
|
95 default: |
|
96 return IO_SSL; |
|
97 } |
|
98 } |
|
99 return IO_UNKNOWN; |
|
100 } |
|
101 |
|
102 /** |
|
103 * Send data |
|
104 */ |
|
105 static int ssl_send(void *ctx, const char *data, size_t count, size_t *sent, |
|
106 p_timeout tm) |
|
107 { |
|
108 int err; |
|
109 p_ssl ssl = (p_ssl) ctx; |
|
110 if (ssl->state == ST_SSL_CLOSED) |
|
111 return IO_CLOSED; |
|
112 *sent = 0; |
|
113 for ( ; ; ) { |
|
114 ERR_clear_error(); |
|
115 err = SSL_write(ssl->ssl, data, (int) count); |
|
116 ssl->error = SSL_get_error(ssl->ssl, err); |
|
117 switch(ssl->error) { |
|
118 case SSL_ERROR_NONE: |
|
119 *sent = err; |
|
120 return IO_DONE; |
|
121 case SSL_ERROR_WANT_READ: |
|
122 err = socket_waitfd(&ssl->sock, WAITFD_R, tm); |
|
123 if (err == IO_TIMEOUT) return IO_SSL; |
|
124 if (err != IO_DONE) return err; |
|
125 break; |
|
126 case SSL_ERROR_WANT_WRITE: |
|
127 err = socket_waitfd(&ssl->sock, WAITFD_W, tm); |
|
128 if (err == IO_TIMEOUT) return IO_SSL; |
|
129 if (err != IO_DONE) return err; |
|
130 break; |
|
131 case SSL_ERROR_SYSCALL: |
|
132 if (ERR_peek_error()) { |
|
133 ssl->error = SSL_ERROR_SSL; |
|
134 return IO_SSL; |
|
135 } |
|
136 if (err == 0) |
|
137 return IO_CLOSED; |
|
138 return socket_error(); |
|
139 default: |
|
140 return IO_SSL; |
|
141 } |
|
142 } |
|
143 return IO_UNKNOWN; |
|
144 } |
|
145 |
|
146 /** |
|
147 * Receive data |
|
148 */ |
|
149 static int ssl_recv(void *ctx, char *data, size_t count, size_t *got, |
|
150 p_timeout tm) |
|
151 { |
|
152 int err; |
|
153 p_ssl ssl = (p_ssl) ctx; |
|
154 if (ssl->state == ST_SSL_CLOSED) |
|
155 return IO_CLOSED; |
|
156 *got = 0; |
|
157 for ( ; ; ) { |
|
158 ERR_clear_error(); |
|
159 err = SSL_read(ssl->ssl, data, (int) count); |
|
160 ssl->error = SSL_get_error(ssl->ssl, err); |
|
161 switch(ssl->error) { |
|
162 case SSL_ERROR_NONE: |
|
163 *got = err; |
|
164 return IO_DONE; |
|
165 case SSL_ERROR_ZERO_RETURN: |
|
166 *got = err; |
|
167 return IO_CLOSED; |
|
168 case SSL_ERROR_WANT_READ: |
|
169 err = socket_waitfd(&ssl->sock, WAITFD_R, tm); |
|
170 if (err == IO_TIMEOUT) return IO_SSL; |
|
171 if (err != IO_DONE) return err; |
|
172 break; |
|
173 case SSL_ERROR_WANT_WRITE: |
|
174 err = socket_waitfd(&ssl->sock, WAITFD_W, tm); |
|
175 if (err == IO_TIMEOUT) return IO_SSL; |
|
176 if (err != IO_DONE) return err; |
|
177 break; |
|
178 case SSL_ERROR_SYSCALL: |
|
179 if (ERR_peek_error()) { |
|
180 ssl->error = SSL_ERROR_SSL; |
|
181 return IO_SSL; |
|
182 } |
|
183 if (err == 0) |
|
184 return IO_CLOSED; |
|
185 return socket_error(); |
|
186 default: |
|
187 return IO_SSL; |
|
188 } |
|
189 } |
|
190 return IO_UNKNOWN; |
|
191 } |
|
192 |
|
193 /** |
|
194 * Create a new TLS/SSL object and mark it as new. |
|
195 */ |
|
196 static int meth_create(lua_State *L) |
|
197 { |
|
198 p_ssl ssl; |
|
199 int mode = ctx_getmode(L, 1); |
|
200 SSL_CTX *ctx = ctx_getcontext(L, 1); |
|
201 |
|
202 if (mode == MD_CTX_INVALID) { |
|
203 lua_pushnil(L); |
|
204 lua_pushstring(L, "invalid mode"); |
|
205 return 2; |
|
206 } |
|
207 ssl = (p_ssl) lua_newuserdata(L, sizeof(t_ssl)); |
|
208 if (!ssl) { |
|
209 lua_pushnil(L); |
|
210 lua_pushstring(L, "error creating SSL object"); |
|
211 return 2; |
|
212 } |
|
213 ssl->ssl = SSL_new(ctx); |
|
214 if (!ssl->ssl) { |
|
215 lua_pushnil(L); |
|
216 lua_pushstring(L, "error creating SSL object"); |
|
217 return 2;; |
|
218 } |
|
219 ssl->state = ST_SSL_NEW; |
|
220 SSL_set_fd(ssl->ssl, (int) SOCKET_INVALID); |
|
221 SSL_set_mode(ssl->ssl, SSL_MODE_ENABLE_PARTIAL_WRITE | |
|
222 SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); |
|
223 if (mode == MD_CTX_SERVER) |
|
224 SSL_set_accept_state(ssl->ssl); |
|
225 else |
|
226 SSL_set_connect_state(ssl->ssl); |
|
227 |
|
228 io_init(&ssl->io, (p_send) ssl_send, (p_recv) ssl_recv, |
|
229 (p_error) ssl_ioerror, ssl); |
|
230 timeout_init(&ssl->tm, -1, -1); |
|
231 buffer_init(&ssl->buf, &ssl->io, &ssl->tm); |
|
232 |
|
233 luaL_getmetatable(L, "SSL:Connection"); |
|
234 lua_setmetatable(L, -2); |
|
235 return 1; |
|
236 } |
|
237 |
|
238 /** |
|
239 * Buffer send function |
|
240 */ |
|
241 static int meth_send(lua_State *L) { |
|
242 p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection"); |
|
243 return buffer_meth_send(L, &ssl->buf); |
|
244 } |
|
245 |
|
246 /** |
|
247 * Buffer receive function |
|
248 */ |
|
249 static int meth_receive(lua_State *L) { |
|
250 p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection"); |
|
251 return buffer_meth_receive(L, &ssl->buf); |
|
252 } |
|
253 |
|
254 /** |
|
255 * Select support methods |
|
256 */ |
|
257 static int meth_getfd(lua_State *L) |
|
258 { |
|
259 p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection"); |
|
260 lua_pushnumber(L, ssl->sock); |
|
261 return 1; |
|
262 } |
|
263 |
|
264 /** |
|
265 * Set the TLS/SSL file descriptor. |
|
266 * This is done *before* the handshake. |
|
267 */ |
|
268 static int meth_setfd(lua_State *L) |
|
269 { |
|
270 p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection"); |
|
271 if (ssl->state != ST_SSL_NEW) |
|
272 luaL_argerror(L, 1, "invalid SSL object state"); |
|
273 ssl->sock = luaL_checkint(L, 2); |
|
274 socket_setnonblocking(&ssl->sock); |
|
275 SSL_set_fd(ssl->ssl, (int)ssl->sock); |
|
276 return 0; |
|
277 } |
|
278 |
|
279 /** |
|
280 * Lua handshake function. |
|
281 */ |
|
282 static int meth_handshake(lua_State *L) |
|
283 { |
|
284 p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection"); |
|
285 int err = handshake(ssl); |
|
286 if (err == IO_DONE) { |
|
287 lua_pushboolean(L, 1); |
|
288 return 1; |
|
289 } |
|
290 lua_pushboolean(L, 0); |
|
291 lua_pushstring(L, ssl_ioerror((void*)ssl, err)); |
|
292 return 2; |
|
293 } |
|
294 |
|
295 /** |
|
296 * Close the connection. |
|
297 */ |
|
298 static int meth_close(lua_State *L) |
|
299 { |
|
300 p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection"); |
|
301 meth_destroy(L); |
|
302 ssl->state = ST_SSL_CLOSED; |
|
303 return 0; |
|
304 } |
|
305 |
|
306 /** |
|
307 * Set timeout. |
|
308 */ |
|
309 static int meth_settimeout(lua_State *L) |
|
310 { |
|
311 p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection"); |
|
312 return timeout_meth_settimeout(L, &ssl->tm); |
|
313 } |
|
314 |
|
315 /** |
|
316 * Check if there is data in the buffer. |
|
317 */ |
|
318 static int meth_dirty(lua_State *L) |
|
319 { |
|
320 int res = 0; |
|
321 p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection"); |
|
322 if (ssl->state != ST_SSL_CLOSED) |
|
323 res = !buffer_isempty(&ssl->buf) || SSL_pending(ssl->ssl); |
|
324 lua_pushboolean(L, res); |
|
325 return 1; |
|
326 } |
|
327 |
|
328 /** |
|
329 * Return the state information about the SSL object. |
|
330 */ |
|
331 static int meth_want(lua_State *L) |
|
332 { |
|
333 p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection"); |
|
334 int code = (ssl->state == ST_SSL_CLOSED) ? SSL_NOTHING : SSL_want(ssl->ssl); |
|
335 switch(code) { |
|
336 case SSL_NOTHING: lua_pushstring(L, "nothing"); break; |
|
337 case SSL_READING: lua_pushstring(L, "read"); break; |
|
338 case SSL_WRITING: lua_pushstring(L, "write"); break; |
|
339 case SSL_X509_LOOKUP: lua_pushstring(L, "x509lookup"); break; |
|
340 } |
|
341 return 1; |
|
342 } |
|
343 |
|
344 /** |
|
345 * Return a pointer to SSL structure. |
|
346 */ |
|
347 static int meth_rawconn(lua_State *L) |
|
348 { |
|
349 p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); |
|
350 lua_pushlightuserdata(L, (void*)ssl->ssl); |
|
351 return 1; |
|
352 } |
|
353 |
|
354 /*---------------------------------------------------------------------------*/ |
|
355 |
|
356 |
|
357 /** |
|
358 * SSL metamethods |
|
359 */ |
|
360 static luaL_Reg meta[] = { |
|
361 {"close", meth_close}, |
|
362 {"getfd", meth_getfd}, |
|
363 {"dirty", meth_dirty}, |
|
364 {"dohandshake", meth_handshake}, |
|
365 {"receive", meth_receive}, |
|
366 {"send", meth_send}, |
|
367 {"settimeout", meth_settimeout}, |
|
368 {"want", meth_want}, |
|
369 {NULL, NULL} |
|
370 }; |
|
371 |
|
372 /** |
|
373 * SSL functions |
|
374 */ |
|
375 static luaL_Reg funcs[] = { |
|
376 {"create", meth_create}, |
|
377 {"setfd", meth_setfd}, |
|
378 {"rawconnection", meth_rawconn}, |
|
379 {NULL, NULL} |
|
380 }; |
|
381 |
|
382 /** |
|
383 * Initialize modules |
|
384 */ |
|
385 LUASEC_API int luaopen_ssl_core(lua_State *L) |
|
386 { |
|
387 /* Initialize SSL */ |
|
388 if (!SSL_library_init()) { |
|
389 lua_pushstring(L, "unable to initialize SSL library"); |
|
390 lua_error(L); |
|
391 } |
|
392 SSL_load_error_strings(); |
|
393 |
|
394 /* Initialize internal library */ |
|
395 socket_open(); |
|
396 |
|
397 /* Registre the functions and tables */ |
|
398 luaL_newmetatable(L, "SSL:Connection"); |
|
399 lua_newtable(L); |
|
400 luaL_register(L, NULL, meta); |
|
401 lua_setfield(L, -2, "__index"); |
|
402 lua_pushcfunction(L, meth_destroy); |
|
403 lua_setfield(L, -2, "__gc"); |
|
404 |
|
405 luaL_register(L, "ssl.core", funcs); |
|
406 lua_pushnumber(L, SOCKET_INVALID); |
|
407 lua_setfield(L, -2, "invalidfd"); |
|
408 |
|
409 return 1; |
|
410 } |
|
411 |