Add DocTests to diffie.py (#10156)

* diffie doctest

* fix ut

* update doctest

---------

Co-authored-by: Harsha Kottapalli <skottapalli@microsoft.com>
This commit is contained in:
Sai Harsha Kottapalli 2023-10-09 20:49:05 +05:30 committed by GitHub
parent 53d78b9cc0
commit c0da015b7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,11 +1,28 @@
from __future__ import annotations from __future__ import annotations
def find_primitive(n: int) -> int | None: def find_primitive(modulus: int) -> int | None:
for r in range(1, n): """
Find a primitive root modulo modulus, if one exists.
Args:
modulus : The modulus for which to find a primitive root.
Returns:
The primitive root if one exists, or None if there is none.
Examples:
>>> find_primitive(7) # Modulo 7 has primitive root 3
3
>>> find_primitive(11) # Modulo 11 has primitive root 2
2
>>> find_primitive(8) == None # Modulo 8 has no primitive root
True
"""
for r in range(1, modulus):
li = [] li = []
for x in range(n - 1): for x in range(modulus - 1):
val = pow(r, x, n) val = pow(r, x, modulus)
if val in li: if val in li:
break break
li.append(val) li.append(val)
@ -15,18 +32,22 @@ def find_primitive(n: int) -> int | None:
if __name__ == "__main__": if __name__ == "__main__":
q = int(input("Enter a prime number q: ")) import doctest
a = find_primitive(q)
if a is None: doctest.testmod()
print(f"Cannot find the primitive for the value: {a!r}")
prime = int(input("Enter a prime number q: "))
primitive_root = find_primitive(prime)
if primitive_root is None:
print(f"Cannot find the primitive for the value: {primitive_root!r}")
else: else:
a_private = int(input("Enter private key of A: ")) a_private = int(input("Enter private key of A: "))
a_public = pow(a, a_private, q) a_public = pow(primitive_root, a_private, prime)
b_private = int(input("Enter private key of B: ")) b_private = int(input("Enter private key of B: "))
b_public = pow(a, b_private, q) b_public = pow(primitive_root, b_private, prime)
a_secret = pow(b_public, a_private, q) a_secret = pow(b_public, a_private, prime)
b_secret = pow(a_public, b_private, q) b_secret = pow(a_public, b_private, prime)
print("The key value generated by A is: ", a_secret) print("The key value generated by A is: ", a_secret)
print("The key value generated by B is: ", b_secret) print("The key value generated by B is: ", b_secret)