Challenge 23: Clone an MT19937 RNG from its output

The internal state of MT19937 consists of 624 32 bit integers.

For each batch of 624 outputs, MT permutes that internal state. By permuting state regularly, MT19937 achieves a period of 2**19937, which is Big.

Each time MT19937 is tapped, an element of its internal state is subjected to a tempering function that diffuses bits through the result.

The tempering function is invertible; you can write an "untemper" function that takes an MT19937 output and transforms it back into the corresponding element of the MT19937 state array.

To invert the temper transform, apply the inverse of each of the operations in the temper transform in reverse order. There are two kinds of operations in the temper transform each applied twice; one is an XOR against a right-shifted value, and the other is an XOR against a left-shifted value AND'd with a magic number. So you'll need code to invert the "right" and the "left" operation.

Once you have "untemper" working, create a new MT19937 generator, tap it for 624 outputs, untemper each of them to recreate the state of the generator, and splice that state into a new instance of the MT19937 generator.

The new "spliced" generator should predict the values of the original.

Stop and think for a second: How would you modify MT19937 to make this attack hard? What would happen if you subjected each tempered output to a cryptographic hash?

Here is the tempering function from MT19937:

In [1]:
def MT_temper(x):
    (w, n, m, r) = (32, 624, 397, 31)
    a = 0x9908B0DF
    (u, d) = (11, 0xFFFFFFFF)
    (s, b) = (7, 0x9D2C5680)
    (t, c) = (15, 0xEFC60000)
    l = 18
    f = 1812433253
    
    y = x ^ ((x >> u) & d)
    y = y ^ ((y << s) & b)
    y = y ^ ((y << t) & c)
    return y ^ (y >> l)

As you can see it's 4 operations that are more or less similar: the value is shifted and a mask is applied, then the result is XORed against the previous value.

The instructions tell us these operations are invertible, so let's have a look.

Let $ y = x \oplus ( x << s \and m ) $ where $ s $ is the shift amount and $ m $ the mask and let's try to find $ x $ from $ y $ and known parameters.

We will represents bits of numbers as $ y = Y_0 ~ Y_1 ~ \dots ~ Y_{w-1} $ or Y0 Y1 ... Yw-1

The operation to invert can then be represented as

  X2 X3 X4 X5 0  0 
& M0 M1 M2 M3 M4 M5
⊕ X0 X1 X2 X3 X4 X5
= Y0 Y1 Y2 Y3 Y4 Y5

This can be written in a more mathematical way as:

$$ Y_n = \begin{cases} X_n \oplus ( X_{n + ds} \and M_n ) & \text{ if } 0 \leq n+ds \leq w-1\\ X_n & \text{ otherwise} \end{cases} $$

Where $ d \in \{-1, 1\} $ represents the shift direction (left or right).

To invert we can then just reverse the previous expression and have:

$$ X_n = \begin{cases} Y_n \oplus ( X_{n + ds} \and M_n ) & \text{ if } 0 \leq n+ds \leq w-1\\ Y_n & \text{ otherwise} \end{cases} $$

The strategy is then to recover $ x $ bit-by-bit: we start with the ones that are trivial to recover because shifting makes sure that they did not change from $ x $ to $ y $ (bits X4 and X5 in our example).

We are then able to recover (X2 and X3 in our example) because their value can be computed from the formula and the value of the bits we just recovered, etc.

I tried to find a "more elegant way" to do this, i.e. not bit-by-bit but instead by combining $ y $ and $ m $. Turns out it's possible but not at all "more elegant". I spent quite some time for nothing.

Looking at other solutions everyone seems to use Python's bitwise operations on integers and this leads to solutions that I think are not very elegant. Several things I wanted to improve:

  • have function work on every set of parameters (insead of relying on hardcoded constants)
  • avoiding having one function to invert left shift and another one for right shift

So I used the following strategy: numbers are converted to lists of bits first because lists are so much easier to work with (especially if you are trying to access bits separately).

Note how much the core logic (with if n < shift:) ressembles the math formula above! The power of lists makes it so much easier to reason about individual bits.

In [2]:
# using a function that declares internal functions every time
# is not the most elegant thing to do,
# but it just felt wrong to create an object
# for something (untempering) that does not require any state
def untemper(y):
    (w, n, m, r) = (32, 624, 397, 31)
    a = 0x9908B0DF
    (u, d) = (11, 0xFFFFFFFF)
    (s, b) = (7, 0x9D2C5680)
    (t, c) = (15, 0xEFC60000)
    l = 18
    f = 1812433253

    def int_to_bit_list(x):
        return [int(b) for b in '{:032b}'.format(x)]

    def bit_list_to_int(l):
        return int(''.join(str(x) for x in l), base=2)

    def invert_shift_mask_xor(y, direction, shift, mask=0xFFFFFFFF):
        y = int_to_bit_list(y)
        mask = int_to_bit_list(mask)

        if direction == 'left':
            y.reverse()
            mask.reverse()
        else:
            assert direction == 'right'

        x = [None]*32
        for n in range(32):
            if n < shift:
                x[n] = y[n]
            else:
                x[n] = y[n] ^ (mask[n] & x[n-shift])

        if direction == 'left':
            x.reverse()

        return bit_list_to_int(x)

    xx = y
    xx = invert_shift_mask_xor(xx, direction='right', shift=l)
    xx = invert_shift_mask_xor(xx, direction='left', shift=t, mask=c)
    xx = invert_shift_mask_xor(xx, direction='left', shift=s, mask=b)
    xx = invert_shift_mask_xor(xx, direction='right', shift=u, mask=d)

    return xx

# testing
from random import randint
from libmatasano import html_test
for _ in range(10):
    x = randint(0, 0xFFFFFFF)
    y = MT_temper(x)
    assert untemper(y) == x
    
html_test(True)
OK
In [3]:
# cloning time!
from libmatasano import MT19937_32
prng = MT19937_32()

state = [untemper(next(prng)) for _ in range(624)] 

cloned_prng = MT19937_32(state=state)

for _ in range(20):
    assert next(prng) == next(cloned_prng)

html_test(True)
OK