#### Uint

d
```-- Unsigned integers with a given base.
-- Support basic arithmetic (increment, decrement, addition, subtraction, multiplication, exponentiation),
-- with both mutating variants (via methods) and copying variants (via arithmetic operators),
-- comparisons, as well as efficient base conversion.

-- Cache class tables for given bases.
-- Important for equality comparisons to work for the same base.
local uints = setmetatable({}, { __mode = "v" })

return function(
-- number, base to be used; digits will be 0 to base (exclusive);
-- should be at most `2^26` for products to stay in exact float bounds,
-- default is `2^24` (3 bytes per digit)
base
)
base = base or 2 ^ 24
assert(base >= 2 and base % 1 == 0 and base <= 2 ^ 26)

if uints[base] then
return uints[base]
end

local uint = { base = base }
local mt = { __index = uint }

local function bin_op(name, f)
mt["__" .. name] = function(a, b)
if type(a) == "number" then
a = uint.from_number(a)
elseif type(b) == "number" then
b = uint.from_number(b)
end
return f(a, b)
end
end

function uint.zero()
return setmetatable({}, mt)
end

function uint.one()
return setmetatable({ 1 }, mt)
end

function uint.from_digits(
digits -- list of digits in the appropriate base, little endian; is consumed
)
return setmetatable(digits, mt)
end

function uint.from_number(
number -- exact integer >= 0
)
local digits = {}
assert(number >= 0 and number % 1 == 0)
while number > 0 do
local digit = number % base
table.insert(digits, digit)
number = (number - digit) / base
end
return uint.from_digits(digits)
end

function uint:to_number(
exact -- whether to allow losing precision, defaults to true
)
local number = 0
local pow = 1
for _, digit in ipairs(self) do
local val = digit * pow
pow = pow * base
if exact ~= false and 2 ^ 53 - val < number then
return -- if not representable without loss of precision
end
number = number + val
end
return number -- exact number representing the same integer
end

function uint:copy()
local copy = {}
for i, v in ipairs(self) do
copy[i] = v
end
return uint.from_digits(copy)
end

function uint.copy_from(dst, src)
assert(dst.base == src.base)
for i, v in ipairs(src) do
dst[i] = v
end
for i = #src + 1, #dst do
dst[i] = nil
end
end

--> sign(a - b)
--! For better efficiency, prefer a single `uint.compare` over multiple `<`/`>`/`<=`/`>=`/`==` comparisons
function uint.compare(a, b)
assert(a.base == b.base)
if #a < #b then
return -1
end
if #a > #b then
return 1
end
for i = #a, 1, -1 do
if a[i] < b[i] then
return -1
end
if a[i] > b[i] then
return 1
end
end
return 0
end

-- Note: These will only run if the metatables are equal,
-- so there is no risk of comparing numbers of different bases.
bin_op("eq", function(a, b)
return a:compare(b) == 0
end)
bin_op("lt", function(a, b)
return a:compare(b) < 0
end)
bin_op("le", function(a, b)
return a:compare(b) <= 0
end)

function uint:increment()
local i = 1
while self[i] == base - 1 do
self[i] = 0
i = i + 1
end
self[i] = (self[i] or 0) + 1
end

function uint:decrement()
local i = 1
while self[i] == 0 do
self[i] = base - 1
i = i + 1
end
self[i] = assert(self[i], "result < 0") - 1
if self[i] == 0 then
self[i] = nil
end
end

local i, j = srcshift + 1, 1
local carry = 0
while dst[i] or src[j] or carry > 0 do
local digit_sum = (dst[i] or 0) + (src[j] or 0) + carry
dst[i] = digit_sum % base
carry = (digit_sum - dst[i]) / base
i, j = i + 1, j + 1
end
end

end

local res = a:copy()
return res
end)

local i = #dst
while dst[i] == 0 do
dst[i] = nil
i = i - 1
end
end

function uint.subtract(dst, src)
do
local i = 1
local borrow = 0
while src[i] or borrow > 0 do
local digit_diff = assert(dst[i], "result < 0") - (src[i] or 0) - borrow
-- Works since Lua's remainder operator is special -
-- it computes `(base + digit_diff) % base` for a negative `digit_diff`
dst[i] = digit_diff % base
if digit_diff < 0 then
-- Unfortunately `math.floor` rounds negative numbers in the wrong direction
-- borrow = math.floor(-digit_diff / base)
-- assert(borrow >= 0)
borrow = 1
else
borrow = 0
end
i = i + 1
end
assert(borrow == 0, "result < 0")
end
end

bin_op("sub", function(a, b)
local res = a:copy()
res:subtract(b)
return res
end)

local function product_naive(a, b)
assert(a[#a] ~= 0 and b[#b] ~= 0)
local res = uint.zero()
-- Enforce #a <= #b
if #b < #a then
a, b = b, a
end
for i, a_digit in ipairs(a) do
if a_digit == 0 then
res[i] = res[i] or 0 -- no holes!
else
local term = {}
local carry = 0
for j, b_digit in ipairs(b) do
local res_digit = a_digit * b_digit + carry
term[j] = res_digit % base
carry = (res_digit - term[j]) / base
end
if carry > 0 then
table.insert(term, carry)
end
end
end
assert(res[#res] ~= 0)
return res
end

-- Note: Returns `{}` (zero) if `from > to`
local function slice(self, from, to)
local digits = {}
for i = from, to do
table.insert(digits, self[i])
end
return uint.from_digits(digits)
end

local function product_karatsuba(a, b)
-- Ensure that `b` is the longer number
if #b < #a then
a, b = b, a
end
if #a < 10 then -- base case: Naive multiplication, to be tweaked
return product_naive(a, b)
end
local mid = math.floor(#b / 2) -- split at the middle of the longer of the two numbers
local prod_high, prod_low, prod_sum
do
local a_low, a_high = slice(a, 1, mid), slice(a, mid + 1, #a)
local b_low, b_high = slice(b, 1, mid), slice(b, mid + 1, #b)
prod_high = product_karatsuba(a_high, b_high)
prod_low = product_karatsuba(a_low, b_low)
-- Note: We can mutate a_low and b_low here since we already used them above.
local a_sum = a_low
local b_sum = b_low
prod_sum = product_karatsuba(a_sum, b_sum)
-- At this point we have `prod_sum = (a_low + a_high) * (b_low + b_high)`
prod_sum:subtract(prod_high)
prod_sum:subtract(prod_low)
-- This leaves us with `prod_sum = a_high * b_low + b_high * a_low`
end
local res = prod_low
-- Ensure that we produce no holes
local min_len = (prod_high[1] and 2 * mid) or (prod_sum[1] and mid) or 0
for i = #res + 1, min_len do
res[i] = res[i] or 0
end
return res
end

bin_op("mul", product_karatsuba)

function uint.multiply(dst, src)
assert(dst.base == src.base)
dst:copy_from(product_karatsuba(dst, src))
end

-- TODO division, modulo

-- Exponentiation by squaring. Some details have to be different for uints.
local function fastpow(
expbase, -- Base (uint)
exp -- Exponent, non-negative integer
)
if exp == 1 then
return expbase
end
local res = expbase.one()
while exp > 0 do -- loop invariant: `res * expbase^exp = expbase^exp`
if exp % 2 == 1 then
-- `res * expbase * expbase^(exp-1) = expbase^exp`
res = res * expbase
exp = exp - 1
else
-- `res * (expbase^2)^(exp/2) = expbase^exp`
expbase = expbase * expbase
exp = exp / 2
end
end
return res
end

function mt.__pow(expbase, exp)
local zero_base = expbase == 0 or expbase == uint.zero()
if exp == 0 or exp == uint.zero() then
assert(not zero_base, "0^0")
return uint.one()
end
if zero_base then
return uint.zero()
end
if type(expbase) == "number" then
expbase = uint.from_number(expbase)
end
-- Try conversion of exponent to number for consistency;
-- taking n^m with n >= 2 and m >= 2^53 wouldn't fit into memory anyways
if type(exp) ~= "number" then
exp = assert(exp:to_number(), "exponent too large")
end
return fastpow(expbase, exp)
end

local function convert_base(self, other_uint)
if not self[1] then
return other_uint.zero()
end
if not self[2] then
return other_uint.from_number(self[1])
end
local mid = math.floor(#self / 2)
local low = convert_base(slice(self, 1, mid), other_uint)
local high = convert_base(slice(self, mid + 1, #self), other_uint)
-- TODO this fastpow could be optimized with memoization, but shouldn't matter asymptotically
return low
end

--> uint instance of `other_uint`
function uint:convert_base_to(
other_uint -- "class" table to convert to
)
if not self[1] then
return other_uint.zero()
end
if uint.base == other_uint.base then
return other_uint.copy(self)
end
return convert_base(self, other_uint)
end

uints[base] = uint

return uint
end
```