Skip to main content

karyon_p2p/
handshake.rs

1use std::{collections::HashMap, time::Duration};
2
3use bincode::{Decode, Encode};
4use log::trace;
5
6use karyon_core::async_util::timeout;
7
8use karyon_net::FramedReader;
9
10use crate::{
11    codec::PeerNetMsgCodec,
12    message::{PeerNetCmd, PeerNetMsg},
13    protocol::{ProtocolID, ProtocolKind, ProtocolMeta},
14    util::decode,
15    version::{version_match, VersionInt},
16    Error, PeerID, Result, Version,
17};
18
19/// Negotiated protocol set from a successful handshake.
20pub type NegotiatedProtocols = Vec<ProtocolID>;
21
22/// Writer type for the handshake.
23pub type HandshakeWriter = karyon_net::FramedWriter<PeerNetMsgCodec>;
24
25/// Version-exchange message kicked off by the outbound side.
26#[derive(Decode, Encode, Debug, Clone)]
27pub struct VerMsg {
28    pub peer_id: PeerID,
29    pub version: VersionInt,
30    pub protocols: HashMap<ProtocolID, VersionInt>,
31}
32
33/// Acknowledges a Version message; `ack=false` means the responder
34/// rejected the proposed version/protocol mix.
35#[derive(Decode, Encode, Debug, Clone)]
36pub struct VerAckMsg {
37    pub peer_id: PeerID,
38    pub ack: bool,
39}
40
41/// Non-IO inputs to the handshake.
42pub struct HandshakeParams<'a> {
43    pub own_id: &'a PeerID,
44    pub is_inbound: bool,
45    pub config_version: &'a Version,
46    /// Local protocols with metadata (version + kind). Drives both
47    /// version negotiation and the mandatory-subset check.
48    pub protocols: &'a HashMap<ProtocolID, ProtocolMeta>,
49    pub timeout_secs: u64,
50    /// PeerID derived from the secure transport (TLS cert), if any.
51    /// When `Some`, the handshake asserts the peer's claimed `vermsg.peer_id`
52    /// matches it - so a peer can't claim an identity it can't prove.
53    pub verified_peer_id: Option<&'a PeerID>,
54}
55
56/// Run the handshake on split reader/writer. Returns the remote peer ID
57/// and the set of protocols both sides support.
58pub async fn handshake(
59    reader: &mut FramedReader<PeerNetMsgCodec>,
60    writer: &mut HandshakeWriter,
61    p: &HandshakeParams<'_>,
62) -> Result<(PeerID, NegotiatedProtocols)> {
63    trace!("Init Handshake");
64
65    if !p.is_inbound {
66        send_vermsg(writer, p.own_id, p.config_version, p.protocols).await?;
67    }
68
69    let t = Duration::from_secs(p.timeout_secs);
70    let msg: PeerNetMsg = timeout(t, reader.recv_msg()).await??;
71
72    match msg.header.command {
73        PeerNetCmd::Version => {
74            let result =
75                validate_version_msg(&msg, p.config_version, p.protocols, p.verified_peer_id).await;
76            match &result {
77                Ok(_) => send_verack(writer, p.own_id, true).await?,
78                Err(Error::IncompatibleVersion(_)) | Err(Error::IncompatiblePeer) => {
79                    send_verack(writer, p.own_id, false).await?;
80                }
81                _ => {}
82            };
83            result
84        }
85        PeerNetCmd::Verack => {
86            let pid = validate_verack_msg(&msg, p.verified_peer_id).await?;
87            // Outbound side: we sent VerMsg, got VerAck. We don't know
88            // the intersection yet - return all our protocols. The inbound
89            // side already computed the intersection and will only run
90            // the shared set.
91            let all_protos = p.protocols.keys().cloned().collect();
92            Ok((pid, all_protos))
93        }
94        cmd => Err(Error::InvalidMsg(format!("unexpected msg found {cmd:?}"))),
95    }
96}
97
98/// Sends a Version message.
99async fn send_vermsg(
100    writer: &mut HandshakeWriter,
101    own_id: &PeerID,
102    config_version: &Version,
103    protocols: &HashMap<ProtocolID, ProtocolMeta>,
104) -> Result<()> {
105    let proto_versions = protocols
106        .iter()
107        .map(|(k, m)| (k.clone(), m.version.v.clone()))
108        .collect();
109
110    let vermsg = VerMsg {
111        peer_id: own_id.clone(),
112        protocols: proto_versions,
113        version: config_version.v.clone(),
114    };
115
116    trace!("Send VerMsg");
117    writer
118        .send_msg(PeerNetMsg::new(PeerNetCmd::Version, &vermsg)?)
119        .await?;
120    Ok(())
121}
122
123/// Sends a Verack message.
124async fn send_verack(writer: &mut HandshakeWriter, own_id: &PeerID, ack: bool) -> Result<()> {
125    let verack = VerAckMsg {
126        peer_id: own_id.clone(),
127        ack,
128    };
129
130    trace!("Send VerAckMsg {verack:?}");
131    writer
132        .send_msg(PeerNetMsg::new(PeerNetCmd::Verack, &verack)?)
133        .await?;
134    Ok(())
135}
136
137/// Validates the given version msg. Returns the remote peer ID and
138/// the intersection of compatible protocols.
139async fn validate_version_msg(
140    msg: &PeerNetMsg,
141    config_version: &Version,
142    protocols: &HashMap<ProtocolID, ProtocolMeta>,
143    verified_peer_id: Option<&PeerID>,
144) -> Result<(PeerID, NegotiatedProtocols)> {
145    let (vermsg, _) = decode::<VerMsg>(&msg.payload)?;
146
147    if !version_match(&config_version.req, &vermsg.version) {
148        return Err(Error::IncompatibleVersion("system version".into()));
149    }
150
151    // Bind the claimed PeerID to the secure transport's identity. A peer
152    // with a valid TLS cert for keypair X cannot claim PeerID Y.
153    if let Some(vpid) = verified_peer_id {
154        if vpid != &vermsg.peer_id {
155            return Err(Error::IncompatiblePeer);
156        }
157    }
158
159    let shared = protocols_intersection(protocols, &vermsg.protocols);
160
161    if shared.is_empty() {
162        return Err(Error::IncompatiblePeer);
163    }
164
165    // Every protocol the local node marked `Mandatory` must be in the
166    // negotiated intersection. Otherwise reject the handshake.
167    for (id, meta) in protocols.iter() {
168        if matches!(meta.kind, ProtocolKind::Mandatory) && !shared.iter().any(|p| p == id) {
169            return Err(Error::IncompatiblePeer);
170        }
171    }
172
173    trace!("Received VerMsg from: {}", vermsg.peer_id);
174    Ok((vermsg.peer_id, shared))
175}
176
177/// Validates the given verack msg.
178async fn validate_verack_msg(
179    msg: &PeerNetMsg,
180    verified_peer_id: Option<&PeerID>,
181) -> Result<PeerID> {
182    let (verack, _) = decode::<VerAckMsg>(&msg.payload)?;
183
184    if !verack.ack {
185        return Err(Error::IncompatiblePeer);
186    }
187
188    if let Some(vpid) = verified_peer_id {
189        if vpid != &verack.peer_id {
190            return Err(Error::IncompatiblePeer);
191        }
192    }
193
194    trace!("Received VerAckMsg from: {}", verack.peer_id);
195    Ok(verack.peer_id)
196}
197
198/// Compute the intersection of protocols both sides support with
199/// compatible versions. Returns the list of shared protocol IDs.
200fn protocols_intersection(
201    our_protocols: &HashMap<ProtocolID, ProtocolMeta>,
202    their_protocols: &HashMap<String, VersionInt>,
203) -> NegotiatedProtocols {
204    let mut shared = Vec::new();
205    for (name, their_version) in their_protocols.iter() {
206        if let Some(our_meta) = our_protocols.get(name) {
207            if version_match(&our_meta.version.req, their_version) {
208                shared.push(name.clone());
209            }
210        }
211    }
212    shared
213}