aeslua/aes.lua

changeset 0
598d09faf89c
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/aeslua/aes.lua	Wed Feb 16 20:29:33 2011 +0000
@@ -0,0 +1,509 @@
+local bit = require("bit");
+
+local gf = require("aeslua.gf");
+local util = require("aeslua.util");
+
+--
+-- Implementation of AES with nearly pure lua (only bitlib is needed) 
+--
+-- AES with lua is slow, really slow :-)
+--
+
+local public = {};
+local private = {};
+
+local aeslua = require("aeslua");
+aeslua.aes = public;
+
+-- some constants
+public.ROUNDS = "rounds";
+public.KEY_TYPE = "type";
+public.ENCRYPTION_KEY=1;
+public.DECRYPTION_KEY=2;
+
+-- aes SBOX
+private.SBox = {};
+private.iSBox = {};
+
+-- aes tables
+private.table0 = {};
+private.table1 = {};
+private.table2 = {};
+private.table3 = {};
+
+private.tableInv0 = {};
+private.tableInv1 = {};
+private.tableInv2 = {};
+private.tableInv3 = {};
+
+-- round constants
+private.rCon = {0x01000000, 
+                0x02000000, 
+                0x04000000, 
+                0x08000000, 
+                0x10000000, 
+                0x20000000, 
+                0x40000000, 
+                0x80000000, 
+                0x1b000000, 
+                0x36000000,
+                0x6c000000,
+                0xd8000000,
+                0xab000000,
+                0x4d000000,
+                0x9a000000,
+                0x2f000000};
+
+--
+-- affine transformation for calculating the S-Box of AES
+--
+function private.affinMap(byte)
+    local mask = 0xf8;
+    local result = 0;
+    for i = 1,8 do
+        result = bit.lshift(result,1);
+
+        local parity = util.byteParity(bit.band(byte,mask)); 
+        result = result + parity;
+
+        -- simulate roll
+        local lastbit = bit.band(mask, 1);
+        mask = bit.band(bit.rshift(mask, 1),0xff);
+        if (lastbit ~= 0) then
+            mask = bit.bor(mask, 0x80);
+        else
+            mask = bit.band(mask, 0x7f);
+        end
+    end
+
+    return bit.bxor(result, 0x63);
+end
+
+--
+-- calculate S-Box and inverse S-Box of AES
+-- apply affine transformation to inverse in finite field 2^8 
+--
+function private.calcSBox() 
+    local inverse;
+    for i = 0, 255 do
+        if (i ~= 0) then
+            inverse = gf.invert(i);
+        else
+            inverse = i;
+        end
+        local mapped = private.affinMap(inverse);                 
+        private.SBox[i] = mapped;
+        private.iSBox[mapped] = i;
+    end
+end
+
+--
+-- Calculate round tables
+-- round tables are used to calculate shiftRow, MixColumn and SubBytes 
+-- with 4 table lookups and 4 xor operations.
+--
+function private.calcRoundTables()
+    for x = 0,255 do
+        local byte = private.SBox[x];
+        private.table0[x] = util.putByte(gf.mul(0x03, byte), 0)
+                          + util.putByte(             byte , 1)
+                          + util.putByte(             byte , 2)
+                          + util.putByte(gf.mul(0x02, byte), 3);
+        private.table1[x] = util.putByte(             byte , 0)
+                          + util.putByte(             byte , 1)
+                          + util.putByte(gf.mul(0x02, byte), 2)
+                          + util.putByte(gf.mul(0x03, byte), 3);
+        private.table2[x] = util.putByte(             byte , 0)
+                          + util.putByte(gf.mul(0x02, byte), 1)
+                          + util.putByte(gf.mul(0x03, byte), 2)
+                          + util.putByte(             byte , 3);
+        private.table3[x] = util.putByte(gf.mul(0x02, byte), 0)
+                          + util.putByte(gf.mul(0x03, byte), 1)
+                          + util.putByte(             byte , 2)
+                          + util.putByte(             byte , 3);
+    end
+end
+
+--
+-- Calculate inverse round tables
+-- does the inverse of the normal roundtables for the equivalent 
+-- decryption algorithm.
+--
+function private.calcInvRoundTables()
+    for x = 0,255 do
+        local byte = private.iSBox[x];
+        private.tableInv0[x] = util.putByte(gf.mul(0x0b, byte), 0)
+                             + util.putByte(gf.mul(0x0d, byte), 1)
+                             + util.putByte(gf.mul(0x09, byte), 2)
+                             + util.putByte(gf.mul(0x0e, byte), 3);
+        private.tableInv1[x] = util.putByte(gf.mul(0x0d, byte), 0)
+                             + util.putByte(gf.mul(0x09, byte), 1)
+                             + util.putByte(gf.mul(0x0e, byte), 2)
+                             + util.putByte(gf.mul(0x0b, byte), 3);
+        private.tableInv2[x] = util.putByte(gf.mul(0x09, byte), 0)
+                             + util.putByte(gf.mul(0x0e, byte), 1)
+                             + util.putByte(gf.mul(0x0b, byte), 2)
+                             + util.putByte(gf.mul(0x0d, byte), 3);
+        private.tableInv3[x] = util.putByte(gf.mul(0x0e, byte), 0)
+                             + util.putByte(gf.mul(0x0b, byte), 1)
+                             + util.putByte(gf.mul(0x0d, byte), 2)
+                             + util.putByte(gf.mul(0x09, byte), 3);
+    end
+end
+
+
+--
+-- rotate word: 0xaabbccdd gets 0xbbccddaa
+-- used for key schedule
+--
+function private.rotWord(word)
+    local tmp = bit.band(word,0xff000000);
+    return (bit.lshift(word,8) + bit.rshift(tmp,24)) ;
+end
+
+--
+-- replace all bytes in a word with the SBox.
+-- used for key schedule
+--
+function private.subWord(word)
+    return util.putByte(private.SBox[util.getByte(word,0)],0) 
+         + util.putByte(private.SBox[util.getByte(word,1)],1) 
+         + util.putByte(private.SBox[util.getByte(word,2)],2)
+         + util.putByte(private.SBox[util.getByte(word,3)],3);
+end
+
+--
+-- generate key schedule for aes encryption
+--
+-- returns table with all round keys and
+-- the necessary number of rounds saved in [public.ROUNDS]
+--
+function public.expandEncryptionKey(key)
+    local keySchedule = {};
+    local keyWords = math.floor(#key / 4);
+   
+ 
+    if ((keyWords ~= 4 and keyWords ~= 6 and keyWords ~= 8) or (keyWords * 4 ~= #key)) then
+        error("Invalid key size: "..keyWords);
+        return nil;
+    end
+
+    keySchedule[public.ROUNDS] = keyWords + 6;
+    keySchedule[public.KEY_TYPE] = public.ENCRYPTION_KEY;
+ 
+    for i = 0,keyWords - 1 do
+        keySchedule[i] = util.putByte(key[i*4+1], 3) 
+                       + util.putByte(key[i*4+2], 2)
+                       + util.putByte(key[i*4+3], 1)
+                       + util.putByte(key[i*4+4], 0);  
+    end    
+   
+    for i = keyWords, (keySchedule[public.ROUNDS] + 1)*4 - 1 do
+        local tmp = keySchedule[i-1];
+
+        if ( i % keyWords == 0) then
+            tmp = private.rotWord(tmp);
+            tmp = private.subWord(tmp);
+            
+            local index = math.floor(i/keyWords);
+            tmp = bit.bxor(tmp,private.rCon[index]);
+        elseif (keyWords > 6 and i % keyWords == 4) then
+            tmp = private.subWord(tmp);
+        end
+        
+        keySchedule[i] = bit.bxor(keySchedule[(i-keyWords)],tmp);
+    end
+
+    return keySchedule;
+end
+
+--
+-- Inverse mix column
+-- used for key schedule of decryption key
+--
+function private.invMixColumnOld(word)
+    local b0 = util.getByte(word,3);
+    local b1 = util.getByte(word,2);
+    local b2 = util.getByte(word,1);
+    local b3 = util.getByte(word,0);
+     
+    return util.putByte(gf.add(gf.add(gf.add(gf.mul(0x0b, b1), 
+                                             gf.mul(0x0d, b2)), 
+                                             gf.mul(0x09, b3)), 
+                                             gf.mul(0x0e, b0)),3)
+         + util.putByte(gf.add(gf.add(gf.add(gf.mul(0x0b, b2), 
+                                             gf.mul(0x0d, b3)), 
+                                             gf.mul(0x09, b0)), 
+                                             gf.mul(0x0e, b1)),2)
+         + util.putByte(gf.add(gf.add(gf.add(gf.mul(0x0b, b3), 
+                                             gf.mul(0x0d, b0)), 
+                                             gf.mul(0x09, b1)), 
+                                             gf.mul(0x0e, b2)),1)
+         + util.putByte(gf.add(gf.add(gf.add(gf.mul(0x0b, b0), 
+                                             gf.mul(0x0d, b1)), 
+                                             gf.mul(0x09, b2)), 
+                                             gf.mul(0x0e, b3)),0);
+end
+
+-- 
+-- Optimized inverse mix column
+-- look at http://fp.gladman.plus.com/cryptography_technology/rijndael/aes.spec.311.pdf
+-- TODO: make it work
+--
+function private.invMixColumn(word)
+    local b0 = util.getByte(word,3);
+    local b1 = util.getByte(word,2);
+    local b2 = util.getByte(word,1);
+    local b3 = util.getByte(word,0);
+    
+    local t = bit.bxor(b3,b2);
+    local u = bit.bxor(b1,b0);
+    local v = bit.bxor(t,u);
+    v = bit.bxor(v,gf.mul(0x08,v));
+    v = bit.bxor(v,gf.mul(0x04, bit.bxor(b3,b1)));
+    local w = bit.bxor(v,gf.mul(0x04, bit.bxor(b2,b0)));
+    
+    return util.putByte( bit.bxor(bit.bxor(b3,v), gf.mul(0x02, bit.bxor(b0,b3))), 0)
+         + util.putByte( bit.bxor(bit.bxor(b2,w), gf.mul(0x02, t              )), 1)
+         + util.putByte( bit.bxor(bit.bxor(b1,v), gf.mul(0x02, bit.bxor(b0,b3))), 2)
+         + util.putByte( bit.bxor(bit.bxor(b0,w), gf.mul(0x02, u              )), 3);
+end
+
+--
+-- generate key schedule for aes decryption
+--
+-- uses key schedule for aes encryption and transforms each
+-- key by inverse mix column. 
+--
+function public.expandDecryptionKey(key)
+    local keySchedule = public.expandEncryptionKey(key);
+    if (keySchedule == nil) then
+        return nil;
+    end
+    
+    keySchedule[public.KEY_TYPE] = public.DECRYPTION_KEY;    
+
+    for i = 4, (keySchedule[public.ROUNDS] + 1)*4 - 5 do
+        keySchedule[i] = private.invMixColumnOld(keySchedule[i]);
+    end
+    
+    return keySchedule;
+end
+
+--
+-- xor round key to state
+--
+function private.addRoundKey(state, key, round)
+    for i = 0, 3 do
+        state[i] = bit.bxor(state[i], key[round*4+i]);
+    end
+end
+
+--
+-- do encryption round (ShiftRow, SubBytes, MixColumn together)
+--
+function private.doRound(origState, dstState)
+    dstState[0] =  bit.bxor(bit.bxor(bit.bxor(
+                private.table0[util.getByte(origState[0],3)],
+                private.table1[util.getByte(origState[1],2)]),
+                private.table2[util.getByte(origState[2],1)]),
+                private.table3[util.getByte(origState[3],0)]);
+
+    dstState[1] =  bit.bxor(bit.bxor(bit.bxor(
+                private.table0[util.getByte(origState[1],3)],
+                private.table1[util.getByte(origState[2],2)]),
+                private.table2[util.getByte(origState[3],1)]),
+                private.table3[util.getByte(origState[0],0)]);
+    
+    dstState[2] =  bit.bxor(bit.bxor(bit.bxor(
+                private.table0[util.getByte(origState[2],3)],
+                private.table1[util.getByte(origState[3],2)]),
+                private.table2[util.getByte(origState[0],1)]),
+                private.table3[util.getByte(origState[1],0)]);
+    
+    dstState[3] =  bit.bxor(bit.bxor(bit.bxor(
+                private.table0[util.getByte(origState[3],3)],
+                private.table1[util.getByte(origState[0],2)]),
+                private.table2[util.getByte(origState[1],1)]),
+                private.table3[util.getByte(origState[2],0)]);
+end
+
+--
+-- do last encryption round (ShiftRow and SubBytes)
+--
+function private.doLastRound(origState, dstState)
+    dstState[0] = util.putByte(private.SBox[util.getByte(origState[0],3)], 3)
+                + util.putByte(private.SBox[util.getByte(origState[1],2)], 2)
+                + util.putByte(private.SBox[util.getByte(origState[2],1)], 1)
+                + util.putByte(private.SBox[util.getByte(origState[3],0)], 0);
+
+    dstState[1] = util.putByte(private.SBox[util.getByte(origState[1],3)], 3)
+                + util.putByte(private.SBox[util.getByte(origState[2],2)], 2)
+                + util.putByte(private.SBox[util.getByte(origState[3],1)], 1)
+                + util.putByte(private.SBox[util.getByte(origState[0],0)], 0);
+
+    dstState[2] = util.putByte(private.SBox[util.getByte(origState[2],3)], 3)
+                + util.putByte(private.SBox[util.getByte(origState[3],2)], 2)
+                + util.putByte(private.SBox[util.getByte(origState[0],1)], 1)
+                + util.putByte(private.SBox[util.getByte(origState[1],0)], 0);
+
+    dstState[3] = util.putByte(private.SBox[util.getByte(origState[3],3)], 3)
+                + util.putByte(private.SBox[util.getByte(origState[0],2)], 2)
+                + util.putByte(private.SBox[util.getByte(origState[1],1)], 1)
+                + util.putByte(private.SBox[util.getByte(origState[2],0)], 0);
+end
+
+--
+-- do decryption round 
+--
+function private.doInvRound(origState, dstState)
+    dstState[0] =  bit.bxor(bit.bxor(bit.bxor(
+                private.tableInv0[util.getByte(origState[0],3)],
+                private.tableInv1[util.getByte(origState[3],2)]),
+                private.tableInv2[util.getByte(origState[2],1)]),
+                private.tableInv3[util.getByte(origState[1],0)]);
+
+    dstState[1] =  bit.bxor(bit.bxor(bit.bxor(
+                private.tableInv0[util.getByte(origState[1],3)],
+                private.tableInv1[util.getByte(origState[0],2)]),
+                private.tableInv2[util.getByte(origState[3],1)]),
+                private.tableInv3[util.getByte(origState[2],0)]);
+    
+    dstState[2] =  bit.bxor(bit.bxor(bit.bxor(
+                private.tableInv0[util.getByte(origState[2],3)],
+                private.tableInv1[util.getByte(origState[1],2)]),
+                private.tableInv2[util.getByte(origState[0],1)]),
+                private.tableInv3[util.getByte(origState[3],0)]);
+    
+    dstState[3] =  bit.bxor(bit.bxor(bit.bxor(
+                private.tableInv0[util.getByte(origState[3],3)],
+                private.tableInv1[util.getByte(origState[2],2)]),
+                private.tableInv2[util.getByte(origState[1],1)]),
+                private.tableInv3[util.getByte(origState[0],0)]);
+end
+
+--
+-- do last decryption round
+--
+function private.doInvLastRound(origState, dstState)
+    dstState[0] = util.putByte(private.iSBox[util.getByte(origState[0],3)], 3)
+                + util.putByte(private.iSBox[util.getByte(origState[3],2)], 2)
+                + util.putByte(private.iSBox[util.getByte(origState[2],1)], 1)
+                + util.putByte(private.iSBox[util.getByte(origState[1],0)], 0);
+
+    dstState[1] = util.putByte(private.iSBox[util.getByte(origState[1],3)], 3)
+                + util.putByte(private.iSBox[util.getByte(origState[0],2)], 2)
+                + util.putByte(private.iSBox[util.getByte(origState[3],1)], 1)
+                + util.putByte(private.iSBox[util.getByte(origState[2],0)], 0);
+
+    dstState[2] = util.putByte(private.iSBox[util.getByte(origState[2],3)], 3)
+                + util.putByte(private.iSBox[util.getByte(origState[1],2)], 2)
+                + util.putByte(private.iSBox[util.getByte(origState[0],1)], 1)
+                + util.putByte(private.iSBox[util.getByte(origState[3],0)], 0);
+
+    dstState[3] = util.putByte(private.iSBox[util.getByte(origState[3],3)], 3)
+                + util.putByte(private.iSBox[util.getByte(origState[2],2)], 2)
+                + util.putByte(private.iSBox[util.getByte(origState[1],1)], 1)
+                + util.putByte(private.iSBox[util.getByte(origState[0],0)], 0);
+end
+
+--
+-- encrypts 16 Bytes
+-- key           encryption key schedule
+-- input         array with input data
+-- inputOffset   start index for input
+-- output        array for encrypted data
+-- outputOffset  start index for output
+--
+function public.encrypt(key, input, inputOffset, output, outputOffset) 
+    --default parameters
+    inputOffset = inputOffset or 1;
+    output = output or {};
+    outputOffset = outputOffset or 1;
+
+    local state = {};
+    local tmpState = {};
+    
+    if (key[public.KEY_TYPE] ~= public.ENCRYPTION_KEY) then
+        error("No encryption key: "..key[public.KEY_TYPE]);
+        return;
+    end
+
+    state = util.bytesToInts(input, inputOffset, 4);
+    private.addRoundKey(state, key, 0);
+
+    local round = 1;
+    while (round < key[public.ROUNDS] - 1) do
+        -- do a double round to save temporary assignments
+        private.doRound(state, tmpState);
+        private.addRoundKey(tmpState, key, round);
+        round = round + 1;
+
+        private.doRound(tmpState, state);
+        private.addRoundKey(state, key, round);
+        round = round + 1;
+    end
+    
+    private.doRound(state, tmpState);
+    private.addRoundKey(tmpState, key, round);
+    round = round +1;
+
+    private.doLastRound(tmpState, state);
+    private.addRoundKey(state, key, round);
+    
+    return util.intsToBytes(state, output, outputOffset);
+end
+
+--
+-- decrypt 16 bytes
+-- key           decryption key schedule
+-- input         array with input data
+-- inputOffset   start index for input
+-- output        array for decrypted data
+-- outputOffset  start index for output
+---
+function public.decrypt(key, input, inputOffset, output, outputOffset) 
+    -- default arguments
+    inputOffset = inputOffset or 1;
+    output = output or {};
+    outputOffset = outputOffset or 1;
+
+    local state = {};
+    local tmpState = {};
+
+    if (key[public.KEY_TYPE] ~= public.DECRYPTION_KEY) then
+        error("No decryption key: "..key[public.KEY_TYPE]);
+        return;
+    end
+
+    state = util.bytesToInts(input, inputOffset, 4);
+    private.addRoundKey(state, key, key[public.ROUNDS]);
+
+    local round = key[public.ROUNDS] - 1;
+    while (round > 2) do
+        -- do a double round to save temporary assignments
+        private.doInvRound(state, tmpState);
+        private.addRoundKey(tmpState, key, round);
+        round = round - 1;
+
+        private.doInvRound(tmpState, state);
+        private.addRoundKey(state, key, round);
+        round = round - 1;
+    end
+    
+    private.doInvRound(state, tmpState);
+    private.addRoundKey(tmpState, key, round);
+    round = round - 1;
+
+    private.doInvLastRound(tmpState, state);
+    private.addRoundKey(state, key, round);
+    
+    return util.intsToBytes(state, output, outputOffset);
+end
+
+-- calculate all tables when loading this file
+private.calcSBox();
+private.calcRoundTables();
+private.calcInvRoundTables();
+
+return public;

mercurial