1use core::{cmp::PartialEq, fmt::Debug};
2use hashbrown::{HashMap, HashSet};
3use polynomial::Polynomial;
4use rand_core::{CryptoRng, RngCore};
5use serde::{Deserialize, Serialize};
6use std::fmt;
7
8use crate::{
9 common::{MerkleRoot, Nonce, PolyCommitment, PublicNonce, Signature, SignatureShare},
10 curve::{point::Point, scalar::Scalar},
11 errors::{AggregatorError, DkgError},
12 taproot::SchnorrProof,
13};
14
15#[derive(Clone, Deserialize, Serialize, PartialEq)]
16pub struct PartyState {
18 pub polynomial: Option<Polynomial<Scalar>>,
20 pub private_keys: Vec<(u32, Scalar)>,
22 pub nonce: Nonce,
24}
25
26impl fmt::Debug for PartyState {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 f.debug_struct("PartyState").finish_non_exhaustive()
29 }
30}
31
32#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
33pub struct SignerState {
35 pub id: u32,
37 pub key_ids: Vec<u32>,
39 pub num_keys: u32,
41 pub num_parties: u32,
43 pub threshold: u32,
45 pub group_key: Point,
47 pub parties: Vec<(u32, PartyState)>,
49}
50
51pub trait Signer: Clone + Debug + PartialEq {
53 fn new<RNG: RngCore + CryptoRng>(
55 party_id: u32,
56 key_ids: &[u32],
57 num_signers: u32,
58 num_keys: u32,
59 threshold: u32,
60 rng: &mut RNG,
61 ) -> Self;
62
63 fn load(state: &SignerState) -> Self;
65
66 fn save(&self) -> SignerState;
68
69 fn get_id(&self) -> u32;
71
72 fn get_key_ids(&self) -> Vec<u32>;
74
75 fn get_num_parties(&self) -> u32;
77
78 fn get_poly_commitments<RNG: RngCore + CryptoRng>(
80 &self,
81 ctx: &[u8],
82 rng: &mut RNG,
83 ) -> Vec<PolyCommitment>;
84
85 fn reset_polys<RNG: RngCore + CryptoRng>(&mut self, rng: &mut RNG);
87
88 fn clear_polys(&mut self);
90
91 fn get_shares(&self) -> HashMap<u32, HashMap<u32, Scalar>>;
93
94 fn compute_secrets(
96 &mut self,
97 shares: &HashMap<u32, HashMap<u32, Scalar>>,
98 polys: &HashMap<u32, PolyCommitment>,
99 ctx: &[u8],
100 ) -> Result<(), HashMap<u32, DkgError>>;
101
102 fn gen_nonces<RNG: RngCore + CryptoRng>(
104 &mut self,
105 secret_key: &Scalar,
106 rng: &mut RNG,
107 ) -> Vec<PublicNonce>;
108
109 fn compute_intermediate(
111 msg: &[u8],
112 signer_ids: &[u32],
113 key_ids: &[u32],
114 nonces: &[PublicNonce],
115 ) -> (Vec<Point>, Point);
116
117 fn validate_party_id(
119 signer_id: u32,
120 party_id: u32,
121 signer_key_ids: &HashMap<u32, HashSet<u32>>,
122 ) -> bool;
123
124 fn sign(
126 &self,
127 msg: &[u8],
128 signer_ids: &[u32],
129 key_ids: &[u32],
130 nonces: &[PublicNonce],
131 ) -> Vec<SignatureShare>;
132
133 fn sign_schnorr(
135 &self,
136 msg: &[u8],
137 signer_ids: &[u32],
138 key_ids: &[u32],
139 nonces: &[PublicNonce],
140 ) -> Vec<SignatureShare>;
141
142 fn sign_taproot(
144 &self,
145 msg: &[u8],
146 signer_ids: &[u32],
147 key_ids: &[u32],
148 nonces: &[PublicNonce],
149 merkle_root: Option<MerkleRoot>,
150 ) -> Vec<SignatureShare>;
151}
152
153pub trait Aggregator: Clone + Debug + PartialEq {
155 fn new(num_keys: u32, threshold: u32) -> Self;
157
158 fn init(&mut self, poly_comms: &HashMap<u32, PolyCommitment>) -> Result<(), AggregatorError>;
160
161 fn sign(
163 &mut self,
164 msg: &[u8],
165 nonces: &[PublicNonce],
166 sig_shares: &[SignatureShare],
167 key_ids: &[u32],
168 ) -> Result<Signature, AggregatorError>;
169
170 fn sign_schnorr(
173 &mut self,
174 msg: &[u8],
175 nonces: &[PublicNonce],
176 sig_shares: &[SignatureShare],
177 key_ids: &[u32],
178 ) -> Result<SchnorrProof, AggregatorError>;
179
180 fn sign_taproot(
184 &mut self,
185 msg: &[u8],
186 nonces: &[PublicNonce],
187 sig_shares: &[SignatureShare],
188 key_ids: &[u32],
189 merkle_root: Option<MerkleRoot>,
190 ) -> Result<SchnorrProof, AggregatorError>;
191}
192
193pub mod test_helpers {
195 use hashbrown::HashMap;
196 use rand_core::{CryptoRng, RngCore};
197
198 use crate::{common::PolyCommitment, errors::DkgError, traits::Scalar, util::create_rng};
199
200 pub fn dkg<RNG: RngCore + CryptoRng, Signer: super::Signer>(
202 signers: &mut [Signer],
203 rng: &mut RNG,
204 ) -> Result<HashMap<u32, PolyCommitment>, HashMap<u32, DkgError>> {
205 let ctx = 0u64.to_be_bytes();
206 let public_shares: HashMap<u32, PolyCommitment> = signers
207 .iter()
208 .flat_map(|s| s.get_poly_commitments(&ctx, rng))
209 .map(|comm| (comm.id.id.get_u32(), comm))
210 .collect();
211 let mut private_shares = HashMap::new();
212
213 for signer in signers.iter() {
214 for (signer_id, signer_shares) in signer.get_shares() {
215 private_shares.insert(signer_id, signer_shares);
216 }
217 }
218
219 let mut secret_errors = HashMap::new();
220 for signer in signers.iter_mut() {
221 if let Err(signer_secret_errors) =
222 signer.compute_secrets(&private_shares, &public_shares, &ctx)
223 {
224 secret_errors.extend(signer_secret_errors.into_iter());
225 }
226 }
227
228 if secret_errors.is_empty() {
229 Ok(public_shares)
230 } else {
231 Err(secret_errors)
232 }
233 }
234
235 fn compute_secrets_missing_private_shares<RNG: RngCore + CryptoRng, Signer: super::Signer>(
237 signers: &mut [Signer],
238 rng: &mut RNG,
239 missing_key_ids: &[u32],
240 ) -> Result<HashMap<u32, PolyCommitment>, HashMap<u32, DkgError>> {
241 assert!(
242 !missing_key_ids.is_empty(),
243 "Cannot run a missing shares test without specificying at least one missing key id"
244 );
245 let ctx = 0u64.to_be_bytes();
246 let polys: HashMap<u32, PolyCommitment> = signers
247 .iter()
248 .flat_map(|s| s.get_poly_commitments(&ctx, rng))
249 .map(|comm| (comm.id.id.get_u32(), comm))
250 .collect();
251 let mut private_shares = HashMap::new();
252
253 for signer in signers.iter() {
254 for (signer_id, mut signer_shares) in signer.get_shares() {
255 for key_id in missing_key_ids {
256 if signer.get_key_ids().contains(key_id) {
257 signer_shares.remove(key_id);
258 }
259 }
260 private_shares.insert(signer_id, signer_shares);
261 }
262 }
263
264 let mut secret_errors = HashMap::new();
265 for signer in signers.iter_mut() {
266 if let Err(signer_secret_errors) = signer.compute_secrets(&private_shares, &polys, &ctx)
267 {
268 secret_errors.extend(signer_secret_errors.into_iter());
269 }
270 }
271
272 if secret_errors.is_empty() {
273 Ok(polys)
274 } else {
275 Err(secret_errors)
276 }
277 }
278
279 #[allow(non_snake_case)]
280 pub fn run_compute_secrets_missing_private_shares<Signer: super::Signer>() {
282 let Nk: u32 = 10;
283 let Np: u32 = 4;
284 let T: u32 = 7;
285 let signer_ids: Vec<Vec<u32>> = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8], vec![9, 10]];
286 let missing_key_ids = vec![1, 7];
287 let mut rng = create_rng();
288 let mut signers: Vec<Signer> = signer_ids
289 .iter()
290 .enumerate()
291 .map(|(id, ids)| Signer::new(id.try_into().unwrap(), ids, Nk, Np, T, &mut rng))
292 .collect();
293
294 match compute_secrets_missing_private_shares(&mut signers, &mut rng, &missing_key_ids) {
295 Ok(polys) => panic!("Got a result with missing public shares: {polys:?}"),
296 Err(secret_errors) => {
297 for (_, error) in secret_errors {
298 assert!(matches!(error, DkgError::MissingPrivateShares(_)));
299 }
300 }
301 }
302 }
303
304 pub fn bad_polynomial_length<Signer: super::Signer, F: Fn(u32) -> u32>(func: F) {
306 let num_keys: u32 = 10;
307 let num_signers: u32 = 4;
308 let threshold: u32 = 7;
309 let signer_ids: Vec<Vec<u32>> = vec![vec![1, 2, 3, 4], vec![5, 6, 7], vec![8, 9], vec![10]];
310 let mut rng = create_rng();
311 let mut signers: Vec<Signer> = signer_ids
312 .iter()
313 .enumerate()
314 .map(|(id, ids)| {
315 if *ids == vec![10] {
316 Signer::new(
317 id.try_into().unwrap(),
318 ids,
319 num_signers,
320 num_keys,
321 func(threshold),
322 &mut rng,
323 )
324 } else {
325 Signer::new(
326 id.try_into().unwrap(),
327 ids,
328 num_signers,
329 num_keys,
330 threshold,
331 &mut rng,
332 )
333 }
334 })
335 .collect();
336
337 if dkg(&mut signers, &mut rng).is_ok() {
338 panic!("DKG should have failed")
339 }
340 }
341
342 pub fn bad_polynomial_commitment<Signer: super::Signer>() {
344 let num_keys: u32 = 10;
345 let num_signers: u32 = 4;
346 let threshold: u32 = 7;
347 let signer_ids: Vec<Vec<u32>> = vec![vec![1, 2, 3, 4], vec![5, 6, 7], vec![8, 9], vec![10]];
348 let mut rng = create_rng();
349 let mut signers: Vec<Signer> = signer_ids
350 .iter()
351 .enumerate()
352 .map(|(id, ids)| {
353 Signer::new(
354 id.try_into().unwrap(),
355 ids,
356 num_signers,
357 num_keys,
358 threshold,
359 &mut rng,
360 )
361 })
362 .collect();
363
364 let ctx = 0u64.to_be_bytes();
369 let bad_party_id = 2u32;
370 let public_shares: HashMap<u32, PolyCommitment> = signers
371 .iter()
372 .flat_map(|s| s.get_poly_commitments(&ctx, &mut rng))
373 .map(|comm| {
374 let party_id = comm.id.id.get_u32();
375 if party_id == bad_party_id {
376 let mut bad_comm = comm.clone();
378 bad_comm.id.proof.s += Scalar::from(1);
379 (party_id, bad_comm)
380 } else {
381 (party_id, comm)
382 }
383 })
384 .collect();
385 let mut private_shares = HashMap::new();
386
387 for signer in signers.iter() {
388 for (signer_id, signer_shares) in signer.get_shares() {
389 private_shares.insert(signer_id, signer_shares);
390 }
391 }
392
393 let mut secret_errors = HashMap::new();
394 for signer in signers.iter_mut() {
395 if let Err(signer_secret_errors) =
396 signer.compute_secrets(&private_shares, &public_shares, &ctx)
397 {
398 secret_errors.extend(signer_secret_errors.into_iter());
399 }
400 }
401
402 assert!(!secret_errors.is_empty());
403 }
404}