wsts/
v2.rs

1use std::fmt;
2
3use hashbrown::{HashMap, HashSet};
4use num_traits::{One, Zero};
5use polynomial::Polynomial;
6use rand_core::{CryptoRng, RngCore};
7use tracing::warn;
8
9use crate::{
10    common::{check_public_shares, Nonce, PolyCommitment, PublicNonce, Signature, SignatureShare},
11    compute,
12    curve::{
13        point::{Point, G},
14        scalar::Scalar,
15    },
16    errors::{AggregatorError, DkgError},
17    schnorr::ID,
18    taproot::SchnorrProof,
19    traits,
20    vss::VSS,
21};
22
23#[derive(Clone, Eq, PartialEq)]
24/// A WSTS party, which encapsulates a single polynomial, nonce, and one private key per key ID
25pub struct Party {
26    /// The party ID
27    pub party_id: u32,
28    /// The key IDs for this party
29    pub key_ids: Vec<u32>,
30    /// The public keys for this party, indexed by ID
31    num_keys: u32,
32    num_parties: u32,
33    threshold: u32,
34    f: Option<Polynomial<Scalar>>,
35    private_keys: HashMap<u32, Scalar>,
36    group_key: Point,
37    nonce: Nonce,
38}
39
40impl fmt::Debug for Party {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        f.debug_struct("Party")
43            .field("part_id", &self.party_id)
44            .field("key_ids", &self.key_ids)
45            .field("num_keys", &self.num_keys)
46            .field("num_parties", &self.num_parties)
47            .field("threshold", &self.threshold)
48            .field("group_key", &self.group_key)
49            .finish_non_exhaustive()
50    }
51}
52
53impl Party {
54    /// Construct a random Party with the passed party ID, key IDs, and parameters
55    pub fn new<RNG: RngCore + CryptoRng>(
56        party_id: u32,
57        key_ids: &[u32],
58        num_parties: u32,
59        num_keys: u32,
60        threshold: u32,
61        rng: &mut RNG,
62    ) -> Self {
63        // create a nonce using just the passed RNG to avoid having a zero nonce used against us
64        let secret_key = Scalar::random(rng);
65        let nonce = Nonce::random(&secret_key, rng);
66
67        Self {
68            party_id,
69            key_ids: key_ids.to_vec(),
70            num_keys,
71            num_parties,
72            threshold,
73            f: Some(VSS::random_poly(threshold - 1, rng)),
74            private_keys: Default::default(),
75            group_key: Point::zero(),
76            nonce,
77        }
78    }
79
80    /// Generate and store a private nonce for a signing round
81    pub fn gen_nonce<RNG: RngCore + CryptoRng>(
82        &mut self,
83        secret_key: &Scalar,
84        rng: &mut RNG,
85    ) -> PublicNonce {
86        self.nonce = Nonce::random(secret_key, rng);
87        PublicNonce::from(&self.nonce)
88    }
89
90    /// Get a public commitment to the private polynomial
91    pub fn get_poly_commitment<RNG: RngCore + CryptoRng>(
92        &self,
93        ctx: &[u8],
94        rng: &mut RNG,
95    ) -> Option<PolyCommitment> {
96        if let Some(poly) = &self.f {
97            Some(PolyCommitment {
98                id: ID::new(&self.id(), &poly.data()[0], ctx, rng),
99                poly: (0..poly.data().len())
100                    .map(|i| &poly.data()[i] * G)
101                    .collect(),
102            })
103        } else {
104            warn!("get_poly_commitment called with no polynomial");
105            None
106        }
107    }
108
109    /// Get the shares of this party's private polynomial for all keys
110    pub fn get_shares(&self) -> HashMap<u32, Scalar> {
111        let mut shares = HashMap::new();
112        if let Some(poly) = &self.f {
113            for i in 1..self.num_keys + 1 {
114                shares.insert(i, poly.eval(compute::id(i)));
115            }
116        } else {
117            warn!("get_poly_commitment called with no polynomial");
118        }
119        shares
120    }
121
122    /// Compute this party's share of the group secret key, but first check that the data is valid
123    /// and consistent.  This raises an issue though: what if we have private_shares and
124    /// public_shares from different parties?
125    /// To resolve the ambiguity, assume that the public_shares represent the correct group of
126    /// parties.  
127    pub fn compute_secret(
128        &mut self,
129        private_shares: &HashMap<u32, HashMap<u32, Scalar>>,
130        public_shares: &HashMap<u32, PolyCommitment>,
131        ctx: &[u8],
132    ) -> Result<(), DkgError> {
133        self.private_keys.clear();
134        self.group_key = Point::zero();
135
136        let threshold: usize = self.threshold.try_into()?;
137
138        let mut bad_ids = Vec::new();
139        for (i, comm) in public_shares.iter() {
140            if !check_public_shares(comm, threshold, ctx) {
141                bad_ids.push(*i);
142            } else {
143                self.group_key += comm.poly[0];
144            }
145        }
146        if !bad_ids.is_empty() {
147            return Err(DkgError::BadPublicShares(bad_ids));
148        }
149
150        let mut missing_shares = Vec::new();
151        for dst_key_id in &self.key_ids {
152            for src_key_id in public_shares.keys() {
153                match private_shares.get(dst_key_id) {
154                    Some(shares) => {
155                        if shares.get(src_key_id).is_none() {
156                            missing_shares.push((*dst_key_id, *src_key_id));
157                        }
158                    }
159                    None => {
160                        missing_shares.push((*dst_key_id, *src_key_id));
161                    }
162                }
163            }
164        }
165        if !missing_shares.is_empty() {
166            return Err(DkgError::MissingPrivateShares(missing_shares));
167        }
168
169        let mut bad_shares = Vec::new();
170        for key_id in &self.key_ids {
171            if let Some(shares) = private_shares.get(key_id) {
172                for (sender, s) in shares {
173                    if let Some(comm) = public_shares.get(sender) {
174                        if s * G != compute::poly(&compute::id(*key_id), &comm.poly)? {
175                            bad_shares.push(*sender);
176                        }
177                    } else {
178                        warn!("unable to check private share from {}: no corresponding public share, even though we checked for it above", sender);
179                    }
180                }
181            } else {
182                warn!(
183                    "no private shares for key_id {}, even though we checked for it above",
184                    key_id
185                );
186            }
187        }
188        if !bad_shares.is_empty() {
189            return Err(DkgError::BadPrivateShares(bad_shares));
190        }
191
192        for key_id in &self.key_ids {
193            self.private_keys.insert(*key_id, Scalar::zero());
194            if let Some(shares) = private_shares.get(key_id) {
195                let secret = shares.values().sum();
196                self.private_keys.insert(*key_id, secret);
197            } else {
198                warn!(
199                    "no private shares for key_id {}, even though we checked for it above",
200                    key_id
201                );
202            }
203        }
204
205        Ok(())
206    }
207
208    /// Compute a Scalar from this party's ID
209    pub fn id(&self) -> Scalar {
210        compute::id(self.party_id)
211    }
212
213    /// Sign `msg` with this party's shares of the group private key, using the set of `party_ids`, `key_ids` and corresponding `nonces`
214    pub fn sign(
215        &self,
216        msg: &[u8],
217        party_ids: &[u32],
218        key_ids: &[u32],
219        nonces: &[PublicNonce],
220    ) -> SignatureShare {
221        self.sign_with_tweak(msg, party_ids, key_ids, nonces, None)
222    }
223
224    /// Sign `msg` with this party's shares of the group private key, using the set of `party_ids`, `key_ids` and corresponding `nonces` with a tweaked public key. The posible values for tweak are
225    /// None    - standard FROST signature
226    /// Some(0) - BIP-340 schnorr signature using 32-byte private key adjustments
227    /// Some(t) - BIP-340 schnorr signature with BIP-341 tweaked keys, using 32-byte private key adjustments
228    #[allow(non_snake_case)]
229    pub fn sign_with_tweak(
230        &self,
231        msg: &[u8],
232        party_ids: &[u32],
233        key_ids: &[u32],
234        nonces: &[PublicNonce],
235        tweak: Option<Scalar>,
236    ) -> SignatureShare {
237        // When using BIP-340 32-byte public keys, we have to invert the private key if the
238        // public key is odd.  But if we're also using BIP-341 tweaked keys, we have to do
239        // the same thing if the tweaked public key is odd.  In that case, only invert the
240        // public key if exactly one of the internal or tweaked public keys is odd
241        let mut cx_sign = Scalar::one();
242        let tweaked_public_key = if let Some(t) = tweak {
243            if t != Scalar::zero() {
244                let key = compute::tweaked_public_key_from_tweak(&self.group_key, t);
245                if key.has_even_y() ^ self.group_key.has_even_y() {
246                    cx_sign = -cx_sign;
247                }
248
249                key
250            } else {
251                if !self.group_key.has_even_y() {
252                    cx_sign = -cx_sign;
253                }
254                self.group_key
255            }
256        } else {
257            self.group_key
258        };
259        let (_, R) = compute::intermediate(msg, party_ids, nonces);
260        let c = compute::challenge(&tweaked_public_key, &R, msg);
261        let mut r = &self.nonce.d + &self.nonce.e * compute::binding(&self.id(), nonces, msg);
262        if tweak.is_some() && !R.has_even_y() {
263            r = -r;
264        }
265
266        let mut cx = Scalar::zero();
267        for key_id in self.key_ids.iter() {
268            cx += c * &self.private_keys[key_id] * compute::lambda(*key_id, key_ids);
269        }
270
271        cx = cx_sign * cx;
272
273        let z = r + cx;
274
275        SignatureShare {
276            id: self.party_id,
277            z_i: z,
278            key_ids: self.key_ids.clone(),
279        }
280    }
281}
282
283/// The group signature aggregator
284#[derive(Clone, Debug, PartialEq)]
285pub struct Aggregator {
286    /// The total number of keys
287    pub num_keys: u32,
288    /// The threshold of signing keys needed to construct a valid signature
289    pub threshold: u32,
290    /// The aggregate group polynomial; `poly[0]` is the group public key
291    pub poly: Vec<Point>,
292}
293
294impl Aggregator {
295    /// Aggregate the party signatures using a tweak.  The posible values for tweak are
296    /// None    - standard FROST signature
297    /// Some(0) - BIP-340 schnorr signature using 32-byte private key adjustments
298    /// Some(t) - BIP-340 schnorr signature with BIP-341 tweaked keys, using 32-byte private key adjustments
299    #[allow(non_snake_case)]
300    pub fn sign_with_tweak(
301        &mut self,
302        msg: &[u8],
303        nonces: &[PublicNonce],
304        sig_shares: &[SignatureShare],
305        _key_ids: &[u32],
306        tweak: Option<Scalar>,
307    ) -> Result<(Point, Signature), AggregatorError> {
308        if nonces.len() != sig_shares.len() {
309            return Err(AggregatorError::BadNonceLen(nonces.len(), sig_shares.len()));
310        }
311
312        let party_ids: Vec<u32> = sig_shares.iter().map(|ss| ss.id).collect();
313        let (_Rs, R) = compute::intermediate(msg, &party_ids, nonces);
314        let mut z = Scalar::zero();
315        let mut cx_sign = Scalar::one();
316        let aggregate_public_key = self.poly[0];
317        let tweaked_public_key = if let Some(t) = tweak {
318            if t != Scalar::zero() {
319                let key = compute::tweaked_public_key_from_tweak(&aggregate_public_key, t);
320                if !key.has_even_y() {
321                    cx_sign = -cx_sign;
322                }
323                key
324            } else {
325                aggregate_public_key
326            }
327        } else {
328            aggregate_public_key
329        };
330        let c = compute::challenge(&tweaked_public_key, &R, msg);
331        // optimistically try to create the aggregate signature without checking for bad keys or sig shares
332        for sig_share in sig_shares {
333            z += sig_share.z_i;
334        }
335
336        // The signature shares have already incorporated the private key adjustments, so we just have to add the tweak.  But the tweak itself needs to be adjusted if the tweaked public key is odd
337        if let Some(t) = tweak {
338            z += cx_sign * c * t;
339        }
340
341        let sig = Signature { R, z };
342
343        Ok((tweaked_public_key, sig))
344    }
345
346    /// Check the party signatures after a failed group signature. The posible values for tweak are
347    /// None    - standard FROST signature
348    /// Some(0) - BIP-340 schnorr signature using 32-byte private key adjustments
349    /// Some(t) - BIP-340 schnorr signature with BIP-341 tweaked keys, using 32-byte private key adjustments
350    #[allow(non_snake_case)]
351    pub fn check_signature_shares(
352        &mut self,
353        msg: &[u8],
354        nonces: &[PublicNonce],
355        sig_shares: &[SignatureShare],
356        key_ids: &[u32],
357        tweak: Option<Scalar>,
358    ) -> AggregatorError {
359        if nonces.len() != sig_shares.len() {
360            return AggregatorError::BadNonceLen(nonces.len(), sig_shares.len());
361        }
362
363        let party_ids: Vec<u32> = sig_shares.iter().map(|ss| ss.id).collect();
364        let (Rs, R) = compute::intermediate(msg, &party_ids, nonces);
365        let mut bad_party_keys = Vec::new();
366        let mut bad_party_sigs = Vec::new();
367        let aggregate_public_key = self.poly[0];
368        let tweaked_public_key = if let Some(t) = tweak {
369            if t != Scalar::zero() {
370                compute::tweaked_public_key_from_tweak(&aggregate_public_key, t)
371            } else {
372                aggregate_public_key
373            }
374        } else {
375            aggregate_public_key
376        };
377        let c = compute::challenge(&tweaked_public_key, &R, msg);
378        let mut r_sign = Scalar::one();
379        let mut cx_sign = Scalar::one();
380        if let Some(t) = tweak {
381            if !R.has_even_y() {
382                r_sign = -Scalar::one();
383            }
384            if t != Scalar::zero() {
385                if !tweaked_public_key.has_even_y() ^ !aggregate_public_key.has_even_y() {
386                    cx_sign = -Scalar::one();
387                }
388            } else if !aggregate_public_key.has_even_y() {
389                cx_sign = -Scalar::one();
390            }
391        }
392
393        for i in 0..sig_shares.len() {
394            let z_i = sig_shares[i].z_i;
395            let mut cx = Point::zero();
396
397            for key_id in &sig_shares[i].key_ids {
398                let kid = compute::id(*key_id);
399                let public_key = match compute::poly(&kid, &self.poly) {
400                    Ok(p) => p,
401                    Err(_) => {
402                        bad_party_keys.push(sig_shares[i].id);
403                        Point::zero()
404                    }
405                };
406
407                cx += compute::lambda(*key_id, key_ids) * c * public_key;
408            }
409
410            if z_i * G != (r_sign * Rs[i] + cx_sign * cx) {
411                bad_party_sigs.push(sig_shares[i].id);
412            }
413        }
414        if !bad_party_keys.is_empty() {
415            AggregatorError::BadPartyKeys(bad_party_keys)
416        } else if !bad_party_sigs.is_empty() {
417            AggregatorError::BadPartySigs(bad_party_sigs)
418        } else {
419            AggregatorError::BadGroupSig
420        }
421    }
422}
423
424impl traits::Aggregator for Aggregator {
425    /// Construct an Aggregator with the passed parameters
426    fn new(num_keys: u32, threshold: u32) -> Self {
427        Self {
428            num_keys,
429            threshold,
430            poly: Default::default(),
431        }
432    }
433
434    /// Initialize the Aggregator polynomial
435    fn init(&mut self, comms: &HashMap<u32, PolyCommitment>) -> Result<(), AggregatorError> {
436        let threshold: usize = self.threshold.try_into()?;
437        let mut poly = Vec::with_capacity(threshold);
438
439        for i in 0..poly.capacity() {
440            poly.push(Point::zero());
441            for (_, comm) in comms {
442                poly[i] += &comm.poly[i];
443            }
444        }
445
446        self.poly = poly;
447
448        Ok(())
449    }
450
451    /// Check and aggregate the party signatures
452    fn sign(
453        &mut self,
454        msg: &[u8],
455        nonces: &[PublicNonce],
456        sig_shares: &[SignatureShare],
457        key_ids: &[u32],
458    ) -> Result<Signature, AggregatorError> {
459        let (key, sig) = self.sign_with_tweak(msg, nonces, sig_shares, key_ids, None)?;
460
461        if sig.verify(&key, msg) {
462            Ok(sig)
463        } else {
464            Err(self.check_signature_shares(msg, nonces, sig_shares, key_ids, None))
465        }
466    }
467
468    /// Check and aggregate the party signatures
469    fn sign_schnorr(
470        &mut self,
471        msg: &[u8],
472        nonces: &[PublicNonce],
473        sig_shares: &[SignatureShare],
474        key_ids: &[u32],
475    ) -> Result<SchnorrProof, AggregatorError> {
476        let tweak = Scalar::from(0);
477        let (key, sig) = self.sign_with_tweak(msg, nonces, sig_shares, key_ids, Some(tweak))?;
478        let proof = SchnorrProof::new(&sig);
479
480        if proof.verify(&key.x(), msg) {
481            Ok(proof)
482        } else {
483            Err(self.check_signature_shares(msg, nonces, sig_shares, key_ids, Some(tweak)))
484        }
485    }
486
487    /// Check and aggregate the party signatures
488    fn sign_taproot(
489        &mut self,
490        msg: &[u8],
491        nonces: &[PublicNonce],
492        sig_shares: &[SignatureShare],
493        key_ids: &[u32],
494        merkle_root: Option<[u8; 32]>,
495    ) -> Result<SchnorrProof, AggregatorError> {
496        let tweak = compute::tweak(&self.poly[0], merkle_root);
497        let (key, sig) = self.sign_with_tweak(msg, nonces, sig_shares, key_ids, Some(tweak))?;
498        let proof = SchnorrProof::new(&sig);
499
500        if proof.verify(&key.x(), msg) {
501            Ok(proof)
502        } else {
503            Err(self.check_signature_shares(msg, nonces, sig_shares, key_ids, Some(tweak)))
504        }
505    }
506}
507
508/// Typedef so we can use the same tokens for v1 and v2
509pub type Signer = Party;
510
511impl traits::Signer for Party {
512    fn new<RNG: RngCore + CryptoRng>(
513        party_id: u32,
514        key_ids: &[u32],
515        num_signers: u32,
516        num_keys: u32,
517        threshold: u32,
518        rng: &mut RNG,
519    ) -> Self {
520        Party::new(party_id, key_ids, num_signers, num_keys, threshold, rng)
521    }
522
523    fn load(state: &traits::SignerState) -> Self {
524        // v2 signer contains single party
525        assert_eq!(state.parties.len(), 1);
526
527        let party_state = &state.parties[0].1;
528
529        Self {
530            party_id: state.id,
531            key_ids: state.key_ids.clone(),
532            num_keys: state.num_keys,
533            num_parties: state.num_parties,
534            threshold: state.threshold,
535            f: party_state.polynomial.clone(),
536            private_keys: party_state
537                .private_keys
538                .iter()
539                .map(|(k, v)| (*k, *v))
540                .collect(),
541            group_key: state.group_key,
542            nonce: party_state.nonce.clone(),
543        }
544    }
545
546    fn save(&self) -> traits::SignerState {
547        let party_state = traits::PartyState {
548            polynomial: self.f.clone(),
549            private_keys: self.private_keys.iter().map(|(k, v)| (*k, *v)).collect(),
550            nonce: self.nonce.clone(),
551        };
552        traits::SignerState {
553            id: self.party_id,
554            key_ids: self.key_ids.clone(),
555            num_keys: self.num_keys,
556            num_parties: self.num_parties,
557            threshold: self.threshold,
558            group_key: self.group_key,
559            parties: vec![(self.party_id, party_state)],
560        }
561    }
562
563    fn get_id(&self) -> u32 {
564        self.party_id
565    }
566
567    fn get_key_ids(&self) -> Vec<u32> {
568        self.key_ids.clone()
569    }
570
571    fn get_num_parties(&self) -> u32 {
572        self.num_parties
573    }
574
575    fn get_poly_commitments<RNG: RngCore + CryptoRng>(
576        &self,
577        ctx: &[u8],
578        rng: &mut RNG,
579    ) -> Vec<PolyCommitment> {
580        if let Some(poly) = self.get_poly_commitment(ctx, rng) {
581            vec![poly.clone()]
582        } else {
583            vec![]
584        }
585    }
586
587    fn reset_polys<RNG: RngCore + CryptoRng>(&mut self, rng: &mut RNG) {
588        self.f = Some(VSS::random_poly(self.threshold - 1, rng));
589    }
590
591    fn clear_polys(&mut self) {
592        self.f = None;
593    }
594
595    fn get_shares(&self) -> HashMap<u32, HashMap<u32, Scalar>> {
596        let mut shares = HashMap::new();
597
598        shares.insert(self.party_id, self.get_shares());
599
600        shares
601    }
602
603    fn compute_secrets(
604        &mut self,
605        private_shares: &HashMap<u32, HashMap<u32, Scalar>>,
606        polys: &HashMap<u32, PolyCommitment>,
607        ctx: &[u8],
608    ) -> Result<(), HashMap<u32, DkgError>> {
609        // go through the shares, looking for this party's
610        let mut key_shares = HashMap::new();
611        for dest_key_id in self.get_key_ids() {
612            let mut shares = HashMap::new();
613            for (src_party_id, signer_shares) in private_shares.iter() {
614                if let Some(signer_share) = signer_shares.get(&dest_key_id) {
615                    shares.insert(*src_party_id, *signer_share);
616                }
617            }
618            key_shares.insert(dest_key_id, shares);
619        }
620
621        match self.compute_secret(&key_shares, polys, ctx) {
622            Ok(()) => Ok(()),
623            Err(dkg_error) => {
624                let mut dkg_errors = HashMap::new();
625                dkg_errors.insert(self.party_id, dkg_error);
626                Err(dkg_errors)
627            }
628        }
629    }
630
631    fn gen_nonces<RNG: RngCore + CryptoRng>(
632        &mut self,
633        secret_key: &Scalar,
634        rng: &mut RNG,
635    ) -> Vec<PublicNonce> {
636        vec![self.gen_nonce(secret_key, rng)]
637    }
638
639    fn compute_intermediate(
640        msg: &[u8],
641        signer_ids: &[u32],
642        _key_ids: &[u32],
643        nonces: &[PublicNonce],
644    ) -> (Vec<Point>, Point) {
645        compute::intermediate(msg, signer_ids, nonces)
646    }
647
648    fn validate_party_id(
649        signer_id: u32,
650        party_id: u32,
651        _signer_key_ids: &HashMap<u32, HashSet<u32>>,
652    ) -> bool {
653        signer_id == party_id
654    }
655
656    fn sign(
657        &self,
658        msg: &[u8],
659        signer_ids: &[u32],
660        key_ids: &[u32],
661        nonces: &[PublicNonce],
662    ) -> Vec<SignatureShare> {
663        vec![self.sign(msg, signer_ids, key_ids, nonces)]
664    }
665
666    fn sign_schnorr(
667        &self,
668        msg: &[u8],
669        signer_ids: &[u32],
670        key_ids: &[u32],
671        nonces: &[PublicNonce],
672    ) -> Vec<SignatureShare> {
673        vec![self.sign_with_tweak(msg, signer_ids, key_ids, nonces, Some(Scalar::from(0)))]
674    }
675
676    fn sign_taproot(
677        &self,
678        msg: &[u8],
679        signer_ids: &[u32],
680        key_ids: &[u32],
681        nonces: &[PublicNonce],
682        merkle_root: Option<[u8; 32]>,
683    ) -> Vec<SignatureShare> {
684        let tweak = compute::tweak(&self.group_key, merkle_root);
685        vec![self.sign_with_tweak(msg, signer_ids, key_ids, nonces, Some(tweak))]
686    }
687}
688
689/// Helper functions for tests
690pub mod test_helpers {
691    use super::Scalar;
692    use crate::common::{PolyCommitment, PublicNonce};
693    use crate::errors::DkgError;
694    use crate::traits::Signer;
695    use crate::v2;
696    use crate::v2::SignatureShare;
697
698    use hashbrown::HashMap;
699    use rand_core::{CryptoRng, RngCore};
700
701    /// Run a distributed key generation round
702    pub fn dkg<RNG: RngCore + CryptoRng>(
703        signers: &mut [v2::Party],
704        rng: &mut RNG,
705    ) -> Result<HashMap<u32, PolyCommitment>, HashMap<u32, DkgError>> {
706        let ctx = 0u64.to_be_bytes();
707        let mut polys: HashMap<u32, PolyCommitment> = Default::default();
708        for signer in signers.iter() {
709            if let Some(poly) = signer.get_poly_commitment(&ctx, rng) {
710                polys.insert(signer.get_id(), poly);
711            }
712        }
713
714        // each party broadcasts their commitments
715        let mut broadcast_shares = Vec::new();
716        for party in signers.iter() {
717            broadcast_shares.push((party.party_id, party.get_shares()));
718        }
719
720        // each party collects its shares from the broadcasts
721        // maybe this should collect into a hashmap first?
722        let mut secret_errors = HashMap::new();
723        for party in signers.iter_mut() {
724            let mut party_shares = HashMap::new();
725            for key_id in party.key_ids.clone() {
726                let mut key_shares = HashMap::new();
727
728                for (id, shares) in &broadcast_shares {
729                    if let Some(share) = shares.get(&key_id) {
730                        key_shares.insert(*id, *share);
731                    }
732                }
733
734                party_shares.insert(key_id, key_shares);
735            }
736
737            if let Err(secret_error) = party.compute_secret(&party_shares, &polys, &ctx) {
738                secret_errors.insert(party.party_id, secret_error);
739            }
740        }
741
742        if secret_errors.is_empty() {
743            Ok(polys)
744        } else {
745            Err(secret_errors)
746        }
747    }
748
749    /// Run a signing round for the passed `msg`
750    pub fn sign<RNG: RngCore + CryptoRng>(
751        msg: &[u8],
752        signers: &mut [v2::Party],
753        rng: &mut RNG,
754    ) -> (Vec<PublicNonce>, Vec<SignatureShare>, Vec<u32>) {
755        let secret_key = Scalar::random(rng);
756        let party_ids: Vec<u32> = signers.iter().map(|s| s.party_id).collect();
757        let key_ids: Vec<u32> = signers.iter().flat_map(|s| s.key_ids.clone()).collect();
758        let nonces: Vec<PublicNonce> = signers
759            .iter_mut()
760            .map(|s| s.gen_nonce(&secret_key, rng))
761            .collect();
762        let shares = signers
763            .iter()
764            .map(|s| s.sign(msg, &party_ids, &key_ids, &nonces))
765            .collect();
766
767        (nonces, shares, key_ids)
768    }
769}
770
771#[cfg(test)]
772mod tests {
773    use hashbrown::{HashMap, HashSet};
774    use num_traits::Zero;
775
776    use crate::util::create_rng;
777    use crate::{
778        curve::scalar::Scalar,
779        traits::{
780            self, test_helpers::run_compute_secrets_missing_private_shares, Aggregator, Signer,
781        },
782        v2,
783    };
784
785    #[test]
786    fn signer_gen_nonces() {
787        let mut rng = create_rng();
788        let secret_key = Scalar::random(&mut rng);
789        let id = 1;
790        let key_ids = [1, 2, 3];
791        let n: u32 = 10;
792        let p: u32 = 3;
793        let t: u32 = 7;
794
795        let mut signer = v2::Signer::new(id, &key_ids, p, n, t, &mut rng);
796
797        assert!(!signer.nonce.is_zero());
798        assert!(signer.nonce.is_valid());
799
800        let nonces = signer.gen_nonces(&secret_key, &mut rng);
801
802        assert_eq!(nonces.len(), 1);
803
804        assert!(!signer.nonce.is_zero());
805        assert!(signer.nonce.is_valid());
806    }
807
808    #[test]
809    fn party_save_load() {
810        let mut rng = create_rng();
811        let key_ids = [1, 2, 3];
812        let n: u32 = 10;
813        let t: u32 = 7;
814
815        let signer = v2::Party::new(0, &key_ids, 1, n, t, &mut rng);
816
817        let state = signer.save();
818        let loaded = v2::Party::load(&state);
819
820        assert_eq!(signer, loaded);
821    }
822
823    #[test]
824    fn clear_polys() {
825        let ctx = 0u64.to_be_bytes();
826        let mut rng = create_rng();
827        let key_ids = [1, 2, 3];
828        let n: u32 = 10;
829        let t: u32 = 7;
830
831        let mut signer = v2::Party::new(0, &key_ids, 1, n, t, &mut rng);
832
833        assert_eq!(signer.get_poly_commitments(&ctx, &mut rng).len(), 1);
834        assert_eq!(signer.get_shares().len(), usize::try_from(n).unwrap());
835
836        signer.clear_polys();
837
838        assert_eq!(signer.get_poly_commitments(&ctx, &mut rng).len(), 0);
839        assert_eq!(signer.get_shares().len(), 0);
840    }
841
842    #[test]
843    fn aggregator_sign() {
844        let mut rng = create_rng();
845        let msg = "It was many and many a year ago".as_bytes();
846        let n_k: u32 = 10;
847        let t: u32 = 7;
848        let party_key_ids: Vec<Vec<u32>> = [
849            [1, 2, 3].to_vec(),
850            [4, 5].to_vec(),
851            [6, 7, 8].to_vec(),
852            [9, 10].to_vec(),
853        ]
854        .to_vec();
855        let n_p = party_key_ids.len().try_into().unwrap();
856        let mut signers: Vec<v2::Party> = party_key_ids
857            .iter()
858            .enumerate()
859            .map(|(pid, pkids)| {
860                v2::Party::new(pid.try_into().unwrap(), pkids, n_p, n_k, t, &mut rng)
861            })
862            .collect();
863
864        let comms = match traits::test_helpers::dkg(&mut signers, &mut rng) {
865            Ok(comms) => comms,
866            Err(secret_errors) => {
867                panic!("Got secret errors from DKG: {secret_errors:?}");
868            }
869        };
870
871        // signers [0,1,3] who have t keys
872        {
873            let mut signers = [signers[0].clone(), signers[1].clone(), signers[3].clone()].to_vec();
874            let mut sig_agg = v2::Aggregator::new(n_k, t);
875
876            sig_agg.init(&comms).expect("aggregator init failed");
877
878            let (nonces, sig_shares, key_ids) = v2::test_helpers::sign(msg, &mut signers, &mut rng);
879            if let Err(e) = sig_agg.sign(msg, &nonces, &sig_shares, &key_ids) {
880                panic!("Aggregator sign failed: {e:?}");
881            }
882        }
883    }
884
885    #[test]
886    /// Run a distributed key generation round with not enough shares
887    pub fn run_compute_secrets_missing_shares() {
888        run_compute_secrets_missing_private_shares::<v2::Signer>()
889    }
890
891    #[test]
892    /// Run DKG and aggregator init with a bad polynomial length
893    pub fn bad_polynomial_length() {
894        let gt = |t| t + 1;
895        let lt = |t| t - 1;
896        traits::test_helpers::bad_polynomial_length::<v2::Signer, _>(gt);
897        traits::test_helpers::bad_polynomial_length::<v2::Signer, _>(lt);
898    }
899
900    #[test]
901    /// Run DKG and aggregator init with a bad polynomial commitment
902    pub fn bad_polynomial_commitment() {
903        traits::test_helpers::bad_polynomial_commitment::<v2::Signer>();
904    }
905
906    #[test]
907    /// Check that party_ids can be properly validated
908    fn validate_party_id() {
909        let mut signer_key_ids = HashMap::new();
910        let mut key_ids = HashSet::new();
911
912        key_ids.insert(1);
913        signer_key_ids.insert(0, key_ids);
914
915        assert!(v2::Signer::validate_party_id(0, 0, &signer_key_ids));
916        assert!(!v2::Signer::validate_party_id(0, 1, &signer_key_ids));
917    }
918}