src/ssl.c

Fri, 05 Nov 2010 16:38:10 +0000

author
Matthew Wild <mwild1@gmail.com>
date
Fri, 05 Nov 2010 16:38:10 +0000
changeset 11
8d7698d3fd26
parent 10
a4a1fd8c1b43
child 12
ac943b31f40c
permissions
-rw-r--r--

Refactoring of :getpeercertificate(), support for subjectAltName extensions

/*--------------------------------------------------------------------------
 * LuaSec 0.4
 * Copyright (C) 2006-2009 Bruno Silvestre
 *
 *--------------------------------------------------------------------------*/

#include <string.h>

#include <openssl/ssl.h>
#include <openssl/x509v3.h>
#include <openssl/err.h>

#include <lua.h>
#include <lauxlib.h>

#include "io.h"
#include "buffer.h"
#include "timeout.h"
#include "socket.h"
#include "ssl.h"

#define min(a, b) (a<b)?a:b

/**
 * Map error code into string.
 */
static const char *ssl_ioerror(void *ctx, int err)
{
  if (err == IO_SSL) {
    p_ssl ssl = (p_ssl) ctx;
    switch(ssl->error) {
    case SSL_ERROR_NONE: return "No error";
    case SSL_ERROR_ZERO_RETURN: return "closed";
    case SSL_ERROR_WANT_READ: return "wantread";
    case SSL_ERROR_WANT_WRITE: return "wantwrite";
    case SSL_ERROR_WANT_CONNECT: return "'connect' not completed";
    case SSL_ERROR_WANT_ACCEPT: return "'accept' not completed";
    case SSL_ERROR_WANT_X509_LOOKUP: return "Waiting for callback";
    case SSL_ERROR_SYSCALL: return "System error";
    case SSL_ERROR_SSL: return ERR_reason_error_string(ERR_get_error());
    default: return "Unknown SSL error";
    }
  }
  return socket_strerror(err);
}

/**
 * Close the connection before the GC collect the object.
 */
static int meth_destroy(lua_State *L)
{
  p_ssl ssl = (p_ssl) lua_touserdata(L, 1);
  if (ssl->ssl) {
    socket_setblocking(&ssl->sock);
    SSL_shutdown(ssl->ssl);
    socket_destroy(&ssl->sock);
    SSL_free(ssl->ssl);
    ssl->ssl = NULL;
  }
  return 0;
}

/**
 * Perform the TLS/SSL handshake
 */
static int handshake(p_ssl ssl)
{
  int err;
  p_timeout tm = timeout_markstart(&ssl->tm);
  if (ssl->state == ST_SSL_CLOSED)
    return IO_CLOSED;
  for ( ; ; ) {
    ERR_clear_error();
    err = SSL_do_handshake(ssl->ssl);
    ssl->error = SSL_get_error(ssl->ssl, err);
    switch(ssl->error) {
    case SSL_ERROR_NONE:
      ssl->state = ST_SSL_CONNECTED;
      return IO_DONE;
    case SSL_ERROR_WANT_READ:
      err = socket_waitfd(&ssl->sock, WAITFD_R, tm);
      if (err == IO_TIMEOUT) return IO_SSL;
      if (err != IO_DONE)    return err;
      break;
    case SSL_ERROR_WANT_WRITE:
      err = socket_waitfd(&ssl->sock, WAITFD_W, tm);
      if (err == IO_TIMEOUT) return IO_SSL;
      if (err != IO_DONE)    return err;
      break;
    case SSL_ERROR_SYSCALL:
      if (ERR_peek_error())  {
        ssl->error = SSL_ERROR_SSL;
        return IO_SSL;
      }
      if (err == 0)
        return IO_CLOSED;
      return socket_error();
    default:
      return IO_SSL;
    }
  }
  return IO_UNKNOWN;
}

/**
 * Send data
 */
static int ssl_send(void *ctx, const char *data, size_t count, size_t *sent,
   p_timeout tm)
{
  int err;
  p_ssl ssl = (p_ssl) ctx;
  if (ssl->state == ST_SSL_CLOSED)
    return IO_CLOSED;
  *sent = 0;
  for ( ; ; ) {
    ERR_clear_error();
    err = SSL_write(ssl->ssl, data, (int) count);
    ssl->error = SSL_get_error(ssl->ssl, err);
    switch(ssl->error) {
    case SSL_ERROR_NONE:
      *sent = err;
      return IO_DONE;
    case SSL_ERROR_WANT_READ: 
      err = socket_waitfd(&ssl->sock, WAITFD_R, tm);
      if (err == IO_TIMEOUT) return IO_SSL;
      if (err != IO_DONE)    return err;
      break;
    case SSL_ERROR_WANT_WRITE:
      err = socket_waitfd(&ssl->sock, WAITFD_W, tm);
      if (err == IO_TIMEOUT) return IO_SSL;
      if (err != IO_DONE)    return err;
      break;
    case SSL_ERROR_SYSCALL:
      if (ERR_peek_error())  {
        ssl->error = SSL_ERROR_SSL;
        return IO_SSL;
      }
      if (err == 0)
        return IO_CLOSED;
      return socket_error();
    default:
      return IO_SSL;
    }
  }
  return IO_UNKNOWN;
}

/**
 * Receive data
 */
static int ssl_recv(void *ctx, char *data, size_t count, size_t *got,
  p_timeout tm)
{
  int err;
  p_ssl ssl = (p_ssl) ctx;
  if (ssl->state == ST_SSL_CLOSED)
    return IO_CLOSED;
  *got = 0;
  for ( ; ; ) {
    ERR_clear_error();
    err = SSL_read(ssl->ssl, data, (int) count);
    ssl->error = SSL_get_error(ssl->ssl, err);
    switch(ssl->error) {
    case SSL_ERROR_NONE:
      *got = err;
      return IO_DONE;
    case SSL_ERROR_ZERO_RETURN:
      *got = err;
      return IO_CLOSED;
    case SSL_ERROR_WANT_READ:
      err = socket_waitfd(&ssl->sock, WAITFD_R, tm);
      if (err == IO_TIMEOUT) return IO_SSL;
      if (err != IO_DONE)    return err;
      break;
    case SSL_ERROR_WANT_WRITE:
      err = socket_waitfd(&ssl->sock, WAITFD_W, tm);
      if (err == IO_TIMEOUT) return IO_SSL;
      if (err != IO_DONE)    return err;
      break;
    case SSL_ERROR_SYSCALL:
      if (ERR_peek_error())  {
        ssl->error = SSL_ERROR_SSL;
        return IO_SSL;
      }
      if (err == 0)
        return IO_CLOSED;
      return socket_error();
    default:
      return IO_SSL;
    }
  }
  return IO_UNKNOWN;
}

/**
 * Create a new TLS/SSL object and mark it as new.
 */
static int meth_create(lua_State *L)
{
  p_ssl ssl;
  int mode = ctx_getmode(L, 1);
  SSL_CTX *ctx = ctx_getcontext(L, 1);

  if (mode == MD_CTX_INVALID) {
    lua_pushnil(L);
    lua_pushstring(L, "invalid mode");
    return 2;
  }
  ssl = (p_ssl) lua_newuserdata(L, sizeof(t_ssl));
  if (!ssl) {
    lua_pushnil(L);
    lua_pushstring(L, "error creating SSL object");
    return 2;
  }
  ssl->ssl = SSL_new(ctx);
  if (!ssl->ssl) {
    lua_pushnil(L);
    lua_pushstring(L, "error creating SSL object");
    return 2;;
  }
  ssl->state = ST_SSL_NEW;
  SSL_set_fd(ssl->ssl, (int) SOCKET_INVALID);
  SSL_set_mode(ssl->ssl,   SSL_MODE_ENABLE_PARTIAL_WRITE |
    SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);

#ifdef SSL_MODE_RELEASE_BUFFERS
  SSL_set_mode(ssl->ssl, SSL_MODE_RELEASE_BUFFERS);
#endif

  if (mode == MD_CTX_SERVER)
    SSL_set_accept_state(ssl->ssl);
  else
    SSL_set_connect_state(ssl->ssl);

  io_init(&ssl->io, (p_send) ssl_send, (p_recv) ssl_recv, 
    (p_error) ssl_ioerror, ssl);
  timeout_init(&ssl->tm, -1, -1);
  buffer_init(&ssl->buf, &ssl->io, &ssl->tm);

  luaL_getmetatable(L, "SSL:Connection");
  lua_setmetatable(L, -2);
  return 1;
}

/**
 * Buffer send function
 */
static int meth_send(lua_State *L) {
  p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection");
  return buffer_meth_send(L, &ssl->buf);
}

/**
 * Buffer receive function
 */
static int meth_receive(lua_State *L) {
  p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection");
  return buffer_meth_receive(L, &ssl->buf);
}

/**
 * Select support methods
 */
static int meth_getfd(lua_State *L)
{
  p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection");
  lua_pushnumber(L, ssl->sock);
  return 1;
}

/**
 * Set the TLS/SSL file descriptor.
 * This is done *before* the handshake.
 */
static int meth_setfd(lua_State *L)
{
  p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection");
  if (ssl->state != ST_SSL_NEW)
    luaL_argerror(L, 1, "invalid SSL object state");
  ssl->sock = luaL_checkint(L, 2);
  socket_setnonblocking(&ssl->sock);
  SSL_set_fd(ssl->ssl, (int)ssl->sock);
  return 0;
}

/**
 * Lua handshake function.
 */
static int meth_handshake(lua_State *L)
{
  p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection");
  int err = handshake(ssl);
  if (err == IO_DONE) {
    lua_pushboolean(L, 1);
    return 1;
  }
  lua_pushboolean(L, 0);
  lua_pushstring(L, ssl_ioerror((void*)ssl, err));
  return 2;
}

/**
 * Close the connection.
 */
static int meth_close(lua_State *L)
{
  p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection");
  meth_destroy(L);
  ssl->state = ST_SSL_CLOSED;
  return 0;
}

/**
 * Set timeout.
 */
static int meth_settimeout(lua_State *L)
{
  p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection");
  return timeout_meth_settimeout(L, &ssl->tm);
}

/**
 * Check if there is data in the buffer.
 */
static int meth_dirty(lua_State *L)
{
  int res = 0;
  p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection");
  if (ssl->state != ST_SSL_CLOSED)
    res = !buffer_isempty(&ssl->buf) || SSL_pending(ssl->ssl);
  lua_pushboolean(L, res);
  return 1;
}

/**
 * Return the state information about the SSL object.
 */
static int meth_want(lua_State *L)
{
  p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection");
  int code = (ssl->state == ST_SSL_CLOSED) ? SSL_NOTHING : SSL_want(ssl->ssl);
  switch(code) {
  case SSL_NOTHING: lua_pushstring(L, "nothing"); break;
  case SSL_READING: lua_pushstring(L, "read"); break;
  case SSL_WRITING: lua_pushstring(L, "write"); break;
  case SSL_X509_LOOKUP: lua_pushstring(L, "x509lookup"); break;
  }
  return 1;
}
  
/**
 * Return a pointer to SSL structure.
 */
static int meth_rawconn(lua_State *L)
{
  p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection");
  lua_pushlightuserdata(L, (void*)ssl->ssl);
  return 1;
}

/**
 * Return the compression method used.
 */
static int meth_compression(lua_State *L)
{
  const COMP_METHOD *comp;
  p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection");
  comp = SSL_get_current_compression(ssl->ssl);
  if (comp) {
    lua_pushstring(L, SSL_COMP_get_name(comp));
    return 1;
  } else {
    lua_pushboolean(L, 0);
    return 1;
  }
}

void luasec_push_asn1_objname(lua_State* L, ASN1_OBJECT *object, int no_name)
{
  char buffer[256];
  int len = OBJ_obj2txt(buffer, sizeof(buffer), object, no_name);
  lua_pushlstring(L, buffer, min(sizeof(buffer),len));
}

void luasec_push_asn1_string(lua_State* L, ASN1_STRING *string)
{
  if(string)
    lua_pushlstring(L, (char*)ASN1_STRING_data(string), ASN1_STRING_length(string));
  else
    lua_pushnil(L);
}

int luasec_push_subtable(lua_State* L, int idx)
{

        lua_pushvalue(L, -1);
        lua_gettable(L, idx-1);

        if(lua_isnil(L, -1))
        {
        	lua_pop(L, 1);
        	lua_newtable(L);
        	lua_pushvalue(L, -2);
        	lua_pushvalue(L, -2);
        	lua_settable(L, idx-3);

        	lua_replace(L, -2); /* Replace key with table */
        	return 1;
        }
       	lua_replace(L, -2); /* Replace key with table */
        return 0;
}

/**
 * Return the peer certificate.
 */
static int meth_getpeercertificate(lua_State *L)
{
  X509 *peer;
  int i, j;
  p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection");
  peer = SSL_get_peer_certificate(ssl->ssl);
  if (peer == NULL) {
    /* No client certificate available */
    lua_pushboolean(L, 0);
    return 1;
  }
  else
  {
    X509_NAME *subject;
    int n_entries;

    lua_newtable(L); /* ret */

    lua_pushboolean(L, (SSL_get_verify_result(ssl->ssl) == X509_V_OK));
    lua_setfield(L, -2, "trusted");

    subject = X509_get_subject_name(peer);

    n_entries = X509_NAME_entry_count(subject);

    lua_newtable(L); /* {} */
    lua_pushvalue(L, -1);
    lua_setfield(L, -3, "subject"); /* ret.subject = {} */
    for(i = 0; i <= n_entries; i++)
    {
      X509_NAME_ENTRY *entry;
      ASN1_OBJECT *object;

      entry = X509_NAME_get_entry(subject, i);
      object = X509_NAME_ENTRY_get_object(entry);

      luasec_push_asn1_objname(L, object, 1);

      if(luasec_push_subtable(L, -2))
      {
        /* Get short/long name of the entry */
        luasec_push_asn1_objname(L, object, 0);
        lua_setfield(L, -2, "name");
      }

      luasec_push_asn1_string(L, X509_NAME_ENTRY_get_data(entry));
      lua_rawseti(L, -2, lua_objlen(L, -2)+1);

      lua_pop(L, 1);
    }
  }
  lua_pop(L, 1); /* ret.subject */

  lua_newtable(L); /* {} */
  lua_pushvalue(L, -1);
  lua_setfield(L, -3, "extensions"); /* ret.extensions = {} */
  
  i = -1;
  while((i = X509_get_ext_by_NID(peer, NID_subject_alt_name, i)) != -1)
  {
    X509_EXTENSION *extension;
    STACK_OF(GENERAL_NAME) *values;
    int n_general_names;
    
    extension = X509_get_ext(peer, i);
    if(extension == NULL)
      break;
    
    values = X509V3_EXT_d2i(extension);
    if(values == NULL)
      break;

    /* Push ret.extensions[oid] */
    luasec_push_asn1_objname(L, extension->object, 1);
    luasec_push_subtable(L, -2);
    /* Set ret.extensions[oid].name = name */
    luasec_push_asn1_objname(L, extension->object, 0);
    lua_setfield(L, -2, "name");

    n_general_names = sk_GENERAL_NAME_num(values);
    for(j = 0; j < n_general_names; j++)
    {
      GENERAL_NAME *general_name;

      general_name = sk_GENERAL_NAME_value(values, j);

      switch(general_name->type)
      {
      case GEN_OTHERNAME:
      {
        OTHERNAME *otherName = general_name->d.otherName;

        luasec_push_asn1_objname(L, otherName->type_id, 1);

        if(luasec_push_subtable(L, -2))
        {
          luasec_push_asn1_objname(L, otherName->type_id, 0);
          lua_setfield(L, -2, "name");
        }

        luasec_push_asn1_string(L, otherName->value->value.asn1_string);
        lua_rawseti(L, -2, lua_objlen(L, -2)+1);

        lua_pop(L, 1);
        break;
      }
      case GEN_DNS:
      {
        lua_pushstring(L, "dNSName");
	luasec_push_subtable(L, -2);
        luasec_push_asn1_string(L, general_name->d.dNSName);
        lua_rawseti(L, -2, lua_objlen(L, -2)+1);
        lua_pop(L, 1);
        break;
      }
      default:
        break;
      }
    }

    lua_pop(L, 1); /* array */
    i++; /* Next extension */
  }
  lua_pop(L, 1); /* ret.extensions */

  return 1;
}

static int meth_getfinished(lua_State *L)
{
  p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection");
  SSL *conn = ssl->ssl;
  char *buffer = NULL;
  size_t len = 0;
  if ((len = SSL_get_finished(conn, NULL, 0)) != 0) {
    buffer = malloc(len);
    if (buffer == NULL) return 0;
    len = SSL_get_finished(conn, buffer, len);
    lua_pushlstring(L, buffer, len);
    free(buffer);
    return 1;
  } else {
    return 0;
  }
}

static int meth_getpeerfinished(lua_State *L)
{
  p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection");
  SSL *conn = ssl->ssl;
  char *buffer = NULL;
  size_t len = 0;
  if ((len = SSL_get_peer_finished(conn, NULL, 0)) != 0) {
    buffer = malloc(len);
    if (buffer == NULL) return 0;
    len = SSL_get_peer_finished(conn, buffer, len);
    lua_pushlstring(L, buffer, len);
    free(buffer);
    return 1;
  } else {
    return 0;
  }
}

/*---------------------------------------------------------------------------*/


/**
 * SSL metamethods 
 */
static luaL_Reg meta[] = {
  {"close",             meth_close},
  {"getfd",             meth_getfd},
  {"dirty",             meth_dirty},
  {"dohandshake",       meth_handshake},
  {"receive",           meth_receive},
  {"send",              meth_send},
  {"settimeout",        meth_settimeout},
  {"want",              meth_want},
  {"compression",       meth_compression},
  {"getpeercertificate",meth_getpeercertificate},
  {"getfinished",       meth_getfinished},
  {"getpeerfinished",   meth_getpeerfinished},
  {NULL,                NULL}
};

/**
 * SSL functions 
 */
static luaL_Reg funcs[] = {
  {"create",        meth_create},
  {"setfd",         meth_setfd},
  {"rawconnection", meth_rawconn},
  {NULL,            NULL}
};

/**
 * Initialize modules
 */
LUASEC_API int luaopen_ssl_core(lua_State *L)
{
  /* Initialize SSL */
  if (!SSL_library_init()) {
    lua_pushstring(L, "unable to initialize SSL library");
    lua_error(L);
  }
  SSL_load_error_strings();

  /* Initialize internal library */
  socket_open();
   
  /* Registre the functions and tables */
  luaL_newmetatable(L, "SSL:Connection");
  lua_newtable(L);
  luaL_register(L, NULL, meta);
  lua_setfield(L, -2, "__index");
  lua_pushcfunction(L, meth_destroy);
  lua_setfield(L, -2, "__gc");

  luaL_register(L, "ssl.core", funcs);
  lua_pushnumber(L, SOCKET_INVALID);
  lua_setfield(L, -2, "invalidfd");

  return 1;
}

mercurial