audnet 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- audnet/__init__.py +8 -0
- audnet/cli.py +222 -0
- audnet/collector.py +180 -0
- audnet/collector_async.py +199 -0
- audnet/compliance.py +250 -0
- audnet/config.py +113 -0
- audnet/exceptions.py +25 -0
- audnet/models.py +105 -0
- audnet/parser.py +67 -0
- audnet/reporter.py +61 -0
- audnet/templates/__init__.py +0 -0
- audnet/textfsm_templates/__init__.py +0 -0
- audnet/textfsm_templates/cisco_ios_show_cdp_neighbors_detail.textfsm +12 -0
- audnet/textfsm_templates/cisco_ios_show_interface_status.textfsm +10 -0
- audnet/textfsm_templates/cisco_ios_show_ip_interface_brief.textfsm +10 -0
- audnet/textfsm_templates/cisco_ios_show_running_config.textfsm +4 -0
- audnet/textfsm_templates/cisco_ios_show_version.textfsm +9 -0
- audnet/vendor_registry.py +165 -0
- audnet-0.1.2.dist-info/METADATA +826 -0
- audnet-0.1.2.dist-info/RECORD +23 -0
- audnet-0.1.2.dist-info/WHEEL +4 -0
- audnet-0.1.2.dist-info/entry_points.txt +2 -0
- audnet-0.1.2.dist-info/licenses/LICENSE +21 -0
audnet/__init__.py
ADDED
audnet/cli.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
"""CLI entry point for audnet."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import structlog
|
|
9
|
+
import typer
|
|
10
|
+
from rich.console import Console
|
|
11
|
+
from rich.table import Table
|
|
12
|
+
|
|
13
|
+
from audnet import __version__
|
|
14
|
+
from audnet.collector import collect_all
|
|
15
|
+
from audnet.compliance import run_checks
|
|
16
|
+
from audnet.config import load_inventory, load_baseline
|
|
17
|
+
from audnet.models import AuditReport
|
|
18
|
+
from audnet.reporter import render_markdown, render_html
|
|
19
|
+
|
|
20
|
+
# Async collector is imported lazily to avoid requiring asyncssh unless --async is used.
|
|
21
|
+
_collect_all_async = None
|
|
22
|
+
|
|
23
|
+
app = typer.Typer(help="Network Security & Compliance State Auditor")
|
|
24
|
+
console = Console()
|
|
25
|
+
logger = structlog.get_logger("audnet")
|
|
26
|
+
|
|
27
|
+
_SECRET_KEYS = frozenset({"password", "key_file", "secret", "passwd", "token"})
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _redact_secrets(
|
|
31
|
+
_logger: logging.Logger, _method_name: str, event_dict: dict[str, Any]
|
|
32
|
+
) -> dict[str, Any]:
|
|
33
|
+
"""Structlog processor that redacts sensitive values from log events."""
|
|
34
|
+
for key in event_dict:
|
|
35
|
+
if key.lower() in _SECRET_KEYS and event_dict[key] is not None:
|
|
36
|
+
event_dict[key] = "***REDACTED***"
|
|
37
|
+
return event_dict
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _setup_logging(verbose: bool = False) -> None:
|
|
41
|
+
"""Configure structlog with JSON or console output and secret redaction."""
|
|
42
|
+
level = logging.DEBUG if verbose else logging.INFO
|
|
43
|
+
shared_processors: list[Any] = [
|
|
44
|
+
structlog.stdlib.add_log_level,
|
|
45
|
+
structlog.stdlib.add_logger_name,
|
|
46
|
+
structlog.processors.TimeStamper(fmt="%Y-%m-%d %H:%M:%S"),
|
|
47
|
+
_redact_secrets,
|
|
48
|
+
]
|
|
49
|
+
renderer: Any
|
|
50
|
+
if verbose:
|
|
51
|
+
renderer = structlog.dev.ConsoleRenderer()
|
|
52
|
+
else:
|
|
53
|
+
renderer = structlog.processors.JSONRenderer()
|
|
54
|
+
|
|
55
|
+
structlog.configure(
|
|
56
|
+
processors=[*shared_processors, renderer],
|
|
57
|
+
wrapper_class=structlog.stdlib.BoundLogger,
|
|
58
|
+
context_class=dict,
|
|
59
|
+
logger_factory=structlog.stdlib.LoggerFactory(),
|
|
60
|
+
cache_logger_on_first_use=True,
|
|
61
|
+
)
|
|
62
|
+
logging.basicConfig(
|
|
63
|
+
format="%(message)s",
|
|
64
|
+
level=level,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@app.command()
|
|
69
|
+
def audit(
|
|
70
|
+
inventory: str = typer.Option("inventories/devices.yaml", help="Device inventory YAML"),
|
|
71
|
+
baseline: str = typer.Option("baselines/security_baseline.yaml", help="Security baseline YAML"),
|
|
72
|
+
output: str = typer.Option("audit_report", help="Output file prefix"),
|
|
73
|
+
format: str = typer.Option("both", help="Output format: md, html, or both"),
|
|
74
|
+
workers: int = typer.Option(4, help="Max parallel SSH connections"),
|
|
75
|
+
verbose: bool = typer.Option(False, "-v", "--verbose", help="Enable debug logging"),
|
|
76
|
+
device: str | None = typer.Option(None, "--device", help="Filter to single device by name"),
|
|
77
|
+
check: list[str] = typer.Option(
|
|
78
|
+
[],
|
|
79
|
+
"--check",
|
|
80
|
+
help="Filter to specific checks (repeatable; supports comma-separated in one arg)",
|
|
81
|
+
),
|
|
82
|
+
json_out: bool = typer.Option(False, "--json", help="Output JSON summary to stdout"),
|
|
83
|
+
dry_run: bool = typer.Option(
|
|
84
|
+
False,
|
|
85
|
+
"-n",
|
|
86
|
+
"--dry-run",
|
|
87
|
+
help="Validate config and show what would be audited without connecting to devices",
|
|
88
|
+
),
|
|
89
|
+
strict: bool = typer.Option(
|
|
90
|
+
False,
|
|
91
|
+
"--strict",
|
|
92
|
+
help="Fail if any device has a plaintext password (no ${ENV_VAR} reference)",
|
|
93
|
+
),
|
|
94
|
+
no_fail: bool = typer.Option(
|
|
95
|
+
False,
|
|
96
|
+
"--no-fail",
|
|
97
|
+
help="Always exit 0 even on compliance failures (informational mode)",
|
|
98
|
+
),
|
|
99
|
+
async_mode: bool = typer.Option(
|
|
100
|
+
False,
|
|
101
|
+
"--async",
|
|
102
|
+
help="Use asyncio-based SSH collector (recommended for >20 devices)",
|
|
103
|
+
),
|
|
104
|
+
) -> None:
|
|
105
|
+
"""Run a full compliance audit against all (or filtered) devices.
|
|
106
|
+
|
|
107
|
+
Supports device/check filters, JSON output, dry-run mode, and strict secret handling for CI/automation.
|
|
108
|
+
"""
|
|
109
|
+
_setup_logging(verbose)
|
|
110
|
+
console.print(f"[bold blue]audnet v{__version__} — Starting audit...[/bold blue]")
|
|
111
|
+
|
|
112
|
+
_, devices = load_inventory(inventory, strict=strict)
|
|
113
|
+
baseline_data = load_baseline(baseline)
|
|
114
|
+
|
|
115
|
+
if device:
|
|
116
|
+
devices = [d for d in devices if d.name == device]
|
|
117
|
+
if not devices:
|
|
118
|
+
console.print(f"[red]Device '{device}' not found in inventory[/red]")
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
check_names = set(baseline_data.get("checks", {}).keys())
|
|
122
|
+
console.print(f"Loaded {len(devices)} devices, {len(check_names)} checks")
|
|
123
|
+
|
|
124
|
+
if dry_run:
|
|
125
|
+
console.print("[bold yellow]DRY RUN — no device connections will be made[/bold yellow]")
|
|
126
|
+
console.print("[yellow]Devices that would be audited:[/yellow]")
|
|
127
|
+
for d in devices:
|
|
128
|
+
console.print(f" • {d.name} ({d.host}) — {d.device_type}")
|
|
129
|
+
console.print("[yellow]Checks that would be run:[/yellow]")
|
|
130
|
+
for name in sorted(check_names):
|
|
131
|
+
console.print(f" • {name}")
|
|
132
|
+
if check:
|
|
133
|
+
check_set = {c.strip() for item in check for c in item.split(",")}
|
|
134
|
+
unknown = check_set - check_names
|
|
135
|
+
if unknown:
|
|
136
|
+
console.print(
|
|
137
|
+
f"[yellow]Warning: unknown check(s) {', '.join(sorted(unknown))} — "
|
|
138
|
+
f"available: {', '.join(sorted(check_names))}[/yellow]"
|
|
139
|
+
)
|
|
140
|
+
console.print("[green]Dry run complete — config and baseline are valid[/green]")
|
|
141
|
+
return
|
|
142
|
+
|
|
143
|
+
# Collect with status
|
|
144
|
+
console.print("[yellow]Collecting device data...[/yellow]")
|
|
145
|
+
if async_mode:
|
|
146
|
+
import asyncio
|
|
147
|
+
|
|
148
|
+
global _collect_all_async
|
|
149
|
+
if _collect_all_async is None:
|
|
150
|
+
from audnet.collector_async import collect_all_async
|
|
151
|
+
|
|
152
|
+
_collect_all_async = collect_all_async
|
|
153
|
+
snapshots = asyncio.run(_collect_all_async(devices, max_workers=workers))
|
|
154
|
+
else:
|
|
155
|
+
snapshots = collect_all(devices, max_workers=workers)
|
|
156
|
+
|
|
157
|
+
# Resolve check filter
|
|
158
|
+
if check:
|
|
159
|
+
check_set = {c.strip() for item in check for c in item.split(",")}
|
|
160
|
+
unknown = check_set - check_names
|
|
161
|
+
if unknown:
|
|
162
|
+
console.print(
|
|
163
|
+
f"[yellow]Warning: unknown check(s) {', '.join(sorted(unknown))} — "
|
|
164
|
+
f"available: {', '.join(sorted(check_names))}[/yellow]"
|
|
165
|
+
)
|
|
166
|
+
else:
|
|
167
|
+
check_set = set()
|
|
168
|
+
|
|
169
|
+
# Audit
|
|
170
|
+
reports = []
|
|
171
|
+
for snap in snapshots:
|
|
172
|
+
if snap.collection_error:
|
|
173
|
+
console.print(f"[red]ERROR {snap.device_name}: {snap.collection_error}[/red]")
|
|
174
|
+
reports.append(AuditReport(device_name=snap.device_name, overall_pass=False, checks=[]))
|
|
175
|
+
continue
|
|
176
|
+
|
|
177
|
+
results = run_checks(snap, baseline_data)
|
|
178
|
+
if check_set:
|
|
179
|
+
results = [r for r in results if r.check_name in check_set]
|
|
180
|
+
overall = all(r.passed for r in results) if results else False
|
|
181
|
+
reports.append(
|
|
182
|
+
AuditReport(device_name=snap.device_name, overall_pass=overall, checks=results)
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# Terminal summary
|
|
186
|
+
table = Table(title="Audit Results")
|
|
187
|
+
table.add_column("Device")
|
|
188
|
+
table.add_column("Status")
|
|
189
|
+
table.add_column("Passed")
|
|
190
|
+
table.add_column("Failed")
|
|
191
|
+
for r in reports:
|
|
192
|
+
status = "[green]PASS[/green]" if r.overall_pass else "[red]FAIL[/red]"
|
|
193
|
+
table.add_row(r.device_name, status, str(r.pass_count), str(r.fail_count))
|
|
194
|
+
console.print(table)
|
|
195
|
+
|
|
196
|
+
# Write reports
|
|
197
|
+
if format in ("md", "both"):
|
|
198
|
+
md_path = Path(f"{output}.md")
|
|
199
|
+
md_path.write_text(render_markdown(reports))
|
|
200
|
+
console.print(f"[green]Markdown report: {md_path}[/green]")
|
|
201
|
+
|
|
202
|
+
if format in ("html", "both"):
|
|
203
|
+
html_path = Path(f"{output}.html")
|
|
204
|
+
html_path.write_text(render_html(reports))
|
|
205
|
+
console.print(f"[green]HTML report: {html_path}[/green]")
|
|
206
|
+
|
|
207
|
+
if json_out:
|
|
208
|
+
json_data = [r.model_dump(mode="json") for r in reports]
|
|
209
|
+
console.print_json(json.dumps(json_data))
|
|
210
|
+
|
|
211
|
+
if not no_fail and reports and not all(r.overall_pass for r in reports):
|
|
212
|
+
raise typer.Exit(code=1)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
@app.command()
|
|
216
|
+
def version() -> None:
|
|
217
|
+
"""Show the audnet version."""
|
|
218
|
+
console.print(f"audnet {__version__}")
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
if __name__ == "__main__":
|
|
222
|
+
app()
|
audnet/collector.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
"""Parallel SSH collector for network device data.
|
|
2
|
+
|
|
3
|
+
Uses the vendor registry for multi-vendor command dispatch.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import concurrent.futures
|
|
8
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
9
|
+
from typing import cast
|
|
10
|
+
|
|
11
|
+
from netmiko import ConnectHandler
|
|
12
|
+
from netmiko.exceptions import (
|
|
13
|
+
NetmikoTimeoutException,
|
|
14
|
+
NetmikoAuthenticationException,
|
|
15
|
+
ConfigInvalidException,
|
|
16
|
+
ConnectionException,
|
|
17
|
+
ReadException,
|
|
18
|
+
NetmikoParsingException,
|
|
19
|
+
)
|
|
20
|
+
from paramiko.ssh_exception import SSHException
|
|
21
|
+
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception
|
|
22
|
+
|
|
23
|
+
from audnet.exceptions import ParseError
|
|
24
|
+
from audnet.models import Device, DeviceSnapshot, ParsedInterfaces, ParsedVersion, ParsedConfig
|
|
25
|
+
from audnet.parser import parse_interfaces, parse_version, parse_config
|
|
26
|
+
from audnet.vendor_registry import Slot, get_commands
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
# Transient exceptions that are safe to retry on
|
|
31
|
+
_RETRYABLE_EXCEPTIONS = (
|
|
32
|
+
NetmikoTimeoutException,
|
|
33
|
+
ConnectionException,
|
|
34
|
+
ReadException,
|
|
35
|
+
SSHException,
|
|
36
|
+
NetmikoParsingException,
|
|
37
|
+
OSError,
|
|
38
|
+
ConnectionError,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _is_retryable(exc: BaseException) -> bool:
|
|
43
|
+
"""Return True if *exc* is a transient error worth retrying.
|
|
44
|
+
|
|
45
|
+
Explicitly excludes authentication failures — those are never transient.
|
|
46
|
+
"""
|
|
47
|
+
if isinstance(exc, NetmikoAuthenticationException):
|
|
48
|
+
return False
|
|
49
|
+
return isinstance(exc, _RETRYABLE_EXCEPTIONS)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@retry(
|
|
53
|
+
stop=stop_after_attempt(3),
|
|
54
|
+
wait=wait_exponential(multiplier=1, min=2, max=10),
|
|
55
|
+
retry=retry_if_exception(_is_retryable),
|
|
56
|
+
reraise=True,
|
|
57
|
+
)
|
|
58
|
+
def _do_ssh_collect(device: Device) -> dict[Slot, str]:
|
|
59
|
+
"""Internal function that performs the actual SSH collection.
|
|
60
|
+
|
|
61
|
+
Retries transient errors up to 3 times with exponential backoff.
|
|
62
|
+
Returns a dict mapping Slot -> raw CLI output.
|
|
63
|
+
"""
|
|
64
|
+
params = {
|
|
65
|
+
"device_type": device.device_type,
|
|
66
|
+
"host": device.host,
|
|
67
|
+
"username": device.username,
|
|
68
|
+
"password": device.get_password(),
|
|
69
|
+
"port": device.port,
|
|
70
|
+
"timeout": device.timeout,
|
|
71
|
+
}
|
|
72
|
+
if device.use_keys:
|
|
73
|
+
params["use_keys"] = True
|
|
74
|
+
if device.key_file:
|
|
75
|
+
params["key_file"] = device.key_file
|
|
76
|
+
commands = get_commands(device.device_type)
|
|
77
|
+
slot_map = (Slot.INTERFACES, Slot.VERSION, Slot.RUNNING_CONFIG)
|
|
78
|
+
with ConnectHandler(**params) as conn:
|
|
79
|
+
return {slot: cast(str, conn.send_command(cmd)) for slot, cmd in zip(slot_map, commands)}
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def collect_device(device: Device) -> DeviceSnapshot:
|
|
83
|
+
"""Collect data from one device (with internal retry for transient SSH issues)."""
|
|
84
|
+
logger.info("Collecting data from %s (%s)", device.name, device.host)
|
|
85
|
+
try:
|
|
86
|
+
raw_outputs = _do_ssh_collect(device)
|
|
87
|
+
|
|
88
|
+
logger.info("Successfully collected from %s", device.name)
|
|
89
|
+
parsed_version = parse_version(raw_outputs[Slot.VERSION], device_type=device.device_type)
|
|
90
|
+
return DeviceSnapshot(
|
|
91
|
+
device_name=device.name,
|
|
92
|
+
interfaces=ParsedInterfaces(
|
|
93
|
+
interfaces=parse_interfaces(
|
|
94
|
+
raw_outputs[Slot.INTERFACES], device_type=device.device_type
|
|
95
|
+
)
|
|
96
|
+
),
|
|
97
|
+
version=ParsedVersion(**parsed_version, raw=raw_outputs[Slot.VERSION]),
|
|
98
|
+
config=ParsedConfig(
|
|
99
|
+
lines=parse_config(raw_outputs[Slot.RUNNING_CONFIG]),
|
|
100
|
+
raw=raw_outputs[Slot.RUNNING_CONFIG],
|
|
101
|
+
),
|
|
102
|
+
)
|
|
103
|
+
except (
|
|
104
|
+
NetmikoTimeoutException,
|
|
105
|
+
NetmikoAuthenticationException,
|
|
106
|
+
ConfigInvalidException,
|
|
107
|
+
ConnectionException,
|
|
108
|
+
ReadException,
|
|
109
|
+
NetmikoParsingException,
|
|
110
|
+
SSHException,
|
|
111
|
+
OSError,
|
|
112
|
+
ValueError,
|
|
113
|
+
ConnectionError,
|
|
114
|
+
ParseError,
|
|
115
|
+
) as exc:
|
|
116
|
+
logger.error("Failed to collect from %s: %s", device.name, exc)
|
|
117
|
+
return DeviceSnapshot(
|
|
118
|
+
device_name=device.name,
|
|
119
|
+
interfaces=ParsedInterfaces(),
|
|
120
|
+
version=ParsedVersion(),
|
|
121
|
+
config=ParsedConfig(),
|
|
122
|
+
collection_error=str(exc),
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def collect_all(
|
|
127
|
+
devices: list[Device],
|
|
128
|
+
max_workers: int = 4,
|
|
129
|
+
timeout: float | None = None,
|
|
130
|
+
) -> list[DeviceSnapshot]:
|
|
131
|
+
"""Run parallel collection across devices.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
devices: List of devices to collect from.
|
|
135
|
+
max_workers: Maximum parallel SSH connections.
|
|
136
|
+
timeout: Optional per-device timeout in seconds. If a device takes
|
|
137
|
+
longer than this, its collection is aborted and an error snapshot
|
|
138
|
+
is returned. None means no timeout.
|
|
139
|
+
"""
|
|
140
|
+
from time import monotonic
|
|
141
|
+
|
|
142
|
+
deadline = monotonic() + timeout if timeout else None
|
|
143
|
+
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
|
144
|
+
future_to_dev = {pool.submit(collect_device, d): d for d in devices}
|
|
145
|
+
pending = set(future_to_dev)
|
|
146
|
+
completed: dict[str, DeviceSnapshot] = {}
|
|
147
|
+
while pending:
|
|
148
|
+
wait_timeout = None
|
|
149
|
+
if deadline:
|
|
150
|
+
wait_timeout = max(0, deadline - monotonic())
|
|
151
|
+
done, pending = concurrent.futures.wait(
|
|
152
|
+
pending,
|
|
153
|
+
timeout=wait_timeout,
|
|
154
|
+
return_when=concurrent.futures.FIRST_COMPLETED,
|
|
155
|
+
)
|
|
156
|
+
for future in done:
|
|
157
|
+
dev = future_to_dev[future]
|
|
158
|
+
try:
|
|
159
|
+
completed[dev.name] = future.result(timeout=0)
|
|
160
|
+
except TimeoutError:
|
|
161
|
+
logger.error("Collection from %s timed out after %ss", dev.name, timeout)
|
|
162
|
+
completed[dev.name] = DeviceSnapshot(
|
|
163
|
+
device_name=dev.name,
|
|
164
|
+
interfaces=ParsedInterfaces(),
|
|
165
|
+
version=ParsedVersion(),
|
|
166
|
+
config=ParsedConfig(),
|
|
167
|
+
collection_error=f"Collection timed out after {timeout}s",
|
|
168
|
+
)
|
|
169
|
+
if deadline and monotonic() >= deadline:
|
|
170
|
+
for fut in pending:
|
|
171
|
+
d = future_to_dev[fut]
|
|
172
|
+
completed[d.name] = DeviceSnapshot(
|
|
173
|
+
device_name=d.name,
|
|
174
|
+
interfaces=ParsedInterfaces(),
|
|
175
|
+
version=ParsedVersion(),
|
|
176
|
+
config=ParsedConfig(),
|
|
177
|
+
collection_error=f"Collection timed out after {timeout}s",
|
|
178
|
+
)
|
|
179
|
+
break
|
|
180
|
+
return [completed[d.name] for d in devices]
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
"""Async SSH collector prototype for network device data.
|
|
2
|
+
|
|
3
|
+
This module provides an asyncio-based alternative to the ThreadPool + Netmiko
|
|
4
|
+
collector. It uses asyncssh for SSH connections and is designed to scale to
|
|
5
|
+
hundreds of devices with lower memory and thread overhead.
|
|
6
|
+
|
|
7
|
+
Architecture:
|
|
8
|
+
- asyncio event loop manages all concurrent SSH sessions
|
|
9
|
+
- asyncssh handles SSH transport (no thread per connection)
|
|
10
|
+
- Same DeviceSnapshot output format as sync collector
|
|
11
|
+
- Same retry logic via tenacity (async-compatible)
|
|
12
|
+
|
|
13
|
+
Trade-offs vs sync collector (collector.py):
|
|
14
|
+
+ Single-threaded: no GIL contention, lower memory per connection
|
|
15
|
+
+ Native concurrency: scales to 100s of devices without thread overhead
|
|
16
|
+
+ No thread pool sizing: concurrency limited by semaphore, not OS threads
|
|
17
|
+
- Requires asyncssh dependency (not in current dependency tree)
|
|
18
|
+
- No Netmiko device-type abstraction: commands sent raw
|
|
19
|
+
- Prototype status: not yet integrated into CLI
|
|
20
|
+
|
|
21
|
+
Migration path:
|
|
22
|
+
1. Install asyncssh: uv add asyncssh
|
|
23
|
+
2. Switch collector import in cli.py: from audnet.collector_async import collect_all
|
|
24
|
+
3. Add --workers flag maps to asyncio.Semaphore limit
|
|
25
|
+
4. Keep sync collector as fallback for environments without asyncssh
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
import asyncio
|
|
29
|
+
import logging
|
|
30
|
+
from typing import cast
|
|
31
|
+
|
|
32
|
+
import asyncssh
|
|
33
|
+
from asyncssh import (
|
|
34
|
+
ChannelOpenError,
|
|
35
|
+
DisconnectError,
|
|
36
|
+
PermissionDenied,
|
|
37
|
+
TimeoutError as AsyncSshTimeoutError,
|
|
38
|
+
)
|
|
39
|
+
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception
|
|
40
|
+
|
|
41
|
+
from audnet.models import Device, DeviceSnapshot, ParsedInterfaces, ParsedVersion, ParsedConfig
|
|
42
|
+
from audnet.parser import parse_interfaces, parse_version, parse_config
|
|
43
|
+
from audnet.vendor_registry import Slot, get_commands
|
|
44
|
+
|
|
45
|
+
logger = logging.getLogger(__name__)
|
|
46
|
+
|
|
47
|
+
# Transient exceptions that are safe to retry on (asyncssh equivalents)
|
|
48
|
+
_RETRYABLE_EXCEPTIONS = (
|
|
49
|
+
DisconnectError,
|
|
50
|
+
ChannelOpenError,
|
|
51
|
+
AsyncSshTimeoutError,
|
|
52
|
+
OSError,
|
|
53
|
+
ConnectionError,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _is_retryable(exc: BaseException) -> bool:
|
|
58
|
+
"""Return True if *exc* is a transient error worth retrying.
|
|
59
|
+
|
|
60
|
+
Explicitly excludes authentication failures -- those are never transient.
|
|
61
|
+
"""
|
|
62
|
+
if isinstance(exc, PermissionDenied):
|
|
63
|
+
return False
|
|
64
|
+
return isinstance(exc, _RETRYABLE_EXCEPTIONS)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@retry(
|
|
68
|
+
stop=stop_after_attempt(3),
|
|
69
|
+
wait=wait_exponential(multiplier=1, min=2, max=10),
|
|
70
|
+
retry=retry_if_exception(_is_retryable),
|
|
71
|
+
reraise=True,
|
|
72
|
+
)
|
|
73
|
+
async def _do_ssh_collect(device: Device, known_hosts: str | None = None) -> dict[Slot, str]:
|
|
74
|
+
"""Perform async SSH collection with retry for transient errors.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
device: Device to collect from.
|
|
78
|
+
known_hosts: Path to known_hosts file. ``None`` uses the system default
|
|
79
|
+
(``~/.ssh/known_hosts``). Pass an empty string to disable verification
|
|
80
|
+
(lab/testing only).
|
|
81
|
+
"""
|
|
82
|
+
commands = get_commands(device.device_type)
|
|
83
|
+
password = device.get_password()
|
|
84
|
+
slot_map = (Slot.INTERFACES, Slot.VERSION, Slot.RUNNING_CONFIG)
|
|
85
|
+
async with asyncssh.connect(
|
|
86
|
+
device.host,
|
|
87
|
+
port=device.port,
|
|
88
|
+
username=device.username,
|
|
89
|
+
password=password,
|
|
90
|
+
known_hosts=known_hosts,
|
|
91
|
+
connect_timeout=device.timeout or 30,
|
|
92
|
+
) as conn:
|
|
93
|
+
results: dict[Slot, str] = {}
|
|
94
|
+
for slot, cmd in zip(slot_map, commands):
|
|
95
|
+
result = await conn.run(cmd, timeout=device.timeout)
|
|
96
|
+
results[slot] = cast(str, result.stdout)
|
|
97
|
+
return results
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
async def collect_device_async(
|
|
101
|
+
device: Device,
|
|
102
|
+
known_hosts: str | None = None,
|
|
103
|
+
) -> DeviceSnapshot:
|
|
104
|
+
"""Collect data from one device asynchronously.
|
|
105
|
+
|
|
106
|
+
Same interface as sync collect_device(), but uses asyncio + asyncssh
|
|
107
|
+
instead of ThreadPool + Netmiko.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
device: Device to collect from.
|
|
111
|
+
known_hosts: Path to known_hosts file. ``None`` uses the system default
|
|
112
|
+
(``~/.ssh/known_hosts``). Pass an empty string to disable verification
|
|
113
|
+
(lab/testing only).
|
|
114
|
+
"""
|
|
115
|
+
logger.info("Collecting data from %s (%s)", device.name, device.host)
|
|
116
|
+
try:
|
|
117
|
+
raw_outputs = await _do_ssh_collect(device, known_hosts=known_hosts)
|
|
118
|
+
|
|
119
|
+
logger.info("Successfully collected from %s", device.name)
|
|
120
|
+
parsed_version = parse_version(raw_outputs[Slot.VERSION], device_type=device.device_type)
|
|
121
|
+
return DeviceSnapshot(
|
|
122
|
+
device_name=device.name,
|
|
123
|
+
interfaces=ParsedInterfaces(
|
|
124
|
+
interfaces=parse_interfaces(
|
|
125
|
+
raw_outputs[Slot.INTERFACES], device_type=device.device_type
|
|
126
|
+
)
|
|
127
|
+
),
|
|
128
|
+
version=ParsedVersion(**parsed_version, raw=raw_outputs[Slot.VERSION]),
|
|
129
|
+
config=ParsedConfig(
|
|
130
|
+
lines=parse_config(raw_outputs[Slot.RUNNING_CONFIG]),
|
|
131
|
+
raw=raw_outputs[Slot.RUNNING_CONFIG],
|
|
132
|
+
),
|
|
133
|
+
)
|
|
134
|
+
except (
|
|
135
|
+
PermissionDenied,
|
|
136
|
+
DisconnectError,
|
|
137
|
+
ChannelOpenError,
|
|
138
|
+
AsyncSshTimeoutError,
|
|
139
|
+
OSError,
|
|
140
|
+
ValueError,
|
|
141
|
+
ConnectionError,
|
|
142
|
+
) as exc:
|
|
143
|
+
logger.error("Failed to collect from %s: %s", device.name, exc)
|
|
144
|
+
return DeviceSnapshot(
|
|
145
|
+
device_name=device.name,
|
|
146
|
+
interfaces=ParsedInterfaces(),
|
|
147
|
+
version=ParsedVersion(),
|
|
148
|
+
config=ParsedConfig(),
|
|
149
|
+
collection_error=str(exc),
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
async def collect_all_async(
|
|
154
|
+
devices: list[Device],
|
|
155
|
+
max_workers: int = 50,
|
|
156
|
+
timeout: float | None = None,
|
|
157
|
+
known_hosts: str | None = None,
|
|
158
|
+
) -> list[DeviceSnapshot]:
|
|
159
|
+
"""Run async collection across all devices concurrently.
|
|
160
|
+
|
|
161
|
+
Uses an asyncio.Semaphore to limit concurrent connections, which is
|
|
162
|
+
more memory-efficient than a ThreadPool for large inventories.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
devices: List of devices to collect from.
|
|
166
|
+
max_workers: Maximum concurrent SSH connections (semaphore limit).
|
|
167
|
+
Defaults to 50 -- much higher than the sync default of 4
|
|
168
|
+
because async connections have minimal per-connection overhead.
|
|
169
|
+
timeout: Optional per-device timeout in seconds.
|
|
170
|
+
known_hosts: Path to known_hosts file. ``None`` uses the system default
|
|
171
|
+
(``~/.ssh/known_hosts``). Pass an empty string to disable
|
|
172
|
+
verification (lab/testing only).
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
List of DeviceSnapshot results, one per device.
|
|
176
|
+
"""
|
|
177
|
+
semaphore = asyncio.Semaphore(max_workers)
|
|
178
|
+
|
|
179
|
+
async def _bounded_collect(device: Device) -> DeviceSnapshot:
|
|
180
|
+
async with semaphore:
|
|
181
|
+
if timeout:
|
|
182
|
+
try:
|
|
183
|
+
return await asyncio.wait_for(
|
|
184
|
+
collect_device_async(device, known_hosts=known_hosts),
|
|
185
|
+
timeout=timeout,
|
|
186
|
+
)
|
|
187
|
+
except asyncio.TimeoutError:
|
|
188
|
+
logger.error("Collection from %s timed out after %ss", device.name, timeout)
|
|
189
|
+
return DeviceSnapshot(
|
|
190
|
+
device_name=device.name,
|
|
191
|
+
interfaces=ParsedInterfaces(),
|
|
192
|
+
version=ParsedVersion(),
|
|
193
|
+
config=ParsedConfig(),
|
|
194
|
+
collection_error=f"Collection timed out after {timeout}s",
|
|
195
|
+
)
|
|
196
|
+
return await collect_device_async(device, known_hosts=known_hosts)
|
|
197
|
+
|
|
198
|
+
tasks = [asyncio.create_task(_bounded_collect(d)) for d in devices]
|
|
199
|
+
return list(await asyncio.gather(*tasks))
|