wsts/
traits.rs

1use core::{cmp::PartialEq, fmt::Debug};
2use hashbrown::{HashMap, HashSet};
3use polynomial::Polynomial;
4use rand_core::{CryptoRng, RngCore};
5use serde::{Deserialize, Serialize};
6use std::fmt;
7
8use crate::{
9    common::{MerkleRoot, Nonce, PolyCommitment, PublicNonce, Signature, SignatureShare},
10    curve::{point::Point, scalar::Scalar},
11    errors::{AggregatorError, DkgError},
12    taproot::SchnorrProof,
13};
14
15#[derive(Clone, Deserialize, Serialize, PartialEq)]
16/// The saved state required to reconstruct a party
17pub struct PartyState {
18    /// The party's private polynomial
19    pub polynomial: Option<Polynomial<Scalar>>,
20    /// The key IDS and associate private keys for this party
21    pub private_keys: Vec<(u32, Scalar)>,
22    /// The nonce being used by this party
23    pub nonce: Nonce,
24}
25
26impl fmt::Debug for PartyState {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        f.debug_struct("PartyState").finish_non_exhaustive()
29    }
30}
31
32#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
33/// The saved state required to reconstruct a signer
34pub struct SignerState {
35    /// The signer ID
36    pub id: u32,
37    /// The key IDs this signer controls
38    pub key_ids: Vec<u32>,
39    /// The total number of keys
40    pub num_keys: u32,
41    /// The total number of parties
42    pub num_parties: u32,
43    /// The threshold for signing
44    pub threshold: u32,
45    /// The aggregate group public key
46    pub group_key: Point,
47    /// The party IDs and associated state for this signer
48    pub parties: Vec<(u32, PartyState)>,
49}
50
51/// A trait which provides a common `Signer` interface for `v1` and `v2`
52pub trait Signer: Clone + Debug + PartialEq {
53    /// Create a new `Signer`
54    fn new<RNG: RngCore + CryptoRng>(
55        party_id: u32,
56        key_ids: &[u32],
57        num_signers: u32,
58        num_keys: u32,
59        threshold: u32,
60        rng: &mut RNG,
61    ) -> Self;
62
63    /// Load a signer from the previously saved `state`
64    fn load(state: &SignerState) -> Self;
65
66    /// Save the state required to reconstruct the party
67    fn save(&self) -> SignerState;
68
69    /// Get the signer ID for this signer
70    fn get_id(&self) -> u32;
71
72    /// Get all key IDs for this signer
73    fn get_key_ids(&self) -> Vec<u32>;
74
75    /// Get the total number of parties
76    fn get_num_parties(&self) -> u32;
77
78    /// Get all poly commitments for this signer and the passed context
79    fn get_poly_commitments<RNG: RngCore + CryptoRng>(
80        &self,
81        ctx: &[u8],
82        rng: &mut RNG,
83    ) -> Vec<PolyCommitment>;
84
85    /// Reset all polynomials for this signer
86    fn reset_polys<RNG: RngCore + CryptoRng>(&mut self, rng: &mut RNG);
87
88    /// Clear all polynomials for this signer
89    fn clear_polys(&mut self);
90
91    /// Get all private shares for this signer
92    fn get_shares(&self) -> HashMap<u32, HashMap<u32, Scalar>>;
93
94    /// Compute all secrets for this signer
95    fn compute_secrets(
96        &mut self,
97        shares: &HashMap<u32, HashMap<u32, Scalar>>,
98        polys: &HashMap<u32, PolyCommitment>,
99        ctx: &[u8],
100    ) -> Result<(), HashMap<u32, DkgError>>;
101
102    /// Generate all nonces for this signer
103    fn gen_nonces<RNG: RngCore + CryptoRng>(
104        &mut self,
105        secret_key: &Scalar,
106        rng: &mut RNG,
107    ) -> Vec<PublicNonce>;
108
109    /// Compute intermediate values
110    fn compute_intermediate(
111        msg: &[u8],
112        signer_ids: &[u32],
113        key_ids: &[u32],
114        nonces: &[PublicNonce],
115    ) -> (Vec<Point>, Point);
116
117    /// Validate that signer_id owns party_id
118    fn validate_party_id(
119        signer_id: u32,
120        party_id: u32,
121        signer_key_ids: &HashMap<u32, HashSet<u32>>,
122    ) -> bool;
123
124    /// Sign `msg` using all this signer's keys
125    fn sign(
126        &self,
127        msg: &[u8],
128        signer_ids: &[u32],
129        key_ids: &[u32],
130        nonces: &[PublicNonce],
131    ) -> Vec<SignatureShare>;
132
133    /// Sign `msg` using all this signer's keys
134    fn sign_schnorr(
135        &self,
136        msg: &[u8],
137        signer_ids: &[u32],
138        key_ids: &[u32],
139        nonces: &[PublicNonce],
140    ) -> Vec<SignatureShare>;
141
142    /// Sign `msg` using all this signer's keys and a tweaked public key
143    fn sign_taproot(
144        &self,
145        msg: &[u8],
146        signer_ids: &[u32],
147        key_ids: &[u32],
148        nonces: &[PublicNonce],
149        merkle_root: Option<MerkleRoot>,
150    ) -> Vec<SignatureShare>;
151}
152
153/// A trait which provides a common `Aggregator` interface for `v1` and `v2`
154pub trait Aggregator: Clone + Debug + PartialEq {
155    /// Construct an Aggregator with the passed parameters
156    fn new(num_keys: u32, threshold: u32) -> Self;
157
158    /// Initialize an Aggregator with the passed polynomial commitments
159    fn init(&mut self, poly_comms: &HashMap<u32, PolyCommitment>) -> Result<(), AggregatorError>;
160
161    /// Check and aggregate the signature shares into a FROST `Signature`
162    fn sign(
163        &mut self,
164        msg: &[u8],
165        nonces: &[PublicNonce],
166        sig_shares: &[SignatureShare],
167        key_ids: &[u32],
168    ) -> Result<Signature, AggregatorError>;
169
170    /// Check and aggregate the signature shares into a BIP-340 `SchnorrProof`.
171    /// <https://github.com/bitcoin/bips/blob/master/bip-0340.mediawiki>
172    fn sign_schnorr(
173        &mut self,
174        msg: &[u8],
175        nonces: &[PublicNonce],
176        sig_shares: &[SignatureShare],
177        key_ids: &[u32],
178    ) -> Result<SchnorrProof, AggregatorError>;
179
180    /// Check and aggregate the signature shares into a BIP-340 `SchnorrProof` with BIP-341 key tweaks
181    /// <https://github.com/bitcoin/bips/blob/master/bip-0340.mediawiki>
182    /// <https://github.com/bitcoin/bips/blob/master/bip-0341.mediawiki>
183    fn sign_taproot(
184        &mut self,
185        msg: &[u8],
186        nonces: &[PublicNonce],
187        sig_shares: &[SignatureShare],
188        key_ids: &[u32],
189        merkle_root: Option<MerkleRoot>,
190    ) -> Result<SchnorrProof, AggregatorError>;
191}
192
193/// Helper functions for tests
194pub mod test_helpers {
195    use hashbrown::HashMap;
196    use rand_core::{CryptoRng, RngCore};
197
198    use crate::{common::PolyCommitment, errors::DkgError, traits::Scalar, util::create_rng};
199
200    /// Run DKG on the passed signers
201    pub fn dkg<RNG: RngCore + CryptoRng, Signer: super::Signer>(
202        signers: &mut [Signer],
203        rng: &mut RNG,
204    ) -> Result<HashMap<u32, PolyCommitment>, HashMap<u32, DkgError>> {
205        let ctx = 0u64.to_be_bytes();
206        let public_shares: HashMap<u32, PolyCommitment> = signers
207            .iter()
208            .flat_map(|s| s.get_poly_commitments(&ctx, rng))
209            .map(|comm| (comm.id.id.get_u32(), comm))
210            .collect();
211        let mut private_shares = HashMap::new();
212
213        for signer in signers.iter() {
214            for (signer_id, signer_shares) in signer.get_shares() {
215                private_shares.insert(signer_id, signer_shares);
216            }
217        }
218
219        let mut secret_errors = HashMap::new();
220        for signer in signers.iter_mut() {
221            if let Err(signer_secret_errors) =
222                signer.compute_secrets(&private_shares, &public_shares, &ctx)
223            {
224                secret_errors.extend(signer_secret_errors.into_iter());
225            }
226        }
227
228        if secret_errors.is_empty() {
229            Ok(public_shares)
230        } else {
231            Err(secret_errors)
232        }
233    }
234
235    /// Remove the provided key ids from the list of private shares and execute compute secrets
236    fn compute_secrets_missing_private_shares<RNG: RngCore + CryptoRng, Signer: super::Signer>(
237        signers: &mut [Signer],
238        rng: &mut RNG,
239        missing_key_ids: &[u32],
240    ) -> Result<HashMap<u32, PolyCommitment>, HashMap<u32, DkgError>> {
241        assert!(
242            !missing_key_ids.is_empty(),
243            "Cannot run a missing shares test without specificying at least one missing key id"
244        );
245        let ctx = 0u64.to_be_bytes();
246        let polys: HashMap<u32, PolyCommitment> = signers
247            .iter()
248            .flat_map(|s| s.get_poly_commitments(&ctx, rng))
249            .map(|comm| (comm.id.id.get_u32(), comm))
250            .collect();
251        let mut private_shares = HashMap::new();
252
253        for signer in signers.iter() {
254            for (signer_id, mut signer_shares) in signer.get_shares() {
255                for key_id in missing_key_ids {
256                    if signer.get_key_ids().contains(key_id) {
257                        signer_shares.remove(key_id);
258                    }
259                }
260                private_shares.insert(signer_id, signer_shares);
261            }
262        }
263
264        let mut secret_errors = HashMap::new();
265        for signer in signers.iter_mut() {
266            if let Err(signer_secret_errors) = signer.compute_secrets(&private_shares, &polys, &ctx)
267            {
268                secret_errors.extend(signer_secret_errors.into_iter());
269            }
270        }
271
272        if secret_errors.is_empty() {
273            Ok(polys)
274        } else {
275            Err(secret_errors)
276        }
277    }
278
279    #[allow(non_snake_case)]
280    /// Run compute secrets test to trigger MissingPrivateShares code path
281    pub fn run_compute_secrets_missing_private_shares<Signer: super::Signer>() {
282        let Nk: u32 = 10;
283        let Np: u32 = 4;
284        let T: u32 = 7;
285        let signer_ids: Vec<Vec<u32>> = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8], vec![9, 10]];
286        let missing_key_ids = vec![1, 7];
287        let mut rng = create_rng();
288        let mut signers: Vec<Signer> = signer_ids
289            .iter()
290            .enumerate()
291            .map(|(id, ids)| Signer::new(id.try_into().unwrap(), ids, Nk, Np, T, &mut rng))
292            .collect();
293
294        match compute_secrets_missing_private_shares(&mut signers, &mut rng, &missing_key_ids) {
295            Ok(polys) => panic!("Got a result with missing public shares: {polys:?}"),
296            Err(secret_errors) => {
297                for (_, error) in secret_errors {
298                    assert!(matches!(error, DkgError::MissingPrivateShares(_)));
299                }
300            }
301        }
302    }
303
304    /// Check that bad polynomial lengths are properly caught as errors during DKG
305    pub fn bad_polynomial_length<Signer: super::Signer, F: Fn(u32) -> u32>(func: F) {
306        let num_keys: u32 = 10;
307        let num_signers: u32 = 4;
308        let threshold: u32 = 7;
309        let signer_ids: Vec<Vec<u32>> = vec![vec![1, 2, 3, 4], vec![5, 6, 7], vec![8, 9], vec![10]];
310        let mut rng = create_rng();
311        let mut signers: Vec<Signer> = signer_ids
312            .iter()
313            .enumerate()
314            .map(|(id, ids)| {
315                if *ids == vec![10] {
316                    Signer::new(
317                        id.try_into().unwrap(),
318                        ids,
319                        num_signers,
320                        num_keys,
321                        func(threshold),
322                        &mut rng,
323                    )
324                } else {
325                    Signer::new(
326                        id.try_into().unwrap(),
327                        ids,
328                        num_signers,
329                        num_keys,
330                        threshold,
331                        &mut rng,
332                    )
333                }
334            })
335            .collect();
336
337        if dkg(&mut signers, &mut rng).is_ok() {
338            panic!("DKG should have failed")
339        }
340    }
341
342    /// Check that bad polynomial commitments are properly caught as errors during DKG
343    pub fn bad_polynomial_commitment<Signer: super::Signer>() {
344        let num_keys: u32 = 10;
345        let num_signers: u32 = 4;
346        let threshold: u32 = 7;
347        let signer_ids: Vec<Vec<u32>> = vec![vec![1, 2, 3, 4], vec![5, 6, 7], vec![8, 9], vec![10]];
348        let mut rng = create_rng();
349        let mut signers: Vec<Signer> = signer_ids
350            .iter()
351            .enumerate()
352            .map(|(id, ids)| {
353                Signer::new(
354                    id.try_into().unwrap(),
355                    ids,
356                    num_signers,
357                    num_keys,
358                    threshold,
359                    &mut rng,
360                )
361            })
362            .collect();
363
364        // The code that follows is essentially the same code that we have
365        // in the `dkg` helper function above, except we've corrupted the
366        // schnorr proof so that we can test verification would fail at
367        // the end.
368        let ctx = 0u64.to_be_bytes();
369        let bad_party_id = 2u32;
370        let public_shares: HashMap<u32, PolyCommitment> = signers
371            .iter()
372            .flat_map(|s| s.get_poly_commitments(&ctx, &mut rng))
373            .map(|comm| {
374                let party_id = comm.id.id.get_u32();
375                if party_id == bad_party_id {
376                    // alter the schnorr proof so it will fail verification
377                    let mut bad_comm = comm.clone();
378                    bad_comm.id.proof.s += Scalar::from(1);
379                    (party_id, bad_comm)
380                } else {
381                    (party_id, comm)
382                }
383            })
384            .collect();
385        let mut private_shares = HashMap::new();
386
387        for signer in signers.iter() {
388            for (signer_id, signer_shares) in signer.get_shares() {
389                private_shares.insert(signer_id, signer_shares);
390            }
391        }
392
393        let mut secret_errors = HashMap::new();
394        for signer in signers.iter_mut() {
395            if let Err(signer_secret_errors) =
396                signer.compute_secrets(&private_shares, &public_shares, &ctx)
397            {
398                secret_errors.extend(signer_secret_errors.into_iter());
399            }
400        }
401
402        assert!(!secret_errors.is_empty());
403    }
404}