Skip to main content

cggmp24_keygen/
threshold.rs

1use alloc::vec::Vec;
2
3use digest::Digest;
4use generic_ec::{Curve, NonZero, Point, Scalar, SecretScalar};
5use generic_ec_zkp::{polynomial::Polynomial, schnorr_pok};
6use rand_core::{CryptoRng, RngCore};
7use round_based::{
8    rounds_router::simple_store::RoundInput, rounds_router::RoundsRouter, Delivery, Mpc, MpcParty,
9    Outgoing, ProtocolMessage, SinkExt,
10};
11use serde::{Deserialize, Serialize};
12use serde_with::serde_as;
13
14use crate::progress::Tracer;
15use crate::{
16    errors::IoError,
17    key_share::{CoreKeyShare, DirtyCoreKeyShare, DirtyKeyInfo, Validate, VssSetup},
18    security_level::SecurityLevel,
19    utils, ExecutionId,
20};
21
22use super::{Bug, KeygenAborted, KeygenError};
23
24macro_rules! prefixed {
25    ($name:tt) => {
26        concat!("dfns.cggmp24.keygen.threshold.", $name)
27    };
28}
29
30/// Message of key generation protocol
31#[derive(ProtocolMessage, Clone, Serialize, Deserialize)]
32#[serde(bound = "")]
33pub enum Msg<E: Curve, L: SecurityLevel, D: Digest> {
34    /// Round 1 message
35    Round1(MsgRound1<D>),
36    /// Round 2a message
37    Round2Broad(MsgRound2Broad<E, L>),
38    /// Round 2b message
39    Round2Uni(MsgRound2Uni<E>),
40    /// Round 3 message
41    Round3(MsgRound3<E>),
42    /// Reliability check message (optional additional round)
43    ReliabilityCheck(MsgReliabilityCheck<D>),
44}
45
46/// Message from round 1
47#[derive(Clone, Serialize, Deserialize, udigest::Digestable)]
48#[serde(bound = "")]
49#[udigest(bound = "")]
50#[udigest(tag = prefixed!("round1"))]
51pub struct MsgRound1<D: Digest> {
52    /// $V_i$
53    #[udigest(as_bytes)]
54    pub commitment: digest::Output<D>,
55}
56/// Message from round 2 broadcasted to everyone
57#[serde_as]
58#[derive(Clone, Serialize, Deserialize, udigest::Digestable)]
59#[serde(bound = "")]
60#[udigest(bound = "")]
61#[udigest(tag = prefixed!("round2_broad"))]
62pub struct MsgRound2Broad<E: Curve, L: SecurityLevel> {
63    /// `rid_i`
64    #[serde_as(as = "utils::HexOrBin")]
65    #[udigest(as_bytes)]
66    pub rid: L::KappaBytes,
67    /// $\vec S_i$
68    pub F: Polynomial<Point<E>>,
69    /// $A_i$
70    pub sch_commit: schnorr_pok::Commit<E>,
71    /// Party contribution to chain code
72    #[cfg(feature = "hd-wallet")]
73    #[serde_as(as = "Option<utils::HexOrBin>")]
74    #[udigest(as = Option<udigest::Bytes>)]
75    pub chain_code: Option<hd_wallet::ChainCode>,
76    /// $u_i$
77    #[serde(with = "hex::serde")]
78    #[udigest(as_bytes)]
79    pub decommit: L::KappaBytes,
80}
81/// Message from round 2 unicasted to each party
82#[derive(Clone, Serialize, Deserialize)]
83#[serde(bound = "")]
84pub struct MsgRound2Uni<E: Curve> {
85    /// $\sigma_{i,j}$
86    pub sigma: Scalar<E>,
87}
88/// Message from round 3
89#[derive(Clone, Serialize, Deserialize)]
90#[serde(bound = "")]
91pub struct MsgRound3<E: Curve> {
92    /// $\psi_i$
93    pub sch_proof: schnorr_pok::Proof<E>,
94}
95/// Message parties exchange to ensure reliability of broadcast channel
96#[derive(Clone, Serialize, Deserialize)]
97#[serde(bound = "")]
98pub struct MsgReliabilityCheck<D: Digest>(pub digest::Output<D>);
99
100mod unambiguous {
101    use generic_ec::{Curve, NonZero, Point};
102
103    use crate::{ExecutionId, SecurityLevel};
104
105    #[derive(udigest::Digestable)]
106    #[udigest(tag = prefixed!("hash_commitment"))]
107    #[udigest(bound = "")]
108    pub struct HashCom<'a, E: Curve, L: SecurityLevel> {
109        pub sid: ExecutionId<'a>,
110        pub party_index: u16,
111        pub decommitment: &'a super::MsgRound2Broad<E, L>,
112    }
113
114    #[derive(udigest::Digestable)]
115    #[udigest(tag = prefixed!("schnorr_pok"))]
116    #[udigest(bound = "")]
117    pub struct SchnorrPok<'a, E: Curve> {
118        pub sid: ExecutionId<'a>,
119        pub prover: u16,
120        #[udigest(as_bytes)]
121        pub rid: &'a [u8],
122        pub y: NonZero<Point<E>>,
123        pub h: Point<E>,
124    }
125
126    #[derive(udigest::Digestable)]
127    #[udigest(tag = prefixed!("echo_round"))]
128    #[udigest(bound = "")]
129    pub struct Echo<'a, D: digest::Digest> {
130        pub sid: ExecutionId<'a>,
131        pub commitment: &'a super::MsgRound1<D>,
132    }
133}
134
135pub async fn run_threshold_keygen<E, R, M, L, D>(
136    mut tracer: Option<&mut dyn Tracer>,
137    i: u16,
138    t: u16,
139    n: u16,
140    reliable_broadcast_enforced: bool,
141    sid: ExecutionId<'_>,
142    rng: &mut R,
143    party: M,
144    #[cfg(feature = "hd-wallet")] hd_enabled: bool,
145) -> Result<CoreKeyShare<E>, KeygenError>
146where
147    E: Curve,
148    L: SecurityLevel,
149    D: Digest + Clone + 'static,
150    R: RngCore + CryptoRng,
151    M: Mpc<ProtocolMessage = Msg<E, L, D>>,
152{
153    tracer.protocol_begins();
154
155    tracer.stage("Setup networking");
156    let MpcParty { delivery, .. } = party.into_party();
157    let (incomings, mut outgoings) = delivery.split();
158
159    let mut rounds = RoundsRouter::<Msg<E, L, D>>::builder();
160    let round1 = rounds.add_round(RoundInput::<MsgRound1<D>>::broadcast(i, n));
161    let round1_sync = rounds.add_round(RoundInput::<MsgReliabilityCheck<D>>::broadcast(i, n));
162    let round2_broad = rounds.add_round(RoundInput::<MsgRound2Broad<E, L>>::broadcast(i, n));
163    let round2_uni = rounds.add_round(RoundInput::<MsgRound2Uni<E>>::p2p(i, n));
164    let round3 = rounds.add_round(RoundInput::<MsgRound3<E>>::broadcast(i, n));
165    let mut rounds = rounds.listen(incomings);
166
167    // Round 1
168    tracer.round_begins();
169
170    tracer.stage("Sample rid_i, schnorr commitment, polynomial, chain_code");
171    let mut rid = L::KappaBytes::default();
172    rng.fill_bytes(rid.as_mut());
173
174    let (r, h) = schnorr_pok::prover_commits_ephemeral_secret::<E, _>(rng);
175
176    let f = Polynomial::<SecretScalar<E>>::sample(rng, usize::from(t) - 1);
177    let F = &f * &Point::generator();
178    let sigmas = (0..n)
179        .map(|j| {
180            let x = Scalar::from(j + 1);
181            f.value(&x)
182        })
183        .collect::<Vec<_>>();
184    debug_assert_eq!(sigmas.len(), usize::from(n));
185
186    #[cfg(feature = "hd-wallet")]
187    let chain_code_local = if hd_enabled {
188        let mut chain_code = hd_wallet::ChainCode::default();
189        rng.fill_bytes(&mut chain_code);
190        Some(chain_code)
191    } else {
192        None
193    };
194
195    tracer.stage("Commit to public data");
196    let my_decommitment = MsgRound2Broad {
197        rid,
198        F: F.clone(),
199        sch_commit: h,
200        #[cfg(feature = "hd-wallet")]
201        chain_code: chain_code_local,
202        decommit: {
203            let mut nonce = L::KappaBytes::default();
204            rng.fill_bytes(nonce.as_mut());
205            nonce
206        },
207    };
208    let hash_commit = udigest::hash::<D>(&unambiguous::HashCom {
209        sid,
210        party_index: i,
211        decommitment: &my_decommitment,
212    });
213
214    tracer.send_msg();
215    let my_commitment = MsgRound1 {
216        commitment: hash_commit,
217    };
218    outgoings
219        .send(Outgoing::broadcast(Msg::Round1(my_commitment.clone())))
220        .await
221        .map_err(IoError::send_message)?;
222    tracer.msg_sent();
223
224    // Round 2
225    tracer.round_begins();
226
227    tracer.receive_msgs();
228    let commitments = rounds
229        .complete(round1)
230        .await
231        .map_err(IoError::receive_message)?;
232    tracer.msgs_received();
233
234    // Optional reliability check
235    if reliable_broadcast_enforced {
236        tracer.stage("Hash received msgs (reliability check)");
237        let h_i = udigest::hash_iter::<D>(
238            commitments
239                .iter_including_me(&my_commitment)
240                .map(|commitment| unambiguous::Echo { sid, commitment }),
241        );
242
243        tracer.send_msg();
244        outgoings
245            .send(Outgoing::broadcast(Msg::ReliabilityCheck(
246                MsgReliabilityCheck(h_i.clone()),
247            )))
248            .await
249            .map_err(IoError::send_message)?;
250        tracer.msg_sent();
251
252        tracer.round_begins();
253
254        tracer.receive_msgs();
255        let hashes = rounds
256            .complete(round1_sync)
257            .await
258            .map_err(IoError::receive_message)?;
259        tracer.msgs_received();
260
261        tracer.stage("Assert other parties hashed messages (reliability check)");
262        let parties_have_different_hashes = hashes
263            .into_iter_indexed()
264            .filter(|(_j, _msg_id, h_j)| h_i != h_j.0)
265            .map(|(j, msg_id, _)| (j, msg_id))
266            .collect::<Vec<_>>();
267        if !parties_have_different_hashes.is_empty() {
268            return Err(KeygenAborted::Round1NotReliable(parties_have_different_hashes).into());
269        }
270    }
271
272    tracer.send_msg();
273    outgoings
274        .feed(Outgoing::broadcast(Msg::Round2Broad(
275            my_decommitment.clone(),
276        )))
277        .await
278        .map_err(IoError::send_message)?;
279
280    let messages = utils::iter_peers(i, n).map(|j| {
281        let message = MsgRound2Uni {
282            sigma: sigmas[usize::from(j)],
283        };
284        Outgoing::p2p(j, Msg::Round2Uni(message))
285    });
286    outgoings
287        .send_all(&mut futures_util::stream::iter(messages.map(Ok)))
288        .await
289        .map_err(IoError::send_message)?;
290    tracer.msg_sent();
291
292    // Round 3
293    tracer.round_begins();
294
295    tracer.receive_msgs();
296    let decommitments = rounds
297        .complete(round2_broad)
298        .await
299        .map_err(IoError::receive_message)?;
300    let sigmas_msg = rounds
301        .complete(round2_uni)
302        .await
303        .map_err(IoError::receive_message)?;
304    tracer.msgs_received();
305
306    tracer.stage("Validate decommitments");
307    let blame = utils::collect_blame(&commitments, &decommitments, |j, com, decom| {
308        let com_expected = udigest::hash::<D>(&unambiguous::HashCom {
309            sid,
310            party_index: j,
311            decommitment: decom,
312        });
313        com.commitment != com_expected
314    });
315    if !blame.is_empty() {
316        return Err(KeygenAborted::InvalidDecommitment(blame).into());
317    }
318
319    tracer.stage("Validate data size");
320    let blame = decommitments
321        .iter_indexed()
322        .filter(|(_, _, d)| d.F.degree() + 1 != usize::from(t))
323        .map(|t| t.0)
324        .collect::<Vec<_>>();
325    if !blame.is_empty() {
326        return Err(KeygenAborted::InvalidDataSize { parties: blame }.into());
327    }
328
329    tracer.stage("Validate Feldmann VSS");
330    let blame = decommitments
331        .iter_indexed()
332        .zip(sigmas_msg.iter())
333        .filter(|((_, _, d), s)| {
334            d.F.value::<_, Point<_>>(&Scalar::from(i + 1)) != Point::generator() * s.sigma
335        })
336        .map(|t| t.0 .0)
337        .collect::<Vec<_>>();
338    if !blame.is_empty() {
339        return Err(KeygenAborted::FeldmanVerificationFailed { parties: blame }.into());
340    }
341
342    tracer.stage("Compute rid");
343    let rid = decommitments
344        .iter_including_me(&my_decommitment)
345        .map(|d| &d.rid)
346        .fold(L::KappaBytes::default(), utils::xor_array);
347    #[cfg(feature = "hd-wallet")]
348    let chain_code = if hd_enabled {
349        tracer.stage("Compute chain_code");
350        let blame = utils::collect_simple_blame(&decommitments, |decom| decom.chain_code.is_none());
351        if !blame.is_empty() {
352            return Err(KeygenAborted::MissingChainCode(blame).into());
353        }
354        Some(decommitments.iter_including_me(&my_decommitment).try_fold(
355            hd_wallet::ChainCode::default(),
356            |acc, decom| {
357                Ok::<_, Bug>(utils::xor_array(
358                    acc,
359                    decom.chain_code.ok_or(Bug::NoChainCode)?,
360                ))
361            },
362        )?)
363    } else {
364        None
365    };
366    tracer.stage("Compute Ys");
367    let polynomial_sum = decommitments
368        .iter_including_me(&my_decommitment)
369        .map(|d| &d.F)
370        .sum::<Polynomial<_>>();
371    let ys = (0..n)
372        .map(|l| polynomial_sum.value(&Scalar::from(l + 1)))
373        .map(|y_j: Point<E>| NonZero::from_point(y_j).ok_or(Bug::ZeroShare))
374        .collect::<Result<Vec<_>, _>>()?;
375    tracer.stage("Compute sigma");
376    let sigma: Scalar<E> = sigmas_msg.iter().map(|msg| msg.sigma).sum();
377    let mut sigma = sigma + sigmas[usize::from(i)];
378    let sigma = NonZero::from_secret_scalar(SecretScalar::new(&mut sigma)).ok_or(Bug::ZeroShare)?;
379    debug_assert_eq!(Point::generator() * &sigma, ys[usize::from(i)]);
380
381    tracer.stage("Calculate challenge");
382    let challenge = Scalar::from_hash::<D>(&unambiguous::SchnorrPok {
383        sid,
384        prover: i,
385        rid: rid.as_ref(),
386        y: ys[usize::from(i)],
387        h: my_decommitment.sch_commit.0,
388    });
389    let challenge = schnorr_pok::Challenge { nonce: challenge };
390
391    tracer.stage("Prove knowledge of `sigma_i`");
392    let z = schnorr_pok::prove(&r, &challenge, &sigma);
393
394    tracer.send_msg();
395    let my_sch_proof = MsgRound3 { sch_proof: z };
396    outgoings
397        .send(Outgoing::broadcast(Msg::Round3(my_sch_proof.clone())))
398        .await
399        .map_err(IoError::send_message)?;
400    tracer.msg_sent();
401
402    // Output round
403    tracer.round_begins();
404
405    tracer.receive_msgs();
406    let sch_proofs = rounds
407        .complete(round3)
408        .await
409        .map_err(IoError::receive_message)?;
410    tracer.msgs_received();
411
412    tracer.stage("Validate schnorr proofs");
413    let blame = utils::collect_blame(&decommitments, &sch_proofs, |j, decom, sch_proof| {
414        let challenge = Scalar::from_hash::<D>(&unambiguous::SchnorrPok {
415            sid,
416            prover: j,
417            rid: rid.as_ref(),
418            y: ys[usize::from(j)],
419            h: decom.sch_commit.0,
420        });
421        let challenge = schnorr_pok::Challenge { nonce: challenge };
422        sch_proof
423            .sch_proof
424            .verify(&decom.sch_commit, &challenge, &ys[usize::from(j)])
425            .is_err()
426    });
427    if !blame.is_empty() {
428        return Err(KeygenAborted::InvalidSchnorrProof(blame).into());
429    }
430
431    tracer.stage("Derive resulting public key and other data");
432    let y: Point<E> = decommitments
433        .iter_including_me(&my_decommitment)
434        .map(|d| d.F.coefs()[0])
435        .sum();
436    let key_shares_indexes = (1..=n)
437        .map(|i| NonZero::from_scalar(Scalar::from(i)))
438        .collect::<Option<Vec<_>>>()
439        .ok_or(Bug::NonZeroScalar)?;
440
441    tracer.protocol_ends();
442
443    Ok(DirtyCoreKeyShare {
444        i,
445        key_info: DirtyKeyInfo {
446            curve: Default::default(),
447            shared_public_key: NonZero::from_point(y).ok_or(Bug::ZeroPk)?,
448            public_shares: ys,
449            vss_setup: Some(VssSetup {
450                min_signers: t,
451                I: key_shares_indexes,
452            }),
453            #[cfg(feature = "hd-wallet")]
454            chain_code,
455        },
456        x: sigma,
457    }
458    .validate()
459    .map_err(|err| Bug::InvalidKeyShare(err.into_error()))?)
460}