21

Challenge 21: Implement the MT19937 Mersenne Twister RNG

You can get the psuedocode for this from Wikipedia.

If you're writing in Python, Ruby, or (gah) PHP, your language is probably already giving you MT19937 as "rand()"; don't use rand(). Write the RNG yourself.

Here is the Wikipedia article (see Section “Algorithmic detail”):

https://en.m.wikipedia.org/wiki/Mersenne_Twister

The description may be a bit impressive, but note that first we don't have to pick parameters $ w, m, n, r \dots$, they are already fixed, and second we don't have to actually construct the A and T matrices: applying these matrices is equivalent to simply applying a few bitwise operations that are explicitely given in the Wikipedia article.

After removing all this complexity, this is what we are left with:

  • MT19937-32 keeps a state in the form of an array of $ n $ 32-bits integers that must be initialized before we output the first pseudo-random number.
  • Initializing the state is done as follow: the first value is the seed, then each value is computed from the previous value in the state using some recurrence formula involving XOR, right-shift and integer multiplication (with the output of the multiplication being truncated to fit into 32 bits)
  • then, we can finally ouput our first pseudo-random number: we compute a value noted $ x $ that is computed from the first value in the state (at first it's the seed, though), the second value in the state, and third value somewhere in the state (which one exactly is controlled by the $ m $ parameters, set to 397 apparently)
  • this $ x $ value is not the number that we output. before outputing it we apply what they call a tempering transform (the T matrix, that is some more bitwise operations). It is the output of the tempering transform that we output.
  • the value $ x $ is the one that is inserted at the end of the state, while the first value of the state is popped out. this way, the next $ x $ value will be computed out of the different values (the former second, the former third and the former $(m+1)$-nth value of the state)

Example code on Wikipedia and some other write-ups use what seems to be some sort of "optimization": they precompute the $ n $ first "$ x $" values right after initialization (or when the generator is asked for a first output value).

It may be interesting for performance (I'm not even sure it makes a big difference) but I am more focused on readability and elegance here, so instead I chose to generate a new $ x $ only when needed, that is, at each call.

Also this is the perfect opportunity to use the yield keyword in Python that allows to easily create "generators", i.e. objects that can be called several times and keep a state between two calls. For more information about generators and the yield keyword, see the Python Documentation and the related PEP

I generated some test data using the MT19937 implementation provided with C++ because Python does not seems to provide access to the numbers outputed by it directly.

Actually I used the code from here:

https://www.guyrutenberg.com/2014/05/03/c-mt19937-example/

Put a for loop in it, and executed it in the online C++ shell:

http://www.cpp.sh/

In [1]:
from libmatasano import html_test
In [2]:
def MT19937_32(seed=5489):
    '''Mersenne-Twister PRNG, 32-bit version'''
    # parameters for MT19937-32
    (w, n, m, r) = (32, 624, 397, 31)
    a = 0x9908B0DF
    (u, d) = (11, 0xFFFFFFFF)
    (s, b) = (7, 0x9D2C5680)
    (t, c) = (15, 0xEFC60000)
    l = 18
    f = 1812433253

    # masks (to apply with an '&' operator)
    # ---------------------------------------
    # zeroes out all bits except "the w-r highest bits"
    # (i.e. with our parameters the single highest bit, since w-r=1)
    high_mask = ((1<<w) - 1) - ((1<<r) - 1)
    # zeroes out all bits excepts "the r lowest bits"
    low_mask = (1<<r)-1

    def twist(x):
        return (x >> 1)^a if (x % 2 == 1) else x >> 1

    # initialization (populating the state)
    state = list()
    state.append(seed)
    for i in range(1, n):
        prev = state[-1]
        # the "& d" is to take only the lowest 32 bits of the result
        x = (f * (prev ^ (prev >> (w-2))) + i) & d
        state.append(x)

    while True:
        x = state[m] ^ twist((state[0] & high_mask) + (state[1] & low_mask))

        # tempering transform and output
        y = x ^ ((x >> u) & d)
        y = y ^ ((y << s) & b)
        y = y ^ ((y << t) & c)
        yield y ^ (y >> l)

        # note that it's the 'x' value
        # that we insert in the state
        state.pop(0)
        state.append(x)
In [3]:
import json

with open('data/21.json') as f:
    test_data = json.load(f)

for key in test_data:
    seed = int(key)
    assert all(x==y for (x,y) in zip(MT19937_32(seed), test_data[key]))
    
html_test(True)
OK

Tadaa !

I had to debug my code a little bit though because at first I did a bad job at "selecting the w-r highest bits of one and the r lowets bits of another". It was nice to get some practice at bitwise operations.

I was lucky I could take the Python implementation of someone else (https://github.com/akalin/cryptopals-python3) and run it with the Python debugger to see exactly where our values where diverging and spot my mistakes. This is the reason why in some specifications they give all the intermediate values for an example value, for instance in the AES specs (See Appendix C – Example Vectors).