aeslua/aes.lua

Wed, 16 Feb 2011 20:29:33 +0000

author
Matthew Wild <mwild1@gmail.com>
date
Wed, 16 Feb 2011 20:29:33 +0000
changeset 0
598d09faf89c
permissions
-rw-r--r--

There are no secrets better kept than the secrets that everybody guesses.

local bit = require("bit");

local gf = require("aeslua.gf");
local util = require("aeslua.util");

--
-- Implementation of AES with nearly pure lua (only bitlib is needed) 
--
-- AES with lua is slow, really slow :-)
--

local public = {};
local private = {};

local aeslua = require("aeslua");
aeslua.aes = public;

-- some constants
public.ROUNDS = "rounds";
public.KEY_TYPE = "type";
public.ENCRYPTION_KEY=1;
public.DECRYPTION_KEY=2;

-- aes SBOX
private.SBox = {};
private.iSBox = {};

-- aes tables
private.table0 = {};
private.table1 = {};
private.table2 = {};
private.table3 = {};

private.tableInv0 = {};
private.tableInv1 = {};
private.tableInv2 = {};
private.tableInv3 = {};

-- round constants
private.rCon = {0x01000000, 
                0x02000000, 
                0x04000000, 
                0x08000000, 
                0x10000000, 
                0x20000000, 
                0x40000000, 
                0x80000000, 
                0x1b000000, 
                0x36000000,
                0x6c000000,
                0xd8000000,
                0xab000000,
                0x4d000000,
                0x9a000000,
                0x2f000000};

--
-- affine transformation for calculating the S-Box of AES
--
function private.affinMap(byte)
    local mask = 0xf8;
    local result = 0;
    for i = 1,8 do
        result = bit.lshift(result,1);

        local parity = util.byteParity(bit.band(byte,mask)); 
        result = result + parity;

        -- simulate roll
        local lastbit = bit.band(mask, 1);
        mask = bit.band(bit.rshift(mask, 1),0xff);
        if (lastbit ~= 0) then
            mask = bit.bor(mask, 0x80);
        else
            mask = bit.band(mask, 0x7f);
        end
    end

    return bit.bxor(result, 0x63);
end

--
-- calculate S-Box and inverse S-Box of AES
-- apply affine transformation to inverse in finite field 2^8 
--
function private.calcSBox() 
    local inverse;
    for i = 0, 255 do
        if (i ~= 0) then
            inverse = gf.invert(i);
        else
            inverse = i;
        end
        local mapped = private.affinMap(inverse);                 
        private.SBox[i] = mapped;
        private.iSBox[mapped] = i;
    end
end

--
-- Calculate round tables
-- round tables are used to calculate shiftRow, MixColumn and SubBytes 
-- with 4 table lookups and 4 xor operations.
--
function private.calcRoundTables()
    for x = 0,255 do
        local byte = private.SBox[x];
        private.table0[x] = util.putByte(gf.mul(0x03, byte), 0)
                          + util.putByte(             byte , 1)
                          + util.putByte(             byte , 2)
                          + util.putByte(gf.mul(0x02, byte), 3);
        private.table1[x] = util.putByte(             byte , 0)
                          + util.putByte(             byte , 1)
                          + util.putByte(gf.mul(0x02, byte), 2)
                          + util.putByte(gf.mul(0x03, byte), 3);
        private.table2[x] = util.putByte(             byte , 0)
                          + util.putByte(gf.mul(0x02, byte), 1)
                          + util.putByte(gf.mul(0x03, byte), 2)
                          + util.putByte(             byte , 3);
        private.table3[x] = util.putByte(gf.mul(0x02, byte), 0)
                          + util.putByte(gf.mul(0x03, byte), 1)
                          + util.putByte(             byte , 2)
                          + util.putByte(             byte , 3);
    end
end

--
-- Calculate inverse round tables
-- does the inverse of the normal roundtables for the equivalent 
-- decryption algorithm.
--
function private.calcInvRoundTables()
    for x = 0,255 do
        local byte = private.iSBox[x];
        private.tableInv0[x] = util.putByte(gf.mul(0x0b, byte), 0)
                             + util.putByte(gf.mul(0x0d, byte), 1)
                             + util.putByte(gf.mul(0x09, byte), 2)
                             + util.putByte(gf.mul(0x0e, byte), 3);
        private.tableInv1[x] = util.putByte(gf.mul(0x0d, byte), 0)
                             + util.putByte(gf.mul(0x09, byte), 1)
                             + util.putByte(gf.mul(0x0e, byte), 2)
                             + util.putByte(gf.mul(0x0b, byte), 3);
        private.tableInv2[x] = util.putByte(gf.mul(0x09, byte), 0)
                             + util.putByte(gf.mul(0x0e, byte), 1)
                             + util.putByte(gf.mul(0x0b, byte), 2)
                             + util.putByte(gf.mul(0x0d, byte), 3);
        private.tableInv3[x] = util.putByte(gf.mul(0x0e, byte), 0)
                             + util.putByte(gf.mul(0x0b, byte), 1)
                             + util.putByte(gf.mul(0x0d, byte), 2)
                             + util.putByte(gf.mul(0x09, byte), 3);
    end
end


--
-- rotate word: 0xaabbccdd gets 0xbbccddaa
-- used for key schedule
--
function private.rotWord(word)
    local tmp = bit.band(word,0xff000000);
    return (bit.lshift(word,8) + bit.rshift(tmp,24)) ;
end

--
-- replace all bytes in a word with the SBox.
-- used for key schedule
--
function private.subWord(word)
    return util.putByte(private.SBox[util.getByte(word,0)],0) 
         + util.putByte(private.SBox[util.getByte(word,1)],1) 
         + util.putByte(private.SBox[util.getByte(word,2)],2)
         + util.putByte(private.SBox[util.getByte(word,3)],3);
end

--
-- generate key schedule for aes encryption
--
-- returns table with all round keys and
-- the necessary number of rounds saved in [public.ROUNDS]
--
function public.expandEncryptionKey(key)
    local keySchedule = {};
    local keyWords = math.floor(#key / 4);
   
 
    if ((keyWords ~= 4 and keyWords ~= 6 and keyWords ~= 8) or (keyWords * 4 ~= #key)) then
        error("Invalid key size: "..keyWords);
        return nil;
    end

    keySchedule[public.ROUNDS] = keyWords + 6;
    keySchedule[public.KEY_TYPE] = public.ENCRYPTION_KEY;
 
    for i = 0,keyWords - 1 do
        keySchedule[i] = util.putByte(key[i*4+1], 3) 
                       + util.putByte(key[i*4+2], 2)
                       + util.putByte(key[i*4+3], 1)
                       + util.putByte(key[i*4+4], 0);  
    end    
   
    for i = keyWords, (keySchedule[public.ROUNDS] + 1)*4 - 1 do
        local tmp = keySchedule[i-1];

        if ( i % keyWords == 0) then
            tmp = private.rotWord(tmp);
            tmp = private.subWord(tmp);
            
            local index = math.floor(i/keyWords);
            tmp = bit.bxor(tmp,private.rCon[index]);
        elseif (keyWords > 6 and i % keyWords == 4) then
            tmp = private.subWord(tmp);
        end
        
        keySchedule[i] = bit.bxor(keySchedule[(i-keyWords)],tmp);
    end

    return keySchedule;
end

--
-- Inverse mix column
-- used for key schedule of decryption key
--
function private.invMixColumnOld(word)
    local b0 = util.getByte(word,3);
    local b1 = util.getByte(word,2);
    local b2 = util.getByte(word,1);
    local b3 = util.getByte(word,0);
     
    return util.putByte(gf.add(gf.add(gf.add(gf.mul(0x0b, b1), 
                                             gf.mul(0x0d, b2)), 
                                             gf.mul(0x09, b3)), 
                                             gf.mul(0x0e, b0)),3)
         + util.putByte(gf.add(gf.add(gf.add(gf.mul(0x0b, b2), 
                                             gf.mul(0x0d, b3)), 
                                             gf.mul(0x09, b0)), 
                                             gf.mul(0x0e, b1)),2)
         + util.putByte(gf.add(gf.add(gf.add(gf.mul(0x0b, b3), 
                                             gf.mul(0x0d, b0)), 
                                             gf.mul(0x09, b1)), 
                                             gf.mul(0x0e, b2)),1)
         + util.putByte(gf.add(gf.add(gf.add(gf.mul(0x0b, b0), 
                                             gf.mul(0x0d, b1)), 
                                             gf.mul(0x09, b2)), 
                                             gf.mul(0x0e, b3)),0);
end

-- 
-- Optimized inverse mix column
-- look at http://fp.gladman.plus.com/cryptography_technology/rijndael/aes.spec.311.pdf
-- TODO: make it work
--
function private.invMixColumn(word)
    local b0 = util.getByte(word,3);
    local b1 = util.getByte(word,2);
    local b2 = util.getByte(word,1);
    local b3 = util.getByte(word,0);
    
    local t = bit.bxor(b3,b2);
    local u = bit.bxor(b1,b0);
    local v = bit.bxor(t,u);
    v = bit.bxor(v,gf.mul(0x08,v));
    v = bit.bxor(v,gf.mul(0x04, bit.bxor(b3,b1)));
    local w = bit.bxor(v,gf.mul(0x04, bit.bxor(b2,b0)));
    
    return util.putByte( bit.bxor(bit.bxor(b3,v), gf.mul(0x02, bit.bxor(b0,b3))), 0)
         + util.putByte( bit.bxor(bit.bxor(b2,w), gf.mul(0x02, t              )), 1)
         + util.putByte( bit.bxor(bit.bxor(b1,v), gf.mul(0x02, bit.bxor(b0,b3))), 2)
         + util.putByte( bit.bxor(bit.bxor(b0,w), gf.mul(0x02, u              )), 3);
end

--
-- generate key schedule for aes decryption
--
-- uses key schedule for aes encryption and transforms each
-- key by inverse mix column. 
--
function public.expandDecryptionKey(key)
    local keySchedule = public.expandEncryptionKey(key);
    if (keySchedule == nil) then
        return nil;
    end
    
    keySchedule[public.KEY_TYPE] = public.DECRYPTION_KEY;    

    for i = 4, (keySchedule[public.ROUNDS] + 1)*4 - 5 do
        keySchedule[i] = private.invMixColumnOld(keySchedule[i]);
    end
    
    return keySchedule;
end

--
-- xor round key to state
--
function private.addRoundKey(state, key, round)
    for i = 0, 3 do
        state[i] = bit.bxor(state[i], key[round*4+i]);
    end
end

--
-- do encryption round (ShiftRow, SubBytes, MixColumn together)
--
function private.doRound(origState, dstState)
    dstState[0] =  bit.bxor(bit.bxor(bit.bxor(
                private.table0[util.getByte(origState[0],3)],
                private.table1[util.getByte(origState[1],2)]),
                private.table2[util.getByte(origState[2],1)]),
                private.table3[util.getByte(origState[3],0)]);

    dstState[1] =  bit.bxor(bit.bxor(bit.bxor(
                private.table0[util.getByte(origState[1],3)],
                private.table1[util.getByte(origState[2],2)]),
                private.table2[util.getByte(origState[3],1)]),
                private.table3[util.getByte(origState[0],0)]);
    
    dstState[2] =  bit.bxor(bit.bxor(bit.bxor(
                private.table0[util.getByte(origState[2],3)],
                private.table1[util.getByte(origState[3],2)]),
                private.table2[util.getByte(origState[0],1)]),
                private.table3[util.getByte(origState[1],0)]);
    
    dstState[3] =  bit.bxor(bit.bxor(bit.bxor(
                private.table0[util.getByte(origState[3],3)],
                private.table1[util.getByte(origState[0],2)]),
                private.table2[util.getByte(origState[1],1)]),
                private.table3[util.getByte(origState[2],0)]);
end

--
-- do last encryption round (ShiftRow and SubBytes)
--
function private.doLastRound(origState, dstState)
    dstState[0] = util.putByte(private.SBox[util.getByte(origState[0],3)], 3)
                + util.putByte(private.SBox[util.getByte(origState[1],2)], 2)
                + util.putByte(private.SBox[util.getByte(origState[2],1)], 1)
                + util.putByte(private.SBox[util.getByte(origState[3],0)], 0);

    dstState[1] = util.putByte(private.SBox[util.getByte(origState[1],3)], 3)
                + util.putByte(private.SBox[util.getByte(origState[2],2)], 2)
                + util.putByte(private.SBox[util.getByte(origState[3],1)], 1)
                + util.putByte(private.SBox[util.getByte(origState[0],0)], 0);

    dstState[2] = util.putByte(private.SBox[util.getByte(origState[2],3)], 3)
                + util.putByte(private.SBox[util.getByte(origState[3],2)], 2)
                + util.putByte(private.SBox[util.getByte(origState[0],1)], 1)
                + util.putByte(private.SBox[util.getByte(origState[1],0)], 0);

    dstState[3] = util.putByte(private.SBox[util.getByte(origState[3],3)], 3)
                + util.putByte(private.SBox[util.getByte(origState[0],2)], 2)
                + util.putByte(private.SBox[util.getByte(origState[1],1)], 1)
                + util.putByte(private.SBox[util.getByte(origState[2],0)], 0);
end

--
-- do decryption round 
--
function private.doInvRound(origState, dstState)
    dstState[0] =  bit.bxor(bit.bxor(bit.bxor(
                private.tableInv0[util.getByte(origState[0],3)],
                private.tableInv1[util.getByte(origState[3],2)]),
                private.tableInv2[util.getByte(origState[2],1)]),
                private.tableInv3[util.getByte(origState[1],0)]);

    dstState[1] =  bit.bxor(bit.bxor(bit.bxor(
                private.tableInv0[util.getByte(origState[1],3)],
                private.tableInv1[util.getByte(origState[0],2)]),
                private.tableInv2[util.getByte(origState[3],1)]),
                private.tableInv3[util.getByte(origState[2],0)]);
    
    dstState[2] =  bit.bxor(bit.bxor(bit.bxor(
                private.tableInv0[util.getByte(origState[2],3)],
                private.tableInv1[util.getByte(origState[1],2)]),
                private.tableInv2[util.getByte(origState[0],1)]),
                private.tableInv3[util.getByte(origState[3],0)]);
    
    dstState[3] =  bit.bxor(bit.bxor(bit.bxor(
                private.tableInv0[util.getByte(origState[3],3)],
                private.tableInv1[util.getByte(origState[2],2)]),
                private.tableInv2[util.getByte(origState[1],1)]),
                private.tableInv3[util.getByte(origState[0],0)]);
end

--
-- do last decryption round
--
function private.doInvLastRound(origState, dstState)
    dstState[0] = util.putByte(private.iSBox[util.getByte(origState[0],3)], 3)
                + util.putByte(private.iSBox[util.getByte(origState[3],2)], 2)
                + util.putByte(private.iSBox[util.getByte(origState[2],1)], 1)
                + util.putByte(private.iSBox[util.getByte(origState[1],0)], 0);

    dstState[1] = util.putByte(private.iSBox[util.getByte(origState[1],3)], 3)
                + util.putByte(private.iSBox[util.getByte(origState[0],2)], 2)
                + util.putByte(private.iSBox[util.getByte(origState[3],1)], 1)
                + util.putByte(private.iSBox[util.getByte(origState[2],0)], 0);

    dstState[2] = util.putByte(private.iSBox[util.getByte(origState[2],3)], 3)
                + util.putByte(private.iSBox[util.getByte(origState[1],2)], 2)
                + util.putByte(private.iSBox[util.getByte(origState[0],1)], 1)
                + util.putByte(private.iSBox[util.getByte(origState[3],0)], 0);

    dstState[3] = util.putByte(private.iSBox[util.getByte(origState[3],3)], 3)
                + util.putByte(private.iSBox[util.getByte(origState[2],2)], 2)
                + util.putByte(private.iSBox[util.getByte(origState[1],1)], 1)
                + util.putByte(private.iSBox[util.getByte(origState[0],0)], 0);
end

--
-- encrypts 16 Bytes
-- key           encryption key schedule
-- input         array with input data
-- inputOffset   start index for input
-- output        array for encrypted data
-- outputOffset  start index for output
--
function public.encrypt(key, input, inputOffset, output, outputOffset) 
    --default parameters
    inputOffset = inputOffset or 1;
    output = output or {};
    outputOffset = outputOffset or 1;

    local state = {};
    local tmpState = {};
    
    if (key[public.KEY_TYPE] ~= public.ENCRYPTION_KEY) then
        error("No encryption key: "..key[public.KEY_TYPE]);
        return;
    end

    state = util.bytesToInts(input, inputOffset, 4);
    private.addRoundKey(state, key, 0);

    local round = 1;
    while (round < key[public.ROUNDS] - 1) do
        -- do a double round to save temporary assignments
        private.doRound(state, tmpState);
        private.addRoundKey(tmpState, key, round);
        round = round + 1;

        private.doRound(tmpState, state);
        private.addRoundKey(state, key, round);
        round = round + 1;
    end
    
    private.doRound(state, tmpState);
    private.addRoundKey(tmpState, key, round);
    round = round +1;

    private.doLastRound(tmpState, state);
    private.addRoundKey(state, key, round);
    
    return util.intsToBytes(state, output, outputOffset);
end

--
-- decrypt 16 bytes
-- key           decryption key schedule
-- input         array with input data
-- inputOffset   start index for input
-- output        array for decrypted data
-- outputOffset  start index for output
---
function public.decrypt(key, input, inputOffset, output, outputOffset) 
    -- default arguments
    inputOffset = inputOffset or 1;
    output = output or {};
    outputOffset = outputOffset or 1;

    local state = {};
    local tmpState = {};

    if (key[public.KEY_TYPE] ~= public.DECRYPTION_KEY) then
        error("No decryption key: "..key[public.KEY_TYPE]);
        return;
    end

    state = util.bytesToInts(input, inputOffset, 4);
    private.addRoundKey(state, key, key[public.ROUNDS]);

    local round = key[public.ROUNDS] - 1;
    while (round > 2) do
        -- do a double round to save temporary assignments
        private.doInvRound(state, tmpState);
        private.addRoundKey(tmpState, key, round);
        round = round - 1;

        private.doInvRound(tmpState, state);
        private.addRoundKey(state, key, round);
        round = round - 1;
    end
    
    private.doInvRound(state, tmpState);
    private.addRoundKey(tmpState, key, round);
    round = round - 1;

    private.doInvLastRound(tmpState, state);
    private.addRoundKey(state, key, round);
    
    return util.intsToBytes(state, output, outputOffset);
end

-- calculate all tables when loading this file
private.calcSBox();
private.calcRoundTables();
private.calcInvRoundTables();

return public;

mercurial