Skip to content

Bleichenbacher 1998

if there's an oracle that give it c, it will tells you if the m's highest two bytes is 0x0002 (PKCS#1).

let 28(k1)n<28k, B=28(k2), if the oracle tells m is PKCS conforming, that means

2Bm<3B

Algorithm

let Mi be a set of intervals that m0Mi

step 1

given an cipher c, randomly choose s0 to let (cs0e(modn))'s plain is PKCS conforming, then set c0=c(s0)e(modn), M0=[2B,3B1], i=1

step 2

if i=1, because for any s1n3B, s1m0 will not be PKCS conforming, so search for the smallest integer s1>n3B, let s1m0 is PKCS conforming.

if i>1 and the number of intervals in Mi1 is at least 2, then search for the smallest integer si>si1 let sim0 is PKCS conforming.

if i>1 and Mi1=[a,b], then choose small integer ri,si such that

ri2bsi12Bn

and

2B+rinbsi<3B+rina

because am0b, so

2B+rinb2B+rinm0si3B1+rinm03B1+rina

use these si until sim0 is PKCS conforming

step 3

after si has been found, set

Mi=[a,b]Mi1, ri{[ max(a, 2B+rinsi), min(b, 3B1+rinsi) ]}

and for [a,b]Mi1

asi3B+1nribsi2Bn

because

2Bsim0(modn)3B1 2Bsim0rin3B1 2B+rinsim03B1+rin 2B+rinsim03B1+rinsi

and for every [a,b]Mi1, we know that 2Bsim0rin3B1, so

asi(3B1)rinbsi2B
Mi=[a,b]Mi1, ri{[ max(a, 2B+rinsi), min(b, 3B1+knsi) ]}

step 4

if Mi contains only one interval and Mi=[a,a], then m0=a, ms01a(modn), else go back to step 2


Code

def bleichenbacher_1998(n: int, e: int, c: int, oracle):
    """
    - input : `n (int)`, `e (int)`, `c (int)`, `oracle (func)` , `c` is PKCS#1 conforming
    - output : `m (int)` , `e`'s plain
    - oracle func : 
        - input : `c (int)`
        - output : `PKCS_conforming (bool)` , is `c` PKCS#1 conforming
    """

    assert oracle(c)
    B = 1 << (n.bit_length() // 8 - 1) * 8

    def bleichenbacher_orifind_s(lower_bound: int):
        si = lower_bound
        while True:
            new_c = (pow(si, e, n) * c) % n
            if oracle(new_c):
                return si
            si += 1

    def bleichenbacher_optfind_s(prev_si: int, a: int, b: int):
        ri = ceil_int(2 * (b * prev_si - 2 * B), n)
        while True:
            low_bound = ceil_int(2 * B + ri * n, b)
            high_bound = ceil_int(3 * B + ri * n, a)
            for si in range(low_bound, high_bound):
                new_c = (pow(si, e, n) * c) % n
                if oracle(new_c):
                    return si
            ri += 1

    def bleichenbacher_merge_M(si: int, M: list):
        new_M = []
        for [a, b] in M:
            r_low = ceil_int(a * si - 3 * B + 1, n)
            r_high = floor_int(b * si - 2 * B, n) + 1
            for ri in range(r_low, r_high):
                interval_low = max(a, ceil_int(2 * B + ri * n, si))
                interval_high = min(b, floor_int(3 * B + ri * n - 1, si))
                if interval_high >= interval_low:
                    new_M.append([interval_low, interval_high])
        return new_M

    s = bleichenbacher_orifind_s(ceil_int(n, 3 * B))
    M = bleichenbacher_merge_M(s, [[2 * B, 3 * B - 1]])
    print(s, M)

    while True:
        if len(M) > 1:
            s = bleichenbacher_orifind_s(s + 1)
        else:
            if M[0][0] == M[0][1]:
                return M[0][0]
            s = bleichenbacher_optfind_s(s, M[0][0], M[0][1])
        M = bleichenbacher_merge_M(s, M)
        print(s, M)