Skip to main content

cggmp24_keygen/
progress.rs

1//! Traces progress of protocol execution
2//!
3//! Provides [`Tracer`] trait that can be used to trace progress of ongoing MPC protocol execution.
4//! For instance, it can be implemented to report progress to the end user.
5//!
6//! Out of box, there's [`PerfProfiler`] which can be used to bechmark a protocol.
7
8/// Traces progress of protocol execution
9///
10/// See [module level documentation](self) for more details
11pub trait Tracer: Send + Sync {
12    /// Traces occurred event
13    fn trace_event(&mut self, event: Event);
14
15    /// Traces [`Event::ProtocolBegins`] event
16    fn protocol_begins(&mut self) {
17        self.trace_event(Event::ProtocolBegins)
18    }
19    /// Traces [`Event::RoundBegins`] event
20    fn round_begins(&mut self) {
21        self.trace_event(Event::RoundBegins { name: None })
22    }
23    /// Traces [`Event::RoundBegins`] event
24    fn named_round_begins(&mut self, round_name: &'static str) {
25        self.trace_event(Event::RoundBegins {
26            name: Some(round_name),
27        })
28    }
29    /// Traces [`Event::Stage`] event
30    fn stage(&mut self, stage: &'static str) {
31        self.trace_event(Event::Stage { name: stage })
32    }
33    /// Traces [`Event::ReceiveMsgs`] event
34    fn receive_msgs(&mut self) {
35        self.trace_event(Event::ReceiveMsgs)
36    }
37    /// Traces [`Event::MsgsReceived`] event
38    fn msgs_received(&mut self) {
39        self.trace_event(Event::MsgsReceived)
40    }
41    /// Traces [`Event::SendMsg`] event
42    fn send_msg(&mut self) {
43        self.trace_event(Event::SendMsg)
44    }
45    /// Traces [`Event::MsgSent`] event
46    fn msg_sent(&mut self) {
47        self.trace_event(Event::MsgSent)
48    }
49    /// Traces [`Event::ProtocolEnds`] event
50    fn protocol_ends(&mut self) {
51        self.trace_event(Event::ProtocolEnds)
52    }
53}
54
55/// Event occurred during the protocol execution
56#[derive(Debug, PartialEq, Eq, Copy, Clone)]
57pub enum Event {
58    /// Protocol begins
59    ///
60    /// This event is always emitted before any other events
61    ProtocolBegins,
62
63    /// Round begins
64    RoundBegins {
65        /// Optional name of the round
66        name: Option<&'static str>,
67    },
68    /// Stage begins
69    Stage {
70        /// Name of the stage
71        name: &'static str,
72    },
73
74    /// Protocol waits for some messages to be received
75    ReceiveMsgs,
76    /// Protocol received messages, round continues
77    MsgsReceived,
78
79    /// Protocol starts sending a message
80    SendMsg,
81    /// Protocol sent a message, round continues
82    MsgSent,
83
84    /// Protocol completed
85    ProtocolEnds,
86}
87
88impl Tracer for &mut dyn Tracer {
89    fn trace_event(&mut self, event: Event) {
90        (*self).trace_event(event)
91    }
92}
93
94impl<T: Tracer> Tracer for &mut T {
95    fn trace_event(&mut self, event: Event) {
96        <T as Tracer>::trace_event(self, event)
97    }
98}
99
100impl<T: Tracer> Tracer for Option<T> {
101    fn trace_event(&mut self, event: Event) {
102        match self {
103            Some(tracer) => tracer.trace_event(event),
104            None => {
105                // no-op
106            }
107        }
108    }
109}
110
111#[cfg(feature = "std")]
112pub use requires_std::*;
113#[cfg(feature = "std")]
114mod requires_std {
115    use alloc::{vec, vec::Vec};
116    use core::fmt;
117    use std::time::{Duration, Instant};
118
119    use thiserror::Error;
120
121    use super::*;
122
123    /// Profiles performance of the protocol
124    ///
125    /// Implements [`Tracer`] trait so it can be embedded into protocol execution. `PerfProfiler` keeps track of time
126    /// passed between each step of protocol. After protocol is completed, you can obtain a [`PerfReport`] via
127    /// [`.get_report()`](PerfProfiler::get_report) method that contains all the measurements.
128    pub struct PerfProfiler {
129        last_timestamp: Option<Instant>,
130        ongoing_stage: Option<usize>,
131        protocol_began: Option<Instant>,
132        report: PerfReport,
133        error: Option<ProfileError>,
134    }
135
136    /// Performance report generated by [`PerfProfiler`]
137    #[derive(Debug, Clone)]
138    pub struct PerfReport {
139        /// Duration of setup phase (time after protocol began and before first round started)
140        pub setup: Duration,
141        /// Stages of setup phase
142        pub setup_stages: Vec<StageDuration>,
143        /// Performance report for each round
144        pub rounds: Vec<RoundDuration>,
145        display_io: bool,
146    }
147
148    /// Performance of specific round (part of [`PerfReport`])
149    #[derive(Debug, Clone)]
150    pub struct RoundDuration {
151        /// Round name (if provided)
152        pub round_name: Option<&'static str>,
153        /// Stages of the round
154        pub stages: Vec<StageDuration>,
155        /// Total duration of pure computation performed during the round
156        pub computation: Duration,
157        /// Total time we spent during this round on sending messages
158        pub sending: Duration,
159        /// Total time we spent during this round on receiving messages
160        pub receiving: Duration,
161    }
162
163    /// Performance of specific stage (part of [`PerfReport`])
164    #[derive(Debug, Clone)]
165    pub struct StageDuration {
166        /// Stage name
167        pub name: &'static str,
168        /// Duration of the stage
169        pub duration: Duration,
170    }
171
172    /// Protocol profiling resulted into error
173    #[derive(Debug, Error, Clone)]
174    #[error("profiler failed to trace protocol: it behaved unexpectedly")]
175    pub struct ProfileError(
176        #[source]
177        #[from]
178        ErrorReason,
179    );
180
181    #[derive(Debug, Error, Clone)]
182    enum ErrorReason {
183        #[error("protocol has never began")]
184        ProtocolNeverBegan,
185        #[error("tracing stage or sending/receiving message but round never began")]
186        RoundNeverBegan,
187        #[error("stage is ongoing, but it can't be finished with that event: {event:?}")]
188        CantFinishStage { event: Event },
189    }
190
191    impl Tracer for PerfProfiler {
192        fn trace_event(&mut self, event: Event) {
193            if self.error.is_none() {
194                if let Err(err) = self.try_trace_event(event) {
195                    self.error = Some(err)
196                }
197            }
198        }
199    }
200
201    impl PerfProfiler {
202        /// Constructs new [`PerfProfiler`]
203        pub fn new() -> Self {
204            Self {
205                last_timestamp: None,
206                ongoing_stage: None,
207                protocol_began: None,
208                report: PerfReport {
209                    setup: Duration::ZERO,
210                    setup_stages: vec![],
211                    rounds: vec![],
212                    display_io: true,
213                },
214                error: None,
215            }
216        }
217
218        /// Obtains a report
219        ///
220        /// Returns error if protocol behaved unexpectedly
221        pub fn get_report(&self) -> Result<PerfReport, ProfileError> {
222            if let Some(err) = self.error.clone() {
223                Err(err)
224            } else {
225                Ok(self.report.clone())
226            }
227        }
228
229        fn try_trace_event(&mut self, event: Event) -> Result<(), ProfileError> {
230            let now = Instant::now();
231
232            if Self::event_can_finish_ongoing_stage(&event) {
233                if let Some(stage_i) = self.ongoing_stage.take() {
234                    let last_timestamp = self.last_timestamp()?;
235
236                    if !self.report.rounds.is_empty() {
237                        let last_round = self.last_round_mut()?;
238                        last_round.stages[stage_i].duration += now - last_timestamp;
239                    } else {
240                        self.report.setup_stages[stage_i].duration += now - last_timestamp;
241                    }
242                }
243            } else if self.ongoing_stage.is_some() {
244                return Err(ErrorReason::CantFinishStage { event }.into());
245            }
246            match event {
247                Event::ProtocolBegins => {
248                    self.protocol_began = Some(now);
249                }
250                Event::RoundBegins { name } => {
251                    let last_timestamp = self.last_timestamp()?;
252                    match self.report.rounds.last_mut() {
253                        None => self.report.setup += now - last_timestamp,
254                        Some(last_round) => last_round.computation += now - last_timestamp,
255                    }
256                    self.report.rounds.push(RoundDuration {
257                        round_name: name,
258                        stages: vec![],
259                        computation: Duration::ZERO,
260                        sending: Duration::ZERO,
261                        receiving: Duration::ZERO,
262                    })
263                }
264                Event::Stage { name } => {
265                    let last_timestamp = self.last_timestamp()?;
266
267                    let stages = if !self.report.rounds.is_empty() {
268                        let last_round = self.last_round_mut()?;
269                        last_round.computation += now - last_timestamp;
270
271                        &mut last_round.stages
272                    } else {
273                        self.report.setup += now - last_timestamp;
274                        &mut self.report.setup_stages
275                    };
276
277                    let stage_i = stages.iter().position(|s| s.name == name);
278                    let stage_i = match stage_i {
279                        Some(i) => i,
280                        None => {
281                            stages.push(StageDuration {
282                                name,
283                                duration: Duration::ZERO,
284                            });
285                            stages.len() - 1
286                        }
287                    };
288                    self.ongoing_stage = Some(stage_i);
289                }
290                Event::ReceiveMsgs => {
291                    let last_timestamp = self.last_timestamp()?;
292                    let last_round = self.last_round_mut()?;
293                    last_round.computation += now - last_timestamp;
294                }
295                Event::MsgsReceived => {
296                    let last_timestamp = self.last_timestamp()?;
297                    let last_round = self.last_round_mut()?;
298                    last_round.receiving += now - last_timestamp;
299                }
300                Event::SendMsg => {
301                    let last_timestamp = self.last_timestamp()?;
302                    let last_round = self.last_round_mut()?;
303                    last_round.computation += now - last_timestamp;
304                }
305                Event::MsgSent => {
306                    let last_timestamp = self.last_timestamp()?;
307                    let last_round = self.last_round_mut()?;
308                    last_round.sending += now - last_timestamp;
309                }
310                Event::ProtocolEnds => {
311                    let last_timestamp = self.last_timestamp()?;
312                    let last_round = self.last_round_mut()?;
313                    last_round.computation += now - last_timestamp;
314                }
315            }
316
317            self.last_timestamp = Some(now);
318            Ok(())
319        }
320
321        fn last_timestamp(&self) -> Result<Instant, ProfileError> {
322            let last_timestamp = self.last_timestamp.ok_or(ErrorReason::ProtocolNeverBegan)?;
323            Ok(last_timestamp)
324        }
325        fn last_round_mut(&mut self) -> Result<&mut RoundDuration, ProfileError> {
326            let last_round = self
327                .report
328                .rounds
329                .last_mut()
330                .ok_or(ErrorReason::RoundNeverBegan)?;
331            Ok(last_round)
332        }
333        fn event_can_finish_ongoing_stage(event: &Event) -> bool {
334            matches!(
335                event,
336                Event::RoundBegins { .. }
337                    | Event::Stage { .. }
338                    | Event::ReceiveMsgs
339                    | Event::SendMsg
340                    | Event::ProtocolEnds
341            )
342        }
343    }
344
345    impl Default for PerfProfiler {
346        fn default() -> Self {
347            Self::new()
348        }
349    }
350
351    impl PerfReport {
352        /// Specifies whether time spent on i/o should be rendered in the final report
353        ///
354        /// Time spent on i/o is the time when signer was sending messages or waiting other
355        /// parties to send messages
356        pub fn display_io(mut self, display: bool) -> Self {
357            self.display_io = display;
358            self
359        }
360    }
361
362    impl fmt::Display for PerfReport {
363        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
364            let total_computation =
365                self.setup + self.rounds.iter().map(|r| r.computation).sum::<Duration>();
366            let total_send = if self.display_io {
367                self.rounds.iter().map(|r| r.sending).sum::<Duration>()
368            } else {
369                Duration::ZERO
370            };
371            let total_recv = if self.display_io {
372                self.rounds.iter().map(|r| r.receiving).sum::<Duration>()
373            } else {
374                Duration::ZERO
375            };
376            let total_io = total_send + total_recv;
377            let total = total_computation + total_io;
378
379            writeln!(f, "Protocol Performance:")?;
380            writeln!(f, "  - Protocol took {total:.2?} to complete")?;
381            if self.display_io {
382                writeln!(
383                    f,
384                    "    - Computation: {total_computation:.2?} ({})",
385                    percent(total_computation, total)
386                )?;
387                writeln!(
388                    f,
389                    "    - I/O: {total_io:.2?} ({})",
390                    percent(total_io, total)
391                )?;
392                writeln!(f, "      - Send: {total_send:.2?}")?;
393                writeln!(f, "      - Recv: {total_recv:.2?}")?;
394            }
395
396            writeln!(f, "In particular:")?;
397            Self::fmt_round(f, 0, Some("Stage"), &self.setup_stages, self.setup, None)?;
398
399            for (i, round) in self.rounds.iter().enumerate() {
400                Self::fmt_round(
401                    f,
402                    i + 1,
403                    round.round_name,
404                    &round.stages,
405                    round.computation,
406                    if self.display_io {
407                        Some((round.sending, round.receiving))
408                    } else {
409                        None
410                    },
411                )?;
412            }
413
414            Ok(())
415        }
416    }
417
418    impl PerfReport {
419        fn fmt_round(
420            f: &mut fmt::Formatter,
421            i: usize,
422            round_name: Option<&str>,
423            stages: &[StageDuration],
424            computation: Duration,
425            io: Option<(Duration, Duration)>, // (sending, receiving)
426        ) -> fmt::Result {
427            let total_duration = computation + io.map(|(s, r)| s + r).unwrap_or_default();
428            if let Some(round_name) = round_name {
429                writeln!(f, "  - {round_name}: {total_duration:.2?}")?
430            } else {
431                writeln!(f, "  - Round {i}: {total_duration:.2?}")?
432            }
433
434            Self::fmt_stages(f, total_duration, stages)?;
435
436            if let Some((sending, receiving)) = io {
437                let total_io = sending + receiving;
438                writeln!(
439                    f,
440                    "    - I/O: {:.2?} ({})",
441                    total_io,
442                    percent(total_io, total_duration)
443                )?;
444                writeln!(f, "      - Send: {sending:.2?}")?;
445                writeln!(f, "      - Recv: {receiving:.2?}")?;
446            }
447
448            if !stages.is_empty() || io.is_some() {
449                let stages_total = stages.iter().map(|s| s.duration).sum::<Duration>();
450                let unstaged = computation - stages_total;
451                let percent = percent(unstaged, total_duration);
452                writeln!(f, "    - Unstaged: {unstaged:.2?} ({percent})")?;
453            }
454
455            Ok(())
456        }
457
458        fn fmt_stages(
459            f: &mut fmt::Formatter,
460            total: Duration,
461            stages: &[StageDuration],
462        ) -> fmt::Result {
463            for stage in stages {
464                writeln!(
465                    f,
466                    "    - {}: {:.2?} ({})",
467                    stage.name,
468                    stage.duration,
469                    percent(stage.duration, total),
470                )?;
471            }
472            Ok(())
473        }
474    }
475
476    fn percent(part: Duration, total: Duration) -> impl fmt::Display {
477        struct Percentage(Duration, Duration);
478
479        impl fmt::Display for Percentage {
480            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
481                let percent = self.0.as_secs_f64() / self.1.as_secs_f64() * 100.;
482                write!(f, "{percent:.1}%")
483            }
484        }
485
486        Percentage(part, total)
487    }
488
489    /// Prints progress of the protocol to stderr
490    #[derive(Default)]
491    pub struct Stderr {
492        prefix: Option<std::string::String>,
493    }
494
495    impl Stderr {
496        /// Constructs an stderr tracer
497        pub fn new() -> Self {
498            Self::default()
499        }
500
501        /// Sets a prefix to be printed for each event
502        pub fn with_prefix(mut self, prefix: impl std::string::ToString) -> Self {
503            self.prefix = Some(prefix.to_string());
504            self
505        }
506    }
507
508    impl Tracer for Stderr {
509        fn trace_event(&mut self, event: Event) {
510            if let Some(prefix) = &self.prefix {
511                std::eprintln!("{prefix}: {event:?}")
512            } else {
513                std::eprintln!("{event:?}")
514            }
515        }
516    }
517}