Skip to main content

paillier_zk/
common.rs

1pub mod sqrt;
2
3use std::sync::Arc;
4
5use fast_paillier::backend::Integer;
6use generic_ec::Scalar;
7
8/// Auxiliary data known to both prover and verifier
9#[cfg_attr(
10    feature = "__internal_doctest",
11    derive(serde::Serialize, serde::Deserialize)
12)]
13#[derive(Clone, Debug)]
14pub struct Aux {
15    /// ring-pedersen parameter
16    pub s: Integer,
17    /// ring-pedersen parameter
18    pub t: Integer,
19    /// N^ in paper
20    pub rsa_modulo: Integer,
21    /// Precomuted table for computing `s^x t^y mod rsa_modulo` faster
22    ///
23    /// If absent, optimization is disabled.
24    #[cfg_attr(feature = "__internal_doctest", serde(skip))]
25    pub multiexp: Option<Arc<crate::multiexp::MultiexpTable>>,
26    #[cfg_attr(feature = "__internal_doctest", serde(skip))]
27    pub crt: Option<fast_paillier::utils::CrtExp>,
28}
29
30impl Aux {
31    /// Returns `s^x t^y mod rsa_modulo`
32    pub fn combine(&self, x: &Integer, y: &Integer) -> Result<Integer, BadExponent> {
33        if let Some(table) = &self.multiexp {
34            match table.prod_exp(x, y) {
35                Some(res) => return Ok(res),
36                None if cfg!(debug_assertions) => {
37                    return Err(BadExponentReason::ExpSize {
38                        exp_size: (x.significant_bits(), y.significant_bits()),
39                        max_exp_size: table.max_exponents_size(),
40                    }
41                    .into())
42                }
43                None => {
44                    // When debug assertions are disabled, we fallback to naive exponentiation
45                }
46            }
47        }
48
49        // Naive exponentiation when optimizations are not enabled
50        self.rsa_modulo
51            .combine(&self.s, x, &self.t, y)
52            .ok_or_else(BadExponent::undefined)
53    }
54
55    /// Returns `x^e mod rsa_modulo`
56    pub fn pow_mod(&self, x: &Integer, e: &Integer) -> Result<Integer, BadExponent> {
57        match &self.crt {
58            Some(crt) => {
59                let e = crt.prepare_exponent(e);
60                crt.exp(x, &e).ok_or_else(BadExponent::undefined)
61            }
62            None => Ok(x
63                .pow_mod_ref(e, &self.rsa_modulo)
64                .ok_or_else(BadExponent::undefined)?),
65        }
66    }
67
68    /// Checks if `x` is in multiplicative group Z<super>*</super><sub>N</sub> where `N = ` [`rsa_modulo`](Self::rsa_modulo)
69    pub fn is_in_mult_group(&self, x: &Integer) -> bool {
70        x.in_mult_group_of(&self.rsa_modulo)
71    }
72
73    /// Returns a stripped version of `Aux` that contains only public data which can be digested
74    /// via [`udigest::Digestable`]
75    pub fn digest_public_data(&self) -> impl udigest::Digestable {
76        udigest::inline_struct!("paillier_zk.aux" {
77            s: udigest::Bytes(self.s.to_bytes_msf()),
78            t: udigest::Bytes(self.t.to_bytes_msf()),
79            rsa_modulo: udigest::Bytes(self.rsa_modulo.to_bytes_msf()),
80        })
81    }
82}
83
84/// Error indicating that proof is invalid
85#[derive(Debug, Clone, thiserror::Error)]
86#[error("invalid proof")]
87pub struct InvalidProof(
88    #[source]
89    #[from]
90    InvalidProofReason,
91);
92
93/// Reason for failure. If the proof fails, you should only be interested in a
94/// reason for debugging purposes
95#[non_exhaustive]
96#[derive(Debug, PartialEq, Eq, Clone, Copy, thiserror::Error)]
97pub enum InvalidProofReason {
98    /// One equality doesn't hold. Parameterized by equality index
99    #[error("equality check failed {0}")]
100    EqualityCheck(usize),
101    /// Check that integer belongs to multiplicative group failed. Parameterized by equality index
102    #[error("mult group check failed {0}")]
103    MultGroupCheck(usize),
104    /// One range check doesn't hold. Parameterized by check index
105    #[error("range check failed {0}")]
106    RangeCheck(usize),
107    /// Encryption of supplied data failed when attempting to verify
108    #[error("encryption failed")]
109    Encryption,
110    #[error("paillier encryption failed")]
111    PaillierEnc,
112    #[error("paillier homomorphic op failed")]
113    PaillierOp,
114    /// Failed to evaluate powmod
115    #[error("powmod failed")]
116    ModPow,
117    /// Paillier-Blum modulus is prime
118    #[error("modulus is prime")]
119    ModulusIsPrime,
120    /// Paillier-Blum modulus is even
121    #[error("modulus is even")]
122    ModulusIsEven,
123    /// Proof's z value in n-th power does not equal commitment value
124    #[error("incorrect nth root")]
125    IncorrectNthRoot,
126    /// Proof's x value in 4-th power does not equal commitment value
127    #[error("incorrect 4th root")]
128    IncorrectFourthRoot,
129    /// Conversion failed (e.g. from u32 to usize)
130    #[error("conversion failed")]
131    Conversion,
132}
133
134impl InvalidProof {
135    #[cfg(test)]
136    pub(crate) fn reason(&self) -> InvalidProofReason {
137        self.0
138    }
139}
140
141impl From<BadExponent> for InvalidProof {
142    fn from(_err: BadExponent) -> Self {
143        InvalidProofReason::ModPow.into()
144    }
145}
146
147impl From<PaillierError> for InvalidProof {
148    fn from(_err: PaillierError) -> Self {
149        InvalidProof(InvalidProofReason::Encryption)
150    }
151}
152
153/// Error indicating that encryption failed
154#[derive(Clone, Copy, Debug, thiserror::Error)]
155#[error("paillier encryption failed")]
156pub struct PaillierError;
157
158pub trait IntegerExt: Sized {
159    /// Embed BigInt into chosen scalar type
160    fn to_scalar<C: generic_ec::Curve>(&self) -> Scalar<C>;
161
162    /// Returns prime order of curve C
163    fn curve_order<C: generic_ec::Curve>() -> Self;
164
165    /// Generates a random integer in interval
166    /// `[-range/2; range/2]` if range is even
167    /// `[-(range-1)/2; (range-1)/2]` if range is odd
168    fn from_rng_half_pm<R: rand_core::RngCore>(rng: &mut R, range: &Self) -> Self;
169
170    /// Checks whether `self` is in interval
171    /// `[-range/2; range/2]` when range is even
172    /// `[-(range-1)/2; (range-1)/2]` when range is odd
173    fn is_in_half_pm(&self, range: &Self) -> bool;
174}
175
176impl IntegerExt for Integer {
177    fn to_scalar<C: generic_ec::Curve>(&self) -> Scalar<C> {
178        let bytes_be = self.to_bytes_msf();
179        let s = Scalar::<C>::from_be_bytes_mod_order(bytes_be);
180        if self.cmp0().is_ge() {
181            s
182        } else {
183            -s
184        }
185    }
186
187    fn curve_order<C: generic_ec::Curve>() -> Self {
188        let order_minus_one = -Scalar::<C>::one();
189        let i = Integer::from_bytes_msf(&order_minus_one.to_be_bytes());
190        i + 1
191    }
192
193    fn from_rng_half_pm<R: rand_core::RngCore>(rng: &mut R, range: &Self) -> Self {
194        if range.is_even() {
195            let half_range = range >> 1;
196            let range_plus_one = range + 1u32;
197            range_plus_one.random_below(rng) - half_range
198        } else {
199            // range is odd, so half of the range minus one (that is `(range -
200            // 1) / 2`) is range / 2
201            let half_range_minus_one = range >> 1;
202            range.random_below_ref(rng) - half_range_minus_one
203        }
204    }
205
206    fn is_in_half_pm(&self, range: &Self) -> bool {
207        // If range is even, range >> 1 is exactly range / 2
208        // If range is odd, range >> 1 == (range - 1) >> 1 == (range - 1) / 2 as
209        // the lowest bit is discarded either way
210        let bound = range >> 1;
211        self.cmp_abs(&bound).is_le()
212    }
213}
214
215/// Error indicating that computation cannot be evaluated because of bad exponent
216///
217/// Returned by [`Aux::pow_mod`] and other functions that do exponentiation internally
218#[derive(Clone, Copy, Debug, thiserror::Error)]
219#[error(transparent)]
220pub struct BadExponent(#[from] BadExponentReason);
221
222impl BadExponent {
223    /// Constructs an error that exponent is undefined
224    pub fn undefined() -> Self {
225        Self(BadExponentReason::Undefined)
226    }
227}
228
229#[derive(Clone, Copy, Debug, thiserror::Error)]
230enum BadExponentReason {
231    #[error("exponent is undefined")]
232    Undefined,
233    #[error("multiexp error: exponent size is too large (exponents size: {exp_size:?}, max exponent size: {max_exp_size:?})")]
234    ExpSize {
235        exp_size: (u64, u64),
236        max_exp_size: (usize, usize),
237    },
238}
239
240/// Returns `Err(err)` if `assertion` is false
241pub fn fail_if<E>(err: E, assertion: bool) -> Result<(), E> {
242    if assertion {
243        Ok(())
244    } else {
245        Err(err)
246    }
247}
248
249/// Returns `Err(err)` if `lhs != rhs`
250pub fn fail_if_ne<T: PartialEq, E>(err: E, lhs: T, rhs: T) -> Result<(), E> {
251    if lhs == rhs {
252        Ok(())
253    } else {
254        Err(err)
255    }
256}
257
258pub mod encoding {
259    /// Digests a fast-paillier backend integer
260    pub struct Integer;
261    impl udigest::DigestAs<fast_paillier::backend::Integer> for Integer {
262        fn digest_as<B: udigest::Buffer>(
263            value: &fast_paillier::backend::Integer,
264            encoder: udigest::encoding::EncodeValue<B>,
265        ) {
266            let digits = value.to_bytes_msf();
267            encoder.encode_leaf_value(digits)
268        }
269    }
270
271    /// Digests any encryption key
272    pub struct AnyEncryptionKey;
273    impl udigest::DigestAs<&dyn fast_paillier::AnyEncryptionKey> for AnyEncryptionKey {
274        fn digest_as<B: udigest::Buffer>(
275            value: &&dyn fast_paillier::AnyEncryptionKey,
276            encoder: udigest::encoding::EncodeValue<B>,
277        ) {
278            Integer::digest_as(value.n(), encoder)
279        }
280    }
281}
282
283/// A common logic shared across tests and doctests
284#[cfg(test)]
285pub mod test {
286    use fast_paillier::backend::Integer;
287
288    pub fn random_key<R: rand_core::RngCore>(rng: &mut R) -> Option<fast_paillier::DecryptionKey> {
289        let p = generate_blum_prime(rng, 1536);
290        let q = generate_blum_prime(rng, 1536);
291        fast_paillier::DecryptionKey::from_primes(p, q).ok()
292    }
293
294    pub fn aux<R: rand_core::RngCore>(rng: &mut R) -> super::Aux {
295        let p = generate_blum_prime(rng, 1536);
296        let q = generate_blum_prime(rng, 1536);
297        let n = &p * &q;
298
299        let (s, t) = {
300            let phi_n = (p - 1u8) * (q - 1u8);
301            let r = Integer::sample_in_mult_group_of(rng, &n);
302            let lambda = phi_n.random_below(rng);
303
304            let t = r.square().modulo(&n);
305            let s = t.pow_mod_ref(&lambda, &n).unwrap();
306
307            (s, t)
308        };
309
310        super::Aux {
311            s,
312            t,
313            rsa_modulo: n,
314            multiexp: None,
315            crt: None,
316        }
317    }
318
319    pub fn generate_blum_prime(rng: &mut impl rand_core::RngCore, bits_size: u32) -> Integer {
320        loop {
321            let n = Integer::generate_prime(rng, bits_size);
322            if n.mod_u(4) == 3 {
323                break n;
324            }
325        }
326    }
327}
328
329#[cfg(test)]
330mod _test {
331    use fast_paillier::backend::Integer;
332
333    use super::IntegerExt;
334
335    #[test]
336    fn to_scalar_encoding() {
337        type E = generic_ec::curves::Secp256k1;
338
339        let bytes = [123u8, 231u8];
340        let int = u16::from_be_bytes(bytes);
341        let bn = Integer::from(int);
342        let scalar = bn.to_scalar();
343        assert_eq!(scalar, generic_ec::Scalar::<E>::from(int));
344
345        assert_eq!(bn.to_bytes_msf(), &bytes);
346
347        let curve_order = Integer::curve_order::<E>();
348        assert_eq!(curve_order.to_scalar(), generic_ec::Scalar::<E>::zero());
349        assert_eq!(
350            (curve_order - 1u8).to_scalar(),
351            -generic_ec::Scalar::<E>::one()
352        );
353    }
354
355    #[test]
356    fn multiexp() {
357        let mut rng = rand_dev::DevRng::new();
358        let mut aux = super::test::aux(&mut rng);
359        let table = std::sync::Arc::new(
360            crate::multiexp::MultiexpTable::build(&aux.s, &aux.t, 512, 448, aux.rsa_modulo.clone())
361                .unwrap(),
362        );
363        let (x_bits, y_bits) = table.max_exponents_size();
364        aux.multiexp = Some(table);
365
366        // Corner case: upper bound
367        let x_max = (Integer::one() << x_bits) - 1;
368        let y_max = (Integer::one() << y_bits) - 1;
369        let actual = aux.combine(&x_max, &y_max).unwrap();
370        let expected = aux
371            .rsa_modulo
372            .combine(&aux.s, &x_max, &aux.t, &y_max)
373            .unwrap();
374        assert_eq!(actual, expected);
375
376        // Corner case: lower bound
377        let x_min = -&x_max;
378        let y_min = -&y_max;
379        let actual = aux.combine(&x_min, &y_min).unwrap();
380        let expected = aux
381            .rsa_modulo
382            .combine(&aux.s, &x_min, &aux.t, &y_min)
383            .unwrap();
384        assert_eq!(actual, expected);
385
386        // Random integers within the range
387        for _ in 0..100 {
388            let x = (&x_max + 1u8).random_below(&mut rng);
389            let y = (&y_max + 1u8).random_below(&mut rng);
390
391            let x = if rand::Rng::gen(&mut rng) { x } else { -x };
392            let y = if rand::Rng::gen(&mut rng) { y } else { -y };
393
394            println!("x: {x}");
395            println!("y: {y}");
396
397            let actual = aux.combine(&x, &y).unwrap();
398            let expected = aux.rsa_modulo.combine(&aux.s, &x, &aux.t, &y).unwrap();
399            assert_eq!(actual, expected);
400        }
401    }
402
403    #[test]
404    fn test_from_rng_half_pm_bounds() {
405        let mut rng = rand_dev::DevRng::new();
406        // Testing even case
407        let range = Integer::from(10);
408        let upper_bound = &range >> 1;
409        let lower_bound = -&upper_bound;
410        let mut min = Integer::from(0);
411        let mut max = Integer::from(0);
412
413        // Obtaining lower and upper bounds
414        for _ in 0..10000 {
415            let value = Integer::from_rng_half_pm(&mut rng, &range);
416            if value > max {
417                max.clone_from(&value);
418            }
419            if value < min {
420                min.clone_from(&value);
421            }
422        }
423
424        assert_eq!(
425            min, lower_bound,
426            "Minimum value {min} did not match expected lower bound {lower_bound}"
427        );
428        assert_eq!(
429            max, upper_bound,
430            "Maximum value {max} did not match expected upper bound {upper_bound}"
431        );
432
433        // Testing odd case
434        let range = Integer::from(9);
435        let range_minus_one = &range - Integer::one();
436        let upper_bound = range_minus_one >> 1;
437        let lower_bound = -&upper_bound;
438        let mut min = Integer::from(0);
439        let mut max = Integer::from(0);
440
441        // Obtaining lower and upper bounds
442        for _ in 0..10000 {
443            let value = Integer::from_rng_half_pm(&mut rng, &range);
444            if value > max {
445                max.clone_from(&value);
446            }
447            if value < min {
448                min.clone_from(&value);
449            }
450        }
451
452        assert_eq!(
453            min, lower_bound,
454            "Minimum value {min} did not match expected lower bound {lower_bound}"
455        );
456        assert_eq!(
457            max, upper_bound,
458            "Maximum value {max} did not match expected upper bound {upper_bound}"
459        );
460    }
461
462    #[test]
463    fn test_is_in_half_pm() {
464        // Testing even case
465        let range = Integer::from(10);
466        let a_1 = Integer::from(-6);
467        let a_2 = Integer::from(-5);
468        let a_3 = Integer::from(5);
469        let a_4 = Integer::from(6);
470        assert!(
471            !a_1.is_in_half_pm(&range),
472            "{a_1} should be outside [-range/2,range/2]"
473        );
474        assert!(
475            a_2.is_in_half_pm(&range),
476            "{a_2} should be in [-range/2,range/2]"
477        );
478        assert!(
479            a_3.is_in_half_pm(&range),
480            "{a_3} should be in [-range/2,range/2]"
481        );
482        assert!(
483            !a_4.is_in_half_pm(&range),
484            "{a_4} should be outside [-range/2,range/2]"
485        );
486
487        // Testing odd case
488        let range = Integer::from(9);
489        let a_1 = Integer::from(-5);
490        let a_2 = Integer::from(-4);
491        let a_3 = Integer::from(4);
492        let a_4 = Integer::from(5);
493        assert!(
494            !a_1.is_in_half_pm(&range),
495            "{a_1} should be outside [-(range-1)/2,(range-1)/2]"
496        );
497        assert!(
498            a_2.is_in_half_pm(&range),
499            "{a_1} should be in [-(range-1)/2,(range-1)/2]"
500        );
501        assert!(
502            a_3.is_in_half_pm(&range),
503            "{a_3} should be in [-(range-1)/2,(range-1)/2]"
504        );
505        assert!(
506            !a_4.is_in_half_pm(&range),
507            "{a_4} should be outside [-(range-1)/2,(range-1)/2]"
508        );
509    }
510}