wsts/
v2.rs

1use std::fmt;
2
3use hashbrown::{HashMap, HashSet};
4use num_traits::{One, Zero};
5use polynomial::Polynomial;
6use rand_core::{CryptoRng, RngCore};
7use tracing::warn;
8
9use crate::{
10    common::{check_public_shares, Nonce, PolyCommitment, PublicNonce, Signature, SignatureShare},
11    compute,
12    curve::{
13        point::{Point, G},
14        scalar::Scalar,
15    },
16    errors::{AggregatorError, DkgError},
17    schnorr::ID,
18    taproot::SchnorrProof,
19    traits,
20    vss::VSS,
21};
22
23#[derive(Clone, Eq, PartialEq)]
24/// A WSTS party, which encapsulates a single polynomial, nonce, and one private key per key ID
25pub struct Party {
26    /// The party ID
27    pub party_id: u32,
28    /// The key IDs for this party
29    pub key_ids: Vec<u32>,
30    /// The public keys for this party, indexed by ID
31    num_keys: u32,
32    num_parties: u32,
33    threshold: u32,
34    f: Option<Polynomial<Scalar>>,
35    private_keys: HashMap<u32, Scalar>,
36    group_key: Point,
37    nonce: Nonce,
38}
39
40impl fmt::Debug for Party {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        f.debug_struct("Party")
43            .field("part_id", &self.party_id)
44            .field("key_ids", &self.key_ids)
45            .field("num_keys", &self.num_keys)
46            .field("num_parties", &self.num_parties)
47            .field("threshold", &self.threshold)
48            .field("group_key", &self.group_key)
49            .finish_non_exhaustive()
50    }
51}
52
53impl Party {
54    /// Construct a random Party with the passed party ID, key IDs, and parameters
55    pub fn new<RNG: RngCore + CryptoRng>(
56        party_id: u32,
57        key_ids: &[u32],
58        num_parties: u32,
59        num_keys: u32,
60        threshold: u32,
61        rng: &mut RNG,
62    ) -> Self {
63        // create a nonce using just the passed RNG to avoid having a zero nonce used against us
64        let secret_key = Scalar::random(rng);
65        let nonce = Nonce::random(&secret_key, rng);
66
67        Self {
68            party_id,
69            key_ids: key_ids.to_vec(),
70            num_keys,
71            num_parties,
72            threshold,
73            f: Some(VSS::random_poly(threshold - 1, rng)),
74            private_keys: Default::default(),
75            group_key: Point::zero(),
76            nonce,
77        }
78    }
79
80    /// Generate and store a private nonce for a signing round
81    pub fn gen_nonce<RNG: RngCore + CryptoRng>(
82        &mut self,
83        secret_key: &Scalar,
84        rng: &mut RNG,
85    ) -> PublicNonce {
86        self.nonce = Nonce::random(secret_key, rng);
87        PublicNonce::from(&self.nonce)
88    }
89
90    /// Get a public commitment to the private polynomial
91    pub fn get_poly_commitment<RNG: RngCore + CryptoRng>(
92        &self,
93        ctx: &[u8],
94        rng: &mut RNG,
95    ) -> Option<PolyCommitment> {
96        if let Some(poly) = &self.f {
97            Some(PolyCommitment {
98                id: ID::new(&self.id(), &poly.data()[0], ctx, rng),
99                poly: (0..poly.data().len())
100                    .map(|i| &poly.data()[i] * G)
101                    .collect(),
102            })
103        } else {
104            warn!("get_poly_commitment called with no polynomial");
105            None
106        }
107    }
108
109    /// Get the shares of this party's private polynomial for all keys
110    pub fn get_shares(&self) -> HashMap<u32, Scalar> {
111        let mut shares = HashMap::new();
112        if let Some(poly) = &self.f {
113            for i in 1..self.num_keys + 1 {
114                shares.insert(i, poly.eval(compute::id(i)));
115            }
116        } else {
117            warn!("get_poly_commitment called with no polynomial");
118        }
119        shares
120    }
121
122    /// Compute this party's share of the group secret key, but first check that the data is valid
123    /// and consistent.  This raises an issue though: what if we have private_shares and
124    /// public_shares from different parties?
125    /// To resolve the ambiguity, assume that the public_shares represent the correct group of
126    /// parties.  
127    pub fn compute_secret(
128        &mut self,
129        private_shares: &HashMap<u32, HashMap<u32, Scalar>>,
130        public_shares: &HashMap<u32, PolyCommitment>,
131        ctx: &[u8],
132    ) -> Result<(), DkgError> {
133        self.private_keys.clear();
134        self.group_key = Point::zero();
135
136        let threshold: usize = self.threshold.try_into()?;
137
138        let mut bad_ids = Vec::new();
139        for (i, comm) in public_shares.iter() {
140            if !check_public_shares(comm, threshold, ctx) {
141                bad_ids.push(*i);
142            } else {
143                self.group_key += comm.poly[0];
144            }
145        }
146        if !bad_ids.is_empty() {
147            return Err(DkgError::BadPublicShares(bad_ids));
148        }
149
150        let mut missing_shares = Vec::new();
151        for dst_key_id in &self.key_ids {
152            for src_key_id in public_shares.keys() {
153                match private_shares.get(dst_key_id) {
154                    Some(shares) => {
155                        if shares.get(src_key_id).is_none() {
156                            missing_shares.push((*dst_key_id, *src_key_id));
157                        }
158                    }
159                    None => {
160                        missing_shares.push((*dst_key_id, *src_key_id));
161                    }
162                }
163            }
164        }
165        if !missing_shares.is_empty() {
166            return Err(DkgError::MissingPrivateShares(missing_shares));
167        }
168
169        let mut bad_shares = Vec::new();
170        for key_id in &self.key_ids {
171            if let Some(shares) = private_shares.get(key_id) {
172                for (sender, s) in shares {
173                    if let Some(comm) = public_shares.get(sender) {
174                        if s * G != compute::poly(&compute::id(*key_id), &comm.poly)? {
175                            bad_shares.push(*sender);
176                        }
177                    } else {
178                        warn!("unable to check private share from {}: no corresponding public share, even though we checked for it above", sender);
179                    }
180                }
181            } else {
182                warn!(
183                    "no private shares for key_id {}, even though we checked for it above",
184                    key_id
185                );
186            }
187        }
188        if !bad_shares.is_empty() {
189            return Err(DkgError::BadPrivateShares(bad_shares));
190        }
191
192        for key_id in &self.key_ids {
193            self.private_keys.insert(*key_id, Scalar::zero());
194            if let Some(shares) = private_shares.get(key_id) {
195                let secret = shares.values().sum();
196                self.private_keys.insert(*key_id, secret);
197            } else {
198                warn!(
199                    "no private shares for key_id {}, even though we checked for it above",
200                    key_id
201                );
202            }
203        }
204
205        Ok(())
206    }
207
208    /// Compute a Scalar from this party's ID
209    pub fn id(&self) -> Scalar {
210        compute::id(self.party_id)
211    }
212
213    /// Sign `msg` with this party's shares of the group private key, using the set of `party_ids`, `key_ids` and corresponding `nonces`
214    pub fn sign(
215        &self,
216        msg: &[u8],
217        party_ids: &[u32],
218        key_ids: &[u32],
219        nonces: &[PublicNonce],
220    ) -> SignatureShare {
221        self.sign_with_tweak(msg, party_ids, key_ids, nonces, None)
222    }
223
224    /// Sign `msg` with this party's shares of the group private key, using the set of `party_ids`, `key_ids` and corresponding `nonces` with a tweaked public key. The posible values for tweak are
225    /// None    - standard FROST signature
226    /// Some(0) - BIP-340 schnorr signature using 32-byte private key adjustments
227    /// Some(t) - BIP-340 schnorr signature with BIP-341 tweaked keys, using 32-byte private key adjustments
228    #[allow(non_snake_case)]
229    pub fn sign_with_tweak(
230        &self,
231        msg: &[u8],
232        party_ids: &[u32],
233        key_ids: &[u32],
234        nonces: &[PublicNonce],
235        tweak: Option<Scalar>,
236    ) -> SignatureShare {
237        // When using BIP-340 32-byte public keys, we have to invert the private key if the
238        // public key is odd.  But if we're also using BIP-341 tweaked keys, we have to do
239        // the same thing if the tweaked public key is odd.  In that case, only invert the
240        // public key if exactly one of the internal or tweaked public keys is odd
241        let mut cx_sign = Scalar::one();
242        let tweaked_public_key = if let Some(t) = tweak {
243            if t != Scalar::zero() {
244                let key = compute::tweaked_public_key_from_tweak(&self.group_key, t);
245                if key.has_even_y() ^ self.group_key.has_even_y() {
246                    cx_sign = -cx_sign;
247                }
248
249                key
250            } else {
251                if !self.group_key.has_even_y() {
252                    cx_sign = -cx_sign;
253                }
254                self.group_key
255            }
256        } else {
257            self.group_key
258        };
259        let (_, R) = compute::intermediate(msg, self.group_key, party_ids, nonces);
260        let c = compute::challenge(&tweaked_public_key, &R, msg);
261        let commitment_list: Vec<(Scalar, PublicNonce)> = party_ids
262            .iter()
263            .zip(nonces)
264            .map(|(id, nonce)| (Scalar::from(*id), nonce.clone()))
265            .collect();
266        let mut r = &self.nonce.d
267            + &self.nonce.e * compute::binding(&self.id(), self.group_key, &commitment_list, msg);
268        if tweak.is_some() && !R.has_even_y() {
269            r = -r;
270        }
271        let mut cx = Scalar::zero();
272        for key_id in self.key_ids.iter() {
273            cx += c * &self.private_keys[key_id] * compute::lambda(*key_id, key_ids);
274        }
275
276        cx = cx_sign * cx;
277
278        let z = r + cx;
279
280        SignatureShare {
281            id: self.party_id,
282            z_i: z,
283            key_ids: self.key_ids.clone(),
284        }
285    }
286}
287
288/// The group signature aggregator
289#[derive(Clone, Debug, PartialEq)]
290pub struct Aggregator {
291    /// The total number of keys
292    pub num_keys: u32,
293    /// The threshold of signing keys needed to construct a valid signature
294    pub threshold: u32,
295    /// The aggregate group polynomial; `poly[0]` is the group public key
296    pub poly: Vec<Point>,
297}
298
299impl Aggregator {
300    /// Aggregate the party signatures using a tweak.  The posible values for tweak are
301    /// None    - standard FROST signature
302    /// Some(0) - BIP-340 schnorr signature using 32-byte private key adjustments
303    /// Some(t) - BIP-340 schnorr signature with BIP-341 tweaked keys, using 32-byte private key adjustments
304    #[allow(non_snake_case)]
305    pub fn sign_with_tweak(
306        &mut self,
307        msg: &[u8],
308        nonces: &[PublicNonce],
309        sig_shares: &[SignatureShare],
310        _key_ids: &[u32],
311        tweak: Option<Scalar>,
312    ) -> Result<(Point, Signature), AggregatorError> {
313        if nonces.len() != sig_shares.len() {
314            return Err(AggregatorError::BadNonceLen(nonces.len(), sig_shares.len()));
315        }
316
317        let party_ids: Vec<u32> = sig_shares.iter().map(|ss| ss.id).collect();
318        let (_Rs, R) = compute::intermediate(msg, self.poly[0], &party_ids, nonces);
319        let mut z = Scalar::zero();
320        let mut cx_sign = Scalar::one();
321        let aggregate_public_key = self.poly[0];
322        let tweaked_public_key = if let Some(t) = tweak {
323            if t != Scalar::zero() {
324                let key = compute::tweaked_public_key_from_tweak(&aggregate_public_key, t);
325                if !key.has_even_y() {
326                    cx_sign = -cx_sign;
327                }
328                key
329            } else {
330                aggregate_public_key
331            }
332        } else {
333            aggregate_public_key
334        };
335        let c = compute::challenge(&tweaked_public_key, &R, msg);
336        // optimistically try to create the aggregate signature without checking for bad keys or sig shares
337        for sig_share in sig_shares {
338            z += sig_share.z_i;
339        }
340
341        // The signature shares have already incorporated the private key adjustments, so we just have to add the tweak.  But the tweak itself needs to be adjusted if the tweaked public key is odd
342        if let Some(t) = tweak {
343            z += cx_sign * c * t;
344        }
345
346        let sig = Signature { R, z };
347
348        Ok((tweaked_public_key, sig))
349    }
350
351    /// Check the party signatures after a failed group signature. The posible values for tweak are
352    /// None    - standard FROST signature
353    /// Some(0) - BIP-340 schnorr signature using 32-byte private key adjustments
354    /// Some(t) - BIP-340 schnorr signature with BIP-341 tweaked keys, using 32-byte private key adjustments
355    #[allow(non_snake_case)]
356    pub fn check_signature_shares(
357        &mut self,
358        msg: &[u8],
359        nonces: &[PublicNonce],
360        sig_shares: &[SignatureShare],
361        key_ids: &[u32],
362        tweak: Option<Scalar>,
363    ) -> AggregatorError {
364        if nonces.len() != sig_shares.len() {
365            return AggregatorError::BadNonceLen(nonces.len(), sig_shares.len());
366        }
367
368        let party_ids: Vec<u32> = sig_shares.iter().map(|ss| ss.id).collect();
369        let (Rs, R) = compute::intermediate(msg, self.poly[0], &party_ids, nonces);
370        let mut bad_party_keys = Vec::new();
371        let mut bad_party_sigs = Vec::new();
372        let aggregate_public_key = self.poly[0];
373        let tweaked_public_key = if let Some(t) = tweak {
374            if t != Scalar::zero() {
375                compute::tweaked_public_key_from_tweak(&aggregate_public_key, t)
376            } else {
377                aggregate_public_key
378            }
379        } else {
380            aggregate_public_key
381        };
382        let c = compute::challenge(&tweaked_public_key, &R, msg);
383        let mut r_sign = Scalar::one();
384        let mut cx_sign = Scalar::one();
385        if let Some(t) = tweak {
386            if !R.has_even_y() {
387                r_sign = -Scalar::one();
388            }
389            if t != Scalar::zero() {
390                if !tweaked_public_key.has_even_y() ^ !aggregate_public_key.has_even_y() {
391                    cx_sign = -Scalar::one();
392                }
393            } else if !aggregate_public_key.has_even_y() {
394                cx_sign = -Scalar::one();
395            }
396        }
397
398        for i in 0..sig_shares.len() {
399            let z_i = sig_shares[i].z_i;
400            let mut cx = Point::zero();
401
402            for key_id in &sig_shares[i].key_ids {
403                let kid = compute::id(*key_id);
404                let public_key = match compute::poly(&kid, &self.poly) {
405                    Ok(p) => p,
406                    Err(_) => {
407                        bad_party_keys.push(sig_shares[i].id);
408                        Point::zero()
409                    }
410                };
411
412                cx += compute::lambda(*key_id, key_ids) * c * public_key;
413            }
414
415            if z_i * G != (r_sign * Rs[i] + cx_sign * cx) {
416                bad_party_sigs.push(sig_shares[i].id);
417            }
418        }
419        if !bad_party_keys.is_empty() {
420            AggregatorError::BadPartyKeys(bad_party_keys)
421        } else if !bad_party_sigs.is_empty() {
422            AggregatorError::BadPartySigs(bad_party_sigs)
423        } else {
424            AggregatorError::BadGroupSig
425        }
426    }
427}
428
429impl traits::Aggregator for Aggregator {
430    /// Construct an Aggregator with the passed parameters
431    fn new(num_keys: u32, threshold: u32) -> Self {
432        Self {
433            num_keys,
434            threshold,
435            poly: Default::default(),
436        }
437    }
438
439    /// Initialize the Aggregator polynomial
440    fn init(&mut self, comms: &HashMap<u32, PolyCommitment>) -> Result<(), AggregatorError> {
441        let threshold: usize = self.threshold.try_into()?;
442        let mut poly = Vec::with_capacity(threshold);
443
444        for i in 0..poly.capacity() {
445            poly.push(Point::zero());
446            for (_, comm) in comms {
447                poly[i] += &comm.poly[i];
448            }
449        }
450
451        self.poly = poly;
452
453        Ok(())
454    }
455
456    /// Check and aggregate the party signatures
457    fn sign(
458        &mut self,
459        msg: &[u8],
460        nonces: &[PublicNonce],
461        sig_shares: &[SignatureShare],
462        key_ids: &[u32],
463    ) -> Result<Signature, AggregatorError> {
464        let (key, sig) = self.sign_with_tweak(msg, nonces, sig_shares, key_ids, None)?;
465
466        if sig.verify(&key, msg) {
467            Ok(sig)
468        } else {
469            Err(self.check_signature_shares(msg, nonces, sig_shares, key_ids, None))
470        }
471    }
472
473    /// Check and aggregate the party signatures
474    fn sign_schnorr(
475        &mut self,
476        msg: &[u8],
477        nonces: &[PublicNonce],
478        sig_shares: &[SignatureShare],
479        key_ids: &[u32],
480    ) -> Result<SchnorrProof, AggregatorError> {
481        let tweak = Scalar::from(0);
482        let (key, sig) = self.sign_with_tweak(msg, nonces, sig_shares, key_ids, Some(tweak))?;
483        let proof = SchnorrProof::new(&sig);
484
485        if proof.verify(&key.x(), msg) {
486            Ok(proof)
487        } else {
488            Err(self.check_signature_shares(msg, nonces, sig_shares, key_ids, Some(tweak)))
489        }
490    }
491
492    /// Check and aggregate the party signatures
493    fn sign_taproot(
494        &mut self,
495        msg: &[u8],
496        nonces: &[PublicNonce],
497        sig_shares: &[SignatureShare],
498        key_ids: &[u32],
499        merkle_root: Option<[u8; 32]>,
500    ) -> Result<SchnorrProof, AggregatorError> {
501        let tweak = compute::tweak(&self.poly[0], merkle_root);
502        let (key, sig) = self.sign_with_tweak(msg, nonces, sig_shares, key_ids, Some(tweak))?;
503        let proof = SchnorrProof::new(&sig);
504
505        if proof.verify(&key.x(), msg) {
506            Ok(proof)
507        } else {
508            Err(self.check_signature_shares(msg, nonces, sig_shares, key_ids, Some(tweak)))
509        }
510    }
511}
512
513/// Typedef so we can use the same tokens for v1 and v2
514pub type Signer = Party;
515
516impl traits::Signer for Party {
517    fn new<RNG: RngCore + CryptoRng>(
518        party_id: u32,
519        key_ids: &[u32],
520        num_signers: u32,
521        num_keys: u32,
522        threshold: u32,
523        rng: &mut RNG,
524    ) -> Self {
525        Party::new(party_id, key_ids, num_signers, num_keys, threshold, rng)
526    }
527
528    fn load(state: &traits::SignerState) -> Self {
529        // v2 signer contains single party
530        assert_eq!(state.parties.len(), 1);
531
532        let party_state = &state.parties[0].1;
533
534        Self {
535            party_id: state.id,
536            key_ids: state.key_ids.clone(),
537            num_keys: state.num_keys,
538            num_parties: state.num_parties,
539            threshold: state.threshold,
540            f: party_state.polynomial.clone(),
541            private_keys: party_state
542                .private_keys
543                .iter()
544                .map(|(k, v)| (*k, *v))
545                .collect(),
546            group_key: state.group_key,
547            nonce: party_state.nonce.clone(),
548        }
549    }
550
551    fn save(&self) -> traits::SignerState {
552        let party_state = traits::PartyState {
553            polynomial: self.f.clone(),
554            private_keys: self.private_keys.iter().map(|(k, v)| (*k, *v)).collect(),
555            nonce: self.nonce.clone(),
556        };
557        traits::SignerState {
558            id: self.party_id,
559            key_ids: self.key_ids.clone(),
560            num_keys: self.num_keys,
561            num_parties: self.num_parties,
562            threshold: self.threshold,
563            group_key: self.group_key,
564            parties: vec![(self.party_id, party_state)],
565        }
566    }
567
568    fn get_id(&self) -> u32 {
569        self.party_id
570    }
571
572    fn get_key_ids(&self) -> Vec<u32> {
573        self.key_ids.clone()
574    }
575
576    fn get_num_parties(&self) -> u32 {
577        self.num_parties
578    }
579
580    fn get_poly_commitments<RNG: RngCore + CryptoRng>(
581        &self,
582        ctx: &[u8],
583        rng: &mut RNG,
584    ) -> Vec<PolyCommitment> {
585        if let Some(poly) = self.get_poly_commitment(ctx, rng) {
586            vec![poly.clone()]
587        } else {
588            vec![]
589        }
590    }
591
592    fn reset_polys<RNG: RngCore + CryptoRng>(&mut self, rng: &mut RNG) {
593        self.f = Some(VSS::random_poly(self.threshold - 1, rng));
594    }
595
596    fn clear_polys(&mut self) {
597        self.f = None;
598    }
599
600    fn get_shares(&self) -> HashMap<u32, HashMap<u32, Scalar>> {
601        let mut shares = HashMap::new();
602
603        shares.insert(self.party_id, self.get_shares());
604
605        shares
606    }
607
608    fn compute_secrets(
609        &mut self,
610        private_shares: &HashMap<u32, HashMap<u32, Scalar>>,
611        polys: &HashMap<u32, PolyCommitment>,
612        ctx: &[u8],
613    ) -> Result<(), HashMap<u32, DkgError>> {
614        // go through the shares, looking for this party's
615        let mut key_shares = HashMap::new();
616        for dest_key_id in self.get_key_ids() {
617            let mut shares = HashMap::new();
618            for (src_party_id, signer_shares) in private_shares.iter() {
619                if let Some(signer_share) = signer_shares.get(&dest_key_id) {
620                    shares.insert(*src_party_id, *signer_share);
621                }
622            }
623            key_shares.insert(dest_key_id, shares);
624        }
625
626        match self.compute_secret(&key_shares, polys, ctx) {
627            Ok(()) => Ok(()),
628            Err(dkg_error) => {
629                let mut dkg_errors = HashMap::new();
630                dkg_errors.insert(self.party_id, dkg_error);
631                Err(dkg_errors)
632            }
633        }
634    }
635
636    fn gen_nonces<RNG: RngCore + CryptoRng>(
637        &mut self,
638        secret_key: &Scalar,
639        rng: &mut RNG,
640    ) -> Vec<PublicNonce> {
641        vec![self.gen_nonce(secret_key, rng)]
642    }
643
644    fn compute_intermediate(
645        &self,
646        msg: &[u8],
647        signer_ids: &[u32],
648        _key_ids: &[u32],
649        nonces: &[PublicNonce],
650    ) -> (Vec<Point>, Point) {
651        compute::intermediate(msg, self.group_key, signer_ids, nonces)
652    }
653
654    fn validate_party_id(
655        signer_id: u32,
656        party_id: u32,
657        _signer_key_ids: &HashMap<u32, HashSet<u32>>,
658    ) -> bool {
659        signer_id == party_id
660    }
661
662    fn sign(
663        &self,
664        msg: &[u8],
665        signer_ids: &[u32],
666        key_ids: &[u32],
667        nonces: &[PublicNonce],
668    ) -> Vec<SignatureShare> {
669        vec![self.sign(msg, signer_ids, key_ids, nonces)]
670    }
671
672    fn sign_schnorr(
673        &self,
674        msg: &[u8],
675        signer_ids: &[u32],
676        key_ids: &[u32],
677        nonces: &[PublicNonce],
678    ) -> Vec<SignatureShare> {
679        vec![self.sign_with_tweak(msg, signer_ids, key_ids, nonces, Some(Scalar::from(0)))]
680    }
681
682    fn sign_taproot(
683        &self,
684        msg: &[u8],
685        signer_ids: &[u32],
686        key_ids: &[u32],
687        nonces: &[PublicNonce],
688        merkle_root: Option<[u8; 32]>,
689    ) -> Vec<SignatureShare> {
690        let tweak = compute::tweak(&self.group_key, merkle_root);
691        vec![self.sign_with_tweak(msg, signer_ids, key_ids, nonces, Some(tweak))]
692    }
693}
694
695/// Helper functions for tests
696pub mod test_helpers {
697    use super::Scalar;
698    use crate::common::{PolyCommitment, PublicNonce};
699    use crate::errors::DkgError;
700    use crate::traits::Signer;
701    use crate::v2;
702    use crate::v2::SignatureShare;
703
704    use hashbrown::HashMap;
705    use rand_core::{CryptoRng, RngCore};
706
707    /// Run a distributed key generation round
708    pub fn dkg<RNG: RngCore + CryptoRng>(
709        signers: &mut [v2::Party],
710        rng: &mut RNG,
711    ) -> Result<HashMap<u32, PolyCommitment>, HashMap<u32, DkgError>> {
712        let ctx = 0u64.to_be_bytes();
713        let mut polys: HashMap<u32, PolyCommitment> = Default::default();
714        for signer in signers.iter() {
715            if let Some(poly) = signer.get_poly_commitment(&ctx, rng) {
716                polys.insert(signer.get_id(), poly);
717            }
718        }
719
720        // each party broadcasts their commitments
721        let mut broadcast_shares = Vec::new();
722        for party in signers.iter() {
723            broadcast_shares.push((party.party_id, party.get_shares()));
724        }
725
726        // each party collects its shares from the broadcasts
727        // maybe this should collect into a hashmap first?
728        let mut secret_errors = HashMap::new();
729        for party in signers.iter_mut() {
730            let mut party_shares = HashMap::new();
731            for key_id in party.key_ids.clone() {
732                let mut key_shares = HashMap::new();
733
734                for (id, shares) in &broadcast_shares {
735                    if let Some(share) = shares.get(&key_id) {
736                        key_shares.insert(*id, *share);
737                    }
738                }
739
740                party_shares.insert(key_id, key_shares);
741            }
742
743            if let Err(secret_error) = party.compute_secret(&party_shares, &polys, &ctx) {
744                secret_errors.insert(party.party_id, secret_error);
745            }
746        }
747
748        if secret_errors.is_empty() {
749            Ok(polys)
750        } else {
751            Err(secret_errors)
752        }
753    }
754
755    /// Run a signing round for the passed `msg`
756    pub fn sign<RNG: RngCore + CryptoRng>(
757        msg: &[u8],
758        signers: &mut [v2::Party],
759        rng: &mut RNG,
760    ) -> (Vec<PublicNonce>, Vec<SignatureShare>, Vec<u32>) {
761        let secret_key = Scalar::random(rng);
762        let party_ids: Vec<u32> = signers.iter().map(|s| s.party_id).collect();
763        let key_ids: Vec<u32> = signers.iter().flat_map(|s| s.key_ids.clone()).collect();
764        let nonces: Vec<PublicNonce> = signers
765            .iter_mut()
766            .map(|s| s.gen_nonce(&secret_key, rng))
767            .collect();
768        let shares = signers
769            .iter()
770            .map(|s| s.sign(msg, &party_ids, &key_ids, &nonces))
771            .collect();
772
773        (nonces, shares, key_ids)
774    }
775}
776
777#[cfg(test)]
778mod tests {
779    use hashbrown::{HashMap, HashSet};
780    use num_traits::Zero;
781
782    use crate::util::create_rng;
783    use crate::{
784        curve::scalar::Scalar,
785        traits::{
786            self, test_helpers::run_compute_secrets_missing_private_shares, Aggregator, Signer,
787        },
788        v2,
789    };
790
791    #[test]
792    fn signer_gen_nonces() {
793        let mut rng = create_rng();
794        let secret_key = Scalar::random(&mut rng);
795        let id = 1;
796        let key_ids = [1, 2, 3];
797        let n: u32 = 10;
798        let p: u32 = 3;
799        let t: u32 = 7;
800
801        let mut signer = v2::Signer::new(id, &key_ids, p, n, t, &mut rng);
802
803        assert!(!signer.nonce.is_zero());
804        assert!(signer.nonce.is_valid());
805
806        let nonces = signer.gen_nonces(&secret_key, &mut rng);
807
808        assert_eq!(nonces.len(), 1);
809
810        assert!(!signer.nonce.is_zero());
811        assert!(signer.nonce.is_valid());
812    }
813
814    #[test]
815    fn party_save_load() {
816        let mut rng = create_rng();
817        let key_ids = [1, 2, 3];
818        let n: u32 = 10;
819        let t: u32 = 7;
820
821        let signer = v2::Party::new(0, &key_ids, 1, n, t, &mut rng);
822
823        let state = signer.save();
824        let loaded = v2::Party::load(&state);
825
826        assert_eq!(signer, loaded);
827    }
828
829    #[test]
830    fn clear_polys() {
831        let ctx = 0u64.to_be_bytes();
832        let mut rng = create_rng();
833        let key_ids = [1, 2, 3];
834        let n: u32 = 10;
835        let t: u32 = 7;
836
837        let mut signer = v2::Party::new(0, &key_ids, 1, n, t, &mut rng);
838
839        assert_eq!(signer.get_poly_commitments(&ctx, &mut rng).len(), 1);
840        assert_eq!(signer.get_shares().len(), usize::try_from(n).unwrap());
841
842        signer.clear_polys();
843
844        assert_eq!(signer.get_poly_commitments(&ctx, &mut rng).len(), 0);
845        assert_eq!(signer.get_shares().len(), 0);
846    }
847
848    #[test]
849    fn aggregator_sign() {
850        let mut rng = create_rng();
851        let msg = "It was many and many a year ago".as_bytes();
852        let n_k: u32 = 10;
853        let t: u32 = 7;
854        let party_key_ids: Vec<Vec<u32>> = [
855            [1, 2, 3].to_vec(),
856            [4, 5].to_vec(),
857            [6, 7, 8].to_vec(),
858            [9, 10].to_vec(),
859        ]
860        .to_vec();
861        let n_p = party_key_ids.len().try_into().unwrap();
862        let mut signers: Vec<v2::Party> = party_key_ids
863            .iter()
864            .enumerate()
865            .map(|(pid, pkids)| {
866                v2::Party::new(pid.try_into().unwrap(), pkids, n_p, n_k, t, &mut rng)
867            })
868            .collect();
869
870        let comms = match traits::test_helpers::dkg(&mut signers, &mut rng) {
871            Ok(comms) => comms,
872            Err(secret_errors) => {
873                panic!("Got secret errors from DKG: {secret_errors:?}");
874            }
875        };
876
877        // signers [0,1,3] who have t keys
878        {
879            let mut signers = [signers[0].clone(), signers[1].clone(), signers[3].clone()].to_vec();
880            let mut sig_agg = v2::Aggregator::new(n_k, t);
881
882            sig_agg.init(&comms).expect("aggregator init failed");
883
884            let (nonces, sig_shares, key_ids) = v2::test_helpers::sign(msg, &mut signers, &mut rng);
885            if let Err(e) = sig_agg.sign(msg, &nonces, &sig_shares, &key_ids) {
886                panic!("Aggregator sign failed: {e:?}");
887            }
888        }
889    }
890
891    #[test]
892    /// Run a distributed key generation round with not enough shares
893    pub fn run_compute_secrets_missing_shares() {
894        run_compute_secrets_missing_private_shares::<v2::Signer>()
895    }
896
897    #[test]
898    /// Run DKG and aggregator init with a bad polynomial length
899    pub fn bad_polynomial_length() {
900        let gt = |t| t + 1;
901        let lt = |t| t - 1;
902        traits::test_helpers::bad_polynomial_length::<v2::Signer, _>(gt);
903        traits::test_helpers::bad_polynomial_length::<v2::Signer, _>(lt);
904    }
905
906    #[test]
907    /// Run DKG and aggregator init with a bad polynomial commitment
908    pub fn bad_polynomial_commitment() {
909        traits::test_helpers::bad_polynomial_commitment::<v2::Signer>();
910    }
911
912    #[test]
913    /// Check that party_ids can be properly validated
914    fn validate_party_id() {
915        let mut signer_key_ids = HashMap::new();
916        let mut key_ids = HashSet::new();
917
918        key_ids.insert(1);
919        signer_key_ids.insert(0, key_ids);
920
921        assert!(v2::Signer::validate_party_id(0, 0, &signer_key_ids));
922        assert!(!v2::Signer::validate_party_id(0, 1, &signer_key_ids));
923    }
924}