1use std::sync::Arc;
2
3#[cfg(feature = "smol")]
4use futures_rustls::rustls;
5
6#[cfg(feature = "tokio")]
7use tokio_rustls::rustls;
8
9use rustls::{
10 client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
11 crypto::{
12 aws_lc_rs::{self, cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, kx_group},
13 CryptoProvider, SupportedKxGroup,
14 },
15 server::danger::{ClientCertVerified, ClientCertVerifier},
16 CertificateError, DigitallySignedStruct, DistinguishedName,
17 Error::InvalidCertificate,
18 SignatureScheme, SupportedCipherSuite, SupportedProtocolVersion,
19};
20
21use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
22
23use log::error;
24use rcgen::PublicKeyData;
25use x509_parser::{certificate::X509Certificate, parse_x509_certificate};
26
27use karyon_core::crypto::{KeyPair, KeyPairType, PublicKey};
28
29use crate::{PeerID, Result};
30
31static PROTOCOL_VERSIONS: &[&SupportedProtocolVersion] = &[&rustls::version::TLS13];
34static CIPHER_SUITES: &[SupportedCipherSuite] = &[TLS13_CHACHA20_POLY1305_SHA256];
35static KX_GROUPS: &[&dyn SupportedKxGroup] = &[kx_group::X25519];
36static SIGNATURE_SCHEMES: &[SignatureScheme] = &[SignatureScheme::ED25519];
37
38const BAD_SIGNATURE_ERR: rustls::Error = InvalidCertificate(CertificateError::BadSignature);
39const BAD_ENCODING_ERR: rustls::Error = InvalidCertificate(CertificateError::BadEncoding);
40
41pub fn tls_client_config(
43 key_pair: &KeyPair,
44 peer_id: Option<PeerID>,
45) -> Result<rustls::ClientConfig> {
46 let (cert, private_key) = generate_cert(key_pair)?;
47 let server_verifier = SrvrCertVerifier { peer_id };
48
49 let client_config = rustls::ClientConfig::builder_with_provider(
50 CryptoProvider {
51 kx_groups: KX_GROUPS.to_vec(),
52 cipher_suites: CIPHER_SUITES.to_vec(),
53 ..aws_lc_rs::default_provider()
54 }
55 .into(),
56 )
57 .with_protocol_versions(PROTOCOL_VERSIONS)?
58 .dangerous()
59 .with_custom_certificate_verifier(Arc::new(server_verifier))
60 .with_client_auth_cert(vec![cert], private_key)?;
61
62 Ok(client_config)
63}
64
65pub fn tls_server_config(key_pair: &KeyPair) -> Result<rustls::ServerConfig> {
67 let (cert, private_key) = generate_cert(key_pair)?;
68 let client_verifier = CliCertVerifier {};
69 let server_config = rustls::ServerConfig::builder_with_provider(
70 CryptoProvider {
71 kx_groups: KX_GROUPS.to_vec(),
72 cipher_suites: CIPHER_SUITES.to_vec(),
73 ..aws_lc_rs::default_provider()
74 }
75 .into(),
76 )
77 .with_protocol_versions(PROTOCOL_VERSIONS)?
78 .with_client_cert_verifier(Arc::new(client_verifier))
79 .with_single_cert(vec![cert], private_key)?;
80
81 Ok(server_config)
82}
83
84fn generate_cert<'a>(key_pair: &KeyPair) -> Result<(CertificateDer<'a>, PrivateKeyDer<'a>)> {
86 let cert_key_pair = rcgen::KeyPair::generate_for(&rcgen::PKCS_ED25519)?;
87 let private_key = PrivateKeyDer::Pkcs8(cert_key_pair.serialize_der().into());
88
89 let signature = key_pair.sign(&cert_key_pair.subject_public_key_info());
93 let ext_content = yasna::encode_der(&(key_pair.public().as_bytes().to_vec(), signature));
94 let mut ext = rcgen::CustomExtension::from_oid_content(&[0, 0, 0, 0], ext_content);
96 ext.set_criticality(true);
97
98 let mut params = rcgen::CertificateParams::new(vec![])?;
99 params.custom_extensions.push(ext);
100
101 let cert = CertificateDer::from(params.self_signed(&cert_key_pair)?);
102 Ok((cert, private_key))
103}
104
105fn verify_cert(end_entity: &CertificateDer<'_>) -> std::result::Result<PeerID, rustls::Error> {
107 let cert = parse_cert(end_entity)?;
109
110 match cert.extensions().first() {
111 Some(ext) => {
112 let (public_key, signature): (Vec<u8>, Vec<u8>) =
114 yasna::decode_der(ext.value).map_err(|_| BAD_ENCODING_ERR)?;
115
116 let public_key = PublicKey::from_bytes(&KeyPairType::Ed25519, &public_key)
118 .map_err(|_| BAD_ENCODING_ERR)?;
119 public_key
120 .verify(cert.public_key().raw, &signature)
121 .map_err(|_| BAD_SIGNATURE_ERR)?;
122
123 verify_cert_signature(
125 &cert,
126 cert.tbs_certificate.as_ref(),
127 cert.signature_value.as_ref(),
128 )?;
129
130 PeerID::try_from(public_key).map_err(|_| BAD_ENCODING_ERR)
131 }
132 None => Err(BAD_ENCODING_ERR),
133 }
134}
135
136fn parse_cert<'a>(
138 end_entity: &'a CertificateDer<'a>,
139) -> std::result::Result<X509Certificate<'a>, rustls::Error> {
140 let (_, cert) = parse_x509_certificate(end_entity.as_ref()).map_err(|_| BAD_ENCODING_ERR)?;
141
142 if !cert.validity().is_valid() {
143 return Err(InvalidCertificate(CertificateError::NotValidYet));
144 }
145
146 Ok(cert)
147}
148
149fn verify_cert_signature(
151 cert: &X509Certificate,
152 message: &[u8],
153 signature: &[u8],
154) -> std::result::Result<(), rustls::Error> {
155 let public_key = PublicKey::from_bytes(
156 &KeyPairType::Ed25519,
157 cert.tbs_certificate.subject_pki.subject_public_key.as_ref(),
158 )
159 .map_err(|_| BAD_ENCODING_ERR)?;
160
161 public_key
162 .verify(message, signature)
163 .map_err(|_| BAD_SIGNATURE_ERR)
164}
165
166#[derive(Debug)]
167struct SrvrCertVerifier {
168 peer_id: Option<PeerID>,
169}
170
171impl ServerCertVerifier for SrvrCertVerifier {
172 fn verify_server_cert(
173 &self,
174 end_entity: &CertificateDer<'_>,
175 _intermediates: &[CertificateDer<'_>],
176 _server_name: &ServerName,
177 _ocsp_response: &[u8],
178 _now: UnixTime,
179 ) -> std::result::Result<ServerCertVerified, rustls::Error> {
180 let peer_id = match verify_cert(end_entity) {
181 Ok(pid) => pid,
182 Err(err) => {
183 error!("Failed to verify cert: {err}");
184 return Err(err);
185 }
186 };
187
188 if let Some(pid) = &self.peer_id {
192 if pid != &peer_id {
193 return Err(InvalidCertificate(
194 CertificateError::ApplicationVerificationFailure,
195 ));
196 }
197 }
198
199 Ok(ServerCertVerified::assertion())
200 }
201
202 fn verify_tls12_signature(
203 &self,
204 _message: &[u8],
205 _cert: &CertificateDer<'_>,
206 _dss: &DigitallySignedStruct,
207 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
208 unreachable!("ONLY SUPPORT tls 13 VERSION")
209 }
210
211 fn verify_tls13_signature(
212 &self,
213 message: &[u8],
214 cert: &CertificateDer<'_>,
215 dss: &DigitallySignedStruct,
216 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
217 let cert = parse_cert(cert)?;
218 verify_cert_signature(&cert, message, dss.signature())?;
219 Ok(HandshakeSignatureValid::assertion())
220 }
221
222 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
223 SIGNATURE_SCHEMES.to_vec()
224 }
225}
226
227#[derive(Debug)]
228struct CliCertVerifier {}
229impl ClientCertVerifier for CliCertVerifier {
230 fn verify_client_cert(
231 &self,
232 end_entity: &CertificateDer<'_>,
233 _intermediates: &[CertificateDer<'_>],
234 _now: UnixTime,
235 ) -> std::result::Result<ClientCertVerified, rustls::Error> {
236 if let Err(err) = verify_cert(end_entity) {
237 error!("Failed to verify cert: {err}");
238 return Err(err);
239 };
240 Ok(ClientCertVerified::assertion())
241 }
242
243 fn root_hint_subjects(&self) -> &[DistinguishedName] {
244 &[]
245 }
246
247 fn verify_tls12_signature(
248 &self,
249 _message: &[u8],
250 _cert: &CertificateDer<'_>,
251 _dss: &DigitallySignedStruct,
252 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
253 unreachable!("ONLY SUPPORT tls 13 VERSION")
254 }
255
256 fn verify_tls13_signature(
257 &self,
258 message: &[u8],
259 cert: &CertificateDer<'_>,
260 dss: &DigitallySignedStruct,
261 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
262 let cert = parse_cert(cert)?;
263 verify_cert_signature(&cert, message, dss.signature())?;
264 Ok(HandshakeSignatureValid::assertion())
265 }
266
267 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
268 SIGNATURE_SCHEMES.to_vec()
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn verify_generated_certificate() {
278 let key_pair = KeyPair::generate(&KeyPairType::Ed25519);
279 let (cert, _) = generate_cert(&key_pair).unwrap();
280
281 let result = verify_cert(&cert);
282 assert!(result.is_ok());
283 let peer_id = result.unwrap();
284 assert_eq!(peer_id, PeerID::try_from(key_pair.public()).unwrap());
285 assert_eq!(peer_id.0, key_pair.public().as_bytes());
286 }
287}