Skip to main content

cggmp24/key_refresh/
aux_only.rs

1use digest::Digest;
2use futures::SinkExt;
3use paillier_zk::{backend::Integer, no_small_factor as π_fac, paillier_blum_modulus as π_mod};
4use rand_core::{CryptoRng, RngCore};
5use round_based::{
6    rounds_router::{simple_store::RoundInput, RoundsRouter},
7    Delivery, Mpc, MpcParty, Outgoing, ProtocolMessage,
8};
9use serde::{Deserialize, Serialize};
10
11use crate::{
12    errors::IoError,
13    key_share::{AuxInfo, DirtyAuxInfo, PedersenParams, Validate},
14    progress::Tracer,
15    security_level::SecurityLevel,
16    utils,
17    utils::{collect_blame, AbortBlame},
18    zk::ring_pedersen_parameters as π_prm,
19    ExecutionId,
20};
21
22use super::{Bug, KeyRefreshError, PregeneratedPrimes, ProtocolAborted};
23
24macro_rules! prefixed {
25    ($name:tt) => {
26        concat!("dfns.cggmp24.aux_gen.", $name)
27    };
28}
29
30/// Message of key refresh protocol
31#[derive(ProtocolMessage, Clone, Serialize, Deserialize)]
32#[serde(bound = "")]
33// 3 kilobytes for the largest option, and 2.5 kilobytes for second largest
34#[allow(clippy::large_enum_variant)]
35pub enum Msg<D: Digest, L: SecurityLevel> {
36    /// Round 1 message
37    Round1(MsgRound1<D>),
38    /// Round 2 message
39    Round2(MsgRound2<L>),
40    /// Round 3 message
41    Round3(MsgRound3),
42    /// Reliability check message (optional additional round)
43    ReliabilityCheck(MsgReliabilityCheck<D>),
44}
45
46/// Message from round 1
47#[derive(Clone, Serialize, Deserialize, udigest::Digestable)]
48#[udigest(tag = prefixed!("round1"))]
49#[udigest(bound = "")]
50#[serde(bound = "")]
51pub struct MsgRound1<D: Digest> {
52    /// $V_i$
53    #[udigest(as_bytes)]
54    pub commitment: digest::Output<D>,
55}
56/// Message from round 2
57#[derive(Clone, Serialize, Deserialize, udigest::Digestable)]
58#[udigest(tag = prefixed!("round2"))]
59#[udigest(bound = "")]
60#[serde(bound = "")]
61pub struct MsgRound2<L: SecurityLevel> {
62    /// $N_i$
63    #[udigest(as = utils::encoding::Integer)]
64    pub N: Integer,
65    /// $\hat N_i$
66    #[udigest(as = utils::encoding::Integer)]
67    pub hat_N: Integer,
68    /// $s_i$
69    #[udigest(as = utils::encoding::Integer)]
70    pub s: Integer,
71    /// $t_i$
72    #[udigest(as = utils::encoding::Integer)]
73    pub t: Integer,
74    /// $\hat \psi_i$
75    // this should be L::M instead, but no rustc support yet
76    pub params_proof: π_prm::Proof<{ crate::security_level::M }>,
77    /// $\rho_i$
78    // ideally it would be [u8; L::SECURITY_BYTES], but no rustc support yet
79    #[serde(with = "hex")]
80    #[udigest(as_bytes)]
81    pub rho_bytes: L::KappaBytes,
82    /// $u_i$
83    #[serde(with = "hex")]
84    #[udigest(as_bytes)]
85    pub decommit: L::KappaBytes,
86}
87/// Unicast message of round 3, sent to each participant
88#[derive(Clone, Serialize, Deserialize)]
89pub struct MsgRound3 {
90    /// $\psi_i$
91    // this should be L::M instead, but no rustc support yet
92    pub mod_proof: π_mod::NiProof<{ crate::security_level::M }>,
93    /// $\phi_i^j$
94    pub fac_proof: π_fac::NiProof,
95}
96
97/// Message from an optional round that enforces reliability check
98#[derive(Clone, Serialize, Deserialize)]
99#[serde(bound = "")]
100pub struct MsgReliabilityCheck<D: Digest>(pub digest::Output<D>);
101
102mod unambiguous {
103    use digest::Digest;
104
105    use crate::{ExecutionId, SecurityLevel};
106
107    #[derive(udigest::Digestable)]
108    #[udigest(tag = prefixed!("proof_prm"))]
109    pub struct ProofPrm<'a> {
110        pub sid: ExecutionId<'a>,
111        pub prover: u16,
112    }
113
114    #[derive(udigest::Digestable)]
115    #[udigest(tag = prefixed!("proof_mod"))]
116    pub struct ProofMod<'a> {
117        pub sid: ExecutionId<'a>,
118        #[udigest(as_bytes)]
119        pub rho: &'a [u8],
120        pub prover: u16,
121    }
122
123    #[derive(udigest::Digestable)]
124    #[udigest(tag = prefixed!("proof_fac"))]
125    #[udigest(bound = "")]
126    pub struct ProofFac<'a> {
127        pub sid: ExecutionId<'a>,
128        #[udigest(as_bytes)]
129        pub rho: &'a [u8],
130        pub prover: u16,
131    }
132
133    #[derive(udigest::Digestable)]
134    #[udigest(tag = prefixed!("hash_commitment"))]
135    #[udigest(bound = "")]
136    pub struct HashCom<'a, L: SecurityLevel> {
137        pub sid: ExecutionId<'a>,
138        pub prover: u16,
139        pub decommitment: &'a super::MsgRound2<L>,
140    }
141
142    #[derive(udigest::Digestable)]
143    #[udigest(tag = prefixed!("echo_round"))]
144    #[udigest(bound = "")]
145    pub struct Echo<'a, D: Digest> {
146        pub sid: ExecutionId<'a>,
147        pub commitment: &'a super::MsgRound1<D>,
148    }
149}
150
151pub async fn run_aux_gen<R, M, L, D>(
152    i: u16,
153    n: u16,
154    mut rng: &mut R,
155    party: M,
156    sid: ExecutionId<'_>,
157    pregenerated: PregeneratedPrimes<L>,
158    mut tracer: Option<&mut dyn Tracer>,
159    reliable_broadcast_enforced: bool,
160    compute_multiexp_table: bool,
161) -> Result<AuxInfo<L>, KeyRefreshError>
162where
163    R: RngCore + CryptoRng,
164    M: Mpc<ProtocolMessage = Msg<D, L>>,
165    L: SecurityLevel,
166    D: Digest + Clone + 'static,
167{
168    tracer.protocol_begins();
169
170    tracer.stage("Retrieve auxiliary data");
171
172    tracer.stage("Setup networking");
173    let MpcParty { delivery, .. } = party.into_party();
174    let (incomings, mut outgoings) = delivery.split();
175
176    let mut rounds = RoundsRouter::<Msg<D, L>>::builder();
177    let round1 = rounds.add_round(RoundInput::<MsgRound1<D>>::broadcast(i, n));
178    let round1_sync = rounds.add_round(RoundInput::<MsgReliabilityCheck<D>>::broadcast(i, n));
179    let round2 = rounds.add_round(RoundInput::<MsgRound2<L>>::broadcast(i, n));
180    let round3 = rounds.add_round(RoundInput::<MsgRound3>::p2p(i, n));
181    let mut rounds = rounds.listen(incomings);
182
183    // Round 1
184    tracer.round_begins();
185
186    let [p, q, hat_p, hat_q] = pregenerated.into_primes();
187
188    tracer.stage("Build Paillier key");
189    let N = &p * &q;
190
191    tracer.stage("Build Pedersen params");
192    let (pedersen_params, phi_hat_N, lambda) = utils::generate_pedersen_params(rng, hat_p, hat_q)?;
193
194    tracer.stage("Prove Πprm (ψˆ_i)");
195    let hat_psi = π_prm::prove::<{ crate::security_level::M }, D>(
196        &unambiguous::ProofPrm { sid, prover: i },
197        &mut rng,
198        π_prm::Data {
199            N: &pedersen_params.hat_N,
200            s: &pedersen_params.s,
201            t: &pedersen_params.t,
202        },
203        &phi_hat_N,
204        &lambda,
205    )
206    .map_err(Bug::PiPrm)?;
207
208    tracer.stage("Sample random bytes");
209    // rho_i in paper, this signer's share of bytes
210    let mut rho_bytes = L::KappaBytes::default();
211    rng.fill_bytes(rho_bytes.as_mut());
212
213    tracer.stage("Compute hash commitment and sample decommitment");
214    // V_i and u_i in paper
215    let decommitment = MsgRound2 {
216        N: N.clone(),
217        hat_N: pedersen_params.hat_N.clone(),
218        s: pedersen_params.s.clone(),
219        t: pedersen_params.t.clone(),
220        params_proof: hat_psi,
221        rho_bytes: rho_bytes.clone(),
222        decommit: {
223            let mut nonce = L::KappaBytes::default();
224            rng.fill_bytes(nonce.as_mut());
225            nonce
226        },
227    };
228    let hash_commit = udigest::hash::<D>(&unambiguous::HashCom {
229        sid,
230        prover: i,
231        decommitment: &decommitment,
232    });
233
234    tracer.send_msg();
235    let commitment = MsgRound1 {
236        commitment: hash_commit,
237    };
238    outgoings
239        .send(Outgoing::broadcast(Msg::Round1(commitment.clone())))
240        .await
241        .map_err(IoError::send_message)?;
242    tracer.msg_sent();
243
244    // Round 2
245    tracer.round_begins();
246
247    tracer.receive_msgs();
248    let commitments = rounds
249        .complete(round1)
250        .await
251        .map_err(IoError::receive_message)?;
252    tracer.msgs_received();
253
254    // Optional reliability check
255    if reliable_broadcast_enforced {
256        tracer.stage("Hash received msgs (reliability check)");
257        let h_i = udigest::hash_iter::<D>(
258            commitments
259                .iter_including_me(&commitment)
260                .map(|commitment| unambiguous::Echo { sid, commitment }),
261        );
262
263        tracer.send_msg();
264        outgoings
265            .send(Outgoing::broadcast(Msg::ReliabilityCheck(
266                MsgReliabilityCheck(h_i.clone()),
267            )))
268            .await
269            .map_err(IoError::send_message)?;
270        tracer.msg_sent();
271
272        tracer.round_begins();
273
274        tracer.receive_msgs();
275        let hashes = rounds
276            .complete(round1_sync)
277            .await
278            .map_err(IoError::receive_message)?;
279        tracer.msgs_received();
280
281        tracer.stage("Assert other parties hashed messages (reliability check)");
282        let parties_have_different_hashes = hashes
283            .into_iter_indexed()
284            .filter(|(_j, _msg_id, h_j)| h_i != h_j.0)
285            .map(|(j, msg_id, _)| AbortBlame::new(j, msg_id, msg_id))
286            .collect::<Vec<_>>();
287        if !parties_have_different_hashes.is_empty() {
288            return Err(ProtocolAborted::round1_not_reliable(parties_have_different_hashes).into());
289        }
290    }
291
292    tracer.send_msg();
293    outgoings
294        .send(Outgoing::broadcast(Msg::Round2(decommitment.clone())))
295        .await
296        .map_err(IoError::send_message)?;
297    tracer.msg_sent();
298
299    // Round 3
300    tracer.round_begins();
301
302    tracer.receive_msgs();
303    let decommitments = rounds
304        .complete(round2)
305        .await
306        .map_err(IoError::receive_message)?;
307    tracer.msgs_received();
308
309    // validate decommitments
310    tracer.stage("Validate round 1 decommitments");
311    let blame = collect_blame(&decommitments, &commitments, |j, decomm, comm| {
312        let com_expected = udigest::hash::<D>(&unambiguous::HashCom {
313            sid,
314            prover: j,
315            decommitment: decomm,
316        });
317        com_expected != comm.commitment
318    });
319    if !blame.is_empty() {
320        return Err(ProtocolAborted::invalid_decommitment(blame).into());
321    }
322    // validate parameters and param_proofs
323    tracer.stage("Validate bit length and П_prm (ψˆ_i)");
324    let blame = collect_blame(&decommitments, &decommitments, |j, d, _| {
325        if [&d.N, &d.hat_N]
326            .iter()
327            .any(|biprime| !crate::security_level::validate_public_paillier_key_size::<L>(biprime))
328        {
329            true
330        } else {
331            π_prm::verify::<{ crate::security_level::M }, D>(
332                &unambiguous::ProofPrm { sid, prover: j },
333                π_prm::Data {
334                    N: &d.hat_N,
335                    s: &d.s,
336                    t: &d.t,
337                },
338                &d.params_proof,
339            )
340            .is_err()
341        }
342    });
343    if !blame.is_empty() {
344        return Err(ProtocolAborted::invalid_ring_pedersen_parameters(blame).into());
345    }
346
347    tracer.stage("Add together shared random bytes");
348    // rho in paper, collective random bytes
349    let rho_bytes = decommitments
350        .iter()
351        .map(|d| &d.rho_bytes)
352        .fold(rho_bytes, utils::xor_array);
353
354    // common data for messages
355    tracer.stage("Compute П_mod (ψ_i)");
356    let psi = π_mod::non_interactive::prove::<{ crate::security_level::M }, D>(
357        &unambiguous::ProofMod {
358            sid,
359            rho: rho_bytes.as_ref(),
360            prover: i,
361        },
362        π_mod::Data { n: &N },
363        π_mod::PrivateData { p: &p, q: &q },
364        &mut rng,
365    )
366    .map_err(Bug::PiMod)?;
367    tracer.stage("Assemble security params for П_fac (ψ_i)");
368    let π_fac_security = π_fac::SecurityParams {
369        l: L::ELL,
370        epsilon: L::EPSILON,
371    };
372    let n_sqrt = N.sqrt_ref().ok_or(Bug::NegativeModulus)?;
373
374    // message to each party
375    for (j, _, d) in decommitments.iter_indexed() {
376        tracer.send_msg();
377
378        tracer.stage("Compute П_fac (ψ'_i,j)");
379        let psi_prime = π_fac::non_interactive::prove::<D>(
380            &unambiguous::ProofFac {
381                sid,
382                rho: rho_bytes.as_ref(),
383                prover: i,
384            },
385            &π_fac::Aux {
386                s: d.s.clone(),
387                t: d.t.clone(),
388                rsa_modulo: d.hat_N.clone(),
389                multiexp: None,
390                crt: None,
391            },
392            π_fac::Data {
393                n: &N,
394                n_root: &n_sqrt,
395            },
396            π_fac::PrivateData { p: &p, q: &q },
397            &π_fac_security,
398            &mut rng,
399        )
400        .map_err(Bug::PiFac)?;
401
402        tracer.send_msg();
403        let msg = MsgRound3 {
404            mod_proof: psi.clone(),
405            fac_proof: psi_prime.clone(),
406        };
407        outgoings
408            .feed(Outgoing::p2p(j, Msg::Round3(msg)))
409            .await
410            .map_err(IoError::send_message)?;
411        tracer.msg_sent();
412    }
413
414    tracer.send_msg();
415    outgoings.flush().await.map_err(IoError::send_message)?;
416    tracer.msg_sent();
417
418    // Output
419    tracer.round_begins();
420
421    tracer.receive_msgs();
422    let shares_msg_b = rounds
423        .complete(round3)
424        .await
425        .map_err(IoError::receive_message)?;
426    tracer.msgs_received();
427
428    tracer.stage("Validate ψ_j (П_mod)");
429    // verify mod proofs
430    let blame = collect_blame(
431        &decommitments,
432        &shares_msg_b,
433        |j, decommitment, proof_msg| {
434            π_mod::non_interactive::verify::<{ crate::security_level::M }, D>(
435                &unambiguous::ProofMod {
436                    sid,
437                    rho: rho_bytes.as_ref(),
438                    prover: j,
439                },
440                π_mod::Data { n: &decommitment.N },
441                &proof_msg.mod_proof,
442                rng,
443            )
444            .is_err()
445        },
446    );
447    if !blame.is_empty() {
448        return Err(ProtocolAborted::invalid_mod_proof(blame).into());
449    }
450
451    tracer.stage("Validate ψ'_j,i (П_fac)");
452    // verify fac proofs
453
454    let phi_common_aux: π_fac::Aux = (&pedersen_params).into();
455    let blame = collect_blame(
456        &decommitments,
457        &shares_msg_b,
458        |j, decommitment, proof_msg| {
459            let n_root = match decommitment.N.sqrt_ref() {
460                Some(root) => root,
461                None => return true,
462            };
463            π_fac::non_interactive::verify::<D>(
464                &unambiguous::ProofFac {
465                    sid,
466                    rho: rho_bytes.as_ref(),
467                    prover: j,
468                },
469                &phi_common_aux,
470                π_fac::Data {
471                    n: &decommitment.N,
472                    n_root: &n_root,
473                },
474                &π_fac_security,
475                &proof_msg.fac_proof,
476            )
477            .is_err()
478        },
479    );
480    if !blame.is_empty() {
481        return Err(ProtocolAborted::invalid_fac_proof(blame).into());
482    }
483
484    // verifications passed, compute final key shares
485
486    tracer.stage("Assemble auxiliary info");
487    let mut parties_pedersen = decommitments
488        .iter()
489        .map(|d| PedersenParams {
490            hat_N: d.hat_N.clone(),
491            s: d.s.clone(),
492            t: d.t.clone(),
493            multiexp: None,
494            crt: None,
495        })
496        .collect::<Vec<_>>();
497    parties_pedersen.insert(i.into(), pedersen_params);
498
499    let N = decommitments
500        .into_iter_including_me(decommitment)
501        .map(|d| d.N)
502        .collect::<Vec<_>>();
503    let mut aux = DirtyAuxInfo {
504        p,
505        q,
506        N,
507        pedersen_params: parties_pedersen,
508        security_level: std::marker::PhantomData,
509    };
510
511    if compute_multiexp_table {
512        tracer.stage("Precompute multiexp tables");
513
514        aux.precompute_multiexp_tables()
515            .map_err(Bug::BuildMultiexpTables)?;
516    }
517
518    let aux = aux
519        .validate()
520        .map_err(|err| Bug::InvalidShareGenerated(err.into_error()))?;
521
522    tracer.protocol_ends();
523    Ok(aux)
524}