karyon_p2p/routing_table/
mod.rs

1mod bucket;
2mod entry;
3
4use std::net::IpAddr;
5
6use parking_lot::RwLock;
7
8use rand::{rngs::OsRng, seq::SliceRandom};
9
10use karyon_net::Addr;
11
12pub use bucket::{
13    Bucket, BucketEntry, EntryStatusFlag, CONNECTED_ENTRY, DISCONNECTED_ENTRY, INCOMPATIBLE_ENTRY,
14    PENDING_ENTRY, UNREACHABLE_ENTRY, UNSTABLE_ENTRY,
15};
16pub use entry::{xor_distance, Entry, Key};
17
18use bucket::BUCKET_SIZE;
19use entry::KEY_SIZE;
20
21/// The total number of buckets in the routing table.
22const TABLE_SIZE: usize = 32;
23
24/// The distance limit for the closest buckets.
25const DISTANCE_LIMIT: usize = 32;
26
27/// The maximum number of matched subnets allowed within a single bucket.
28const MAX_MATCHED_SUBNET_IN_BUCKET: usize = 1;
29
30/// The maximum number of matched subnets across the entire routing table.
31const MAX_MATCHED_SUBNET_IN_TABLE: usize = 6;
32
33/// Represents the possible result when adding a new entry.
34#[derive(Debug)]
35pub enum AddEntryResult {
36    /// The entry is added.
37    Added,
38    /// The entry is already exists.
39    Exists,
40    /// The entry is ignored.
41    Ignored,
42    /// The entry is restricted and not allowed.
43    Restricted,
44}
45
46/// This is a modified version of the Kademlia Distributed Hash Table (DHT).
47/// <https://en.wikipedia.org/wiki/Kademlia>
48#[derive(Debug)]
49pub struct RoutingTable {
50    key: Key,
51    buckets: RwLock<Vec<Bucket>>,
52}
53
54impl RoutingTable {
55    /// Creates a new RoutingTable
56    pub fn new(key: Key) -> Self {
57        let buckets: Vec<Bucket> = (0..TABLE_SIZE).map(|_| Bucket::new()).collect();
58        Self {
59            key,
60            buckets: RwLock::new(buckets),
61        }
62    }
63
64    /// Adds a new entry to the table and returns a result indicating success,
65    /// failure, or restrictions.
66    pub fn add_entry(&self, entry: Entry) -> AddEntryResult {
67        // Determine the index of the bucket where the entry should be placed.
68        let bucket_idx = match self.bucket_index(&entry.key) {
69            Some(i) => i,
70            None => return AddEntryResult::Ignored,
71        };
72
73        // Check if the entry already exists in the bucket.
74        if self.contains_key(&entry.key) {
75            return AddEntryResult::Exists;
76        }
77
78        // Check if the entry is restricted.
79        if self.subnet_restricted(bucket_idx, &entry) {
80            return AddEntryResult::Restricted;
81        }
82
83        let mut buckets = self.buckets.write();
84        let bucket = &mut buckets[bucket_idx];
85
86        // If the bucket has free space, add the entry and return success.
87        if bucket.len() < BUCKET_SIZE {
88            bucket.add(&entry);
89            return AddEntryResult::Added;
90        }
91
92        // Replace it with an incompatible entry if one exists.
93        let incompatible_entry = bucket.iter().find(|e| e.is_incompatible()).cloned();
94        if let Some(e) = incompatible_entry {
95            bucket.remove(&e.entry.key);
96            bucket.add(&entry);
97            return AddEntryResult::Added;
98        }
99
100        // If the bucket is full, the entry is ignored.
101        AddEntryResult::Ignored
102    }
103
104    /// Check if the table contains the given key.
105    pub fn contains_key(&self, key: &Key) -> bool {
106        let buckets = self.buckets.read();
107        // Determine the bucket index for the given key.
108        let bucket_idx = match self.bucket_index(key) {
109            Some(bi) => bi,
110            None => return false,
111        };
112
113        let bucket = &buckets[bucket_idx];
114        bucket.contains_key(key)
115    }
116
117    /// Updates the status of an entry in the routing table identified
118    /// by the given key.
119    ///
120    /// If the key is not found, no action is taken.
121    pub fn update_entry(&self, key: &Key, entry_flag: EntryStatusFlag) {
122        let mut buckets = self.buckets.write();
123        // Determine the bucket index for the given key.
124        let bucket_idx = match self.bucket_index(key) {
125            Some(bi) => bi,
126            None => return,
127        };
128
129        let bucket = &mut buckets[bucket_idx];
130        bucket.update_entry(key, entry_flag);
131    }
132
133    /// Returns a list of bucket indexes that are closest to the given target key.
134    pub fn bucket_indexes(&self, target_key: &Key) -> Vec<usize> {
135        let mut indexes = vec![];
136
137        // Determine the primary bucket index for the target key.
138        let bucket_idx = self.bucket_index(target_key).unwrap_or(0);
139
140        indexes.push(bucket_idx);
141
142        // Add additional bucket indexes within a certain distance limit.
143        for i in 1..DISTANCE_LIMIT {
144            if bucket_idx >= i && bucket_idx - i >= 1 {
145                indexes.push(bucket_idx - i);
146            }
147
148            if bucket_idx + i < (TABLE_SIZE - 1) {
149                indexes.push(bucket_idx + i);
150            }
151        }
152
153        indexes
154    }
155
156    /// Returns a list of the closest entries to the given target key, limited by max_entries.
157    pub fn closest_entries(&self, target_key: &Key, max_entries: usize) -> Vec<Entry> {
158        let buckets = self.buckets.read();
159        let mut entries: Vec<Entry> = vec![];
160
161        // Collect entries
162        'outer: for idx in self.bucket_indexes(target_key) {
163            let bucket = &buckets[idx];
164            for bucket_entry in bucket.iter() {
165                if bucket_entry.is_unreachable() || bucket_entry.is_unstable() {
166                    continue;
167                }
168
169                entries.push(bucket_entry.entry.clone());
170                if entries.len() == max_entries {
171                    break 'outer;
172                }
173            }
174        }
175
176        // Sort the entries by their distance to the target key.
177        entries.sort_by(|a, b| {
178            xor_distance(target_key, &a.key).cmp(&xor_distance(target_key, &b.key))
179        });
180
181        entries
182    }
183
184    /// Removes an entry with the given key from the routing table, if it exists.
185    pub fn remove_entry(&self, key: &Key) {
186        let mut buckets = self.buckets.write();
187        // Determine the bucket index for the given key.
188        let bucket_idx = match self.bucket_index(key) {
189            Some(bi) => bi,
190            None => return,
191        };
192
193        let bucket = &mut buckets[bucket_idx];
194        bucket.remove(key);
195    }
196
197    /// Returns an iterator of entries.
198    /// FIXME: TODO: avoid cloning the data
199    pub fn buckets(&self) -> Vec<Bucket> {
200        self.buckets.read().clone()
201    }
202
203    /// Returns a random entry from the routing table.
204    pub fn random_entry(&self, entry_flag: EntryStatusFlag) -> Option<Entry> {
205        let buckets = self.buckets.read();
206        for bucket in buckets.choose_multiple(&mut OsRng, buckets.len()) {
207            for entry in bucket.random_iter(bucket.len()) {
208                if entry.status & entry_flag == 0 {
209                    continue;
210                }
211                return Some(entry.entry.clone());
212            }
213        }
214
215        None
216    }
217
218    // Returns the bucket index for a given key in the table.
219    fn bucket_index(&self, key: &Key) -> Option<usize> {
220        // Calculate the XOR distance between the self key and the provided key.
221        let distance = xor_distance(&self.key, key);
222
223        for (i, b) in distance.iter().enumerate() {
224            if *b != 0 {
225                let lz = i * 8 + b.leading_zeros() as usize;
226                let bits = KEY_SIZE * 8 - 1;
227                let idx = (bits - lz) / 8;
228                return Some(idx);
229            }
230        }
231        None
232    }
233
234    /// This function iterate through the routing table and counts how many
235    /// entries in the same subnet as the given Entry are already present.
236    ///
237    /// If the number of matching entries in the same bucket exceeds a
238    /// threshold (MAX_MATCHED_SUBNET_IN_BUCKET), or if the total count of
239    /// matching entries in the entire table exceeds a threshold
240    /// (MAX_MATCHED_SUBNET_IN_TABLE), the addition of the Entry
241    /// is considered restricted and returns true.
242    fn subnet_restricted(&self, idx: usize, entry: &Entry) -> bool {
243        let buckets = self.buckets.read();
244        let mut bucket_count = 0;
245        let mut table_count = 0;
246
247        // Iterate through the routing table's buckets and entries to check
248        // for subnet matches.
249        for (i, bucket) in buckets.iter().enumerate() {
250            for e in bucket.iter() {
251                // If there is a subnet match, update the counts.
252                let matched = subnet_match(&e.entry.addr, &entry.addr);
253                if matched {
254                    if i == idx {
255                        bucket_count += 1;
256                    }
257                    table_count += 1;
258                }
259
260                // If the number of matched entries in the same bucket exceeds
261                // the limit, return true
262                if bucket_count >= MAX_MATCHED_SUBNET_IN_BUCKET {
263                    return true;
264                }
265            }
266
267            // If the total matched entries in the table exceed the limit,
268            // return true.
269            if table_count >= MAX_MATCHED_SUBNET_IN_TABLE {
270                return true;
271            }
272        }
273
274        // If no subnet restrictions are encountered, return false.
275        false
276    }
277}
278
279/// Check if two addresses belong to the same subnet.
280fn subnet_match(addr: &Addr, other_addr: &Addr) -> bool {
281    match (addr, other_addr) {
282        (Addr::Ip(IpAddr::V4(ip)), Addr::Ip(IpAddr::V4(other_ip))) => {
283            // TODO: Consider moving this to a different place
284            if other_ip.is_loopback() && ip.is_loopback() {
285                return false;
286            }
287            ip.octets()[0..3] == other_ip.octets()[0..3]
288        }
289        // Assume that they have /64 prefix
290        (Addr::Ip(IpAddr::V6(ip)), Addr::Ip(IpAddr::V6(other_ip))) => {
291            if other_ip.is_loopback() && ip.is_loopback() {
292                return false;
293            }
294            // Compare the first 4 segments (128 bits, 4 * 16 bits)
295            ip.segments()[0..4] == other_ip.segments()[0..4]
296        }
297
298        // If the address types don't match or are not handled
299        _ => false,
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::bucket::ALL_ENTRY;
306    use super::*;
307
308    use karyon_net::Addr;
309
310    struct Setup {
311        local_key: Key,
312        keys: Vec<Key>,
313    }
314
315    fn new_entry(key: &Key, addr: &Addr, port: u16, discovery_port: u16) -> Entry {
316        Entry {
317            key: key.clone(),
318            addr: addr.clone(),
319            port,
320            discovery_port,
321        }
322    }
323
324    impl Setup {
325        fn new() -> Self {
326            let keys = vec![
327                [
328                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
329                    0, 0, 0, 0, 0, 1,
330                ],
331                [
332                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
333                    1, 1, 0, 1, 1, 2,
334                ],
335                [
336                    0, 0, 0, 0, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
337                    0, 0, 0, 0, 0, 3,
338                ],
339                [
340                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 30, 1, 18, 0, 0, 0,
341                    0, 0, 0, 0, 0, 4,
342                ],
343                [
344                    223, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
345                    0, 0, 0, 0, 0, 5,
346                ],
347                [
348                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 50, 1, 18, 0, 0, 0,
349                    0, 0, 0, 0, 0, 6,
350                ],
351                [
352                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 50, 1, 18, 0, 0,
353                    0, 0, 0, 0, 0, 0, 7,
354                ],
355                [
356                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 10, 50, 1, 18, 0, 0,
357                    0, 0, 0, 0, 0, 0, 8,
358                ],
359                [
360                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 10, 10, 50, 1, 18, 0, 0,
361                    0, 0, 0, 0, 0, 0, 9,
362                ],
363            ];
364
365            Self {
366                local_key: [
367                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
368                    0, 0, 0, 0, 0, 0,
369                ],
370                keys,
371            }
372        }
373
374        fn entries(&self) -> Vec<Entry> {
375            let mut entries = vec![];
376            for (i, key) in self.keys.iter().enumerate() {
377                entries.push(new_entry(
378                    key,
379                    &Addr::Ip(format!("127.0.{i}.1").parse().unwrap()),
380                    3000,
381                    3010,
382                ));
383            }
384            entries
385        }
386
387        fn table(&self) -> RoutingTable {
388            let table = RoutingTable::new(self.local_key.clone());
389
390            for entry in self.entries() {
391                let res = table.add_entry(entry);
392                assert!(matches!(res, AddEntryResult::Added));
393            }
394
395            table
396        }
397    }
398
399    #[test]
400    fn test_bucket_index() {
401        let setup = Setup::new();
402        let table = setup.table();
403
404        assert_eq!(table.bucket_index(&setup.local_key), None);
405        assert_eq!(table.bucket_index(&setup.keys[0]), Some(0));
406        assert_eq!(table.bucket_index(&setup.keys[1]), Some(5));
407        assert_eq!(table.bucket_index(&setup.keys[2]), Some(26));
408        assert_eq!(table.bucket_index(&setup.keys[3]), Some(11));
409        assert_eq!(table.bucket_index(&setup.keys[4]), Some(31));
410        assert_eq!(table.bucket_index(&setup.keys[5]), Some(11));
411        assert_eq!(table.bucket_index(&setup.keys[6]), Some(12));
412        assert_eq!(table.bucket_index(&setup.keys[7]), Some(13));
413        assert_eq!(table.bucket_index(&setup.keys[8]), Some(14));
414    }
415
416    #[test]
417    fn test_closest_entries() {
418        let setup = Setup::new();
419        let table = setup.table();
420        let entries = setup.entries();
421
422        assert_eq!(
423            table.closest_entries(&setup.keys[5], 8),
424            vec![
425                entries[5].clone(),
426                entries[3].clone(),
427                entries[1].clone(),
428                entries[6].clone(),
429                entries[7].clone(),
430                entries[8].clone(),
431                entries[2].clone(),
432            ]
433        );
434
435        assert_eq!(
436            table.closest_entries(&setup.keys[4], 2),
437            vec![entries[4].clone(), entries[2].clone()]
438        );
439    }
440
441    #[test]
442    fn test_random_entry() {
443        let setup = Setup::new();
444        let table = setup.table();
445        let entries = setup.entries();
446
447        let entry = table.random_entry(ALL_ENTRY);
448        assert!(matches!(entry, Some(_)));
449
450        let entry = table.random_entry(CONNECTED_ENTRY);
451        assert!(matches!(entry, None));
452
453        for entry in entries {
454            table.remove_entry(&entry.key);
455        }
456
457        let entry = table.random_entry(ALL_ENTRY);
458        assert!(matches!(entry, None));
459    }
460
461    #[test]
462    fn test_add_entries() {
463        let setup = Setup::new();
464        let table = setup.table();
465
466        let key = [
467            0, 0, 0, 0, 0, 0, 0, 1, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
468            0, 0, 5,
469        ];
470
471        let key2 = [
472            0, 0, 0, 0, 0, 0, 0, 1, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
473            0, 0, 5,
474        ];
475
476        let entry1 = new_entry(&key, &Addr::Ip("240.120.3.1".parse().unwrap()), 3000, 3010);
477        assert!(matches!(
478            table.add_entry(entry1.clone()),
479            AddEntryResult::Added
480        ));
481
482        assert!(matches!(table.add_entry(entry1), AddEntryResult::Exists));
483
484        let entry2 = new_entry(&key2, &Addr::Ip("240.120.3.2".parse().unwrap()), 3000, 3010);
485        assert!(matches!(
486            table.add_entry(entry2),
487            AddEntryResult::Restricted
488        ));
489
490        let mut key: [u8; 32] = [0; 32];
491
492        for i in 0..BUCKET_SIZE {
493            key[i] += 1;
494            let entry = new_entry(
495                &key,
496                &Addr::Ip(format!("127.0.{i}.1").parse().unwrap()),
497                3000,
498                3010,
499            );
500            table.add_entry(entry);
501        }
502
503        key[BUCKET_SIZE] += 1;
504        let entry = new_entry(&key, &Addr::Ip("125.20.0.1".parse().unwrap()), 3000, 3010);
505        assert!(matches!(table.add_entry(entry), AddEntryResult::Ignored));
506    }
507
508    use std::net::{Ipv4Addr, Ipv6Addr};
509    #[test]
510    fn check_subnet_match() {
511        let addr_v4 = Addr::Ip(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)));
512        let other_addr_v4 = Addr::Ip(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2)));
513
514        let addr_v6 = Addr::Ip(IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)));
515        let other_addr_v6 = Addr::Ip(IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 2)));
516        let diff_addr_v6 = Addr::Ip(IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb7, 0, 0, 0, 0, 0, 2)));
517
518        assert!(matches!(subnet_match(&addr_v4, &other_addr_v4), true));
519        assert!(matches!(subnet_match(&addr_v6, &other_addr_v6), true));
520        assert!(matches!(subnet_match(&addr_v6, &diff_addr_v6), false));
521    }
522}