local treap = {}
function treap.new(keys, less_than)
keys = keys or {}
if #keys == 0 then
return { less_than = less_than }
end
local function heapify_down(node)
local min_weight_child = node[true] or node[false]
if not min_weight_child then
return
end
if node[true] and node[false] and node[false].weight < node[true].weight then
min_weight_child = node[false]
end
if min_weight_child.weight < node.weight then
node.weight, min_weight_child.weight = min_weight_child.weight, node.weight
return heapify_down(min_weight_child)
end
end
local function build(from, to)
if to < from then
return nil
end
local mid = math.floor((from + to) / 2)
local node = {
key = keys[mid],
weight = math.random(),
[true] = build(from, mid - 1),
[false] = build(mid + 1, to),
}
heapify_down(node)
return node
end
return { less_than = less_than, _root = build(1, #keys) }
end
function treap:copy()
local function copy(node)
if node == nil then
return nil
end
return {
key = node.key,
weight = node.weight,
[true] = copy(node[true]),
[false] = copy(node[false]),
}
end
return setmetatable({ less_than = self.less_than, _root = copy(self._root) }, getmetatable(self))
end
function treap:insert(key, upsert)
local weight = math.random()
local prev_key
local function insert(node)
if not node then
return { key = key, weight = weight }
end
local lt = self.less_than(key, node.key)
if not (lt or self.less_than(node.key, key)) then
prev_key = node.key
if upsert then
node.key = key
end
return node
end
local child = insert(node[lt])
node[lt] = child
if child.weight < node.weight then
node[lt], child[not lt] = child[not lt], node
return child
end
return node
end
self._root = insert(self._root)
return prev_key
end
function treap:remove(key)
local function remove_adjacent(node)
local lt = math.random() < 0.5
if not node[lt] then
lt = not lt
end
local parent, side = node, lt
node = assert(parent[side])
while node[not lt] do
parent, side = node, not lt
node = parent[side]
end
parent[side] = node[lt]
return node
end
local function heapify_down(node)
if not (node[true] or node[false]) then
return node
end
local min_key
if not node[true] then
min_key = false
elseif not node[false] then
min_key = true
elseif node[true].weight < node[false].weight then
min_key = true
else
min_key = false
end
local child = node[min_key]
if node.weight <= child.weight then
return node
end
node[min_key] = child[not min_key]
child[not min_key] = heapify_down(node)
return child
end
local prev_key
local function remove(node)
if node == nil then
return node
end
local lt = self.less_than(key, node.key)
local eq = not (lt or self.less_than(node.key, key))
if eq then
prev_key = node.key
if node[true] and node[false] then
node.key = remove_adjacent(node).key
return heapify_down(node)
end
return node[true] or node[false]
end
node[lt] = remove(node[lt])
return node
end
self._root = remove(self._root)
return prev_key
end
return require("class")(treap, require("data_structures.sorted_set.binary_search_tree"))