Optimizing Lua Code – Longest Collatz Sequence

Lua is often praised for its speed and simplicity, however, unless using LuaJIT or equivalent, performance is still far from low-level or JIT-languages, such as C/C++, Java, Rust, etc. In this article I am going to demonstrate some optimization techniques for the standard Lua implementation. I will use Lua 5.3 because of its native support for bitwise operators. Testing will be done by repeatedly using the Linux time command. Lua startup time will not be subtracted unless it has a significant effect on the results. Distinction between real, user and sys times will also be ignored, unless significant. All results are rounded to first reliable figure.

The Problem

A Collatz sequence is the sequence of natural numbers obtained by recursively applying the following function to a given natural number: if n is odd, yield 3n+1, else yield n/2. The Collatz conjecture states that the sequence will eventually reach 1 for any given input. Given a natural number N, output the length of the longest Collatz sequence for inputs in from 1 to N (inclusive).

Naive approach

We will start by writing a simple, iterative algorithm:

local N = 1000000

local function collatz(n)
	local len = 0
	while n ~= 1 do
		if n % 2 == 1 then
			n = 3*n + 1
		else
			n = n / 2
		end
		len = len + 1
	end
	return len
end

local max = 0
for n = 1, N do
	local c = collatz(n)
	if c > max then
		max = c
	end
end

print(max)

This program takes 26 seconds to finish. In comparison, the same algorithm written in C and compiled with clang -O9 only takes 0.274 seconds. (Interestingly, gcc absolutely fails to optimize the code and takes 0.448 seconds)

Basic optimizations

The code has some obvious performance issues. First, modulo is a very inefficient way to determine if a natural number is even or odd. The same result can be obtained by checking if the least significant bit of the number is set. Second, when the control reaches the statement n = n / 2, we know that n is always an even number and therefore this statement should be replaced by integer division operator "//" or simply right-shifting by one bit.

local N = 1000000

local function collatz(n)
	local len = 0
	while n ~= 1 do
		if n & 1 == 1 then
			n = 3*n + 1
		else
			n = n >> 1
		end
		len = len + 1
	end
	return len
end

local max = 0
for n = 1, N do
	local c = collatz(n)
	if c > max then
		max = c
	end
end

print(max)

These simple optimizations reduce the running time to just a bit over 10 seconds, which is still much slower than C code

Divide and Conquer

When optimizing Lua code, one must utilize the underlying C code as much as possible. Lua tables provide a very good memory-to-speed tradeoff and it is almost always faster to cache even large quantities of data instead of recomputing it every time. By looking at the Collatz sequence length function, we can see that it can easily be made recursive and each result can be cached and reused, similarly to the fibonacci function.

local N = 1000000

local cache = {0}

local function collatz(n)
	local cached = cache[n]
	if cached then
		return cached
	end
	if n & 1 == 1 then
		local c = 1 + collatz(3 * n + 1)
		cache[n] = c
		return c
	else
		local c = 1 + collatz(n >> 1)
		cache[n] = c
		return c
	end
end

local max = 0
for n = 1, N do
	local c = collatz(n)
	if c > max then
		max = c
	end
end

print(max)

Caching results in a drastic performance improvement and the code now takes only 1.560 seconds. A little trick used here is that instead of explicitly checking if n == 1 and returning zero, the code relies on the entry 1 -> 0 being manually placed in the cache. Avoiding the 'if' statement reduces the time by about 50 milliseconds.

Predictable Bit Patterns

If we take a look at some of the Collatz sequences, we can see that there are no consecutive odd numbers. This can be proven by pointing out that if n is odd, 3*n+1 will always be even. The reverse is not true, because if n is even and a multiple of 4, n/2 will still be even. From this we can derive that if n is odd then 1 + collatz(3*n+1) can be replaced by 2 + collatz((3*n+1)/2). We can add a similar special case when n is a multiple of 4 – in this case 1 + collatz(n/2) is equivalent to 2 + collatz(n/4). Remember that it is possible to quickly check if a number is a multiple of 4 by inspecting the two least-significant bits.

local N = 1000000

local cache = {0}

local function collatz(n)
	local cached = cache[n]
	if cached then
		return cached
	end
	if n & 1 == 1 then
		local c = 2 + collatz(n + (n >> 1) + 1)
		cache[n] = c
		return c
	elseif n & 3 == 0 then
		local c = 2 + collatz(n >> 2)
		cache[n] = c
		return c
	else
		local c = 1 + collatz(n >> 1)
		cache[n] = c
		return c
	end
end

local max = 0
for n = 1, N do
	local c = collatz(n)
	if c > max then
		max = c
	end
end

print(max)

Results

The code now only takes about 1.000 seconds to run and the performance is now comparable with C code. In total, we've reduced the running time by 26 times. There are, of course, possible optimizations to the algorithm itself, but these are out of scope of this demonstration.