vanguards_rs/
vanguards.rs

1//! Vanguard state management and ExcludeNodes parsing.
2//!
3//! This module provides persistent vanguard guard selection and state management,
4//! along with ExcludeNodes configuration parsing for relay exclusion.
5//!
6//! # Overview
7//!
8//! The vanguard system maintains persistent sets of guard relays at two layers:
9//!
10//! - **Layer 2 Guards**: Second-hop relays with longer lifetimes (1-45 days)
11//! - **Layer 3 Guards**: Third-hop relays with shorter lifetimes (1-48 hours)
12//!
13//! Guards are selected using bandwidth-weighted random selection and rotated
14//! based on configurable lifetime parameters.
15//!
16//! # Guard Layer Architecture
17//!
18//! Vanguards protect hidden services by restricting which relays can be used
19//! at each position in the circuit:
20//!
21//! ```text
22//! ┌────────────────────────────────────────────────────────────────────────────┐
23//! │                    Hidden Service Circuit Path                             │
24//! │                                                                            │
25//! │  ┌──────────┐     ┌──────────┐       ┌──────────┐     ┌──────────┐         │
26//! │  │  Client  │───▶│ Layer 1  │───▶ │ Layer 2  │───▶│ Layer 3  │───▶ HS  │
27//! │  │          │    │ (Entry)  │      │ (Middle) │     │ (Middle) │           │
28//! │  └──────────┘    └──────────┘      └──────────┘    └──────────┘            │
29//! │                       │               │               │                    │
30//! │                       ▼               ▼               ▼                    │
31//! │                  ┌─────────┐    ┌─────────┐    ┌─────────┐                 │
32//! │                  │ Tor's   │    │ 4-8     │    │ 4-8     │                 │
33//! │                  │ Guard   │    │ Guards  │    │ Guards  │                 │
34//! │                  │ System  │    │ 1-45    │    │ 1-48    │                 │
35//! │                  │         │    │ days    │    │ hours   │                 │
36//! │                  └─────────┘    └─────────┘    └─────────┘                 │
37//! └────────────────────────────────────────────────────────────────────────────┘
38//! ```
39//!
40//! # Guard Lifecycle
41//!
42//! Guards progress through the following states:
43//!
44//! ```text
45//!                    ┌─────────────────┐
46//!                    │    Selection    │
47//!                    │ (BW-weighted)   │
48//!                    └────────┬────────┘
49//!                             │
50//!                             ▼
51//!                    ┌─────────────────┐
52//!                    │     Active      │
53//!                    │ (in guardset)   │
54//!                    └────────┬────────┘
55//!                             │
56//!          ┌──────────────────┼──────────────────┐
57//!          │                  │                  │
58//!          ▼                  ▼                  ▼
59//!    ┌───────────┐     ┌───────────┐     ┌───────────┐
60//!    │  Expired  │     │   Down    │     │ Excluded  │
61//!    │(lifetime) │     │(consensus)│     │(ExcludeN) │
62//!    └───────────┘     └───────────┘     └───────────┘
63//!          │                  │                  │
64//!          └──────────────────┼──────────────────┘
65//!                             │
66//!                             ▼
67//!                    ┌─────────────────┐
68//!                    │    Removed      │
69//!                    │ (replenished)   │
70//!                    └─────────────────┘
71//! ```
72//!
73//! # State Persistence
74//!
75//! State is persisted in Python pickle format for compatibility with the
76//! Python vanguards implementation. This allows seamless migration between
77//! implementations.
78//!
79//! ```text
80//! ┌─────────────────────────────────────────────────────────────────────────┐
81//! │                        State File Format                                │
82//! │                                                                         │
83//! │  VanguardState {                                                        │
84//! │      layer2: [                                                          │
85//! │          GuardNode { idhex, chosen_at, expires_at },                    │
86//! │          ...                                                            │
87//! │      ],                                                                 │
88//! │      layer3: [                                                          │
89//! │          GuardNode { idhex, chosen_at, expires_at },                    │
90//! │          ...                                                            │
91//! │      ],                                                                 │
92//! │      rendguard: RendGuard { use_counts, total_use_counts },             │
93//! │      pickle_revision: 1,                                                │
94//! │  }                                                                      │
95//! └─────────────────────────────────────────────────────────────────────────┘
96//! ```
97//!
98//! # What This Module Does NOT Do
99//!
100//! - **Guard selection algorithm**: Use [`crate::node_selection`] for bandwidth-weighted selection
101//! - **Tor configuration**: Use [`crate::control::configure_tor`] to apply guards to Tor
102//! - **Attack detection**: Use [`crate::bandguards`] and [`crate::rendguard`] for monitoring
103//!
104//! # ExcludeNodes
105//!
106//! The [`ExcludeNodes`] struct parses Tor's ExcludeNodes configuration to
107//! filter out unwanted relays based on:
108//!
109//! - Fingerprints (40 hex characters, optionally prefixed with $)
110//! - Country codes ({cc} format)
111//! - IP networks (CIDR notation)
112//! - Nicknames
113//!
114//! # Example
115//!
116//! ```rust,no_run
117//! use vanguards_rs::vanguards::{VanguardState, GuardNode, ExcludeNodes};
118//! use std::path::Path;
119//!
120//! // Load or create vanguard state
121//! let mut state = VanguardState::load_or_create(Path::new("vanguards.state"));
122//!
123//! // Check current guards
124//! println!("Layer 2 guards: {}", state.layer2_guardset());
125//! println!("Layer 3 guards: {}", state.layer3_guardset());
126//!
127//! // Parse exclusion configuration
128//! let exclude = ExcludeNodes::parse("{us},{ru},BadRelay", None);
129//! println!("Excluding {} countries", exclude.countries.len());
130//! ```
131//!
132//! # Security Considerations
133//!
134//! - State files contain guard fingerprints - protect with appropriate permissions
135//! - Guard lifetimes use max-of-two-uniform distribution for better security
136//! - Atomic writes prevent state file corruption
137//! - Validation prevents loading corrupted or malicious state files
138//!
139//! # See Also
140//!
141//! - [`crate::node_selection`] - Bandwidth-weighted node selection
142//! - [`crate::config::VanguardsConfig`] - Vanguard configuration options
143//! - [`crate::rendguard`] - Rendezvous point monitoring (uses RendGuard from this module)
144//! - [Python vanguards](https://github.com/mikeperry-tor/vanguards) - Original implementation
145//! - [Vanguards proposal](https://github.com/torproject/torspec/blob/main/proposals/292-mesh-vanguards.txt) - Design specification
146
147use std::collections::{HashMap, HashSet};
148use std::fs::File;
149use std::io::{BufReader, BufWriter, Write};
150use std::net::IpAddr;
151use std::path::Path;
152use std::time::{SystemTime, UNIX_EPOCH};
153
154use ipnetwork::IpNetwork;
155use rand::Rng;
156use serde::{Deserialize, Serialize};
157use stem_rs::descriptor::router_status::RouterStatusEntry;
158
159use crate::config::VanguardsConfig;
160use crate::error::{Error, Result};
161use crate::node_selection::{is_valid_country_code, is_valid_fingerprint, BwWeightedGenerator};
162
163/// Seconds per hour constant.
164const SEC_PER_HOUR: f64 = 3600.0;
165
166/// A guard node selected as a vanguard with lifetime metadata.
167///
168/// Each guard node tracks when it was selected and when it should expire.
169/// Timestamps are stored as Unix timestamps (f64) for Python pickle compatibility.
170///
171/// # Fields
172///
173/// - `idhex`: The relay's 40-character uppercase hex fingerprint
174/// - `chosen_at`: Unix timestamp when this guard was selected
175/// - `expires_at`: Unix timestamp when this guard should be rotated
176///
177/// # Lifetime Calculation
178///
179/// Guard lifetimes are calculated using the max of two uniform random samples
180/// from the configured range. This distribution favors longer lifetimes,
181/// providing better security by reducing guard rotation frequency.
182///
183/// ```text
184/// Lifetime = max(uniform(min, max), uniform(min, max))
185/// ```
186///
187/// # Example
188///
189/// ```rust
190/// use vanguards_rs::vanguards::GuardNode;
191/// use std::time::{SystemTime, UNIX_EPOCH};
192///
193/// let now = SystemTime::now()
194///     .duration_since(UNIX_EPOCH)
195///     .unwrap()
196///     .as_secs_f64();
197/// let expires = now + 86400.0; // 24 hours
198///
199/// let guard = GuardNode::new("A".repeat(40), now, expires);
200/// assert!(!guard.is_expired());
201/// ```
202///
203/// # See Also
204///
205/// - [`VanguardState::calculate_guard_lifetime`] - Lifetime calculation
206/// - [`VanguardState::add_new_layer2`] - Layer 2 guard creation
207/// - [`VanguardState::add_new_layer3`] - Layer 3 guard creation
208#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
209pub struct GuardNode {
210    /// The relay's 40-character hex fingerprint.
211    pub idhex: String,
212    /// Unix timestamp when this guard was selected.
213    pub chosen_at: f64,
214    /// Unix timestamp when this guard should be rotated.
215    pub expires_at: f64,
216}
217
218impl GuardNode {
219    /// Creates a new guard node with the specified fingerprint and timestamps.
220    ///
221    /// # Arguments
222    ///
223    /// * `idhex` - The relay's 40-character hex fingerprint
224    /// * `chosen_at` - Unix timestamp when this guard was selected
225    /// * `expires_at` - Unix timestamp when this guard should be rotated
226    ///
227    /// # Returns
228    ///
229    /// A new `GuardNode` instance.
230    ///
231    /// # Example
232    ///
233    /// ```rust
234    /// use vanguards_rs::vanguards::GuardNode;
235    ///
236    /// let guard = GuardNode::new(
237    ///     "AABBCCDD00112233445566778899AABBCCDDEEFF".to_string(),
238    ///     1700000000.0,  // chosen_at
239    ///     1700086400.0,  // expires_at (24 hours later)
240    /// );
241    /// ```
242    pub fn new(idhex: String, chosen_at: f64, expires_at: f64) -> Self {
243        Self {
244            idhex,
245            chosen_at,
246            expires_at,
247        }
248    }
249
250    /// Returns true if this guard has expired.
251    ///
252    /// Compares the current time against `expires_at` to determine if
253    /// this guard should be rotated.
254    ///
255    /// # Returns
256    ///
257    /// `true` if the current time is past `expires_at`, `false` otherwise.
258    ///
259    /// # Example
260    ///
261    /// ```rust
262    /// use vanguards_rs::vanguards::GuardNode;
263    /// use std::time::{SystemTime, UNIX_EPOCH};
264    ///
265    /// let now = SystemTime::now()
266    ///     .duration_since(UNIX_EPOCH)
267    ///     .unwrap()
268    ///     .as_secs_f64();
269    ///
270    /// // Expired guard
271    /// let expired = GuardNode::new("A".repeat(40), now - 1000.0, now - 100.0);
272    /// assert!(expired.is_expired());
273    ///
274    /// // Active guard
275    /// let active = GuardNode::new("B".repeat(40), now, now + 86400.0);
276    /// assert!(!active.is_expired());
277    /// ```
278    pub fn is_expired(&self) -> bool {
279        let now = SystemTime::now()
280            .duration_since(UNIX_EPOCH)
281            .unwrap_or_default()
282            .as_secs_f64();
283        self.expires_at < now
284    }
285}
286
287/// Rendezvous point usage count for a single relay.
288///
289/// Tracks how many times a relay has been used as a rendezvous point
290/// and its expected weight based on consensus bandwidth.
291#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
292pub struct RendUseCount {
293    /// The relay's fingerprint.
294    pub idhex: String,
295    /// Number of times this relay has been used.
296    pub used: f64,
297    /// Expected usage weight based on bandwidth.
298    pub weight: f64,
299}
300
301impl RendUseCount {
302    /// Creates a new usage count entry.
303    pub fn new(idhex: String, weight: f64) -> Self {
304        Self {
305            idhex,
306            used: 0.0,
307            weight,
308        }
309    }
310}
311
312/// Rendezvous point usage tracking for detecting statistical attacks.
313///
314/// Tracks usage counts for all relays used as rendezvous points and
315/// detects when a relay is being used more than expected based on
316/// its bandwidth weight.
317#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
318pub struct RendGuard {
319    /// Usage counts per relay fingerprint.
320    pub use_counts: HashMap<String, RendUseCount>,
321    /// Total usage count across all relays.
322    pub total_use_counts: f64,
323    /// Version number for pickle compatibility.
324    pub pickle_revision: f64,
325}
326
327impl Default for RendGuard {
328    fn default() -> Self {
329        Self::new()
330    }
331}
332
333impl RendGuard {
334    /// Creates a new empty RendGuard.
335    pub fn new() -> Self {
336        Self {
337            use_counts: HashMap::new(),
338            total_use_counts: 0.0,
339            pickle_revision: 1.0,
340        }
341    }
342
343    /// Scales all usage counts by half.
344    ///
345    /// Called when total_use_counts reaches the scale threshold to prevent
346    /// unbounded growth and to avoid over-counting high-uptime relays.
347    pub fn scale_counts(&mut self) {
348        for count in self.use_counts.values_mut() {
349            count.used /= 2.0;
350        }
351        self.total_use_counts = self.use_counts.values().map(|c| c.used).sum();
352    }
353
354    /// Records a rendezvous point usage and checks for overuse.
355    ///
356    /// This method should be called each time a relay is used as a rendezvous
357    /// point for a hidden service circuit.
358    ///
359    /// # Arguments
360    ///
361    /// * `fingerprint` - The relay's fingerprint (40 hex characters)
362    /// * `config` - Rendguard configuration
363    ///
364    /// # Returns
365    ///
366    /// `true` if the usage is valid (not overused), `false` if overused.
367    pub fn valid_rend_use(
368        &mut self,
369        fingerprint: &str,
370        config: &crate::config::RendguardConfig,
371    ) -> bool {
372        const NOT_IN_CONSENSUS_ID: &str = "NOT_IN_CONSENSUS";
373
374        let relay_id = if self.use_counts.contains_key(fingerprint) {
375            fingerprint.to_string()
376        } else {
377            // Relay not in consensus - track under special ID
378            if !self.use_counts.contains_key(NOT_IN_CONSENSUS_ID) {
379                self.use_counts.insert(
380                    NOT_IN_CONSENSUS_ID.to_string(),
381                    RendUseCount::new(NOT_IN_CONSENSUS_ID.to_string(), 0.0),
382                );
383            }
384            NOT_IN_CONSENSUS_ID.to_string()
385        };
386
387        // Increment usage counts
388        if let Some(count) = self.use_counts.get_mut(&relay_id) {
389            count.used += 1.0;
390        }
391        self.total_use_counts += 1.0;
392
393        // Check for overuse
394        if let Some(count) = self.use_counts.get(&relay_id) {
395            if self.total_use_counts >= config.use_global_start_count as f64
396                && count.used >= config.use_relay_start_count as f64
397                && count.used / self.total_use_counts
398                    > count.weight * config.use_max_use_to_bw_ratio
399            {
400                return false; // Overused
401            }
402        }
403
404        true // Valid usage
405    }
406
407    /// Transfers and updates use counts on consensus change.
408    ///
409    /// This method should be called when a new consensus is received.
410    pub fn xfer_use_counts(
411        &mut self,
412        generator: &BwWeightedGenerator,
413        config: &crate::config::RendguardConfig,
414    ) {
415        const NOT_IN_CONSENSUS_ID: &str = "NOT_IN_CONSENSUS";
416
417        let old_counts = std::mem::take(&mut self.use_counts);
418        let should_scale = self.total_use_counts >= config.use_scale_at_count as f64;
419
420        // Create entries for all routers in new consensus
421        let routers = generator.routers();
422        let node_weights = generator.node_weights();
423        let weight_total = generator.weight_total();
424        let exit_total = generator.exit_total();
425
426        for (i, router) in routers.iter().enumerate() {
427            let weight = if router.flags.contains(&"Exit".to_string()) && exit_total > 0.0 {
428                node_weights[i] / exit_total
429            } else if weight_total > 0.0 {
430                node_weights[i] / weight_total
431            } else {
432                0.0
433            };
434
435            self.use_counts.insert(
436                router.fingerprint.clone(),
437                RendUseCount::new(router.fingerprint.clone(), weight),
438            );
439        }
440
441        // Add NOT_IN_CONSENSUS entry
442        self.use_counts.insert(
443            NOT_IN_CONSENSUS_ID.to_string(),
444            RendUseCount::new(
445                NOT_IN_CONSENSUS_ID.to_string(),
446                config.use_max_consensus_weight_churn / 100.0,
447            ),
448        );
449
450        // Transfer old counts
451        for (fp, old_count) in old_counts {
452            if fp == NOT_IN_CONSENSUS_ID || self.use_counts.contains_key(&fp) {
453                if let Some(new_count) = self.use_counts.get_mut(&fp) {
454                    new_count.used = if should_scale {
455                        old_count.used / 2.0
456                    } else {
457                        old_count.used
458                    };
459                }
460            }
461        }
462
463        // Recalculate total
464        self.total_use_counts = self.use_counts.values().map(|c| c.used).sum();
465    }
466
467    /// Returns the usage rate for a relay as a percentage.
468    pub fn usage_rate(&self, fingerprint: &str) -> f64 {
469        if self.total_use_counts <= 0.0 {
470            return 0.0;
471        }
472        self.use_counts
473            .get(fingerprint)
474            .map(|c| 100.0 * c.used / self.total_use_counts)
475            .unwrap_or(0.0)
476    }
477
478    /// Returns the expected weight for a relay as a percentage.
479    pub fn expected_weight(&self, fingerprint: &str) -> f64 {
480        self.use_counts
481            .get(fingerprint)
482            .map(|c| 100.0 * c.weight)
483            .unwrap_or(0.0)
484    }
485
486    /// Checks if a relay is currently overused.
487    pub fn is_overused(&self, fingerprint: &str, config: &crate::config::RendguardConfig) -> bool {
488        if self.total_use_counts < config.use_global_start_count as f64 {
489            return false;
490        }
491
492        if let Some(count) = self.use_counts.get(fingerprint) {
493            if count.used < config.use_relay_start_count as f64 {
494                return false;
495            }
496            count.used / self.total_use_counts > count.weight * config.use_max_use_to_bw_ratio
497        } else {
498            false
499        }
500    }
501}
502
503/// Persistent vanguard state containing guard layers and rendguard tracking.
504///
505/// Contains the layer 2 and layer 3 guard lists, along with rendguard state.
506/// This state is persisted to disk in Python pickle format for compatibility.
507///
508/// # Guard Layers
509///
510/// ```text
511/// ┌─────────────────────────────────────────────────────────────────────────┐
512/// │                         VanguardState                                   │
513/// │                                                                         │
514/// │  ┌─────────────────────────────────────────────────────────────────┐    │
515/// │  │ Layer 2 Guards (HSLayer2Nodes)                                  │    │
516/// │  │ • 4-8 guards (configurable)                                     │    │
517/// │  │ • Lifetime: 1-45 days (configurable)                            │    │
518/// │  │ • Used for second hop in HS circuits                            │    │
519/// │  └─────────────────────────────────────────────────────────────────┘    │
520/// │                                                                         │
521/// │  ┌─────────────────────────────────────────────────────────────────┐    │
522/// │  │ Layer 3 Guards (HSLayer3Nodes)                                  │    │
523/// │  │ • 4-8 guards (configurable)                                     │    │
524/// │  │ • Lifetime: 1-48 hours (configurable)                           │    │
525/// │  │ • Used for third hop in HS circuits                             │    │
526/// │  └─────────────────────────────────────────────────────────────────┘    │
527/// │                                                                         │
528/// │  ┌─────────────────────────────────────────────────────────────────┐    │
529/// │  │ RendGuard                                                       │    │
530/// │  │ • Tracks rendezvous point usage                                 │    │
531/// │  │ • Detects statistical attacks                                   │    │
532/// │  └─────────────────────────────────────────────────────────────────┘    │
533/// └─────────────────────────────────────────────────────────────────────────┘
534/// ```
535///
536/// # State File Format
537///
538/// The state file uses Python pickle format with the following structure:
539///
540/// ```text
541/// VanguardState {
542///     layer2: [GuardNode, ...],
543///     layer3: [GuardNode, ...],
544///     state_file: String,
545///     rendguard: RendGuard,
546///     pickle_revision: u32,
547/// }
548/// ```
549///
550/// # Thread Safety
551///
552/// `VanguardState` is not thread-safe. It should be accessed from a single
553/// task or protected with appropriate synchronization.
554///
555/// # Example
556///
557/// ```rust,no_run
558/// use vanguards_rs::vanguards::VanguardState;
559/// use std::path::Path;
560///
561/// // Load existing state or create new
562/// let mut state = VanguardState::load_or_create(Path::new("vanguards.state"));
563///
564/// // Check current guards
565/// println!("Layer 2: {}", state.layer2_guardset());
566/// println!("Layer 3: {}", state.layer3_guardset());
567///
568/// // Save state
569/// state.write_to_file(Path::new("vanguards.state")).unwrap();
570/// ```
571///
572/// # See Also
573///
574/// - [`GuardNode`] - Individual guard node
575/// - [`RendGuard`] - Rendezvous point tracking
576/// - [`crate::config::VanguardsConfig`] - Configuration options
577#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
578pub struct VanguardState {
579    /// Layer 2 guard nodes (second hop).
580    pub layer2: Vec<GuardNode>,
581    /// Layer 3 guard nodes (third hop).
582    pub layer3: Vec<GuardNode>,
583    /// Path to the state file.
584    pub state_file: String,
585    /// Rendezvous point usage tracking.
586    pub rendguard: RendGuard,
587    /// Version number for pickle compatibility.
588    pub pickle_revision: u32,
589    /// Whether vanguards are enabled (runtime flag, not persisted).
590    #[serde(skip)]
591    pub enable_vanguards: bool,
592}
593
594impl Default for VanguardState {
595    fn default() -> Self {
596        Self::new("vanguards.state")
597    }
598}
599
600impl VanguardState {
601    /// Creates a new empty vanguard state.
602    pub fn new(state_file: &str) -> Self {
603        Self {
604            layer2: Vec::new(),
605            layer3: Vec::new(),
606            state_file: state_file.to_string(),
607            rendguard: RendGuard::new(),
608            pickle_revision: 1,
609            enable_vanguards: true,
610        }
611    }
612
613    /// Loads state from a file or creates new state if the file doesn't exist.
614    ///
615    /// # Arguments
616    ///
617    /// * `path` - Path to the state file
618    ///
619    /// # Returns
620    ///
621    /// The loaded or newly created state.
622    pub fn load_or_create(path: &Path) -> Self {
623        match Self::read_from_file(path) {
624            Ok(mut state) => {
625                state.state_file = path.to_string_lossy().to_string();
626                state
627            }
628            Err(_) => Self::new(&path.to_string_lossy()),
629        }
630    }
631
632    /// Reads state from a pickle file with validation.
633    ///
634    /// Validates that:
635    /// - All fingerprints are valid 40-character hex strings
636    /// - No timestamps are in the future (with 1 hour tolerance)
637    /// - The file format is valid
638    ///
639    /// # Errors
640    ///
641    /// Returns [`Error::State`] if the file cannot be read, parsed, or fails validation.
642    pub fn read_from_file(path: &Path) -> Result<Self> {
643        let file =
644            File::open(path).map_err(|e| Error::State(format!("cannot open state file: {}", e)))?;
645        let reader = BufReader::new(file);
646        let state: Self = serde_pickle::from_reader(reader, Default::default())
647            .map_err(|e| Error::State(format!("cannot parse state file: {}", e)))?;
648
649        // Validate the loaded state
650        state.validate()?;
651
652        Ok(state)
653    }
654
655    /// Validates the state for integrity.
656    ///
657    /// Checks:
658    /// - All fingerprints are valid 40-character hex strings
659    /// - No timestamps are in the future (with 1 hour tolerance for clock skew)
660    ///
661    /// # Errors
662    ///
663    /// Returns [`Error::State`] if validation fails.
664    pub fn validate(&self) -> Result<()> {
665        let now = SystemTime::now()
666            .duration_since(UNIX_EPOCH)
667            .unwrap_or_default()
668            .as_secs_f64();
669
670        // Allow 1 hour tolerance for clock skew
671        let max_timestamp = now + 3600.0;
672
673        // Validate layer2 guards
674        for guard in &self.layer2 {
675            if !is_valid_fingerprint(&guard.idhex) {
676                return Err(Error::State(format!(
677                    "invalid fingerprint in layer2: {}",
678                    guard.idhex
679                )));
680            }
681            if guard.chosen_at > max_timestamp {
682                return Err(Error::State(format!(
683                    "future timestamp in layer2 guard {}: chosen_at {} > now {}",
684                    guard.idhex, guard.chosen_at, now
685                )));
686            }
687            if guard.expires_at > max_timestamp + 86400.0 * 365.0 {
688                // Allow up to 1 year in the future for expires_at
689                return Err(Error::State(format!(
690                    "unreasonable future expiration in layer2 guard {}: expires_at {}",
691                    guard.idhex, guard.expires_at
692                )));
693            }
694        }
695
696        // Validate layer3 guards
697        for guard in &self.layer3 {
698            if !is_valid_fingerprint(&guard.idhex) {
699                return Err(Error::State(format!(
700                    "invalid fingerprint in layer3: {}",
701                    guard.idhex
702                )));
703            }
704            if guard.chosen_at > max_timestamp {
705                return Err(Error::State(format!(
706                    "future timestamp in layer3 guard {}: chosen_at {} > now {}",
707                    guard.idhex, guard.chosen_at, now
708                )));
709            }
710            if guard.expires_at > max_timestamp + 86400.0 * 365.0 {
711                return Err(Error::State(format!(
712                    "unreasonable future expiration in layer3 guard {}: expires_at {}",
713                    guard.idhex, guard.expires_at
714                )));
715            }
716        }
717
718        // Validate rendguard fingerprints
719        for fp in self.rendguard.use_counts.keys() {
720            // Skip special NOT_IN_CONSENSUS_ID
721            if fp == "NOT_IN_CONSENSUS" {
722                continue;
723            }
724            if !is_valid_fingerprint(fp) {
725                return Err(Error::State(format!(
726                    "invalid fingerprint in rendguard: {}",
727                    fp
728                )));
729            }
730        }
731
732        Ok(())
733    }
734
735    /// Writes state to a pickle file with atomic write and secure permissions.
736    ///
737    /// Uses atomic write (write to temp file, then rename) to prevent corruption.
738    /// On Unix systems, sets file permissions to 0600 (owner read/write only).
739    ///
740    /// # Errors
741    ///
742    /// Returns [`Error::State`] if the file cannot be written.
743    pub fn write_to_file(&self, path: &Path) -> Result<()> {
744        // Create a temporary file in the same directory for atomic write
745        let temp_path = path.with_extension("tmp");
746
747        // Create file with secure permissions on Unix
748        #[cfg(unix)]
749        let file = {
750            use std::os::unix::fs::OpenOptionsExt;
751            std::fs::OpenOptions::new()
752                .write(true)
753                .create(true)
754                .truncate(true)
755                .mode(0o600)
756                .open(&temp_path)
757                .map_err(|e| Error::State(format!("cannot create temp state file: {}", e)))?
758        };
759
760        #[cfg(not(unix))]
761        let file = File::create(&temp_path)
762            .map_err(|e| Error::State(format!("cannot create temp state file: {}", e)))?;
763
764        let mut writer = BufWriter::new(file);
765        serde_pickle::to_writer(&mut writer, self, Default::default())
766            .map_err(|e| Error::State(format!("cannot write state file: {}", e)))?;
767
768        // Ensure all data is flushed
769        writer
770            .flush()
771            .map_err(|e| Error::State(format!("cannot flush state file: {}", e)))?;
772        drop(writer);
773
774        // Atomic rename
775        std::fs::rename(&temp_path, path)
776            .map_err(|e| Error::State(format!("cannot rename temp state file: {}", e)))?;
777
778        Ok(())
779    }
780
781    /// Returns the layer 2 guard fingerprints as a comma-separated string.
782    pub fn layer2_guardset(&self) -> String {
783        self.layer2
784            .iter()
785            .map(|g| g.idhex.as_str())
786            .collect::<Vec<_>>()
787            .join(",")
788    }
789
790    /// Returns the layer 3 guard fingerprints as a comma-separated string.
791    pub fn layer3_guardset(&self) -> String {
792        self.layer3
793            .iter()
794            .map(|g| g.idhex.as_str())
795            .collect::<Vec<_>>()
796            .join(",")
797    }
798
799    /// Calculates a guard lifetime using max of two uniform random samples.
800    ///
801    /// This distribution favors longer lifetimes, providing better security
802    /// by reducing guard rotation frequency.
803    ///
804    /// # Arguments
805    ///
806    /// * `min_hours` - Minimum lifetime in hours
807    /// * `max_hours` - Maximum lifetime in hours
808    ///
809    /// # Returns
810    ///
811    /// Lifetime in seconds.
812    pub fn calculate_guard_lifetime(min_hours: u32, max_hours: u32) -> f64 {
813        let mut rng = rand::thread_rng();
814        let min_secs = min_hours as f64 * SEC_PER_HOUR;
815        let max_secs = max_hours as f64 * SEC_PER_HOUR;
816        let sample1 = rng.gen_range(min_secs..=max_secs);
817        let sample2 = rng.gen_range(min_secs..=max_secs);
818        sample1.max(sample2)
819    }
820
821    /// Adds a new layer 2 guard.
822    ///
823    /// Selects a guard using the provided generator, avoiding duplicates
824    /// and excluded nodes.
825    pub fn add_new_layer2(
826        &mut self,
827        generator: &BwWeightedGenerator,
828        excluded: &ExcludeNodes,
829        config: &VanguardsConfig,
830    ) -> Result<()> {
831        let existing: HashSet<_> = self.layer2.iter().map(|g| g.idhex.as_str()).collect();
832
833        for _ in 0..1000 {
834            let guard = generator.generate()?;
835            if existing.contains(guard.fingerprint.as_str()) {
836                continue;
837            }
838            if excluded.router_is_excluded(guard) {
839                continue;
840            }
841
842            let now = SystemTime::now()
843                .duration_since(UNIX_EPOCH)
844                .unwrap_or_default()
845                .as_secs_f64();
846            let lifetime = Self::calculate_guard_lifetime(
847                config.min_layer2_lifetime_hours,
848                config.max_layer2_lifetime_hours,
849            );
850            let expires = now + lifetime;
851
852            self.layer2
853                .push(GuardNode::new(guard.fingerprint.clone(), now, expires));
854            return Ok(());
855        }
856
857        Err(Error::NoNodesRemain)
858    }
859
860    /// Adds a new layer 3 guard.
861    ///
862    /// Selects a guard using the provided generator, avoiding duplicates
863    /// and excluded nodes.
864    pub fn add_new_layer3(
865        &mut self,
866        generator: &BwWeightedGenerator,
867        excluded: &ExcludeNodes,
868        config: &VanguardsConfig,
869    ) -> Result<()> {
870        let existing: HashSet<_> = self.layer3.iter().map(|g| g.idhex.as_str()).collect();
871
872        for _ in 0..1000 {
873            let guard = generator.generate()?;
874            if existing.contains(guard.fingerprint.as_str()) {
875                continue;
876            }
877            if excluded.router_is_excluded(guard) {
878                continue;
879            }
880
881            let now = SystemTime::now()
882                .duration_since(UNIX_EPOCH)
883                .unwrap_or_default()
884                .as_secs_f64();
885            let lifetime = Self::calculate_guard_lifetime(
886                config.min_layer3_lifetime_hours,
887                config.max_layer3_lifetime_hours,
888            );
889            let expires = now + lifetime;
890
891            self.layer3
892                .push(GuardNode::new(guard.fingerprint.clone(), now, expires));
893            return Ok(());
894        }
895
896        Err(Error::NoNodesRemain)
897    }
898
899    /// Removes guards that are no longer in the consensus.
900    pub fn remove_down_from_layer(layer: &mut Vec<GuardNode>, consensus_fps: &HashSet<String>) {
901        layer.retain(|g| consensus_fps.contains(&g.idhex));
902    }
903
904    /// Removes guards whose rotation time has expired.
905    pub fn remove_expired_from_layer(layer: &mut Vec<GuardNode>) {
906        let now = SystemTime::now()
907            .duration_since(UNIX_EPOCH)
908            .unwrap_or_default()
909            .as_secs_f64();
910        layer.retain(|g| g.expires_at >= now);
911    }
912
913    /// Removes guards that match the ExcludeNodes configuration.
914    pub fn remove_excluded_from_layer(
915        layer: &mut Vec<GuardNode>,
916        router_map: &HashMap<String, &RouterStatusEntry>,
917        excluded: &ExcludeNodes,
918    ) {
919        layer.retain(|g| {
920            if let Some(router) = router_map.get(&g.idhex) {
921                !excluded.router_is_excluded(router)
922            } else {
923                true
924            }
925        });
926    }
927
928    /// Replenishes guard layers to configured counts.
929    ///
930    /// First trims layers if they exceed configured counts, then adds
931    /// new guards until the configured count is reached.
932    pub fn replenish_layers(
933        &mut self,
934        generator: &BwWeightedGenerator,
935        excluded: &ExcludeNodes,
936        config: &VanguardsConfig,
937    ) -> Result<()> {
938        self.layer2.truncate(config.num_layer2_guards as usize);
939        self.layer3.truncate(config.num_layer3_guards as usize);
940
941        while self.layer2.len() < config.num_layer2_guards as usize {
942            self.add_new_layer2(generator, excluded, config)?;
943        }
944
945        while self.layer3.len() < config.num_layer3_guards as usize {
946            self.add_new_layer3(generator, excluded, config)?;
947        }
948
949        Ok(())
950    }
951}
952
953/// Parsed ExcludeNodes configuration for relay filtering.
954///
955/// Parses Tor's ExcludeNodes configuration option to filter out unwanted
956/// relays based on various criteria. This is used to ensure vanguard guards
957/// respect the user's exclusion preferences.
958///
959/// # Supported Entry Types
960///
961/// ```text
962/// ┌─────────────────────────────────────────────────────────────────────────┐
963/// │                    ExcludeNodes Entry Types                             │
964/// │                                                                         │
965/// │  Type          │ Format                     │ Example                   │
966/// │  ──────────────┼────────────────────────────┼───────────────────────────│
967/// │  Fingerprint   │ $FINGERPRINT or FINGERPRINT│ $AABB...EEFF              │
968/// │  Country       │ {cc}                       │ {us}, {ru}                │
969/// │  Network       │ IP/CIDR                    │ 192.168.0.0/16            │
970/// │  IP Address    │ IP                         │ 192.168.1.1               │
971/// │  Nickname      │ name                       │ BadRelay                  │
972/// └─────────────────────────────────────────────────────────────────────────┘
973/// ```
974///
975/// # GeoIPExcludeUnknown
976///
977/// The `exclude_unknowns` field controls handling of relays with unknown
978/// country codes:
979///
980/// | Setting | Behavior |
981/// |---------|----------|
982/// | `"1"` | Always exclude `??` and `a1` country codes |
983/// | `"auto"` | Exclude `??` and `a1` only if other countries are excluded |
984/// | `None` | Don't exclude unknown countries |
985///
986/// # Example
987///
988/// ```rust
989/// use vanguards_rs::vanguards::ExcludeNodes;
990///
991/// // Parse mixed exclusion configuration
992/// let exclude = ExcludeNodes::parse(
993///     "$AABBCCDD00112233445566778899AABBCCDDEEFF,{us},192.168.0.0/16,BadRelay",
994///     Some("auto")
995/// );
996///
997/// assert!(exclude.idhexes.contains("AABBCCDD00112233445566778899AABBCCDDEEFF"));
998/// assert!(exclude.countries.contains("us"));
999/// assert!(exclude.countries.contains("??")); // auto-added due to {us}
1000/// assert_eq!(exclude.networks.len(), 1);
1001/// assert!(exclude.nicks.contains("BadRelay"));
1002/// ```
1003///
1004/// # See Also
1005///
1006/// - [`VanguardState::remove_excluded_from_layer`] - Uses this for filtering
1007/// - [Tor Manual - ExcludeNodes](https://2019.www.torproject.org/docs/tor-manual.html.en#ExcludeNodes)
1008#[derive(Debug, Clone, Default)]
1009pub struct ExcludeNodes {
1010    /// IP networks to exclude (CIDR notation).
1011    pub networks: Vec<IpNetwork>,
1012    /// Relay fingerprints to exclude (uppercase hex).
1013    pub idhexes: HashSet<String>,
1014    /// Relay nicknames to exclude.
1015    pub nicks: HashSet<String>,
1016    /// Country codes to exclude (lowercase).
1017    pub countries: HashSet<String>,
1018    /// GeoIPExcludeUnknown setting ("1", "auto", or None).
1019    pub exclude_unknowns: Option<String>,
1020}
1021
1022impl ExcludeNodes {
1023    /// Creates a new empty ExcludeNodes.
1024    pub fn new() -> Self {
1025        Self::default()
1026    }
1027
1028    /// Parses an ExcludeNodes configuration line.
1029    ///
1030    /// # Arguments
1031    ///
1032    /// * `conf_line` - The ExcludeNodes configuration value (comma-separated)
1033    /// * `exclude_unknowns` - The GeoIPExcludeUnknown setting
1034    ///
1035    /// # Returns
1036    ///
1037    /// A parsed ExcludeNodes struct.
1038    ///
1039    /// # Entry Format
1040    ///
1041    /// Entries are comma-separated and can be:
1042    ///
1043    /// - `$FINGERPRINT` or `FINGERPRINT` - 40 hex character fingerprint
1044    /// - `$FINGERPRINT~nickname` or `$FINGERPRINT=nickname` - Fingerprint with suffix (suffix stripped)
1045    /// - `{cc}` - Country code (2 characters)
1046    /// - `192.168.0.0/24` or `2001:db8::/32` - IP network
1047    /// - `nickname` - Relay nickname
1048    pub fn parse(conf_line: &str, exclude_unknowns: Option<&str>) -> Self {
1049        let mut result = Self::new();
1050        result.exclude_unknowns = exclude_unknowns.map(|s| s.to_string());
1051
1052        if let Some(ref setting) = result.exclude_unknowns {
1053            if setting == "1" {
1054                result.countries.insert("??".to_string());
1055                result.countries.insert("a1".to_string());
1056            }
1057        }
1058
1059        if conf_line.is_empty() {
1060            return result;
1061        }
1062
1063        result.parse_line(conf_line);
1064        result
1065    }
1066
1067    /// Parses a single configuration line.
1068    fn parse_line(&mut self, conf_line: &str) {
1069        for part in conf_line.split(',') {
1070            let mut p = part.trim().to_string();
1071            if p.is_empty() {
1072                continue;
1073            }
1074
1075            if p.starts_with('$') {
1076                p = p[1..].to_string();
1077            }
1078
1079            if let Some(idx) = p.find('~') {
1080                p = p[..idx].to_string();
1081            }
1082            if let Some(idx) = p.find('=') {
1083                p = p[..idx].to_string();
1084            }
1085
1086            if is_valid_fingerprint(&p) {
1087                self.idhexes.insert(p.to_uppercase());
1088            } else if p.starts_with('{') && p.ends_with('}') && p.len() >= 3 {
1089                let cc = &p[1..p.len() - 1];
1090                if is_valid_country_code(cc) {
1091                    self.countries.insert(cc.to_lowercase());
1092                }
1093            } else if p.contains(':') || p.contains('.') {
1094                if let Ok(network) = p.parse::<IpNetwork>() {
1095                    self.networks.push(network);
1096                } else if let Ok(ip) = p.parse::<IpAddr>() {
1097                    let network = match ip {
1098                        IpAddr::V4(_) => format!("{}/32", ip).parse().ok(),
1099                        IpAddr::V6(_) => format!("{}/128", ip).parse().ok(),
1100                    };
1101                    if let Some(net) = network {
1102                        self.networks.push(net);
1103                    }
1104                }
1105            } else {
1106                self.nicks.insert(p);
1107            }
1108        }
1109
1110        if let Some(ref setting) = self.exclude_unknowns {
1111            if setting == "auto" && !self.countries.is_empty() {
1112                self.countries.insert("??".to_string());
1113                self.countries.insert("a1".to_string());
1114            }
1115        }
1116    }
1117
1118    /// Checks if a router should be excluded.
1119    ///
1120    /// # Arguments
1121    ///
1122    /// * `router` - The router status entry to check
1123    ///
1124    /// # Returns
1125    ///
1126    /// `true` if the router matches any exclusion criteria.
1127    pub fn router_is_excluded(&self, router: &RouterStatusEntry) -> bool {
1128        if self.idhexes.contains(&router.fingerprint.to_uppercase()) {
1129            return true;
1130        }
1131
1132        if self.nicks.contains(&router.nickname) {
1133            return true;
1134        }
1135
1136        let addresses = self.get_router_addresses(router);
1137        for (addr, _port, _is_ipv6) in &addresses {
1138            for network in &self.networks {
1139                if network.contains(*addr) {
1140                    return true;
1141                }
1142            }
1143        }
1144
1145        false
1146    }
1147
1148    /// Gets all addresses for a router.
1149    fn get_router_addresses(&self, router: &RouterStatusEntry) -> Vec<(IpAddr, u16, bool)> {
1150        let mut addresses = vec![(router.address, router.or_port, router.address.is_ipv6())];
1151        addresses.extend(router.or_addresses.iter().cloned());
1152        addresses
1153    }
1154
1155    /// Returns true if this ExcludeNodes has any exclusions configured.
1156    pub fn has_exclusions(&self) -> bool {
1157        !self.networks.is_empty()
1158            || !self.idhexes.is_empty()
1159            || !self.nicks.is_empty()
1160            || !self.countries.is_empty()
1161    }
1162}
1163
1164#[cfg(test)]
1165mod tests {
1166    use super::*;
1167    use chrono::Utc;
1168    use stem_rs::descriptor::router_status::RouterStatusEntryType;
1169
1170    fn create_test_router(fingerprint: &str, nickname: &str, address: &str) -> RouterStatusEntry {
1171        RouterStatusEntry::new(
1172            RouterStatusEntryType::V3,
1173            nickname.to_string(),
1174            fingerprint.to_string(),
1175            Utc::now(),
1176            address.parse().unwrap(),
1177            9001,
1178        )
1179    }
1180
1181    #[test]
1182    fn test_guard_node_creation() {
1183        let now = 1000000.0;
1184        let expires = 2000000.0;
1185        let guard = GuardNode::new("A".repeat(40), now, expires);
1186        assert_eq!(guard.idhex, "A".repeat(40));
1187        assert_eq!(guard.chosen_at, now);
1188        assert_eq!(guard.expires_at, expires);
1189    }
1190
1191    #[test]
1192    fn test_guard_node_expired() {
1193        let now = SystemTime::now()
1194            .duration_since(UNIX_EPOCH)
1195            .unwrap()
1196            .as_secs_f64();
1197
1198        let expired = GuardNode::new("A".repeat(40), now - 1000.0, now - 100.0);
1199        assert!(expired.is_expired());
1200
1201        let not_expired = GuardNode::new("B".repeat(40), now, now + 86400.0);
1202        assert!(!not_expired.is_expired());
1203    }
1204
1205    #[test]
1206    fn test_vanguard_state_new() {
1207        let state = VanguardState::new("test.state");
1208        assert!(state.layer2.is_empty());
1209        assert!(state.layer3.is_empty());
1210        assert_eq!(state.state_file, "test.state");
1211        assert_eq!(state.pickle_revision, 1);
1212    }
1213
1214    #[test]
1215    fn test_vanguard_state_guardset() {
1216        let mut state = VanguardState::new("test.state");
1217        state
1218            .layer2
1219            .push(GuardNode::new("A".repeat(40), 0.0, 1000.0));
1220        state
1221            .layer2
1222            .push(GuardNode::new("B".repeat(40), 0.0, 1000.0));
1223
1224        let guardset = state.layer2_guardset();
1225        assert!(guardset.contains(&"A".repeat(40)));
1226        assert!(guardset.contains(&"B".repeat(40)));
1227        assert!(guardset.contains(','));
1228    }
1229
1230    #[test]
1231    fn test_calculate_guard_lifetime() {
1232        for _ in 0..100 {
1233            let lifetime = VanguardState::calculate_guard_lifetime(24, 1080);
1234            let min_secs = 24.0 * SEC_PER_HOUR;
1235            let max_secs = 1080.0 * SEC_PER_HOUR;
1236            assert!(lifetime >= min_secs);
1237            assert!(lifetime <= max_secs);
1238        }
1239    }
1240
1241    #[test]
1242    fn test_remove_expired_from_layer() {
1243        let now = SystemTime::now()
1244            .duration_since(UNIX_EPOCH)
1245            .unwrap()
1246            .as_secs_f64();
1247
1248        let mut layer = vec![
1249            GuardNode::new("A".repeat(40), now - 1000.0, now - 100.0),
1250            GuardNode::new("B".repeat(40), now, now + 86400.0),
1251            GuardNode::new("C".repeat(40), now - 2000.0, now - 500.0),
1252        ];
1253
1254        VanguardState::remove_expired_from_layer(&mut layer);
1255        assert_eq!(layer.len(), 1);
1256        assert_eq!(layer[0].idhex, "B".repeat(40));
1257    }
1258
1259    #[test]
1260    fn test_remove_down_from_layer() {
1261        let mut layer = vec![
1262            GuardNode::new("A".repeat(40), 0.0, 1000.0),
1263            GuardNode::new("B".repeat(40), 0.0, 1000.0),
1264            GuardNode::new("C".repeat(40), 0.0, 1000.0),
1265        ];
1266
1267        let mut consensus_fps = HashSet::new();
1268        consensus_fps.insert("A".repeat(40));
1269        consensus_fps.insert("C".repeat(40));
1270
1271        VanguardState::remove_down_from_layer(&mut layer, &consensus_fps);
1272        assert_eq!(layer.len(), 2);
1273        assert!(layer.iter().any(|g| g.idhex == "A".repeat(40)));
1274        assert!(layer.iter().any(|g| g.idhex == "C".repeat(40)));
1275        assert!(!layer.iter().any(|g| g.idhex == "B".repeat(40)));
1276    }
1277
1278    #[test]
1279    fn test_exclude_nodes_parse_fingerprint() {
1280        let exclude = ExcludeNodes::parse("$AABBCCDD00112233445566778899AABBCCDDEEFF", None);
1281        assert!(exclude
1282            .idhexes
1283            .contains("AABBCCDD00112233445566778899AABBCCDDEEFF"));
1284    }
1285
1286    #[test]
1287    fn test_exclude_nodes_parse_fingerprint_without_dollar() {
1288        let exclude = ExcludeNodes::parse("AABBCCDD00112233445566778899AABBCCDDEEFF", None);
1289        assert!(exclude
1290            .idhexes
1291            .contains("AABBCCDD00112233445566778899AABBCCDDEEFF"));
1292    }
1293
1294    #[test]
1295    fn test_exclude_nodes_parse_fingerprint_with_suffix() {
1296        let exclude =
1297            ExcludeNodes::parse("$AABBCCDD00112233445566778899AABBCCDDEEFF~nickname", None);
1298        assert!(exclude
1299            .idhexes
1300            .contains("AABBCCDD00112233445566778899AABBCCDDEEFF"));
1301        assert!(!exclude.nicks.contains("nickname"));
1302
1303        let exclude2 =
1304            ExcludeNodes::parse("$AABBCCDD00112233445566778899AABBCCDDEEFF=nickname", None);
1305        assert!(exclude2
1306            .idhexes
1307            .contains("AABBCCDD00112233445566778899AABBCCDDEEFF"));
1308    }
1309
1310    #[test]
1311    fn test_exclude_nodes_parse_country_code() {
1312        let exclude = ExcludeNodes::parse("{us}", None);
1313        assert!(exclude.countries.contains("us"));
1314
1315        let exclude2 = ExcludeNodes::parse("{US}", None);
1316        assert!(exclude2.countries.contains("us"));
1317    }
1318
1319    #[test]
1320    fn test_exclude_nodes_parse_network() {
1321        let exclude = ExcludeNodes::parse("192.168.0.0/24", None);
1322        assert_eq!(exclude.networks.len(), 1);
1323
1324        let exclude2 = ExcludeNodes::parse("2001:db8::/32", None);
1325        assert_eq!(exclude2.networks.len(), 1);
1326    }
1327
1328    #[test]
1329    fn test_exclude_nodes_parse_ip_address() {
1330        let exclude = ExcludeNodes::parse("192.168.1.1", None);
1331        assert_eq!(exclude.networks.len(), 1);
1332    }
1333
1334    #[test]
1335    fn test_exclude_nodes_parse_nickname() {
1336        let exclude = ExcludeNodes::parse("BadRelay", None);
1337        assert!(exclude.nicks.contains("BadRelay"));
1338    }
1339
1340    #[test]
1341    fn test_exclude_nodes_parse_mixed() {
1342        let exclude = ExcludeNodes::parse(
1343            "$AABBCCDD00112233445566778899AABBCCDDEEFF,{us},192.168.0.0/16,BadRelay",
1344            None,
1345        );
1346        assert!(exclude
1347            .idhexes
1348            .contains("AABBCCDD00112233445566778899AABBCCDDEEFF"));
1349        assert!(exclude.countries.contains("us"));
1350        assert_eq!(exclude.networks.len(), 1);
1351        assert!(exclude.nicks.contains("BadRelay"));
1352    }
1353
1354    #[test]
1355    fn test_exclude_nodes_geoip_exclude_unknown_1() {
1356        let exclude = ExcludeNodes::parse("", Some("1"));
1357        assert!(exclude.countries.contains("??"));
1358        assert!(exclude.countries.contains("a1"));
1359    }
1360
1361    #[test]
1362    fn test_exclude_nodes_geoip_exclude_unknown_auto() {
1363        let exclude = ExcludeNodes::parse("{us}", Some("auto"));
1364        assert!(exclude.countries.contains("us"));
1365        assert!(exclude.countries.contains("??"));
1366        assert!(exclude.countries.contains("a1"));
1367    }
1368
1369    #[test]
1370    fn test_exclude_nodes_geoip_exclude_unknown_auto_no_countries() {
1371        let exclude = ExcludeNodes::parse("BadRelay", Some("auto"));
1372        assert!(!exclude.countries.contains("??"));
1373        assert!(!exclude.countries.contains("a1"));
1374    }
1375
1376    #[test]
1377    fn test_router_is_excluded_by_fingerprint() {
1378        let exclude = ExcludeNodes::parse("$AABBCCDD00112233445566778899AABBCCDDEEFF", None);
1379        let router = create_test_router(
1380            "AABBCCDD00112233445566778899AABBCCDDEEFF",
1381            "test",
1382            "192.0.2.1",
1383        );
1384        assert!(exclude.router_is_excluded(&router));
1385    }
1386
1387    #[test]
1388    fn test_router_is_excluded_by_nickname() {
1389        let exclude = ExcludeNodes::parse("BadRelay", None);
1390        let router = create_test_router(&"A".repeat(40), "BadRelay", "192.0.2.1");
1391        assert!(exclude.router_is_excluded(&router));
1392    }
1393
1394    #[test]
1395    fn test_router_is_excluded_by_network() {
1396        let exclude = ExcludeNodes::parse("192.168.0.0/16", None);
1397        let router = create_test_router(&"A".repeat(40), "test", "192.168.1.1");
1398        assert!(exclude.router_is_excluded(&router));
1399
1400        let router2 = create_test_router(&"B".repeat(40), "test2", "10.0.0.1");
1401        assert!(!exclude.router_is_excluded(&router2));
1402    }
1403
1404    #[test]
1405    fn test_router_not_excluded() {
1406        let exclude = ExcludeNodes::parse("$BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB,{de}", None);
1407        let router = create_test_router(&"A".repeat(40), "GoodRelay", "192.0.2.1");
1408        assert!(!exclude.router_is_excluded(&router));
1409    }
1410
1411    #[test]
1412    fn test_rendguard_new() {
1413        let rg = RendGuard::new();
1414        assert!(rg.use_counts.is_empty());
1415        assert_eq!(rg.total_use_counts, 0.0);
1416        assert_eq!(rg.pickle_revision, 1.0);
1417    }
1418
1419    #[test]
1420    fn test_rendguard_scale_counts() {
1421        let mut rg = RendGuard::new();
1422        rg.use_counts.insert(
1423            "A".repeat(40),
1424            RendUseCount {
1425                idhex: "A".repeat(40),
1426                used: 100.0,
1427                weight: 0.5,
1428            },
1429        );
1430        rg.use_counts.insert(
1431            "B".repeat(40),
1432            RendUseCount {
1433                idhex: "B".repeat(40),
1434                used: 200.0,
1435                weight: 0.5,
1436            },
1437        );
1438        rg.total_use_counts = 300.0;
1439
1440        rg.scale_counts();
1441
1442        assert_eq!(rg.use_counts.get(&"A".repeat(40)).unwrap().used, 50.0);
1443        assert_eq!(rg.use_counts.get(&"B".repeat(40)).unwrap().used, 100.0);
1444        assert_eq!(rg.total_use_counts, 150.0);
1445    }
1446
1447    #[test]
1448    fn test_exclude_nodes_has_exclusions() {
1449        let empty = ExcludeNodes::new();
1450        assert!(!empty.has_exclusions());
1451
1452        let with_fp = ExcludeNodes::parse("$AABBCCDD00112233445566778899AABBCCDDEEFF", None);
1453        assert!(with_fp.has_exclusions());
1454
1455        let with_country = ExcludeNodes::parse("{us}", None);
1456        assert!(with_country.has_exclusions());
1457
1458        let with_network = ExcludeNodes::parse("192.168.0.0/24", None);
1459        assert!(with_network.has_exclusions());
1460
1461        let with_nick = ExcludeNodes::parse("BadRelay", None);
1462        assert!(with_nick.has_exclusions());
1463    }
1464
1465    #[test]
1466    fn test_exclude_nodes_empty_string() {
1467        let exclude = ExcludeNodes::parse("", None);
1468        assert!(!exclude.has_exclusions());
1469    }
1470
1471    #[test]
1472    fn test_exclude_nodes_whitespace_handling() {
1473        let exclude =
1474            ExcludeNodes::parse(" $AABBCCDD00112233445566778899AABBCCDDEEFF , {us} ", None);
1475        assert!(exclude
1476            .idhexes
1477            .contains("AABBCCDD00112233445566778899AABBCCDDEEFF"));
1478        assert!(exclude.countries.contains("us"));
1479    }
1480
1481    #[test]
1482    fn test_vanguard_state_validation_valid() {
1483        let now = SystemTime::now()
1484            .duration_since(UNIX_EPOCH)
1485            .unwrap()
1486            .as_secs_f64();
1487
1488        let mut state = VanguardState::new("test.state");
1489        state
1490            .layer2
1491            .push(GuardNode::new("A".repeat(40), now - 1000.0, now + 86400.0));
1492        state
1493            .layer3
1494            .push(GuardNode::new("B".repeat(40), now - 500.0, now + 3600.0));
1495
1496        assert!(state.validate().is_ok());
1497    }
1498
1499    #[test]
1500    fn test_vanguard_state_validation_invalid_fingerprint() {
1501        let now = SystemTime::now()
1502            .duration_since(UNIX_EPOCH)
1503            .unwrap()
1504            .as_secs_f64();
1505
1506        let mut state = VanguardState::new("test.state");
1507        state
1508            .layer2
1509            .push(GuardNode::new("invalid".to_string(), now, now + 86400.0));
1510
1511        assert!(state.validate().is_err());
1512    }
1513
1514    #[test]
1515    fn test_vanguard_state_validation_future_timestamp() {
1516        let now = SystemTime::now()
1517            .duration_since(UNIX_EPOCH)
1518            .unwrap()
1519            .as_secs_f64();
1520
1521        let mut state = VanguardState::new("test.state");
1522        state
1523            .layer2
1524            .push(GuardNode::new("A".repeat(40), now + 10000.0, now + 86400.0));
1525
1526        assert!(state.validate().is_err());
1527    }
1528}
1529
1530#[cfg(test)]
1531mod proptests {
1532    use super::*;
1533    use crate::node_selection::is_valid_fingerprint;
1534    use proptest::prelude::*;
1535
1536    fn arb_fingerprint() -> impl Strategy<Value = String> {
1537        "[0-9A-F]{40}".prop_map(|s| s.to_uppercase())
1538    }
1539
1540    fn arb_country_code() -> impl Strategy<Value = String> {
1541        "[a-z]{2}"
1542    }
1543
1544    fn arb_ipv4() -> impl Strategy<Value = String> {
1545        (1u8..=254, 0u8..=255, 0u8..=255, 1u8..=254)
1546            .prop_map(|(a, b, c, d)| format!("{}.{}.{}.{}", a, b, c, d))
1547    }
1548
1549    fn arb_cidr() -> impl Strategy<Value = String> {
1550        (arb_ipv4(), 8u8..=30).prop_map(|(ip, prefix)| format!("{}/{}", ip, prefix))
1551    }
1552
1553    fn arb_nickname() -> impl Strategy<Value = String> {
1554        "[A-Za-z][A-Za-z0-9]{0,18}"
1555    }
1556
1557    fn arb_guard_node() -> impl Strategy<Value = GuardNode> {
1558        let now = std::time::SystemTime::now()
1559            .duration_since(std::time::UNIX_EPOCH)
1560            .unwrap()
1561            .as_secs_f64();
1562        let chosen_min = now - 365.0 * 86400.0;
1563        let chosen_max = now;
1564        let expires_max = now + 365.0 * 86400.0;
1565
1566        (
1567            arb_fingerprint(),
1568            chosen_min..chosen_max,
1569            chosen_max..expires_max,
1570        )
1571            .prop_map(|(idhex, chosen_at, expires_at)| GuardNode::new(idhex, chosen_at, expires_at))
1572    }
1573
1574    fn arb_rend_use_count() -> impl Strategy<Value = RendUseCount> {
1575        (arb_fingerprint(), 0.0f64..10000.0, 0.0f64..1.0).prop_map(|(idhex, used, weight)| {
1576            RendUseCount {
1577                idhex,
1578                used,
1579                weight,
1580            }
1581        })
1582    }
1583
1584    fn arb_rendguard() -> impl Strategy<Value = RendGuard> {
1585        (
1586            prop::collection::vec(arb_rend_use_count(), 0..10),
1587            0.0f64..100000.0,
1588        )
1589            .prop_map(|(counts, total)| {
1590                let mut rg = RendGuard::new();
1591                for count in counts {
1592                    rg.use_counts.insert(count.idhex.clone(), count);
1593                }
1594                rg.total_use_counts = total;
1595                rg
1596            })
1597    }
1598
1599    fn arb_vanguard_state() -> impl Strategy<Value = VanguardState> {
1600        (
1601            prop::collection::vec(arb_guard_node(), 0..8),
1602            prop::collection::vec(arb_guard_node(), 0..16),
1603            arb_rendguard(),
1604        )
1605            .prop_map(|(layer2, layer3, rendguard)| {
1606                let mut state = VanguardState::new("test.state");
1607                state.layer2 = layer2;
1608                state.layer3 = layer3;
1609                state.rendguard = rendguard;
1610                state
1611            })
1612    }
1613
1614    proptest! {
1615        #![proptest_config(ProptestConfig::with_cases(100))]
1616
1617        #[test]
1618        fn exclude_nodes_parsing(
1619            fingerprints in prop::collection::vec(arb_fingerprint(), 0..5),
1620            countries in prop::collection::vec(arb_country_code(), 0..5),
1621            networks in prop::collection::vec(arb_cidr(), 0..3),
1622            nicknames in prop::collection::vec(arb_nickname(), 0..5),
1623        ) {
1624            let mut parts = Vec::new();
1625
1626            for fp in &fingerprints {
1627                parts.push(format!("${}", fp));
1628            }
1629            for cc in &countries {
1630                parts.push(format!("{{{}}}", cc));
1631            }
1632            for net in &networks {
1633                parts.push(net.clone());
1634            }
1635            for nick in &nicknames {
1636                parts.push(nick.clone());
1637            }
1638
1639            let conf_line = parts.join(",");
1640            let exclude = ExcludeNodes::parse(&conf_line, None);
1641
1642            for fp in &fingerprints {
1643                prop_assert!(exclude.idhexes.contains(&fp.to_uppercase()),
1644                    "Fingerprint {} not found in parsed idhexes", fp);
1645            }
1646
1647            for cc in &countries {
1648                prop_assert!(exclude.countries.contains(&cc.to_lowercase()),
1649                    "Country code {} not found in parsed countries", cc);
1650            }
1651
1652            prop_assert_eq!(exclude.networks.len(), networks.len(),
1653                "Expected {} networks, got {}", networks.len(), exclude.networks.len());
1654
1655            for nick in &nicknames {
1656                if !is_valid_fingerprint(nick) && !nick.contains('.') && !nick.contains(':') {
1657                    prop_assert!(exclude.nicks.contains(nick),
1658                        "Nickname {} not found in parsed nicks", nick);
1659                }
1660            }
1661        }
1662
1663        #[test]
1664        fn state_serialization_round_trip(state in arb_vanguard_state()) {
1665            let temp_dir = tempfile::tempdir().expect("Failed to create temp dir");
1666            let state_path = temp_dir.path().join("test.state");
1667
1668            state.write_to_file(&state_path).expect("Failed to write state");
1669            let loaded = VanguardState::read_from_file(&state_path).expect("Failed to read state");
1670
1671            prop_assert_eq!(state.layer2.len(), loaded.layer2.len());
1672            prop_assert_eq!(state.layer3.len(), loaded.layer3.len());
1673
1674            for (orig, load) in state.layer2.iter().zip(loaded.layer2.iter()) {
1675                prop_assert_eq!(&orig.idhex, &load.idhex);
1676                prop_assert!((orig.chosen_at - load.chosen_at).abs() < 0.001);
1677                prop_assert!((orig.expires_at - load.expires_at).abs() < 0.001);
1678            }
1679
1680            for (orig, load) in state.layer3.iter().zip(loaded.layer3.iter()) {
1681                prop_assert_eq!(&orig.idhex, &load.idhex);
1682                prop_assert!((orig.chosen_at - load.chosen_at).abs() < 0.001);
1683                prop_assert!((orig.expires_at - load.expires_at).abs() < 0.001);
1684            }
1685
1686            prop_assert_eq!(state.rendguard.use_counts.len(), loaded.rendguard.use_counts.len());
1687            prop_assert!((state.rendguard.total_use_counts - loaded.rendguard.total_use_counts).abs() < 0.001);
1688        }
1689
1690        #[test]
1691        fn guard_lifetime_distribution(
1692            min_hours in 1u32..100,
1693            max_hours in 100u32..2000,
1694        ) {
1695            prop_assume!(min_hours < max_hours);
1696
1697            let min_secs = min_hours as f64 * 3600.0;
1698            let max_secs = max_hours as f64 * 3600.0;
1699
1700            let mut lifetimes = Vec::new();
1701            for _ in 0..100 {
1702                let lifetime = VanguardState::calculate_guard_lifetime(min_hours, max_hours);
1703                prop_assert!(lifetime >= min_secs, "Lifetime {} below min {}", lifetime, min_secs);
1704                prop_assert!(lifetime <= max_secs, "Lifetime {} above max {}", lifetime, max_secs);
1705                lifetimes.push(lifetime);
1706            }
1707
1708            let avg = lifetimes.iter().sum::<f64>() / lifetimes.len() as f64;
1709            let midpoint = (min_secs + max_secs) / 2.0;
1710            prop_assert!(avg > midpoint,
1711                "Average lifetime {} should be above midpoint {} (max of two uniforms)", avg, midpoint);
1712        }
1713
1714        #[test]
1715        fn expired_guard_removal(
1716            num_expired in 0usize..5,
1717            num_valid in 0usize..5,
1718        ) {
1719            let now = SystemTime::now()
1720                .duration_since(UNIX_EPOCH)
1721                .unwrap()
1722                .as_secs_f64();
1723
1724            let mut layer = Vec::new();
1725            let mut expected_remaining = HashSet::new();
1726
1727            for i in 0..num_expired {
1728                let fp = format!("{:0>40X}", i);
1729                layer.push(GuardNode::new(fp, now - 10000.0, now - 1000.0));
1730            }
1731
1732            for i in 0..num_valid {
1733                let fp = format!("{:0>40X}", 100 + i);
1734                layer.push(GuardNode::new(fp.clone(), now - 1000.0, now + 86400.0));
1735                expected_remaining.insert(fp);
1736            }
1737
1738            VanguardState::remove_expired_from_layer(&mut layer);
1739
1740            prop_assert_eq!(layer.len(), num_valid,
1741                "Expected {} guards after removal, got {}", num_valid, layer.len());
1742
1743            for guard in &layer {
1744                prop_assert!(expected_remaining.contains(&guard.idhex),
1745                    "Unexpected guard {} in layer", guard.idhex);
1746                prop_assert!(guard.expires_at >= now,
1747                    "Guard {} should not be expired", guard.idhex);
1748            }
1749        }
1750
1751        #[test]
1752        fn down_guard_removal(
1753            num_in_consensus in 0usize..5,
1754            num_not_in_consensus in 0usize..5,
1755        ) {
1756            let now = SystemTime::now()
1757                .duration_since(UNIX_EPOCH)
1758                .unwrap()
1759                .as_secs_f64();
1760
1761            let mut layer = Vec::new();
1762            let mut consensus_fps = HashSet::new();
1763
1764            for i in 0..num_in_consensus {
1765                let fp = format!("{:0>40X}", i);
1766                layer.push(GuardNode::new(fp.clone(), now, now + 86400.0));
1767                consensus_fps.insert(fp);
1768            }
1769
1770            for i in 0..num_not_in_consensus {
1771                let fp = format!("{:0>40X}", 100 + i);
1772                layer.push(GuardNode::new(fp, now, now + 86400.0));
1773            }
1774
1775            VanguardState::remove_down_from_layer(&mut layer, &consensus_fps);
1776
1777            prop_assert_eq!(layer.len(), num_in_consensus,
1778                "Expected {} guards after removal, got {}", num_in_consensus, layer.len());
1779
1780            for guard in &layer {
1781                prop_assert!(consensus_fps.contains(&guard.idhex),
1782                    "Guard {} should be in consensus", guard.idhex);
1783            }
1784        }
1785
1786        #[test]
1787        fn layer_trimming(
1788            initial_layer2 in 0usize..20,
1789            initial_layer3 in 0usize..30,
1790            target_layer2 in 1u8..10,
1791            target_layer3 in 1u8..15,
1792        ) {
1793            let now = SystemTime::now()
1794                .duration_since(UNIX_EPOCH)
1795                .unwrap()
1796                .as_secs_f64();
1797
1798            let mut state = VanguardState::new("test.state");
1799
1800            for i in 0..initial_layer2 {
1801                let fp = format!("{:0>40X}", i);
1802                state.layer2.push(GuardNode::new(fp, now, now + 86400.0));
1803            }
1804
1805            for i in 0..initial_layer3 {
1806                let fp = format!("{:0>40X}", 100 + i);
1807                state.layer3.push(GuardNode::new(fp, now, now + 86400.0));
1808            }
1809
1810            state.layer2.truncate(target_layer2 as usize);
1811            state.layer3.truncate(target_layer3 as usize);
1812
1813            let expected_layer2 = initial_layer2.min(target_layer2 as usize);
1814            let expected_layer3 = initial_layer3.min(target_layer3 as usize);
1815
1816            prop_assert_eq!(state.layer2.len(), expected_layer2,
1817                "Layer2 should have {} guards, got {}", expected_layer2, state.layer2.len());
1818            prop_assert_eq!(state.layer3.len(), expected_layer3,
1819                "Layer3 should have {} guards, got {}", expected_layer3, state.layer3.len());
1820        }
1821    }
1822}