1use std::{collections::HashMap, time::Duration};
2
3use bincode::{Decode, Encode};
4use log::trace;
5
6use karyon_core::async_util::timeout;
7
8use karyon_net::FramedReader;
9
10use crate::{
11 codec::PeerNetMsgCodec,
12 message::{PeerNetCmd, PeerNetMsg},
13 protocol::{ProtocolID, ProtocolKind, ProtocolMeta},
14 util::decode,
15 version::{version_match, VersionInt},
16 Error, PeerID, Result, Version,
17};
18
19pub type NegotiatedProtocols = Vec<ProtocolID>;
21
22pub type HandshakeWriter = karyon_net::FramedWriter<PeerNetMsgCodec>;
24
25#[derive(Decode, Encode, Debug, Clone)]
27pub struct VerMsg {
28 pub peer_id: PeerID,
29 pub version: VersionInt,
30 pub protocols: HashMap<ProtocolID, VersionInt>,
31}
32
33#[derive(Decode, Encode, Debug, Clone)]
36pub struct VerAckMsg {
37 pub peer_id: PeerID,
38 pub ack: bool,
39}
40
41pub struct HandshakeParams<'a> {
43 pub own_id: &'a PeerID,
44 pub is_inbound: bool,
45 pub config_version: &'a Version,
46 pub protocols: &'a HashMap<ProtocolID, ProtocolMeta>,
49 pub timeout_secs: u64,
50 pub verified_peer_id: Option<&'a PeerID>,
54}
55
56pub async fn handshake(
59 reader: &mut FramedReader<PeerNetMsgCodec>,
60 writer: &mut HandshakeWriter,
61 p: &HandshakeParams<'_>,
62) -> Result<(PeerID, NegotiatedProtocols)> {
63 trace!("Init Handshake");
64
65 if !p.is_inbound {
66 send_vermsg(writer, p.own_id, p.config_version, p.protocols).await?;
67 }
68
69 let t = Duration::from_secs(p.timeout_secs);
70 let msg: PeerNetMsg = timeout(t, reader.recv_msg()).await??;
71
72 match msg.header.command {
73 PeerNetCmd::Version => {
74 let result =
75 validate_version_msg(&msg, p.config_version, p.protocols, p.verified_peer_id).await;
76 match &result {
77 Ok(_) => send_verack(writer, p.own_id, true).await?,
78 Err(Error::IncompatibleVersion(_)) | Err(Error::IncompatiblePeer) => {
79 send_verack(writer, p.own_id, false).await?;
80 }
81 _ => {}
82 };
83 result
84 }
85 PeerNetCmd::Verack => {
86 let pid = validate_verack_msg(&msg, p.verified_peer_id).await?;
87 let all_protos = p.protocols.keys().cloned().collect();
92 Ok((pid, all_protos))
93 }
94 cmd => Err(Error::InvalidMsg(format!("unexpected msg found {cmd:?}"))),
95 }
96}
97
98async fn send_vermsg(
100 writer: &mut HandshakeWriter,
101 own_id: &PeerID,
102 config_version: &Version,
103 protocols: &HashMap<ProtocolID, ProtocolMeta>,
104) -> Result<()> {
105 let proto_versions = protocols
106 .iter()
107 .map(|(k, m)| (k.clone(), m.version.v.clone()))
108 .collect();
109
110 let vermsg = VerMsg {
111 peer_id: own_id.clone(),
112 protocols: proto_versions,
113 version: config_version.v.clone(),
114 };
115
116 trace!("Send VerMsg");
117 writer
118 .send_msg(PeerNetMsg::new(PeerNetCmd::Version, &vermsg)?)
119 .await?;
120 Ok(())
121}
122
123async fn send_verack(writer: &mut HandshakeWriter, own_id: &PeerID, ack: bool) -> Result<()> {
125 let verack = VerAckMsg {
126 peer_id: own_id.clone(),
127 ack,
128 };
129
130 trace!("Send VerAckMsg {verack:?}");
131 writer
132 .send_msg(PeerNetMsg::new(PeerNetCmd::Verack, &verack)?)
133 .await?;
134 Ok(())
135}
136
137async fn validate_version_msg(
140 msg: &PeerNetMsg,
141 config_version: &Version,
142 protocols: &HashMap<ProtocolID, ProtocolMeta>,
143 verified_peer_id: Option<&PeerID>,
144) -> Result<(PeerID, NegotiatedProtocols)> {
145 let (vermsg, _) = decode::<VerMsg>(&msg.payload)?;
146
147 if !version_match(&config_version.req, &vermsg.version) {
148 return Err(Error::IncompatibleVersion("system version".into()));
149 }
150
151 if let Some(vpid) = verified_peer_id {
154 if vpid != &vermsg.peer_id {
155 return Err(Error::IncompatiblePeer);
156 }
157 }
158
159 let shared = protocols_intersection(protocols, &vermsg.protocols);
160
161 if shared.is_empty() {
162 return Err(Error::IncompatiblePeer);
163 }
164
165 for (id, meta) in protocols.iter() {
168 if matches!(meta.kind, ProtocolKind::Mandatory) && !shared.iter().any(|p| p == id) {
169 return Err(Error::IncompatiblePeer);
170 }
171 }
172
173 trace!("Received VerMsg from: {}", vermsg.peer_id);
174 Ok((vermsg.peer_id, shared))
175}
176
177async fn validate_verack_msg(
179 msg: &PeerNetMsg,
180 verified_peer_id: Option<&PeerID>,
181) -> Result<PeerID> {
182 let (verack, _) = decode::<VerAckMsg>(&msg.payload)?;
183
184 if !verack.ack {
185 return Err(Error::IncompatiblePeer);
186 }
187
188 if let Some(vpid) = verified_peer_id {
189 if vpid != &verack.peer_id {
190 return Err(Error::IncompatiblePeer);
191 }
192 }
193
194 trace!("Received VerAckMsg from: {}", verack.peer_id);
195 Ok(verack.peer_id)
196}
197
198fn protocols_intersection(
201 our_protocols: &HashMap<ProtocolID, ProtocolMeta>,
202 their_protocols: &HashMap<String, VersionInt>,
203) -> NegotiatedProtocols {
204 let mut shared = Vec::new();
205 for (name, their_version) in their_protocols.iter() {
206 if let Some(our_meta) = our_protocols.get(name) {
207 if version_match(&our_meta.version.req, their_version) {
208 shared.push(name.clone());
209 }
210 }
211 }
212 shared
213}