Skip to main content

paillier_zk/
multiexp.rs

1//! Optimized multiexponentiation with precomputations
2//!
3//! Many ZK proofs often require computing `s^x t^y mod N` with s, t, and N being known in advance.
4//! This module provides [`MultiexpTable`] that can compute multiexponent faster.
5
6#![allow(non_snake_case)]
7
8use fast_paillier::backend::Integer;
9
10/// Precomputed table for performing faster multiexponentiation
11#[derive(Debug, Clone)]
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13pub struct MultiexpTable {
14    s: Vec<Integer>,
15    ell_x: Integer,
16    s_to_ell_x: Integer,
17    t: Vec<Integer>,
18    ell_y: Integer,
19    t_to_ell_y: Integer,
20    N: Integer,
21}
22
23impl MultiexpTable {
24    /// Builds a multiexponentiation table to perform `s^x t^y mod N` faster
25    /// where `x` and `y` are up to `x_bits` and `y_bits`
26    ///
27    /// Returns `None` is `s` or `t` are non-positive or if any of them are not co-prime to `N` or
28    /// if `N` is less than 2.
29    pub fn build(s: &Integer, t: &Integer, x_bits: u32, y_bits: u32, N: Integer) -> Option<Self> {
30        if s.cmp0().is_le()
31            || t.cmp0().is_le()
32            || N <= Integer::one()
33            || !s.gcd_ref(&N).is_one()
34            || !t.gcd_ref(&N).is_one()
35        {
36            return None;
37        }
38        let k_x = x_bits / 8 + 1;
39        let k_y = y_bits / 8 + 1;
40        let mut s_table = Vec::with_capacity(k_x.try_into().ok()?);
41        let mut t_table = Vec::with_capacity(k_y.try_into().ok()?);
42
43        let B: u32 = 256;
44        for i in 0..k_x {
45            let B_to_i = Integer::u_pow_u(B, i);
46            s_table.push(s.pow_mod_ref(&B_to_i, &N)?);
47        }
48        for i in 0..k_y {
49            let B_to_i = Integer::u_pow_u(B, i);
50            t_table.push(t.pow_mod_ref(&B_to_i, &N)?);
51        }
52
53        // smallest negative value possible for `x`
54        let ell_x = -(Integer::one() << (k_x * 8)) + 1;
55        let s_to_ell_x = s.pow_mod_ref(&ell_x, &N)?;
56        // smallest negative value possible for `y`
57        let ell_y = -(Integer::one() << (k_y * 8)) + 1;
58        let t_to_ell_y = t.pow_mod_ref(&ell_y, &N)?;
59
60        Some(Self {
61            s: s_table,
62            ell_x,
63            s_to_ell_x,
64            t: t_table,
65            ell_y,
66            t_to_ell_y,
67            N,
68        })
69    }
70
71    /// Calculates `s^x t^y mod N`
72    ///
73    /// Returns `None` if either `x` or `y` do not fit into `x_bits` or `y_bits` provided in [`MultiexpTable::build`].
74    pub fn prod_exp(&self, x: &Integer, y: &Integer) -> Option<Integer> {
75        let x_is_neg = x.cmp0().is_lt();
76        // `x_digits` correspond to digits of `x` is it's non-negative, and `x - ell_x` otherwise
77        let x_digits = if !x_is_neg {
78            x.to_bytes_lsf()
79        } else {
80            let x = x - &self.ell_x;
81            if x.cmp0().is_lt() {
82                // `x` is less than lower bound
83                return None;
84            }
85            x.to_bytes_lsf()
86        };
87
88        let y_is_neg = y.cmp0().is_lt();
89        // `y_digits` correspond to digits of `y` is it's non-negative, and `y - ell_y` otherwise
90        let y_digits = if !y_is_neg {
91            y.to_bytes_lsf()
92        } else {
93            let y = y - &self.ell_y;
94            if y.cmp0().is_lt() {
95                // `y` is less than lower bound
96                return None;
97            }
98            y.to_bytes_lsf()
99        };
100
101        if x_digits.len() > self.s.len() || y_digits.len() > self.t.len() {
102            // `x` or `y` are higher than upper bound
103            return None;
104        }
105
106        let mut digits_table = [(); 255].map(|_| None);
107        build_digits_table(&mut digits_table, &self.s, &x_digits, &self.N);
108        build_digits_table(&mut digits_table, &self.t, &y_digits, &self.N);
109
110        let mut res = Integer::one();
111        let mut acc = Integer::one();
112        for d in digits_table.iter().rev() {
113            if let Some(d) = d {
114                acc = (acc * d) % &self.N;
115            }
116            res = (res * &acc) % &self.N;
117        }
118
119        if x_is_neg {
120            res = (res * &self.s_to_ell_x) % &self.N;
121        }
122        if y_is_neg {
123            res = (res * &self.t_to_ell_y) % &self.N;
124        }
125
126        Some(res)
127    }
128
129    /// Returns max size of exponents (in bits) that can be computed
130    ///
131    /// Max exponent size is guaranteed to be equal or greater than `x_bits` and `y_bits`
132    /// provided in [MultiexpTable::build]
133    pub fn max_exponents_size(&self) -> (usize, usize) {
134        (self.s.len() * 8, self.t.len() * 8)
135    }
136
137    /// Estimates size of the table in RAM in bytes
138    pub fn size_in_bytes(&self) -> usize {
139        let Self {
140            s,
141            ell_x,
142            s_to_ell_x,
143            t,
144            ell_y,
145            t_to_ell_y,
146            N,
147        } = self;
148
149        // A few bytes to encode length of Vec `s` and `t`
150        let vec_len = 2 * (usize::BITS as usize / 8);
151        // And a few bytes more to encode length of each integer
152        let int_len = (5 + s.len() + t.len()) * (usize::BITS as usize / 8);
153
154        let s: usize = s.iter().map(|s_i| s_i.significant_dwords()).sum();
155        let ell_x = ell_x.significant_dwords();
156        let s_to_ell_x = s_to_ell_x.significant_dwords();
157        let t: usize = t.iter().map(|t_i| t_i.significant_dwords()).sum();
158        let ell_y = ell_y.significant_dwords();
159        let t_to_ell_y = t_to_ell_y.significant_dwords();
160        let N = N.significant_dwords();
161
162        let limbs_bytes =
163            (u32::BITS as usize / 8) * (s + ell_x + s_to_ell_x + t + ell_y + t_to_ell_y + N);
164
165        vec_len + int_len + limbs_bytes
166    }
167}
168
169fn build_digits_table(
170    table: &mut [Option<Integer>; 255],
171    base: &[Integer],
172    digits: &[u8],
173    N: &Integer,
174) {
175    for (i, digit) in digits.iter().copied().enumerate() {
176        if digit != 0 {
177            match &mut table[usize::from(digit - 1)] {
178                Some(out) => {
179                    *out *= &base[i];
180                    *out %= N;
181                }
182                out @ None => *out = Some(base[i].clone()),
183            }
184        }
185    }
186}
187
188#[cfg(test)]
189mod test {
190    use fast_paillier::backend::Integer;
191
192    use super::MultiexpTable;
193
194    #[test]
195    fn multiexp_works() {
196        let N = Integer::from(100000);
197        let s = Integer::from(3);
198        let t = Integer::from(7);
199
200        let x_bits = 48;
201        let y_bits = 32;
202
203        let table = MultiexpTable::build(&s, &t, x_bits, y_bits, N.clone()).unwrap();
204
205        let mut rng = rand_dev::DevRng::new();
206
207        for _ in 0..100 {
208            let mut x = Integer::random_bits(x_bits, &mut rng);
209            if rand::Rng::gen(&mut rng) {
210                x = -x
211            }
212
213            let mut y = Integer::random_bits(y_bits, &mut rng);
214            if rand::Rng::gen(&mut rng) {
215                y = -y
216            }
217            println!("x={x} y={y}");
218
219            let actual = table.prod_exp(&x, &y).unwrap();
220            let expected = (s.pow_mod_ref(&x, &N).unwrap() * t.pow_mod_ref(&y, &N).unwrap()) % &N;
221            assert_eq!(actual, expected);
222        }
223    }
224}