The Speed of Python

You should be aware that Python is slow compared to native code like Java (which is compiled at runtime). The factor is always more than 10 and can be more than 100. So be warned!

Often, it is possible to improve the runtime by switching the algorithm. Let us study the Collatz problem an example. There is no need to explain this problem here. You can easily find a description on Wikipedia. For a start, we want to find the longest sequence for any starting point less than a million, and later for hundred millions.

The following is a recursive function to get the length of the Collatz sequence starting at n.

def get_collatz_length (n):
    if n==1:
        return 0
    elif n%2:
        return get_collatz_length(3*n+1)+1
    return get_collatz_length(n//2)+1

To compute the lengths up to N=1’000’000 takes around 20 seconds. We added a code to compute the 10 worst starting points.

import time
start = time.process_time()
N=1_000_000
lengths = [(n,get_collatz_length(n)) for n in range(1,N)]
print(f"{time.process_time()-start:5.4} seconds")

records = sorted(lengths,key=lambda x:x[1],reverse=True)
for k in range(0,11):
    print(f"{records[k][0]}: {records[k][1]}")
18.72 seconds
837799: 524
626331: 508
939497: 506
704623: 503
910107: 475
927003: 475
511935: 469
767903: 467
796095: 467
970599: 457
546681: 451

A comparable Java code is the following.

public class Main 
{
	
	public static int get_collatz_length (long n)
	{
		int res = 0;
		if (n>1)
		{
			if (n%2 == 1)
				res = get_collatz_length(3*n+1)+1;
			else
				res = get_collatz_length(n/2)+1;
		}
		return res;
	}

	public static void main(String[] args) 
	{
		long time = System.currentTimeMillis();
		int N=1000000;
		int L[] = new int[N];
		for (int k=0; k<N; k++)
			L[k] = get_collatz_length(k);
		int max=0,imax=0;
		for (int k=1; k<N; k++)
			if (L[k]>max)
			{
				max=L[k]; imax=k;
			}
		System.out.println("Maximal value " + max + " at "+imax);
		time = System.currentTimeMillis()-time;
		System.out.println((time/1000.0) + " seconds.");
		
	}

}
Maximal value 524 at 837799
0.498 seconds.

This takes about half a second. The factor is 40.

What we also see is that the Python code looks more elegant and mighty. E.g., sorting the array of lengths and printing the 10 highest is done with three lines of code. In Java, the same can only be achieved by writing a Comparator class.

Consequently, I see Python as an interactive script language to do math. If speed is needed, native code must be used, as in Numpy. Unfortunately, the user cannot easily write a function in C or Java.

Can we speed up the process? We could try to avoid the doubled computations by remembering and using the already known lengths in a dictionary. The following code does this.

collatz_lengths = {}

def get_collatz_length (n, nmax):
    res = collatz_lengths.get(n)
    if res:
        return res 
    elif n==1:
        res = 0
    elif n%2:
        res = get_collatz_length(3*n+1,nmax)+1
    else: 
        res = get_collatz_length(n//2,nmax)+1
    if n<nmax:
        collatz_lengths[n] = res
    return res

The storage of the lengths is only done if n is not too high. For the call, we have set this maximal value equal to N=1’000’000. This code reduces the time to about 1 second. We gained a factor of 20 just by this simple trick. Obviously, an idea might be worth more than a compiled code.

Maybe, we can do oven better, if we do not use a dictionary. Instead, we can use a list and store the known lengths in it. Let us try.

def get_collatz_length (n, L, N):
    if n<N and L[n]>=0:
        return L[n]
    if n==1:
        res = 0
    elif n%2:
        res = get_collatz_length(3*n+1,L,N)+1
    else: 
        res = get_collatz_length(n//2,L,N)+1
    if n<N:
        L[n]=res
    return res

import time
start = time.process_time()
N=1_000_000
L = list(range(N))
for k in range(N):
    L[k] = -1
lengths = [(n,get_collatz_length(n,L,N)) for n in range(1,N)]
records = sorted(lengths,key=lambda x:x[1],reverse=True)
for k in range(11):
    print(f"{records[k][0]}: {records[k][1]}")
print(f"{time.process_time()-start:5.4} seconds")

This reduces the time only slightly. It is still one second.

However, implementing the same trick in Java, we get the lengths up to one million in 0.033 seconds, i.e., 30 times faster. The reason for this is partly the infinite integer type of Python, and also the use of a dictionary instead of an array. Here is the Java code.

public class Main 
{
	
	public static int get_collatz_length (long n, int L[], int N)
	{
		if (n<N && L[(int)n]>=0)
			return L[(int)n];
		int res = 0;
		if (n>1)
		{
			if (n%2 == 1)
				res = get_collatz_length(3*n+1,L,N)+1;
			else
				res = get_collatz_length(n/2,L,N)+1;
		}
		if (n<N) L[(int)n] = res;
		return res;
	}

	public static void main(String[] args) 
	{
		long time = System.currentTimeMillis();
		int N=100000000;
		int L[] = new int[N];
		for (int k=0; k<N; k++)
			L[k] = -1;
		for (int k=0; k<N; k++)
			L[k] = get_collatz_length(k,L,N);
		int max=0,imax=0;
		for (int k=1; k<N; k++)
			if (L[k]>max)
			{
				max=L[k]; imax=k;
			}
		System.out.println("Maximal value " + max + " at "+imax);
		time = System.currentTimeMillis()-time;
		System.out.println((time/1000.0) + " seconds.");
		
	}

}

Increasing N to 10’000’000, the Java computation time rises to 0.15 seconds. The more clever Python code was faster now and took 9 seconds, i.e., it was 60 times slower.

Going further to N=100’000’000 made Java run for 1.44 seconds. The more clever code in Python, however, broke down and took more than 2 minutes, again a factor of over 50.

Finally, I wanted a plot of the frequencies of the lengths. I had to write a procedure for the count, not trusting the performance of the Counter class.

lengths1 = [x[1] for x in lengths]
M = max(lengths1)
def get_count (L,M):
    count = [0 for k in range(M)]
    for x in L:
        if x<M:
            count[x] += 1
    return count
count = get_count(lengths1,M+1)

This allows to plot the frequencies. The vertical axis is the number of times a length occurs among the first 100 million numbers. The same plot can be found on Wikipedia.

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

fig,ax = plt.subplots(figsize=(14,6))
ax.grid(True)
ax.plot(count);

In summary, Python did well. But it seems to be so flexible that it is not predictable in terms of performance and memory consumption. That is so with many modern software systems who develop a mind of their own from the perspective of a user.

Schreibe einen Kommentar

Deine E-Mail-Adresse wird nicht veröffentlicht.

Diese Website verwendet Akismet, um Spam zu reduzieren. Erfahre mehr darüber, wie deine Kommentardaten verarbeitet werden.