1#![allow(non_snake_case)]
7
8use fast_paillier::backend::Integer;
9
10#[derive(Debug, Clone)]
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13pub struct MultiexpTable {
14 s: Vec<Integer>,
15 ell_x: Integer,
16 s_to_ell_x: Integer,
17 t: Vec<Integer>,
18 ell_y: Integer,
19 t_to_ell_y: Integer,
20 N: Integer,
21}
22
23impl MultiexpTable {
24 pub fn build(s: &Integer, t: &Integer, x_bits: u32, y_bits: u32, N: Integer) -> Option<Self> {
30 if s.cmp0().is_le()
31 || t.cmp0().is_le()
32 || N <= Integer::one()
33 || !s.gcd_ref(&N).is_one()
34 || !t.gcd_ref(&N).is_one()
35 {
36 return None;
37 }
38 let k_x = x_bits / 8 + 1;
39 let k_y = y_bits / 8 + 1;
40 let mut s_table = Vec::with_capacity(k_x.try_into().ok()?);
41 let mut t_table = Vec::with_capacity(k_y.try_into().ok()?);
42
43 let B: u32 = 256;
44 for i in 0..k_x {
45 let B_to_i = Integer::u_pow_u(B, i);
46 s_table.push(s.pow_mod_ref(&B_to_i, &N)?);
47 }
48 for i in 0..k_y {
49 let B_to_i = Integer::u_pow_u(B, i);
50 t_table.push(t.pow_mod_ref(&B_to_i, &N)?);
51 }
52
53 let ell_x = -(Integer::one() << (k_x * 8)) + 1;
55 let s_to_ell_x = s.pow_mod_ref(&ell_x, &N)?;
56 let ell_y = -(Integer::one() << (k_y * 8)) + 1;
58 let t_to_ell_y = t.pow_mod_ref(&ell_y, &N)?;
59
60 Some(Self {
61 s: s_table,
62 ell_x,
63 s_to_ell_x,
64 t: t_table,
65 ell_y,
66 t_to_ell_y,
67 N,
68 })
69 }
70
71 pub fn prod_exp(&self, x: &Integer, y: &Integer) -> Option<Integer> {
75 let x_is_neg = x.cmp0().is_lt();
76 let x_digits = if !x_is_neg {
78 x.to_bytes_lsf()
79 } else {
80 let x = x - &self.ell_x;
81 if x.cmp0().is_lt() {
82 return None;
84 }
85 x.to_bytes_lsf()
86 };
87
88 let y_is_neg = y.cmp0().is_lt();
89 let y_digits = if !y_is_neg {
91 y.to_bytes_lsf()
92 } else {
93 let y = y - &self.ell_y;
94 if y.cmp0().is_lt() {
95 return None;
97 }
98 y.to_bytes_lsf()
99 };
100
101 if x_digits.len() > self.s.len() || y_digits.len() > self.t.len() {
102 return None;
104 }
105
106 let mut digits_table = [(); 255].map(|_| None);
107 build_digits_table(&mut digits_table, &self.s, &x_digits, &self.N);
108 build_digits_table(&mut digits_table, &self.t, &y_digits, &self.N);
109
110 let mut res = Integer::one();
111 let mut acc = Integer::one();
112 for d in digits_table.iter().rev() {
113 if let Some(d) = d {
114 acc = (acc * d) % &self.N;
115 }
116 res = (res * &acc) % &self.N;
117 }
118
119 if x_is_neg {
120 res = (res * &self.s_to_ell_x) % &self.N;
121 }
122 if y_is_neg {
123 res = (res * &self.t_to_ell_y) % &self.N;
124 }
125
126 Some(res)
127 }
128
129 pub fn max_exponents_size(&self) -> (usize, usize) {
134 (self.s.len() * 8, self.t.len() * 8)
135 }
136
137 pub fn size_in_bytes(&self) -> usize {
139 let Self {
140 s,
141 ell_x,
142 s_to_ell_x,
143 t,
144 ell_y,
145 t_to_ell_y,
146 N,
147 } = self;
148
149 let vec_len = 2 * (usize::BITS as usize / 8);
151 let int_len = (5 + s.len() + t.len()) * (usize::BITS as usize / 8);
153
154 let s: usize = s.iter().map(|s_i| s_i.significant_dwords()).sum();
155 let ell_x = ell_x.significant_dwords();
156 let s_to_ell_x = s_to_ell_x.significant_dwords();
157 let t: usize = t.iter().map(|t_i| t_i.significant_dwords()).sum();
158 let ell_y = ell_y.significant_dwords();
159 let t_to_ell_y = t_to_ell_y.significant_dwords();
160 let N = N.significant_dwords();
161
162 let limbs_bytes =
163 (u32::BITS as usize / 8) * (s + ell_x + s_to_ell_x + t + ell_y + t_to_ell_y + N);
164
165 vec_len + int_len + limbs_bytes
166 }
167}
168
169fn build_digits_table(
170 table: &mut [Option<Integer>; 255],
171 base: &[Integer],
172 digits: &[u8],
173 N: &Integer,
174) {
175 for (i, digit) in digits.iter().copied().enumerate() {
176 if digit != 0 {
177 match &mut table[usize::from(digit - 1)] {
178 Some(out) => {
179 *out *= &base[i];
180 *out %= N;
181 }
182 out @ None => *out = Some(base[i].clone()),
183 }
184 }
185 }
186}
187
188#[cfg(test)]
189mod test {
190 use fast_paillier::backend::Integer;
191
192 use super::MultiexpTable;
193
194 #[test]
195 fn multiexp_works() {
196 let N = Integer::from(100000);
197 let s = Integer::from(3);
198 let t = Integer::from(7);
199
200 let x_bits = 48;
201 let y_bits = 32;
202
203 let table = MultiexpTable::build(&s, &t, x_bits, y_bits, N.clone()).unwrap();
204
205 let mut rng = rand_dev::DevRng::new();
206
207 for _ in 0..100 {
208 let mut x = Integer::random_bits(x_bits, &mut rng);
209 if rand::Rng::gen(&mut rng) {
210 x = -x
211 }
212
213 let mut y = Integer::random_bits(y_bits, &mut rng);
214 if rand::Rng::gen(&mut rng) {
215 y = -y
216 }
217 println!("x={x} y={y}");
218
219 let actual = table.prod_exp(&x, &y).unwrap();
220 let expected = (s.pow_mod_ref(&x, &N).unwrap() * t.pow_mod_ref(&y, &N).unwrap()) % &N;
221 assert_eq!(actual, expected);
222 }
223 }
224}