1use std::sync::Arc;
2
3#[cfg(feature = "smol")]
4use futures_rustls::rustls;
5
6#[cfg(feature = "tokio")]
7use tokio_rustls::rustls;
8
9use rustls::{
10 client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
11 crypto::{
12 aws_lc_rs::{self, cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, kx_group},
13 CryptoProvider, SupportedKxGroup,
14 },
15 server::danger::{ClientCertVerified, ClientCertVerifier},
16 CertificateError, DigitallySignedStruct, DistinguishedName,
17 Error::InvalidCertificate,
18 SignatureScheme, SupportedCipherSuite, SupportedProtocolVersion,
19};
20
21use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
22
23use log::error;
24use rcgen::PublicKeyData;
25use x509_parser::{asn1_rs::Oid, certificate::X509Certificate, parse_x509_certificate};
26
27use karyon_core::crypto::{KeyPair, KeyPairType, PublicKey};
28
29use crate::{PeerID, Result};
30
31static PROTOCOL_VERSIONS: &[&SupportedProtocolVersion] = &[&rustls::version::TLS13];
34static CIPHER_SUITES: &[SupportedCipherSuite] = &[TLS13_CHACHA20_POLY1305_SHA256];
35static KX_GROUPS: &[&dyn SupportedKxGroup] = &[kx_group::X25519];
36static SIGNATURE_SCHEMES: &[SignatureScheme] = &[SignatureScheme::ED25519];
37
38const PEER_ID_EXT_OID: &[u64] = &[0, 0, 0, 0];
41
42const BAD_SIGNATURE_ERR: rustls::Error = InvalidCertificate(CertificateError::BadSignature);
43const BAD_ENCODING_ERR: rustls::Error = InvalidCertificate(CertificateError::BadEncoding);
44
45pub fn tls_client_config(
47 key_pair: &KeyPair,
48 peer_id: Option<PeerID>,
49) -> Result<rustls::ClientConfig> {
50 let (cert, private_key) = generate_cert(key_pair)?;
51 let server_verifier = SrvrCertVerifier { peer_id };
52
53 let client_config = rustls::ClientConfig::builder_with_provider(
54 CryptoProvider {
55 kx_groups: KX_GROUPS.to_vec(),
56 cipher_suites: CIPHER_SUITES.to_vec(),
57 ..aws_lc_rs::default_provider()
58 }
59 .into(),
60 )
61 .with_protocol_versions(PROTOCOL_VERSIONS)?
62 .dangerous()
63 .with_custom_certificate_verifier(Arc::new(server_verifier))
64 .with_client_auth_cert(vec![cert], private_key)?;
65
66 Ok(client_config)
67}
68
69pub fn tls_server_config(key_pair: &KeyPair) -> Result<rustls::ServerConfig> {
71 let (cert, private_key) = generate_cert(key_pair)?;
72 let client_verifier = CliCertVerifier {};
73 let server_config = rustls::ServerConfig::builder_with_provider(
74 CryptoProvider {
75 kx_groups: KX_GROUPS.to_vec(),
76 cipher_suites: CIPHER_SUITES.to_vec(),
77 ..aws_lc_rs::default_provider()
78 }
79 .into(),
80 )
81 .with_protocol_versions(PROTOCOL_VERSIONS)?
82 .with_client_cert_verifier(Arc::new(client_verifier))
83 .with_single_cert(vec![cert], private_key)?;
84
85 Ok(server_config)
86}
87
88fn generate_cert<'a>(key_pair: &KeyPair) -> Result<(CertificateDer<'a>, PrivateKeyDer<'a>)> {
90 let cert_key_pair = rcgen::KeyPair::generate_for(&rcgen::PKCS_ED25519)?;
91 let private_key = PrivateKeyDer::Pkcs8(cert_key_pair.serialize_der().into());
92
93 let signature = key_pair.sign(&cert_key_pair.subject_public_key_info());
97 let ext_content = yasna::encode_der(&(key_pair.public().as_bytes().to_vec(), signature));
98 let mut ext = rcgen::CustomExtension::from_oid_content(PEER_ID_EXT_OID, ext_content);
99 ext.set_criticality(false);
103
104 let mut params = rcgen::CertificateParams::new(vec![])?;
105 params.custom_extensions.push(ext);
106
107 let cert = CertificateDer::from(params.self_signed(&cert_key_pair)?);
108 Ok((cert, private_key))
109}
110
111pub(crate) fn peer_id_from_certs(certs: &[CertificateDer<'_>]) -> Option<PeerID> {
115 let end_entity = certs.first()?;
116 verify_cert(end_entity).ok()
117}
118
119fn verify_cert(end_entity: &CertificateDer<'_>) -> std::result::Result<PeerID, rustls::Error> {
121 let cert = parse_cert(end_entity)?;
123
124 let want_oid = Oid::from(PEER_ID_EXT_OID).map_err(|_| BAD_ENCODING_ERR)?;
126 let ext = cert
127 .extensions()
128 .iter()
129 .find(|e| e.oid == want_oid)
130 .ok_or(BAD_ENCODING_ERR)?;
131
132 let (public_key, signature): (Vec<u8>, Vec<u8>) =
134 yasna::decode_der(ext.value).map_err(|_| BAD_ENCODING_ERR)?;
135
136 let public_key =
138 PublicKey::from_bytes(&KeyPairType::Ed25519, &public_key).map_err(|_| BAD_ENCODING_ERR)?;
139 public_key
140 .verify(cert.public_key().raw, &signature)
141 .map_err(|_| BAD_SIGNATURE_ERR)?;
142
143 verify_cert_signature(
145 &cert,
146 cert.tbs_certificate.as_ref(),
147 cert.signature_value.as_ref(),
148 )?;
149
150 PeerID::try_from(public_key).map_err(|_| BAD_ENCODING_ERR)
151}
152
153fn parse_cert<'a>(
155 end_entity: &'a CertificateDer<'a>,
156) -> std::result::Result<X509Certificate<'a>, rustls::Error> {
157 let (_, cert) = parse_x509_certificate(end_entity.as_ref()).map_err(|_| BAD_ENCODING_ERR)?;
158
159 if !cert.validity().is_valid() {
160 return Err(InvalidCertificate(CertificateError::NotValidYet));
161 }
162
163 Ok(cert)
164}
165
166fn verify_cert_signature(
168 cert: &X509Certificate,
169 message: &[u8],
170 signature: &[u8],
171) -> std::result::Result<(), rustls::Error> {
172 let public_key = PublicKey::from_bytes(
173 &KeyPairType::Ed25519,
174 cert.tbs_certificate.subject_pki.subject_public_key.as_ref(),
175 )
176 .map_err(|_| BAD_ENCODING_ERR)?;
177
178 public_key
179 .verify(message, signature)
180 .map_err(|_| BAD_SIGNATURE_ERR)
181}
182
183#[derive(Debug)]
184struct SrvrCertVerifier {
185 peer_id: Option<PeerID>,
186}
187
188impl ServerCertVerifier for SrvrCertVerifier {
189 fn verify_server_cert(
190 &self,
191 end_entity: &CertificateDer<'_>,
192 _intermediates: &[CertificateDer<'_>],
193 _server_name: &ServerName,
194 _ocsp_response: &[u8],
195 _now: UnixTime,
196 ) -> std::result::Result<ServerCertVerified, rustls::Error> {
197 let peer_id = match verify_cert(end_entity) {
198 Ok(pid) => pid,
199 Err(err) => {
200 error!("Failed to verify cert: {err}");
201 return Err(err);
202 }
203 };
204
205 if let Some(pid) = &self.peer_id {
209 if pid != &peer_id {
210 return Err(InvalidCertificate(
211 CertificateError::ApplicationVerificationFailure,
212 ));
213 }
214 }
215
216 Ok(ServerCertVerified::assertion())
217 }
218
219 fn verify_tls12_signature(
220 &self,
221 _message: &[u8],
222 _cert: &CertificateDer<'_>,
223 _dss: &DigitallySignedStruct,
224 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
225 unreachable!("ONLY SUPPORT tls 13 VERSION")
226 }
227
228 fn verify_tls13_signature(
229 &self,
230 message: &[u8],
231 cert: &CertificateDer<'_>,
232 dss: &DigitallySignedStruct,
233 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
234 let cert = parse_cert(cert)?;
235 verify_cert_signature(&cert, message, dss.signature())?;
236 Ok(HandshakeSignatureValid::assertion())
237 }
238
239 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
240 SIGNATURE_SCHEMES.to_vec()
241 }
242}
243
244#[derive(Debug)]
245struct CliCertVerifier {}
246impl ClientCertVerifier for CliCertVerifier {
247 fn verify_client_cert(
248 &self,
249 end_entity: &CertificateDer<'_>,
250 _intermediates: &[CertificateDer<'_>],
251 _now: UnixTime,
252 ) -> std::result::Result<ClientCertVerified, rustls::Error> {
253 if let Err(err) = verify_cert(end_entity) {
254 error!("Failed to verify cert: {err}");
255 return Err(err);
256 };
257 Ok(ClientCertVerified::assertion())
258 }
259
260 fn root_hint_subjects(&self) -> &[DistinguishedName] {
261 &[]
262 }
263
264 fn verify_tls12_signature(
265 &self,
266 _message: &[u8],
267 _cert: &CertificateDer<'_>,
268 _dss: &DigitallySignedStruct,
269 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
270 unreachable!("ONLY SUPPORT tls 13 VERSION")
271 }
272
273 fn verify_tls13_signature(
274 &self,
275 message: &[u8],
276 cert: &CertificateDer<'_>,
277 dss: &DigitallySignedStruct,
278 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
279 let cert = parse_cert(cert)?;
280 verify_cert_signature(&cert, message, dss.signature())?;
281 Ok(HandshakeSignatureValid::assertion())
282 }
283
284 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
285 SIGNATURE_SCHEMES.to_vec()
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[test]
294 fn verify_generated_certificate() {
295 let key_pair = KeyPair::generate(&KeyPairType::Ed25519);
296 let (cert, _) = generate_cert(&key_pair).unwrap();
297
298 let result = verify_cert(&cert);
299 assert!(result.is_ok());
300 let peer_id = result.unwrap();
301 assert_eq!(peer_id, PeerID::try_from(key_pair.public()).unwrap());
302 assert_eq!(peer_id.0, key_pair.public().as_bytes());
303 }
304}