wsts/
taproot.rs

1use crate::{
2    common::Signature,
3    compute,
4    curve::{
5        field,
6        point::{Point, G},
7        scalar::Scalar,
8    },
9};
10
11/// A SchnorrProof in BIP-340 format
12#[allow(non_snake_case)]
13#[derive(Clone, Debug, PartialEq, Eq)]
14pub struct SchnorrProof {
15    /// The schnorr public commitment (FROST Signature R)
16    pub r: field::Element,
17    /// The schnorr response (FROST Signature z)
18    pub s: Scalar,
19}
20
21impl SchnorrProof {
22    /// Construct a BIP-340 schnorr proof from a FROST signature
23    pub fn new(sig: &Signature) -> Self {
24        Self {
25            r: sig.R.x(),
26            s: sig.z,
27        }
28    }
29
30    /// Verify a BIP-340 schnorr proof
31    #[allow(non_snake_case)]
32    pub fn verify(&self, public_key: &field::Element, msg: &[u8]) -> bool {
33        let Ok(Y) = Point::lift_x(public_key) else {
34            return false;
35        };
36        let Ok(R) = Point::lift_x(&self.r) else {
37            return false;
38        };
39        let c = compute::challenge(&Y, &R, msg);
40        let Rp = self.s * G - c * Y;
41
42        Rp.has_even_y() && Rp.x() == self.r
43    }
44
45    /// Serialize this proof into a 64-byte buffer
46    pub fn to_bytes(&self) -> [u8; 64] {
47        let mut bytes = [0u8; 64];
48
49        bytes[0..32].copy_from_slice(&self.r.to_bytes());
50        bytes[32..64].copy_from_slice(&self.s.to_bytes());
51
52        bytes
53    }
54}
55
56impl From<[u8; 64]> for SchnorrProof {
57    fn from(bytes: [u8; 64]) -> Self {
58        let mut r_bytes = [0u8; 32];
59        let mut s_bytes = [0u8; 32];
60
61        r_bytes.copy_from_slice(&bytes[0..32]);
62        s_bytes.copy_from_slice(&bytes[32..64]);
63
64        Self {
65            r: field::Element::from(r_bytes),
66            s: Scalar::from(s_bytes),
67        }
68    }
69}
70
71/// Helper functions for tests
72pub mod test_helpers {
73    use super::*;
74    use crate::{
75        common::{PolyCommitment, PublicNonce, SignatureShare},
76        errors::DkgError,
77        traits,
78    };
79
80    use hashbrown::HashMap;
81    use rand_core::{CryptoRng, RngCore};
82
83    /// Run a distributed key generation round
84    #[allow(non_snake_case)]
85    pub fn dkg<RNG: RngCore + CryptoRng, Signer: traits::Signer>(
86        signers: &mut [Signer],
87        rng: &mut RNG,
88    ) -> Result<HashMap<u32, PolyCommitment>, HashMap<u32, DkgError>> {
89        let ctx = 0u64.to_be_bytes();
90        let polys: HashMap<u32, PolyCommitment> = signers
91            .iter()
92            .flat_map(|s| s.get_poly_commitments(&ctx, rng))
93            .map(|comm| (comm.id.id.get_u32(), comm))
94            .collect();
95
96        let mut private_shares = HashMap::new();
97        for signer in signers.iter() {
98            for (signer_id, signer_shares) in signer.get_shares() {
99                private_shares.insert(signer_id, signer_shares);
100            }
101        }
102
103        let mut secret_errors = HashMap::new();
104        for signer in signers.iter_mut() {
105            if let Err(signer_secret_errors) = signer.compute_secrets(&private_shares, &polys, &ctx)
106            {
107                secret_errors.extend(signer_secret_errors.into_iter());
108            }
109        }
110
111        if secret_errors.is_empty() {
112            Ok(polys)
113        } else {
114            Err(secret_errors)
115        }
116    }
117
118    fn sign_params<RNG: RngCore + CryptoRng, Signer: traits::Signer>(
119        signers: &mut [Signer],
120        rng: &mut RNG,
121    ) -> (Vec<u32>, Vec<u32>, Vec<PublicNonce>) {
122        let secret_key = Scalar::random(rng);
123        let signer_ids: Vec<u32> = signers.iter().map(|s| s.get_id()).collect();
124        let key_ids: Vec<u32> = signers.iter().flat_map(|s| s.get_key_ids()).collect();
125        let nonces: Vec<PublicNonce> = signers
126            .iter_mut()
127            .flat_map(|s| s.gen_nonces(&secret_key, rng))
128            .collect();
129
130        (signer_ids, key_ids, nonces)
131    }
132
133    /// Run a signing round for the passed `msg`
134    pub fn sign<RNG: RngCore + CryptoRng, Signer: traits::Signer>(
135        msg: &[u8],
136        signers: &mut [Signer],
137        rng: &mut RNG,
138        merkle_root: Option<[u8; 32]>,
139    ) -> (Vec<PublicNonce>, Vec<SignatureShare>) {
140        let (signer_ids, key_ids, nonces) = sign_params(signers, rng);
141        let shares = signers
142            .iter()
143            .flat_map(|s| s.sign_taproot(msg, &signer_ids, &key_ids, &nonces, merkle_root))
144            .collect();
145
146        (nonces, shares)
147    }
148
149    /// Run a signing round for the passed `msg`
150    pub fn sign_schnorr<RNG: RngCore + CryptoRng, Signer: traits::Signer>(
151        msg: &[u8],
152        signers: &mut [Signer],
153        rng: &mut RNG,
154    ) -> (Vec<PublicNonce>, Vec<SignatureShare>) {
155        let (signer_ids, key_ids, nonces) = sign_params(signers, rng);
156        let shares = signers
157            .iter()
158            .flat_map(|s| s.sign_schnorr(msg, &signer_ids, &key_ids, &nonces))
159            .collect();
160
161        (nonces, shares)
162    }
163}
164
165#[cfg(test)]
166mod test {
167    use super::{test_helpers, Point, Scalar, SchnorrProof, G};
168
169    #[cfg(feature = "with_v1")]
170    use crate::v1;
171    use crate::{compute, traits::Aggregator, traits::Signer, util::create_rng, v2};
172
173    #[test]
174    #[allow(non_snake_case)]
175    fn key_tweaks() {
176        let mut rng = create_rng();
177        let r = Scalar::random(&mut rng);
178        let R = r * G;
179        let rp = if R.has_even_y() { r } else { -r };
180        let mut d = Scalar::random(&mut rng);
181        let mut P = d * G;
182        let msg = "sign me";
183        let c = compute::challenge(&P, &R, msg.as_bytes());
184
185        println!("P.has_even_y {}", P.has_even_y());
186        println!("R.has_even_y {}", R.has_even_y());
187
188        let s = r - c * d;
189        assert!(R == s * G + c * P);
190
191        while P.has_even_y() {
192            d = Scalar::random(&mut rng);
193            P = d * G;
194        }
195
196        println!("P.has_even_y {}", P.has_even_y());
197        let c = compute::challenge(&P, &R, msg.as_bytes());
198        let s = r - c * d;
199        assert!(R == s * G + c * P);
200
201        assert!(!P.has_even_y());
202        assert_eq!(d * G, P);
203
204        let s = rp + c * (-d);
205        assert!(Point::lift_x(&R.x()).unwrap() == s * G - c * Point::lift_x(&P.x()).unwrap());
206
207        let proof = SchnorrProof { r: R.x(), s };
208        {
209            let Pp = Point::lift_x(&P.x()).unwrap();
210            assert!(Pp == (-d) * G);
211            let R = Point::lift_x(&proof.r).unwrap();
212            let e = compute::challenge(&P, &R, msg.as_bytes());
213            let Rp = proof.s * G - e * Pp;
214            assert!(Rp.has_even_y());
215            assert_eq!(Rp.x(), proof.r);
216        }
217        assert!(proof.verify(&P.x(), msg.as_bytes()));
218
219        let Q = Point::lift_x(&P.x()).unwrap();
220        let c = compute::challenge(&Q, &R, msg.as_bytes());
221        println!("Q.has_even_y {}", Q.has_even_y());
222
223        assert!(Q != P);
224        assert!(d * G != Q);
225
226        let e = -d;
227
228        assert!(e * G == Q);
229
230        let s = r + c * e;
231        assert!(R == s * G - c * Q);
232
233        let s = rp + c * e;
234        let proof = SchnorrProof { r: R.x(), s };
235        assert!(proof.verify(&Q.x(), msg.as_bytes()));
236
237        {
238            let P = Point::lift_x(&Q.x()).unwrap();
239            let R = Point::lift_x(&proof.r).unwrap();
240            let e = compute::challenge(&Q, &R, msg.as_bytes());
241            //let e = c.clone();
242            let Rp = proof.s * G - e * P;
243            assert!(Rp.has_even_y());
244            assert_eq!(Rp.x(), proof.r);
245        }
246
247        /*
248        d = Scalar::random(&mut rng);
249        P = d * G;
250        e = Scalar::random(&mut rng);
251        Q = e * G;
252        */
253        let S = compute::tweaked_public_key(&P, None);
254        println!("S.has_even_y {}", S.has_even_y());
255        let t = compute::tweak(&P, None);
256        //let d = if !P.has_even_y() || !S.has_even_y() {
257        //let d = if !S.has_even_y() {
258        let d = if !P.has_even_y() { -d + t } else { d + t };
259        assert!((d * G).x() == S.x());
260        assert!((d * G) == S);
261
262        let c = compute::challenge(&S, &R, msg.as_bytes());
263        let s = r - c * d;
264        assert!(R == s * G + c * S);
265
266        let d = if !S.has_even_y() { -d } else { d };
267
268        let s = rp + c * d;
269        let proof = SchnorrProof { r: R.x(), s };
270        {
271            let P = Point::lift_x(&S.x()).unwrap();
272            let R = Point::lift_x(&proof.r).unwrap();
273            let e = compute::challenge(&S, &R, msg.as_bytes());
274            //let e = c.clone();
275            let Rp = proof.s * G - e * P;
276            assert!(Rp.has_even_y());
277            assert_eq!(Rp.x(), proof.r);
278        }
279        assert!(proof.verify(&S.x(), msg.as_bytes()));
280
281        let T = compute::tweaked_public_key(&Q, None);
282        println!("T.has_even_y {}", T.has_even_y());
283        let t = compute::tweak(&Q, None);
284        //let e = if !Q.has_even_y() || !T.has_even_y() {
285        //let e = if !T.has_even_y() {
286        let e = if !Q.has_even_y() { -e + t } else { e + t };
287        assert!((e * G).x() == T.x());
288        assert!((e * G) == T);
289
290        let c = compute::challenge(&T, &R, msg.as_bytes());
291        let s = r - c * e;
292        assert!(R == s * G + c * T);
293
294        let e = if !T.has_even_y() { -e } else { e };
295
296        let s = rp + c * e;
297        let schnorr_proof = SchnorrProof { r: R.x(), s };
298        assert!(schnorr_proof.verify(&T.x(), msg.as_bytes()));
299    }
300
301    #[test]
302    #[allow(non_snake_case)]
303    #[cfg(feature = "with_v1")]
304    fn taproot_sign_verify_v1_with_merkle_root() {
305        let script = "OP_1".as_bytes();
306        let merkle_root = compute::merkle_root(script);
307
308        taproot_sign_verify_v1(Some(merkle_root));
309    }
310
311    #[test]
312    #[allow(non_snake_case)]
313    #[cfg(feature = "with_v1")]
314    fn taproot_sign_verify_v1_no_merkle_root() {
315        taproot_sign_verify_v1(None);
316    }
317
318    #[allow(non_snake_case)]
319    #[cfg(feature = "with_v1")]
320    fn taproot_sign_verify_v1(merkle_root: Option<[u8; 32]>) {
321        let mut rng = create_rng();
322
323        // First create and verify a frost signature
324        let msg = "It was many and many a year ago".as_bytes();
325        let N: u32 = 10;
326        let T: u32 = 7;
327        let signer_ids: Vec<Vec<u32>> = [
328            [1, 2, 3].to_vec(),
329            [4, 5].to_vec(),
330            [6, 7, 8].to_vec(),
331            [9, 10].to_vec(),
332        ]
333        .to_vec();
334        let mut signers: Vec<v1::Signer> = signer_ids
335            .iter()
336            .enumerate()
337            .map(|(id, ids)| v1::Signer::new(id.try_into().unwrap(), ids, N, T, &mut rng))
338            .collect();
339
340        let polys = match test_helpers::dkg(&mut signers, &mut rng) {
341            Ok(polys) => polys,
342            Err(secret_errors) => {
343                panic!("Got secret errors from DKG: {secret_errors:?}");
344            }
345        };
346
347        let mut S = [signers[0].clone(), signers[1].clone(), signers[3].clone()].to_vec();
348        let mut sig_agg = v1::Aggregator::new(N, T);
349        sig_agg.init(&polys).expect("aggregator init failed");
350        let aggregate_public_key = sig_agg.poly[0];
351        println!(
352            "sign_verify:  agg_pubkey    {}",
353            &hex::encode(sig_agg.poly[0].compress().as_bytes())
354        );
355        println!("sign_verify:  agg_pubkey.x  {}", &sig_agg.poly[0].x());
356        let tweaked_public_key = compute::tweaked_public_key(&aggregate_public_key, merkle_root);
357        println!(
358            "sign_verify: tweaked_key    {}",
359            &hex::encode(tweaked_public_key.compress().as_bytes())
360        );
361        println!("sign_verify: tweaked_key.x  {}", &tweaked_public_key.x());
362        let (nonces, sig_shares) = test_helpers::sign(msg, &mut S, &mut rng, merkle_root);
363        let proof = match sig_agg.sign_taproot(msg, &nonces, &sig_shares, &[], merkle_root) {
364            Err(e) => panic!("Aggregator sign failed: {e:?}"),
365            Ok(proof) => proof,
366        };
367
368        // now ser/de the proof
369        let proof_bytes = proof.to_bytes();
370        let proof_deser = SchnorrProof::from(proof_bytes);
371
372        assert_eq!(proof, proof_deser);
373        assert!(proof_deser.verify(&tweaked_public_key.x(), msg));
374    }
375
376    #[test]
377    #[allow(non_snake_case)]
378    fn taproot_sign_verify_v2_with_merkle_root() {
379        let script = "OP_1".as_bytes();
380        let merkle_root = compute::merkle_root(script);
381
382        taproot_sign_verify_v2(Some(merkle_root));
383    }
384
385    #[test]
386    #[allow(non_snake_case)]
387    fn taproot_sign_verify_v2_no_merkle_root() {
388        taproot_sign_verify_v2(None);
389    }
390
391    #[allow(non_snake_case)]
392    fn taproot_sign_verify_v2(merkle_root: Option<[u8; 32]>) {
393        let mut rng = create_rng();
394
395        // First create and verify a frost signature
396        let msg = "It was many and many a year ago".as_bytes();
397        let Nk: u32 = 10;
398        let Np: u32 = 4;
399        let T: u32 = 7;
400        let signer_ids: Vec<Vec<u32>> = [
401            [1, 2, 3].to_vec(),
402            [4, 5].to_vec(),
403            [6, 7, 8].to_vec(),
404            [9, 10].to_vec(),
405        ]
406        .to_vec();
407        let mut signers: Vec<v2::Signer> = signer_ids
408            .iter()
409            .enumerate()
410            .map(|(id, ids)| v2::Signer::new(id.try_into().unwrap(), ids, Np, Nk, T, &mut rng))
411            .collect();
412
413        let polys = match test_helpers::dkg(&mut signers, &mut rng) {
414            Ok(polys) => polys,
415            Err(secret_errors) => {
416                panic!("Got secret errors from DKG: {secret_errors:?}")
417            }
418        };
419
420        let mut S = [signers[0].clone(), signers[1].clone(), signers[3].clone()].to_vec();
421        let key_ids = S.iter().flat_map(|s| s.get_key_ids()).collect::<Vec<u32>>();
422        let mut sig_agg = v2::Aggregator::new(Nk, T);
423        sig_agg.init(&polys).expect("aggregator init failed");
424        let tweaked_public_key = compute::tweaked_public_key(&sig_agg.poly[0], merkle_root);
425        let (nonces, sig_shares) = test_helpers::sign(msg, &mut S, &mut rng, merkle_root);
426        let proof = match sig_agg.sign_taproot(msg, &nonces, &sig_shares, &key_ids, merkle_root) {
427            Err(e) => panic!("Aggregator sign failed: {e:?}"),
428            Ok(proof) => proof,
429        };
430
431        // now ser/de the proof
432        let proof_bytes = proof.to_bytes();
433        let proof_deser = SchnorrProof::from(proof_bytes);
434
435        assert_eq!(proof, proof_deser);
436        assert!(proof_deser.verify(&tweaked_public_key.x(), msg));
437    }
438}