The Algorithms logo
The Algorithms
Acerca deDonar

Prim's Algorithm (Adjacency Matrix)

using System;

namespace Algorithms.Graph.MinimumSpanningTree;

/// <summary>
///     Class that uses Prim's (Jarnik's algorithm) to determine the minimum
///     spanning tree (MST) of a given graph. Prim's algorithm is a greedy
///     algorithm that can determine the MST of a weighted undirected graph
///     in O(V^2) time where V is the number of nodes/vertices when using an
///     adjacency matrix representation.
///     More information: https://en.wikipedia.org/wiki/Prim%27s_algorithm
///     Pseudocode and runtime analysis: https://www.personal.kent.edu/~rmuhamma/Algorithms/MyAlgorithms/GraphAlgor/primAlgor.htm .
/// </summary>
public static class PrimMatrix
{
    /// <summary>
    ///     Determine the minimum spanning tree for a given weighted undirected graph.
    /// </summary>
    /// <param name="adjacencyMatrix">Adjacency matrix for graph to find MST of.</param>
    /// <param name="start">Node to start search from.</param>
    /// <returns>Adjacency matrix of the found MST.</returns>
    public static float[,] Solve(float[,] adjacencyMatrix, int start)
    {
        ValidateMatrix(adjacencyMatrix);

        var numNodes = adjacencyMatrix.GetLength(0);

        // Create array to represent minimum spanning tree
        var mst = new float[numNodes, numNodes];

        // Create array to keep track of which nodes are in the MST already
        var added = new bool[numNodes];

        // Create array to keep track of smallest edge weight for node
        var key = new float[numNodes];

        // Create array to store parent of node
        var parent = new int[numNodes];

        for (var i = 0; i < numNodes; i++)
        {
            mst[i, i] = float.PositiveInfinity;
            key[i] = float.PositiveInfinity;

            for (var j = i + 1; j < numNodes; j++)
            {
                mst[i, j] = float.PositiveInfinity;
                mst[j, i] = float.PositiveInfinity;
            }
        }

        // Ensures that the starting node is added first
        key[start] = 0;

        // Keep looping until all nodes are in tree
        for (var i = 0; i < numNodes - 1; i++)
        {
            GetNextNode(adjacencyMatrix, key, added, parent);
        }

        // Build adjacency matrix for tree
        for (var i = 0; i < numNodes; i++)
        {
            if (i == start)
            {
                continue;
            }

            mst[i, parent[i]] = adjacencyMatrix[i, parent[i]];
            mst[parent[i], i] = adjacencyMatrix[i, parent[i]];
        }

        return mst;
    }

    /// <summary>
    ///     Ensure that the given adjacency matrix represents a weighted undirected graph.
    /// </summary>
    /// <param name="adjacencyMatrix">Adjacency matric to check.</param>
    private static void ValidateMatrix(float[,] adjacencyMatrix)
    {
        // Matrix should be square
        if (adjacencyMatrix.GetLength(0) != adjacencyMatrix.GetLength(1))
        {
            throw new ArgumentException("Adjacency matrix must be square!");
        }

        // Graph needs to be undirected and connected
        for (var i = 0; i < adjacencyMatrix.GetLength(0); i++)
        {
            var connection = false;
            for (var j = 0; j < adjacencyMatrix.GetLength(0); j++)
            {
                if (Math.Abs(adjacencyMatrix[i, j] - adjacencyMatrix[j, i]) > 1e-6)
                {
                    throw new ArgumentException("Adjacency matrix must be symmetric!");
                }

                if (!connection && float.IsFinite(adjacencyMatrix[i, j]))
                {
                    connection = true;
                }
            }

            if (!connection)
            {
                throw new ArgumentException("Graph must be connected!");
            }
        }
    }

    /// <summary>
    ///     Determine which node should be added next to the MST.
    /// </summary>
    /// <param name="adjacencyMatrix">Adjacency matrix of graph.</param>
    /// <param name="key">Currently known minimum edge weight connected to each node.</param>
    /// <param name="added">Whether or not a node has been added to the MST.</param>
    /// <param name="parent">The node that added the node to the MST. Used for building MST adjacency matrix.</param>
    private static void GetNextNode(float[,] adjacencyMatrix, float[] key, bool[] added, int[] parent)
    {
        var numNodes = adjacencyMatrix.GetLength(0);
        var minWeight = float.PositiveInfinity;

        var node = -1;

        // Find node with smallest node with known edge weight not in tree. Will always start with starting node
        for (var i = 0; i < numNodes; i++)
        {
            if (!added[i] && key[i] < minWeight)
            {
                minWeight = key[i];
                node = i;
            }
        }

        // Add node to mst
        added[node] = true;

        // Update smallest found edge weights and parent for adjacent nodes
        for (var i = 0; i < numNodes; i++)
        {
            if (!added[i] && adjacencyMatrix[node, i] < key[i])
            {
                key[i] = adjacencyMatrix[node, i];
                parent[i] = node;
            }
        }
    }
}