1pub mod sqrt;
2
3use std::sync::Arc;
4
5use fast_paillier::backend::Integer;
6use generic_ec::Scalar;
7
8#[cfg_attr(
10 feature = "__internal_doctest",
11 derive(serde::Serialize, serde::Deserialize)
12)]
13#[derive(Clone, Debug)]
14pub struct Aux {
15 pub s: Integer,
17 pub t: Integer,
19 pub rsa_modulo: Integer,
21 #[cfg_attr(feature = "__internal_doctest", serde(skip))]
25 pub multiexp: Option<Arc<crate::multiexp::MultiexpTable>>,
26 #[cfg_attr(feature = "__internal_doctest", serde(skip))]
27 pub crt: Option<fast_paillier::utils::CrtExp>,
28}
29
30impl Aux {
31 pub fn combine(&self, x: &Integer, y: &Integer) -> Result<Integer, BadExponent> {
33 if let Some(table) = &self.multiexp {
34 match table.prod_exp(x, y) {
35 Some(res) => return Ok(res),
36 None if cfg!(debug_assertions) => {
37 return Err(BadExponentReason::ExpSize {
38 exp_size: (x.significant_bits(), y.significant_bits()),
39 max_exp_size: table.max_exponents_size(),
40 }
41 .into())
42 }
43 None => {
44 }
46 }
47 }
48
49 self.rsa_modulo
51 .combine(&self.s, x, &self.t, y)
52 .ok_or_else(BadExponent::undefined)
53 }
54
55 pub fn pow_mod(&self, x: &Integer, e: &Integer) -> Result<Integer, BadExponent> {
57 match &self.crt {
58 Some(crt) => {
59 let e = crt.prepare_exponent(e);
60 crt.exp(x, &e).ok_or_else(BadExponent::undefined)
61 }
62 None => Ok(x
63 .pow_mod_ref(e, &self.rsa_modulo)
64 .ok_or_else(BadExponent::undefined)?),
65 }
66 }
67
68 pub fn is_in_mult_group(&self, x: &Integer) -> bool {
70 x.in_mult_group_of(&self.rsa_modulo)
71 }
72
73 pub fn digest_public_data(&self) -> impl udigest::Digestable {
76 udigest::inline_struct!("paillier_zk.aux" {
77 s: udigest::Bytes(self.s.to_bytes_msf()),
78 t: udigest::Bytes(self.t.to_bytes_msf()),
79 rsa_modulo: udigest::Bytes(self.rsa_modulo.to_bytes_msf()),
80 })
81 }
82}
83
84#[derive(Debug, Clone, thiserror::Error)]
86#[error("invalid proof")]
87pub struct InvalidProof(
88 #[source]
89 #[from]
90 InvalidProofReason,
91);
92
93#[non_exhaustive]
96#[derive(Debug, PartialEq, Eq, Clone, Copy, thiserror::Error)]
97pub enum InvalidProofReason {
98 #[error("equality check failed {0}")]
100 EqualityCheck(usize),
101 #[error("mult group check failed {0}")]
103 MultGroupCheck(usize),
104 #[error("range check failed {0}")]
106 RangeCheck(usize),
107 #[error("encryption failed")]
109 Encryption,
110 #[error("paillier encryption failed")]
111 PaillierEnc,
112 #[error("paillier homomorphic op failed")]
113 PaillierOp,
114 #[error("powmod failed")]
116 ModPow,
117 #[error("modulus is prime")]
119 ModulusIsPrime,
120 #[error("modulus is even")]
122 ModulusIsEven,
123 #[error("incorrect nth root")]
125 IncorrectNthRoot,
126 #[error("incorrect 4th root")]
128 IncorrectFourthRoot,
129 #[error("conversion failed")]
131 Conversion,
132}
133
134impl InvalidProof {
135 #[cfg(test)]
136 pub(crate) fn reason(&self) -> InvalidProofReason {
137 self.0
138 }
139}
140
141impl From<BadExponent> for InvalidProof {
142 fn from(_err: BadExponent) -> Self {
143 InvalidProofReason::ModPow.into()
144 }
145}
146
147impl From<PaillierError> for InvalidProof {
148 fn from(_err: PaillierError) -> Self {
149 InvalidProof(InvalidProofReason::Encryption)
150 }
151}
152
153#[derive(Clone, Copy, Debug, thiserror::Error)]
155#[error("paillier encryption failed")]
156pub struct PaillierError;
157
158pub trait IntegerExt: Sized {
159 fn to_scalar<C: generic_ec::Curve>(&self) -> Scalar<C>;
161
162 fn curve_order<C: generic_ec::Curve>() -> Self;
164
165 fn from_rng_half_pm<R: rand_core::RngCore>(rng: &mut R, range: &Self) -> Self;
169
170 fn is_in_half_pm(&self, range: &Self) -> bool;
174}
175
176impl IntegerExt for Integer {
177 fn to_scalar<C: generic_ec::Curve>(&self) -> Scalar<C> {
178 let bytes_be = self.to_bytes_msf();
179 let s = Scalar::<C>::from_be_bytes_mod_order(bytes_be);
180 if self.cmp0().is_ge() {
181 s
182 } else {
183 -s
184 }
185 }
186
187 fn curve_order<C: generic_ec::Curve>() -> Self {
188 let order_minus_one = -Scalar::<C>::one();
189 let i = Integer::from_bytes_msf(&order_minus_one.to_be_bytes());
190 i + 1
191 }
192
193 fn from_rng_half_pm<R: rand_core::RngCore>(rng: &mut R, range: &Self) -> Self {
194 if range.is_even() {
195 let half_range = range >> 1;
196 let range_plus_one = range + 1u32;
197 range_plus_one.random_below(rng) - half_range
198 } else {
199 let half_range_minus_one = range >> 1;
202 range.random_below_ref(rng) - half_range_minus_one
203 }
204 }
205
206 fn is_in_half_pm(&self, range: &Self) -> bool {
207 let bound = range >> 1;
211 self.cmp_abs(&bound).is_le()
212 }
213}
214
215#[derive(Clone, Copy, Debug, thiserror::Error)]
219#[error(transparent)]
220pub struct BadExponent(#[from] BadExponentReason);
221
222impl BadExponent {
223 pub fn undefined() -> Self {
225 Self(BadExponentReason::Undefined)
226 }
227}
228
229#[derive(Clone, Copy, Debug, thiserror::Error)]
230enum BadExponentReason {
231 #[error("exponent is undefined")]
232 Undefined,
233 #[error("multiexp error: exponent size is too large (exponents size: {exp_size:?}, max exponent size: {max_exp_size:?})")]
234 ExpSize {
235 exp_size: (u64, u64),
236 max_exp_size: (usize, usize),
237 },
238}
239
240pub fn fail_if<E>(err: E, assertion: bool) -> Result<(), E> {
242 if assertion {
243 Ok(())
244 } else {
245 Err(err)
246 }
247}
248
249pub fn fail_if_ne<T: PartialEq, E>(err: E, lhs: T, rhs: T) -> Result<(), E> {
251 if lhs == rhs {
252 Ok(())
253 } else {
254 Err(err)
255 }
256}
257
258pub mod encoding {
259 pub struct Integer;
261 impl udigest::DigestAs<fast_paillier::backend::Integer> for Integer {
262 fn digest_as<B: udigest::Buffer>(
263 value: &fast_paillier::backend::Integer,
264 encoder: udigest::encoding::EncodeValue<B>,
265 ) {
266 let digits = value.to_bytes_msf();
267 encoder.encode_leaf_value(digits)
268 }
269 }
270
271 pub struct AnyEncryptionKey;
273 impl udigest::DigestAs<&dyn fast_paillier::AnyEncryptionKey> for AnyEncryptionKey {
274 fn digest_as<B: udigest::Buffer>(
275 value: &&dyn fast_paillier::AnyEncryptionKey,
276 encoder: udigest::encoding::EncodeValue<B>,
277 ) {
278 Integer::digest_as(value.n(), encoder)
279 }
280 }
281}
282
283#[cfg(test)]
285pub mod test {
286 use fast_paillier::backend::Integer;
287
288 pub fn random_key<R: rand_core::RngCore>(rng: &mut R) -> Option<fast_paillier::DecryptionKey> {
289 let p = generate_blum_prime(rng, 1536);
290 let q = generate_blum_prime(rng, 1536);
291 fast_paillier::DecryptionKey::from_primes(p, q).ok()
292 }
293
294 pub fn aux<R: rand_core::RngCore>(rng: &mut R) -> super::Aux {
295 let p = generate_blum_prime(rng, 1536);
296 let q = generate_blum_prime(rng, 1536);
297 let n = &p * &q;
298
299 let (s, t) = {
300 let phi_n = (p - 1u8) * (q - 1u8);
301 let r = Integer::sample_in_mult_group_of(rng, &n);
302 let lambda = phi_n.random_below(rng);
303
304 let t = r.square().modulo(&n);
305 let s = t.pow_mod_ref(&lambda, &n).unwrap();
306
307 (s, t)
308 };
309
310 super::Aux {
311 s,
312 t,
313 rsa_modulo: n,
314 multiexp: None,
315 crt: None,
316 }
317 }
318
319 pub fn generate_blum_prime(rng: &mut impl rand_core::RngCore, bits_size: u32) -> Integer {
320 loop {
321 let n = Integer::generate_prime(rng, bits_size);
322 if n.mod_u(4) == 3 {
323 break n;
324 }
325 }
326 }
327}
328
329#[cfg(test)]
330mod _test {
331 use fast_paillier::backend::Integer;
332
333 use super::IntegerExt;
334
335 #[test]
336 fn to_scalar_encoding() {
337 type E = generic_ec::curves::Secp256k1;
338
339 let bytes = [123u8, 231u8];
340 let int = u16::from_be_bytes(bytes);
341 let bn = Integer::from(int);
342 let scalar = bn.to_scalar();
343 assert_eq!(scalar, generic_ec::Scalar::<E>::from(int));
344
345 assert_eq!(bn.to_bytes_msf(), &bytes);
346
347 let curve_order = Integer::curve_order::<E>();
348 assert_eq!(curve_order.to_scalar(), generic_ec::Scalar::<E>::zero());
349 assert_eq!(
350 (curve_order - 1u8).to_scalar(),
351 -generic_ec::Scalar::<E>::one()
352 );
353 }
354
355 #[test]
356 fn multiexp() {
357 let mut rng = rand_dev::DevRng::new();
358 let mut aux = super::test::aux(&mut rng);
359 let table = std::sync::Arc::new(
360 crate::multiexp::MultiexpTable::build(&aux.s, &aux.t, 512, 448, aux.rsa_modulo.clone())
361 .unwrap(),
362 );
363 let (x_bits, y_bits) = table.max_exponents_size();
364 aux.multiexp = Some(table);
365
366 let x_max = (Integer::one() << x_bits) - 1;
368 let y_max = (Integer::one() << y_bits) - 1;
369 let actual = aux.combine(&x_max, &y_max).unwrap();
370 let expected = aux
371 .rsa_modulo
372 .combine(&aux.s, &x_max, &aux.t, &y_max)
373 .unwrap();
374 assert_eq!(actual, expected);
375
376 let x_min = -&x_max;
378 let y_min = -&y_max;
379 let actual = aux.combine(&x_min, &y_min).unwrap();
380 let expected = aux
381 .rsa_modulo
382 .combine(&aux.s, &x_min, &aux.t, &y_min)
383 .unwrap();
384 assert_eq!(actual, expected);
385
386 for _ in 0..100 {
388 let x = (&x_max + 1u8).random_below(&mut rng);
389 let y = (&y_max + 1u8).random_below(&mut rng);
390
391 let x = if rand::Rng::gen(&mut rng) { x } else { -x };
392 let y = if rand::Rng::gen(&mut rng) { y } else { -y };
393
394 println!("x: {x}");
395 println!("y: {y}");
396
397 let actual = aux.combine(&x, &y).unwrap();
398 let expected = aux.rsa_modulo.combine(&aux.s, &x, &aux.t, &y).unwrap();
399 assert_eq!(actual, expected);
400 }
401 }
402
403 #[test]
404 fn test_from_rng_half_pm_bounds() {
405 let mut rng = rand_dev::DevRng::new();
406 let range = Integer::from(10);
408 let upper_bound = &range >> 1;
409 let lower_bound = -&upper_bound;
410 let mut min = Integer::from(0);
411 let mut max = Integer::from(0);
412
413 for _ in 0..10000 {
415 let value = Integer::from_rng_half_pm(&mut rng, &range);
416 if value > max {
417 max.clone_from(&value);
418 }
419 if value < min {
420 min.clone_from(&value);
421 }
422 }
423
424 assert_eq!(
425 min, lower_bound,
426 "Minimum value {min} did not match expected lower bound {lower_bound}"
427 );
428 assert_eq!(
429 max, upper_bound,
430 "Maximum value {max} did not match expected upper bound {upper_bound}"
431 );
432
433 let range = Integer::from(9);
435 let range_minus_one = &range - Integer::one();
436 let upper_bound = range_minus_one >> 1;
437 let lower_bound = -&upper_bound;
438 let mut min = Integer::from(0);
439 let mut max = Integer::from(0);
440
441 for _ in 0..10000 {
443 let value = Integer::from_rng_half_pm(&mut rng, &range);
444 if value > max {
445 max.clone_from(&value);
446 }
447 if value < min {
448 min.clone_from(&value);
449 }
450 }
451
452 assert_eq!(
453 min, lower_bound,
454 "Minimum value {min} did not match expected lower bound {lower_bound}"
455 );
456 assert_eq!(
457 max, upper_bound,
458 "Maximum value {max} did not match expected upper bound {upper_bound}"
459 );
460 }
461
462 #[test]
463 fn test_is_in_half_pm() {
464 let range = Integer::from(10);
466 let a_1 = Integer::from(-6);
467 let a_2 = Integer::from(-5);
468 let a_3 = Integer::from(5);
469 let a_4 = Integer::from(6);
470 assert!(
471 !a_1.is_in_half_pm(&range),
472 "{a_1} should be outside [-range/2,range/2]"
473 );
474 assert!(
475 a_2.is_in_half_pm(&range),
476 "{a_2} should be in [-range/2,range/2]"
477 );
478 assert!(
479 a_3.is_in_half_pm(&range),
480 "{a_3} should be in [-range/2,range/2]"
481 );
482 assert!(
483 !a_4.is_in_half_pm(&range),
484 "{a_4} should be outside [-range/2,range/2]"
485 );
486
487 let range = Integer::from(9);
489 let a_1 = Integer::from(-5);
490 let a_2 = Integer::from(-4);
491 let a_3 = Integer::from(4);
492 let a_4 = Integer::from(5);
493 assert!(
494 !a_1.is_in_half_pm(&range),
495 "{a_1} should be outside [-(range-1)/2,(range-1)/2]"
496 );
497 assert!(
498 a_2.is_in_half_pm(&range),
499 "{a_1} should be in [-(range-1)/2,(range-1)/2]"
500 );
501 assert!(
502 a_3.is_in_half_pm(&range),
503 "{a_3} should be in [-(range-1)/2,(range-1)/2]"
504 );
505 assert!(
506 !a_4.is_in_half_pm(&range),
507 "{a_4} should be outside [-(range-1)/2,(range-1)/2]"
508 );
509 }
510}