PageRenderTime 51ms CodeModel.GetById 13ms app.highlight 34ms RepoModel.GetById 1ms app.codeStats 0ms

/gdata/tlslite/utils/rijndael.py

http://radioappz.googlecode.com/
Python | 392 lines | 306 code | 36 blank | 50 comment | 100 complexity | b556d51d40ce5d9b76b633faa14961ab MD5 | raw file
  1"""
  2A pure python (slow) implementation of rijndael with a decent interface
  3
  4To include -
  5
  6from rijndael import rijndael
  7
  8To do a key setup -
  9
 10r = rijndael(key, block_size = 16)
 11
 12key must be a string of length 16, 24, or 32
 13blocksize must be 16, 24, or 32. Default is 16
 14
 15To use -
 16
 17ciphertext = r.encrypt(plaintext)
 18plaintext = r.decrypt(ciphertext)
 19
 20If any strings are of the wrong length a ValueError is thrown
 21"""
 22
 23# ported from the Java reference code by Bram Cohen, bram@gawth.com, April 2001
 24# this code is public domain, unless someone makes
 25# an intellectual property claim against the reference
 26# code, in which case it can be made public domain by
 27# deleting all the comments and renaming all the variables
 28
 29import copy
 30import string
 31
 32
 33
 34#-----------------------
 35#TREV - ADDED BECAUSE THERE'S WARNINGS ABOUT INT OVERFLOW BEHAVIOR CHANGING IN
 36#2.4.....
 37import os
 38if os.name != "java":
 39    import exceptions
 40    if hasattr(exceptions, "FutureWarning"):
 41        import warnings
 42        warnings.filterwarnings("ignore", category=FutureWarning, append=1)
 43#-----------------------
 44
 45
 46
 47shifts = [[[0, 0], [1, 3], [2, 2], [3, 1]],
 48          [[0, 0], [1, 5], [2, 4], [3, 3]],
 49          [[0, 0], [1, 7], [3, 5], [4, 4]]]
 50
 51# [keysize][block_size]
 52num_rounds = {16: {16: 10, 24: 12, 32: 14}, 24: {16: 12, 24: 12, 32: 14}, 32: {16: 14, 24: 14, 32: 14}}
 53
 54A = [[1, 1, 1, 1, 1, 0, 0, 0],
 55     [0, 1, 1, 1, 1, 1, 0, 0],
 56     [0, 0, 1, 1, 1, 1, 1, 0],
 57     [0, 0, 0, 1, 1, 1, 1, 1],
 58     [1, 0, 0, 0, 1, 1, 1, 1],
 59     [1, 1, 0, 0, 0, 1, 1, 1],
 60     [1, 1, 1, 0, 0, 0, 1, 1],
 61     [1, 1, 1, 1, 0, 0, 0, 1]]
 62
 63# produce log and alog tables, needed for multiplying in the
 64# field GF(2^m) (generator = 3)
 65alog = [1]
 66for i in xrange(255):
 67    j = (alog[-1] << 1) ^ alog[-1]
 68    if j & 0x100 != 0:
 69        j ^= 0x11B
 70    alog.append(j)
 71
 72log = [0] * 256
 73for i in xrange(1, 255):
 74    log[alog[i]] = i
 75
 76# multiply two elements of GF(2^m)
 77def mul(a, b):
 78    if a == 0 or b == 0:
 79        return 0
 80    return alog[(log[a & 0xFF] + log[b & 0xFF]) % 255]
 81
 82# substitution box based on F^{-1}(x)
 83box = [[0] * 8 for i in xrange(256)]
 84box[1][7] = 1
 85for i in xrange(2, 256):
 86    j = alog[255 - log[i]]
 87    for t in xrange(8):
 88        box[i][t] = (j >> (7 - t)) & 0x01
 89
 90B = [0, 1, 1, 0, 0, 0, 1, 1]
 91
 92# affine transform:  box[i] <- B + A*box[i]
 93cox = [[0] * 8 for i in xrange(256)]
 94for i in xrange(256):
 95    for t in xrange(8):
 96        cox[i][t] = B[t]
 97        for j in xrange(8):
 98            cox[i][t] ^= A[t][j] * box[i][j]
 99
100# S-boxes and inverse S-boxes
101S =  [0] * 256
102Si = [0] * 256
103for i in xrange(256):
104    S[i] = cox[i][0] << 7
105    for t in xrange(1, 8):
106        S[i] ^= cox[i][t] << (7-t)
107    Si[S[i] & 0xFF] = i
108
109# T-boxes
110G = [[2, 1, 1, 3],
111    [3, 2, 1, 1],
112    [1, 3, 2, 1],
113    [1, 1, 3, 2]]
114
115AA = [[0] * 8 for i in xrange(4)]
116
117for i in xrange(4):
118    for j in xrange(4):
119        AA[i][j] = G[i][j]
120        AA[i][i+4] = 1
121
122for i in xrange(4):
123    pivot = AA[i][i]
124    if pivot == 0:
125        t = i + 1
126        while AA[t][i] == 0 and t < 4:
127            t += 1
128            assert t != 4, 'G matrix must be invertible'
129            for j in xrange(8):
130                AA[i][j], AA[t][j] = AA[t][j], AA[i][j]
131            pivot = AA[i][i]
132    for j in xrange(8):
133        if AA[i][j] != 0:
134            AA[i][j] = alog[(255 + log[AA[i][j] & 0xFF] - log[pivot & 0xFF]) % 255]
135    for t in xrange(4):
136        if i != t:
137            for j in xrange(i+1, 8):
138                AA[t][j] ^= mul(AA[i][j], AA[t][i])
139            AA[t][i] = 0
140
141iG = [[0] * 4 for i in xrange(4)]
142
143for i in xrange(4):
144    for j in xrange(4):
145        iG[i][j] = AA[i][j + 4]
146
147def mul4(a, bs):
148    if a == 0:
149        return 0
150    r = 0
151    for b in bs:
152        r <<= 8
153        if b != 0:
154            r = r | mul(a, b)
155    return r
156
157T1 = []
158T2 = []
159T3 = []
160T4 = []
161T5 = []
162T6 = []
163T7 = []
164T8 = []
165U1 = []
166U2 = []
167U3 = []
168U4 = []
169
170for t in xrange(256):
171    s = S[t]
172    T1.append(mul4(s, G[0]))
173    T2.append(mul4(s, G[1]))
174    T3.append(mul4(s, G[2]))
175    T4.append(mul4(s, G[3]))
176
177    s = Si[t]
178    T5.append(mul4(s, iG[0]))
179    T6.append(mul4(s, iG[1]))
180    T7.append(mul4(s, iG[2]))
181    T8.append(mul4(s, iG[3]))
182
183    U1.append(mul4(t, iG[0]))
184    U2.append(mul4(t, iG[1]))
185    U3.append(mul4(t, iG[2]))
186    U4.append(mul4(t, iG[3]))
187
188# round constants
189rcon = [1]
190r = 1
191for t in xrange(1, 30):
192    r = mul(2, r)
193    rcon.append(r)
194
195del A
196del AA
197del pivot
198del B
199del G
200del box
201del log
202del alog
203del i
204del j
205del r
206del s
207del t
208del mul
209del mul4
210del cox
211del iG
212
213class rijndael:
214    def __init__(self, key, block_size = 16):
215        if block_size != 16 and block_size != 24 and block_size != 32:
216            raise ValueError('Invalid block size: ' + str(block_size))
217        if len(key) != 16 and len(key) != 24 and len(key) != 32:
218            raise ValueError('Invalid key size: ' + str(len(key)))
219        self.block_size = block_size
220
221        ROUNDS = num_rounds[len(key)][block_size]
222        BC = block_size / 4
223        # encryption round keys
224        Ke = [[0] * BC for i in xrange(ROUNDS + 1)]
225        # decryption round keys
226        Kd = [[0] * BC for i in xrange(ROUNDS + 1)]
227        ROUND_KEY_COUNT = (ROUNDS + 1) * BC
228        KC = len(key) / 4
229
230        # copy user material bytes into temporary ints
231        tk = []
232        for i in xrange(0, KC):
233            tk.append((ord(key[i * 4]) << 24) | (ord(key[i * 4 + 1]) << 16) |
234                (ord(key[i * 4 + 2]) << 8) | ord(key[i * 4 + 3]))
235
236        # copy values into round key arrays
237        t = 0
238        j = 0
239        while j < KC and t < ROUND_KEY_COUNT:
240            Ke[t / BC][t % BC] = tk[j]
241            Kd[ROUNDS - (t / BC)][t % BC] = tk[j]
242            j += 1
243            t += 1
244        tt = 0
245        rconpointer = 0
246        while t < ROUND_KEY_COUNT:
247            # extrapolate using phi (the round key evolution function)
248            tt = tk[KC - 1]
249            tk[0] ^= (S[(tt >> 16) & 0xFF] & 0xFF) << 24 ^  \
250                     (S[(tt >>  8) & 0xFF] & 0xFF) << 16 ^  \
251                     (S[ tt        & 0xFF] & 0xFF) <<  8 ^  \
252                     (S[(tt >> 24) & 0xFF] & 0xFF)       ^  \
253                     (rcon[rconpointer]    & 0xFF) << 24
254            rconpointer += 1
255            if KC != 8:
256                for i in xrange(1, KC):
257                    tk[i] ^= tk[i-1]
258            else:
259                for i in xrange(1, KC / 2):
260                    tk[i] ^= tk[i-1]
261                tt = tk[KC / 2 - 1]
262                tk[KC / 2] ^= (S[ tt        & 0xFF] & 0xFF)       ^ \
263                              (S[(tt >>  8) & 0xFF] & 0xFF) <<  8 ^ \
264                              (S[(tt >> 16) & 0xFF] & 0xFF) << 16 ^ \
265                              (S[(tt >> 24) & 0xFF] & 0xFF) << 24
266                for i in xrange(KC / 2 + 1, KC):
267                    tk[i] ^= tk[i-1]
268            # copy values into round key arrays
269            j = 0
270            while j < KC and t < ROUND_KEY_COUNT:
271                Ke[t / BC][t % BC] = tk[j]
272                Kd[ROUNDS - (t / BC)][t % BC] = tk[j]
273                j += 1
274                t += 1
275        # inverse MixColumn where needed
276        for r in xrange(1, ROUNDS):
277            for j in xrange(BC):
278                tt = Kd[r][j]
279                Kd[r][j] = U1[(tt >> 24) & 0xFF] ^ \
280                           U2[(tt >> 16) & 0xFF] ^ \
281                           U3[(tt >>  8) & 0xFF] ^ \
282                           U4[ tt        & 0xFF]
283        self.Ke = Ke
284        self.Kd = Kd
285
286    def encrypt(self, plaintext):
287        if len(plaintext) != self.block_size:
288            raise ValueError('wrong block length, expected ' + str(self.block_size) + ' got ' + str(len(plaintext)))
289        Ke = self.Ke
290
291        BC = self.block_size / 4
292        ROUNDS = len(Ke) - 1
293        if BC == 4:
294            SC = 0
295        elif BC == 6:
296            SC = 1
297        else:
298            SC = 2
299        s1 = shifts[SC][1][0]
300        s2 = shifts[SC][2][0]
301        s3 = shifts[SC][3][0]
302        a = [0] * BC
303        # temporary work array
304        t = []
305        # plaintext to ints + key
306        for i in xrange(BC):
307            t.append((ord(plaintext[i * 4    ]) << 24 |
308                      ord(plaintext[i * 4 + 1]) << 16 |
309                      ord(plaintext[i * 4 + 2]) <<  8 |
310                      ord(plaintext[i * 4 + 3])        ) ^ Ke[0][i])
311        # apply round transforms
312        for r in xrange(1, ROUNDS):
313            for i in xrange(BC):
314                a[i] = (T1[(t[ i           ] >> 24) & 0xFF] ^
315                        T2[(t[(i + s1) % BC] >> 16) & 0xFF] ^
316                        T3[(t[(i + s2) % BC] >>  8) & 0xFF] ^
317                        T4[ t[(i + s3) % BC]        & 0xFF]  ) ^ Ke[r][i]
318            t = copy.copy(a)
319        # last round is special
320        result = []
321        for i in xrange(BC):
322            tt = Ke[ROUNDS][i]
323            result.append((S[(t[ i           ] >> 24) & 0xFF] ^ (tt >> 24)) & 0xFF)
324            result.append((S[(t[(i + s1) % BC] >> 16) & 0xFF] ^ (tt >> 16)) & 0xFF)
325            result.append((S[(t[(i + s2) % BC] >>  8) & 0xFF] ^ (tt >>  8)) & 0xFF)
326            result.append((S[ t[(i + s3) % BC]        & 0xFF] ^  tt       ) & 0xFF)
327        return string.join(map(chr, result), '')
328
329    def decrypt(self, ciphertext):
330        if len(ciphertext) != self.block_size:
331            raise ValueError('wrong block length, expected ' + str(self.block_size) + ' got ' + str(len(plaintext)))
332        Kd = self.Kd
333
334        BC = self.block_size / 4
335        ROUNDS = len(Kd) - 1
336        if BC == 4:
337            SC = 0
338        elif BC == 6:
339            SC = 1
340        else:
341            SC = 2
342        s1 = shifts[SC][1][1]
343        s2 = shifts[SC][2][1]
344        s3 = shifts[SC][3][1]
345        a = [0] * BC
346        # temporary work array
347        t = [0] * BC
348        # ciphertext to ints + key
349        for i in xrange(BC):
350            t[i] = (ord(ciphertext[i * 4    ]) << 24 |
351                    ord(ciphertext[i * 4 + 1]) << 16 |
352                    ord(ciphertext[i * 4 + 2]) <<  8 |
353                    ord(ciphertext[i * 4 + 3])        ) ^ Kd[0][i]
354        # apply round transforms
355        for r in xrange(1, ROUNDS):
356            for i in xrange(BC):
357                a[i] = (T5[(t[ i           ] >> 24) & 0xFF] ^
358                        T6[(t[(i + s1) % BC] >> 16) & 0xFF] ^
359                        T7[(t[(i + s2) % BC] >>  8) & 0xFF] ^
360                        T8[ t[(i + s3) % BC]        & 0xFF]  ) ^ Kd[r][i]
361            t = copy.copy(a)
362        # last round is special
363        result = []
364        for i in xrange(BC):
365            tt = Kd[ROUNDS][i]
366            result.append((Si[(t[ i           ] >> 24) & 0xFF] ^ (tt >> 24)) & 0xFF)
367            result.append((Si[(t[(i + s1) % BC] >> 16) & 0xFF] ^ (tt >> 16)) & 0xFF)
368            result.append((Si[(t[(i + s2) % BC] >>  8) & 0xFF] ^ (tt >>  8)) & 0xFF)
369            result.append((Si[ t[(i + s3) % BC]        & 0xFF] ^  tt       ) & 0xFF)
370        return string.join(map(chr, result), '')
371
372def encrypt(key, block):
373    return rijndael(key, len(block)).encrypt(block)
374
375def decrypt(key, block):
376    return rijndael(key, len(block)).decrypt(block)
377
378def test():
379    def t(kl, bl):
380        b = 'b' * bl
381        r = rijndael('a' * kl, bl)
382        assert r.decrypt(r.encrypt(b)) == b
383    t(16, 16)
384    t(16, 24)
385    t(16, 32)
386    t(24, 16)
387    t(24, 24)
388    t(24, 32)
389    t(32, 16)
390    t(32, 24)
391    t(32, 32)
392