stackone-defender 0.7.0__tar.gz → 0.7.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. stackone_defender-0.7.1/.release-please-manifest.json +1 -0
  2. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/CHANGELOG.md +15 -0
  3. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/PKG-INFO +14 -1
  4. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/README.md +13 -0
  5. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/pyproject.toml +1 -1
  6. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/__init__.py +15 -1
  7. stackone_defender-0.7.1/src/stackone_defender/classifiers/tier3_orchestrator.py +27 -0
  8. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/core/prompt_defense.py +426 -20
  9. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/types.py +54 -1
  10. stackone_defender-0.7.1/tests/test_tier3.py +380 -0
  11. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/uv.lock +1 -1
  12. stackone_defender-0.7.0/.release-please-manifest.json +0 -1
  13. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/.github/workflows/ci.yaml +0 -0
  14. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/.github/workflows/release.yaml +0 -0
  15. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/.gitignore +0 -0
  16. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/.python-version +0 -0
  17. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/.release-please-config.json +0 -0
  18. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/models/minilm-full-aug/config.json +0 -0
  19. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/models/minilm-full-aug/model_quantized.onnx +0 -0
  20. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/models/minilm-full-aug/tokenizer.json +0 -0
  21. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/models/minilm-full-aug/tokenizer_config.json +0 -0
  22. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/classifiers/__init__.py +0 -0
  23. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/classifiers/onnx_classifier.py +0 -0
  24. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/classifiers/pattern_detector.py +0 -0
  25. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/classifiers/patterns.py +0 -0
  26. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/classifiers/tier2_classifier.py +0 -0
  27. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/config.py +0 -0
  28. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/core/__init__.py +0 -0
  29. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/core/tool_result_sanitizer.py +0 -0
  30. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/models/minilm-multihead-v5/classifier_config.json +0 -0
  31. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/models/minilm-multihead-v5/config.json +0 -0
  32. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/models/minilm-multihead-v5/model_quantized.onnx +0 -0
  33. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/models/minilm-multihead-v5/tokenizer.json +0 -0
  34. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/models/minilm-multihead-v5/tokenizer_config.json +0 -0
  35. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/sanitizers/__init__.py +0 -0
  36. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/sanitizers/encoding_detector.py +0 -0
  37. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/sanitizers/leet_normalizer.py +0 -0
  38. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/sanitizers/normalizer.py +0 -0
  39. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/sanitizers/pattern_remover.py +0 -0
  40. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/sanitizers/role_stripper.py +0 -0
  41. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/sanitizers/sanitizer.py +0 -0
  42. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/sfe/__init__.py +0 -0
  43. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/sfe/model.ftz +0 -0
  44. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/sfe/preprocess.py +0 -0
  45. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/utils/__init__.py +0 -0
  46. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/utils/boundary.py +0 -0
  47. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/utils/field_detection.py +0 -0
  48. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/src/stackone_defender/utils/structure.py +0 -0
  49. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/tests/__init__.py +0 -0
  50. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/tests/test_integration.py +0 -0
  51. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/tests/test_onnx_classifier.py +0 -0
  52. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/tests/test_pattern_detector.py +0 -0
  53. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/tests/test_sanitizers.py +0 -0
  54. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/tests/test_sfe.py +0 -0
  55. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/tests/test_tier2_classifier.py +0 -0
  56. {stackone_defender-0.7.0 → stackone_defender-0.7.1}/tests/test_utils.py +0 -0
@@ -0,0 +1 @@
1
+ {".":"0.7.1"}
@@ -1,5 +1,20 @@
1
1
  # Changelog
2
2
 
3
+ ## [0.7.1](https://github.com/StackOneHQ/stackone-defender/compare/stackone-defender-v0.7.0...stackone-defender-v0.7.1) (2026-06-16)
4
+
5
+
6
+ ### Features
7
+
8
+ * add defend_tool_results_async for npm batch parity ([a05783c](https://github.com/StackOneHQ/stackone-defender/commit/a05783c5671548aa66dfead1f129584b249d8778))
9
+ * Python parity with @stackone/defender 0.7.1 (Tier 3) ([c58a17c](https://github.com/StackOneHQ/stackone-defender/commit/c58a17c9ba1a902148cde9204666f7f1a916d09b))
10
+ * Tier 3 provider interface and cascade orchestration (TS 0.7.1 parity) ([f2b4109](https://github.com/StackOneHQ/stackone-defender/commit/f2b41096db4ca65741b9d4ba62f3fad7591929ab))
11
+
12
+
13
+ ### Bug Fixes
14
+
15
+ * address Copilot PR review on Tier 3 orchestration ([570f567](https://github.com/StackOneHQ/stackone-defender/commit/570f56753292700a15b73725a12db426316468c6))
16
+ * tighten Tier3ClassifyResult type and batch doc wording ([2515772](https://github.com/StackOneHQ/stackone-defender/commit/2515772f894dd2cbdaa51e9d0b39e26f151d257f))
17
+
3
18
  ## [0.7.0](https://github.com/StackOneHQ/stackone-defender/compare/stackone-defender-v0.6.3...stackone-defender-v0.7.0) (2026-05-29)
4
19
 
5
20
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: stackone-defender
3
- Version: 0.7.0
3
+ Version: 0.7.1
4
4
  Summary: Indirect prompt injection defense for AI agents using tool calls
5
5
  Project-URL: Homepage, https://github.com/StackOneHQ/stackone-defender
6
6
  Project-URL: Repository, https://github.com/StackOneHQ/stackone-defender
@@ -204,6 +204,8 @@ class DefenseResult:
204
204
 
205
205
  ### `defense.defend_tool_results(items)`
206
206
 
207
+ Sync batch API. When `enable_tier3=True`, uses one `asyncio.run()` and defends items **concurrently** via `asyncio.gather` (same scheduling model as npm `defendToolResults`; blocking sync providers still run one at a time on the event-loop thread). From async code, prefer `defend_tool_results_async`.
208
+
207
209
  ```python
208
210
  results = defense.defend_tool_results([
209
211
  {"value": email_data, "tool_name": "gmail_get_message"},
@@ -215,6 +217,17 @@ for r in results:
215
217
  print("Blocked:", ", ".join(r.fields_sanitized))
216
218
  ```
217
219
 
220
+ ### `await defense.defend_tool_results_async(items)`
221
+
222
+ Async batch API — runs `defend_tool_result_async` per item concurrently via `asyncio.gather`. Required when Tier 3 is enabled inside a running event loop (e.g. FastAPI).
223
+
224
+ ```python
225
+ results = await defense.defend_tool_results_async([
226
+ {"value": email_data, "tool_name": "gmail_get_message"},
227
+ {"value": doc_data, "tool_name": "documents_get"},
228
+ ])
229
+ ```
230
+
218
231
  ### `defense.analyze(text)`
219
232
 
220
233
  Tier 1 only — useful for debugging pattern hits without full tool-result traversal.
@@ -178,6 +178,8 @@ class DefenseResult:
178
178
 
179
179
  ### `defense.defend_tool_results(items)`
180
180
 
181
+ Sync batch API. When `enable_tier3=True`, uses one `asyncio.run()` and defends items **concurrently** via `asyncio.gather` (same scheduling model as npm `defendToolResults`; blocking sync providers still run one at a time on the event-loop thread). From async code, prefer `defend_tool_results_async`.
182
+
181
183
  ```python
182
184
  results = defense.defend_tool_results([
183
185
  {"value": email_data, "tool_name": "gmail_get_message"},
@@ -189,6 +191,17 @@ for r in results:
189
191
  print("Blocked:", ", ".join(r.fields_sanitized))
190
192
  ```
191
193
 
194
+ ### `await defense.defend_tool_results_async(items)`
195
+
196
+ Async batch API — runs `defend_tool_result_async` per item concurrently via `asyncio.gather`. Required when Tier 3 is enabled inside a running event loop (e.g. FastAPI).
197
+
198
+ ```python
199
+ results = await defense.defend_tool_results_async([
200
+ {"value": email_data, "tool_name": "gmail_get_message"},
201
+ {"value": doc_data, "tool_name": "documents_get"},
202
+ ])
203
+ ```
204
+
192
205
  ### `defense.analyze(text)`
193
206
 
194
207
  Tier 1 only — useful for debugging pattern hits without full tool-result traversal.
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "stackone-defender"
3
- version = "0.7.0"
3
+ version = "0.7.1"
4
4
  description = "Indirect prompt injection defense for AI agents using tool calls"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11"
@@ -12,6 +12,7 @@ Usage:
12
12
  """
13
13
 
14
14
  from .classifiers.onnx_classifier import get_default_model_path
15
+ from .classifiers.tier3_orchestrator import get_default_tier3_provider, set_default_tier3_provider
15
16
  from .core.prompt_defense import PromptDefense, create_prompt_defense
16
17
  from .sfe.preprocess import (
17
18
  DropDecision,
@@ -21,10 +22,19 @@ from .sfe.preprocess import (
21
22
  get_default_sfe_model_path,
22
23
  sfe_preprocess,
23
24
  )
24
- from .types import DefenseResult, MultiheadConfig, RiskLevel, Tier1Result
25
+ from .types import (
26
+ DefenderMode,
27
+ DefenseResult,
28
+ MultiheadConfig,
29
+ RiskLevel,
30
+ Tier1Result,
31
+ Tier3Provider,
32
+ Tier3Verdict,
33
+ )
25
34
  from .utils.boundary import contains_boundary_patterns, generate_boundary_instructions
26
35
 
27
36
  __all__ = [
37
+ "DefenderMode",
28
38
  "DefenseResult",
29
39
  "DropDecision",
30
40
  "MultiheadConfig",
@@ -33,11 +43,15 @@ __all__ = [
33
43
  "SfePredictor",
34
44
  "SfePreprocessResult",
35
45
  "Tier1Result",
46
+ "Tier3Provider",
47
+ "Tier3Verdict",
36
48
  "contains_boundary_patterns",
37
49
  "create_prompt_defense",
38
50
  "generate_boundary_instructions",
39
51
  "get_default_model_path",
40
52
  "get_default_predictor",
41
53
  "get_default_sfe_model_path",
54
+ "get_default_tier3_provider",
55
+ "set_default_tier3_provider",
42
56
  "sfe_preprocess",
43
57
  ]
@@ -0,0 +1,27 @@
1
+ """Tier 3 provider registry.
2
+
3
+ The defender package ships no Tier 3 implementations — proprietary model
4
+ endpoints (SageMaker, OpenAI, etc.) live in consumer code. Consumers call
5
+ ``set_default_tier3_provider(provider)`` once at app startup; ``PromptDefense``
6
+ picks the registered provider up when callers opt in via ``enable_tier3=True``.
7
+
8
+ Module-level singleton because the defender is often instantiated per-request
9
+ and we don't want to pipe a provider object through that boundary on every call.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from ..types import Tier3Provider
15
+
16
+ _default_provider: Tier3Provider | None = None
17
+
18
+
19
+ def set_default_tier3_provider(provider: Tier3Provider | None) -> None:
20
+ """Register the process-wide default Tier 3 provider. Pass ``None`` to clear."""
21
+ global _default_provider
22
+ _default_provider = provider
23
+
24
+
25
+ def get_default_tier3_provider() -> Tier3Provider | None:
26
+ """Return the registered default Tier 3 provider, or ``None`` if unset."""
27
+ return _default_provider
@@ -6,6 +6,8 @@ Provides a simple API for defending tool results against prompt injection.
6
6
 
7
7
  from __future__ import annotations
8
8
 
9
+ import asyncio
10
+ import inspect
9
11
  import logging
10
12
  import math
11
13
  import time
@@ -14,13 +16,29 @@ from typing import Any
14
16
 
15
17
  from ..classifiers.pattern_detector import PatternDetector, create_pattern_detector
16
18
  from ..classifiers.tier2_classifier import Tier2Classifier, create_tier2_classifier
19
+ from ..classifiers.tier3_orchestrator import get_default_tier3_provider
17
20
  from ..config import MAX_TRAVERSAL_DEPTH, create_config
18
21
  from ..sfe.preprocess import SfePredictor, get_default_predictor, sfe_preprocess
19
- from ..types import DefenseResult, MultiheadConfig, PromptDefenseConfig, RiskLevel, Tier1Result
22
+ from ..types import (
23
+ DefenderMode,
24
+ DefenseResult,
25
+ MultiheadConfig,
26
+ PromptDefenseConfig,
27
+ RiskLevel,
28
+ Tier1Result,
29
+ Tier3EscalationBand,
30
+ Tier3Provider,
31
+ Tier3Result,
32
+ Tier3Skip,
33
+ Tier3Verdict,
34
+ )
20
35
  from .tool_result_sanitizer import ToolResultSanitizer, create_tool_result_sanitizer
21
36
 
22
37
  _logger = logging.getLogger(__name__)
23
38
 
39
+ _DEFAULT_TIER3_BAND = Tier3EscalationBand(lower=0.3, upper=0.85)
40
+ _DEFAULT_TIER3_MAX_TEXT_LENGTH = 10000
41
+
24
42
 
25
43
  @dataclass
26
44
  class _Tier2Aggregate:
@@ -118,6 +136,29 @@ def _extract_strings(
118
136
  return strings
119
137
 
120
138
 
139
+ def _bounded_join_strings(strings: list[str], max_len: int, sep: str = "\n") -> str:
140
+ """Join strings with ``sep``, capping total length at ``max_len`` without building the full join first."""
141
+ if max_len <= 0:
142
+ return ""
143
+ parts: list[str] = []
144
+ used = 0
145
+ sep_len = len(sep)
146
+ for s in strings:
147
+ if not s:
148
+ continue
149
+ prefix = sep_len if parts else 0
150
+ if used + prefix >= max_len:
151
+ break
152
+ remaining = max_len - used - prefix
153
+ if len(s) <= remaining:
154
+ parts.append(s)
155
+ used += prefix + len(s)
156
+ else:
157
+ parts.append(s[:remaining])
158
+ break
159
+ return sep.join(parts)
160
+
161
+
121
162
  _RISK_LEVELS: list[RiskLevel] = ["low", "medium", "high", "critical"]
122
163
 
123
164
 
@@ -136,6 +177,9 @@ class PromptDefense:
136
177
  block_high_risk: bool = False,
137
178
  default_risk_level: RiskLevel = "medium",
138
179
  annotate_boundary: bool = False,
180
+ enable_tier3: bool = False,
181
+ defender_mode: DefenderMode = "cascade",
182
+ tier3: dict[str, Any] | None = None,
139
183
  ):
140
184
  self._config: PromptDefenseConfig = create_config(config)
141
185
  if block_high_risk:
@@ -184,6 +228,53 @@ class PromptDefense:
184
228
  self._config.tier2.high_risk_threshold = float(effective["high_risk_threshold"])
185
229
  self._config.tier2.medium_risk_threshold = float(effective["medium_risk_threshold"])
186
230
 
231
+ self._tier3_enabled = enable_tier3
232
+ if defender_mode not in ("cascade", "tier3_only"):
233
+ _logger.warning(
234
+ '[defender] invalid defender_mode %r — must be "cascade" or "tier3_only". '
235
+ 'Falling back to "cascade".',
236
+ defender_mode,
237
+ )
238
+ defender_mode = "cascade"
239
+ self._defender_mode: DefenderMode = defender_mode
240
+ self._tier3_custom_provider: Tier3Provider | None = None
241
+ self._tier3_band = _DEFAULT_TIER3_BAND
242
+ self._tier3_max_text_length = _DEFAULT_TIER3_MAX_TEXT_LENGTH
243
+ self._tier3_missing_provider_warned = False
244
+ tier3_opts = tier3 or {}
245
+ if tier3_opts.get("provider") is not None:
246
+ self._tier3_custom_provider = tier3_opts["provider"]
247
+ max_text_length = tier3_opts.get("max_text_length", tier3_opts.get("maxTextLength"))
248
+ if max_text_length is not None:
249
+ if isinstance(max_text_length, (int, float)) and math.isfinite(max_text_length) and max_text_length > 0:
250
+ self._tier3_max_text_length = int(max_text_length)
251
+ else:
252
+ _logger.warning(
253
+ "[defender] invalid tier3.max_text_length %s — must be a positive finite number. "
254
+ "Falling back to default %s.",
255
+ max_text_length,
256
+ _DEFAULT_TIER3_MAX_TEXT_LENGTH,
257
+ )
258
+ escalation_band = tier3_opts.get("escalation_band", tier3_opts.get("escalationBand"))
259
+ if escalation_band is not None:
260
+ lower = escalation_band.get("lower")
261
+ upper = escalation_band.get("upper")
262
+ if (
263
+ isinstance(lower, (int, float))
264
+ and isinstance(upper, (int, float))
265
+ and math.isfinite(lower)
266
+ and math.isfinite(upper)
267
+ and 0 <= lower < upper <= 1
268
+ ):
269
+ self._tier3_band = Tier3EscalationBand(lower=float(lower), upper=float(upper))
270
+ else:
271
+ _logger.warning(
272
+ "[defender] invalid tier3.escalation_band { lower: %s, upper: %s } — "
273
+ "must satisfy 0 <= lower < upper <= 1. Falling back to default { lower: 0.3, upper: 0.85 }.",
274
+ lower,
275
+ upper,
276
+ )
277
+
187
278
  def warmup_tier2(self) -> None:
188
279
  if self._tier2:
189
280
  self._tier2.warmup()
@@ -198,16 +289,311 @@ class PromptDefense:
198
289
  def is_tier2_ready(self) -> bool:
199
290
  return self._tier2.is_ready() if self._tier2 else False
200
291
 
292
+ def _resolve_tier3_provider(self) -> Tier3Provider | None:
293
+ return self._tier3_custom_provider or get_default_tier3_provider()
294
+
295
+ @staticmethod
296
+ def _validate_tier3_verdict(verdict: Any) -> Tier3Verdict | Tier3Skip:
297
+ if isinstance(verdict, Tier3Verdict):
298
+ if verdict.decision in ("block", "allow"):
299
+ return verdict
300
+ return Tier3Skip(
301
+ skip_reason=(
302
+ f'Tier 3 provider returned invalid decision: {verdict.decision!r} '
303
+ '(expected "block" | "allow")'
304
+ )
305
+ )
306
+ if verdict is None or not isinstance(verdict, dict):
307
+ return Tier3Skip(
308
+ skip_reason=f"Tier 3 provider returned non-object verdict: {type(verdict).__name__}"
309
+ )
310
+ decision = verdict.get("decision")
311
+ if decision not in ("block", "allow"):
312
+ return Tier3Skip(
313
+ skip_reason=f'Tier 3 provider returned invalid decision: {decision!r} (expected "block" | "allow")'
314
+ )
315
+ return Tier3Verdict(
316
+ decision=decision,
317
+ score=verdict.get("score"),
318
+ raw=verdict.get("raw"),
319
+ latency_ms=verdict.get("latency_ms", verdict.get("latencyMs")),
320
+ )
321
+
322
+ @staticmethod
323
+ async def _invoke_tier3_classify(provider: Tier3Provider, text: str, tool_name: str) -> Any:
324
+ ctx = {"toolName": tool_name}
325
+ result = provider.classify(text, ctx=ctx)
326
+ if inspect.isawaitable(result):
327
+ return await result
328
+ return result
329
+
330
+ @staticmethod
331
+ def _tier1_metadata(sanitized) -> tuple[list[str], list[str], dict]:
332
+ prm = sanitized.metadata.patterns_removed_by_field
333
+ mbf = sanitized.metadata.methods_by_field
334
+ detections = list(dict.fromkeys(p for patterns in prm.values() for p in patterns))
335
+ active_methods = {"role_stripping", "pattern_removal", "encoding_detection"}
336
+ fields_sanitized = [
337
+ field_name for field_name, methods in mbf.items()
338
+ if any(m in active_methods for m in methods)
339
+ ]
340
+ return detections, fields_sanitized, prm
341
+
342
+ async def _maybe_tier3_cascade(
343
+ self,
344
+ tier2: _Tier2Outcome,
345
+ tool_name: str,
346
+ ) -> tuple[Tier3Result | None, bool | None]:
347
+ """Run Tier 3 cascade escalation when Tier 2 score is in the gray band."""
348
+ if not (self._tier3_enabled and self._defender_mode == "cascade"):
349
+ return None, None
350
+ eff = tier2.effective_score
351
+ if eff is None or not tier2.max_sentence:
352
+ return None, None
353
+ if eff < self._tier3_band.lower or eff >= self._tier3_band.upper:
354
+ return None, None
355
+
356
+ provider = self._resolve_tier3_provider()
357
+ if provider is None:
358
+ if not self._tier3_missing_provider_warned:
359
+ self._tier3_missing_provider_warned = True
360
+ _logger.warning(
361
+ "[defender] enable_tier3=true but no Tier 3 provider is registered. "
362
+ "Cascade will skip Tier 3 escalation. Call set_default_tier3_provider() at app startup."
363
+ )
364
+ return Tier3Skip(skip_reason="No Tier 3 provider registered"), None
365
+
366
+ max_sentence = tier2.max_sentence
367
+ bounded = (
368
+ max_sentence[: self._tier3_max_text_length]
369
+ if len(max_sentence) > self._tier3_max_text_length
370
+ else max_sentence
371
+ )
372
+ try:
373
+ raw = await self._invoke_tier3_classify(provider, bounded, tool_name)
374
+ validated = self._validate_tier3_verdict(raw)
375
+ if isinstance(validated, Tier3Skip):
376
+ return validated, None
377
+ return validated, validated.decision == "block"
378
+ except Exception as e:
379
+ return Tier3Skip(skip_reason=f"Tier 3 provider error: {e}"), None
380
+
381
+ @staticmethod
382
+ def _finalize_allowed_and_risk(
383
+ *,
384
+ detections: list[str],
385
+ fields_sanitized: list[str],
386
+ tier2_has_threat: bool,
387
+ tier2_idx: int,
388
+ tier1_idx: int,
389
+ risk_level: RiskLevel,
390
+ block_high_risk: bool,
391
+ tier3_override_block: bool | None,
392
+ ) -> tuple[RiskLevel, bool]:
393
+ tier3_overrode_to_allow = tier3_override_block is False
394
+ tier3_overrode_to_block = tier3_override_block is True
395
+
396
+ if tier3_overrode_to_block and _RISK_LEVELS.index(risk_level) < _RISK_LEVELS.index("high"):
397
+ risk_level = "high"
398
+ elif tier3_overrode_to_allow and tier2_idx > tier1_idx:
399
+ risk_level = _RISK_LEVELS[tier1_idx]
400
+
401
+ has_threats = (
402
+ bool(detections)
403
+ or bool(fields_sanitized)
404
+ or (tier2_has_threat and not tier3_overrode_to_allow)
405
+ or tier3_overrode_to_block
406
+ )
407
+ allowed = (
408
+ not block_high_risk
409
+ or not has_threats
410
+ or risk_level not in ("high", "critical")
411
+ )
412
+ return risk_level, allowed
413
+
414
+ async def _run_tier3_only(
415
+ self,
416
+ value: Any,
417
+ provider: Tier3Provider,
418
+ tool_name: str,
419
+ depth_flag: dict[str, bool],
420
+ start_time: float,
421
+ ) -> DefenseResult:
422
+ strings = [s for s in _extract_strings(value, None, depth_flag) if len(s) > 0]
423
+ bounded = _bounded_join_strings(strings, self._tier3_max_text_length)
424
+
425
+ verdict: Tier3Verdict | None = None
426
+ skip_reason: str | None = None
427
+ if len(bounded) == 0:
428
+ skip_reason = "No strings extracted from tool result"
429
+ else:
430
+ try:
431
+ raw = await self._invoke_tier3_classify(provider, bounded, tool_name)
432
+ validated = self._validate_tier3_verdict(raw)
433
+ if isinstance(validated, Tier3Skip):
434
+ skip_reason = validated.skip_reason
435
+ else:
436
+ verdict = validated
437
+ except Exception as e:
438
+ skip_reason = f"Tier 3 provider error: {e}"
439
+
440
+ sanitized = self._tool_sanitizer.sanitize(value, tool_name=tool_name)
441
+ detections, fields_sanitized, prm = self._tier1_metadata(sanitized)
442
+
443
+ blocked = verdict is not None and verdict.decision == "block"
444
+ risk_level: RiskLevel = "high" if blocked else "low"
445
+ allowed = not self._config.block_high_risk or not blocked
446
+ tier3_result: Tier3Result = (
447
+ verdict if verdict is not None else Tier3Skip(skip_reason=skip_reason or "Tier 3 skipped")
448
+ )
449
+
450
+ return DefenseResult(
451
+ allowed=allowed,
452
+ risk_level=risk_level,
453
+ sanitized=sanitized.sanitized,
454
+ detections=detections,
455
+ fields_sanitized=fields_sanitized,
456
+ patterns_by_field=prm,
457
+ tier3=tier3_result,
458
+ fields_dropped=[],
459
+ truncated_at_depth=depth_flag["hit"] or None,
460
+ latency_ms=(time.perf_counter() - start_time) * 1000,
461
+ )
462
+
201
463
  def defend_tool_result(self, value: Any, tool_name: str) -> DefenseResult:
202
- """Defend a tool result using Tier 1 and optionally Tier 2 classification.
464
+ """Defend a tool result using Tier 1 and optionally Tier 2 / Tier 3 classification.
203
465
 
204
466
  When SFE is enabled, ``fields_dropped`` lists paths excluded from **Tier 2**
205
467
  string extraction only; the returned ``sanitized`` payload is still Tier 1 output
206
468
  from the **original** tool value (SFE does not remove fields from the returned object).
469
+
470
+ When ``enable_tier3`` is on, this delegates to :meth:`defend_tool_result_async`
471
+ via ``asyncio.run``. Call that method directly from async code (e.g. FastAPI).
207
472
  """
473
+ if self._tier3_enabled:
474
+ try:
475
+ asyncio.get_running_loop()
476
+ except RuntimeError:
477
+ return asyncio.run(self.defend_tool_result_async(value, tool_name))
478
+ raise RuntimeError(
479
+ "defend_tool_result() cannot call Tier 3 from a running event loop; "
480
+ "use: await defense.defend_tool_result_async(value, tool_name)"
481
+ )
482
+ return self._defend_tool_result_sync(value, tool_name)
483
+
484
+ async def defend_tool_result_async(self, value: Any, tool_name: str) -> DefenseResult:
485
+ """Async defense path — required when Tier 3 is enabled inside a running event loop."""
208
486
  start_time = time.perf_counter()
209
487
  depth_flag = {"hit": False}
210
488
 
489
+ if self._tier3_enabled and self._defender_mode == "tier3_only":
490
+ provider = self._resolve_tier3_provider()
491
+ if provider is not None:
492
+ return await self._run_tier3_only(value, provider, tool_name, depth_flag, start_time)
493
+ if not self._tier3_missing_provider_warned:
494
+ self._tier3_missing_provider_warned = True
495
+ _logger.warning(
496
+ "[defender] defender_mode=tier3_only but no Tier 3 provider is registered. "
497
+ "Falling back to Tier 1 + Tier 2. Call set_default_tier3_provider() at app startup."
498
+ )
499
+
500
+ return await self._defend_tool_result_async_impl(
501
+ value, tool_name, start_time=start_time, depth_flag=depth_flag
502
+ )
503
+
504
+ async def _defend_tool_result_async_impl(
505
+ self,
506
+ value: Any,
507
+ tool_name: str,
508
+ *,
509
+ start_time: float,
510
+ depth_flag: dict[str, bool],
511
+ ) -> DefenseResult:
512
+ sfe_filtered_value: Any = value
513
+ fields_dropped: list[str] = []
514
+ if self._sfe_enabled:
515
+ try:
516
+ predictor = self._sfe_custom_predictor or get_default_predictor()
517
+ if predictor is not None:
518
+ pre = sfe_preprocess(value, {"predictor": predictor, "threshold": self._sfe_threshold})
519
+ sfe_filtered_value = pre.filtered
520
+ fields_dropped = pre.dropped
521
+ if pre.truncated_at_depth:
522
+ depth_flag["hit"] = True
523
+ except Exception as e:
524
+ _logger.warning(
525
+ "[defender] SFE preprocessing failed; continuing without filtering. Reason: %s",
526
+ e,
527
+ )
528
+
529
+ sanitized = self._tool_sanitizer.sanitize(value, tool_name=tool_name)
530
+ detections, fields_sanitized, prm = self._tier1_metadata(sanitized)
531
+
532
+ tier2 = (
533
+ self._evaluate_tier2(self._tier2, sfe_filtered_value, depth_flag)
534
+ if self._tier2 is not None
535
+ else _Tier2Outcome()
536
+ )
537
+
538
+ tier3_result, tier3_override_block = await self._maybe_tier3_cascade(tier2, tool_name)
539
+
540
+ tier1_idx = _RISK_LEVELS.index(sanitized.metadata.overall_risk_level)
541
+ tier2_idx = _RISK_LEVELS.index(tier2.risk)
542
+ risk_level = _RISK_LEVELS[max(tier1_idx, tier2_idx)]
543
+
544
+ if tier2.multihead_blocked is True:
545
+ tier2_has_threat = True
546
+ elif tier2.multihead_blocked is False:
547
+ tier2_has_threat = False
548
+ else:
549
+ tier2_has_threat = (
550
+ tier2.effective_score is not None
551
+ and tier2.effective_score >= self._config.tier2.high_risk_threshold
552
+ )
553
+
554
+ risk_level, allowed = self._finalize_allowed_and_risk(
555
+ detections=detections,
556
+ fields_sanitized=fields_sanitized,
557
+ tier2_has_threat=tier2_has_threat,
558
+ tier2_idx=tier2_idx,
559
+ tier1_idx=tier1_idx,
560
+ risk_level=risk_level,
561
+ block_high_risk=self._config.block_high_risk,
562
+ tier3_override_block=tier3_override_block,
563
+ )
564
+
565
+ return DefenseResult(
566
+ allowed=allowed,
567
+ risk_level=risk_level,
568
+ sanitized=sanitized.sanitized,
569
+ detections=detections,
570
+ fields_sanitized=fields_sanitized,
571
+ patterns_by_field=prm,
572
+ tier2_score=tier2.effective_score,
573
+ tier2_raw_score=tier2.raw_score,
574
+ tier2_aux_score=tier2.aux_score,
575
+ tier2_multihead_blocked=tier2.multihead_blocked,
576
+ tier2_skip_reason=tier2.skip_reason,
577
+ max_sentence=tier2.max_sentence,
578
+ tier3=tier3_result,
579
+ fields_dropped=fields_dropped,
580
+ truncated_at_depth=depth_flag["hit"] or None,
581
+ latency_ms=(time.perf_counter() - start_time) * 1000,
582
+ )
583
+
584
+ def _defend_tool_result_sync(
585
+ self,
586
+ value: Any,
587
+ tool_name: str,
588
+ *,
589
+ start_time: float | None = None,
590
+ depth_flag: dict[str, bool] | None = None,
591
+ ) -> DefenseResult:
592
+ if start_time is None:
593
+ start_time = time.perf_counter()
594
+ if depth_flag is None:
595
+ depth_flag = {"hit": False}
596
+
211
597
  sfe_filtered_value: Any = value
212
598
  fields_dropped: list[str] = []
213
599
  if self._sfe_enabled:
@@ -275,25 +661,17 @@ class PromptDefense:
275
661
 
276
662
  # Threat signals: Tier 1 detections, Tier 1 sanitization methods, or
277
663
  # Tier 2 above-threshold (subject to multi-head veto).
278
- has_threats = bool(detections) or bool(fields_sanitized) or tier2_has_threat
279
-
280
- # Three cases for ``allowed``:
281
- # 1. ``block_high_risk`` is off -> always allow.
282
- # 2. No threat signals found -> allow (base risk from tool rules
283
- # alone does not block).
284
- # 3. Risk did not reach high/critical -> allow.
285
- allowed = (
286
- not self._config.block_high_risk
287
- or not has_threats
288
- or risk_level not in ("high", "critical")
664
+ risk_level, allowed = self._finalize_allowed_and_risk(
665
+ detections=detections,
666
+ fields_sanitized=fields_sanitized,
667
+ tier2_has_threat=tier2_has_threat,
668
+ tier2_idx=tier2_idx,
669
+ tier1_idx=tier1_idx,
670
+ risk_level=risk_level,
671
+ block_high_risk=self._config.block_high_risk,
672
+ tier3_override_block=None,
289
673
  )
290
674
 
291
- # ``tier2_score`` reports the effective score -- the value that
292
- # drove the block decision. The multi-head aux veto path sets it
293
- # to ``0.0`` (not ``None``), keeping the triple coherent:
294
- # tier2_score=0 / risk_level low / allowed=true.
295
- # ``tier2_raw_score`` is the pre-density / pre-rule max-chunk main
296
- # score for forensics -- never use it to make decisions.
297
675
  return DefenseResult(
298
676
  allowed=allowed,
299
677
  risk_level=risk_level,
@@ -578,9 +956,37 @@ class PromptDefense:
578
956
  out.risk = tier2.get_risk_level(out.effective_score)
579
957
 
580
958
  def defend_tool_results(self, items: list[dict[str, Any]]) -> list[DefenseResult]:
581
- """Defend multiple tool results."""
959
+ """Defend multiple tool results (sequential when Tier 3 is off).
960
+
961
+ When ``enable_tier3`` is on, delegates to :meth:`defend_tool_results_async`
962
+ via ``asyncio.run`` (parallel per item, matching npm ``defendToolResults``).
963
+ Use the async method directly inside a running event loop.
964
+ """
965
+ if self._tier3_enabled:
966
+ try:
967
+ asyncio.get_running_loop()
968
+ except RuntimeError:
969
+ return asyncio.run(self.defend_tool_results_async(items))
970
+ raise RuntimeError(
971
+ "defend_tool_results() cannot call Tier 3 from a running event loop; "
972
+ "use: await defense.defend_tool_results_async(items)"
973
+ )
582
974
  return [self.defend_tool_result(item["value"], item["tool_name"]) for item in items]
583
975
 
976
+ async def defend_tool_results_async(self, items: list[dict[str, Any]]) -> list[DefenseResult]:
977
+ """Defend multiple tool results concurrently (npm ``defendToolResults`` parity).
978
+
979
+ Runs :meth:`defend_tool_result_async` per item in parallel via ``asyncio.gather``.
980
+ Result order matches ``items``.
981
+ """
982
+ if not items:
983
+ return []
984
+ return list(
985
+ await asyncio.gather(
986
+ *(self.defend_tool_result_async(item["value"], item["tool_name"]) for item in items)
987
+ )
988
+ )
989
+
584
990
  def analyze(self, text: str) -> Tier1Result:
585
991
  """Analyze text for injection patterns (Tier 1 only)."""
586
992
  return self._pattern_detector.analyze(text)