karyon_p2p/protocols/
handshake.rs1use std::{collections::HashMap, sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use log::trace;
5
6use karyon_core::{async_util::timeout, util::decode};
7
8use crate::{
9 message::{NetMsg, NetMsgCmd, VerAckMsg, VerMsg},
10 peer::Peer,
11 protocol::{InitProtocol, ProtocolID},
12 version::{version_match, VersionInt},
13 Error, PeerID, Result, Version,
14};
15
16pub struct HandshakeProtocol {
17 peer: Arc<Peer>,
18 protocols: HashMap<ProtocolID, Version>,
19}
20
21#[async_trait]
22impl InitProtocol for HandshakeProtocol {
23 type T = Result<PeerID>;
24 async fn init(self: Arc<Self>) -> Self::T {
26 trace!("Init Handshake: {}", self.peer.remote_endpoint());
27
28 if !self.peer.is_inbound() {
29 self.send_vermsg().await?;
30 }
31
32 let t = Duration::from_secs(self.peer.config().handshake_timeout);
33 let msg: NetMsg = timeout(t, self.peer.conn.recv_inner()).await??;
34 match msg.header.command {
35 NetMsgCmd::Version => {
36 let result = self.validate_version_msg(&msg).await;
37 match result {
38 Ok(_) => {
39 self.send_verack(true).await?;
40 }
41 Err(Error::IncompatibleVersion(_)) | Err(Error::UnsupportedProtocol(_)) => {
42 self.send_verack(false).await?;
43 }
44 _ => {}
45 };
46 result
47 }
48 NetMsgCmd::Verack => self.validate_verack_msg(&msg).await,
49 cmd => Err(Error::InvalidMsg(format!("unexpected msg found {:?}", cmd))),
50 }
51 }
52}
53
54impl HandshakeProtocol {
55 pub fn new(peer: Arc<Peer>, protocols: HashMap<ProtocolID, Version>) -> Arc<Self> {
56 Arc::new(Self { peer, protocols })
57 }
58
59 async fn send_vermsg(&self) -> Result<()> {
61 let protocols = self
62 .protocols
63 .clone()
64 .into_iter()
65 .map(|p| (p.0, p.1.v))
66 .collect();
67
68 let vermsg = VerMsg {
69 peer_id: self.peer.own_id().clone(),
70 protocols,
71 version: self.peer.config().version.v.clone(),
72 };
73
74 trace!("Send VerMsg");
75 self.peer
76 .conn
77 .send_inner(NetMsg::new(NetMsgCmd::Version, &vermsg)?)
78 .await?;
79 Ok(())
80 }
81
82 async fn send_verack(&self, ack: bool) -> Result<()> {
84 let verack = VerAckMsg {
85 peer_id: self.peer.own_id().clone(),
86 ack,
87 };
88
89 trace!("Send VerAckMsg {:?}", verack);
90 self.peer
91 .conn
92 .send_inner(NetMsg::new(NetMsgCmd::Verack, &verack)?)
93 .await?;
94 Ok(())
95 }
96
97 async fn validate_version_msg(&self, msg: &NetMsg) -> Result<PeerID> {
99 let (vermsg, _) = decode::<VerMsg>(&msg.payload)?;
100
101 if !version_match(&self.peer.config().version.req, &vermsg.version) {
102 return Err(Error::IncompatibleVersion("system: {}".into()));
103 }
104
105 self.protocols_match(&vermsg.protocols).await?;
106
107 trace!("Received VerMsg from: {}", vermsg.peer_id);
108 Ok(vermsg.peer_id)
109 }
110
111 async fn validate_verack_msg(&self, msg: &NetMsg) -> Result<PeerID> {
113 let (verack, _) = decode::<VerAckMsg>(&msg.payload)?;
114
115 if !verack.ack {
116 return Err(Error::IncompatiblePeer);
117 }
118
119 trace!("Received VerAckMsg from: {}", verack.peer_id);
120 Ok(verack.peer_id)
121 }
122
123 async fn protocols_match(&self, protocols: &HashMap<ProtocolID, VersionInt>) -> Result<()> {
125 for (n, pv) in protocols.iter() {
126 match self.protocols.get(n) {
127 Some(v) => {
128 if !version_match(&v.req, pv) {
129 return Err(Error::IncompatibleVersion(format!("{n} protocol: {pv}")));
130 }
131 }
132 None => {
133 return Err(Error::UnsupportedProtocol(n.to_string()));
134 }
135 }
136 }
137 Ok(())
138 }
139}