Why is pow(a, d, n) so much faster than a**d % n?

Each Answer to this Q is separated by one/two green lines.

I was trying to implement a Miller-Rabin primality test, and was puzzled why it was taking so long (> 20 seconds) for midsize numbers (~7 digits). I eventually found the following line of code to be the source of the problem:

x = a**d % n

(where a, d, and n are all similar, but unequal, midsize numbers, ** is the exponentiation operator, and % is the modulo operator)

I then I tried replacing it with the following:

x = pow(a, d, n)

and it by comparison it is almost instantaneous.

For context, here is the original function:

from random import randint

def primalityTest(n, k):
    if n < 2:
        return False
    if n % 2 == 0:
        return False
    s = 0
    d = n - 1
    while d % 2 == 0:
        s += 1
        d >>= 1
    for i in range(k):
        rand = randint(2, n - 2)
        x = rand**d % n         # offending line
        if x == 1 or x == n - 1:
            continue
        for r in range(s):
            toReturn = True
            x = pow(x, 2, n)
            if x == 1:
                return False
            if x == n - 1:
                toReturn = False
                break
        if toReturn:
            return False
    return True

print(primalityTest(2700643,1))

An example timed calculation:

from timeit import timeit

a = 2505626
d = 1520321
n = 2700643

def testA():
    print(a**d % n)

def testB():
    print(pow(a, d, n))

print("time: %(time)fs" % {"time":timeit("testA()", setup="from __main__ import testA", number=1)})
print("time: %(time)fs" % {"time":timeit("testB()", setup="from __main__ import testB", number=1)})

Output (run with PyPy 1.9.0):

2642565
time: 23.785543s
2642565
time: 0.000030s

Output (run with Python 3.3.0, 2.7.2 returns very similar times):

2642565
time: 14.426975s
2642565
time: 0.000021s

And a related question, why is this calculation almost twice as fast when run with Python 2 or 3 than with PyPy, when usually PyPy is much faster?

See the Wikipedia article on modular exponentiation. Basically, when you do a**d % n, you actually have to calculate a**d, which could be quite large. But there are ways of computing a**d % n without having to compute a**d itself, and that is what pow does. The ** operator can’t do this because it can’t “see into the future” to know that you are going to immediately take the modulus.

BrenBarn answered your main question. For your aside:

why is it almost twice as fast when run with Python 2 or 3 than PyPy, when usually PyPy is much faster?

If you read PyPy’s performance page, this is exactly the kind of thing PyPy is not good at—in fact, the very first example they give:

Bad examples include doing computations with large longs – which is performed by unoptimizable support code.

Theoretically, turning a huge exponentiation followed by a mod into a modular exponentiation (at least after the first pass) is a transformation a JIT might be able to make… but not PyPy’s JIT.

As a side note, if you need to do calculations with huge integers, you may want to look at third-party modules like gmpy, which can sometimes be much faster than CPython’s native implementation in some cases outside the mainstream uses, and also has a lot of additional functionality that you’d otherwise have to write yourself, at the cost of being less convenient.

There are shortcuts to doing modular exponentiation: for instance, you can find a**(2i) mod n for every i from 1 to log(d) and multiply together (mod n) the intermediate results you need. A dedicated modular-exponentiation function like 3-argument pow() can leverage such tricks because it knows you’re doing modular arithmetic. The Python parser can’t recognize this given the bare expression a**d % n, so it will perform the full calculation (which will take much longer).

The way x = a**d % n is calculated is to raise a to the d power, then modulo that with n. Firstly, if a is large, this creates a huge number which is then truncated. However, x = pow(a, d, n) is most likely optimized so that only the last n digits are tracked, which are all that are required for calculating multiplication modulo a number.


The answers/resolutions are collected from stackoverflow, are licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0 .

Leave a Reply

Your email address will not be published.