Added initial Oracle driver support - functionality is complete, but may be too buggy in its current state for any serious use.

Sat, 06 Dec 2008 00:32:37 +0000

author
nrich@ii.net
date
Sat, 06 Dec 2008 00:32:37 +0000
changeset 17
21c4feaeafe7
parent 16
318e5dfd03b8
child 18
b705ba343e94

Added initial Oracle driver support - functionality is complete, but may be too buggy in its current state for any serious use.

DBI.lua file | annotate | diff | comparison | revisions
Makefile file | annotate | diff | comparison | revisions
dbd/common.c file | annotate | diff | comparison | revisions
dbd/common.h file | annotate | diff | comparison | revisions
dbd/db2/statement.c file | annotate | diff | comparison | revisions
dbd/oracle/connection.c file | annotate | diff | comparison | revisions
dbd/oracle/dbd_oracle.h file | annotate | diff | comparison | revisions
dbd/oracle/main.c file | annotate | diff | comparison | revisions
dbd/oracle/statement.c file | annotate | diff | comparison | revisions
dbd/postgresql/statement.c file | annotate | diff | comparison | revisions
--- a/DBI.lua	Fri Dec 05 09:20:31 2008 +0000
+++ b/DBI.lua	Sat Dec 06 00:32:37 2008 +0000
@@ -8,6 +8,7 @@
     PostgreSQL = 'dbdpostgresql',
     SQLite3 = 'dbdsqlite3',
     DB2 = 'dbddb2',
+    Oracle = 'dbdoracle',
 }
 
 local string = require('string')
--- a/Makefile	Fri Dec 05 09:20:31 2008 +0000
+++ b/Makefile	Sat Dec 06 00:32:37 2008 +0000
@@ -1,5 +1,5 @@
 CC=gcc
-CFLAGS=-g -pedantic -O2 -Wall -shared -fpic -I /usr/include/lua5.1 -I /usr/include/mysql -I /usr/include/postgresql/ -I /opt/ibm/db2exc/V9.5/include/ -I . 
+CFLAGS=-g -pedantic -O2 -Wall -shared -fpic -I /usr/include/lua5.1 -I /usr/include/mysql -I /usr/include/postgresql/ -I /opt/ibm/db2exc/V9.5/include/ -I /usr/lib/oracle/xe/app/oracle/product/10.2.0/client/rdbms/public/ -I . 
 AR=ar rcu
 RANLIB=ranlib
 RM=rm -f
@@ -9,20 +9,24 @@
 PSQL_LDFLAGS=$(COMMON_LDFLAGS) -lpq 
 SQLITE3_LDFLAGS=$(COMMON_LDFLAGS) -lsqlite3 
 DB2_LDFLAGS=$(COMMON_LDFLAGS) -L/opt/ibm/db2exc/V9.5/lib32 -ldb2 
+ORACLE_LDFLAGS=$(COMMON_LDFLAGS) -L/usr/lib/oracle/xe/app/oracle/product/10.2.0/client/lib/ -locixe 
 
 DBDMYSQL=dbdmysql.so
 DBDPSQL=dbdpostgresql.so
 DBDSQLITE3=dbdsqlite3.so
 DBDDB2=dbddb2.so
+DBDORACLE=dbdoracle.so
 
-MYSQL_OBJS=build/dbd_mysql_main.o build/dbd_mysql_connection.o build/dbd_mysql_statement.o
-PSQL_OBJS=build/dbd_postgresql_main.o build/dbd_postgresql_connection.o build/dbd_postgresql_statement.o
-SQLITE3_OBJS=build/dbd_sqlite3_main.o build/dbd_sqlite3_connection.o build/dbd_sqlite3_statement.o
-DB2_OBJS=build/dbd_db2_main.o build/dbd_db2_connection.o build/dbd_db2_statement.o
+OBJS=build/dbd_common.o
+MYSQL_OBJS=$(OBJS) build/dbd_mysql_main.o build/dbd_mysql_connection.o build/dbd_mysql_statement.o
+PSQL_OBJS=$(OBJS) build/dbd_postgresql_main.o build/dbd_postgresql_connection.o build/dbd_postgresql_statement.o
+SQLITE3_OBJS=$(OBJS) build/dbd_sqlite3_main.o build/dbd_sqlite3_connection.o build/dbd_sqlite3_statement.o
+DB2_OBJS=$(OBJS) build/dbd_db2_main.o build/dbd_db2_connection.o build/dbd_db2_statement.o
+ORACLE_OBJS=$(OBJS) build/dbd_oracle_main.o build/dbd_oracle_connection.o build/dbd_oracle_statement.o
 
-free: dbdmysql  dbdpsql dbdsqlite3
+free: dbdmysql dbdpsql dbdsqlite3
 
-all:  dbdmysql	dbdpsql	dbdsqlite3 dbddb2
+all:  dbdmysql dbdpsql dbdsqlite3 dbddb2 dbdoracle
 
 dbdmysql: $(MYSQL_OBJS)
 	$(CC) $(CFLAGS) $(MYSQL_OBJS) -o $(DBDMYSQL) $(MYSQL_LDFLAGS)
@@ -36,8 +40,14 @@
 dbddb2: $(DB2_OBJS)
 	$(CC) $(CFLAGS) $(DB2_OBJS) -o $(DBDDB2) $(DB2_LDFLAGS)
 
+dbdoracle: $(ORACLE_OBJS)
+	$(CC) $(CFLAGS) $(ORACLE_OBJS) -o $(DBDORACLE) $(ORACLE_LDFLAGS)
+
 clean:
-	$(RM) $(MYSQL_OBJS) $(PSQL_OBJS) $(SQLITE3_OBJS) $(DB2_OBJS) $(DBDMYSQL) $(DBDPSQL) $(DBDSQLITE3) $(DBDDB2)
+	$(RM) $(MYSQL_OBJS) $(PSQL_OBJS) $(SQLITE3_OBJS) $(DB2_OBJS) $(ORACLE_OBJS) $(DBDMYSQL) $(DBDPSQL) $(DBDSQLITE3) $(DBDDB2) $(DBDORACLE) 
+
+build/dbd_common.o: dbd/common.c dbd/common.h
+	$(CC) -c -o $@ $< $(CFLAGS)
 
 build/dbd_mysql_connection.o: dbd/mysql/connection.c dbd/mysql/dbd_mysql.h dbd/common.h 
 	$(CC) -c -o $@ $< $(CFLAGS)
@@ -67,3 +77,10 @@
 build/dbd_db2_statement.o: dbd/db2/statement.c dbd/db2/dbd_db2.h dbd/common.h
 	$(CC) -c -o $@ $< $(CFLAGS)
 
+build/dbd_oracle_connection.o: dbd/oracle/connection.c dbd/oracle/dbd_oracle.h dbd/common.h 
+	$(CC) -c -o $@ $< $(CFLAGS)
+build/dbd_oracle_main.o: dbd/oracle/main.c dbd/oracle/dbd_oracle.h dbd/common.h
+	$(CC) -c -o $@ $< $(CFLAGS)
+build/dbd_oracle_statement.o: dbd/oracle/statement.c dbd/oracle/dbd_oracle.h dbd/common.h
+	$(CC) -c -o $@ $< $(CFLAGS)
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/dbd/common.c	Sat Dec 06 00:32:37 2008 +0000
@@ -0,0 +1,99 @@
+#include <dbd/common.h>
+
+const char *strlower(char *in) {
+    char *s = in;
+
+    while(*s) {
+	*s= (*s <= 'Z' && *s >= 'A') ? (*s - 'A') + 'a' : *s;
+	s++;
+    }
+
+    return in;
+}
+
+/*
+ * replace '?' placeholders with $\d+ placeholders
+ * to be compatible with PSQL API
+ */
+char *replace_placeholders(lua_State *L, char native_prefix, const char *sql) {
+    size_t len = strlen(sql);
+    int num_placeholders = 0;
+    int extra_space = 0;
+    int i;
+    char *newsql;
+    int newpos = 1;
+    int ph_num = 1;
+    int in_quote = 0;
+    char format_str[4];
+
+    format_str[0] = native_prefix;
+    format_str[1] = '%';
+    format_str[2] = 'u';
+    format_str[3] = '\0';
+
+    /*
+     * dumb count of all '?'
+     * this will match more placeholders than necessesary
+     * but it's safer to allocate more placeholders at the
+     * cost of a few bytes than risk a buffer overflow
+     */ 
+    for (i = 1; i < len; i++) {
+	if (sql[i] == '?') {
+	    num_placeholders++;
+	}
+    }
+    
+    /*
+     * this is MAX_PLACEHOLDER_SIZE-1 because the '?' is 
+     * replaced with '$'
+     */ 
+    extra_space = num_placeholders * (MAX_PLACEHOLDER_SIZE-1); 
+
+    /*
+     * allocate a new string for the converted SQL statement
+     */
+    newsql = malloc(sizeof(char) * (len+extra_space+1));
+    memset(newsql, 0, sizeof(char) * (len+extra_space+1));
+    
+    /* 
+     * copy first char. In valid SQL this cannot be a placeholder
+     */
+    newsql[0] = sql[0];
+
+    /* 
+     * only replace '?' not in a single quoted string
+     */
+    for (i = 1; i < len; i++) {
+	/*
+	 * don't change the quote flag if the ''' is preceded 
+	 * by a '\' to account for escaping
+	 */
+	if (sql[i] == '\'' && sql[i-1] != '\\') {
+	    in_quote = !in_quote;
+	}
+
+	if (sql[i] == '?' && !in_quote) {
+	    size_t n;
+
+	    if (ph_num > MAX_PLACEHOLDERS) {
+		luaL_error(L, "Sorry, you are using more than %d placeholders. Use ${num} format instead", MAX_PLACEHOLDERS);
+	    }
+
+	    n = snprintf(&newsql[newpos], MAX_PLACEHOLDER_SIZE, format_str, ph_num++);
+
+	    newpos += n;
+	} else {
+	    newsql[newpos] = sql[i];
+	    newpos++;
+	}
+    }
+
+    /* 
+     * terminate string on the last position 
+     */
+    newsql[newpos] = '\0';
+
+    /* fprintf(stderr, "[%s]\n", newsql); */
+    return newsql;
+}
+
--- a/dbd/common.h	Fri Dec 05 09:20:31 2008 +0000
+++ b/dbd/common.h	Sat Dec 06 00:32:37 2008 +0000
@@ -88,6 +88,13 @@
 } lua_push_type_t;
 
 /*
+ * used for placeholder translations
+ * from '?' to the .\d{4}
+ */
+#define MAX_PLACEHOLDERS        9999 
+#define MAX_PLACEHOLDER_SIZE    (1+4) /* .\d{4} */
+
+/*
  *
  * Common error strings
  * defined here for consistency in driver implementations
@@ -112,3 +119,15 @@
 #define DBI_ERR_ALLOC_RESULT	    "Error allocating result set: %s"
 #define DBI_ERR_DESC_RESULT	    "Error describing result set: %s"
 #define DBI_ERR_BINDING_TYPE_ERR    "Unknown or unsupported type `%s'"
+
+/*
+ * convert string to lower case
+ */
+const char *strlower(char *in);
+
+/*
+ * replace '?' placeholders with .\d+ placeholders
+ * to be compatible with the native driver API
+ */
+char *replace_placeholders(lua_State *L, char native_prefix, const char *sql);
+
--- a/dbd/db2/statement.c	Fri Dec 05 09:20:31 2008 +0000
+++ b/dbd/db2/statement.c	Sat Dec 06 00:32:37 2008 +0000
@@ -1,22 +1,5 @@
 #include "dbd_db2.h"
 
-#define MAX_COLUMNS	255
-
-#ifndef max
-#define max(a,b) (a > b ? a : b)
-#endif
-
-static const char *strlower(char *in) {
-    char *s = in;
-
-    while(*s) {
-	*s= (*s <= 'Z' && *s >= 'A') ? (*s - 'A') + 'a' : *s;
-	s++;
-    }
-
-    return in;
-}
-
 static lua_push_type_t db2_to_lua_push(unsigned int db2_type, int len) {
     lua_push_type_t lua_type;
 
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/dbd/oracle/connection.c	Sat Dec 06 00:32:37 2008 +0000
@@ -0,0 +1,244 @@
+#include "dbd_oracle.h"
+
+int dbd_oracle_statement_create(lua_State *L, connection_t *conn, const char *sql_query);
+
+static int commit(connection_t *conn) {
+    int rc = OCITransCommit(conn->svc, conn->err, OCI_DEFAULT);
+    return rc;
+}
+
+static int rollback(connection_t *conn) {
+    int rc = OCITransRollback(conn->svc, conn->err, OCI_DEFAULT);
+    return rc;
+}
+
+
+/* 
+ * connection,err = DBD.Oracle.New(dbfile)
+ */
+static int connection_new(lua_State *L) {
+    int n = lua_gettop(L);
+
+    int rc = 0;
+
+    const char *user = NULL;
+    const char *password = NULL;
+    const char *db = NULL;
+
+    OCIEnv *env = NULL;
+    OCIError *err = NULL;
+    OCISvcCtx *svc = NULL;
+
+    connection_t *conn = NULL;
+
+    /* db, user, password */
+    switch(n) {
+    case 5:
+    case 4:
+    case 3:
+	if (lua_isnil(L, 3) == 0) 
+	    password = luaL_checkstring(L, 3);
+    case 2:
+	if (lua_isnil(L, 2) == 0) 
+	    user = luaL_checkstring(L, 2);
+    case 1:
+        /*
+         * db is the only mandatory parameter
+         */
+	db = luaL_checkstring(L, 1);
+    }
+
+    /*
+     * initialise OCI
+     */
+    OCIInitialize((ub4) OCI_DEFAULT, (dvoid *)0, (dvoid * (*)(dvoid *, size_t))0, (dvoid * (*)(dvoid *, dvoid *, size_t))0, (void (*)(dvoid *, dvoid *))0);
+
+    /*
+     * initialise environment
+     */
+    OCIEnvInit((OCIEnv **)&env, OCI_DEFAULT, 0, (dvoid **)0);
+
+    /* 
+     * server contexts 
+     */
+    OCIHandleAlloc((dvoid *)env, (dvoid **)&err, OCI_HTYPE_ERROR, 0, (dvoid **)0);
+    OCIHandleAlloc((dvoid *)env, (dvoid **)&svc, OCI_HTYPE_SVCCTX, 0, (dvoid **)0);
+
+    /*
+     * connect to database server
+     */
+    rc = OCILogon(env, err, &svc, user, strlen(user), password, strlen(password), db, strlen(db));
+    if (rc != 0) {
+	char errbuf[100];
+	int errcode;
+
+	OCIErrorGet((dvoid *)err, (ub4) 1, (text *) NULL, &errcode, errbuf, (ub4) sizeof(errbuf), OCI_HTYPE_ERROR);
+	
+	lua_pushnil(L);
+	lua_pushfstring(L, DBI_ERR_CONNECTION_FAILED, errbuf);
+
+	return 2;
+    }
+
+    conn = (connection_t *)lua_newuserdata(L, sizeof(connection_t));
+    conn->oracle = env;
+    conn->err = err;
+    conn->svc = svc;
+    conn->autocommit = 0;
+
+    luaL_getmetatable(L, DBD_ORACLE_CONNECTION);
+    lua_setmetatable(L, -2);
+
+    return 1;
+}
+
+/*
+ * success = connection:autocommit(on)
+ */
+static int connection_autocommit(lua_State *L) {
+    connection_t *conn = (connection_t *)luaL_checkudata(L, 1, DBD_ORACLE_CONNECTION);
+    int on = lua_toboolean(L, 2); 
+    int err = 1;
+
+    if (conn->oracle) {
+	if (on)
+	    rollback(conn);
+
+	conn->autocommit = on;
+	err = 0;
+    }
+
+    lua_pushboolean(L, !err);
+    return 1;
+}
+
+
+/*
+ * success = connection:close()
+ */
+static int connection_close(lua_State *L) {
+    connection_t *conn = (connection_t *)luaL_checkudata(L, 1, DBD_ORACLE_CONNECTION);
+    int disconnect = 0;   
+
+    if (conn->oracle) {
+	rollback(conn);
+
+	OCILogoff(conn->svc, conn->err);
+	
+	if (conn->svc)
+	    OCIHandleFree((dvoid *)conn->svc, OCI_HTYPE_ENV);
+        if (conn->err)
+            OCIHandleFree((dvoid *)conn->err, OCI_HTYPE_ERROR);
+
+	disconnect = 1;
+	conn->oracle = NULL;
+    }
+
+    lua_pushboolean(L, disconnect);
+    return 1;
+}
+
+/*
+ * success = connection:commit()
+ */
+static int connection_commit(lua_State *L) {
+    connection_t *conn = (connection_t *)luaL_checkudata(L, 1, DBD_ORACLE_CONNECTION);
+    int err = 1;
+
+    if (conn->oracle) {
+	err = commit(conn);
+    }
+
+    lua_pushboolean(L, !err);
+    return 1;
+}
+
+/*
+ * ok = connection:ping()
+ */
+static int connection_ping(lua_State *L) {
+    connection_t *conn = (connection_t *)luaL_checkudata(L, 1, DBD_ORACLE_CONNECTION);
+    int ok = 0;   
+
+    if (conn->oracle) {
+	ok = 1;
+    }
+
+    lua_pushboolean(L, ok);
+    return 1;
+}
+
+/*
+ * statement,err = connection:prepare(sql_str)
+ */
+static int connection_prepare(lua_State *L) {
+    connection_t *conn = (connection_t *)luaL_checkudata(L, 1, DBD_ORACLE_CONNECTION);
+
+    if (conn->oracle) {
+	return dbd_oracle_statement_create(L, conn, luaL_checkstring(L, 2));
+    }
+
+    lua_pushnil(L);    
+    lua_pushstring(L, DBI_ERR_DB_UNAVAILABLE);
+    return 2;
+}
+
+/*
+ * success = connection:rollback()
+ */
+static int connection_rollback(lua_State *L) {
+    connection_t *conn = (connection_t *)luaL_checkudata(L, 1, DBD_ORACLE_CONNECTION);
+    int err = 1;
+
+    if (conn->oracle) {
+	err = rollback(conn);
+    }
+
+    lua_pushboolean(L, !err);
+    return 1;
+}
+
+/*
+ * __gc 
+ */
+static int connection_gc(lua_State *L) {
+    /* always close the connection */
+    connection_close(L);
+
+    return 0;
+}
+
+int dbd_oracle_connection(lua_State *L) {
+    /*
+     * instance methods
+     */
+    static const luaL_Reg connection_methods[] = {
+	{"autocommit", connection_autocommit},
+	{"close", connection_close},
+	{"commit", connection_commit},
+	{"ping", connection_ping},
+	{"prepare", connection_prepare},
+	{"rollback", connection_rollback},
+	{NULL, NULL}
+    };
+
+    /*
+     * class methods
+     */
+    static const luaL_Reg connection_class_methods[] = {
+	{"New", connection_new},
+	{NULL, NULL}
+    };
+
+    luaL_newmetatable(L, DBD_ORACLE_CONNECTION);
+    luaL_register(L, 0, connection_methods);
+    lua_pushvalue(L,-1);
+    lua_setfield(L, -2, "__index");
+
+    lua_pushcfunction(L, connection_gc);
+    lua_setfield(L, -2, "__gc");
+
+    luaL_register(L, DBD_ORACLE_CONNECTION, connection_class_methods);
+
+    return 1;    
+}
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/dbd/oracle/dbd_oracle.h	Sat Dec 06 00:32:37 2008 +0000
@@ -0,0 +1,38 @@
+#include <oci.h>
+#include <dbd/common.h>
+
+#define DBD_ORACLE_CONNECTION	"DBD.Oracle.Connection"
+#define DBD_ORACLE_STATEMENT	"DBD.Oracle.Statement"
+
+typedef struct _bindparams {
+    OCIParam *param;
+    text *name;
+    ub4 name_len;
+    ub2 data_type;
+    ub2 max_len;
+    char *data;
+    OCIDefine *define;
+    sb2 null;
+} bindparams_t;
+
+/*
+ * connection object
+ */
+typedef struct _connection {
+    OCIEnv *oracle;
+    OCISvcCtx *svc;
+    OCIError *err;
+    OCIServer *srv;
+    OCISession *auth;
+    int autocommit;
+} connection_t;
+
+/*
+ * statement object
+ */
+typedef struct _statement {
+    OCIStmt *stmt;
+    connection_t *conn;
+    int num_columns;
+} statement_t;
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/dbd/oracle/main.c	Sat Dec 06 00:32:37 2008 +0000
@@ -0,0 +1,15 @@
+#include "dbd_oracle.h"
+
+int dbd_oracle_connection(lua_State *L);
+int dbd_oracle_statement(lua_State *L);
+
+/* 
+ * library entry point
+ */
+int luaopen_dbdoracle(lua_State *L) {
+    dbd_oracle_connection(L);
+    dbd_oracle_statement(L); 
+
+    return 1;
+}
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/dbd/oracle/statement.c	Sat Dec 06 00:32:37 2008 +0000
@@ -0,0 +1,432 @@
+#include "dbd_oracle.h"
+
+/*
+ * Converts SQLite types to Lua types
+ */
+static lua_push_type_t oracle_to_lua_push(unsigned int oracle_type, int null) {
+    lua_push_type_t lua_type;
+
+    if (null)
+	return LUA_PUSH_NIL;
+
+    switch(oracle_type) {
+    case SQLT_NUM:
+    case SQLT_FLT:
+	lua_type = LUA_PUSH_NUMBER;
+	break;
+    case SQLT_INT:
+	lua_type = LUA_PUSH_INTEGER;
+	break;
+    default:
+        lua_type = LUA_PUSH_STRING;
+    }
+
+    return lua_type;
+}
+
+/*
+ * success = statement:close()
+ */
+int statement_close(lua_State *L) {
+    statement_t *statement = (statement_t *)luaL_checkudata(L, 1, DBD_ORACLE_STATEMENT);
+    int ok = 0;
+
+    if (statement->stmt) {
+	int rc;
+
+	rc = OCIHandleFree((dvoid *)statement->stmt, OCI_HTYPE_STMT);    /* Free handles */	
+
+	statement->stmt = NULL;
+    }
+
+    lua_pushboolean(L, ok);
+    return 1;
+}
+
+/*
+ * success,err = statement:execute(...)
+ */
+int statement_execute(lua_State *L) {
+    int n = lua_gettop(L);
+    statement_t *statement = (statement_t *)luaL_checkudata(L, 1, DBD_ORACLE_STATEMENT);
+    int p;
+    int errflag = 0;
+    const char *errstr = NULL;
+    int expected_params;
+    int num_bind_params = n - 1;
+    int num_columns;
+    int rc;
+
+    char errbuf[100];
+    int errcode;
+
+    ub2 type;
+
+    if (!statement->stmt) {
+	lua_pushboolean(L, 0);
+	lua_pushstring(L, DBI_ERR_EXECUTE_INVALID);
+	return 2;
+    }
+
+    for (p = 2; p <= n; p++) {
+	int i = p - 1;
+	int type = lua_type(L, p);
+	char err[64];
+	const char *value;
+
+	OCIBind *bnd = (OCIBind *)0;
+
+	switch(type) {
+	case LUA_TNIL:
+	    errflag = OCIBindByPos(
+		statement->stmt, 
+		&bnd, 
+		statement->conn->err, 
+		i, 
+		NULL, 
+		0, 
+		SQLT_CHR, 
+		(dvoid *)0, 
+		(ub2 *)0, 
+		(ub2 *)0, 
+		(ub4)0, 
+		(ub4 *)0,
+		OCI_DEFAULT);
+	    break;
+	case LUA_TNUMBER:
+	case LUA_TSTRING:
+	case LUA_TBOOLEAN:
+	    value = lua_tostring(L, p);
+
+	    errflag = OCIBindByPos(
+		statement->stmt, 
+		&bnd, 
+		statement->conn->err, 
+		i, 
+		value, 
+		strlen(value), 
+		SQLT_CHR, 
+		(dvoid *)0, 
+		(ub2 *)0, 
+		(ub2 *)0, 
+		(ub4)0, 
+		(ub4 *)0,
+		(ub4)OCI_DEFAULT);
+	    break;
+	default:
+	    /*
+	     * Unknown/unsupported value type
+	     */
+	    errflag = 1;
+            snprintf(err, sizeof(err)-1, DBI_ERR_BINDING_TYPE_ERR, lua_typename(L, type));
+            errstr = err;
+	}
+
+	if (errflag)
+	    break;
+    }   
+
+    if (errflag) {
+	lua_pushboolean(L, 0);
+	if (errstr)
+	    lua_pushfstring(L, DBI_ERR_BINDING_PARAMS, errstr);
+	else {
+	    OCIErrorGet((dvoid *)statement->conn->err, (ub4) 1, (text *) NULL, &errcode, errbuf, (ub4) sizeof(errbuf), OCI_HTYPE_ERROR);
+
+	    lua_pushfstring(L, DBI_ERR_BINDING_PARAMS, errbuf);
+	}
+    
+	return 2;
+    }
+
+    /* 
+     * statement type 
+     */
+    rc = OCIAttrGet(
+	(dvoid *)statement->stmt, 
+	(ub4)OCI_HTYPE_STMT, 
+	(dvoid *)&type, 
+	(ub4 *)0, 
+	(ub4)OCI_ATTR_STMT_TYPE, 
+	statement->conn->err
+    );
+
+    if (rc) {
+	OCIErrorGet((dvoid *)statement->conn->err, (ub4) 1, (text *) NULL, &errcode, errbuf, (ub4) sizeof(errbuf), OCI_HTYPE_ERROR);
+
+	lua_pushboolean(L, 0);
+	lua_pushfstring(L, "Error getting type: %s", errbuf);
+
+	return 2;
+    }
+
+    /*
+     * execute statement
+     */
+    rc = OCIStmtExecute(
+	statement->conn->svc, 
+	statement->stmt, 
+	statement->conn->err, 
+	type == OCI_STMT_SELECT ? 0 : 1, 
+	(ub4)0, 
+	(CONST OCISnapshot *)NULL, 
+	(OCISnapshot *)NULL, 
+	statement->conn->autocommit ? OCI_COMMIT_ON_SUCCESS : OCI_DEFAULT
+    );
+
+    if (rc) {
+	OCIErrorGet((dvoid *)statement->conn->err, (ub4) 1, (text *) NULL, &errcode, errbuf, (ub4) sizeof(errbuf), OCI_HTYPE_ERROR);
+
+	lua_pushboolean(L, 0);
+	lua_pushfstring(L, DBI_ERR_BINDING_PARAMS, errbuf);
+
+	return 2;
+    }
+
+    /* 
+     * get number of columns 
+     */
+    rc = OCIAttrGet(
+	(dvoid *)statement->stmt, 
+	(ub4)OCI_HTYPE_STMT,
+        (dvoid *)&num_columns, 
+	(ub4 *)0, 
+	(ub4)OCI_ATTR_PARAM_COUNT,
+        statement->conn->err
+    );
+
+    if (rc) {
+	OCIErrorGet((dvoid *)statement->conn->err, (ub4) 1, (text *) NULL, &errcode, errbuf, (ub4) sizeof(errbuf), OCI_HTYPE_ERROR);
+
+	lua_pushboolean(L, 0);
+	lua_pushfstring(L, DBI_ERR_BINDING_PARAMS, errbuf);
+
+	return 2;
+    }
+
+    statement->num_columns = num_columns;
+
+    lua_pushboolean(L, 1);
+    return 1;
+}
+
+/*
+ * must be called after an execute
+ */
+static int statement_fetch_impl(lua_State *L, statement_t *statement, int named_columns) {
+    int rc;
+    sword status;
+    int i;
+    bindparams_t *bind;
+
+    char errbuf[100];
+    int errcode;
+
+    if (!statement->stmt) {
+	luaL_error(L, DBI_ERR_FETCH_INVALID);
+	return 0;
+    }
+
+    bind = (bindparams_t *)malloc(sizeof(bindparams_t) * statement->num_columns);
+    memset(bind, 0, sizeof(bindparams_t) * statement->num_columns);
+
+    for (i = 0; i < statement->num_columns; i++) {
+	rc = OCIParamGet(statement->stmt, OCI_HTYPE_STMT, statement->conn->err, (dvoid **)&bind[i].param, i+1);
+	if (rc) {
+	    OCIErrorGet((dvoid *)statement->conn->err, (ub4) 1, (text *) NULL, &errcode, errbuf, (ub4) sizeof(errbuf), OCI_HTYPE_ERROR);
+	    luaL_error(L, "param get %s", errbuf);
+	}
+
+	rc = OCIAttrGet(bind[i].param, OCI_DTYPE_PARAM, (dvoid *)&(bind[i].name), (ub4 *)&(bind[i].name_len), OCI_ATTR_NAME, statement->conn->err);
+	if (rc) {
+	    OCIErrorGet((dvoid *)statement->conn->err, (ub4) 1, (text *) NULL, &errcode, errbuf, (ub4) sizeof(errbuf), OCI_HTYPE_ERROR);
+	    luaL_error(L, "name get %s", errbuf);
+	}
+
+	rc = OCIAttrGet(bind[i].param, OCI_DTYPE_PARAM, (dvoid *)&(bind[i].data_type), (ub4 *)0, OCI_ATTR_DATA_TYPE, statement->conn->err);
+	if (rc) {
+	    OCIErrorGet((dvoid *)statement->conn->err, (ub4) 1, (text *) NULL, &errcode, errbuf, (ub4) sizeof(errbuf), OCI_HTYPE_ERROR);
+	    luaL_error(L, "datatype get %s", errbuf);
+	}
+
+	rc = OCIAttrGet(bind[i].param, OCI_DTYPE_PARAM, (dvoid *)&(bind[i].max_len), 0, OCI_ATTR_DATA_SIZE, statement->conn->err);
+	if (rc) {
+	    OCIErrorGet((dvoid *)statement->conn->err, (ub4) 1, (text *) NULL, &errcode, errbuf, (ub4) sizeof(errbuf), OCI_HTYPE_ERROR);
+	    luaL_error(L, "datasize get %s", errbuf);
+	}
+
+	bind[i].data = calloc(bind[i].max_len+1, sizeof(char));
+	rc = OCIDefineByPos(statement->stmt, &bind[i].define, statement->conn->err, (ub4)i+1, bind[i].data, bind[i].max_len, SQLT_STR, (dvoid *)&(bind[i].null), (ub2 *)0, (ub2 *)0, (ub4)OCI_DEFAULT);
+	if (rc) {
+	    OCIErrorGet((dvoid *)statement->conn->err, (ub4) 1, (text *) NULL, &errcode, errbuf, (ub4)sizeof(errbuf), OCI_HTYPE_ERROR);
+	    luaL_error(L, "define by pos %s", errbuf);
+	}
+    }
+
+    status = OCIStmtFetch(statement->stmt, statement->conn->err, 1, OCI_FETCH_NEXT, OCI_DEFAULT);
+
+    if (status == OCI_NO_DATA) {
+	/* No more rows */
+        lua_pushnil(L);
+        return 1;
+    } else if (status != OCI_SUCCESS) {
+	OCIErrorGet((dvoid *)statement->conn->err, (ub4)1, (text *)NULL, &errcode, errbuf, (ub4)sizeof(errbuf), OCI_HTYPE_ERROR);
+	luaL_error(L, DBI_ERR_FETCH_FAILED, errbuf);
+    }
+
+    if (statement->num_columns) {
+	int i;
+	int d = 1;
+
+	lua_newtable(L);
+
+	for (i = 0; i < statement->num_columns; i++) {
+	    lua_push_type_t lua_push = oracle_to_lua_push(bind[i].data_type, bind[i].null);
+	    const char *name = strlower(bind[i].name);
+	    const char *data = bind[i].data;
+
+	    if (lua_push == LUA_PUSH_NIL) {
+                if (named_columns) {
+                    LUA_PUSH_ATTRIB_NIL(name);
+                } else {
+                    LUA_PUSH_ARRAY_NIL(d);
+                }
+            } else if (lua_push == LUA_PUSH_INTEGER) {
+		int val = atoi(data);
+
+                if (named_columns) {
+                    LUA_PUSH_ATTRIB_INT(name, val);
+                } else {
+                    LUA_PUSH_ARRAY_INT(d, val);
+                }
+            } else if (lua_push == LUA_PUSH_NUMBER) {
+		double val = strtod(data, NULL);
+
+                if (named_columns) {
+                    LUA_PUSH_ATTRIB_FLOAT(name, val);
+                } else {
+                    LUA_PUSH_ARRAY_FLOAT(d, val);
+                }
+            } else if (lua_push == LUA_PUSH_STRING) {
+                if (named_columns) {
+                    LUA_PUSH_ATTRIB_STRING(name, data);
+                } else {
+                    LUA_PUSH_ARRAY_STRING(d, data);
+                }
+            } else if (lua_push == LUA_PUSH_BOOLEAN) {
+		int val = 1;
+
+                if (named_columns) {
+                    LUA_PUSH_ATTRIB_BOOL(name, val);
+                } else {
+                    LUA_PUSH_ARRAY_BOOL(d, val);
+                }
+            } else {
+                luaL_error(L, DBI_ERR_UNKNOWN_PUSH);
+            }
+	}
+    } else {
+	/* 
+         * no columns returned by statement?
+         */ 
+	lua_pushnil(L);
+    }
+
+    return 1;    
+}
+
+static int next_iterator(lua_State *L) {
+    statement_t *statement = (statement_t *)luaL_checkudata(L, lua_upvalueindex(1), DBD_ORACLE_STATEMENT);
+    int named_columns = lua_toboolean(L, lua_upvalueindex(2));
+
+    return statement_fetch_impl(L, statement, named_columns);
+}
+
+/*
+ * table = statement:fetch(named_indexes)
+ */
+static int statement_fetch(lua_State *L) {
+    statement_t *statement = (statement_t *)luaL_checkudata(L, 1, DBD_ORACLE_STATEMENT);
+    int named_columns = lua_toboolean(L, 2);
+
+    return statement_fetch_impl(L, statement, named_columns);
+}
+
+/*
+ * iterfunc = statement:rows(named_indexes)
+ */
+static int statement_rows(lua_State *L) {
+    if (lua_gettop(L) == 1) {
+        lua_pushvalue(L, 1);
+        lua_pushboolean(L, 0);
+    } else {
+        lua_pushvalue(L, 1);
+        lua_pushboolean(L, lua_toboolean(L, 2));
+    }
+
+    lua_pushcclosure(L, next_iterator, 2);
+    return 1;
+}
+
+/*
+ * __gc
+ */
+static int statement_gc(lua_State *L) {
+    /* always free the handle */
+    statement_close(L);
+
+    return 0;
+}
+
+int dbd_oracle_statement_create(lua_State *L, connection_t *conn, const char *sql_query) { 
+    int rc;
+    statement_t *statement = NULL;
+    OCIStmt *stmt;
+    char *new_sql;
+
+    /*
+     * convert SQL string into a Oracle API compatible SQL statement
+     */
+    new_sql = replace_placeholders(L, ':', sql_query);
+
+    rc = OCIHandleAlloc((dvoid *)conn->oracle, (dvoid **)&stmt, OCI_HTYPE_STMT, 0, (dvoid **)0);
+    rc = OCIStmtPrepare(stmt, conn->err, new_sql, strlen(new_sql), (ub4)OCI_NTV_SYNTAX, (ub4)OCI_DEFAULT);
+
+    free(new_sql);
+
+    statement = (statement_t *)lua_newuserdata(L, sizeof(statement_t));
+    statement->conn = conn;
+    statement->stmt = stmt;
+    statement->num_columns = 0;
+
+    luaL_getmetatable(L, DBD_ORACLE_STATEMENT);
+    lua_setmetatable(L, -2);
+
+    return 1;
+} 
+
+int dbd_oracle_statement(lua_State *L) {
+    static const luaL_Reg statement_methods[] = {
+	{"close", statement_close},
+	{"execute", statement_execute},
+	{"fetch", statement_fetch},
+	{"rows", statement_rows},
+	{NULL, NULL}
+    };
+
+    static const luaL_Reg statement_class_methods[] = {
+	{NULL, NULL}
+    };
+
+    luaL_newmetatable(L, DBD_ORACLE_STATEMENT);
+    luaL_register(L, 0, statement_methods);
+    lua_pushvalue(L,-1);
+    lua_setfield(L, -2, "__index");
+
+    lua_pushcfunction(L, statement_gc);
+    lua_setfield(L, -2, "__gc");
+
+    luaL_register(L, DBD_ORACLE_STATEMENT, statement_class_methods);
+
+    return 1;    
+}
--- a/dbd/postgresql/statement.c	Fri Dec 05 09:20:31 2008 +0000
+++ b/dbd/postgresql/statement.c	Sat Dec 06 00:32:37 2008 +0000
@@ -1,8 +1,5 @@
 #include "dbd_postgresql.h"
 
-#define MAX_PLACEHOLDERS	9999 
-#define MAX_PLACEHOLDER_SIZE	(1+4) /* $\d{4} */
-
 static lua_push_type_t postgresql_to_lua_push(unsigned int postgresql_type) {
     lua_push_type_t lua_type;
 
@@ -28,85 +25,6 @@
     return lua_type;
 }
 
-/*
- * replace '?' placeholders with $\d+ placeholders
- * to be compatible with PSQL API
- */
-static char *replace_placeholders(lua_State *L, const char *sql) {
-    size_t len = strlen(sql);
-    int num_placeholders = 0;
-    int extra_space = 0;
-    int i;
-    char *newsql;
-    int newpos = 1;
-    int ph_num = 1;
-    int in_quote = 0;
-
-    /*
-     * dumb count of all '?'
-     * this will match more placeholders than necessesary
-     * but it's safer to allocate more placeholders at the
-     * cost of a few bytes than risk a buffer overflow
-     */ 
-    for (i = 1; i < len; i++) {
-	if (sql[i] == '?') {
-	    num_placeholders++;
-	}
-    }
-    
-    /*
-     * this is MAX_PLACEHOLDER_SIZE-1 because the '?' is 
-     * replaced with '$'
-     */ 
-    extra_space = num_placeholders * (MAX_PLACEHOLDER_SIZE-1); 
-
-    /*
-     * allocate a new string for the converted SQL statement
-     */
-    newsql = malloc(sizeof(char) * (len+extra_space+1));
-    memset(newsql, 0, sizeof(char) * (len+extra_space+1));
-    
-    /* 
-     * copy first char. In valid SQL this cannot be a placeholder
-     */
-    newsql[0] = sql[0];
-
-    /* 
-     * only replace '?' not in a single quoted string
-     */
-    for (i = 1; i < len; i++) {
-	/*
-	 * don't change the quote flag if the ''' is preceded 
-	 * bt a '\' to account for escaping
-	 */
-	if (sql[i] == '\'' && sql[i-1] != '\\') {
-	    in_quote = !in_quote;
-	}
-
-	if (sql[i] == '?' && !in_quote) {
-	    size_t n;
-
-	    if (ph_num > MAX_PLACEHOLDERS) {
-		luaL_error(L, "Sorry, you are using more than %d placeholders. Use ${num} format instead", MAX_PLACEHOLDERS);
-	    }
-
-	    n = snprintf(&newsql[newpos], MAX_PLACEHOLDER_SIZE, "$%u", ph_num++);
-
-	    newpos += n;
-	} else {
-	    newsql[newpos] = sql[i];
-	    newpos++;
-	}
-    }
-
-    /* 
-     * terminate string on the last position 
-     */
-    newsql[newpos] = '\0';
-
-    /* fprintf(stderr, "[%s]\n", newsql); */
-    return newsql;
-}
 
 /*
  * success = statement:close()
@@ -359,7 +277,7 @@
     /*
      * convert SQL string into a PSQL API compatible SQL statement
      */ 
-    new_sql = replace_placeholders(L, sql_query);
+    new_sql = replace_placeholders(L, '$', sql_query);
 
     snprintf(name, IDLEN, "%017u", ++conn->statement_id);
 

mercurial