Multi-Party Computation was a cryptography challenge for 250 points in this year's Boston Key Party.
We were given the source code of a two-party protocol for private set intersection (PSI). In a secure protocol server and client compute the intersection of their input sets without learning about the input of the other side.
In the challenge we take the role of the client. The server reads the flag from a file and creates a set which encodes it as follows. Let \(n\) be the length of the flag. It generates \(n\) large random integers and multiplies them by \(256\), such that the eight least significant bits are zero. The resulting numbers are sorted in ascending order and the ASCII values of the flag characters are added to them.
A Homomorphic Encryption Scheme
The application implements the Paillier crypto system and uses its homomorphic properties. Denote by \(Enc_{pk}(m)\) one encryption of message \(m\) under the private key \(pk\). Since the encryption algorithm is probabilistic, there are multiple encryptions of the same message.
If we multiply do encryptions under the same public key, we get an encryption of the sum of the two plaintexts.
Taking the encrypted message to the power of some constant \(k\) yields an encryption of the message multiplied by \(k\).
The Protocol
The inputs of the server and the client are the sets \(A_s = \{x_1, \dotsc, x_{n_s}\}\) and \(A_c = \{y_1, \dotsc, y_{n_c}\}\) respectively on which the intersection is to be computed. Each element is encoded as a number.
First the client generates a Paillier keypair \(\langle pk, sk \rangle\) and generates the following polynomial.
The polynomial is represented by the list of its coefficients. The client encrypts this polynomial with its public key \(pk\) by encrypting each coefficient.
After that is done it sends its public key together with the encrypted polynomial to the server.
The server evaluates the encrypted polynomial for each of the elements \(x_i\) in its set. It uses the homomorphic properties to multiply the \(k\) th term of the polynomial with the constant \(x_i^k\) and then add up the resulting terms.
The next steps are to multiply \(\alpha_i\) with a large random number \(r_i\) and add the element.
By the construction of \(p(X)\) if \(x_i\) is in \(A_s \cap A_c\) then \(p(x_i) = 0\) since the elements of \(A_c\) are the root of the polynomial.
The server sends a random permutation of the set \(Z = \{z_1, \dotsc, z_{n_s}\}\) back to the client.
The client computes \(A_c \cap Z\) wich is with high propability equal to \(A_c \cap A_s\).
The Exploitation
Let \(\Sigma \supseteq A_c, A_s\) be the set of all possible values. Since we want to know the set \(A_c = A_c \cap \Sigma\) we can construct a polynomial, with has all elements of \(\Sigma\) as roots. The easiest way to do this is to select the constant polynomial \(p(X) = 0\). Indeed, if we encrypt \(p(X)\) and send it to the server, it responds with the set \(Z = A_s\).
To get decode the flag, we sort the decrypted values in ascending order and take the lowest eight bit of each number. Interpretation of these as ASCII characters yields the flag: FLAG{Monic polynomials FTW}
Appendix
The given source code:
mpc.py
from Crypto.Util import number
import random
from SocketServer import ThreadingMixIn
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
import sys
import json
import traceback
def L(x, n):
return (x-1) // n
def paillier_keygen():
# Returns (pk, sk)
p = number.getStrongPrime(512)
q = number.getStrongPrime(512)
n = p*q
lam = (p-1)*(q-1)/2
while True:
g = random.randrange(n**2)
if number.GCD(g, n) != 1:
continue
mu_inv = L(pow(g, lam, n**2), n)
if number.GCD(mu_inv, n) != 1:
continue
mu = number.inverse(mu_inv, n)
break
return (n, g), (lam, mu)
def paillier_encrypt((n, g), m):
while True:
r = random.randrange(n)
if number.GCD(r, n) == 1:
break
return (pow(g, m, n**2) * pow(r, n, n**2)) % (n**2)
def paillier_decrypt((n, g), (lam, mu), c):
return (L(pow(c, lam, n**2), n) * mu) % n
def paillier_add((n, g), a, b):
return (a * b) % (n**2)
def paillier_multiply((n, g), a, k):
return pow(a, k, n**2)
def mpc_monomial(point):
return [-point, 1]
def mpc_multiply_poly(n, x, y):
result = [0]*(len(x) + len(y))
for i in range(len(x)):
for j in range(len(y)):
result[i+j] += (result[i+j] + x[i]*y[j]) % n
return result
def mpc_encrypt_poly(pk, poly):
return [paillier_encrypt(pk, term) for term in poly]
def mpc_client_genpoly((n, g), points):
result = [1]
for point in points:
result = mpc_multiply_poly(n, result, mpc_monomial(point))
return mpc_encrypt_poly(pk, result)
def mpc_evaluate_poly((n, g), poly, point):
pow_point = point
result = poly[0]
for term in poly[1:]:
result = paillier_add((n, g), result, paillier_multiply((n, g), term, pow_point))
pow_point = (pow_point * point) % n
return result
def mpc_server_side((n, g), poly, points):
for point in points:
result = mpc_evaluate_poly((n, g), poly, point)
result = paillier_multiply((n, g), result, random.randrange(n))
result = paillier_add((n, g), result, paillier_encrypt((n, g), point))
yield result
def mpc_client_parseresults(pk, sk, c_points, s_points_enc):
s_points = [paillier_decrypt(pk, sk, point) for point in s_points_enc]
return set(c_points) & set(s_points)
class MpcHandler(BaseHTTPRequestHandler):
def do_POST(self):
try:
data_str = self.rfile.read(int(self.headers.getheader('content-length')))
data = json.loads(data_str)
n = data['n']
if (n < 2**64):
raise ValueError('too small')
g = data['g']
poly = data['poly']
l = list(mpc_server_side((n, g), poly, POINTS))
random.shuffle(l)
result = json.dumps(l)
except Exception as e:
self.send_response(400)
else:
self.send_response(200)
self.end_headers()
self.wfile.write(result)
class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
pass
if __name__=="__main__":
assert(len(sys.argv) >= 3)
with open('FLAG.txt', 'r') as f:
flag = f.read()[:-1]
print flag
POINTS = []
for i in range(len(flag)):
POINTS.append(random.randrange(2**48) * 256)
POINTS.sort()
for i in range(len(flag)):
POINTS[i] += ord(flag[i])
print POINTS
server = ThreadedHTTPServer((sys.argv[1], int(sys.argv[2])), MpcHandler)
server.serve_forever()
Out exploit script:
exploit.py
#!/usr/bin/env python3
import json
import random
import requests
from Crypto.Util import number
# take paillier_{keygen,encrypt,decrypt} from above
def main():
pk, sk = paillier_keygen()
poly = [paillier_encrypt(pk, 0)]
data = {
'n': pk[0],
'g': pk[1],
'poly': poly,
}
j = json.dumps(data)
print(j)
res = requests.post('http://mpc-1952363567.us-west-2.elb.amazonaws.com:1025/', data=j)
enc_points = res.json()
points = sorted(paillier_decrypt(pk, sk, p) for p in enc_points)
ascii_vals = [p & 0xff for p in points]
print(ascii_vals)
flag = bytes(ascii_vals)
print(flag)
if __name__ == '__main__':
main()