wsts/
compute.rs

1use core::iter::zip;
2use num_traits::{One, Zero};
3use sha2::{Digest, Sha256};
4
5use crate::{
6    common::PublicNonce,
7    curve::{
8        ecdsa,
9        point::{Compressed, Error as PointError, Point, G},
10        scalar::Scalar,
11    },
12    util::{expand_to_scalar, hash_to_scalar},
13};
14
15#[allow(non_snake_case)]
16/// Compute the group commitment from the list of PartyIDs and nonce commitments using XMD-based expansion.
17pub fn group_commitment(commitment_list: &[(Scalar, PublicNonce)]) -> Scalar {
18    let prefix = b"WSTS/group_commitment";
19
20    let mut buf = Vec::new();
21    for (id, public_nonce) in commitment_list {
22        buf.extend_from_slice(&id.to_bytes());
23        buf.extend_from_slice(public_nonce.D.compress().as_bytes());
24        buf.extend_from_slice(public_nonce.E.compress().as_bytes());
25    }
26
27    expand_to_scalar(&buf, prefix)
28        .expect("FATAL: DST is less than 256 bytes so operation should not fail")
29}
30
31#[allow(non_snake_case)]
32/// Compute the group commitment from the list of PartyIDs and nonce commitments using XMD-based expansion.
33pub fn group_commitment_compressed(commitment_list: &[(Scalar, Compressed, Compressed)]) -> Scalar {
34    let prefix = b"WSTS/group_commitment";
35
36    let mut buf = Vec::new();
37    for (id, hiding_commitment, binding_commitment) in commitment_list {
38        buf.extend_from_slice(&id.to_bytes());
39        buf.extend_from_slice(hiding_commitment.as_bytes());
40        buf.extend_from_slice(binding_commitment.as_bytes());
41    }
42
43    expand_to_scalar(&buf, prefix)
44        .expect("FATAL: DST is less than 256 bytes so operation should not fail")
45}
46
47#[allow(non_snake_case)]
48/// Compute a binding value from the party ID, public nonces, and signed message using XMD-based expansion.
49pub fn binding(
50    id: &Scalar,
51    group_public_key: Point,
52    commitment_list: &[(Scalar, PublicNonce)],
53    msg: &[u8],
54) -> Scalar {
55    let prefix = b"WSTS/binding";
56    let encoded_group_commitment = group_commitment(commitment_list);
57
58    let mut buf = Vec::new();
59    buf.extend_from_slice(&id.to_bytes());
60    buf.extend_from_slice(group_public_key.compress().as_bytes());
61    buf.extend_from_slice(msg);
62    buf.extend_from_slice(&encoded_group_commitment.to_bytes());
63
64    expand_to_scalar(&buf, prefix)
65        .expect("FATAL: DST is less than 256 bytes so operation should not fail")
66}
67
68#[allow(non_snake_case)]
69/// Compute a binding value from the party ID, public nonces, and signed message using XMD-based expansion.
70pub fn binding_compressed(
71    id: &Scalar,
72    group_public_key: Point,
73    commitment_list: &[(Scalar, Compressed, Compressed)],
74    msg: &[u8],
75) -> Scalar {
76    let prefix = b"WSTS/binding";
77    let encoded_group_commitment = group_commitment_compressed(commitment_list);
78
79    let mut buf = Vec::new();
80    buf.extend_from_slice(&id.to_bytes());
81    buf.extend_from_slice(group_public_key.compress().as_bytes());
82    buf.extend_from_slice(msg);
83    buf.extend_from_slice(&encoded_group_commitment.to_bytes());
84
85    expand_to_scalar(&buf, prefix)
86        .expect("FATAL: DST is less than 256 bytes so operation should not fail")
87}
88
89#[allow(non_snake_case)]
90/// Compute the schnorr challenge from the public key, aggregated commitments, and the signed message using XMD-based expansion.
91pub fn challenge(publicKey: &Point, R: &Point, msg: &[u8]) -> Scalar {
92    let tag = "BIP0340/challenge";
93
94    let mut hasher = tagged_hash(tag);
95
96    hasher.update(R.x().to_bytes());
97    hasher.update(publicKey.x().to_bytes());
98    hasher.update(msg);
99
100    hash_to_scalar(&mut hasher)
101}
102
103/// Compute the Lagrange interpolation value
104pub fn lambda(i: u32, key_ids: &[u32]) -> Scalar {
105    let mut lambda = Scalar::one();
106    let i_scalar = id(i);
107    for j in key_ids {
108        if i != *j {
109            let j_scalar = id(*j);
110            lambda *= j_scalar / (j_scalar - i_scalar);
111        }
112    }
113    lambda
114}
115
116// Is this the best way to return these values?
117#[allow(non_snake_case)]
118/// Compute the intermediate values used in both the parties and the aggregator
119pub fn intermediate(
120    msg: &[u8],
121    group_key: Point,
122    party_ids: &[u32],
123    nonces: &[PublicNonce],
124) -> (Vec<Point>, Point) {
125    let commitment_list: Vec<(Scalar, PublicNonce)> = party_ids
126        .iter()
127        .zip(nonces)
128        .map(|(i, nonce)| (Scalar::from(*i), nonce.clone()))
129        .collect();
130    let rhos: Vec<Scalar> = party_ids
131        .iter()
132        .map(|i| binding(&id(*i), group_key, &commitment_list, msg))
133        .collect();
134    let R_vec: Vec<Point> = zip(nonces, rhos)
135        .map(|(nonce, rho)| nonce.D + rho * nonce.E)
136        .collect();
137
138    let R = R_vec.iter().fold(Point::zero(), |R, &R_i| R + R_i);
139    (R_vec, R)
140}
141
142#[allow(non_snake_case)]
143/// Compute the aggregate nonce
144pub fn aggregate_nonce(
145    msg: &[u8],
146    group_key: Point,
147    party_ids: &[u32],
148    nonces: &[PublicNonce],
149) -> Result<Point, PointError> {
150    let commitment_list: Vec<(Scalar, Compressed, Compressed)> = party_ids
151        .iter()
152        .zip(nonces)
153        .map(|(id, nonce)| (Scalar::from(*id), nonce.D.compress(), nonce.E.compress()))
154        .collect();
155    let scalars: Vec<Scalar> = party_ids
156        .iter()
157        .flat_map(|&i| {
158            [
159                Scalar::from(1),
160                binding_compressed(&id(i), group_key, &commitment_list, msg),
161            ]
162        })
163        .collect();
164    let points: Vec<Point> = nonces.iter().flat_map(|nonce| [nonce.D, nonce.E]).collect();
165
166    Point::multimult(scalars, points)
167}
168
169/// Compute a one-based Scalar from a zero-based integer
170pub fn id(i: u32) -> Scalar {
171    Scalar::from(i)
172}
173
174/// Evaluate the public polynomial `f` at scalar `x` using multi-exponentiation
175#[allow(clippy::ptr_arg)]
176pub fn poly(x: &Scalar, f: &Vec<Point>) -> Result<Point, PointError> {
177    let mut s = Vec::with_capacity(f.len());
178    let mut pow = Scalar::one();
179    for _ in 0..f.len() {
180        s.push(pow);
181        pow *= x;
182    }
183
184    Point::multimult(s, f.clone())
185}
186
187/// Create a BIP340 compliant tagged hash by double hashing the tag
188pub fn tagged_hash(tag: &str) -> Sha256 {
189    let mut hasher = Sha256::new();
190    let mut tag_hasher = Sha256::new();
191
192    tag_hasher.update(tag.as_bytes());
193    let tag_hash = tag_hasher.finalize();
194
195    hasher.update(tag_hash);
196    hasher.update(tag_hash);
197
198    hasher
199}
200
201/// Create a BIP341 compliant taproot tweak from a public key and merkle root
202pub fn tweak(public_key: &Point, merkle_root: Option<[u8; 32]>) -> Scalar {
203    let mut hasher = tagged_hash("TapTweak");
204
205    hasher.update(public_key.x().to_bytes());
206    if let Some(root) = merkle_root {
207        hasher.update(root);
208    }
209
210    hash_to_scalar(&mut hasher)
211}
212
213/// Create a BIP341 compliant taproot tweak from a public key and merkle root
214pub fn tweaked_public_key(public_key: &Point, merkle_root: Option<[u8; 32]>) -> Point {
215    tweaked_public_key_from_tweak(public_key, tweak(public_key, merkle_root))
216}
217
218/// Create a BIP341 compliant taproot tweak from a public key and a pre-calculated tweak
219///
220/// We should never trigger the unwrap here, because Point::lift_x only returns an error
221/// when the x-coordinate is not on the secp256k1 curve, but we know that public_key.x()
222/// is on the curve because it is a Point.
223pub fn tweaked_public_key_from_tweak(public_key: &Point, tweak: Scalar) -> Point {
224    Point::lift_x(&public_key.x()).unwrap() + tweak * G
225}
226
227/// Create a taproot style merkle root from the serialized script data
228pub fn merkle_root(data: &[u8]) -> [u8; 32] {
229    let mut hasher = tagged_hash("TapLeaf");
230
231    hasher.update(data);
232
233    hasher.finalize().into()
234}
235
236/// Get a Point from an ecdsa::PublicKey
237pub fn point(key: &ecdsa::PublicKey) -> Result<Point, PointError> {
238    let compressed = Compressed::from(key.to_bytes());
239    // this should not fail as long as the public key above was valid
240    Point::try_from(&compressed)
241}