1use core::iter::zip;
2use num_traits::{One, Zero};
3use sha2::{Digest, Sha256};
4
5use crate::{
6 common::PublicNonce,
7 curve::{
8 ecdsa,
9 point::{Compressed, Error as PointError, Point, G},
10 scalar::Scalar,
11 },
12 util::{expand_to_scalar, hash_to_scalar},
13};
14
15#[allow(non_snake_case)]
16pub fn group_commitment(commitment_list: &[(Scalar, PublicNonce)]) -> Scalar {
18 let prefix = b"WSTS/group_commitment";
19
20 let mut buf = Vec::new();
21 for (id, public_nonce) in commitment_list {
22 buf.extend_from_slice(&id.to_bytes());
23 buf.extend_from_slice(public_nonce.D.compress().as_bytes());
24 buf.extend_from_slice(public_nonce.E.compress().as_bytes());
25 }
26
27 expand_to_scalar(&buf, prefix)
28 .expect("FATAL: DST is less than 256 bytes so operation should not fail")
29}
30
31#[allow(non_snake_case)]
32pub fn group_commitment_compressed(commitment_list: &[(Scalar, Compressed, Compressed)]) -> Scalar {
34 let prefix = b"WSTS/group_commitment";
35
36 let mut buf = Vec::new();
37 for (id, hiding_commitment, binding_commitment) in commitment_list {
38 buf.extend_from_slice(&id.to_bytes());
39 buf.extend_from_slice(hiding_commitment.as_bytes());
40 buf.extend_from_slice(binding_commitment.as_bytes());
41 }
42
43 expand_to_scalar(&buf, prefix)
44 .expect("FATAL: DST is less than 256 bytes so operation should not fail")
45}
46
47#[allow(non_snake_case)]
48pub fn binding(
50 id: &Scalar,
51 group_public_key: Point,
52 commitment_list: &[(Scalar, PublicNonce)],
53 msg: &[u8],
54) -> Scalar {
55 let prefix = b"WSTS/binding";
56 let encoded_group_commitment = group_commitment(commitment_list);
57
58 let mut buf = Vec::new();
59 buf.extend_from_slice(&id.to_bytes());
60 buf.extend_from_slice(group_public_key.compress().as_bytes());
61 buf.extend_from_slice(msg);
62 buf.extend_from_slice(&encoded_group_commitment.to_bytes());
63
64 expand_to_scalar(&buf, prefix)
65 .expect("FATAL: DST is less than 256 bytes so operation should not fail")
66}
67
68#[allow(non_snake_case)]
69pub fn binding_compressed(
71 id: &Scalar,
72 group_public_key: Point,
73 commitment_list: &[(Scalar, Compressed, Compressed)],
74 msg: &[u8],
75) -> Scalar {
76 let prefix = b"WSTS/binding";
77 let encoded_group_commitment = group_commitment_compressed(commitment_list);
78
79 let mut buf = Vec::new();
80 buf.extend_from_slice(&id.to_bytes());
81 buf.extend_from_slice(group_public_key.compress().as_bytes());
82 buf.extend_from_slice(msg);
83 buf.extend_from_slice(&encoded_group_commitment.to_bytes());
84
85 expand_to_scalar(&buf, prefix)
86 .expect("FATAL: DST is less than 256 bytes so operation should not fail")
87}
88
89#[allow(non_snake_case)]
90pub fn challenge(publicKey: &Point, R: &Point, msg: &[u8]) -> Scalar {
92 let tag = "BIP0340/challenge";
93
94 let mut hasher = tagged_hash(tag);
95
96 hasher.update(R.x().to_bytes());
97 hasher.update(publicKey.x().to_bytes());
98 hasher.update(msg);
99
100 hash_to_scalar(&mut hasher)
101}
102
103pub fn lambda(i: u32, key_ids: &[u32]) -> Scalar {
105 let mut lambda = Scalar::one();
106 let i_scalar = id(i);
107 for j in key_ids {
108 if i != *j {
109 let j_scalar = id(*j);
110 lambda *= j_scalar / (j_scalar - i_scalar);
111 }
112 }
113 lambda
114}
115
116#[allow(non_snake_case)]
118pub fn intermediate(
120 msg: &[u8],
121 group_key: Point,
122 party_ids: &[u32],
123 nonces: &[PublicNonce],
124) -> (Vec<Point>, Point) {
125 let commitment_list: Vec<(Scalar, PublicNonce)> = party_ids
126 .iter()
127 .zip(nonces)
128 .map(|(i, nonce)| (Scalar::from(*i), nonce.clone()))
129 .collect();
130 let rhos: Vec<Scalar> = party_ids
131 .iter()
132 .map(|i| binding(&id(*i), group_key, &commitment_list, msg))
133 .collect();
134 let R_vec: Vec<Point> = zip(nonces, rhos)
135 .map(|(nonce, rho)| nonce.D + rho * nonce.E)
136 .collect();
137
138 let R = R_vec.iter().fold(Point::zero(), |R, &R_i| R + R_i);
139 (R_vec, R)
140}
141
142#[allow(non_snake_case)]
143pub fn aggregate_nonce(
145 msg: &[u8],
146 group_key: Point,
147 party_ids: &[u32],
148 nonces: &[PublicNonce],
149) -> Result<Point, PointError> {
150 let commitment_list: Vec<(Scalar, Compressed, Compressed)> = party_ids
151 .iter()
152 .zip(nonces)
153 .map(|(id, nonce)| (Scalar::from(*id), nonce.D.compress(), nonce.E.compress()))
154 .collect();
155 let scalars: Vec<Scalar> = party_ids
156 .iter()
157 .flat_map(|&i| {
158 [
159 Scalar::from(1),
160 binding_compressed(&id(i), group_key, &commitment_list, msg),
161 ]
162 })
163 .collect();
164 let points: Vec<Point> = nonces.iter().flat_map(|nonce| [nonce.D, nonce.E]).collect();
165
166 Point::multimult(scalars, points)
167}
168
169pub fn id(i: u32) -> Scalar {
171 Scalar::from(i)
172}
173
174#[allow(clippy::ptr_arg)]
176pub fn poly(x: &Scalar, f: &Vec<Point>) -> Result<Point, PointError> {
177 let mut s = Vec::with_capacity(f.len());
178 let mut pow = Scalar::one();
179 for _ in 0..f.len() {
180 s.push(pow);
181 pow *= x;
182 }
183
184 Point::multimult(s, f.clone())
185}
186
187pub fn tagged_hash(tag: &str) -> Sha256 {
189 let mut hasher = Sha256::new();
190 let mut tag_hasher = Sha256::new();
191
192 tag_hasher.update(tag.as_bytes());
193 let tag_hash = tag_hasher.finalize();
194
195 hasher.update(tag_hash);
196 hasher.update(tag_hash);
197
198 hasher
199}
200
201pub fn tweak(public_key: &Point, merkle_root: Option<[u8; 32]>) -> Scalar {
203 let mut hasher = tagged_hash("TapTweak");
204
205 hasher.update(public_key.x().to_bytes());
206 if let Some(root) = merkle_root {
207 hasher.update(root);
208 }
209
210 hash_to_scalar(&mut hasher)
211}
212
213pub fn tweaked_public_key(public_key: &Point, merkle_root: Option<[u8; 32]>) -> Point {
215 tweaked_public_key_from_tweak(public_key, tweak(public_key, merkle_root))
216}
217
218pub fn tweaked_public_key_from_tweak(public_key: &Point, tweak: Scalar) -> Point {
224 Point::lift_x(&public_key.x()).unwrap() + tweak * G
225}
226
227pub fn merkle_root(data: &[u8]) -> [u8; 32] {
229 let mut hasher = tagged_hash("TapLeaf");
230
231 hasher.update(data);
232
233 hasher.finalize().into()
234}
235
236pub fn point(key: &ecdsa::PublicKey) -> Result<Point, PointError> {
238 let compressed = Compressed::from(key.to_bytes());
239 Point::try_from(&compressed)
241}