1use std::collections::{HashMap, HashSet};
148use std::fs::File;
149use std::io::{BufReader, BufWriter, Write};
150use std::net::IpAddr;
151use std::path::Path;
152use std::time::{SystemTime, UNIX_EPOCH};
153
154use ipnetwork::IpNetwork;
155use rand::Rng;
156use serde::{Deserialize, Serialize};
157use stem_rs::descriptor::router_status::RouterStatusEntry;
158
159use crate::config::VanguardsConfig;
160use crate::error::{Error, Result};
161use crate::node_selection::{is_valid_country_code, is_valid_fingerprint, BwWeightedGenerator};
162
163const SEC_PER_HOUR: f64 = 3600.0;
165
166#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
209pub struct GuardNode {
210 pub idhex: String,
212 pub chosen_at: f64,
214 pub expires_at: f64,
216}
217
218impl GuardNode {
219 pub fn new(idhex: String, chosen_at: f64, expires_at: f64) -> Self {
243 Self {
244 idhex,
245 chosen_at,
246 expires_at,
247 }
248 }
249
250 pub fn is_expired(&self) -> bool {
279 let now = SystemTime::now()
280 .duration_since(UNIX_EPOCH)
281 .unwrap_or_default()
282 .as_secs_f64();
283 self.expires_at < now
284 }
285}
286
287#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
292pub struct RendUseCount {
293 pub idhex: String,
295 pub used: f64,
297 pub weight: f64,
299}
300
301impl RendUseCount {
302 pub fn new(idhex: String, weight: f64) -> Self {
304 Self {
305 idhex,
306 used: 0.0,
307 weight,
308 }
309 }
310}
311
312#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
318pub struct RendGuard {
319 pub use_counts: HashMap<String, RendUseCount>,
321 pub total_use_counts: f64,
323 pub pickle_revision: f64,
325}
326
327impl Default for RendGuard {
328 fn default() -> Self {
329 Self::new()
330 }
331}
332
333impl RendGuard {
334 pub fn new() -> Self {
336 Self {
337 use_counts: HashMap::new(),
338 total_use_counts: 0.0,
339 pickle_revision: 1.0,
340 }
341 }
342
343 pub fn scale_counts(&mut self) {
348 for count in self.use_counts.values_mut() {
349 count.used /= 2.0;
350 }
351 self.total_use_counts = self.use_counts.values().map(|c| c.used).sum();
352 }
353
354 pub fn valid_rend_use(
368 &mut self,
369 fingerprint: &str,
370 config: &crate::config::RendguardConfig,
371 ) -> bool {
372 const NOT_IN_CONSENSUS_ID: &str = "NOT_IN_CONSENSUS";
373
374 let relay_id = if self.use_counts.contains_key(fingerprint) {
375 fingerprint.to_string()
376 } else {
377 if !self.use_counts.contains_key(NOT_IN_CONSENSUS_ID) {
379 self.use_counts.insert(
380 NOT_IN_CONSENSUS_ID.to_string(),
381 RendUseCount::new(NOT_IN_CONSENSUS_ID.to_string(), 0.0),
382 );
383 }
384 NOT_IN_CONSENSUS_ID.to_string()
385 };
386
387 if let Some(count) = self.use_counts.get_mut(&relay_id) {
389 count.used += 1.0;
390 }
391 self.total_use_counts += 1.0;
392
393 if let Some(count) = self.use_counts.get(&relay_id) {
395 if self.total_use_counts >= config.use_global_start_count as f64
396 && count.used >= config.use_relay_start_count as f64
397 && count.used / self.total_use_counts
398 > count.weight * config.use_max_use_to_bw_ratio
399 {
400 return false; }
402 }
403
404 true }
406
407 pub fn xfer_use_counts(
411 &mut self,
412 generator: &BwWeightedGenerator,
413 config: &crate::config::RendguardConfig,
414 ) {
415 const NOT_IN_CONSENSUS_ID: &str = "NOT_IN_CONSENSUS";
416
417 let old_counts = std::mem::take(&mut self.use_counts);
418 let should_scale = self.total_use_counts >= config.use_scale_at_count as f64;
419
420 let routers = generator.routers();
422 let node_weights = generator.node_weights();
423 let weight_total = generator.weight_total();
424 let exit_total = generator.exit_total();
425
426 for (i, router) in routers.iter().enumerate() {
427 let weight = if router.flags.contains(&"Exit".to_string()) && exit_total > 0.0 {
428 node_weights[i] / exit_total
429 } else if weight_total > 0.0 {
430 node_weights[i] / weight_total
431 } else {
432 0.0
433 };
434
435 self.use_counts.insert(
436 router.fingerprint.clone(),
437 RendUseCount::new(router.fingerprint.clone(), weight),
438 );
439 }
440
441 self.use_counts.insert(
443 NOT_IN_CONSENSUS_ID.to_string(),
444 RendUseCount::new(
445 NOT_IN_CONSENSUS_ID.to_string(),
446 config.use_max_consensus_weight_churn / 100.0,
447 ),
448 );
449
450 for (fp, old_count) in old_counts {
452 if fp == NOT_IN_CONSENSUS_ID || self.use_counts.contains_key(&fp) {
453 if let Some(new_count) = self.use_counts.get_mut(&fp) {
454 new_count.used = if should_scale {
455 old_count.used / 2.0
456 } else {
457 old_count.used
458 };
459 }
460 }
461 }
462
463 self.total_use_counts = self.use_counts.values().map(|c| c.used).sum();
465 }
466
467 pub fn usage_rate(&self, fingerprint: &str) -> f64 {
469 if self.total_use_counts <= 0.0 {
470 return 0.0;
471 }
472 self.use_counts
473 .get(fingerprint)
474 .map(|c| 100.0 * c.used / self.total_use_counts)
475 .unwrap_or(0.0)
476 }
477
478 pub fn expected_weight(&self, fingerprint: &str) -> f64 {
480 self.use_counts
481 .get(fingerprint)
482 .map(|c| 100.0 * c.weight)
483 .unwrap_or(0.0)
484 }
485
486 pub fn is_overused(&self, fingerprint: &str, config: &crate::config::RendguardConfig) -> bool {
488 if self.total_use_counts < config.use_global_start_count as f64 {
489 return false;
490 }
491
492 if let Some(count) = self.use_counts.get(fingerprint) {
493 if count.used < config.use_relay_start_count as f64 {
494 return false;
495 }
496 count.used / self.total_use_counts > count.weight * config.use_max_use_to_bw_ratio
497 } else {
498 false
499 }
500 }
501}
502
503#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
578pub struct VanguardState {
579 pub layer2: Vec<GuardNode>,
581 pub layer3: Vec<GuardNode>,
583 pub state_file: String,
585 pub rendguard: RendGuard,
587 pub pickle_revision: u32,
589 #[serde(skip)]
591 pub enable_vanguards: bool,
592}
593
594impl Default for VanguardState {
595 fn default() -> Self {
596 Self::new("vanguards.state")
597 }
598}
599
600impl VanguardState {
601 pub fn new(state_file: &str) -> Self {
603 Self {
604 layer2: Vec::new(),
605 layer3: Vec::new(),
606 state_file: state_file.to_string(),
607 rendguard: RendGuard::new(),
608 pickle_revision: 1,
609 enable_vanguards: true,
610 }
611 }
612
613 pub fn load_or_create(path: &Path) -> Self {
623 match Self::read_from_file(path) {
624 Ok(mut state) => {
625 state.state_file = path.to_string_lossy().to_string();
626 state
627 }
628 Err(_) => Self::new(&path.to_string_lossy()),
629 }
630 }
631
632 pub fn read_from_file(path: &Path) -> Result<Self> {
643 let file =
644 File::open(path).map_err(|e| Error::State(format!("cannot open state file: {}", e)))?;
645 let reader = BufReader::new(file);
646 let state: Self = serde_pickle::from_reader(reader, Default::default())
647 .map_err(|e| Error::State(format!("cannot parse state file: {}", e)))?;
648
649 state.validate()?;
651
652 Ok(state)
653 }
654
655 pub fn validate(&self) -> Result<()> {
665 let now = SystemTime::now()
666 .duration_since(UNIX_EPOCH)
667 .unwrap_or_default()
668 .as_secs_f64();
669
670 let max_timestamp = now + 3600.0;
672
673 for guard in &self.layer2 {
675 if !is_valid_fingerprint(&guard.idhex) {
676 return Err(Error::State(format!(
677 "invalid fingerprint in layer2: {}",
678 guard.idhex
679 )));
680 }
681 if guard.chosen_at > max_timestamp {
682 return Err(Error::State(format!(
683 "future timestamp in layer2 guard {}: chosen_at {} > now {}",
684 guard.idhex, guard.chosen_at, now
685 )));
686 }
687 if guard.expires_at > max_timestamp + 86400.0 * 365.0 {
688 return Err(Error::State(format!(
690 "unreasonable future expiration in layer2 guard {}: expires_at {}",
691 guard.idhex, guard.expires_at
692 )));
693 }
694 }
695
696 for guard in &self.layer3 {
698 if !is_valid_fingerprint(&guard.idhex) {
699 return Err(Error::State(format!(
700 "invalid fingerprint in layer3: {}",
701 guard.idhex
702 )));
703 }
704 if guard.chosen_at > max_timestamp {
705 return Err(Error::State(format!(
706 "future timestamp in layer3 guard {}: chosen_at {} > now {}",
707 guard.idhex, guard.chosen_at, now
708 )));
709 }
710 if guard.expires_at > max_timestamp + 86400.0 * 365.0 {
711 return Err(Error::State(format!(
712 "unreasonable future expiration in layer3 guard {}: expires_at {}",
713 guard.idhex, guard.expires_at
714 )));
715 }
716 }
717
718 for fp in self.rendguard.use_counts.keys() {
720 if fp == "NOT_IN_CONSENSUS" {
722 continue;
723 }
724 if !is_valid_fingerprint(fp) {
725 return Err(Error::State(format!(
726 "invalid fingerprint in rendguard: {}",
727 fp
728 )));
729 }
730 }
731
732 Ok(())
733 }
734
735 pub fn write_to_file(&self, path: &Path) -> Result<()> {
744 let temp_path = path.with_extension("tmp");
746
747 #[cfg(unix)]
749 let file = {
750 use std::os::unix::fs::OpenOptionsExt;
751 std::fs::OpenOptions::new()
752 .write(true)
753 .create(true)
754 .truncate(true)
755 .mode(0o600)
756 .open(&temp_path)
757 .map_err(|e| Error::State(format!("cannot create temp state file: {}", e)))?
758 };
759
760 #[cfg(not(unix))]
761 let file = File::create(&temp_path)
762 .map_err(|e| Error::State(format!("cannot create temp state file: {}", e)))?;
763
764 let mut writer = BufWriter::new(file);
765 serde_pickle::to_writer(&mut writer, self, Default::default())
766 .map_err(|e| Error::State(format!("cannot write state file: {}", e)))?;
767
768 writer
770 .flush()
771 .map_err(|e| Error::State(format!("cannot flush state file: {}", e)))?;
772 drop(writer);
773
774 std::fs::rename(&temp_path, path)
776 .map_err(|e| Error::State(format!("cannot rename temp state file: {}", e)))?;
777
778 Ok(())
779 }
780
781 pub fn layer2_guardset(&self) -> String {
783 self.layer2
784 .iter()
785 .map(|g| g.idhex.as_str())
786 .collect::<Vec<_>>()
787 .join(",")
788 }
789
790 pub fn layer3_guardset(&self) -> String {
792 self.layer3
793 .iter()
794 .map(|g| g.idhex.as_str())
795 .collect::<Vec<_>>()
796 .join(",")
797 }
798
799 pub fn calculate_guard_lifetime(min_hours: u32, max_hours: u32) -> f64 {
813 let mut rng = rand::thread_rng();
814 let min_secs = min_hours as f64 * SEC_PER_HOUR;
815 let max_secs = max_hours as f64 * SEC_PER_HOUR;
816 let sample1 = rng.gen_range(min_secs..=max_secs);
817 let sample2 = rng.gen_range(min_secs..=max_secs);
818 sample1.max(sample2)
819 }
820
821 pub fn add_new_layer2(
826 &mut self,
827 generator: &BwWeightedGenerator,
828 excluded: &ExcludeNodes,
829 config: &VanguardsConfig,
830 ) -> Result<()> {
831 let existing: HashSet<_> = self.layer2.iter().map(|g| g.idhex.as_str()).collect();
832
833 for _ in 0..1000 {
834 let guard = generator.generate()?;
835 if existing.contains(guard.fingerprint.as_str()) {
836 continue;
837 }
838 if excluded.router_is_excluded(guard) {
839 continue;
840 }
841
842 let now = SystemTime::now()
843 .duration_since(UNIX_EPOCH)
844 .unwrap_or_default()
845 .as_secs_f64();
846 let lifetime = Self::calculate_guard_lifetime(
847 config.min_layer2_lifetime_hours,
848 config.max_layer2_lifetime_hours,
849 );
850 let expires = now + lifetime;
851
852 self.layer2
853 .push(GuardNode::new(guard.fingerprint.clone(), now, expires));
854 return Ok(());
855 }
856
857 Err(Error::NoNodesRemain)
858 }
859
860 pub fn add_new_layer3(
865 &mut self,
866 generator: &BwWeightedGenerator,
867 excluded: &ExcludeNodes,
868 config: &VanguardsConfig,
869 ) -> Result<()> {
870 let existing: HashSet<_> = self.layer3.iter().map(|g| g.idhex.as_str()).collect();
871
872 for _ in 0..1000 {
873 let guard = generator.generate()?;
874 if existing.contains(guard.fingerprint.as_str()) {
875 continue;
876 }
877 if excluded.router_is_excluded(guard) {
878 continue;
879 }
880
881 let now = SystemTime::now()
882 .duration_since(UNIX_EPOCH)
883 .unwrap_or_default()
884 .as_secs_f64();
885 let lifetime = Self::calculate_guard_lifetime(
886 config.min_layer3_lifetime_hours,
887 config.max_layer3_lifetime_hours,
888 );
889 let expires = now + lifetime;
890
891 self.layer3
892 .push(GuardNode::new(guard.fingerprint.clone(), now, expires));
893 return Ok(());
894 }
895
896 Err(Error::NoNodesRemain)
897 }
898
899 pub fn remove_down_from_layer(layer: &mut Vec<GuardNode>, consensus_fps: &HashSet<String>) {
901 layer.retain(|g| consensus_fps.contains(&g.idhex));
902 }
903
904 pub fn remove_expired_from_layer(layer: &mut Vec<GuardNode>) {
906 let now = SystemTime::now()
907 .duration_since(UNIX_EPOCH)
908 .unwrap_or_default()
909 .as_secs_f64();
910 layer.retain(|g| g.expires_at >= now);
911 }
912
913 pub fn remove_excluded_from_layer(
915 layer: &mut Vec<GuardNode>,
916 router_map: &HashMap<String, &RouterStatusEntry>,
917 excluded: &ExcludeNodes,
918 ) {
919 layer.retain(|g| {
920 if let Some(router) = router_map.get(&g.idhex) {
921 !excluded.router_is_excluded(router)
922 } else {
923 true
924 }
925 });
926 }
927
928 pub fn replenish_layers(
933 &mut self,
934 generator: &BwWeightedGenerator,
935 excluded: &ExcludeNodes,
936 config: &VanguardsConfig,
937 ) -> Result<()> {
938 self.layer2.truncate(config.num_layer2_guards as usize);
939 self.layer3.truncate(config.num_layer3_guards as usize);
940
941 while self.layer2.len() < config.num_layer2_guards as usize {
942 self.add_new_layer2(generator, excluded, config)?;
943 }
944
945 while self.layer3.len() < config.num_layer3_guards as usize {
946 self.add_new_layer3(generator, excluded, config)?;
947 }
948
949 Ok(())
950 }
951}
952
953#[derive(Debug, Clone, Default)]
1009pub struct ExcludeNodes {
1010 pub networks: Vec<IpNetwork>,
1012 pub idhexes: HashSet<String>,
1014 pub nicks: HashSet<String>,
1016 pub countries: HashSet<String>,
1018 pub exclude_unknowns: Option<String>,
1020}
1021
1022impl ExcludeNodes {
1023 pub fn new() -> Self {
1025 Self::default()
1026 }
1027
1028 pub fn parse(conf_line: &str, exclude_unknowns: Option<&str>) -> Self {
1049 let mut result = Self::new();
1050 result.exclude_unknowns = exclude_unknowns.map(|s| s.to_string());
1051
1052 if let Some(ref setting) = result.exclude_unknowns {
1053 if setting == "1" {
1054 result.countries.insert("??".to_string());
1055 result.countries.insert("a1".to_string());
1056 }
1057 }
1058
1059 if conf_line.is_empty() {
1060 return result;
1061 }
1062
1063 result.parse_line(conf_line);
1064 result
1065 }
1066
1067 fn parse_line(&mut self, conf_line: &str) {
1069 for part in conf_line.split(',') {
1070 let mut p = part.trim().to_string();
1071 if p.is_empty() {
1072 continue;
1073 }
1074
1075 if p.starts_with('$') {
1076 p = p[1..].to_string();
1077 }
1078
1079 if let Some(idx) = p.find('~') {
1080 p = p[..idx].to_string();
1081 }
1082 if let Some(idx) = p.find('=') {
1083 p = p[..idx].to_string();
1084 }
1085
1086 if is_valid_fingerprint(&p) {
1087 self.idhexes.insert(p.to_uppercase());
1088 } else if p.starts_with('{') && p.ends_with('}') && p.len() >= 3 {
1089 let cc = &p[1..p.len() - 1];
1090 if is_valid_country_code(cc) {
1091 self.countries.insert(cc.to_lowercase());
1092 }
1093 } else if p.contains(':') || p.contains('.') {
1094 if let Ok(network) = p.parse::<IpNetwork>() {
1095 self.networks.push(network);
1096 } else if let Ok(ip) = p.parse::<IpAddr>() {
1097 let network = match ip {
1098 IpAddr::V4(_) => format!("{}/32", ip).parse().ok(),
1099 IpAddr::V6(_) => format!("{}/128", ip).parse().ok(),
1100 };
1101 if let Some(net) = network {
1102 self.networks.push(net);
1103 }
1104 }
1105 } else {
1106 self.nicks.insert(p);
1107 }
1108 }
1109
1110 if let Some(ref setting) = self.exclude_unknowns {
1111 if setting == "auto" && !self.countries.is_empty() {
1112 self.countries.insert("??".to_string());
1113 self.countries.insert("a1".to_string());
1114 }
1115 }
1116 }
1117
1118 pub fn router_is_excluded(&self, router: &RouterStatusEntry) -> bool {
1128 if self.idhexes.contains(&router.fingerprint.to_uppercase()) {
1129 return true;
1130 }
1131
1132 if self.nicks.contains(&router.nickname) {
1133 return true;
1134 }
1135
1136 let addresses = self.get_router_addresses(router);
1137 for (addr, _port, _is_ipv6) in &addresses {
1138 for network in &self.networks {
1139 if network.contains(*addr) {
1140 return true;
1141 }
1142 }
1143 }
1144
1145 false
1146 }
1147
1148 fn get_router_addresses(&self, router: &RouterStatusEntry) -> Vec<(IpAddr, u16, bool)> {
1150 let mut addresses = vec![(router.address, router.or_port, router.address.is_ipv6())];
1151 addresses.extend(router.or_addresses.iter().cloned());
1152 addresses
1153 }
1154
1155 pub fn has_exclusions(&self) -> bool {
1157 !self.networks.is_empty()
1158 || !self.idhexes.is_empty()
1159 || !self.nicks.is_empty()
1160 || !self.countries.is_empty()
1161 }
1162}
1163
1164#[cfg(test)]
1165mod tests {
1166 use super::*;
1167 use chrono::Utc;
1168 use stem_rs::descriptor::router_status::RouterStatusEntryType;
1169
1170 fn create_test_router(fingerprint: &str, nickname: &str, address: &str) -> RouterStatusEntry {
1171 RouterStatusEntry::new(
1172 RouterStatusEntryType::V3,
1173 nickname.to_string(),
1174 fingerprint.to_string(),
1175 Utc::now(),
1176 address.parse().unwrap(),
1177 9001,
1178 )
1179 }
1180
1181 #[test]
1182 fn test_guard_node_creation() {
1183 let now = 1000000.0;
1184 let expires = 2000000.0;
1185 let guard = GuardNode::new("A".repeat(40), now, expires);
1186 assert_eq!(guard.idhex, "A".repeat(40));
1187 assert_eq!(guard.chosen_at, now);
1188 assert_eq!(guard.expires_at, expires);
1189 }
1190
1191 #[test]
1192 fn test_guard_node_expired() {
1193 let now = SystemTime::now()
1194 .duration_since(UNIX_EPOCH)
1195 .unwrap()
1196 .as_secs_f64();
1197
1198 let expired = GuardNode::new("A".repeat(40), now - 1000.0, now - 100.0);
1199 assert!(expired.is_expired());
1200
1201 let not_expired = GuardNode::new("B".repeat(40), now, now + 86400.0);
1202 assert!(!not_expired.is_expired());
1203 }
1204
1205 #[test]
1206 fn test_vanguard_state_new() {
1207 let state = VanguardState::new("test.state");
1208 assert!(state.layer2.is_empty());
1209 assert!(state.layer3.is_empty());
1210 assert_eq!(state.state_file, "test.state");
1211 assert_eq!(state.pickle_revision, 1);
1212 }
1213
1214 #[test]
1215 fn test_vanguard_state_guardset() {
1216 let mut state = VanguardState::new("test.state");
1217 state
1218 .layer2
1219 .push(GuardNode::new("A".repeat(40), 0.0, 1000.0));
1220 state
1221 .layer2
1222 .push(GuardNode::new("B".repeat(40), 0.0, 1000.0));
1223
1224 let guardset = state.layer2_guardset();
1225 assert!(guardset.contains(&"A".repeat(40)));
1226 assert!(guardset.contains(&"B".repeat(40)));
1227 assert!(guardset.contains(','));
1228 }
1229
1230 #[test]
1231 fn test_calculate_guard_lifetime() {
1232 for _ in 0..100 {
1233 let lifetime = VanguardState::calculate_guard_lifetime(24, 1080);
1234 let min_secs = 24.0 * SEC_PER_HOUR;
1235 let max_secs = 1080.0 * SEC_PER_HOUR;
1236 assert!(lifetime >= min_secs);
1237 assert!(lifetime <= max_secs);
1238 }
1239 }
1240
1241 #[test]
1242 fn test_remove_expired_from_layer() {
1243 let now = SystemTime::now()
1244 .duration_since(UNIX_EPOCH)
1245 .unwrap()
1246 .as_secs_f64();
1247
1248 let mut layer = vec![
1249 GuardNode::new("A".repeat(40), now - 1000.0, now - 100.0),
1250 GuardNode::new("B".repeat(40), now, now + 86400.0),
1251 GuardNode::new("C".repeat(40), now - 2000.0, now - 500.0),
1252 ];
1253
1254 VanguardState::remove_expired_from_layer(&mut layer);
1255 assert_eq!(layer.len(), 1);
1256 assert_eq!(layer[0].idhex, "B".repeat(40));
1257 }
1258
1259 #[test]
1260 fn test_remove_down_from_layer() {
1261 let mut layer = vec![
1262 GuardNode::new("A".repeat(40), 0.0, 1000.0),
1263 GuardNode::new("B".repeat(40), 0.0, 1000.0),
1264 GuardNode::new("C".repeat(40), 0.0, 1000.0),
1265 ];
1266
1267 let mut consensus_fps = HashSet::new();
1268 consensus_fps.insert("A".repeat(40));
1269 consensus_fps.insert("C".repeat(40));
1270
1271 VanguardState::remove_down_from_layer(&mut layer, &consensus_fps);
1272 assert_eq!(layer.len(), 2);
1273 assert!(layer.iter().any(|g| g.idhex == "A".repeat(40)));
1274 assert!(layer.iter().any(|g| g.idhex == "C".repeat(40)));
1275 assert!(!layer.iter().any(|g| g.idhex == "B".repeat(40)));
1276 }
1277
1278 #[test]
1279 fn test_exclude_nodes_parse_fingerprint() {
1280 let exclude = ExcludeNodes::parse("$AABBCCDD00112233445566778899AABBCCDDEEFF", None);
1281 assert!(exclude
1282 .idhexes
1283 .contains("AABBCCDD00112233445566778899AABBCCDDEEFF"));
1284 }
1285
1286 #[test]
1287 fn test_exclude_nodes_parse_fingerprint_without_dollar() {
1288 let exclude = ExcludeNodes::parse("AABBCCDD00112233445566778899AABBCCDDEEFF", None);
1289 assert!(exclude
1290 .idhexes
1291 .contains("AABBCCDD00112233445566778899AABBCCDDEEFF"));
1292 }
1293
1294 #[test]
1295 fn test_exclude_nodes_parse_fingerprint_with_suffix() {
1296 let exclude =
1297 ExcludeNodes::parse("$AABBCCDD00112233445566778899AABBCCDDEEFF~nickname", None);
1298 assert!(exclude
1299 .idhexes
1300 .contains("AABBCCDD00112233445566778899AABBCCDDEEFF"));
1301 assert!(!exclude.nicks.contains("nickname"));
1302
1303 let exclude2 =
1304 ExcludeNodes::parse("$AABBCCDD00112233445566778899AABBCCDDEEFF=nickname", None);
1305 assert!(exclude2
1306 .idhexes
1307 .contains("AABBCCDD00112233445566778899AABBCCDDEEFF"));
1308 }
1309
1310 #[test]
1311 fn test_exclude_nodes_parse_country_code() {
1312 let exclude = ExcludeNodes::parse("{us}", None);
1313 assert!(exclude.countries.contains("us"));
1314
1315 let exclude2 = ExcludeNodes::parse("{US}", None);
1316 assert!(exclude2.countries.contains("us"));
1317 }
1318
1319 #[test]
1320 fn test_exclude_nodes_parse_network() {
1321 let exclude = ExcludeNodes::parse("192.168.0.0/24", None);
1322 assert_eq!(exclude.networks.len(), 1);
1323
1324 let exclude2 = ExcludeNodes::parse("2001:db8::/32", None);
1325 assert_eq!(exclude2.networks.len(), 1);
1326 }
1327
1328 #[test]
1329 fn test_exclude_nodes_parse_ip_address() {
1330 let exclude = ExcludeNodes::parse("192.168.1.1", None);
1331 assert_eq!(exclude.networks.len(), 1);
1332 }
1333
1334 #[test]
1335 fn test_exclude_nodes_parse_nickname() {
1336 let exclude = ExcludeNodes::parse("BadRelay", None);
1337 assert!(exclude.nicks.contains("BadRelay"));
1338 }
1339
1340 #[test]
1341 fn test_exclude_nodes_parse_mixed() {
1342 let exclude = ExcludeNodes::parse(
1343 "$AABBCCDD00112233445566778899AABBCCDDEEFF,{us},192.168.0.0/16,BadRelay",
1344 None,
1345 );
1346 assert!(exclude
1347 .idhexes
1348 .contains("AABBCCDD00112233445566778899AABBCCDDEEFF"));
1349 assert!(exclude.countries.contains("us"));
1350 assert_eq!(exclude.networks.len(), 1);
1351 assert!(exclude.nicks.contains("BadRelay"));
1352 }
1353
1354 #[test]
1355 fn test_exclude_nodes_geoip_exclude_unknown_1() {
1356 let exclude = ExcludeNodes::parse("", Some("1"));
1357 assert!(exclude.countries.contains("??"));
1358 assert!(exclude.countries.contains("a1"));
1359 }
1360
1361 #[test]
1362 fn test_exclude_nodes_geoip_exclude_unknown_auto() {
1363 let exclude = ExcludeNodes::parse("{us}", Some("auto"));
1364 assert!(exclude.countries.contains("us"));
1365 assert!(exclude.countries.contains("??"));
1366 assert!(exclude.countries.contains("a1"));
1367 }
1368
1369 #[test]
1370 fn test_exclude_nodes_geoip_exclude_unknown_auto_no_countries() {
1371 let exclude = ExcludeNodes::parse("BadRelay", Some("auto"));
1372 assert!(!exclude.countries.contains("??"));
1373 assert!(!exclude.countries.contains("a1"));
1374 }
1375
1376 #[test]
1377 fn test_router_is_excluded_by_fingerprint() {
1378 let exclude = ExcludeNodes::parse("$AABBCCDD00112233445566778899AABBCCDDEEFF", None);
1379 let router = create_test_router(
1380 "AABBCCDD00112233445566778899AABBCCDDEEFF",
1381 "test",
1382 "192.0.2.1",
1383 );
1384 assert!(exclude.router_is_excluded(&router));
1385 }
1386
1387 #[test]
1388 fn test_router_is_excluded_by_nickname() {
1389 let exclude = ExcludeNodes::parse("BadRelay", None);
1390 let router = create_test_router(&"A".repeat(40), "BadRelay", "192.0.2.1");
1391 assert!(exclude.router_is_excluded(&router));
1392 }
1393
1394 #[test]
1395 fn test_router_is_excluded_by_network() {
1396 let exclude = ExcludeNodes::parse("192.168.0.0/16", None);
1397 let router = create_test_router(&"A".repeat(40), "test", "192.168.1.1");
1398 assert!(exclude.router_is_excluded(&router));
1399
1400 let router2 = create_test_router(&"B".repeat(40), "test2", "10.0.0.1");
1401 assert!(!exclude.router_is_excluded(&router2));
1402 }
1403
1404 #[test]
1405 fn test_router_not_excluded() {
1406 let exclude = ExcludeNodes::parse("$BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB,{de}", None);
1407 let router = create_test_router(&"A".repeat(40), "GoodRelay", "192.0.2.1");
1408 assert!(!exclude.router_is_excluded(&router));
1409 }
1410
1411 #[test]
1412 fn test_rendguard_new() {
1413 let rg = RendGuard::new();
1414 assert!(rg.use_counts.is_empty());
1415 assert_eq!(rg.total_use_counts, 0.0);
1416 assert_eq!(rg.pickle_revision, 1.0);
1417 }
1418
1419 #[test]
1420 fn test_rendguard_scale_counts() {
1421 let mut rg = RendGuard::new();
1422 rg.use_counts.insert(
1423 "A".repeat(40),
1424 RendUseCount {
1425 idhex: "A".repeat(40),
1426 used: 100.0,
1427 weight: 0.5,
1428 },
1429 );
1430 rg.use_counts.insert(
1431 "B".repeat(40),
1432 RendUseCount {
1433 idhex: "B".repeat(40),
1434 used: 200.0,
1435 weight: 0.5,
1436 },
1437 );
1438 rg.total_use_counts = 300.0;
1439
1440 rg.scale_counts();
1441
1442 assert_eq!(rg.use_counts.get(&"A".repeat(40)).unwrap().used, 50.0);
1443 assert_eq!(rg.use_counts.get(&"B".repeat(40)).unwrap().used, 100.0);
1444 assert_eq!(rg.total_use_counts, 150.0);
1445 }
1446
1447 #[test]
1448 fn test_exclude_nodes_has_exclusions() {
1449 let empty = ExcludeNodes::new();
1450 assert!(!empty.has_exclusions());
1451
1452 let with_fp = ExcludeNodes::parse("$AABBCCDD00112233445566778899AABBCCDDEEFF", None);
1453 assert!(with_fp.has_exclusions());
1454
1455 let with_country = ExcludeNodes::parse("{us}", None);
1456 assert!(with_country.has_exclusions());
1457
1458 let with_network = ExcludeNodes::parse("192.168.0.0/24", None);
1459 assert!(with_network.has_exclusions());
1460
1461 let with_nick = ExcludeNodes::parse("BadRelay", None);
1462 assert!(with_nick.has_exclusions());
1463 }
1464
1465 #[test]
1466 fn test_exclude_nodes_empty_string() {
1467 let exclude = ExcludeNodes::parse("", None);
1468 assert!(!exclude.has_exclusions());
1469 }
1470
1471 #[test]
1472 fn test_exclude_nodes_whitespace_handling() {
1473 let exclude =
1474 ExcludeNodes::parse(" $AABBCCDD00112233445566778899AABBCCDDEEFF , {us} ", None);
1475 assert!(exclude
1476 .idhexes
1477 .contains("AABBCCDD00112233445566778899AABBCCDDEEFF"));
1478 assert!(exclude.countries.contains("us"));
1479 }
1480
1481 #[test]
1482 fn test_vanguard_state_validation_valid() {
1483 let now = SystemTime::now()
1484 .duration_since(UNIX_EPOCH)
1485 .unwrap()
1486 .as_secs_f64();
1487
1488 let mut state = VanguardState::new("test.state");
1489 state
1490 .layer2
1491 .push(GuardNode::new("A".repeat(40), now - 1000.0, now + 86400.0));
1492 state
1493 .layer3
1494 .push(GuardNode::new("B".repeat(40), now - 500.0, now + 3600.0));
1495
1496 assert!(state.validate().is_ok());
1497 }
1498
1499 #[test]
1500 fn test_vanguard_state_validation_invalid_fingerprint() {
1501 let now = SystemTime::now()
1502 .duration_since(UNIX_EPOCH)
1503 .unwrap()
1504 .as_secs_f64();
1505
1506 let mut state = VanguardState::new("test.state");
1507 state
1508 .layer2
1509 .push(GuardNode::new("invalid".to_string(), now, now + 86400.0));
1510
1511 assert!(state.validate().is_err());
1512 }
1513
1514 #[test]
1515 fn test_vanguard_state_validation_future_timestamp() {
1516 let now = SystemTime::now()
1517 .duration_since(UNIX_EPOCH)
1518 .unwrap()
1519 .as_secs_f64();
1520
1521 let mut state = VanguardState::new("test.state");
1522 state
1523 .layer2
1524 .push(GuardNode::new("A".repeat(40), now + 10000.0, now + 86400.0));
1525
1526 assert!(state.validate().is_err());
1527 }
1528}
1529
1530#[cfg(test)]
1531mod proptests {
1532 use super::*;
1533 use crate::node_selection::is_valid_fingerprint;
1534 use proptest::prelude::*;
1535
1536 fn arb_fingerprint() -> impl Strategy<Value = String> {
1537 "[0-9A-F]{40}".prop_map(|s| s.to_uppercase())
1538 }
1539
1540 fn arb_country_code() -> impl Strategy<Value = String> {
1541 "[a-z]{2}"
1542 }
1543
1544 fn arb_ipv4() -> impl Strategy<Value = String> {
1545 (1u8..=254, 0u8..=255, 0u8..=255, 1u8..=254)
1546 .prop_map(|(a, b, c, d)| format!("{}.{}.{}.{}", a, b, c, d))
1547 }
1548
1549 fn arb_cidr() -> impl Strategy<Value = String> {
1550 (arb_ipv4(), 8u8..=30).prop_map(|(ip, prefix)| format!("{}/{}", ip, prefix))
1551 }
1552
1553 fn arb_nickname() -> impl Strategy<Value = String> {
1554 "[A-Za-z][A-Za-z0-9]{0,18}"
1555 }
1556
1557 fn arb_guard_node() -> impl Strategy<Value = GuardNode> {
1558 let now = std::time::SystemTime::now()
1559 .duration_since(std::time::UNIX_EPOCH)
1560 .unwrap()
1561 .as_secs_f64();
1562 let chosen_min = now - 365.0 * 86400.0;
1563 let chosen_max = now;
1564 let expires_max = now + 365.0 * 86400.0;
1565
1566 (
1567 arb_fingerprint(),
1568 chosen_min..chosen_max,
1569 chosen_max..expires_max,
1570 )
1571 .prop_map(|(idhex, chosen_at, expires_at)| GuardNode::new(idhex, chosen_at, expires_at))
1572 }
1573
1574 fn arb_rend_use_count() -> impl Strategy<Value = RendUseCount> {
1575 (arb_fingerprint(), 0.0f64..10000.0, 0.0f64..1.0).prop_map(|(idhex, used, weight)| {
1576 RendUseCount {
1577 idhex,
1578 used,
1579 weight,
1580 }
1581 })
1582 }
1583
1584 fn arb_rendguard() -> impl Strategy<Value = RendGuard> {
1585 (
1586 prop::collection::vec(arb_rend_use_count(), 0..10),
1587 0.0f64..100000.0,
1588 )
1589 .prop_map(|(counts, total)| {
1590 let mut rg = RendGuard::new();
1591 for count in counts {
1592 rg.use_counts.insert(count.idhex.clone(), count);
1593 }
1594 rg.total_use_counts = total;
1595 rg
1596 })
1597 }
1598
1599 fn arb_vanguard_state() -> impl Strategy<Value = VanguardState> {
1600 (
1601 prop::collection::vec(arb_guard_node(), 0..8),
1602 prop::collection::vec(arb_guard_node(), 0..16),
1603 arb_rendguard(),
1604 )
1605 .prop_map(|(layer2, layer3, rendguard)| {
1606 let mut state = VanguardState::new("test.state");
1607 state.layer2 = layer2;
1608 state.layer3 = layer3;
1609 state.rendguard = rendguard;
1610 state
1611 })
1612 }
1613
1614 proptest! {
1615 #![proptest_config(ProptestConfig::with_cases(100))]
1616
1617 #[test]
1618 fn exclude_nodes_parsing(
1619 fingerprints in prop::collection::vec(arb_fingerprint(), 0..5),
1620 countries in prop::collection::vec(arb_country_code(), 0..5),
1621 networks in prop::collection::vec(arb_cidr(), 0..3),
1622 nicknames in prop::collection::vec(arb_nickname(), 0..5),
1623 ) {
1624 let mut parts = Vec::new();
1625
1626 for fp in &fingerprints {
1627 parts.push(format!("${}", fp));
1628 }
1629 for cc in &countries {
1630 parts.push(format!("{{{}}}", cc));
1631 }
1632 for net in &networks {
1633 parts.push(net.clone());
1634 }
1635 for nick in &nicknames {
1636 parts.push(nick.clone());
1637 }
1638
1639 let conf_line = parts.join(",");
1640 let exclude = ExcludeNodes::parse(&conf_line, None);
1641
1642 for fp in &fingerprints {
1643 prop_assert!(exclude.idhexes.contains(&fp.to_uppercase()),
1644 "Fingerprint {} not found in parsed idhexes", fp);
1645 }
1646
1647 for cc in &countries {
1648 prop_assert!(exclude.countries.contains(&cc.to_lowercase()),
1649 "Country code {} not found in parsed countries", cc);
1650 }
1651
1652 prop_assert_eq!(exclude.networks.len(), networks.len(),
1653 "Expected {} networks, got {}", networks.len(), exclude.networks.len());
1654
1655 for nick in &nicknames {
1656 if !is_valid_fingerprint(nick) && !nick.contains('.') && !nick.contains(':') {
1657 prop_assert!(exclude.nicks.contains(nick),
1658 "Nickname {} not found in parsed nicks", nick);
1659 }
1660 }
1661 }
1662
1663 #[test]
1664 fn state_serialization_round_trip(state in arb_vanguard_state()) {
1665 let temp_dir = tempfile::tempdir().expect("Failed to create temp dir");
1666 let state_path = temp_dir.path().join("test.state");
1667
1668 state.write_to_file(&state_path).expect("Failed to write state");
1669 let loaded = VanguardState::read_from_file(&state_path).expect("Failed to read state");
1670
1671 prop_assert_eq!(state.layer2.len(), loaded.layer2.len());
1672 prop_assert_eq!(state.layer3.len(), loaded.layer3.len());
1673
1674 for (orig, load) in state.layer2.iter().zip(loaded.layer2.iter()) {
1675 prop_assert_eq!(&orig.idhex, &load.idhex);
1676 prop_assert!((orig.chosen_at - load.chosen_at).abs() < 0.001);
1677 prop_assert!((orig.expires_at - load.expires_at).abs() < 0.001);
1678 }
1679
1680 for (orig, load) in state.layer3.iter().zip(loaded.layer3.iter()) {
1681 prop_assert_eq!(&orig.idhex, &load.idhex);
1682 prop_assert!((orig.chosen_at - load.chosen_at).abs() < 0.001);
1683 prop_assert!((orig.expires_at - load.expires_at).abs() < 0.001);
1684 }
1685
1686 prop_assert_eq!(state.rendguard.use_counts.len(), loaded.rendguard.use_counts.len());
1687 prop_assert!((state.rendguard.total_use_counts - loaded.rendguard.total_use_counts).abs() < 0.001);
1688 }
1689
1690 #[test]
1691 fn guard_lifetime_distribution(
1692 min_hours in 1u32..100,
1693 max_hours in 100u32..2000,
1694 ) {
1695 prop_assume!(min_hours < max_hours);
1696
1697 let min_secs = min_hours as f64 * 3600.0;
1698 let max_secs = max_hours as f64 * 3600.0;
1699
1700 let mut lifetimes = Vec::new();
1701 for _ in 0..100 {
1702 let lifetime = VanguardState::calculate_guard_lifetime(min_hours, max_hours);
1703 prop_assert!(lifetime >= min_secs, "Lifetime {} below min {}", lifetime, min_secs);
1704 prop_assert!(lifetime <= max_secs, "Lifetime {} above max {}", lifetime, max_secs);
1705 lifetimes.push(lifetime);
1706 }
1707
1708 let avg = lifetimes.iter().sum::<f64>() / lifetimes.len() as f64;
1709 let midpoint = (min_secs + max_secs) / 2.0;
1710 prop_assert!(avg > midpoint,
1711 "Average lifetime {} should be above midpoint {} (max of two uniforms)", avg, midpoint);
1712 }
1713
1714 #[test]
1715 fn expired_guard_removal(
1716 num_expired in 0usize..5,
1717 num_valid in 0usize..5,
1718 ) {
1719 let now = SystemTime::now()
1720 .duration_since(UNIX_EPOCH)
1721 .unwrap()
1722 .as_secs_f64();
1723
1724 let mut layer = Vec::new();
1725 let mut expected_remaining = HashSet::new();
1726
1727 for i in 0..num_expired {
1728 let fp = format!("{:0>40X}", i);
1729 layer.push(GuardNode::new(fp, now - 10000.0, now - 1000.0));
1730 }
1731
1732 for i in 0..num_valid {
1733 let fp = format!("{:0>40X}", 100 + i);
1734 layer.push(GuardNode::new(fp.clone(), now - 1000.0, now + 86400.0));
1735 expected_remaining.insert(fp);
1736 }
1737
1738 VanguardState::remove_expired_from_layer(&mut layer);
1739
1740 prop_assert_eq!(layer.len(), num_valid,
1741 "Expected {} guards after removal, got {}", num_valid, layer.len());
1742
1743 for guard in &layer {
1744 prop_assert!(expected_remaining.contains(&guard.idhex),
1745 "Unexpected guard {} in layer", guard.idhex);
1746 prop_assert!(guard.expires_at >= now,
1747 "Guard {} should not be expired", guard.idhex);
1748 }
1749 }
1750
1751 #[test]
1752 fn down_guard_removal(
1753 num_in_consensus in 0usize..5,
1754 num_not_in_consensus in 0usize..5,
1755 ) {
1756 let now = SystemTime::now()
1757 .duration_since(UNIX_EPOCH)
1758 .unwrap()
1759 .as_secs_f64();
1760
1761 let mut layer = Vec::new();
1762 let mut consensus_fps = HashSet::new();
1763
1764 for i in 0..num_in_consensus {
1765 let fp = format!("{:0>40X}", i);
1766 layer.push(GuardNode::new(fp.clone(), now, now + 86400.0));
1767 consensus_fps.insert(fp);
1768 }
1769
1770 for i in 0..num_not_in_consensus {
1771 let fp = format!("{:0>40X}", 100 + i);
1772 layer.push(GuardNode::new(fp, now, now + 86400.0));
1773 }
1774
1775 VanguardState::remove_down_from_layer(&mut layer, &consensus_fps);
1776
1777 prop_assert_eq!(layer.len(), num_in_consensus,
1778 "Expected {} guards after removal, got {}", num_in_consensus, layer.len());
1779
1780 for guard in &layer {
1781 prop_assert!(consensus_fps.contains(&guard.idhex),
1782 "Guard {} should be in consensus", guard.idhex);
1783 }
1784 }
1785
1786 #[test]
1787 fn layer_trimming(
1788 initial_layer2 in 0usize..20,
1789 initial_layer3 in 0usize..30,
1790 target_layer2 in 1u8..10,
1791 target_layer3 in 1u8..15,
1792 ) {
1793 let now = SystemTime::now()
1794 .duration_since(UNIX_EPOCH)
1795 .unwrap()
1796 .as_secs_f64();
1797
1798 let mut state = VanguardState::new("test.state");
1799
1800 for i in 0..initial_layer2 {
1801 let fp = format!("{:0>40X}", i);
1802 state.layer2.push(GuardNode::new(fp, now, now + 86400.0));
1803 }
1804
1805 for i in 0..initial_layer3 {
1806 let fp = format!("{:0>40X}", 100 + i);
1807 state.layer3.push(GuardNode::new(fp, now, now + 86400.0));
1808 }
1809
1810 state.layer2.truncate(target_layer2 as usize);
1811 state.layer3.truncate(target_layer3 as usize);
1812
1813 let expected_layer2 = initial_layer2.min(target_layer2 as usize);
1814 let expected_layer3 = initial_layer3.min(target_layer3 as usize);
1815
1816 prop_assert_eq!(state.layer2.len(), expected_layer2,
1817 "Layer2 should have {} guards, got {}", expected_layer2, state.layer2.len());
1818 prop_assert_eq!(state.layer3.len(), expected_layer3,
1819 "Layer3 should have {} guards, got {}", expected_layer3, state.layer3.len());
1820 }
1821 }
1822}