Boston Key Party 2017: Multi-Party Computation

Multi-Party Computation was a cryptography challenge for 250 points in this year's Boston Key Party.

We were given the source code of a two-party protocol for private set intersection (PSI). In a secure protocol server and client compute the intersection of their input sets without learning about the input of the other side.

In the challenge we take the role of the client. The server reads the flag from a file and creates a set which encodes it as follows. Let \(n\) be the length of the flag. It generates \(n\) large random integers and multiplies them by \(256\), such that the eight least significant bits are zero. The resulting numbers are sorted in ascending order and the ASCII values of the flag characters are added to them.

A Homomorphic Encryption Scheme

The application implements the Paillier crypto system and uses its homomorphic properties. Denote by \(Enc_{pk}(m)\) one encryption of message \(m\) under the private key \(pk\). Since the encryption algorithm is probabilistic, there are multiple encryptions of the same message.

If we multiply do encryptions under the same public key, we get an encryption of the sum of the two plaintexts.

\begin{equation*} Enc_{pk}(m_1) \cdot Enc_{pk}(m_2) = Enc_{pk}(m_1 + m_2) \end{equation*}

Taking the encrypted message to the power of some constant \(k\) yields an encryption of the message multiplied by \(k\).

\begin{equation*} Enc_{pk}(m_1)^k = Enc_{pk}(k \cdot m_1) \end{equation*}

The Protocol

The inputs of the server and the client are the sets \(A_s = \{x_1, \dotsc, x_{n_s}\}\) and \(A_c = \{y_1, \dotsc, y_{n_c}\}\) respectively on which the intersection is to be computed. Each element is encoded as a number.

First the client generates a Paillier keypair \(\langle pk, sk \rangle\) and generates the following polynomial.

\begin{equation*} \begin{split} p(X) &= (X - y_1) \dotsm (X - y_{n_c}) \\ &= p_0 + p_1X + p_2X^2 + \dotsb + p_{n_c}X^{n_c} \end{split} \end{equation*}

The polynomial is represented by the list of its coefficients. The client encrypts this polynomial with its public key \(pk\) by encrypting each coefficient.

\begin{equation*} Enc_{pk}(p_0), Enc_{pk}(p_1), Enc_{pk}(p_2), \dotsc, Enc_{pk}(p_{n_c}) \end{equation*}

After that is done it sends its public key together with the encrypted polynomial to the server.

The server evaluates the encrypted polynomial for each of the elements \(x_i\) in its set. It uses the homomorphic properties to multiply the \(k\) th term of the polynomial with the constant \(x_i^k\) and then add up the resulting terms.

\begin{equation*} \begin{split} &Enc_{pk}(p_0) \cdot Enc_{pk}(p_1)^{x_i} \cdot Enc_{pk}(p_2)^{x_i^2} \dotsm Enc_{pk}(p_{n_c})^{x_i^{n_c}} \\ = &Enc_{pk}(p_0) \cdot Enc_{pk}(p_1 \cdot x_i) \cdot Enc_{pk}(p_2 \cdot x_i^2) \dotsm Enc_{pk}(p_{n_c} \cdot x_i^{n_c}) \\ = &Enc_{pk}(p_0 + p_1 \cdot x_i + p_2 \cdot x_i^2 + \dotsb + p_{n_c} \cdot x_i^{n_c}) \\ = &Enc_{pk}(p(x_i)) \end{split} \end{equation*}

The next steps are to multiply \(\alpha_i\) with a large random number \(r_i\) and add the element.

\begin{equation*} z_i = Enc_{pk}(p(x_i))^{r_i} \cdot Enc_{pk}(x_i) = Enc_{pk}(p(x_i) \cdot {r_i} + x_i) \end{equation*}

By the construction of \(p(X)\) if \(x_i\) is in \(A_s \cap A_c\) then \(p(x_i) = 0\) since the elements of \(A_c\) are the root of the polynomial.

\begin{equation*} z_i = \begin{cases} Enc_{pk}(x_i) & \text{if } x_i \in A_s \cap A_c \\ Enc_{pk}(\text{some random looking number}) & \text{if } x_i \notin A_s \cap A_c \\ \end{cases} \end{equation*}

The server sends a random permutation of the set \(Z = \{z_1, \dotsc, z_{n_s}\}\) back to the client.

The client computes \(A_c \cap Z\) wich is with high propability equal to \(A_c \cap A_s\).

The Exploitation

Let \(\Sigma \supseteq A_c, A_s\) be the set of all possible values. Since we want to know the set \(A_c = A_c \cap \Sigma\) we can construct a polynomial, with has all elements of \(\Sigma\) as roots. The easiest way to do this is to select the constant polynomial \(p(X) = 0\). Indeed, if we encrypt \(p(X)\) and send it to the server, it responds with the set \(Z = A_s\).

To get decode the flag, we sort the decrypted values in ascending order and take the lowest eight bit of each number. Interpretation of these as ASCII characters yields the flag: FLAG{Monic polynomials FTW}

Appendix

The given source code:

mpc.py

from Crypto.Util import number
import random

from SocketServer import ThreadingMixIn
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler

import sys
import json

import traceback

def L(x, n):
  return (x-1) // n


def paillier_keygen():
  # Returns (pk, sk)
  p = number.getStrongPrime(512)
  q = number.getStrongPrime(512)
  n = p*q
  lam = (p-1)*(q-1)/2
  while True:
    g = random.randrange(n**2)
    if number.GCD(g, n) != 1:
      continue
    mu_inv = L(pow(g, lam, n**2), n)
    if number.GCD(mu_inv, n) != 1:
      continue
    mu = number.inverse(mu_inv, n)
    break
  return (n, g), (lam, mu)

def paillier_encrypt((n, g), m):
  while True:
    r = random.randrange(n)
    if number.GCD(r, n) == 1:
      break
  return (pow(g, m, n**2) * pow(r, n, n**2)) % (n**2)

def paillier_decrypt((n, g), (lam, mu), c):
  return (L(pow(c, lam, n**2), n) * mu) % n

def paillier_add((n, g), a, b):
  return (a * b) % (n**2)

def paillier_multiply((n, g), a, k):
  return pow(a, k, n**2)

def mpc_monomial(point):
  return [-point, 1]

def mpc_multiply_poly(n, x, y):
  result = [0]*(len(x) + len(y))
  for i in range(len(x)):
    for j in range(len(y)):
      result[i+j] += (result[i+j] + x[i]*y[j]) % n
  return result

def mpc_encrypt_poly(pk, poly):
  return [paillier_encrypt(pk, term) for term in poly]

def mpc_client_genpoly((n, g), points):
  result = [1]
  for point in points:
    result = mpc_multiply_poly(n, result, mpc_monomial(point))
  return mpc_encrypt_poly(pk, result)

def mpc_evaluate_poly((n, g), poly, point):
  pow_point = point
  result = poly[0]
  for term in poly[1:]:
    result = paillier_add((n, g), result, paillier_multiply((n, g), term, pow_point))
    pow_point = (pow_point * point) % n
  return result

def mpc_server_side((n, g), poly, points):
  for point in points:
    result = mpc_evaluate_poly((n, g), poly, point)
    result = paillier_multiply((n, g), result, random.randrange(n))
    result = paillier_add((n, g), result, paillier_encrypt((n, g), point))
    yield result

def mpc_client_parseresults(pk, sk, c_points, s_points_enc):
  s_points = [paillier_decrypt(pk, sk, point) for point in s_points_enc]
  return set(c_points) & set(s_points)



class MpcHandler(BaseHTTPRequestHandler):
  def do_POST(self):
    try:
      data_str = self.rfile.read(int(self.headers.getheader('content-length')))
      data = json.loads(data_str)
      n = data['n']
      if (n < 2**64):
        raise ValueError('too small')
      g = data['g']
      poly = data['poly']
      l = list(mpc_server_side((n, g), poly, POINTS))
      random.shuffle(l)
      result = json.dumps(l)
    except Exception as e:
      self.send_response(400)
    else:
      self.send_response(200)
      self.end_headers()
      self.wfile.write(result)


class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
  pass

if __name__=="__main__":
  assert(len(sys.argv) >= 3)

  with open('FLAG.txt', 'r') as f:
    flag = f.read()[:-1]
  print flag

  POINTS = []
  for i in range(len(flag)):
    POINTS.append(random.randrange(2**48) * 256)
  POINTS.sort()
  for i in range(len(flag)):
    POINTS[i] += ord(flag[i])
  print POINTS

  server = ThreadedHTTPServer((sys.argv[1], int(sys.argv[2])), MpcHandler)
  server.serve_forever()

Out exploit script:

exploit.py

#!/usr/bin/env python3

import json
import random
import requests
from Crypto.Util import number

# take paillier_{keygen,encrypt,decrypt} from above

def main():
    pk, sk = paillier_keygen()
    poly = [paillier_encrypt(pk, 0)]
    data = {
        'n': pk[0],
        'g': pk[1],
        'poly': poly,
    }
    j = json.dumps(data)
    print(j)
    res = requests.post('http://mpc-1952363567.us-west-2.elb.amazonaws.com:1025/', data=j)
    enc_points = res.json()
    points = sorted(paillier_decrypt(pk, sk, p) for p in enc_points)
    ascii_vals = [p & 0xff for p in points]
    print(ascii_vals)
    flag = bytes(ascii_vals)
    print(flag)


if __name__ == '__main__':
    main()