RSA,Rivest-Shamir-Adleman 算法,是一個常見的非對稱加密算法。本文將簡明扼要通俗易懂地介紹 RSA 的原理,并給出 Python 實現(xiàn)。
本文同步發(fā)表于我的博客 https://clouder0.com/zh-cn/posts/rsa/
Why we need RSA?
加密的需求大家都很熟悉,但非對稱加密呢?我們?yōu)槭裁葱枰菍ΨQ加密?
想象以下的場景:
-
數(shù)字簽名: 你希望以你的名義發(fā)送一封郵件(或者發(fā)布任何的數(shù)據(jù)內(nèi)容),并且向公眾公開。
- 但是,你不希望有其他人能夠以你的名義發(fā)布內(nèi)容。
- 這個時候,公眾可以解密,你可以加密。你持有私鑰,公眾持有公鑰。私鑰簽名,公鑰驗簽。
-
加密通訊: 你建了一個小網(wǎng)站,你希望用戶的數(shù)據(jù)在傳輸?shù)侥愕木W(wǎng)站的過程中是安全的。也就是說,用戶發(fā)送加密后的數(shù)據(jù),你解密數(shù)據(jù)。
- 這個時候,用戶持有公鑰進行加密,你持有私鑰進行解密。
- 任何人都可以向你發(fā)送加密內(nèi)容,但只有持有私鑰的你才能解密。
這是非對稱加密的兩大經(jīng)典應(yīng)用場景。
如果只才能對稱加密的話,這兩種場景都是無法實現(xiàn)的:數(shù)字簽名自然不必說,加密通訊的話,由于你需要讓公眾具有加密的能力,但又不希望他們能夠解密,自然也需要非對稱。
How is this possible?為什么還能做出這種「有兩種密碼,一個加密一個解密」的神奇算法呢?
這一般都是利用了非對稱性。例如:將兩個質(zhì)數(shù)乘起來得到結(jié)果是簡單的,但想要對某個大數(shù)做質(zhì)因數(shù)分解,復(fù)雜度則極其高。
How RSA works...
RSA 利用的就是質(zhì)因數(shù)分解復(fù)雜度的非對稱性。
我們首先選擇兩個足夠大的質(zhì)數(shù),記為 $p,q$,然后:
-
計算 $n=pq$,這將用作模數(shù)。$n$ 的長度被稱為 key length,現(xiàn)在比較流行的是 512/2048/4096bits.
-
計算 $\lambda(n)$,即 $n$ 的 Carmichael's totient function. 也就是找到最小的正整數(shù) $m$ 使得 $a^m \equiv 1 \pmod n$ 對任意 $a$ 都成立。
- 在原教旨 RSA 中,選取的 $\lambda(n)$ 實際上是 Euler's totient function,不是最小的正數(shù) $m$,會比較大,但也是能用。
- 現(xiàn)在一般選用 $\lambda(n) = \operatorname{lcm}(p-1,q-1)$. 證明等會再說。
-
選擇一個整數(shù) $e$,滿足 $e \in (1,\lambda(n))$ 且 $\gcd(e, \lambda(n)) = 1$.
- $e$ 最好具有較短的 bit-length 和較小的 hamming weight,大家經(jīng)常選用的是 $e=2^{16}+1= 65537$.
-
計算 $d \equiv e^{-1} \pmod {\lambda(n)}$,也就是 $e$ 的乘法逆元。
好的,密鑰生成部分結(jié)束了。
接下來,我們將 $(n, e)$ 作為加密密鑰,$(n,d)$ 作為解密密鑰。而剩下的 $p,q,\lambda(n)$ 應(yīng)當(dāng)保密或者直接扔掉。
然后就是加密了,加密相當(dāng)?shù)暮唵伟?,加入我們想要傳遞原文 $M$,首先使用 padding 將其變成 $m$,滿足 $0 \le m < n$. 這里的 padding 只要是一種可逆的變換就行了。
然后計算:$c \equiv m^e \pmod n$,這里的 $c$ 就是我們的加密結(jié)果了。
使用快速冪,可以在較短的時間內(nèi)完成計算。
解密也相當(dāng)?shù)暮唵?,我們持有密?$c$,想要獲得 padded 后的原文 $m$,那么:
$$c^d \equiv (m^e)^d \equiv m \pmod n$$這里利用的核心原理是:$ed \equiv 1 \pmod n$,實際上這就是 $d$ 的定義式。
相信大家已經(jīng)完全理解 RSA 了,笑。
Math behind the scene
讓我們思考一下,RSA 算法的執(zhí)行流程已經(jīng)講完了,但它為什么能保證安全性、為什么能保證正確性呢?
RSA 的核心原理是:$e$ 和 $d$ 只有一個公開。而 $m^{ed} \equiv m^{\lambda(n)} \equiv m \pmod n$.
這里 $ed \equiv \lambda(n) \pmod n$ 就是解密密鑰 $d$ 的定義式。而 $m^{\lambda(n)} \equiv m \pmod n$ 就是 $\lambda(n)$ 的定義式。
實際上,$e$ 和 $d$ 是相當(dāng)對稱的。假如持有 $e$ 進行加密,加密后 $c=m^e$,則 $c^d \equiv m$. 用 $d$ 加密也是一樣的:$c=m^d, c^e = m$.
也就是——實際上持有公鑰的用戶也可以既加密、又解密……嗎?比如我們原本約定好公鑰加密,私鑰解密,that's fine. 但哪天你抽風(fēng)了說我們換換位置,公鑰解密私鑰加密,那也是無縫切換。
當(dāng)然,工程實踐上公鑰經(jīng)常取固定的 $e=65537$,嘛。
接下來還有一個問題,$\lambda(n) = \operatorname{lcm}(p-1,q-1)$,為什么就有 $m^{\lambda(n)} \equiv m \pmod n$?
根據(jù)眾所周知的費馬小定理,我們知道:當(dāng) $p$ 為素數(shù)時,$a^{p-1} \equiv 1 \pmod p$.
而當(dāng) $n=pq$ 時,顯然 $n$ 就不是素數(shù)了,我們要找到 $a^{\lambda(n)} \equiv 1 \pmod n$,這個時候可以使用歐拉定理:
$$a^b \equiv \begin{cases} a^{b \bmod \varphi(p)},b < \varphi(p) \\ a^{b \bmod \varphi(p) + \varphi(p)},b \geq \varphi(p) \end{cases} \pmod{p}$$其中 $\varphi(p)$ 為歐拉函數(shù)。歐拉函數(shù)滿足積性,也就是 $\varphi(pq) = \varphi(p) \times \varphi(q)$. 并且有對于素數(shù) $p$,$\varphi(p) = p-1$.
那么 $\varphi(n) = \varphi(pq) = \varphi(p) \times \varphi(q) = (p-1)(q-1)$,非常 reasonable.
那么顯然,我們就可以得到:
$$a^{\varphi(n)} \equiv a^0 \equiv 1 \pmod n$$這就算是求出了一個滿足需要的 $\lambda(n)$...了嗎?注意到我們的定義是最小的 $m$ 使得 $a^m \equiv 1 \pmod n$,這里的 $\varphi(n)$ 未必是最小的。
當(dāng)然,實際上是不是最小的其實對 RSA 影響不大。
接下來就是 Carmichael function,其計算如下:
$$\lambda(n) = \begin{cases} \varphi(n), &\text{if } n \text{ is }1,2,3,4 \text{ or an odd prime power} \\ \dfrac{1}{2}\varphi(n), &\text{ if } n = 2^r, r \ge 3 \\ \operatorname{lcm}\left( \lambda(n_1),\cdots,\lambda(n_k) \right), &\text{ if } n = n_1n_2\cdots n_k, \text{ where } n_i \text{ are}\\ &\text{ power of distinct prime numbers} \end{cases}$$在這里,因為 $n=pq$ 且 $p,q$ 都是質(zhì)數(shù),那么 $\lambda(pq) = \operatorname{lcm}(\varphi(p),\varphi(q))= \operatorname{lcm}(p-1,q-1)$.
Implementation
涉及到大數(shù)運算,人生苦短,我用 Python.
但是 Python 確實不是很快,我決定使用稍微短一些的 pq. 1024bits 吧,這樣最終的 $n$ 就是 2048bits.
以下是核心代碼:
import random
def miller_rabin(n: int, k: int):
"""use miller rabin method to test prime."""
if n == 2:
return True
if n % 2 == 0:
return False
r, s = 0, n - 1
while s % 2 == 0:
r += 1
s //= 2
for _ in range(k):
a = random.randrange(2, n - 1)
x = pow(a, s, n)
if x == 1 or x == n - 1:
continue
for _ in range(r - 1):
x = pow(x, 2, n)
if x == n - 1:
break
else:
return False
return True
def exgcd(a: int, b: int):
"""exgcd to cacl inverse."""
if b == 0:
return a, 1, 0
d, x, y = exgcd(b, a % b)
x, y = y, x - (a // b) * y
return d, x, y
def is_prime(n: int) -> bool:
return miller_rabin(n, 40)
def inv(a: int, m: int) -> int:
"""calc modular inverse."""
d, x, y = exgcd(a, m)
if d != 1:
raise RuntimeError("modular inverse does not exist")
return x % m
def rsa_encrypt(m: int, e: int, n: int) -> int:
return pow(m, e, n)
def rsa_decrypt(c: int, d: int, n: int) -> int:
return pow(c, d, n)
def gcd(a: int, b: int) -> int:
if b == 0:
return a
return gcd(b, a % b)
def lcm(a: int, b: int) -> int:
return a * b // gcd(a, b)
def rsa_gen(p: int, q: int) -> tuple[int, int, int]:
n = p * q
l = lcm(p - 1, q - 1)
e = 65537
d = inv(e, l)
return n, e, d
def get_big_prime():
while True:
p = random.getrandbits(1024)
if is_prime(p):
return p
def get_pq() -> tuple[int, int]:
return get_big_prime(), get_big_prime()
def main():
n, e, d = rsa_gen(*get_pq())
print("n =", n)
print("e =", e)
print("d =", d)
origin = int(input("origin: "))
c = rsa_encrypt(origin, e, n)
print("c =", c)
print("origin =", rsa_decrypt(c, d, n))
assert origin == rsa_decrypt(c, d, n)
print("OK")
if __name__ == "__main__":
main()
一般而言,RSA 的速度較為緩慢,我們可以將 RSA 和對稱加密配合使用,比如說用 RSA 傳遞對稱加密的密鑰,以實現(xiàn)加密通訊。
處理的長度過長的時候,需要分塊。emmm,注意到計算在 $\bmod n$ 下進行,需要分塊后比 $n$ 小。
def encrypt_file(n, e):
with open("input.txt", "rb") as f:
data = f.read()
# chunking by 255 bytes
chunks = [data[i : i + 255] for i in range(0, len(data), 255)]
with open("output.txt", "wb") as f:
for chunk in chunks:
m = int.from_bytes(chunk, "little")
c = rsa_encrypt(m, e, n).to_bytes(512, "little")
f.write(c)
def decrypt_file(n, d):
with open("output.txt", "rb") as f:
data = f.read()
chunks = [data[i : i + 512] for i in range(0, len(data), 512)]
with open("output_de.txt", "wb") as f:
for chunk in chunks[:-1]:
c = int.from_bytes(chunk, "little")
m = rsa_decrypt(c, d, n).to_bytes(255, "little")
f.write(m)
c = int.from_bytes(chunks[-1], "little")
m = rsa_decrypt(c, d, n).to_bytes(255, "little")
# trim trailing zeros
while m[-1] == 0:
m = m[:-1]
f.write(m)
def test_file():
n, e, d = rsa_gen(*get_pq())
encrypt_file(n, e)
decrypt_file(n, d)