Skip to main content

cggmp24_keygen/
non_threshold.rs

1use alloc::vec::Vec;
2
3use digest::Digest;
4use generic_ec::{Curve, NonZero, Point, Scalar, SecretScalar};
5use generic_ec_zkp::schnorr_pok;
6use rand_core::{CryptoRng, RngCore};
7use round_based::{
8    rounds_router::simple_store::RoundInput, rounds_router::RoundsRouter, Delivery, Mpc, MpcParty,
9    Outgoing, ProtocolMessage, SinkExt,
10};
11use serde::{Deserialize, Serialize};
12
13use crate::progress::Tracer;
14use crate::{
15    errors::IoError,
16    key_share::{CoreKeyShare, DirtyCoreKeyShare, DirtyKeyInfo, Validate},
17    security_level::SecurityLevel,
18    utils, ExecutionId,
19};
20
21use super::{Bug, KeygenAborted, KeygenError};
22
23macro_rules! prefixed {
24    ($name:tt) => {
25        concat!("dfns.cggmp24.keygen.non_threshold.", $name)
26    };
27}
28
29/// Message of key generation protocol
30#[derive(ProtocolMessage, Clone, Serialize, Deserialize)]
31#[serde(bound = "")]
32pub enum Msg<E: Curve, L: SecurityLevel, D: Digest> {
33    /// Round 1 message
34    Round1(MsgRound1<D>),
35    /// Reliability check message (optional additional round)
36    ReliabilityCheck(MsgReliabilityCheck<D>),
37    /// Round 2 message
38    Round2(MsgRound2<E, L>),
39    /// Round 3 message
40    Round3(MsgRound3<E>),
41}
42
43/// Message from round 1
44#[derive(Clone, Serialize, Deserialize, udigest::Digestable)]
45#[serde(bound = "")]
46#[udigest(bound = "")]
47#[udigest(tag = prefixed!("round1"))]
48pub struct MsgRound1<D: Digest> {
49    /// $V_i$
50    #[udigest(as_bytes)]
51    pub commitment: digest::Output<D>,
52}
53/// Message from round 2
54#[serde_with::serde_as]
55#[derive(Clone, Serialize, Deserialize, udigest::Digestable)]
56#[serde(bound = "")]
57#[udigest(bound = "")]
58#[udigest(tag = prefixed!("round2"))]
59pub struct MsgRound2<E: Curve, L: SecurityLevel> {
60    /// `rid_i`
61    #[serde_as(as = "utils::HexOrBin")]
62    #[udigest(as_bytes)]
63    pub rid: L::KappaBytes,
64    /// $X_i$
65    pub X: NonZero<Point<E>>,
66    /// $A_i$
67    pub sch_commit: schnorr_pok::Commit<E>,
68    /// Party contribution to chain code
69    #[cfg(feature = "hd-wallet")]
70    #[serde_as(as = "Option<utils::HexOrBin>")]
71    #[udigest(as = Option<udigest::Bytes>)]
72    pub chain_code: Option<hd_wallet::ChainCode>,
73    /// $u_i$
74    #[serde(with = "hex::serde")]
75    #[udigest(as_bytes)]
76    pub decommit: L::KappaBytes,
77}
78/// Message from round 3
79#[derive(Clone, Serialize, Deserialize)]
80#[serde(bound = "")]
81pub struct MsgRound3<E: Curve> {
82    /// $\psi_i$
83    pub sch_proof: schnorr_pok::Proof<E>,
84}
85/// Message parties exchange to ensure reliability of broadcast channel
86#[derive(Clone, Serialize, Deserialize)]
87#[serde(bound = "")]
88pub struct MsgReliabilityCheck<D: Digest>(pub digest::Output<D>);
89
90mod unambiguous {
91    use crate::{ExecutionId, SecurityLevel};
92    use generic_ec::Curve;
93
94    #[derive(udigest::Digestable)]
95    #[udigest(tag = prefixed!("hash_commitment"))]
96    #[udigest(bound = "")]
97    pub struct HashCom<'a, E: Curve, L: SecurityLevel> {
98        pub sid: ExecutionId<'a>,
99        pub party_index: u16,
100        pub decommitment: &'a super::MsgRound2<E, L>,
101    }
102
103    #[derive(udigest::Digestable)]
104    #[udigest(tag = prefixed!("schnorr_pok"))]
105    #[udigest(bound = "")]
106    pub struct SchnorrPok<'a, E: Curve> {
107        pub sid: ExecutionId<'a>,
108        pub prover: u16,
109        #[udigest(as_bytes)]
110        pub rid: &'a [u8],
111        pub X: &'a generic_ec::NonZero<generic_ec::Point<E>>,
112        pub sch_commit: &'a generic_ec_zkp::schnorr_pok::Commit<E>,
113    }
114
115    #[derive(udigest::Digestable)]
116    #[udigest(tag = prefixed!("echo_round"))]
117    #[udigest(bound = "")]
118    pub struct Echo<'a, D: digest::Digest> {
119        pub sid: ExecutionId<'a>,
120        pub commitment: &'a super::MsgRound1<D>,
121    }
122}
123
124pub async fn run_keygen<E, R, M, L, D>(
125    mut tracer: Option<&mut dyn Tracer>,
126    i: u16,
127    n: u16,
128    reliable_broadcast_enforced: bool,
129    sid: ExecutionId<'_>,
130    rng: &mut R,
131    party: M,
132    #[cfg(feature = "hd-wallet")] hd_enabled: bool,
133) -> Result<CoreKeyShare<E>, KeygenError>
134where
135    E: Curve,
136    L: SecurityLevel,
137    D: Digest + Clone + 'static,
138    R: RngCore + CryptoRng,
139    M: Mpc<ProtocolMessage = Msg<E, L, D>>,
140{
141    tracer.protocol_begins();
142
143    tracer.stage("Setup networking");
144    let MpcParty { delivery, .. } = party.into_party();
145    let (incomings, mut outgoings) = delivery.split();
146
147    let mut rounds = RoundsRouter::<Msg<E, L, D>>::builder();
148    let round1 = rounds.add_round(RoundInput::<MsgRound1<D>>::broadcast(i, n));
149    let round1_sync = rounds.add_round(RoundInput::<MsgReliabilityCheck<D>>::broadcast(i, n));
150    let round2 = rounds.add_round(RoundInput::<MsgRound2<E, L>>::broadcast(i, n));
151    let round3 = rounds.add_round(RoundInput::<MsgRound3<E>>::broadcast(i, n));
152    let mut rounds = rounds.listen(incomings);
153
154    // Round 1
155    tracer.round_begins();
156
157    tracer.stage("Sample x_i, rid_i, chain_code");
158    let x_i = NonZero::<SecretScalar<E>>::random(rng);
159    let X_i = Point::generator() * &x_i;
160
161    let mut rid = L::KappaBytes::default();
162    rng.fill_bytes(rid.as_mut());
163
164    #[cfg(feature = "hd-wallet")]
165    let chain_code_local = if hd_enabled {
166        let mut chain_code = hd_wallet::ChainCode::default();
167        rng.fill_bytes(&mut chain_code);
168        Some(chain_code)
169    } else {
170        None
171    };
172
173    tracer.stage("Sample schnorr commitment");
174    let (sch_secret, sch_commit) = schnorr_pok::prover_commits_ephemeral_secret::<E, _>(rng);
175
176    tracer.stage("Commit to public data");
177    let my_decommitment = MsgRound2 {
178        rid,
179        X: X_i,
180        sch_commit: sch_commit.clone(),
181        #[cfg(feature = "hd-wallet")]
182        chain_code: chain_code_local,
183        decommit: {
184            let mut nonce = L::KappaBytes::default();
185            rng.fill_bytes(nonce.as_mut());
186            nonce
187        },
188    };
189    let hash_commit = udigest::hash::<D>(&unambiguous::HashCom {
190        sid,
191        party_index: i,
192        decommitment: &my_decommitment,
193    });
194    let my_commitment = MsgRound1 {
195        commitment: hash_commit,
196    };
197
198    tracer.send_msg();
199    outgoings
200        .send(Outgoing::broadcast(Msg::Round1(my_commitment.clone())))
201        .await
202        .map_err(IoError::send_message)?;
203    tracer.msg_sent();
204
205    // Round 2
206    tracer.round_begins();
207
208    tracer.receive_msgs();
209    let commitments = rounds
210        .complete(round1)
211        .await
212        .map_err(IoError::receive_message)?;
213    tracer.msgs_received();
214
215    // Optional reliability check
216    if reliable_broadcast_enforced {
217        tracer.stage("Hash received msgs (reliability check)");
218        let h_i = udigest::hash_iter::<D>(
219            commitments
220                .iter_including_me(&my_commitment)
221                .map(|commitment| unambiguous::Echo { sid, commitment }),
222        );
223
224        tracer.send_msg();
225        outgoings
226            .send(Outgoing::broadcast(Msg::ReliabilityCheck(
227                MsgReliabilityCheck(h_i.clone()),
228            )))
229            .await
230            .map_err(IoError::send_message)?;
231        tracer.msg_sent();
232
233        tracer.round_begins();
234
235        tracer.receive_msgs();
236        let round1_hashes = rounds
237            .complete(round1_sync)
238            .await
239            .map_err(IoError::receive_message)?;
240        tracer.msgs_received();
241
242        tracer.stage("Assert other parties hashed messages (reliability check)");
243        let parties_have_different_hashes = round1_hashes
244            .into_iter_indexed()
245            .filter(|(_j, _msg_id, hash_j)| hash_j.0 != h_i)
246            .map(|(j, msg_id, _)| (j, msg_id))
247            .collect::<Vec<_>>();
248        if !parties_have_different_hashes.is_empty() {
249            return Err(KeygenAborted::Round1NotReliable(parties_have_different_hashes).into());
250        }
251    }
252
253    tracer.send_msg();
254    outgoings
255        .send(Outgoing::broadcast(Msg::Round2(my_decommitment.clone())))
256        .await
257        .map_err(IoError::send_message)?;
258    tracer.msg_sent();
259
260    // Round 3
261    tracer.round_begins();
262
263    tracer.receive_msgs();
264    let decommitments = rounds
265        .complete(round2)
266        .await
267        .map_err(IoError::receive_message)?;
268    tracer.msgs_received();
269
270    tracer.stage("Validate decommitments");
271    let blame = utils::collect_blame(&commitments, &decommitments, |j, com, decom| {
272        let com_expected = udigest::hash::<D>(&unambiguous::HashCom {
273            sid,
274            party_index: j,
275            decommitment: decom,
276        });
277        com.commitment != com_expected
278    });
279    if !blame.is_empty() {
280        return Err(KeygenAborted::InvalidDecommitment(blame).into());
281    }
282
283    #[cfg(feature = "hd-wallet")]
284    let chain_code = if hd_enabled {
285        tracer.stage("Calculate chain_code");
286        let blame = utils::collect_simple_blame(&decommitments, |decom| decom.chain_code.is_none());
287        if !blame.is_empty() {
288            return Err(KeygenAborted::MissingChainCode(blame).into());
289        }
290        Some(decommitments.iter_including_me(&my_decommitment).try_fold(
291            hd_wallet::ChainCode::default(),
292            |acc, decom| {
293                Ok::<_, Bug>(utils::xor_array(
294                    acc,
295                    decom.chain_code.ok_or(Bug::NoChainCode)?,
296                ))
297            },
298        )?)
299    } else {
300        None
301    };
302
303    tracer.stage("Calculate challege rid");
304    let rid = decommitments
305        .iter_including_me(&my_decommitment)
306        .map(|d| &d.rid)
307        .fold(L::KappaBytes::default(), utils::xor_array);
308    let challenge = Scalar::from_hash::<D>(&unambiguous::SchnorrPok {
309        sid,
310        prover: i,
311        rid: rid.as_ref(),
312        X: &X_i,
313        sch_commit: &sch_commit,
314    });
315    let challenge = schnorr_pok::Challenge { nonce: challenge };
316
317    tracer.stage("Prove knowledge of `x_i`");
318    let sch_proof = schnorr_pok::prove(&sch_secret, &challenge, &x_i);
319
320    tracer.send_msg();
321    let my_sch_proof = MsgRound3 { sch_proof };
322    outgoings
323        .send(Outgoing::broadcast(Msg::Round3(my_sch_proof.clone())))
324        .await
325        .map_err(IoError::send_message)?;
326    tracer.msg_sent();
327
328    // Round 4
329    tracer.round_begins();
330
331    tracer.receive_msgs();
332    let sch_proofs = rounds
333        .complete(round3)
334        .await
335        .map_err(IoError::receive_message)?;
336    tracer.msgs_received();
337
338    tracer.stage("Validate schnorr proofs");
339    let blame = utils::collect_blame(&decommitments, &sch_proofs, |j, decom, sch_proof| {
340        let challenge = Scalar::from_hash::<D>(&unambiguous::SchnorrPok {
341            sid,
342            prover: j,
343            rid: rid.as_ref(),
344            X: &decom.X,
345            sch_commit: &decom.sch_commit,
346        });
347        let challenge = schnorr_pok::Challenge { nonce: challenge };
348        sch_proof
349            .sch_proof
350            .verify(&decom.sch_commit, &challenge, &decom.X)
351            .is_err()
352    });
353    if !blame.is_empty() {
354        return Err(KeygenAborted::InvalidSchnorrProof(blame).into());
355    }
356
357    tracer.protocol_ends();
358
359    Ok(DirtyCoreKeyShare {
360        i,
361        key_info: DirtyKeyInfo {
362            curve: Default::default(),
363            shared_public_key: NonZero::from_point(
364                decommitments
365                    .iter_including_me(&my_decommitment)
366                    .map(|d| d.X)
367                    .sum(),
368            )
369            .ok_or(Bug::ZeroPk)?,
370            public_shares: decommitments
371                .iter_including_me(&my_decommitment)
372                .map(|d| d.X)
373                .collect(),
374            vss_setup: None,
375            #[cfg(feature = "hd-wallet")]
376            chain_code,
377        },
378        x: x_i,
379    }
380    .validate()
381    .map_err(|e| Bug::InvalidKeyShare(e.into_error()))?)
382}