#### Minimum Spanning Tree Prims 2

R
```"""
Prim's (also known as Jarník's) algorithm is a greedy algorithm that finds a minimum
spanning tree for a weighted undirected graph. This means it finds a subset of the
edges that forms a tree that includes every vertex, where the total weight of all the
edges in the tree is minimized. The algorithm operates by building this tree one vertex
at a time, from an arbitrary starting vertex, at each step adding the cheapest possible
connection from the tree to another vertex.
"""

from __future__ import annotations

from sys import maxsize
from typing import Generic, TypeVar

T = TypeVar("T")

def get_parent_position(position: int) -> int:
"""
heap helper function get the position of the parent of the current node

>>> get_parent_position(1)
0
>>> get_parent_position(2)
0
"""
return (position - 1) // 2

def get_child_left_position(position: int) -> int:
"""
heap helper function get the position of the left child of the current node

>>> get_child_left_position(0)
1
"""
return (2 * position) + 1

def get_child_right_position(position: int) -> int:
"""
heap helper function get the position of the right child of the current node

>>> get_child_right_position(0)
2
"""
return (2 * position) + 2

class MinPriorityQueue(Generic[T]):
"""
Minimum Priority Queue Class

Functions:
is_empty: function to check if the priority queue is empty
push: function to add an element with given priority to the queue
extract_min: function to remove and return the element with lowest weight (highest
priority)
update_key: function to update the weight of the given key
_bubble_up: helper function to place a node at the proper position (upward
movement)
_bubble_down: helper function to place a node at the proper position (downward
movement)
_swap_nodes: helper function to swap the nodes at the given positions

>>> queue = MinPriorityQueue()

>>> queue.push(1, 1000)
>>> queue.push(2, 100)
>>> queue.push(3, 4000)
>>> queue.push(4, 3000)

>>> queue.extract_min()
2

>>> queue.update_key(4, 50)

>>> queue.extract_min()
4
>>> queue.extract_min()
1
>>> queue.extract_min()
3
"""

def __init__(self) -> None:
self.heap: list[tuple[T, int]] = []
self.position_map: dict[T, int] = {}
self.elements: int = 0

def __len__(self) -> int:
return self.elements

def __repr__(self) -> str:
return str(self.heap)

def is_empty(self) -> bool:
# Check if the priority queue is empty
return self.elements == 0

def push(self, elem: T, weight: int) -> None:
# Add an element with given priority to the queue
self.heap.append((elem, weight))
self.position_map[elem] = self.elements
self.elements += 1
self._bubble_up(elem)

def extract_min(self) -> T:
# Remove and return the element with lowest weight (highest priority)
if self.elements > 1:
self._swap_nodes(0, self.elements - 1)
elem, _ = self.heap.pop()
del self.position_map[elem]
self.elements -= 1
if self.elements > 0:
bubble_down_elem, _ = self.heap[0]
self._bubble_down(bubble_down_elem)
return elem

def update_key(self, elem: T, weight: int) -> None:
# Update the weight of the given key
position = self.position_map[elem]
self.heap[position] = (elem, weight)
if position > 0:
parent_position = get_parent_position(position)
_, parent_weight = self.heap[parent_position]
if parent_weight > weight:
self._bubble_up(elem)
else:
self._bubble_down(elem)
else:
self._bubble_down(elem)

def _bubble_up(self, elem: T) -> None:
# Place a node at the proper position (upward movement) [to be used internally
# only]
curr_pos = self.position_map[elem]
if curr_pos == 0:
return None
parent_position = get_parent_position(curr_pos)
_, weight = self.heap[curr_pos]
_, parent_weight = self.heap[parent_position]
if parent_weight > weight:
self._swap_nodes(parent_position, curr_pos)
return self._bubble_up(elem)
return None

def _bubble_down(self, elem: T) -> None:
# Place a node at the proper position (downward movement) [to be used
# internally only]
curr_pos = self.position_map[elem]
_, weight = self.heap[curr_pos]
child_left_position = get_child_left_position(curr_pos)
child_right_position = get_child_right_position(curr_pos)
if child_left_position < self.elements and child_right_position < self.elements:
_, child_left_weight = self.heap[child_left_position]
_, child_right_weight = self.heap[child_right_position]
if child_right_weight < child_left_weight and child_right_weight < weight:
self._swap_nodes(child_right_position, curr_pos)
return self._bubble_down(elem)
if child_left_position < self.elements:
_, child_left_weight = self.heap[child_left_position]
if child_left_weight < weight:
self._swap_nodes(child_left_position, curr_pos)
return self._bubble_down(elem)
else:
return None
if child_right_position < self.elements:
_, child_right_weight = self.heap[child_right_position]
if child_right_weight < weight:
self._swap_nodes(child_right_position, curr_pos)
return self._bubble_down(elem)
return None

def _swap_nodes(self, node1_pos: int, node2_pos: int) -> None:
# Swap the nodes at the given positions
node1_elem = self.heap[node1_pos][0]
node2_elem = self.heap[node2_pos][0]
self.heap[node1_pos], self.heap[node2_pos] = (
self.heap[node2_pos],
self.heap[node1_pos],
)
self.position_map[node1_elem] = node2_pos
self.position_map[node2_elem] = node1_pos

class GraphUndirectedWeighted(Generic[T]):
"""
Graph Undirected Weighted Class

Functions:
add_edge: function to add an edge between 2 nodes in the graph
"""

def __init__(self) -> None:
self.connections: dict[T, dict[T, int]] = {}
self.nodes: int = 0

def __repr__(self) -> str:
return str(self.connections)

def __len__(self) -> int:
return self.nodes

def add_node(self, node: T) -> None:
# Add a node in the graph if it is not in the graph
if node not in self.connections:
self.connections[node] = {}
self.nodes += 1

def add_edge(self, node1: T, node2: T, weight: int) -> None:
# Add an edge between 2 nodes in the graph
self.connections[node1][node2] = weight
self.connections[node2][node1] = weight

def prims_algo(
graph: GraphUndirectedWeighted[T],
) -> tuple[dict[T, int], dict[T, T | None]]:
"""
>>> graph = GraphUndirectedWeighted()

>>> dist, parent = prims_algo(graph)

>>> abs(dist["a"] - dist["b"])
3
>>> abs(dist["d"] - dist["b"])
15
>>> abs(dist["a"] - dist["c"])
13
"""
# prim's algorithm for minimum spanning tree
dist: dict[T, int] = {node: maxsize for node in graph.connections}
parent: dict[T, T | None] = {node: None for node in graph.connections}

priority_queue: MinPriorityQueue[T] = MinPriorityQueue()
for node, weight in dist.items():
priority_queue.push(node, weight)

if priority_queue.is_empty():
return dist, parent

# initialization
node = priority_queue.extract_min()
dist[node] = 0
for neighbour in graph.connections[node]:
if dist[neighbour] > dist[node] + graph.connections[node][neighbour]:
dist[neighbour] = dist[node] + graph.connections[node][neighbour]
priority_queue.update_key(neighbour, dist[neighbour])
parent[neighbour] = node

# running prim's algorithm
while not priority_queue.is_empty():
node = priority_queue.extract_min()
for neighbour in graph.connections[node]:
if dist[neighbour] > dist[node] + graph.connections[node][neighbour]:
dist[neighbour] = dist[node] + graph.connections[node][neighbour]
priority_queue.update_key(neighbour, dist[neighbour])
parent[neighbour] = node
return dist, parent
```