WaveGuardClient 2.2.0__tar.gz → 2.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: WaveGuardClient
3
- Version: 2.2.0
3
+ Version: 2.3.0
4
4
  Summary: Python SDK for WaveGuard — physics-based anomaly detection API
5
5
  Author: Greg Partin
6
6
  License-Expression: MIT
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: WaveGuardClient
3
- Version: 2.2.0
3
+ Version: 2.3.0
4
4
  Summary: Python SDK for WaveGuard — physics-based anomaly detection API
5
5
  Author: Greg Partin
6
6
  License-Expression: MIT
@@ -4,8 +4,11 @@ pyproject.toml
4
4
  WaveGuardClient.egg-info/PKG-INFO
5
5
  WaveGuardClient.egg-info/SOURCES.txt
6
6
  WaveGuardClient.egg-info/dependency_links.txt
7
+ WaveGuardClient.egg-info/entry_points.txt
7
8
  WaveGuardClient.egg-info/requires.txt
8
9
  WaveGuardClient.egg-info/top_level.txt
10
+ mcp_server/__init__.py
11
+ mcp_server/server.py
9
12
  tests/test_client.py
10
13
  tests/test_mcp_server.py
11
14
  waveguard/__init__.py
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ waveguard-mcp = mcp_server.server:main
@@ -0,0 +1,2 @@
1
+ mcp_server
2
+ waveguard
File without changes
@@ -0,0 +1,528 @@
1
+ """
2
+ WaveGuard MCP Server — Model Context Protocol for Claude Desktop & AI Agents.
3
+
4
+ Stateless anomaly detection via wave physics simulation.
5
+ One tool: ``waveguard_scan``. Send training + test data, get anomaly scores.
6
+
7
+ Transports
8
+ ----------
9
+ - **stdio** (default) — add to Claude Desktop config
10
+ - **HTTP** — ``python -m mcp_server --http --port 3001``
11
+
12
+ Claude Desktop config
13
+ ~~~~~~~~~~~~~~~~~~~~~
14
+ Add to ``~/.config/claude/claude_desktop_config.json``
15
+ (macOS/Linux) or ``%APPDATA%\\Claude\\claude_desktop_config.json`` (Windows)::
16
+
17
+ {
18
+ "mcpServers": {
19
+ "waveguard": {
20
+ "command": "python",
21
+ "args": ["/path/to/WaveGuardClient/mcp_server/server.py"],
22
+ "env": {
23
+ "WAVEGUARD_API_KEY": "your-key-here"
24
+ }
25
+ }
26
+ }
27
+ }
28
+
29
+ Smithery / Glama config
30
+ ~~~~~~~~~~~~~~~~~~~~~~~
31
+ ::
32
+
33
+ {
34
+ "mcpServers": {
35
+ "waveguard": {
36
+ "url": "https://gpartin--waveguard-api-fastapi-app.modal.run/mcp",
37
+ "transport": "http"
38
+ }
39
+ }
40
+ }
41
+ """
42
+
43
+ from __future__ import annotations
44
+
45
+ import os
46
+ import sys
47
+ import json
48
+ import argparse
49
+ from typing import Any, Dict, List, Optional
50
+
51
+ # ── Configuration ──────────────────────────────────────────────────────────
52
+
53
+ API_URL = os.environ.get(
54
+ "WAVEGUARD_API_URL",
55
+ "https://gpartin--waveguard-api-fastapi-app.modal.run",
56
+ )
57
+ API_KEY = os.environ.get("WAVEGUARD_API_KEY", "")
58
+
59
+ # ── HTTP client ───────────────────────────────────────────────────────────
60
+
61
+ try:
62
+ import requests
63
+
64
+ _session = requests.Session()
65
+ if API_KEY:
66
+ _session.headers["X-API-Key"] = API_KEY
67
+ except ImportError:
68
+ _session = None # type: ignore[assignment]
69
+
70
+
71
+ def _api_post(path: str, body: dict) -> dict:
72
+ if _session is None:
73
+ raise RuntimeError("requests library required: pip install requests")
74
+ resp = _session.post(f"{API_URL}{path}", json=body, timeout=90)
75
+ resp.raise_for_status()
76
+ return resp.json()
77
+
78
+
79
+ def _api_get(path: str) -> Any:
80
+ if _session is None:
81
+ raise RuntimeError("requests library required: pip install requests")
82
+ resp = _session.get(f"{API_URL}{path}", timeout=30)
83
+ resp.raise_for_status()
84
+ return resp.json()
85
+
86
+
87
+ # ═══════════════════════════════════════════════════════════════════════════
88
+ # MCP Tool Definitions
89
+ # ═══════════════════════════════════════════════════════════════════════════
90
+
91
+ TOOLS = [
92
+ {
93
+ "name": "waveguard_scan",
94
+ "description": (
95
+ "Detect anomalies in data using GPU-accelerated wave physics simulation. "
96
+ "Fully stateless — send training data (normal examples) and test data "
97
+ "(samples to check) in ONE call. Returns per-sample anomaly scores, "
98
+ "confidence levels, and the top features explaining WHY each anomaly "
99
+ "was flagged. Works on any data type: JSON objects, numbers, text, "
100
+ "time series, arrays. No separate training step required.\n\n"
101
+ "Example: to check if server metrics are anomalous, send 3-5 normal "
102
+ "readings as training, and the suspect readings as test."
103
+ ),
104
+ "inputSchema": {
105
+ "type": "object",
106
+ "properties": {
107
+ "training": {
108
+ "type": "array",
109
+ "description": (
110
+ "2+ examples of NORMAL/expected data. These define what "
111
+ "'normal' looks like. All samples should be the same type "
112
+ "and shape. More samples = better baseline (4-10 is ideal)."
113
+ ),
114
+ "minItems": 2,
115
+ },
116
+ "test": {
117
+ "type": "array",
118
+ "description": (
119
+ "1+ data points to check for anomalies. Same type/shape "
120
+ "as training data. Each sample is scored independently."
121
+ ),
122
+ "minItems": 1,
123
+ },
124
+ "sensitivity": {
125
+ "type": "number",
126
+ "description": (
127
+ "Anomaly threshold multiplier (default: 2.0). Lower = more "
128
+ "sensitive (flags more anomalies). Higher = less sensitive. "
129
+ "Range: 0.5 to 5.0."
130
+ ),
131
+ },
132
+ "encoder_type": {
133
+ "type": "string",
134
+ "enum": [
135
+ "json",
136
+ "numeric",
137
+ "text",
138
+ "timeseries",
139
+ "tabular",
140
+ "image",
141
+ "correlation",
142
+ ],
143
+ "description": (
144
+ "Data encoder type. Omit to auto-detect from data shape. "
145
+ "Auto-detection works well for most data."
146
+ ),
147
+ },
148
+ },
149
+ "required": ["training", "test"],
150
+ },
151
+ },
152
+ {
153
+ "name": "waveguard_scan_timeseries",
154
+ "description": (
155
+ "Detect anomalies in time-series data using GPU-accelerated wave "
156
+ "physics simulation. Send a flat array of numeric values and a "
157
+ "window size. The tool automatically creates overlapping windows, "
158
+ "uses the first N as training (normal baseline), and scores the "
159
+ "remaining windows as test samples. Returns per-window anomaly "
160
+ "scores, confidence, and p-values.\n\n"
161
+ "Example: send 100 CPU-usage readings with window_size=10. "
162
+ "The first 5 windows become training, the rest are tested."
163
+ ),
164
+ "inputSchema": {
165
+ "type": "object",
166
+ "properties": {
167
+ "data": {
168
+ "type": "array",
169
+ "items": {"type": "number"},
170
+ "description": (
171
+ "Flat array of numeric time-series values in "
172
+ "chronological order."
173
+ ),
174
+ "minItems": 4,
175
+ },
176
+ "window_size": {
177
+ "type": "integer",
178
+ "description": (
179
+ "Number of data points per window (default: 10). "
180
+ "Smaller windows = finer resolution."
181
+ ),
182
+ },
183
+ "test_windows": {
184
+ "type": "integer",
185
+ "description": (
186
+ "Number of trailing windows to test (default: auto, "
187
+ "uses last ~40%% of windows)."
188
+ ),
189
+ },
190
+ "sensitivity": {
191
+ "type": "number",
192
+ "description": (
193
+ "Anomaly threshold multiplier (default: 2.0). Lower = "
194
+ "more sensitive. Range: 0.5 to 5.0."
195
+ ),
196
+ },
197
+ },
198
+ "required": ["data"],
199
+ },
200
+ },
201
+ {
202
+ "name": "waveguard_health",
203
+ "description": (
204
+ "Check WaveGuard API health, GPU availability, version, and engine "
205
+ "status. No authentication required. Use this to verify the service "
206
+ "is running before scanning."
207
+ ),
208
+ "inputSchema": {
209
+ "type": "object",
210
+ "properties": {},
211
+ },
212
+ },
213
+ ]
214
+
215
+
216
+ # ═══════════════════════════════════════════════════════════════════════════
217
+ # Tool Execution
218
+ # ═══════════════════════════════════════════════════════════════════════════
219
+
220
+
221
+ def _execute_timeseries(arguments: dict) -> dict:
222
+ """Sliding-window timeseries scan via the /v1/scan endpoint."""
223
+ data = arguments["data"]
224
+ window = int(arguments.get("window_size", 10))
225
+ sensitivity = arguments.get("sensitivity", 2.0)
226
+
227
+ # Build windows
228
+ windows = [data[i : i + window] for i in range(0, len(data) - window + 1)]
229
+ if len(windows) < 3:
230
+ return {
231
+ "content": [
232
+ {
233
+ "type": "text",
234
+ "text": (
235
+ f"Not enough data: {len(data)} points with "
236
+ f"window_size={window} gives {len(windows)} windows "
237
+ f"(need at least 3)."
238
+ ),
239
+ }
240
+ ],
241
+ "isError": True,
242
+ }
243
+
244
+ # Split into training / test
245
+ test_count = arguments.get("test_windows")
246
+ if test_count is None:
247
+ test_count = max(1, len(windows) * 2 // 5)
248
+ test_count = min(test_count, len(windows) - 2)
249
+
250
+ training = windows[: len(windows) - test_count]
251
+ test = windows[len(windows) - test_count :]
252
+
253
+ body: dict = {
254
+ "training": training,
255
+ "test": test,
256
+ "encoder_type": "timeseries",
257
+ "sensitivity": sensitivity,
258
+ }
259
+ result = _api_post("/v1/scan", body)
260
+
261
+ # Summarise
262
+ lines = [
263
+ f"Time-series scan: {len(windows)} windows "
264
+ f"(window_size={window}, {len(training)} train, {len(test)} test)",
265
+ "",
266
+ ]
267
+ for i, r in enumerate(result.get("results", [])):
268
+ idx = len(training) + i
269
+ is_anom = r.get("is_anomaly", False)
270
+ conf = r.get("confidence", 0)
271
+ pval = r.get("p_value", 1.0)
272
+ marker = "ANOMALY" if is_anom else "Normal"
273
+ lines.append(
274
+ f" Window {idx}: {marker} (confidence: {conf:.0%}, "
275
+ f"p-value: {pval:.4f})"
276
+ )
277
+ summary = "\n".join(lines)
278
+ return {
279
+ "content": [
280
+ {"type": "text", "text": summary},
281
+ {"type": "text", "text": json.dumps(result, indent=2)},
282
+ ]
283
+ }
284
+
285
+
286
+ def execute_tool(name: str, arguments: dict) -> dict:
287
+ """Execute an MCP tool and return the result."""
288
+ try:
289
+ if name == "waveguard_scan":
290
+ body: dict = {
291
+ "training": arguments["training"],
292
+ "test": arguments["test"],
293
+ }
294
+ if "sensitivity" in arguments:
295
+ body["sensitivity"] = arguments["sensitivity"]
296
+ if "encoder_type" in arguments:
297
+ body["encoder_type"] = arguments["encoder_type"]
298
+
299
+ result = _api_post("/v1/scan", body)
300
+
301
+ # Build human-readable summary for the agent
302
+ summary_data = result.get("summary", {})
303
+ total = summary_data.get("total_samples", len(arguments["test"]))
304
+ anomalies = summary_data.get("anomalies_found", 0)
305
+ rate = summary_data.get("anomaly_rate", 0)
306
+
307
+ lines = [
308
+ f"Scanned {total} samples: {anomalies} anomalies ({rate:.0%} anomaly rate)",
309
+ "",
310
+ ]
311
+
312
+ for i, r in enumerate(result.get("results", [])):
313
+ is_anom = r.get("is_anomaly", False)
314
+ conf = r.get("confidence", 0)
315
+ score = r.get("score", 0)
316
+ marker = "ANOMALY" if is_anom else "Normal"
317
+ line = f" Sample {i + 1}: {marker} (confidence: {conf:.0%}, score: {score:.1f})"
318
+
319
+ if is_anom and r.get("top_features"):
320
+ feats = r["top_features"][:3]
321
+ feat_str = ", ".join(
322
+ f"{f.get('label', '?')} (z={f.get('z_score', 0):.1f})"
323
+ for f in feats
324
+ )
325
+ line += f"\n Top features: {feat_str}"
326
+
327
+ lines.append(line)
328
+
329
+ summary = "\n".join(lines)
330
+
331
+ return {
332
+ "content": [
333
+ {"type": "text", "text": summary},
334
+ {"type": "text", "text": json.dumps(result, indent=2)},
335
+ ]
336
+ }
337
+
338
+ elif name == "waveguard_scan_timeseries":
339
+ return _execute_timeseries(arguments)
340
+
341
+ elif name == "waveguard_health":
342
+ result = _api_get("/v1/health")
343
+ status = (
344
+ f"Status: {result.get('status', '?')} | "
345
+ f"Version: {result.get('version', '?')} | "
346
+ f"GPU: {result.get('gpu', 'N/A')}"
347
+ )
348
+ return {"content": [{"type": "text", "text": status}]}
349
+
350
+ else:
351
+ return {
352
+ "content": [{"type": "text", "text": f"Unknown tool: {name}"}],
353
+ "isError": True,
354
+ }
355
+
356
+ except Exception as e:
357
+ return {
358
+ "content": [{"type": "text", "text": f"Error: {e}"}],
359
+ "isError": True,
360
+ }
361
+
362
+
363
+ # ═══════════════════════════════════════════════════════════════════════════
364
+ # MCP Protocol Handler (stdio JSON-RPC transport)
365
+ # ═══════════════════════════════════════════════════════════════════════════
366
+
367
+
368
+ class MCPStdioServer:
369
+ """Minimal MCP server implementing JSON-RPC 2.0 over stdio."""
370
+
371
+ def __init__(self) -> None:
372
+ self.server_info = {
373
+ "name": "waveguard",
374
+ "version": "2.3.0",
375
+ }
376
+
377
+ def handle_message(self, msg: dict) -> Optional[dict]:
378
+ """Process a JSON-RPC 2.0 message and return the response."""
379
+ method = msg.get("method", "")
380
+ msg_id = msg.get("id")
381
+ params = msg.get("params", {})
382
+
383
+ if method == "initialize":
384
+ return {
385
+ "jsonrpc": "2.0",
386
+ "id": msg_id,
387
+ "result": {
388
+ "protocolVersion": "2024-11-05",
389
+ "capabilities": {
390
+ "tools": {"listChanged": False},
391
+ },
392
+ "serverInfo": self.server_info,
393
+ },
394
+ }
395
+
396
+ elif method == "notifications/initialized":
397
+ return None
398
+
399
+ elif method == "tools/list":
400
+ return {
401
+ "jsonrpc": "2.0",
402
+ "id": msg_id,
403
+ "result": {"tools": TOOLS},
404
+ }
405
+
406
+ elif method == "tools/call":
407
+ tool_name = params.get("name", "")
408
+ arguments = params.get("arguments", {})
409
+ result = execute_tool(tool_name, arguments)
410
+ return {
411
+ "jsonrpc": "2.0",
412
+ "id": msg_id,
413
+ "result": result,
414
+ }
415
+
416
+ elif method == "ping":
417
+ return {
418
+ "jsonrpc": "2.0",
419
+ "id": msg_id,
420
+ "result": {},
421
+ }
422
+
423
+ else:
424
+ if msg_id is not None:
425
+ return {
426
+ "jsonrpc": "2.0",
427
+ "id": msg_id,
428
+ "error": {
429
+ "code": -32601,
430
+ "message": f"Method not found: {method}",
431
+ },
432
+ }
433
+ return None
434
+
435
+ def run_stdio(self) -> None:
436
+ """Run the MCP server on stdin/stdout."""
437
+ sys.stderr.write(
438
+ f"WaveGuard MCP server v2.3.0 started (API: {API_URL})\n"
439
+ )
440
+ sys.stderr.flush()
441
+
442
+ for line in sys.stdin:
443
+ line = line.strip()
444
+ if not line:
445
+ continue
446
+ try:
447
+ msg = json.loads(line)
448
+ except json.JSONDecodeError as e:
449
+ sys.stderr.write(f"Invalid JSON: {e}\n")
450
+ sys.stderr.flush()
451
+ continue
452
+
453
+ response = self.handle_message(msg)
454
+ if response is not None:
455
+ sys.stdout.write(json.dumps(response) + "\n")
456
+ sys.stdout.flush()
457
+
458
+
459
+ # ═══════════════════════════════════════════════════════════════════════════
460
+ # HTTP transport (for remote MCP clients / Smithery / Glama)
461
+ # ═══════════════════════════════════════════════════════════════════════════
462
+
463
+
464
+ def run_http_server(port: int = 3001) -> None:
465
+ """Run MCP over HTTP for remote agent access."""
466
+ try:
467
+ from fastapi import FastAPI as FA
468
+ import uvicorn
469
+ except ImportError:
470
+ print("HTTP transport requires: pip install fastapi uvicorn")
471
+ sys.exit(1)
472
+
473
+ mcp_app = FA(title="WaveGuard MCP Server", version="2.3.0")
474
+ server = MCPStdioServer()
475
+
476
+ @mcp_app.post("/mcp")
477
+ async def mcp_endpoint(request: dict) -> dict: # type: ignore[type-arg]
478
+ return server.handle_message(request) # type: ignore[return-value]
479
+
480
+ @mcp_app.get("/mcp/tools")
481
+ async def mcp_tools() -> dict: # type: ignore[type-arg]
482
+ return {"tools": TOOLS}
483
+
484
+ print(f"WaveGuard MCP HTTP server v2.3.0 on port {port}")
485
+ uvicorn.run(mcp_app, host="0.0.0.0", port=port)
486
+
487
+
488
+ # ═══════════════════════════════════════════════════════════════════════════
489
+ # Entry point
490
+ # ═══════════════════════════════════════════════════════════════════════════
491
+
492
+ def main():
493
+ """Entry point for `waveguard-mcp` console script."""
494
+ parser = argparse.ArgumentParser(
495
+ description="WaveGuard MCP Server v2.3.0"
496
+ )
497
+ parser.add_argument(
498
+ "--http",
499
+ action="store_true",
500
+ help="Use HTTP transport instead of stdio",
501
+ )
502
+ parser.add_argument(
503
+ "--port",
504
+ type=int,
505
+ default=3001,
506
+ help="HTTP port (default: 3001)",
507
+ )
508
+ parser.add_argument(
509
+ "--api-url",
510
+ type=str,
511
+ default=None,
512
+ help="WaveGuard API URL (overrides $WAVEGUARD_API_URL)",
513
+ )
514
+ args = parser.parse_args()
515
+
516
+ global API_URL
517
+ if args.api_url:
518
+ API_URL = args.api_url
519
+
520
+ if args.http:
521
+ run_http_server(args.port)
522
+ else:
523
+ server = MCPStdioServer()
524
+ server.run_stdio()
525
+
526
+
527
+ if __name__ == "__main__":
528
+ main()
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "WaveGuardClient"
7
- version = "2.2.0"
7
+ version = "2.3.0"
8
8
  description = "Python SDK for WaveGuard — physics-based anomaly detection API"
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -59,9 +59,12 @@ Repository = "https://github.com/gpartin/WaveGuardClient"
59
59
  Issues = "https://github.com/gpartin/WaveGuardClient/issues"
60
60
  API = "https://gpartin--waveguard-api-fastapi-app.modal.run/docs"
61
61
 
62
+ [project.scripts]
63
+ waveguard-mcp = "mcp_server.server:main"
64
+
62
65
  [tool.setuptools.packages.find]
63
- include = ["waveguard*"]
64
- exclude = ["tests*", "examples*", "docs*", "mcp_server*"]
66
+ include = ["waveguard*", "mcp_server*"]
67
+ exclude = ["tests*", "examples*", "docs*"]
65
68
 
66
69
  [tool.pytest.ini_options]
67
70
  testpaths = ["tests"]
@@ -35,6 +35,7 @@ def _mock_response(status_code=200, json_data=None, text=""):
35
35
  resp.status_code = status_code
36
36
  resp.text = text or json.dumps(json_data or {})
37
37
  resp.json.return_value = json_data or {}
38
+ resp.headers = {}
38
39
  return resp
39
40
 
40
41
 
@@ -92,10 +93,10 @@ class TestScan:
92
93
  @patch("waveguard.client.requests.Session")
93
94
  def test_scan_parses_results(self, mock_session_cls):
94
95
  session = MagicMock()
95
- session.post.return_value = _mock_response(200, SCAN_RESPONSE)
96
+ session.request.return_value = _mock_response(200, SCAN_RESPONSE)
96
97
  mock_session_cls.return_value = session
97
98
 
98
- wg = WaveGuard(api_key="test")
99
+ wg = WaveGuard(api_key="test", max_retries=0)
99
100
  result = wg.scan(
100
101
  training=[{"a": 1}, {"a": 2}, {"a": 3}],
101
102
  test=[{"a": 2}, {"a": 100}],
@@ -113,10 +114,10 @@ class TestScan:
113
114
  @patch("waveguard.client.requests.Session")
114
115
  def test_scan_top_features(self, mock_session_cls):
115
116
  session = MagicMock()
116
- session.post.return_value = _mock_response(200, SCAN_RESPONSE)
117
+ session.request.return_value = _mock_response(200, SCAN_RESPONSE)
117
118
  mock_session_cls.return_value = session
118
119
 
119
- wg = WaveGuard(api_key="test")
120
+ wg = WaveGuard(api_key="test", max_retries=0)
120
121
  result = wg.scan(training=[{"a": 1}, {"a": 2}], test=[{"a": 100}])
121
122
 
122
123
  anomaly = result.results[1]
@@ -127,10 +128,10 @@ class TestScan:
127
128
  @patch("waveguard.client.requests.Session")
128
129
  def test_scan_sends_optional_params(self, mock_session_cls):
129
130
  session = MagicMock()
130
- session.post.return_value = _mock_response(200, SCAN_RESPONSE)
131
+ session.request.return_value = _mock_response(200, SCAN_RESPONSE)
131
132
  mock_session_cls.return_value = session
132
133
 
133
- wg = WaveGuard(api_key="test")
134
+ wg = WaveGuard(api_key="test", max_retries=0)
134
135
  wg.scan(
135
136
  training=[{"a": 1}, {"a": 2}],
136
137
  test=[{"a": 3}],
@@ -138,7 +139,7 @@ class TestScan:
138
139
  sensitivity=0.5,
139
140
  )
140
141
 
141
- call_args = session.post.call_args
142
+ call_args = session.request.call_args
142
143
  body = call_args.kwargs.get("json") or call_args[1].get("json")
143
144
  assert body["encoder_type"] == "text"
144
145
  assert body["sensitivity"] == 0.5
@@ -151,40 +152,40 @@ class TestErrors:
151
152
  @patch("waveguard.client.requests.Session")
152
153
  def test_401_raises_auth_error(self, mock_session_cls):
153
154
  session = MagicMock()
154
- session.post.return_value = _mock_response(401, text="Unauthorized")
155
+ session.request.return_value = _mock_response(401, text="Unauthorized")
155
156
  mock_session_cls.return_value = session
156
157
 
157
- wg = WaveGuard(api_key="bad-key")
158
+ wg = WaveGuard(api_key="bad-key", max_retries=0)
158
159
  with pytest.raises(AuthenticationError):
159
160
  wg.scan(training=[{"a": 1}, {"a": 2}], test=[{"a": 3}])
160
161
 
161
162
  @patch("waveguard.client.requests.Session")
162
163
  def test_422_raises_validation_error(self, mock_session_cls):
163
164
  session = MagicMock()
164
- session.post.return_value = _mock_response(422, text="Empty training")
165
+ session.request.return_value = _mock_response(422, text="Empty training")
165
166
  mock_session_cls.return_value = session
166
167
 
167
- wg = WaveGuard(api_key="test")
168
+ wg = WaveGuard(api_key="test", max_retries=0)
168
169
  with pytest.raises(ValidationError):
169
170
  wg.scan(training=[], test=[{"a": 1}])
170
171
 
171
172
  @patch("waveguard.client.requests.Session")
172
173
  def test_429_raises_rate_limit_error(self, mock_session_cls):
173
174
  session = MagicMock()
174
- session.post.return_value = _mock_response(429, text="Rate limited")
175
+ session.request.return_value = _mock_response(429, text="Rate limited")
175
176
  mock_session_cls.return_value = session
176
177
 
177
- wg = WaveGuard(api_key="test")
178
+ wg = WaveGuard(api_key="test", max_retries=0)
178
179
  with pytest.raises(RateLimitError):
179
180
  wg.scan(training=[{"a": 1}, {"a": 2}], test=[{"a": 3}])
180
181
 
181
182
  @patch("waveguard.client.requests.Session")
182
183
  def test_500_raises_server_error(self, mock_session_cls):
183
184
  session = MagicMock()
184
- session.post.return_value = _mock_response(500, text="Internal error")
185
+ session.request.return_value = _mock_response(500, text="Internal error")
185
186
  mock_session_cls.return_value = session
186
187
 
187
- wg = WaveGuard(api_key="test")
188
+ wg = WaveGuard(api_key="test", max_retries=0)
188
189
  with pytest.raises(ServerError):
189
190
  wg.scan(training=[{"a": 1}, {"a": 2}], test=[{"a": 3}])
190
191
 
@@ -193,10 +194,10 @@ class TestErrors:
193
194
  import requests as req
194
195
 
195
196
  session = MagicMock()
196
- session.post.side_effect = req.ConnectionError("Cannot connect")
197
+ session.request.side_effect = req.ConnectionError("Cannot connect")
197
198
  mock_session_cls.return_value = session
198
199
 
199
- wg = WaveGuard(api_key="test")
200
+ wg = WaveGuard(api_key="test", max_retries=0)
200
201
  with pytest.raises(WaveGuardError, match="Cannot connect"):
201
202
  wg.scan(training=[{"a": 1}, {"a": 2}], test=[{"a": 3}])
202
203
 
@@ -214,7 +215,7 @@ class TestHealth:
214
215
  @patch("waveguard.client.requests.Session")
215
216
  def test_health_parses(self, mock_session_cls):
216
217
  session = MagicMock()
217
- session.get.return_value = _mock_response(
218
+ session.request.return_value = _mock_response(
218
219
  200,
219
220
  {
220
221
  "status": "healthy",
@@ -226,7 +227,7 @@ class TestHealth:
226
227
  )
227
228
  mock_session_cls.return_value = session
228
229
 
229
- wg = WaveGuard(api_key="test")
230
+ wg = WaveGuard(api_key="test", max_retries=0)
230
231
  h = wg.health()
231
232
  assert isinstance(h, HealthStatus)
232
233
  assert h.status == "healthy"
@@ -235,7 +236,7 @@ class TestHealth:
235
236
  @patch("waveguard.client.requests.Session")
236
237
  def test_tier_parses(self, mock_session_cls):
237
238
  session = MagicMock()
238
- session.get.return_value = _mock_response(
239
+ session.request.return_value = _mock_response(
239
240
  200,
240
241
  {
241
242
  "tier": "PRO",
@@ -248,7 +249,7 @@ class TestHealth:
248
249
  )
249
250
  mock_session_cls.return_value = session
250
251
 
251
- wg = WaveGuard(api_key="test")
252
+ wg = WaveGuard(api_key="test", max_retries=0)
252
253
  t = wg.tier()
253
254
  assert isinstance(t, TierInfo)
254
255
  assert t.tier == "PRO"
@@ -26,6 +26,10 @@ Usage::
26
26
 
27
27
  from __future__ import annotations
28
28
 
29
+ import logging
30
+ import os
31
+ import time
32
+ import random
29
33
  import requests
30
34
  from dataclasses import dataclass, field
31
35
  from typing import Any, Dict, List, Optional
@@ -38,7 +42,9 @@ from .exceptions import (
38
42
  ServerError,
39
43
  )
40
44
 
41
- __version__ = "2.2.0"
45
+ __version__ = "2.3.0"
46
+
47
+ logger = logging.getLogger("waveguard")
42
48
 
43
49
 
44
50
  # ─────────────────────────────── Data Classes ─────────────────────────────
@@ -140,17 +146,23 @@ class WaveGuard:
140
146
 
141
147
  Parameters
142
148
  ----------
143
- api_key : str
144
- Your WaveGuard API key.
149
+ api_key : str, optional
150
+ Your WaveGuard API key. If not provided, reads from the
151
+ ``WAVEGUARD_API_KEY`` environment variable. Free-tier scans
152
+ work without a key (rate-limited).
145
153
  base_url : str, optional
146
154
  API base URL. Defaults to the production Modal endpoint.
147
155
  timeout : float, optional
148
156
  Request timeout in seconds. Default ``120`` (generous for GPU
149
157
  cold starts).
158
+ max_retries : int, optional
159
+ Number of automatic retries on transient errors (429, 500, 502,
160
+ 503, 504, connection errors, timeouts). Default ``2``.
161
+ Set to ``0`` to disable retries.
150
162
 
151
163
  Examples
152
164
  --------
153
- >>> wg = WaveGuard(api_key="wg_test_key")
165
+ >>> wg = WaveGuard() # reads WAVEGUARD_API_KEY from env
154
166
  >>> result = wg.scan(
155
167
  ... training=[{"a": 1}, {"a": 2}, {"a": 3}],
156
168
  ... test=[{"a": 100}],
@@ -161,22 +173,27 @@ class WaveGuard:
161
173
 
162
174
  DEFAULT_URL = "https://gpartin--waveguard-api-fastapi-app.modal.run"
163
175
 
176
+ # Status codes that trigger automatic retry
177
+ _RETRYABLE_STATUS = {429, 500, 502, 503, 504}
178
+
164
179
  def __init__(
165
180
  self,
166
181
  api_key: Optional[str] = None,
167
182
  base_url: str = DEFAULT_URL,
168
183
  timeout: float = 120.0,
184
+ max_retries: int = 2,
169
185
  ):
170
- self.api_key = api_key
186
+ self.api_key = api_key or os.environ.get("WAVEGUARD_API_KEY")
171
187
  self.base_url = base_url.rstrip("/")
172
188
  self.timeout = timeout
189
+ self.max_retries = max_retries
173
190
  self._session = requests.Session()
174
191
  headers = {
175
192
  "Content-Type": "application/json",
176
193
  "User-Agent": f"waveguard-python/{__version__}",
177
194
  }
178
- if api_key:
179
- headers["X-API-Key"] = api_key
195
+ if self.api_key:
196
+ headers["X-API-Key"] = self.api_key
180
197
  self._session.headers.update(headers)
181
198
 
182
199
  # ── Core API ──────────────────────────────────────────────────────
@@ -302,66 +319,127 @@ class WaveGuard:
302
319
  # ── Internal HTTP ─────────────────────────────────────────────────
303
320
 
304
321
  def _post(self, path: str, body: dict) -> dict:
305
- url = f"{self.base_url}{path}"
306
- try:
307
- r = self._session.post(url, json=body, timeout=self.timeout)
308
- except requests.ConnectionError:
309
- raise WaveGuardError(f"Cannot connect to {self.base_url}")
310
- except requests.Timeout:
311
- raise WaveGuardError(
312
- f"Request timed out after {self.timeout}s"
313
- )
314
- return self._handle(r)
322
+ return self._request("POST", path, json=body)
315
323
 
316
324
  def _get(self, path: str) -> dict:
325
+ return self._request("GET", path)
326
+
327
+ def _request(
328
+ self,
329
+ method: str,
330
+ path: str,
331
+ json: Optional[dict] = None,
332
+ ) -> dict:
333
+ """Execute an HTTP request with automatic retry and backoff."""
317
334
  url = f"{self.base_url}{path}"
318
- try:
319
- r = self._session.get(url, timeout=self.timeout)
320
- except requests.ConnectionError:
321
- raise WaveGuardError(f"Cannot connect to {self.base_url}")
322
- except requests.Timeout:
323
- raise WaveGuardError(
324
- f"Request timed out after {self.timeout}s"
325
- )
326
- return self._handle(r)
327
-
328
- def _handle(self, r: requests.Response) -> dict:
329
- if r.status_code == 401:
330
- raise AuthenticationError(
331
- "Invalid or missing API key",
332
- status_code=401,
333
- detail=r.text,
334
- )
335
- if r.status_code == 422:
336
- raise ValidationError(
337
- f"Validation failed: {r.text}",
338
- status_code=422,
339
- detail=r.text,
340
- )
341
- if r.status_code == 429:
342
- raise RateLimitError(
343
- f"Rate or tier limit exceeded. "
344
- f"Upgrade at {RateLimitError.UPGRADE_URL}\n"
345
- f"Detail: {r.text}",
346
- status_code=429,
347
- detail=r.text,
348
- )
349
- if r.status_code >= 500:
350
- raise ServerError(
351
- f"Server error {r.status_code}: {r.text}",
352
- status_code=r.status_code,
353
- detail=r.text,
354
- )
355
- if r.status_code >= 400:
356
- raise WaveGuardError(
357
- f"API error {r.status_code}: {r.text}",
358
- status_code=r.status_code,
359
- detail=r.text,
360
- )
361
- try:
362
- return r.json()
363
- except ValueError:
364
- return {"raw": r.text}
335
+ last_exc: Optional[Exception] = None
336
+
337
+ for attempt in range(1 + self.max_retries):
338
+ try:
339
+ logger.debug(
340
+ "%s %s (attempt %d/%d)",
341
+ method, path, attempt + 1, 1 + self.max_retries,
342
+ )
343
+ r = self._session.request(
344
+ method, url, json=json, timeout=self.timeout
345
+ )
346
+
347
+ # Non-retryable errors — raise immediately
348
+ if r.status_code in (401, 403):
349
+ raise AuthenticationError(
350
+ "Invalid or missing API key. "
351
+ "Set WAVEGUARD_API_KEY or pass api_key= to WaveGuard().",
352
+ status_code=r.status_code,
353
+ detail=r.text,
354
+ )
355
+ if r.status_code == 422:
356
+ raise ValidationError(
357
+ f"Validation failed: {r.text}",
358
+ status_code=422,
359
+ detail=r.text,
360
+ )
361
+
362
+ # Retryable errors
363
+ if r.status_code in self._RETRYABLE_STATUS:
364
+ retry_after = r.headers.get("Retry-After")
365
+ if r.status_code == 429 and attempt == self.max_retries:
366
+ raise RateLimitError(
367
+ f"Rate or tier limit exceeded. "
368
+ f"Upgrade at {RateLimitError.UPGRADE_URL}\n"
369
+ f"Detail: {r.text}",
370
+ status_code=429,
371
+ detail=r.text,
372
+ )
373
+ if attempt < self.max_retries:
374
+ delay = self._backoff_delay(attempt, retry_after)
375
+ logger.info(
376
+ "Retryable %d from %s — retrying in %.1fs",
377
+ r.status_code, path, delay,
378
+ )
379
+ time.sleep(delay)
380
+ continue
381
+ # Final attempt — raise appropriate error
382
+ if r.status_code >= 500:
383
+ raise ServerError(
384
+ f"Server error {r.status_code} after "
385
+ f"{self.max_retries} retries",
386
+ status_code=r.status_code,
387
+ detail=r.text,
388
+ )
389
+
390
+ # Other 4xx errors
391
+ if r.status_code >= 400:
392
+ raise WaveGuardError(
393
+ f"API error {r.status_code}: {r.text}",
394
+ status_code=r.status_code,
395
+ detail=r.text,
396
+ )
397
+
398
+ # Success
399
+ try:
400
+ return r.json()
401
+ except ValueError:
402
+ raise WaveGuardError(
403
+ f"Unexpected non-JSON response from {path}",
404
+ status_code=r.status_code,
405
+ detail=r.text,
406
+ )
407
+
408
+ except (requests.ConnectionError, requests.Timeout) as e:
409
+ last_exc = e
410
+ if attempt < self.max_retries:
411
+ delay = self._backoff_delay(attempt)
412
+ logger.info(
413
+ "%s on %s — retrying in %.1fs",
414
+ type(e).__name__, path, delay,
415
+ )
416
+ time.sleep(delay)
417
+ continue
418
+ if isinstance(e, requests.Timeout):
419
+ raise WaveGuardError(
420
+ f"Request timed out after {self.timeout}s "
421
+ f"({self.max_retries} retries exhausted)"
422
+ ) from e
423
+ raise WaveGuardError(
424
+ f"Cannot connect to {self.base_url} "
425
+ f"({self.max_retries} retries exhausted)"
426
+ ) from e
427
+
428
+ # Should not reach here, but just in case
429
+ raise WaveGuardError("Request failed") from last_exc
430
+
431
+ @staticmethod
432
+ def _backoff_delay(
433
+ attempt: int, retry_after: Optional[str] = None
434
+ ) -> float:
435
+ """Exponential backoff with jitter, respecting Retry-After."""
436
+ if retry_after:
437
+ try:
438
+ return min(float(retry_after), 60.0)
439
+ except ValueError:
440
+ pass
441
+ base = min(2 ** attempt, 30) # 1, 2, 4, 8, ... capped at 30s
442
+ return base + random.uniform(0, base * 0.5)
365
443
 
366
444
  def __repr__(self) -> str:
367
445
  return f"WaveGuard(base_url='{self.base_url}')"
@@ -1 +0,0 @@
1
- waveguard
File without changes