The Algorithms logo
The Algorithms
AboutDonate

Diffie

from __future__ import annotations


def find_primitive(modulus: int) -> int | None:
    """
    Find a primitive root modulo modulus, if one exists.

    Args:
        modulus : The modulus for which to find a primitive root.

    Returns:
        The primitive root if one exists, or None if there is none.

    Examples:
    >>> find_primitive(7)  # Modulo 7 has primitive root 3
    3
    >>> find_primitive(11)  # Modulo 11 has primitive root 2
    2
    >>> find_primitive(8) == None # Modulo 8 has no primitive root
    True
    """
    for r in range(1, modulus):
        li = []
        for x in range(modulus - 1):
            val = pow(r, x, modulus)
            if val in li:
                break
            li.append(val)
        else:
            return r
    return None


if __name__ == "__main__":
    import doctest

    doctest.testmod()

    prime = int(input("Enter a prime number q: "))
    primitive_root = find_primitive(prime)
    if primitive_root is None:
        print(f"Cannot find the primitive for the value: {primitive_root!r}")
    else:
        a_private = int(input("Enter private key of A: "))
        a_public = pow(primitive_root, a_private, prime)
        b_private = int(input("Enter private key of B: "))
        b_public = pow(primitive_root, b_private, prime)

        a_secret = pow(b_public, a_private, prime)
        b_secret = pow(a_public, b_private, prime)

        print("The key value generated by A is: ", a_secret)
        print("The key value generated by B is: ", b_secret)