
/* This is an independent implementation of the Twofish encryption  */
/* algorithm designed by Bruce Schneier and colleagues and offered  */
/* as a candidate algorithm for the US NIST Advanced Encryption     */
/* Standard (AES) effort.                                           */
/*                                                                  */
/* Copyright in this implementation is held by Dr B R Gladman but   */
/* I hereby give permission for its free direct or derivative use   */
/* subject to acknowledgment of its origin.                         */
/*                                                                  */
/* My thanks to Niels Ferguson and the Twofish team for suggesting  */
/* an additional optimisation for this code                         */ 
/*                                                                  */
/* Dr Brian Gladman (gladman@seven77.demon.co.uk) 18th October 1998 */
/*                                                                  */
/* Timing data:

Algorithm: twofish (twofish3.c)
128 bit key:
Key Setup:   16333 cycles
Encrypt:       396 cycles =    64.6 mbits/sec
Decrypt:       397 cycles =    64.5 mbits/sec
Mean:          396 cycles =    64.6 mbits/sec
192 bit key:
Key Setup:   23378 cycles
Encrypt:       396 cycles =    64.6 mbits/sec
Decrypt:       398 cycles =    64.3 mbits/sec
Mean:          397 cycles =    64.5 mbits/sec
256 bit key:
Key Setup:   24791 cycles
Encrypt:       396 cycles =    64.6 mbits/sec
Decrypt:       395 cycles =    64.8 mbits/sec
Mean:          396 cycles =    64.7 mbits/sec


*/

#include "../std_defs.h"

#define Q_TABLES
#define M_TABLE
#define MK_TABLE
#define ONE_STEP

static char *alg_name[] = { "twofish", "twofish3.c" };

char **cipher_name()
{
    return alg_name;
}

u4byte  k_len;
u4byte  l_key[40];
u4byte  s_key[4];

/* finite field arithmetic for GF(2**8) with the modular    */
/* polynomial x**8 + x**6 + x**5 + x**3 + 1 (0x169)         */

#define G_M 0x0169

u1byte  tab_5b[4] = { 0, G_M >> 2, G_M >> 1, (G_M >> 1) ^ (G_M >> 2) };
u1byte  tab_ef[4] = { 0, (G_M >> 1) ^ (G_M >> 2), G_M >> 1, G_M >> 2 };

#define ffm_01(x)    (x)
#define ffm_5b(x)   ((x) ^ ((x) >> 2) ^ tab_5b[(x) & 3])
#define ffm_ef(x)   ((x) ^ ((x) >> 1) ^ ((x) >> 2) ^ tab_ef[(x) & 3])

u1byte ror4[16] = { 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15 };
u1byte ashx[16] = { 0, 9, 2, 11, 4, 13, 6, 15, 8, 1, 10, 3, 12, 5, 14, 7 };

u1byte qt0[2][16] = 
{   { 8, 1, 7, 13, 6, 15, 3, 2, 0, 11, 5, 9, 14, 12, 10, 4 },
    { 2, 8, 11, 13, 15, 7, 6, 14, 3, 1, 9, 4, 0, 10, 12, 5 }
};

u1byte qt1[2][16] =
{   { 14, 12, 11, 8, 1, 2, 3, 5, 15, 4, 10, 6, 7, 0, 9, 13 }, 
    { 1, 14, 2, 11, 4, 12, 3, 7, 6, 13, 10, 5, 15, 9, 0, 8 }
};

u1byte qt2[2][16] = 
{   { 11, 10, 5, 14, 6, 13, 9, 0, 12, 8, 15, 3, 2, 4, 7, 1 },
    { 4, 12, 7, 5, 1, 6, 9, 10, 0, 14, 13, 8, 2, 11, 3, 15 }
};

u1byte qt3[2][16] = 
{   { 13, 7, 15, 4, 1, 2, 6, 14, 9, 11, 3, 0, 8, 5, 12, 10 },
    { 11, 9, 5, 1, 12, 3, 13, 14, 6, 4, 7, 15, 2, 0, 8, 10 }
};
 
u1byte  qp(const u4byte n, const u1byte x)
{   u1byte  a0, a1, a2, a3, a4, b0, b1, b2, b3, b4;

    a0 = x >> 4; b0 = x & 15;
    a1 = a0 ^ b0; b1 = ror4[b0] ^ ashx[a0];
    a2 = qt0[n][a1]; b2 = qt1[n][b1];
    a3 = a2 ^ b2; b3 = ror4[b2] ^ ashx[a2];
    a4 = qt2[n][a3]; b4 = qt3[n][b3];
    return (b4 << 4) | a4;
};

#ifdef  Q_TABLES

u4byte  qt_gen = 0;
u1byte  q_tab[2][256];

#define q(n,x)  q_tab[n][x]

void gen_qtab(void)
{   u4byte  i;

    for(i = 0; i < 256; ++i)
    {       
        q(0,i) = qp(0, (u1byte)i);
        q(1,i) = qp(1, (u1byte)i);
    }
};

#else

#define q(n,x)  qp(n, x)

#endif

#ifdef  M_TABLE

u4byte  mt_gen = 0;
u4byte  m_tab[4][256];

void gen_mtab(void)
{   u4byte  i, f01, f5b, fef;
    
    for(i = 0; i < 256; ++i)
    {
        f01 = q(1,i); f5b = ffm_5b(f01); fef = ffm_ef(f01);
        m_tab[0][i] = f01 + (f5b << 8) + (fef << 16) + (fef << 24);
        m_tab[2][i] = f5b + (fef << 8) + (f01 << 16) + (fef << 24);

        f01 = q(0,i); f5b = ffm_5b(f01); fef = ffm_ef(f01);
        m_tab[1][i] = fef + (fef << 8) + (f5b << 16) + (f01 << 24);
        m_tab[3][i] = f5b + (f01 << 8) + (fef << 16) + (f5b << 24);
    }
};

#define mds(n,x)    m_tab[n][x]

#else

#define fm_00   ffm_01
#define fm_10   ffm_5b
#define fm_20   ffm_ef
#define fm_30   ffm_ef
#define q_0(x)  q(1,x)

#define fm_01   ffm_ef
#define fm_11   ffm_ef
#define fm_21   ffm_5b
#define fm_31   ffm_01
#define q_1(x)  q(0,x)

#define fm_02   ffm_5b
#define fm_12   ffm_ef
#define fm_22   ffm_01
#define fm_32   ffm_ef
#define q_2(x)  q(1,x)

#define fm_03   ffm_5b
#define fm_13   ffm_01
#define fm_23   ffm_ef
#define fm_33   ffm_5b
#define q_3(x)  q(0,x)

#define f_0(n,x)    ((u4byte)fm_0##n(x))
#define f_1(n,x)    ((u4byte)fm_1##n(x) << 8)
#define f_2(n,x)    ((u4byte)fm_2##n(x) << 16)
#define f_3(n,x)    ((u4byte)fm_3##n(x) << 24)

#define mds(n,x)    f_0(n,q_##n(x)) ^ f_1(n,q_##n(x)) ^ f_2(n,q_##n(x)) ^ f_3(n,q_##n(x))

#endif

u4byte h_fun(const u4byte x, const u1byte key[])
{   u1byte  b0, b1, b2, b3;

    b0 = byte(x, 0); b1 = byte(x, 1); b2 = byte(x, 2); b3 = byte(x, 3);

    switch(k_len)
    {
        case 4:
            b0 = q(1,b0) ^ key[12];
            b1 = q(0,b1) ^ key[13];
            b2 = q(0,b2) ^ key[14];
            b3 = q(1,b3) ^ key[15];
        case 3: /* fall through */
            b0 = q(1,b0) ^ key[ 8];
            b1 = q(1,b1) ^ key[ 9];
            b2 = q(0,b2) ^ key[10];
            b3 = q(0,b3) ^ key[11];
        case 2: /* fall through */
            b0 = q(0,q(0,b0) ^ key[4]) ^ key[0];
            b1 = q(0,q(1,b1) ^ key[5]) ^ key[1];
            b2 = q(1,q(0,b2) ^ key[6]) ^ key[2];
            b3 = q(1,q(1,b3) ^ key[7]) ^ key[3];
    }

#ifdef  M_TABLE

    return  mds(0, b0) ^ mds(1, b1) ^ mds(2, b2) ^ mds(3, b3);

#else

    b0 = q(1,b0); b1 = q(0,b1); b2 = q(1,b2); b3 = q(0,b3);

    return (u4byte)(       b0  ^ ffm_ef(b1) ^ ffm_5b(b2) ^ ffm_5b(b3))       ^
           (u4byte)(ffm_5b(b0) ^ ffm_ef(b1) ^ ffm_ef(b2) ^        b3 ) <<  8 ^
           (u4byte)(ffm_ef(b0) ^ ffm_5b(b1) ^        b2  ^ ffm_ef(b3)) << 16 ^
           (u4byte)(ffm_ef(b0) ^        b1  ^ ffm_ef(b2) ^ ffm_5b(b3)) << 24;
#endif
};

#ifdef  MK_TABLE

#ifdef  ONE_STEP
u4byte  mk_tab[4][256];
#else
u1byte  sb[4][256];
#endif

gen_mk_tab(u1byte key[])
{   u1byte  b0, b1, b2, b3;
    u4byte  i;

    for(i = 0; i < 256; ++i)
    {
        b0 = b1 = b2 = b3 = i;

        switch(k_len)
        {
            case 4:
                b0 = q(1,b0) ^ key[12]; b1 = q(0,b1) ^ key[13];
                b2 = q(0,b2) ^ key[14]; b3 = q(1,b3) ^ key[15];
            case 3: /* fall through */
                b0 = q(1,b0) ^ key[ 8]; b1 = q(1,b1) ^ key[ 9];
                b2 = q(0,b2) ^ key[10]; b3 = q(0,b3) ^ key[11];
            case 2: /* fall through */
                b0 = q(0,q(0,b0) ^ key[4]) ^ key[0]; 
                b1 = q(0,q(1,b1) ^ key[5]) ^ key[1]; 
                b2 = q(1,q(0,b2) ^ key[6]) ^ key[2]; 
                b3 = q(1,q(1,b3) ^ key[7]) ^ key[3];
        }
#ifdef ONE_STEP
        mk_tab[0][i] = mds(0, b0);
        mk_tab[1][i] = mds(1, b1);
        mk_tab[2][i] = mds(2, b2);
        mk_tab[3][i] = mds(3, b3);
#else
        sb[0][i] = b0;
        sb[1][i] = b1;
        sb[2][i] = b2;
        sb[3][i] = b3;
#endif
    }
};

#  ifdef ONE_STEP
#    define g0_fun(x)   mk_tab[0][byte(x,0)] ^ mk_tab[1][byte(x,1)] \
                      ^ mk_tab[2][byte(x,2)] ^ mk_tab[3][byte(x,3)] 
#    define g1_fun(x)   mk_tab[0][byte(x,3)] ^ mk_tab[1][byte(x,0)] \
                      ^ mk_tab[2][byte(x,1)] ^ mk_tab[3][byte(x,2)] 
#  else
#    define g0_fun(x)   mds(0, sb[0][byte(x,0)]) ^ mds(1, sb[1][byte(x,1)]) \
                      ^ mds(2, sb[2][byte(x,2)]) ^ mds(3, sb[3][byte(x,3)]) 
#    define g1_fun(x)   mds(0, sb[0][byte(x,3)]) ^ mds(1, sb[1][byte(x,0)]) \
                      ^ mds(2, sb[2][byte(x,1)]) ^ mds(3, sb[3][byte(x,2)]) 
#  endif

#else

#define g0_fun(x)   h_fun(x,(u1byte*)s_key)
#define g1_fun(x)   h_fun(rotl(x,8),(u1byte*)s_key)

#endif

/* The (12,8) Reed Soloman code has the generator polynomial

  g(x) = x**4 + (a + 1/a) * x**3 + a * x**2 + (a + 1/a) * x + 1

where the coefficients are in the finite field GF(2**8) with a
modular polynomial a**8 + a**6 + a**3 + a**2 + 1. To generate the
remainder we have to start with a 12th order polynomial with our
eight input bytes as the coefficients of the 4th to 11th terms. 
That is:

  m[7] * x**11 + m[6] * x**10 ... + m[0] * x**4 + 0 * x**3 +... + 0
  
We then multiply the generator polynomial by m[7] * x**7 and subtract
it - xor in GF(2**8) - from the above to eliminate the x**7 term (the 
artihmetic on the coefficients is done in GF(2**8). We then multiply 
the generator polynomial by x**6 * coeff(x**10) and use this to remove
the x**10 term. We carry on in this way until the x**4 term is removed
so that we are left with:

  r[3] * x**3 + r[2] * x**2 + r[1] 8 x**1 + r[0]

which give the resulting 4 bytes of the remainder. This is equivalent 
to the matrix multiplication in the Twofish description but much faster 
to implement.

*/

#define G_MOD   0x0000014d

u4byte mds_rem(u4byte p0, u4byte p1)
{   u4byte  i, t, u;

    for(i = 0; i < 8; ++i)
    {
        t = p1 >> 24; 
        
        p1 = (p1 << 8) | (p0 >> 24); p0 <<= 8;
            
        u = (t << 1) ^ (t & 0x80 ? G_MOD : 0); 

        p1 ^= t ^ (u << 16);

        u ^= (t >> 1) ^ (t & 0x01 ? G_MOD >> 1 : 0);

        p1 ^= (u << 8) | (u << 24);
    }

    return p1;
};

/* initialise the key schedule from the user supplied key   */

u4byte *set_key(const u4byte in_key[], const u4byte key_len)
{   u4byte  i, a, b, me_key[4], mo_key[4];

#ifdef Q_TABLES
    if(!qt_gen)
    {
        gen_qtab(); qt_gen = 1;
    }
#endif

#ifdef M_TABLE
    if(!mt_gen)
    {
        gen_mtab(); mt_gen = 1;
    }
#endif

    k_len = key_len / 64;   /* 2, 3 or 4 */

    for(i = 0; i < k_len; ++i)
    {
        a = in_key[i + i];     me_key[i] = a;
        b = in_key[i + i + 1]; mo_key[i] = b;
        s_key[k_len - i - 1] = mds_rem(a, b);
    }

    for(i = 0; i < 40; i += 2)
    {
        a = 0x01010101 * i; b = a + 0x01010101;
        a = h_fun(a, (u1byte*)me_key);
        b = rotl(h_fun(b, (u1byte*)mo_key), 8);
        l_key[i] = a + b;
        l_key[i + 1] = rotl(a + 2 * b, 9);
    }

#ifdef MK_TABLE
    gen_mk_tab((u1byte*)s_key);
#endif

    return l_key;
};

/* encrypt a block of text  */

#define f_rnd(i)                                                    \
    t0 = g0_fun(blk[0]); t1 = g1_fun(blk[1]);                       \
    blk[2] = rotr(blk[2] ^ (t0 + t1 + l_key[4 * (i) + 8]), 1);      \
    blk[3] = rotl(blk[3], 1) ^ (t0 + 2 * t1 + l_key[4 * (i) + 9]);  \
    t0 = g0_fun(blk[2]); t1 = g1_fun(blk[3]);                       \
    blk[0] = rotr(blk[0] ^ (t0 + t1 + l_key[4 * (i) + 10]), 1);     \
    blk[1] = rotl(blk[1], 1) ^ (t0 + 2 * t1 + l_key[4 * (i) + 11])

void encrypt(const u4byte in_blk[4], u4byte out_blk[])
{   u4byte  t0, t1, blk[4];

    blk[0] = in_blk[0] ^ l_key[0];
    blk[1] = in_blk[1] ^ l_key[1];
    blk[2] = in_blk[2] ^ l_key[2];
    blk[3] = in_blk[3] ^ l_key[3];

    f_rnd(0); f_rnd(1); f_rnd(2); f_rnd(3);
    f_rnd(4); f_rnd(5); f_rnd(6); f_rnd(7);

    out_blk[0] = blk[2] ^ l_key[4];
    out_blk[1] = blk[3] ^ l_key[5];
    out_blk[2] = blk[0] ^ l_key[6];
    out_blk[3] = blk[1] ^ l_key[7]; 
};

/* decrypt a block of text  */

#define i_rnd(i)                                                        \
        t0 = g0_fun(blk[0]); t1 = g1_fun(blk[1]);                       \
        blk[2] = rotl(blk[2], 1) ^ (t0 + t1 + l_key[4 * (i) + 10]);     \
        blk[3] = rotr(blk[3] ^ (t0 + 2 * t1 + l_key[4 * (i) + 11]), 1); \
        t0 = g0_fun(blk[2]); t1 = g1_fun(blk[3]);                       \
        blk[0] = rotl(blk[0], 1) ^ (t0 + t1 + l_key[4 * (i) +  8]);     \
        blk[1] = rotr(blk[1] ^ (t0 + 2 * t1 + l_key[4 * (i) +  9]), 1)

void decrypt(const u4byte in_blk[4], u4byte out_blk[4])
{   u4byte  t0, t1, blk[4];

    blk[0] = in_blk[0] ^ l_key[4];
    blk[1] = in_blk[1] ^ l_key[5];
    blk[2] = in_blk[2] ^ l_key[6];
    blk[3] = in_blk[3] ^ l_key[7];

    i_rnd(7); i_rnd(6); i_rnd(5); i_rnd(4);
    i_rnd(3); i_rnd(2); i_rnd(1); i_rnd(0);

    out_blk[0] = blk[2] ^ l_key[0];
    out_blk[1] = blk[3] ^ l_key[1];
    out_blk[2] = blk[0] ^ l_key[2];
    out_blk[3] = blk[1] ^ l_key[3]; 
};
