Optimizing Lua Code – Longest Collatz Sequence
Lua is often praised for its speed and simplicity, however, unless using LuaJIT or equivalent,
performance is still far from low-level or JIT-languages, such as C/C++, Java, Rust, etc.
In this article I am going to demonstrate some optimization techniques for the standard Lua
implementation. I will use Lua 5.3 because of its native support for bitwise operators.
Testing will be done by repeatedly using the Linux time
command. Lua startup time
will not be subtracted unless it has a significant effect on the results. Distinction between
real
, user
and sys
times will also be ignored, unless
significant. All results are rounded to first reliable figure.
The Problem
A Collatz sequence is the sequence of natural numbers obtained by recursively applying the following function to a given natural number: if n is odd, yield 3n+1, else yield n/2. The Collatz conjecture states that the sequence will eventually reach 1 for any given input. Given a natural number N, output the length of the longest Collatz sequence for inputs in from 1 to N (inclusive).
Naive approach
We will start by writing a simple, iterative algorithm:
local N = 1000000 local function collatz(n) local len = 0 while n ~= 1 do if n % 2 == 1 then n = 3*n + 1 else n = n / 2 end len = len + 1 end return len end local max = 0 for n = 1, N do local c = collatz(n) if c > max then max = c end end print(max)
This program takes 26 seconds to finish. In comparison, the same algorithm written in C and
compiled with clang -O9
only takes 0.274 seconds. (Interestingly, gcc absolutely
fails to optimize the code and takes 0.448 seconds)
Basic optimizations
The code has some obvious performance issues. First, modulo is a very inefficient way to
determine if a natural number is even or odd. The same result can be obtained by checking if
the least significant bit of the number is set. Second, when the control reaches the
statement n = n / 2
, we know that n is always an even number and therefore
this statement should be replaced by integer division operator "//" or simply right-shifting
by one bit.
local N = 1000000 local function collatz(n) local len = 0 while n ~= 1 do if n & 1 == 1 then n = 3*n + 1 else n = n >> 1 end len = len + 1 end return len end local max = 0 for n = 1, N do local c = collatz(n) if c > max then max = c end end print(max)
These simple optimizations reduce the running time to just a bit over 10 seconds, which is still much slower than C code
Divide and Conquer
When optimizing Lua code, one must utilize the underlying C code as much as possible. Lua tables provide a very good memory-to-speed tradeoff and it is almost always faster to cache even large quantities of data instead of recomputing it every time. By looking at the Collatz sequence length function, we can see that it can easily be made recursive and each result can be cached and reused, similarly to the fibonacci function.
local N = 1000000 local cache = {0} local function collatz(n) local cached = cache[n] if cached then return cached end if n & 1 == 1 then local c = 1 + collatz(3 * n + 1) cache[n] = c return c else local c = 1 + collatz(n >> 1) cache[n] = c return c end end local max = 0 for n = 1, N do local c = collatz(n) if c > max then max = c end end print(max)
Caching results in a drastic performance improvement and the code now takes only 1.560 seconds. A little trick used here is that instead of explicitly checking if n == 1 and returning zero, the code relies on the entry 1 -> 0 being manually placed in the cache. Avoiding the 'if' statement reduces the time by about 50 milliseconds.
Predictable Bit Patterns
If we take a look at some of the Collatz sequences, we can see that there are no consecutive
odd numbers. This can be proven by pointing out that if n is odd, 3*n+1 will
always be even. The reverse is not true, because if n is even and a multiple of 4,
n/2 will still be even. From this we can derive that if n is odd then
1 + collatz(3*n+1)
can be replaced by 2 + collatz((3*n+1)/2)
.
We can add a similar special case when n is a multiple of 4 – in this case
1 + collatz(n/2)
is equivalent to 2 + collatz(n/4)
.
Remember that it is possible to quickly check if a number is a multiple of 4 by inspecting
the two least-significant bits.
local N = 1000000 local cache = {0} local function collatz(n) local cached = cache[n] if cached then return cached end if n & 1 == 1 then local c = 2 + collatz(n + (n >> 1) + 1) cache[n] = c return c elseif n & 3 == 0 then local c = 2 + collatz(n >> 2) cache[n] = c return c else local c = 1 + collatz(n >> 1) cache[n] = c return c end end local max = 0 for n = 1, N do local c = collatz(n) if c > max then max = c end end print(max)
Results
The code now only takes about 1.000 seconds to run and the performance is now comparable with C code. In total, we've reduced the running time by 26 times. There are, of course, possible optimizations to the algorithm itself, but these are out of scope of this demonstration.