296 lines
9.0 KiB
C

#include <limits.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
/**
* @brief Determine if the given number is prime or not
*
* @return `true` if the number is prime or `false` if not
*/
bool isPrime(unsigned long num) {
for (unsigned long i = 2; i <= num / 2; i++) {
if (num % i == 0) {
return false;
}
}
return true;
}
/**
* @brief Generate a random number in between the given range
*
* @param min The minimum allowed value for the random number
* @param max The maximum allowed value for the random number
* @return The newly created random number
*/
unsigned long rand_in_range(unsigned long min, unsigned long max) {
return (rand() % (max - min)) + min;
}
#define MIN_PRIME 2
// This value is chosen due to its nature as being less than the square root of
// a unsigned long We sub 100 off of it to ENSURE its in range for any following
// calculations
#define MAX_PRIME USHRT_MAX - 100
/**
* @brief Generate a random prime number between `MIN_PRIME` & `MAX_PRIME`
*
* Note that this implementation is REALLY slow for huge values!
*
* @return a prime number
*/
unsigned long genPrime() {
unsigned long max = MAX_PRIME;
unsigned long prime = rand_in_range(MIN_PRIME, max);
if (prime % 2 == 0) {
prime++;
}
while (!isPrime(prime)) {
prime += 2;
// If our value is *still* not prime and we've exceeded the max, go
// ahead and reduce the max range we can look in and try again
//
// BUG: This may have issues if the min and max values are too close
// together
if (prime > max) {
prime = rand_in_range(MIN_PRIME, max / 2);
if (prime % 2 == 0) {
prime++;
}
}
}
return prime;
}
/**
* @brief Calculate the great common divisor (GCD) of two numbers recursively
*
* @param a first number
* @param b second number
* @return the gcd of both numbers
*/
unsigned long gcd(unsigned long a, unsigned long b) {
// NOTE: This really should be iterative with a while loop, recursion here
// is not ideal
return (b == 0) ? a : gcd(b, a % b);
}
/**
* @brief Calculate the great common divisor (GCD) of two numbers recursively
* via the Extended Euclidean Algorithm
*
* @param a First number to find GCD against b
* @param b Second number to find GCD against a
* @return the gcd of both numbers
*/
signed long long gcdExtended(signed long long a, signed long long b,
signed long long *x, signed long long *y) {
// Base Case
if (a == 0) {
*x = 0, *y = 1;
return b;
}
// To store results of recursive call
signed long long x1, y1;
signed long long gcd = gcdExtended(b % a, a, &x1, &y1);
// Update x and y using results of recursive
// call
*x = y1 - (b / a) * x1;
*y = x1;
return gcd;
}
/**
* @brief Calculate the modulo inverse using the extended Euclidean Algorithm
*
* @param a The value to be inverted mod a
* @param m The modulus, a positive integer greater than 1
*/
unsigned long modInverse(unsigned long a, unsigned long m) {
// You might notice the evil casting going on below. There's a really good
// reason for that! The extended euclidean algo. pretty much has a hard
// requirement on using signed integers. To satisfy this condition, we're
// using `signed long long` so we can safely fit the *unsigned long* value
// within. This allows "safe" casts back and forth without any loss.
//
// Some more enlightening information on this problem can be found at
// https://jeffhurchalla.com/2018/10/13/implementing-the-extended-euclidean-algorithm-with-unsigned-inputs/
//
// Frankly, there are much, MUCH, faster ways of doing this if we didn't
// *have* to use the extended euclidean algo. (mostly in the form of cursed
// bitshifting which FIPS-186-5 has some resources on 😉).
signed long long x, y, a_0 = (signed long long)a, m_0 = (signed long long)m;
signed long long g = gcdExtended(a_0, m_0, &x, &y);
if (g != 1) {
fprintf(stderr,
"Failed to determine modular inverse for `%lu` and `%lu`, they "
"may not be coprime!\n",
a, m);
exit(EXIT_FAILURE);
}
return (unsigned long)((x % m_0 + m_0) % m_0);
}
/**
* @brief Modular Exponentiation
*
* See
* https://www.cs.ucf.edu/~dmarino/ucf/cis3362/lectures/newlecs/FastModExpo.pdf
* for more information
*/
unsigned long modExp(unsigned long base, unsigned long exp, unsigned long num) {
if (exp == 0)
return 1;
if (exp == 1)
return base % num;
if (exp % 2 == 0) {
unsigned long t = modExp(base, exp / 2, num);
return (t * t) % num;
}
return (base * modExp(base, exp - 1, num)) % num;
}
/**
* @brief Generate keys for RSA encryption given some p & q primes
*
* @param p first secret large prime number
* @param q second secret large primer number
* @param n will be set to the result of p * q
* @param e randomly chosen such that e < φ(n) and e & φ (n) are coprime
* @param d the mod inverse of e % φ(n), where e*d ≡ 1 (mod φ(n))
*/
void generateKeys(unsigned long p, unsigned long q, unsigned long *n,
unsigned long *e, unsigned long *d) {
*n = p * q;
unsigned long phi = (p - 1) * (q - 1);
// Pick a valid `e` for our given totient if `e` is not already valid (it
// should be)
while (gcd(*e, phi) != 1 && *e < phi) {
(*e)++;
}
// If we don't have a valid value for `e` we abort. Technically we could
// solve this by wrapping `e` to be else than phi and start searching for
// values again, but we'd rather abort early as `e` is most typically
// statically set in RSA calculations.
if (gcd(*e, phi) != 1) {
fprintf(stderr,
"Failed to find valid `e` value for given `phi` value! `e`: "
"'%lu' | `phi`: '%lu'\n",
*e, phi);
exit(EXIT_FAILURE);
}
*d = modInverse(*e, phi);
}
/**
* @brief Encrypt plaintext with RSA
*
* @param plaintext
* @param e Encryption key
* @param n Combined secret values
*/
unsigned long encrypt(unsigned long plaintext, unsigned long e,
unsigned long n) {
return modExp(plaintext, e, n);
}
// Function to decrypt ciphertext using RSA
/**
* @brief Decrypt RSA encrypted ciphertext
*
* @param ciphertext The text to decrypt
* @param d Decryption key
* @param n Combined secret values
*/
unsigned long decrypt(unsigned long ciphertext, unsigned long d,
unsigned long n) {
return modExp(ciphertext, d, n);
}
int main() {
unsigned long p;
unsigned long q;
unsigned long e = 65537; // A more generally used value for `e`
char choice;
printf("Would you like to choose your own values for `p` & `q`? (Y/n)? ");
if (scanf("%c", &choice) == 0) {
fprintf(stderr, "Failed to get a choice\n");
exit(EXIT_FAILURE);
}
if (choice == 'y' || choice == 'Y') {
printf("Enter a prime number for both p and q: ");
if (scanf("%lu %lu", &p, &q) == 0) {
fprintf(stderr, "Unable to get numbers from input!\n");
exit(EXIT_FAILURE);
}
} else if (choice == 'n' || choice == 'N') {
printf("Generating random primes for `p` & `q`...\n");
srand(time(NULL));
p = genPrime();
q = genPrime();
} else {
fprintf(
stderr,
"Invalid answer `%c` given! Type one of `Y`, `y`, `N`, or `n`.\n",
choice);
exit(EXIT_FAILURE);
}
printf("Got (p, q): (%lu, %lu)\n", p, q);
unsigned long n, d;
if (!isPrime(p)) {
fprintf(stderr, "Given `p` value was not prime, received '%lu'!\n", p);
exit(EXIT_FAILURE);
}
if (!isPrime(q)) {
fprintf(stderr, "Given `q` value was not prime, received '%lu'!\n", q);
exit(EXIT_FAILURE);
}
// Generate RSA keys
printf("Generating keys...\n");
generateKeys(p, q, &n, &e, &d);
printf("Public key (e, n): (%lu, %lu)\n", e, n);
printf("Private key (d, n): (%lu, %lu)\n\n", d, n);
// Encrypt and decrypt a sample plaintext, which we assume is given as an
// integer value
unsigned long plaintext;
printf("Enter an integer between 0 and %lu as plain text to be encrypted: ",
n - 1);
if (scanf("%lu", &plaintext) == 0) {
fprintf(stderr, "Failed to read plaintext!\n");
exit(EXIT_FAILURE);
}
if (plaintext > n - 1) {
fprintf(stderr,
"Unable to RSA encrypt & decrypt given `plaintext`: "
"'%lu'!\nPlaintext value was more than `n-1`: '%lu'\n",
plaintext, n - 1);
exit(EXIT_FAILURE);
}
printf("Original plaintext: %lu\n", plaintext);
unsigned long ciphertext = encrypt(plaintext, e, n);
printf("Encrypted ciphertext: %lu\n", ciphertext);
unsigned long decrypted = decrypt(ciphertext, d, n);
printf("Decrypted plaintext: %lu\n\n", decrypted);
return 0;
}