#include #include #include #include #include /** * @brief Determine if the given number is prime or not * * @return `true` if the number is prime or `false` if not */ bool isPrime(unsigned long num) { for (unsigned long i = 2; i <= num / 2; i++) { if (num % i == 0) { return false; } } return true; } /** * @brief Generate a random number in between the given range * * @param min The minimum allowed value for the random number * @param max The maximum allowed value for the random number * @return The newly created random number */ unsigned long rand_in_range(unsigned long min, unsigned long max) { return (rand() % (max - min)) + min; } #define MIN_PRIME 2 // This value is chosen due to its nature as being less than the square root of // a unsigned long We sub 100 off of it to ENSURE its in range for any following // calculations #define MAX_PRIME USHRT_MAX - 100 /** * @brief Generate a random prime number between `MIN_PRIME` & `MAX_PRIME` * * Note that this implementation is REALLY slow for huge values! * * @return a prime number */ unsigned long genPrime() { unsigned long max = MAX_PRIME; unsigned long prime = rand_in_range(MIN_PRIME, max); if (prime % 2 == 0) { prime++; } while (!isPrime(prime)) { prime += 2; // If our value is *still* not prime and we've exceeded the max, go // ahead and reduce the max range we can look in and try again // // BUG: This may have issues if the min and max values are too close // together if (prime > max) { prime = rand_in_range(MIN_PRIME, max / 2); if (prime % 2 == 0) { prime++; } } } return prime; } /** * @brief Calculate the great common divisor (GCD) of two numbers recursively * * @param a first number * @param b second number * @return the gcd of both numbers */ unsigned long gcd(unsigned long a, unsigned long b) { // NOTE: This really should be iterative with a while loop, recursion here // is not ideal return (b == 0) ? a : gcd(b, a % b); } /** * @brief Calculate the great common divisor (GCD) of two numbers recursively * via the Extended Euclidean Algorithm * * @param a First number to find GCD against b * @param b Second number to find GCD against a * @return the gcd of both numbers */ signed long long gcdExtended(signed long long a, signed long long b, signed long long *x, signed long long *y) { // Base Case if (a == 0) { *x = 0, *y = 1; return b; } // To store results of recursive call signed long long x1, y1; signed long long gcd = gcdExtended(b % a, a, &x1, &y1); // Update x and y using results of recursive // call *x = y1 - (b / a) * x1; *y = x1; return gcd; } /** * @brief Calculate the modulo inverse using the extended Euclidean Algorithm * * @param a The value to be inverted mod a * @param m The modulus, a positive integer greater than 1 */ unsigned long modInverse(unsigned long a, unsigned long m) { // You might notice the evil casting going on below. There's a really good // reason for that! The extended euclidean algo. pretty much has a hard // requirement on using signed integers. To satisfy this condition, we're // using `signed long long` so we can safely fit the *unsigned long* value // within. This allows "safe" casts back and forth without any loss. // // Some more enlightening information on this problem can be found at // https://jeffhurchalla.com/2018/10/13/implementing-the-extended-euclidean-algorithm-with-unsigned-inputs/ // // Frankly, there are much, MUCH, faster ways of doing this if we didn't // *have* to use the extended euclidean algo. (mostly in the form of cursed // bitshifting which FIPS-186-5 has some resources on 😉). signed long long x, y, a_0 = (signed long long)a, m_0 = (signed long long)m; signed long long g = gcdExtended(a_0, m_0, &x, &y); if (g != 1) { fprintf(stderr, "Failed to determine modular inverse for `%lu` and `%lu`, they " "may not be coprime!\n", a, m); exit(EXIT_FAILURE); } return (unsigned long)((x % m_0 + m_0) % m_0); } /** * @brief Modular Exponentiation * * See * https://www.cs.ucf.edu/~dmarino/ucf/cis3362/lectures/newlecs/FastModExpo.pdf * for more information */ unsigned long modExp(unsigned long base, unsigned long exp, unsigned long num) { if (exp == 0) return 1; if (exp == 1) return base % num; if (exp % 2 == 0) { unsigned long t = modExp(base, exp / 2, num); return (t * t) % num; } return (base * modExp(base, exp - 1, num)) % num; } /** * @brief Generate keys for RSA encryption given some p & q primes * * @param p first secret large prime number * @param q second secret large primer number * @param n will be set to the result of p * q * @param e randomly chosen such that e < φ(n) and e & φ (n) are coprime * @param d the mod inverse of e % φ(n), where e*d ≡ 1 (mod φ(n)) */ void generateKeys(unsigned long p, unsigned long q, unsigned long *n, unsigned long *e, unsigned long *d) { *n = p * q; unsigned long phi = (p - 1) * (q - 1); // Pick a valid `e` for our given totient if `e` is not already valid (it // should be) while (gcd(*e, phi) != 1 && *e < phi) { (*e)++; } // If we don't have a valid value for `e` we abort. Technically we could // solve this by wrapping `e` to be else than phi and start searching for // values again, but we'd rather abort early as `e` is most typically // statically set in RSA calculations. if (gcd(*e, phi) != 1) { fprintf(stderr, "Failed to find valid `e` value for given `phi` value! `e`: " "'%lu' | `phi`: '%lu'\n", *e, phi); exit(EXIT_FAILURE); } *d = modInverse(*e, phi); } /** * @brief Encrypt plaintext with RSA * * @param plaintext * @param e Encryption key * @param n Combined secret values */ unsigned long encrypt(unsigned long plaintext, unsigned long e, unsigned long n) { return modExp(plaintext, e, n); } // Function to decrypt ciphertext using RSA /** * @brief Decrypt RSA encrypted ciphertext * * @param ciphertext The text to decrypt * @param d Decryption key * @param n Combined secret values */ unsigned long decrypt(unsigned long ciphertext, unsigned long d, unsigned long n) { return modExp(ciphertext, d, n); } int main() { unsigned long p; unsigned long q; unsigned long e = 65537; // A more generally used value for `e` char choice; printf("Would you like to choose your own values for `p` & `q`? (Y/n)? "); if (scanf("%c", &choice) == 0) { fprintf(stderr, "Failed to get a choice\n"); exit(EXIT_FAILURE); } if (choice == 'y' || choice == 'Y') { printf("Enter a prime number for both p and q: "); if (scanf("%lu %lu", &p, &q) == 0) { fprintf(stderr, "Unable to get numbers from input!\n"); exit(EXIT_FAILURE); } } else if (choice == 'n' || choice == 'N') { printf("Generating random primes for `p` & `q`...\n"); srand(time(NULL)); p = genPrime(); q = genPrime(); } else { fprintf( stderr, "Invalid answer `%c` given! Type one of `Y`, `y`, `N`, or `n`.\n", choice); exit(EXIT_FAILURE); } printf("Got (p, q): (%lu, %lu)\n", p, q); unsigned long n, d; if (!isPrime(p)) { fprintf(stderr, "Given `p` value was not prime, received '%lu'!\n", p); exit(EXIT_FAILURE); } if (!isPrime(q)) { fprintf(stderr, "Given `q` value was not prime, received '%lu'!\n", q); exit(EXIT_FAILURE); } // Generate RSA keys printf("Generating keys...\n"); generateKeys(p, q, &n, &e, &d); printf("Public key (e, n): (%lu, %lu)\n", e, n); printf("Private key (d, n): (%lu, %lu)\n\n", d, n); // Encrypt and decrypt a sample plaintext, which we assume is given as an // integer value unsigned long plaintext; printf("Enter an integer between 0 and %lu as plain text to be encrypted: ", n - 1); if (scanf("%lu", &plaintext) == 0) { fprintf(stderr, "Failed to read plaintext!\n"); exit(EXIT_FAILURE); } if (plaintext > n - 1) { fprintf(stderr, "Unable to RSA encrypt & decrypt given `plaintext`: " "'%lu'!\nPlaintext value was more than `n-1`: '%lu'\n", plaintext, n - 1); exit(EXIT_FAILURE); } printf("Original plaintext: %lu\n", plaintext); unsigned long ciphertext = encrypt(plaintext, e, n); printf("Encrypted ciphertext: %lu\n", ciphertext); unsigned long decrypted = decrypt(ciphertext, d, n); printf("Decrypted plaintext: %lu\n\n", decrypted); return 0; }