1use alloc::vec::Vec;
2
3use digest::Digest;
4use generic_ec::{Curve, NonZero, Point, Scalar, SecretScalar};
5use generic_ec_zkp::schnorr_pok;
6use rand_core::{CryptoRng, RngCore};
7use round_based::{
8 rounds_router::simple_store::RoundInput, rounds_router::RoundsRouter, Delivery, Mpc, MpcParty,
9 Outgoing, ProtocolMessage, SinkExt,
10};
11use serde::{Deserialize, Serialize};
12
13use crate::progress::Tracer;
14use crate::{
15 errors::IoError,
16 key_share::{CoreKeyShare, DirtyCoreKeyShare, DirtyKeyInfo, Validate},
17 security_level::SecurityLevel,
18 utils, ExecutionId,
19};
20
21use super::{Bug, KeygenAborted, KeygenError};
22
23macro_rules! prefixed {
24 ($name:tt) => {
25 concat!("dfns.cggmp24.keygen.non_threshold.", $name)
26 };
27}
28
29#[derive(ProtocolMessage, Clone, Serialize, Deserialize)]
31#[serde(bound = "")]
32pub enum Msg<E: Curve, L: SecurityLevel, D: Digest> {
33 Round1(MsgRound1<D>),
35 ReliabilityCheck(MsgReliabilityCheck<D>),
37 Round2(MsgRound2<E, L>),
39 Round3(MsgRound3<E>),
41}
42
43#[derive(Clone, Serialize, Deserialize, udigest::Digestable)]
45#[serde(bound = "")]
46#[udigest(bound = "")]
47#[udigest(tag = prefixed!("round1"))]
48pub struct MsgRound1<D: Digest> {
49 #[udigest(as_bytes)]
51 pub commitment: digest::Output<D>,
52}
53#[serde_with::serde_as]
55#[derive(Clone, Serialize, Deserialize, udigest::Digestable)]
56#[serde(bound = "")]
57#[udigest(bound = "")]
58#[udigest(tag = prefixed!("round2"))]
59pub struct MsgRound2<E: Curve, L: SecurityLevel> {
60 #[serde_as(as = "utils::HexOrBin")]
62 #[udigest(as_bytes)]
63 pub rid: L::KappaBytes,
64 pub X: NonZero<Point<E>>,
66 pub sch_commit: schnorr_pok::Commit<E>,
68 #[cfg(feature = "hd-wallet")]
70 #[serde_as(as = "Option<utils::HexOrBin>")]
71 #[udigest(as = Option<udigest::Bytes>)]
72 pub chain_code: Option<hd_wallet::ChainCode>,
73 #[serde(with = "hex::serde")]
75 #[udigest(as_bytes)]
76 pub decommit: L::KappaBytes,
77}
78#[derive(Clone, Serialize, Deserialize)]
80#[serde(bound = "")]
81pub struct MsgRound3<E: Curve> {
82 pub sch_proof: schnorr_pok::Proof<E>,
84}
85#[derive(Clone, Serialize, Deserialize)]
87#[serde(bound = "")]
88pub struct MsgReliabilityCheck<D: Digest>(pub digest::Output<D>);
89
90mod unambiguous {
91 use crate::{ExecutionId, SecurityLevel};
92 use generic_ec::Curve;
93
94 #[derive(udigest::Digestable)]
95 #[udigest(tag = prefixed!("hash_commitment"))]
96 #[udigest(bound = "")]
97 pub struct HashCom<'a, E: Curve, L: SecurityLevel> {
98 pub sid: ExecutionId<'a>,
99 pub party_index: u16,
100 pub decommitment: &'a super::MsgRound2<E, L>,
101 }
102
103 #[derive(udigest::Digestable)]
104 #[udigest(tag = prefixed!("schnorr_pok"))]
105 #[udigest(bound = "")]
106 pub struct SchnorrPok<'a, E: Curve> {
107 pub sid: ExecutionId<'a>,
108 pub prover: u16,
109 #[udigest(as_bytes)]
110 pub rid: &'a [u8],
111 pub X: &'a generic_ec::NonZero<generic_ec::Point<E>>,
112 pub sch_commit: &'a generic_ec_zkp::schnorr_pok::Commit<E>,
113 }
114
115 #[derive(udigest::Digestable)]
116 #[udigest(tag = prefixed!("echo_round"))]
117 #[udigest(bound = "")]
118 pub struct Echo<'a, D: digest::Digest> {
119 pub sid: ExecutionId<'a>,
120 pub commitment: &'a super::MsgRound1<D>,
121 }
122}
123
124pub async fn run_keygen<E, R, M, L, D>(
125 mut tracer: Option<&mut dyn Tracer>,
126 i: u16,
127 n: u16,
128 reliable_broadcast_enforced: bool,
129 sid: ExecutionId<'_>,
130 rng: &mut R,
131 party: M,
132 #[cfg(feature = "hd-wallet")] hd_enabled: bool,
133) -> Result<CoreKeyShare<E>, KeygenError>
134where
135 E: Curve,
136 L: SecurityLevel,
137 D: Digest + Clone + 'static,
138 R: RngCore + CryptoRng,
139 M: Mpc<ProtocolMessage = Msg<E, L, D>>,
140{
141 tracer.protocol_begins();
142
143 tracer.stage("Setup networking");
144 let MpcParty { delivery, .. } = party.into_party();
145 let (incomings, mut outgoings) = delivery.split();
146
147 let mut rounds = RoundsRouter::<Msg<E, L, D>>::builder();
148 let round1 = rounds.add_round(RoundInput::<MsgRound1<D>>::broadcast(i, n));
149 let round1_sync = rounds.add_round(RoundInput::<MsgReliabilityCheck<D>>::broadcast(i, n));
150 let round2 = rounds.add_round(RoundInput::<MsgRound2<E, L>>::broadcast(i, n));
151 let round3 = rounds.add_round(RoundInput::<MsgRound3<E>>::broadcast(i, n));
152 let mut rounds = rounds.listen(incomings);
153
154 tracer.round_begins();
156
157 tracer.stage("Sample x_i, rid_i, chain_code");
158 let x_i = NonZero::<SecretScalar<E>>::random(rng);
159 let X_i = Point::generator() * &x_i;
160
161 let mut rid = L::KappaBytes::default();
162 rng.fill_bytes(rid.as_mut());
163
164 #[cfg(feature = "hd-wallet")]
165 let chain_code_local = if hd_enabled {
166 let mut chain_code = hd_wallet::ChainCode::default();
167 rng.fill_bytes(&mut chain_code);
168 Some(chain_code)
169 } else {
170 None
171 };
172
173 tracer.stage("Sample schnorr commitment");
174 let (sch_secret, sch_commit) = schnorr_pok::prover_commits_ephemeral_secret::<E, _>(rng);
175
176 tracer.stage("Commit to public data");
177 let my_decommitment = MsgRound2 {
178 rid,
179 X: X_i,
180 sch_commit: sch_commit.clone(),
181 #[cfg(feature = "hd-wallet")]
182 chain_code: chain_code_local,
183 decommit: {
184 let mut nonce = L::KappaBytes::default();
185 rng.fill_bytes(nonce.as_mut());
186 nonce
187 },
188 };
189 let hash_commit = udigest::hash::<D>(&unambiguous::HashCom {
190 sid,
191 party_index: i,
192 decommitment: &my_decommitment,
193 });
194 let my_commitment = MsgRound1 {
195 commitment: hash_commit,
196 };
197
198 tracer.send_msg();
199 outgoings
200 .send(Outgoing::broadcast(Msg::Round1(my_commitment.clone())))
201 .await
202 .map_err(IoError::send_message)?;
203 tracer.msg_sent();
204
205 tracer.round_begins();
207
208 tracer.receive_msgs();
209 let commitments = rounds
210 .complete(round1)
211 .await
212 .map_err(IoError::receive_message)?;
213 tracer.msgs_received();
214
215 if reliable_broadcast_enforced {
217 tracer.stage("Hash received msgs (reliability check)");
218 let h_i = udigest::hash_iter::<D>(
219 commitments
220 .iter_including_me(&my_commitment)
221 .map(|commitment| unambiguous::Echo { sid, commitment }),
222 );
223
224 tracer.send_msg();
225 outgoings
226 .send(Outgoing::broadcast(Msg::ReliabilityCheck(
227 MsgReliabilityCheck(h_i.clone()),
228 )))
229 .await
230 .map_err(IoError::send_message)?;
231 tracer.msg_sent();
232
233 tracer.round_begins();
234
235 tracer.receive_msgs();
236 let round1_hashes = rounds
237 .complete(round1_sync)
238 .await
239 .map_err(IoError::receive_message)?;
240 tracer.msgs_received();
241
242 tracer.stage("Assert other parties hashed messages (reliability check)");
243 let parties_have_different_hashes = round1_hashes
244 .into_iter_indexed()
245 .filter(|(_j, _msg_id, hash_j)| hash_j.0 != h_i)
246 .map(|(j, msg_id, _)| (j, msg_id))
247 .collect::<Vec<_>>();
248 if !parties_have_different_hashes.is_empty() {
249 return Err(KeygenAborted::Round1NotReliable(parties_have_different_hashes).into());
250 }
251 }
252
253 tracer.send_msg();
254 outgoings
255 .send(Outgoing::broadcast(Msg::Round2(my_decommitment.clone())))
256 .await
257 .map_err(IoError::send_message)?;
258 tracer.msg_sent();
259
260 tracer.round_begins();
262
263 tracer.receive_msgs();
264 let decommitments = rounds
265 .complete(round2)
266 .await
267 .map_err(IoError::receive_message)?;
268 tracer.msgs_received();
269
270 tracer.stage("Validate decommitments");
271 let blame = utils::collect_blame(&commitments, &decommitments, |j, com, decom| {
272 let com_expected = udigest::hash::<D>(&unambiguous::HashCom {
273 sid,
274 party_index: j,
275 decommitment: decom,
276 });
277 com.commitment != com_expected
278 });
279 if !blame.is_empty() {
280 return Err(KeygenAborted::InvalidDecommitment(blame).into());
281 }
282
283 #[cfg(feature = "hd-wallet")]
284 let chain_code = if hd_enabled {
285 tracer.stage("Calculate chain_code");
286 let blame = utils::collect_simple_blame(&decommitments, |decom| decom.chain_code.is_none());
287 if !blame.is_empty() {
288 return Err(KeygenAborted::MissingChainCode(blame).into());
289 }
290 Some(decommitments.iter_including_me(&my_decommitment).try_fold(
291 hd_wallet::ChainCode::default(),
292 |acc, decom| {
293 Ok::<_, Bug>(utils::xor_array(
294 acc,
295 decom.chain_code.ok_or(Bug::NoChainCode)?,
296 ))
297 },
298 )?)
299 } else {
300 None
301 };
302
303 tracer.stage("Calculate challege rid");
304 let rid = decommitments
305 .iter_including_me(&my_decommitment)
306 .map(|d| &d.rid)
307 .fold(L::KappaBytes::default(), utils::xor_array);
308 let challenge = Scalar::from_hash::<D>(&unambiguous::SchnorrPok {
309 sid,
310 prover: i,
311 rid: rid.as_ref(),
312 X: &X_i,
313 sch_commit: &sch_commit,
314 });
315 let challenge = schnorr_pok::Challenge { nonce: challenge };
316
317 tracer.stage("Prove knowledge of `x_i`");
318 let sch_proof = schnorr_pok::prove(&sch_secret, &challenge, &x_i);
319
320 tracer.send_msg();
321 let my_sch_proof = MsgRound3 { sch_proof };
322 outgoings
323 .send(Outgoing::broadcast(Msg::Round3(my_sch_proof.clone())))
324 .await
325 .map_err(IoError::send_message)?;
326 tracer.msg_sent();
327
328 tracer.round_begins();
330
331 tracer.receive_msgs();
332 let sch_proofs = rounds
333 .complete(round3)
334 .await
335 .map_err(IoError::receive_message)?;
336 tracer.msgs_received();
337
338 tracer.stage("Validate schnorr proofs");
339 let blame = utils::collect_blame(&decommitments, &sch_proofs, |j, decom, sch_proof| {
340 let challenge = Scalar::from_hash::<D>(&unambiguous::SchnorrPok {
341 sid,
342 prover: j,
343 rid: rid.as_ref(),
344 X: &decom.X,
345 sch_commit: &decom.sch_commit,
346 });
347 let challenge = schnorr_pok::Challenge { nonce: challenge };
348 sch_proof
349 .sch_proof
350 .verify(&decom.sch_commit, &challenge, &decom.X)
351 .is_err()
352 });
353 if !blame.is_empty() {
354 return Err(KeygenAborted::InvalidSchnorrProof(blame).into());
355 }
356
357 tracer.protocol_ends();
358
359 Ok(DirtyCoreKeyShare {
360 i,
361 key_info: DirtyKeyInfo {
362 curve: Default::default(),
363 shared_public_key: NonZero::from_point(
364 decommitments
365 .iter_including_me(&my_decommitment)
366 .map(|d| d.X)
367 .sum(),
368 )
369 .ok_or(Bug::ZeroPk)?,
370 public_shares: decommitments
371 .iter_including_me(&my_decommitment)
372 .map(|d| d.X)
373 .collect(),
374 vss_setup: None,
375 #[cfg(feature = "hd-wallet")]
376 chain_code,
377 },
378 x: x_i,
379 }
380 .validate()
381 .map_err(|e| Bug::InvalidKeyShare(e.into_error()))?)
382}