wsts/
traits.rs

1use core::{cmp::PartialEq, fmt::Debug};
2use hashbrown::{HashMap, HashSet};
3use polynomial::Polynomial;
4use rand_core::{CryptoRng, RngCore};
5use serde::{Deserialize, Serialize};
6use std::fmt;
7
8use crate::{
9    common::{MerkleRoot, Nonce, PolyCommitment, PublicNonce, Signature, SignatureShare},
10    curve::{point::Point, scalar::Scalar},
11    errors::{AggregatorError, DkgError},
12    taproot::SchnorrProof,
13};
14
15#[derive(Clone, Deserialize, Serialize, PartialEq)]
16/// The saved state required to reconstruct a party
17pub struct PartyState {
18    /// The party's private polynomial
19    pub polynomial: Option<Polynomial<Scalar>>,
20    /// The key IDS and associate private keys for this party
21    pub private_keys: Vec<(u32, Scalar)>,
22    /// The nonce being used by this party
23    pub nonce: Nonce,
24}
25
26impl fmt::Debug for PartyState {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        f.debug_struct("PartyState").finish_non_exhaustive()
29    }
30}
31
32#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
33/// The saved state required to reconstruct a signer
34pub struct SignerState {
35    /// The signer ID
36    pub id: u32,
37    /// The key IDs this signer controls
38    pub key_ids: Vec<u32>,
39    /// The total number of keys
40    pub num_keys: u32,
41    /// The total number of parties
42    pub num_parties: u32,
43    /// The threshold for signing
44    pub threshold: u32,
45    /// The aggregate group public key
46    pub group_key: Point,
47    /// The party IDs and associated state for this signer
48    pub parties: Vec<(u32, PartyState)>,
49}
50
51/// A trait which provides a common `Signer` interface for `v1` and `v2`
52pub trait Signer: Clone + Debug + PartialEq {
53    /// Create a new `Signer`
54    fn new<RNG: RngCore + CryptoRng>(
55        party_id: u32,
56        key_ids: &[u32],
57        num_signers: u32,
58        num_keys: u32,
59        threshold: u32,
60        rng: &mut RNG,
61    ) -> Self;
62
63    /// Load a signer from the previously saved `state`
64    fn load(state: &SignerState) -> Self;
65
66    /// Save the state required to reconstruct the party
67    fn save(&self) -> SignerState;
68
69    /// Get the signer ID for this signer
70    fn get_id(&self) -> u32;
71
72    /// Get all key IDs for this signer
73    fn get_key_ids(&self) -> Vec<u32>;
74
75    /// Get the total number of parties
76    fn get_num_parties(&self) -> u32;
77
78    /// Get all poly commitments for this signer and the passed context
79    fn get_poly_commitments<RNG: RngCore + CryptoRng>(
80        &self,
81        ctx: &[u8],
82        rng: &mut RNG,
83    ) -> Vec<PolyCommitment>;
84
85    /// Reset all polynomials for this signer
86    fn reset_polys<RNG: RngCore + CryptoRng>(&mut self, rng: &mut RNG);
87
88    /// Clear all polynomials for this signer
89    fn clear_polys(&mut self);
90
91    /// Get all private shares for this signer
92    fn get_shares(&self) -> HashMap<u32, HashMap<u32, Scalar>>;
93
94    /// Compute all secrets for this signer
95    fn compute_secrets(
96        &mut self,
97        shares: &HashMap<u32, HashMap<u32, Scalar>>,
98        polys: &HashMap<u32, PolyCommitment>,
99        ctx: &[u8],
100    ) -> Result<(), HashMap<u32, DkgError>>;
101
102    /// Generate all nonces for this signer
103    fn gen_nonces<RNG: RngCore + CryptoRng>(
104        &mut self,
105        secret_key: &Scalar,
106        rng: &mut RNG,
107    ) -> Vec<PublicNonce>;
108
109    /// Compute intermediate values
110    fn compute_intermediate(
111        &self,
112        msg: &[u8],
113        signer_ids: &[u32],
114        key_ids: &[u32],
115        nonces: &[PublicNonce],
116    ) -> (Vec<Point>, Point);
117
118    /// Validate that signer_id owns party_id
119    fn validate_party_id(
120        signer_id: u32,
121        party_id: u32,
122        signer_key_ids: &HashMap<u32, HashSet<u32>>,
123    ) -> bool;
124
125    /// Sign `msg` using all this signer's keys
126    fn sign(
127        &self,
128        msg: &[u8],
129        signer_ids: &[u32],
130        key_ids: &[u32],
131        nonces: &[PublicNonce],
132    ) -> Vec<SignatureShare>;
133
134    /// Sign `msg` using all this signer's keys
135    fn sign_schnorr(
136        &self,
137        msg: &[u8],
138        signer_ids: &[u32],
139        key_ids: &[u32],
140        nonces: &[PublicNonce],
141    ) -> Vec<SignatureShare>;
142
143    /// Sign `msg` using all this signer's keys and a tweaked public key
144    fn sign_taproot(
145        &self,
146        msg: &[u8],
147        signer_ids: &[u32],
148        key_ids: &[u32],
149        nonces: &[PublicNonce],
150        merkle_root: Option<MerkleRoot>,
151    ) -> Vec<SignatureShare>;
152}
153
154/// A trait which provides a common `Aggregator` interface for `v1` and `v2`
155pub trait Aggregator: Clone + Debug + PartialEq {
156    /// Construct an Aggregator with the passed parameters
157    fn new(num_keys: u32, threshold: u32) -> Self;
158
159    /// Initialize an Aggregator with the passed polynomial commitments
160    fn init(&mut self, poly_comms: &HashMap<u32, PolyCommitment>) -> Result<(), AggregatorError>;
161
162    /// Check and aggregate the signature shares into a FROST `Signature`
163    fn sign(
164        &mut self,
165        msg: &[u8],
166        nonces: &[PublicNonce],
167        sig_shares: &[SignatureShare],
168        key_ids: &[u32],
169    ) -> Result<Signature, AggregatorError>;
170
171    /// Check and aggregate the signature shares into a BIP-340 `SchnorrProof`.
172    /// <https://github.com/bitcoin/bips/blob/master/bip-0340.mediawiki>
173    fn sign_schnorr(
174        &mut self,
175        msg: &[u8],
176        nonces: &[PublicNonce],
177        sig_shares: &[SignatureShare],
178        key_ids: &[u32],
179    ) -> Result<SchnorrProof, AggregatorError>;
180
181    /// Check and aggregate the signature shares into a BIP-340 `SchnorrProof` with BIP-341 key tweaks
182    /// <https://github.com/bitcoin/bips/blob/master/bip-0340.mediawiki>
183    /// <https://github.com/bitcoin/bips/blob/master/bip-0341.mediawiki>
184    fn sign_taproot(
185        &mut self,
186        msg: &[u8],
187        nonces: &[PublicNonce],
188        sig_shares: &[SignatureShare],
189        key_ids: &[u32],
190        merkle_root: Option<MerkleRoot>,
191    ) -> Result<SchnorrProof, AggregatorError>;
192}
193
194/// Helper functions for tests
195pub mod test_helpers {
196    use hashbrown::HashMap;
197    use rand_core::{CryptoRng, RngCore};
198
199    use crate::{common::PolyCommitment, errors::DkgError, traits::Scalar, util::create_rng};
200
201    /// Run DKG on the passed signers
202    pub fn dkg<RNG: RngCore + CryptoRng, Signer: super::Signer>(
203        signers: &mut [Signer],
204        rng: &mut RNG,
205    ) -> Result<HashMap<u32, PolyCommitment>, HashMap<u32, DkgError>> {
206        let ctx = 0u64.to_be_bytes();
207        let public_shares: HashMap<u32, PolyCommitment> = signers
208            .iter()
209            .flat_map(|s| s.get_poly_commitments(&ctx, rng))
210            .map(|comm| (comm.id.id.get_u32(), comm))
211            .collect();
212        let mut private_shares = HashMap::new();
213
214        for signer in signers.iter() {
215            for (signer_id, signer_shares) in signer.get_shares() {
216                private_shares.insert(signer_id, signer_shares);
217            }
218        }
219
220        let mut secret_errors = HashMap::new();
221        for signer in signers.iter_mut() {
222            if let Err(signer_secret_errors) =
223                signer.compute_secrets(&private_shares, &public_shares, &ctx)
224            {
225                secret_errors.extend(signer_secret_errors);
226            }
227        }
228
229        if secret_errors.is_empty() {
230            Ok(public_shares)
231        } else {
232            Err(secret_errors)
233        }
234    }
235
236    /// Remove the provided key ids from the list of private shares and execute compute secrets
237    fn compute_secrets_missing_private_shares<RNG: RngCore + CryptoRng, Signer: super::Signer>(
238        signers: &mut [Signer],
239        rng: &mut RNG,
240        missing_key_ids: &[u32],
241    ) -> Result<HashMap<u32, PolyCommitment>, HashMap<u32, DkgError>> {
242        assert!(
243            !missing_key_ids.is_empty(),
244            "Cannot run a missing shares test without specificying at least one missing key id"
245        );
246        let ctx = 0u64.to_be_bytes();
247        let polys: HashMap<u32, PolyCommitment> = signers
248            .iter()
249            .flat_map(|s| s.get_poly_commitments(&ctx, rng))
250            .map(|comm| (comm.id.id.get_u32(), comm))
251            .collect();
252        let mut private_shares = HashMap::new();
253
254        for signer in signers.iter() {
255            for (signer_id, mut signer_shares) in signer.get_shares() {
256                for key_id in missing_key_ids {
257                    if signer.get_key_ids().contains(key_id) {
258                        signer_shares.remove(key_id);
259                    }
260                }
261                private_shares.insert(signer_id, signer_shares);
262            }
263        }
264
265        let mut secret_errors = HashMap::new();
266        for signer in signers.iter_mut() {
267            if let Err(signer_secret_errors) = signer.compute_secrets(&private_shares, &polys, &ctx)
268            {
269                secret_errors.extend(signer_secret_errors);
270            }
271        }
272
273        if secret_errors.is_empty() {
274            Ok(polys)
275        } else {
276            Err(secret_errors)
277        }
278    }
279
280    #[allow(non_snake_case)]
281    /// Run compute secrets test to trigger MissingPrivateShares code path
282    pub fn run_compute_secrets_missing_private_shares<Signer: super::Signer>() {
283        let Nk: u32 = 10;
284        let Np: u32 = 4;
285        let T: u32 = 7;
286        let signer_ids: Vec<Vec<u32>> = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8], vec![9, 10]];
287        let missing_key_ids = vec![1, 7];
288        let mut rng = create_rng();
289        let mut signers: Vec<Signer> = signer_ids
290            .iter()
291            .enumerate()
292            .map(|(id, ids)| Signer::new(id.try_into().unwrap(), ids, Nk, Np, T, &mut rng))
293            .collect();
294
295        match compute_secrets_missing_private_shares(&mut signers, &mut rng, &missing_key_ids) {
296            Ok(polys) => panic!("Got a result with missing public shares: {polys:?}"),
297            Err(secret_errors) => {
298                for (_, error) in secret_errors {
299                    assert!(matches!(error, DkgError::MissingPrivateShares(_)));
300                }
301            }
302        }
303    }
304
305    /// Check that bad polynomial lengths are properly caught as errors during DKG
306    pub fn bad_polynomial_length<Signer: super::Signer, F: Fn(u32) -> u32>(func: F) {
307        let num_keys: u32 = 10;
308        let num_signers: u32 = 4;
309        let threshold: u32 = 7;
310        let signer_ids: Vec<Vec<u32>> = vec![vec![1, 2, 3, 4], vec![5, 6, 7], vec![8, 9], vec![10]];
311        let mut rng = create_rng();
312        let mut signers: Vec<Signer> = signer_ids
313            .iter()
314            .enumerate()
315            .map(|(id, ids)| {
316                if *ids == vec![10] {
317                    Signer::new(
318                        id.try_into().unwrap(),
319                        ids,
320                        num_signers,
321                        num_keys,
322                        func(threshold),
323                        &mut rng,
324                    )
325                } else {
326                    Signer::new(
327                        id.try_into().unwrap(),
328                        ids,
329                        num_signers,
330                        num_keys,
331                        threshold,
332                        &mut rng,
333                    )
334                }
335            })
336            .collect();
337
338        if dkg(&mut signers, &mut rng).is_ok() {
339            panic!("DKG should have failed")
340        }
341    }
342
343    /// Check that bad polynomial commitments are properly caught as errors during DKG
344    pub fn bad_polynomial_commitment<Signer: super::Signer>() {
345        let num_keys: u32 = 10;
346        let num_signers: u32 = 4;
347        let threshold: u32 = 7;
348        let signer_ids: Vec<Vec<u32>> = vec![vec![1, 2, 3, 4], vec![5, 6, 7], vec![8, 9], vec![10]];
349        let mut rng = create_rng();
350        let mut signers: Vec<Signer> = signer_ids
351            .iter()
352            .enumerate()
353            .map(|(id, ids)| {
354                Signer::new(
355                    id.try_into().unwrap(),
356                    ids,
357                    num_signers,
358                    num_keys,
359                    threshold,
360                    &mut rng,
361                )
362            })
363            .collect();
364
365        // The code that follows is essentially the same code that we have
366        // in the `dkg` helper function above, except we've corrupted the
367        // schnorr proof so that we can test verification would fail at
368        // the end.
369        let ctx = 0u64.to_be_bytes();
370        let bad_party_id = 2u32;
371        let public_shares: HashMap<u32, PolyCommitment> = signers
372            .iter()
373            .flat_map(|s| s.get_poly_commitments(&ctx, &mut rng))
374            .map(|comm| {
375                let party_id = comm.id.id.get_u32();
376                if party_id == bad_party_id {
377                    // alter the schnorr proof so it will fail verification
378                    let mut bad_comm = comm.clone();
379                    bad_comm.id.proof.s += Scalar::from(1);
380                    (party_id, bad_comm)
381                } else {
382                    (party_id, comm)
383                }
384            })
385            .collect();
386        let mut private_shares = HashMap::new();
387
388        for signer in signers.iter() {
389            for (signer_id, signer_shares) in signer.get_shares() {
390                private_shares.insert(signer_id, signer_shares);
391            }
392        }
393
394        let mut secret_errors = HashMap::new();
395        for signer in signers.iter_mut() {
396            if let Err(signer_secret_errors) =
397                signer.compute_secrets(&private_shares, &public_shares, &ctx)
398            {
399                secret_errors.extend(signer_secret_errors);
400            }
401        }
402
403        assert!(!secret_errors.is_empty());
404    }
405}