1use core::{cmp::PartialEq, fmt::Debug};
2use hashbrown::{HashMap, HashSet};
3use polynomial::Polynomial;
4use rand_core::{CryptoRng, RngCore};
5use serde::{Deserialize, Serialize};
6use std::fmt;
7
8use crate::{
9 common::{MerkleRoot, Nonce, PolyCommitment, PublicNonce, Signature, SignatureShare},
10 curve::{point::Point, scalar::Scalar},
11 errors::{AggregatorError, DkgError},
12 taproot::SchnorrProof,
13};
14
15#[derive(Clone, Deserialize, Serialize, PartialEq)]
16pub struct PartyState {
18 pub polynomial: Option<Polynomial<Scalar>>,
20 pub private_keys: Vec<(u32, Scalar)>,
22 pub nonce: Nonce,
24}
25
26impl fmt::Debug for PartyState {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 f.debug_struct("PartyState").finish_non_exhaustive()
29 }
30}
31
32#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
33pub struct SignerState {
35 pub id: u32,
37 pub key_ids: Vec<u32>,
39 pub num_keys: u32,
41 pub num_parties: u32,
43 pub threshold: u32,
45 pub group_key: Point,
47 pub parties: Vec<(u32, PartyState)>,
49}
50
51pub trait Signer: Clone + Debug + PartialEq {
53 fn new<RNG: RngCore + CryptoRng>(
55 party_id: u32,
56 key_ids: &[u32],
57 num_signers: u32,
58 num_keys: u32,
59 threshold: u32,
60 rng: &mut RNG,
61 ) -> Self;
62
63 fn load(state: &SignerState) -> Self;
65
66 fn save(&self) -> SignerState;
68
69 fn get_id(&self) -> u32;
71
72 fn get_key_ids(&self) -> Vec<u32>;
74
75 fn get_num_parties(&self) -> u32;
77
78 fn get_poly_commitments<RNG: RngCore + CryptoRng>(
80 &self,
81 ctx: &[u8],
82 rng: &mut RNG,
83 ) -> Vec<PolyCommitment>;
84
85 fn reset_polys<RNG: RngCore + CryptoRng>(&mut self, rng: &mut RNG);
87
88 fn clear_polys(&mut self);
90
91 fn get_shares(&self) -> HashMap<u32, HashMap<u32, Scalar>>;
93
94 fn compute_secrets(
96 &mut self,
97 shares: &HashMap<u32, HashMap<u32, Scalar>>,
98 polys: &HashMap<u32, PolyCommitment>,
99 ctx: &[u8],
100 ) -> Result<(), HashMap<u32, DkgError>>;
101
102 fn gen_nonces<RNG: RngCore + CryptoRng>(
104 &mut self,
105 secret_key: &Scalar,
106 rng: &mut RNG,
107 ) -> Vec<PublicNonce>;
108
109 fn compute_intermediate(
111 &self,
112 msg: &[u8],
113 signer_ids: &[u32],
114 key_ids: &[u32],
115 nonces: &[PublicNonce],
116 ) -> (Vec<Point>, Point);
117
118 fn validate_party_id(
120 signer_id: u32,
121 party_id: u32,
122 signer_key_ids: &HashMap<u32, HashSet<u32>>,
123 ) -> bool;
124
125 fn sign(
127 &self,
128 msg: &[u8],
129 signer_ids: &[u32],
130 key_ids: &[u32],
131 nonces: &[PublicNonce],
132 ) -> Vec<SignatureShare>;
133
134 fn sign_schnorr(
136 &self,
137 msg: &[u8],
138 signer_ids: &[u32],
139 key_ids: &[u32],
140 nonces: &[PublicNonce],
141 ) -> Vec<SignatureShare>;
142
143 fn sign_taproot(
145 &self,
146 msg: &[u8],
147 signer_ids: &[u32],
148 key_ids: &[u32],
149 nonces: &[PublicNonce],
150 merkle_root: Option<MerkleRoot>,
151 ) -> Vec<SignatureShare>;
152}
153
154pub trait Aggregator: Clone + Debug + PartialEq {
156 fn new(num_keys: u32, threshold: u32) -> Self;
158
159 fn init(&mut self, poly_comms: &HashMap<u32, PolyCommitment>) -> Result<(), AggregatorError>;
161
162 fn sign(
164 &mut self,
165 msg: &[u8],
166 nonces: &[PublicNonce],
167 sig_shares: &[SignatureShare],
168 key_ids: &[u32],
169 ) -> Result<Signature, AggregatorError>;
170
171 fn sign_schnorr(
174 &mut self,
175 msg: &[u8],
176 nonces: &[PublicNonce],
177 sig_shares: &[SignatureShare],
178 key_ids: &[u32],
179 ) -> Result<SchnorrProof, AggregatorError>;
180
181 fn sign_taproot(
185 &mut self,
186 msg: &[u8],
187 nonces: &[PublicNonce],
188 sig_shares: &[SignatureShare],
189 key_ids: &[u32],
190 merkle_root: Option<MerkleRoot>,
191 ) -> Result<SchnorrProof, AggregatorError>;
192}
193
194pub mod test_helpers {
196 use hashbrown::HashMap;
197 use rand_core::{CryptoRng, RngCore};
198
199 use crate::{common::PolyCommitment, errors::DkgError, traits::Scalar, util::create_rng};
200
201 pub fn dkg<RNG: RngCore + CryptoRng, Signer: super::Signer>(
203 signers: &mut [Signer],
204 rng: &mut RNG,
205 ) -> Result<HashMap<u32, PolyCommitment>, HashMap<u32, DkgError>> {
206 let ctx = 0u64.to_be_bytes();
207 let public_shares: HashMap<u32, PolyCommitment> = signers
208 .iter()
209 .flat_map(|s| s.get_poly_commitments(&ctx, rng))
210 .map(|comm| (comm.id.id.get_u32(), comm))
211 .collect();
212 let mut private_shares = HashMap::new();
213
214 for signer in signers.iter() {
215 for (signer_id, signer_shares) in signer.get_shares() {
216 private_shares.insert(signer_id, signer_shares);
217 }
218 }
219
220 let mut secret_errors = HashMap::new();
221 for signer in signers.iter_mut() {
222 if let Err(signer_secret_errors) =
223 signer.compute_secrets(&private_shares, &public_shares, &ctx)
224 {
225 secret_errors.extend(signer_secret_errors);
226 }
227 }
228
229 if secret_errors.is_empty() {
230 Ok(public_shares)
231 } else {
232 Err(secret_errors)
233 }
234 }
235
236 fn compute_secrets_missing_private_shares<RNG: RngCore + CryptoRng, Signer: super::Signer>(
238 signers: &mut [Signer],
239 rng: &mut RNG,
240 missing_key_ids: &[u32],
241 ) -> Result<HashMap<u32, PolyCommitment>, HashMap<u32, DkgError>> {
242 assert!(
243 !missing_key_ids.is_empty(),
244 "Cannot run a missing shares test without specificying at least one missing key id"
245 );
246 let ctx = 0u64.to_be_bytes();
247 let polys: HashMap<u32, PolyCommitment> = signers
248 .iter()
249 .flat_map(|s| s.get_poly_commitments(&ctx, rng))
250 .map(|comm| (comm.id.id.get_u32(), comm))
251 .collect();
252 let mut private_shares = HashMap::new();
253
254 for signer in signers.iter() {
255 for (signer_id, mut signer_shares) in signer.get_shares() {
256 for key_id in missing_key_ids {
257 if signer.get_key_ids().contains(key_id) {
258 signer_shares.remove(key_id);
259 }
260 }
261 private_shares.insert(signer_id, signer_shares);
262 }
263 }
264
265 let mut secret_errors = HashMap::new();
266 for signer in signers.iter_mut() {
267 if let Err(signer_secret_errors) = signer.compute_secrets(&private_shares, &polys, &ctx)
268 {
269 secret_errors.extend(signer_secret_errors);
270 }
271 }
272
273 if secret_errors.is_empty() {
274 Ok(polys)
275 } else {
276 Err(secret_errors)
277 }
278 }
279
280 #[allow(non_snake_case)]
281 pub fn run_compute_secrets_missing_private_shares<Signer: super::Signer>() {
283 let Nk: u32 = 10;
284 let Np: u32 = 4;
285 let T: u32 = 7;
286 let signer_ids: Vec<Vec<u32>> = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8], vec![9, 10]];
287 let missing_key_ids = vec![1, 7];
288 let mut rng = create_rng();
289 let mut signers: Vec<Signer> = signer_ids
290 .iter()
291 .enumerate()
292 .map(|(id, ids)| Signer::new(id.try_into().unwrap(), ids, Nk, Np, T, &mut rng))
293 .collect();
294
295 match compute_secrets_missing_private_shares(&mut signers, &mut rng, &missing_key_ids) {
296 Ok(polys) => panic!("Got a result with missing public shares: {polys:?}"),
297 Err(secret_errors) => {
298 for (_, error) in secret_errors {
299 assert!(matches!(error, DkgError::MissingPrivateShares(_)));
300 }
301 }
302 }
303 }
304
305 pub fn bad_polynomial_length<Signer: super::Signer, F: Fn(u32) -> u32>(func: F) {
307 let num_keys: u32 = 10;
308 let num_signers: u32 = 4;
309 let threshold: u32 = 7;
310 let signer_ids: Vec<Vec<u32>> = vec![vec![1, 2, 3, 4], vec![5, 6, 7], vec![8, 9], vec![10]];
311 let mut rng = create_rng();
312 let mut signers: Vec<Signer> = signer_ids
313 .iter()
314 .enumerate()
315 .map(|(id, ids)| {
316 if *ids == vec![10] {
317 Signer::new(
318 id.try_into().unwrap(),
319 ids,
320 num_signers,
321 num_keys,
322 func(threshold),
323 &mut rng,
324 )
325 } else {
326 Signer::new(
327 id.try_into().unwrap(),
328 ids,
329 num_signers,
330 num_keys,
331 threshold,
332 &mut rng,
333 )
334 }
335 })
336 .collect();
337
338 if dkg(&mut signers, &mut rng).is_ok() {
339 panic!("DKG should have failed")
340 }
341 }
342
343 pub fn bad_polynomial_commitment<Signer: super::Signer>() {
345 let num_keys: u32 = 10;
346 let num_signers: u32 = 4;
347 let threshold: u32 = 7;
348 let signer_ids: Vec<Vec<u32>> = vec![vec![1, 2, 3, 4], vec![5, 6, 7], vec![8, 9], vec![10]];
349 let mut rng = create_rng();
350 let mut signers: Vec<Signer> = signer_ids
351 .iter()
352 .enumerate()
353 .map(|(id, ids)| {
354 Signer::new(
355 id.try_into().unwrap(),
356 ids,
357 num_signers,
358 num_keys,
359 threshold,
360 &mut rng,
361 )
362 })
363 .collect();
364
365 let ctx = 0u64.to_be_bytes();
370 let bad_party_id = 2u32;
371 let public_shares: HashMap<u32, PolyCommitment> = signers
372 .iter()
373 .flat_map(|s| s.get_poly_commitments(&ctx, &mut rng))
374 .map(|comm| {
375 let party_id = comm.id.id.get_u32();
376 if party_id == bad_party_id {
377 let mut bad_comm = comm.clone();
379 bad_comm.id.proof.s += Scalar::from(1);
380 (party_id, bad_comm)
381 } else {
382 (party_id, comm)
383 }
384 })
385 .collect();
386 let mut private_shares = HashMap::new();
387
388 for signer in signers.iter() {
389 for (signer_id, signer_shares) in signer.get_shares() {
390 private_shares.insert(signer_id, signer_shares);
391 }
392 }
393
394 let mut secret_errors = HashMap::new();
395 for signer in signers.iter_mut() {
396 if let Err(signer_secret_errors) =
397 signer.compute_secrets(&private_shares, &public_shares, &ctx)
398 {
399 secret_errors.extend(signer_secret_errors);
400 }
401 }
402
403 assert!(!secret_errors.is_empty());
404 }
405}