aeslua/aes.lua

changeset 0
598d09faf89c
equal deleted inserted replaced
-1:000000000000 0:598d09faf89c
1 local bit = require("bit");
2
3 local gf = require("aeslua.gf");
4 local util = require("aeslua.util");
5
6 --
7 -- Implementation of AES with nearly pure lua (only bitlib is needed)
8 --
9 -- AES with lua is slow, really slow :-)
10 --
11
12 local public = {};
13 local private = {};
14
15 local aeslua = require("aeslua");
16 aeslua.aes = public;
17
18 -- some constants
19 public.ROUNDS = "rounds";
20 public.KEY_TYPE = "type";
21 public.ENCRYPTION_KEY=1;
22 public.DECRYPTION_KEY=2;
23
24 -- aes SBOX
25 private.SBox = {};
26 private.iSBox = {};
27
28 -- aes tables
29 private.table0 = {};
30 private.table1 = {};
31 private.table2 = {};
32 private.table3 = {};
33
34 private.tableInv0 = {};
35 private.tableInv1 = {};
36 private.tableInv2 = {};
37 private.tableInv3 = {};
38
39 -- round constants
40 private.rCon = {0x01000000,
41 0x02000000,
42 0x04000000,
43 0x08000000,
44 0x10000000,
45 0x20000000,
46 0x40000000,
47 0x80000000,
48 0x1b000000,
49 0x36000000,
50 0x6c000000,
51 0xd8000000,
52 0xab000000,
53 0x4d000000,
54 0x9a000000,
55 0x2f000000};
56
57 --
58 -- affine transformation for calculating the S-Box of AES
59 --
60 function private.affinMap(byte)
61 local mask = 0xf8;
62 local result = 0;
63 for i = 1,8 do
64 result = bit.lshift(result,1);
65
66 local parity = util.byteParity(bit.band(byte,mask));
67 result = result + parity;
68
69 -- simulate roll
70 local lastbit = bit.band(mask, 1);
71 mask = bit.band(bit.rshift(mask, 1),0xff);
72 if (lastbit ~= 0) then
73 mask = bit.bor(mask, 0x80);
74 else
75 mask = bit.band(mask, 0x7f);
76 end
77 end
78
79 return bit.bxor(result, 0x63);
80 end
81
82 --
83 -- calculate S-Box and inverse S-Box of AES
84 -- apply affine transformation to inverse in finite field 2^8
85 --
86 function private.calcSBox()
87 local inverse;
88 for i = 0, 255 do
89 if (i ~= 0) then
90 inverse = gf.invert(i);
91 else
92 inverse = i;
93 end
94 local mapped = private.affinMap(inverse);
95 private.SBox[i] = mapped;
96 private.iSBox[mapped] = i;
97 end
98 end
99
100 --
101 -- Calculate round tables
102 -- round tables are used to calculate shiftRow, MixColumn and SubBytes
103 -- with 4 table lookups and 4 xor operations.
104 --
105 function private.calcRoundTables()
106 for x = 0,255 do
107 local byte = private.SBox[x];
108 private.table0[x] = util.putByte(gf.mul(0x03, byte), 0)
109 + util.putByte( byte , 1)
110 + util.putByte( byte , 2)
111 + util.putByte(gf.mul(0x02, byte), 3);
112 private.table1[x] = util.putByte( byte , 0)
113 + util.putByte( byte , 1)
114 + util.putByte(gf.mul(0x02, byte), 2)
115 + util.putByte(gf.mul(0x03, byte), 3);
116 private.table2[x] = util.putByte( byte , 0)
117 + util.putByte(gf.mul(0x02, byte), 1)
118 + util.putByte(gf.mul(0x03, byte), 2)
119 + util.putByte( byte , 3);
120 private.table3[x] = util.putByte(gf.mul(0x02, byte), 0)
121 + util.putByte(gf.mul(0x03, byte), 1)
122 + util.putByte( byte , 2)
123 + util.putByte( byte , 3);
124 end
125 end
126
127 --
128 -- Calculate inverse round tables
129 -- does the inverse of the normal roundtables for the equivalent
130 -- decryption algorithm.
131 --
132 function private.calcInvRoundTables()
133 for x = 0,255 do
134 local byte = private.iSBox[x];
135 private.tableInv0[x] = util.putByte(gf.mul(0x0b, byte), 0)
136 + util.putByte(gf.mul(0x0d, byte), 1)
137 + util.putByte(gf.mul(0x09, byte), 2)
138 + util.putByte(gf.mul(0x0e, byte), 3);
139 private.tableInv1[x] = util.putByte(gf.mul(0x0d, byte), 0)
140 + util.putByte(gf.mul(0x09, byte), 1)
141 + util.putByte(gf.mul(0x0e, byte), 2)
142 + util.putByte(gf.mul(0x0b, byte), 3);
143 private.tableInv2[x] = util.putByte(gf.mul(0x09, byte), 0)
144 + util.putByte(gf.mul(0x0e, byte), 1)
145 + util.putByte(gf.mul(0x0b, byte), 2)
146 + util.putByte(gf.mul(0x0d, byte), 3);
147 private.tableInv3[x] = util.putByte(gf.mul(0x0e, byte), 0)
148 + util.putByte(gf.mul(0x0b, byte), 1)
149 + util.putByte(gf.mul(0x0d, byte), 2)
150 + util.putByte(gf.mul(0x09, byte), 3);
151 end
152 end
153
154
155 --
156 -- rotate word: 0xaabbccdd gets 0xbbccddaa
157 -- used for key schedule
158 --
159 function private.rotWord(word)
160 local tmp = bit.band(word,0xff000000);
161 return (bit.lshift(word,8) + bit.rshift(tmp,24)) ;
162 end
163
164 --
165 -- replace all bytes in a word with the SBox.
166 -- used for key schedule
167 --
168 function private.subWord(word)
169 return util.putByte(private.SBox[util.getByte(word,0)],0)
170 + util.putByte(private.SBox[util.getByte(word,1)],1)
171 + util.putByte(private.SBox[util.getByte(word,2)],2)
172 + util.putByte(private.SBox[util.getByte(word,3)],3);
173 end
174
175 --
176 -- generate key schedule for aes encryption
177 --
178 -- returns table with all round keys and
179 -- the necessary number of rounds saved in [public.ROUNDS]
180 --
181 function public.expandEncryptionKey(key)
182 local keySchedule = {};
183 local keyWords = math.floor(#key / 4);
184
185
186 if ((keyWords ~= 4 and keyWords ~= 6 and keyWords ~= 8) or (keyWords * 4 ~= #key)) then
187 error("Invalid key size: "..keyWords);
188 return nil;
189 end
190
191 keySchedule[public.ROUNDS] = keyWords + 6;
192 keySchedule[public.KEY_TYPE] = public.ENCRYPTION_KEY;
193
194 for i = 0,keyWords - 1 do
195 keySchedule[i] = util.putByte(key[i*4+1], 3)
196 + util.putByte(key[i*4+2], 2)
197 + util.putByte(key[i*4+3], 1)
198 + util.putByte(key[i*4+4], 0);
199 end
200
201 for i = keyWords, (keySchedule[public.ROUNDS] + 1)*4 - 1 do
202 local tmp = keySchedule[i-1];
203
204 if ( i % keyWords == 0) then
205 tmp = private.rotWord(tmp);
206 tmp = private.subWord(tmp);
207
208 local index = math.floor(i/keyWords);
209 tmp = bit.bxor(tmp,private.rCon[index]);
210 elseif (keyWords > 6 and i % keyWords == 4) then
211 tmp = private.subWord(tmp);
212 end
213
214 keySchedule[i] = bit.bxor(keySchedule[(i-keyWords)],tmp);
215 end
216
217 return keySchedule;
218 end
219
220 --
221 -- Inverse mix column
222 -- used for key schedule of decryption key
223 --
224 function private.invMixColumnOld(word)
225 local b0 = util.getByte(word,3);
226 local b1 = util.getByte(word,2);
227 local b2 = util.getByte(word,1);
228 local b3 = util.getByte(word,0);
229
230 return util.putByte(gf.add(gf.add(gf.add(gf.mul(0x0b, b1),
231 gf.mul(0x0d, b2)),
232 gf.mul(0x09, b3)),
233 gf.mul(0x0e, b0)),3)
234 + util.putByte(gf.add(gf.add(gf.add(gf.mul(0x0b, b2),
235 gf.mul(0x0d, b3)),
236 gf.mul(0x09, b0)),
237 gf.mul(0x0e, b1)),2)
238 + util.putByte(gf.add(gf.add(gf.add(gf.mul(0x0b, b3),
239 gf.mul(0x0d, b0)),
240 gf.mul(0x09, b1)),
241 gf.mul(0x0e, b2)),1)
242 + util.putByte(gf.add(gf.add(gf.add(gf.mul(0x0b, b0),
243 gf.mul(0x0d, b1)),
244 gf.mul(0x09, b2)),
245 gf.mul(0x0e, b3)),0);
246 end
247
248 --
249 -- Optimized inverse mix column
250 -- look at http://fp.gladman.plus.com/cryptography_technology/rijndael/aes.spec.311.pdf
251 -- TODO: make it work
252 --
253 function private.invMixColumn(word)
254 local b0 = util.getByte(word,3);
255 local b1 = util.getByte(word,2);
256 local b2 = util.getByte(word,1);
257 local b3 = util.getByte(word,0);
258
259 local t = bit.bxor(b3,b2);
260 local u = bit.bxor(b1,b0);
261 local v = bit.bxor(t,u);
262 v = bit.bxor(v,gf.mul(0x08,v));
263 v = bit.bxor(v,gf.mul(0x04, bit.bxor(b3,b1)));
264 local w = bit.bxor(v,gf.mul(0x04, bit.bxor(b2,b0)));
265
266 return util.putByte( bit.bxor(bit.bxor(b3,v), gf.mul(0x02, bit.bxor(b0,b3))), 0)
267 + util.putByte( bit.bxor(bit.bxor(b2,w), gf.mul(0x02, t )), 1)
268 + util.putByte( bit.bxor(bit.bxor(b1,v), gf.mul(0x02, bit.bxor(b0,b3))), 2)
269 + util.putByte( bit.bxor(bit.bxor(b0,w), gf.mul(0x02, u )), 3);
270 end
271
272 --
273 -- generate key schedule for aes decryption
274 --
275 -- uses key schedule for aes encryption and transforms each
276 -- key by inverse mix column.
277 --
278 function public.expandDecryptionKey(key)
279 local keySchedule = public.expandEncryptionKey(key);
280 if (keySchedule == nil) then
281 return nil;
282 end
283
284 keySchedule[public.KEY_TYPE] = public.DECRYPTION_KEY;
285
286 for i = 4, (keySchedule[public.ROUNDS] + 1)*4 - 5 do
287 keySchedule[i] = private.invMixColumnOld(keySchedule[i]);
288 end
289
290 return keySchedule;
291 end
292
293 --
294 -- xor round key to state
295 --
296 function private.addRoundKey(state, key, round)
297 for i = 0, 3 do
298 state[i] = bit.bxor(state[i], key[round*4+i]);
299 end
300 end
301
302 --
303 -- do encryption round (ShiftRow, SubBytes, MixColumn together)
304 --
305 function private.doRound(origState, dstState)
306 dstState[0] = bit.bxor(bit.bxor(bit.bxor(
307 private.table0[util.getByte(origState[0],3)],
308 private.table1[util.getByte(origState[1],2)]),
309 private.table2[util.getByte(origState[2],1)]),
310 private.table3[util.getByte(origState[3],0)]);
311
312 dstState[1] = bit.bxor(bit.bxor(bit.bxor(
313 private.table0[util.getByte(origState[1],3)],
314 private.table1[util.getByte(origState[2],2)]),
315 private.table2[util.getByte(origState[3],1)]),
316 private.table3[util.getByte(origState[0],0)]);
317
318 dstState[2] = bit.bxor(bit.bxor(bit.bxor(
319 private.table0[util.getByte(origState[2],3)],
320 private.table1[util.getByte(origState[3],2)]),
321 private.table2[util.getByte(origState[0],1)]),
322 private.table3[util.getByte(origState[1],0)]);
323
324 dstState[3] = bit.bxor(bit.bxor(bit.bxor(
325 private.table0[util.getByte(origState[3],3)],
326 private.table1[util.getByte(origState[0],2)]),
327 private.table2[util.getByte(origState[1],1)]),
328 private.table3[util.getByte(origState[2],0)]);
329 end
330
331 --
332 -- do last encryption round (ShiftRow and SubBytes)
333 --
334 function private.doLastRound(origState, dstState)
335 dstState[0] = util.putByte(private.SBox[util.getByte(origState[0],3)], 3)
336 + util.putByte(private.SBox[util.getByte(origState[1],2)], 2)
337 + util.putByte(private.SBox[util.getByte(origState[2],1)], 1)
338 + util.putByte(private.SBox[util.getByte(origState[3],0)], 0);
339
340 dstState[1] = util.putByte(private.SBox[util.getByte(origState[1],3)], 3)
341 + util.putByte(private.SBox[util.getByte(origState[2],2)], 2)
342 + util.putByte(private.SBox[util.getByte(origState[3],1)], 1)
343 + util.putByte(private.SBox[util.getByte(origState[0],0)], 0);
344
345 dstState[2] = util.putByte(private.SBox[util.getByte(origState[2],3)], 3)
346 + util.putByte(private.SBox[util.getByte(origState[3],2)], 2)
347 + util.putByte(private.SBox[util.getByte(origState[0],1)], 1)
348 + util.putByte(private.SBox[util.getByte(origState[1],0)], 0);
349
350 dstState[3] = util.putByte(private.SBox[util.getByte(origState[3],3)], 3)
351 + util.putByte(private.SBox[util.getByte(origState[0],2)], 2)
352 + util.putByte(private.SBox[util.getByte(origState[1],1)], 1)
353 + util.putByte(private.SBox[util.getByte(origState[2],0)], 0);
354 end
355
356 --
357 -- do decryption round
358 --
359 function private.doInvRound(origState, dstState)
360 dstState[0] = bit.bxor(bit.bxor(bit.bxor(
361 private.tableInv0[util.getByte(origState[0],3)],
362 private.tableInv1[util.getByte(origState[3],2)]),
363 private.tableInv2[util.getByte(origState[2],1)]),
364 private.tableInv3[util.getByte(origState[1],0)]);
365
366 dstState[1] = bit.bxor(bit.bxor(bit.bxor(
367 private.tableInv0[util.getByte(origState[1],3)],
368 private.tableInv1[util.getByte(origState[0],2)]),
369 private.tableInv2[util.getByte(origState[3],1)]),
370 private.tableInv3[util.getByte(origState[2],0)]);
371
372 dstState[2] = bit.bxor(bit.bxor(bit.bxor(
373 private.tableInv0[util.getByte(origState[2],3)],
374 private.tableInv1[util.getByte(origState[1],2)]),
375 private.tableInv2[util.getByte(origState[0],1)]),
376 private.tableInv3[util.getByte(origState[3],0)]);
377
378 dstState[3] = bit.bxor(bit.bxor(bit.bxor(
379 private.tableInv0[util.getByte(origState[3],3)],
380 private.tableInv1[util.getByte(origState[2],2)]),
381 private.tableInv2[util.getByte(origState[1],1)]),
382 private.tableInv3[util.getByte(origState[0],0)]);
383 end
384
385 --
386 -- do last decryption round
387 --
388 function private.doInvLastRound(origState, dstState)
389 dstState[0] = util.putByte(private.iSBox[util.getByte(origState[0],3)], 3)
390 + util.putByte(private.iSBox[util.getByte(origState[3],2)], 2)
391 + util.putByte(private.iSBox[util.getByte(origState[2],1)], 1)
392 + util.putByte(private.iSBox[util.getByte(origState[1],0)], 0);
393
394 dstState[1] = util.putByte(private.iSBox[util.getByte(origState[1],3)], 3)
395 + util.putByte(private.iSBox[util.getByte(origState[0],2)], 2)
396 + util.putByte(private.iSBox[util.getByte(origState[3],1)], 1)
397 + util.putByte(private.iSBox[util.getByte(origState[2],0)], 0);
398
399 dstState[2] = util.putByte(private.iSBox[util.getByte(origState[2],3)], 3)
400 + util.putByte(private.iSBox[util.getByte(origState[1],2)], 2)
401 + util.putByte(private.iSBox[util.getByte(origState[0],1)], 1)
402 + util.putByte(private.iSBox[util.getByte(origState[3],0)], 0);
403
404 dstState[3] = util.putByte(private.iSBox[util.getByte(origState[3],3)], 3)
405 + util.putByte(private.iSBox[util.getByte(origState[2],2)], 2)
406 + util.putByte(private.iSBox[util.getByte(origState[1],1)], 1)
407 + util.putByte(private.iSBox[util.getByte(origState[0],0)], 0);
408 end
409
410 --
411 -- encrypts 16 Bytes
412 -- key encryption key schedule
413 -- input array with input data
414 -- inputOffset start index for input
415 -- output array for encrypted data
416 -- outputOffset start index for output
417 --
418 function public.encrypt(key, input, inputOffset, output, outputOffset)
419 --default parameters
420 inputOffset = inputOffset or 1;
421 output = output or {};
422 outputOffset = outputOffset or 1;
423
424 local state = {};
425 local tmpState = {};
426
427 if (key[public.KEY_TYPE] ~= public.ENCRYPTION_KEY) then
428 error("No encryption key: "..key[public.KEY_TYPE]);
429 return;
430 end
431
432 state = util.bytesToInts(input, inputOffset, 4);
433 private.addRoundKey(state, key, 0);
434
435 local round = 1;
436 while (round < key[public.ROUNDS] - 1) do
437 -- do a double round to save temporary assignments
438 private.doRound(state, tmpState);
439 private.addRoundKey(tmpState, key, round);
440 round = round + 1;
441
442 private.doRound(tmpState, state);
443 private.addRoundKey(state, key, round);
444 round = round + 1;
445 end
446
447 private.doRound(state, tmpState);
448 private.addRoundKey(tmpState, key, round);
449 round = round +1;
450
451 private.doLastRound(tmpState, state);
452 private.addRoundKey(state, key, round);
453
454 return util.intsToBytes(state, output, outputOffset);
455 end
456
457 --
458 -- decrypt 16 bytes
459 -- key decryption key schedule
460 -- input array with input data
461 -- inputOffset start index for input
462 -- output array for decrypted data
463 -- outputOffset start index for output
464 ---
465 function public.decrypt(key, input, inputOffset, output, outputOffset)
466 -- default arguments
467 inputOffset = inputOffset or 1;
468 output = output or {};
469 outputOffset = outputOffset or 1;
470
471 local state = {};
472 local tmpState = {};
473
474 if (key[public.KEY_TYPE] ~= public.DECRYPTION_KEY) then
475 error("No decryption key: "..key[public.KEY_TYPE]);
476 return;
477 end
478
479 state = util.bytesToInts(input, inputOffset, 4);
480 private.addRoundKey(state, key, key[public.ROUNDS]);
481
482 local round = key[public.ROUNDS] - 1;
483 while (round > 2) do
484 -- do a double round to save temporary assignments
485 private.doInvRound(state, tmpState);
486 private.addRoundKey(tmpState, key, round);
487 round = round - 1;
488
489 private.doInvRound(tmpState, state);
490 private.addRoundKey(state, key, round);
491 round = round - 1;
492 end
493
494 private.doInvRound(state, tmpState);
495 private.addRoundKey(tmpState, key, round);
496 round = round - 1;
497
498 private.doInvLastRound(tmpState, state);
499 private.addRoundKey(state, key, round);
500
501 return util.intsToBytes(state, output, outputOffset);
502 end
503
504 -- calculate all tables when loading this file
505 private.calcSBox();
506 private.calcRoundTables();
507 private.calcInvRoundTables();
508
509 return public;

mercurial