1use core::iter::zip;
2use num_traits::{One, Zero};
3use sha2::{Digest, Sha256};
4
5use crate::{
6 common::PublicNonce,
7 curve::{
8 ecdsa,
9 point::{Compressed, Error as PointError, Point, G},
10 scalar::Scalar,
11 },
12 util::{expand_to_scalar, hash_to_scalar},
13};
14
15#[allow(non_snake_case)]
16pub fn binding(id: &Scalar, B: &[PublicNonce], msg: &[u8]) -> Scalar {
18 let prefix = b"WSTS/binding";
19
20 let mut buf = Vec::new();
22 buf.extend_from_slice(&id.to_bytes());
23
24 for b in B {
25 buf.extend_from_slice(b.D.compress().as_bytes());
26 buf.extend_from_slice(b.E.compress().as_bytes());
27 }
28
29 buf.extend_from_slice(msg);
30
31 expand_to_scalar(&buf, prefix)
32 .expect("FATAL: DST is less than 256 bytes so operation should not fail")
33}
34
35#[allow(non_snake_case)]
36pub fn binding_compressed(id: &Scalar, B: &[(Compressed, Compressed)], msg: &[u8]) -> Scalar {
38 let prefix = b"WSTS/binding";
39
40 let mut buf = Vec::new();
42 buf.extend_from_slice(&id.to_bytes());
43
44 for (D, E) in B {
45 buf.extend_from_slice(D.as_bytes());
46 buf.extend_from_slice(E.as_bytes());
47 }
48
49 buf.extend_from_slice(msg);
50
51 expand_to_scalar(&buf, prefix)
52 .expect("FATAL: DST is less than 256 bytes so operation should not fail")
53}
54
55#[allow(non_snake_case)]
56pub fn challenge(publicKey: &Point, R: &Point, msg: &[u8]) -> Scalar {
58 let tag = "BIP0340/challenge";
59
60 let mut hasher = tagged_hash(tag);
61
62 hasher.update(R.x().to_bytes());
63 hasher.update(publicKey.x().to_bytes());
64 hasher.update(msg);
65
66 hash_to_scalar(&mut hasher)
67}
68
69pub fn lambda(i: u32, key_ids: &[u32]) -> Scalar {
71 let mut lambda = Scalar::one();
72 let i_scalar = id(i);
73 for j in key_ids {
74 if i != *j {
75 let j_scalar = id(*j);
76 lambda *= j_scalar / (j_scalar - i_scalar);
77 }
78 }
79 lambda
80}
81
82#[allow(non_snake_case)]
84pub fn intermediate(msg: &[u8], party_ids: &[u32], nonces: &[PublicNonce]) -> (Vec<Point>, Point) {
86 let rhos: Vec<Scalar> = party_ids
87 .iter()
88 .map(|&i| binding(&id(i), nonces, msg))
89 .collect();
90 let R_vec: Vec<Point> = zip(nonces, rhos)
91 .map(|(nonce, rho)| nonce.D + rho * nonce.E)
92 .collect();
93
94 let R = R_vec.iter().fold(Point::zero(), |R, &R_i| R + R_i);
95 (R_vec, R)
96}
97
98#[allow(non_snake_case)]
99pub fn aggregate_nonce(
101 msg: &[u8],
102 party_ids: &[u32],
103 nonces: &[PublicNonce],
104) -> Result<Point, PointError> {
105 let compressed_nonces: Vec<(Compressed, Compressed)> = nonces
106 .iter()
107 .map(|nonce| (nonce.D.compress(), nonce.E.compress()))
108 .collect();
109 let scalars: Vec<Scalar> = party_ids
110 .iter()
111 .flat_map(|&i| {
112 [
113 Scalar::from(1),
114 binding_compressed(&id(i), &compressed_nonces, msg),
115 ]
116 })
117 .collect();
118 let points: Vec<Point> = nonces.iter().flat_map(|nonce| [nonce.D, nonce.E]).collect();
119
120 Point::multimult(scalars, points)
121}
122
123pub fn id(i: u32) -> Scalar {
125 Scalar::from(i)
126}
127
128#[allow(clippy::ptr_arg)]
130pub fn poly(x: &Scalar, f: &Vec<Point>) -> Result<Point, PointError> {
131 let mut s = Vec::with_capacity(f.len());
132 let mut pow = Scalar::one();
133 for _ in 0..f.len() {
134 s.push(pow);
135 pow *= x;
136 }
137
138 Point::multimult(s, f.clone())
139}
140
141pub fn tagged_hash(tag: &str) -> Sha256 {
143 let mut hasher = Sha256::new();
144 let mut tag_hasher = Sha256::new();
145
146 tag_hasher.update(tag.as_bytes());
147 let tag_hash = tag_hasher.finalize();
148
149 hasher.update(tag_hash);
150 hasher.update(tag_hash);
151
152 hasher
153}
154
155pub fn tweak(public_key: &Point, merkle_root: Option<[u8; 32]>) -> Scalar {
157 let mut hasher = tagged_hash("TapTweak");
158
159 hasher.update(public_key.x().to_bytes());
160 if let Some(root) = merkle_root {
161 hasher.update(root);
162 }
163
164 hash_to_scalar(&mut hasher)
165}
166
167pub fn tweaked_public_key(public_key: &Point, merkle_root: Option<[u8; 32]>) -> Point {
169 tweaked_public_key_from_tweak(public_key, tweak(public_key, merkle_root))
170}
171
172pub fn tweaked_public_key_from_tweak(public_key: &Point, tweak: Scalar) -> Point {
178 Point::lift_x(&public_key.x()).unwrap() + tweak * G
179}
180
181pub fn merkle_root(data: &[u8]) -> [u8; 32] {
183 let mut hasher = tagged_hash("TapLeaf");
184
185 hasher.update(data);
186
187 hasher.finalize().into()
188}
189
190pub fn point(key: &ecdsa::PublicKey) -> Result<Point, PointError> {
192 let compressed = Compressed::from(key.to_bytes());
193 Point::try_from(&compressed)
195}