wsts/
compute.rs

1use core::iter::zip;
2use num_traits::{One, Zero};
3use sha2::{Digest, Sha256};
4
5use crate::{
6    common::PublicNonce,
7    curve::{
8        ecdsa,
9        point::{Compressed, Error as PointError, Point, G},
10        scalar::Scalar,
11    },
12    util::{expand_to_scalar, hash_to_scalar},
13};
14
15#[allow(non_snake_case)]
16/// Compute a binding value from the party ID, public nonces, and signed message using XMD-based expansion.
17pub fn binding(id: &Scalar, B: &[PublicNonce], msg: &[u8]) -> Scalar {
18    let prefix = b"WSTS/binding";
19
20    // Serialize all input into a buffer
21    let mut buf = Vec::new();
22    buf.extend_from_slice(&id.to_bytes());
23
24    for b in B {
25        buf.extend_from_slice(b.D.compress().as_bytes());
26        buf.extend_from_slice(b.E.compress().as_bytes());
27    }
28
29    buf.extend_from_slice(msg);
30
31    expand_to_scalar(&buf, prefix)
32        .expect("FATAL: DST is less than 256 bytes so operation should not fail")
33}
34
35#[allow(non_snake_case)]
36/// Compute a binding value from the party ID, public nonces, and signed message using XMD-based expansion.
37pub fn binding_compressed(id: &Scalar, B: &[(Compressed, Compressed)], msg: &[u8]) -> Scalar {
38    let prefix = b"WSTS/binding";
39
40    // Serialize all input into a buffer
41    let mut buf = Vec::new();
42    buf.extend_from_slice(&id.to_bytes());
43
44    for (D, E) in B {
45        buf.extend_from_slice(D.as_bytes());
46        buf.extend_from_slice(E.as_bytes());
47    }
48
49    buf.extend_from_slice(msg);
50
51    expand_to_scalar(&buf, prefix)
52        .expect("FATAL: DST is less than 256 bytes so operation should not fail")
53}
54
55#[allow(non_snake_case)]
56/// Compute the schnorr challenge from the public key, aggregated commitments, and the signed message using XMD-based expansion.
57pub fn challenge(publicKey: &Point, R: &Point, msg: &[u8]) -> Scalar {
58    let tag = "BIP0340/challenge";
59
60    let mut hasher = tagged_hash(tag);
61
62    hasher.update(R.x().to_bytes());
63    hasher.update(publicKey.x().to_bytes());
64    hasher.update(msg);
65
66    hash_to_scalar(&mut hasher)
67}
68
69/// Compute the Lagrange interpolation value
70pub fn lambda(i: u32, key_ids: &[u32]) -> Scalar {
71    let mut lambda = Scalar::one();
72    let i_scalar = id(i);
73    for j in key_ids {
74        if i != *j {
75            let j_scalar = id(*j);
76            lambda *= j_scalar / (j_scalar - i_scalar);
77        }
78    }
79    lambda
80}
81
82// Is this the best way to return these values?
83#[allow(non_snake_case)]
84/// Compute the intermediate values used in both the parties and the aggregator
85pub fn intermediate(msg: &[u8], party_ids: &[u32], nonces: &[PublicNonce]) -> (Vec<Point>, Point) {
86    let rhos: Vec<Scalar> = party_ids
87        .iter()
88        .map(|&i| binding(&id(i), nonces, msg))
89        .collect();
90    let R_vec: Vec<Point> = zip(nonces, rhos)
91        .map(|(nonce, rho)| nonce.D + rho * nonce.E)
92        .collect();
93
94    let R = R_vec.iter().fold(Point::zero(), |R, &R_i| R + R_i);
95    (R_vec, R)
96}
97
98#[allow(non_snake_case)]
99/// Compute the aggregate nonce
100pub fn aggregate_nonce(
101    msg: &[u8],
102    party_ids: &[u32],
103    nonces: &[PublicNonce],
104) -> Result<Point, PointError> {
105    let compressed_nonces: Vec<(Compressed, Compressed)> = nonces
106        .iter()
107        .map(|nonce| (nonce.D.compress(), nonce.E.compress()))
108        .collect();
109    let scalars: Vec<Scalar> = party_ids
110        .iter()
111        .flat_map(|&i| {
112            [
113                Scalar::from(1),
114                binding_compressed(&id(i), &compressed_nonces, msg),
115            ]
116        })
117        .collect();
118    let points: Vec<Point> = nonces.iter().flat_map(|nonce| [nonce.D, nonce.E]).collect();
119
120    Point::multimult(scalars, points)
121}
122
123/// Compute a one-based Scalar from a zero-based integer
124pub fn id(i: u32) -> Scalar {
125    Scalar::from(i)
126}
127
128/// Evaluate the public polynomial `f` at scalar `x` using multi-exponentiation
129#[allow(clippy::ptr_arg)]
130pub fn poly(x: &Scalar, f: &Vec<Point>) -> Result<Point, PointError> {
131    let mut s = Vec::with_capacity(f.len());
132    let mut pow = Scalar::one();
133    for _ in 0..f.len() {
134        s.push(pow);
135        pow *= x;
136    }
137
138    Point::multimult(s, f.clone())
139}
140
141/// Create a BIP340 compliant tagged hash by double hashing the tag
142pub fn tagged_hash(tag: &str) -> Sha256 {
143    let mut hasher = Sha256::new();
144    let mut tag_hasher = Sha256::new();
145
146    tag_hasher.update(tag.as_bytes());
147    let tag_hash = tag_hasher.finalize();
148
149    hasher.update(tag_hash);
150    hasher.update(tag_hash);
151
152    hasher
153}
154
155/// Create a BIP341 compliant taproot tweak from a public key and merkle root
156pub fn tweak(public_key: &Point, merkle_root: Option<[u8; 32]>) -> Scalar {
157    let mut hasher = tagged_hash("TapTweak");
158
159    hasher.update(public_key.x().to_bytes());
160    if let Some(root) = merkle_root {
161        hasher.update(root);
162    }
163
164    hash_to_scalar(&mut hasher)
165}
166
167/// Create a BIP341 compliant taproot tweak from a public key and merkle root
168pub fn tweaked_public_key(public_key: &Point, merkle_root: Option<[u8; 32]>) -> Point {
169    tweaked_public_key_from_tweak(public_key, tweak(public_key, merkle_root))
170}
171
172/// Create a BIP341 compliant taproot tweak from a public key and a pre-calculated tweak
173///
174/// We should never trigger the unwrap here, because Point::lift_x only returns an error
175/// when the x-coordinate is not on the secp256k1 curve, but we know that public_key.x()
176/// is on the curve because it is a Point.
177pub fn tweaked_public_key_from_tweak(public_key: &Point, tweak: Scalar) -> Point {
178    Point::lift_x(&public_key.x()).unwrap() + tweak * G
179}
180
181/// Create a taproot style merkle root from the serialized script data
182pub fn merkle_root(data: &[u8]) -> [u8; 32] {
183    let mut hasher = tagged_hash("TapLeaf");
184
185    hasher.update(data);
186
187    hasher.finalize().into()
188}
189
190/// Get a Point from an ecdsa::PublicKey
191pub fn point(key: &ecdsa::PublicKey) -> Result<Point, PointError> {
192    let compressed = Compressed::from(key.to_bytes());
193    // this should not fail as long as the public key above was valid
194    Point::try_from(&compressed)
195}