proxilion 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. proxilion/__init__.py +1 -1
  2. proxilion/audit/__init__.py +15 -15
  3. proxilion/audit/compliance/base.py +4 -4
  4. proxilion/audit/compliance/eu_ai_act.py +14 -4
  5. proxilion/audit/compliance/iso27001.py +2 -2
  6. proxilion/audit/compliance/soc2.py +16 -3
  7. proxilion/audit/events.py +9 -5
  8. proxilion/audit/explainability.py +30 -19
  9. proxilion/audit/hash_chain.py +14 -0
  10. proxilion/caching/tool_cache.py +14 -8
  11. proxilion/context/context_window.py +27 -2
  12. proxilion/contrib/anthropic.py +2 -2
  13. proxilion/contrib/mcp.py +2 -1
  14. proxilion/contrib/openai.py +2 -2
  15. proxilion/core.py +26 -21
  16. proxilion/exceptions.py +84 -0
  17. proxilion/guards/output_guard.py +1 -1
  18. proxilion/observability/__init__.py +3 -1
  19. proxilion/observability/metrics.py +14 -6
  20. proxilion/observability/session_cost_tracker.py +6 -7
  21. proxilion/policies/builtin.py +2 -1
  22. proxilion/policies/registry.py +12 -6
  23. proxilion/security/__init__.py +51 -37
  24. proxilion/security/agent_trust.py +23 -8
  25. proxilion/security/behavioral_drift.py +14 -6
  26. proxilion/security/idor_protection.py +12 -4
  27. proxilion/security/intent_capsule.py +3 -2
  28. proxilion/security/intent_validator.py +89 -2
  29. proxilion/security/memory_integrity.py +14 -13
  30. proxilion/security/rate_limiter.py +112 -22
  31. proxilion/timeouts/manager.py +2 -0
  32. {proxilion-0.0.1.dist-info → proxilion-0.0.3.dist-info}/METADATA +6 -6
  33. {proxilion-0.0.1.dist-info → proxilion-0.0.3.dist-info}/RECORD +35 -35
  34. {proxilion-0.0.1.dist-info → proxilion-0.0.3.dist-info}/WHEEL +0 -0
  35. {proxilion-0.0.1.dist-info → proxilion-0.0.3.dist-info}/licenses/LICENSE +0 -0
proxilion/__init__.py CHANGED
@@ -35,7 +35,7 @@ https://proxilion.com
35
35
  Source code: https://github.com/clay-good/proxilion-sdk
36
36
  """
37
37
 
38
- __version__ = "0.1.0"
38
+ __version__ = "0.0.2"
39
39
 
40
40
  # Core types - always available
41
41
  # Main Proxilion class
@@ -56,28 +56,14 @@ from proxilion.audit.events import (
56
56
  redact_sensitive_data,
57
57
  reset_sequence,
58
58
  )
59
- from proxilion.audit.hash_chain import (
60
- GENESIS_HASH,
61
- BatchedHashChain,
62
- ChainVerificationResult,
63
- HashChain,
64
- MerkleBatch,
65
- MerkleTree,
66
- )
67
- from proxilion.audit.logger import (
68
- AuditLogger,
69
- InMemoryAuditLogger,
70
- LoggerConfig,
71
- RotationPolicy,
72
- )
73
59
 
74
60
  # Explainability (CA SB 53 compliance)
75
61
  from proxilion.audit.explainability import (
76
62
  DecisionExplainer,
77
63
  DecisionFactor,
78
64
  DecisionType,
79
- ExplainableDecision,
80
65
  ExplainabilityLogger,
66
+ ExplainableDecision,
81
67
  Explanation,
82
68
  ExplanationFormat,
83
69
  Outcome,
@@ -86,6 +72,20 @@ from proxilion.audit.explainability import (
86
72
  create_guard_decision,
87
73
  create_rate_limit_decision,
88
74
  )
75
+ from proxilion.audit.hash_chain import (
76
+ GENESIS_HASH,
77
+ BatchedHashChain,
78
+ ChainVerificationResult,
79
+ HashChain,
80
+ MerkleBatch,
81
+ MerkleTree,
82
+ )
83
+ from proxilion.audit.logger import (
84
+ AuditLogger,
85
+ InMemoryAuditLogger,
86
+ LoggerConfig,
87
+ RotationPolicy,
88
+ )
89
89
 
90
90
  __all__ = [
91
91
  # Events
@@ -18,13 +18,13 @@ from proxilion.audit.events import AuditEventV2, EventType
18
18
 
19
19
 
20
20
  class ComplianceFramework(Enum):
21
- """Supported compliance frameworks."""
21
+ """Supported compliance frameworks.
22
+
23
+ Currently implemented: EU_AI_ACT, SOC2, ISO27001
24
+ """
22
25
  EU_AI_ACT = "eu_ai_act"
23
26
  SOC2 = "soc2"
24
27
  ISO27001 = "iso27001"
25
- NIST_AI_RMF = "nist_ai_rmf"
26
- HIPAA = "hipaa"
27
- GDPR = "gdpr"
28
28
 
29
29
 
30
30
  @dataclass
@@ -535,6 +535,12 @@ class EUAIActExporter(BaseComplianceExporter):
535
535
 
536
536
  # Article 15: Risk Assessment
537
537
  risk_log = self.export_risk_assessment_log(start, end)
538
+ # Article 15 compliance requires that security events have mitigation responses
539
+ # or no security events occurred during the period
540
+ art_15_compliant = (
541
+ len(risk_log["security_events"]) == 0 or
542
+ len(risk_log.get("mitigation_actions", [])) > 0
543
+ )
538
544
  article_15_evidence = ComplianceEvidence(
539
545
  control_id="Article 15",
540
546
  control_name="Accuracy, Robustness and Cybersecurity",
@@ -545,7 +551,7 @@ class EUAIActExporter(BaseComplianceExporter):
545
551
  ),
546
552
  events=risk_log["security_events"] + risk_log["anomaly_detections"],
547
553
  summary=risk_log["summary"],
548
- compliant=True, # Having the log demonstrates compliance
554
+ compliant=art_15_compliant,
549
555
  notes=(
550
556
  "All security events were handled with appropriate mitigation actions."
551
557
  if risk_log["security_events"] else
@@ -554,7 +560,7 @@ class EUAIActExporter(BaseComplianceExporter):
554
560
  )
555
561
  evidence.append(article_15_evidence)
556
562
 
557
- if risk_log["summary"]["total_security_events"] > 10:
563
+ if risk_log.get("summary", {}).get("total_security_events", 0) > 10:
558
564
  recommendations.append(
559
565
  "Review security event patterns and consider strengthening access controls."
560
566
  )
@@ -563,6 +569,10 @@ class EUAIActExporter(BaseComplianceExporter):
563
569
  all_events = self.filter_by_date_range(start, end)
564
570
  stats = self.compute_summary_stats(all_events)
565
571
 
572
+ # Article 17 compliance requires evidence of systematic monitoring
573
+ # (having events logged demonstrates the quality management system is active)
574
+ # Allow empty for zero-duration periods
575
+ art_17_compliant = stats["total_events"] > 0 or start == end
566
576
  article_17_evidence = ComplianceEvidence(
567
577
  control_id="Article 17",
568
578
  control_name="Quality Management System",
@@ -578,7 +588,7 @@ class EUAIActExporter(BaseComplianceExporter):
578
588
  "tools_available": stats["unique_tools"],
579
589
  "period_coverage": f"{start.date()} to {end.date()}",
580
590
  },
581
- compliant=True,
591
+ compliant=art_17_compliant,
582
592
  )
583
593
  evidence.append(article_17_evidence)
584
594
 
@@ -589,7 +599,7 @@ class EUAIActExporter(BaseComplianceExporter):
589
599
  "risk_classification": self._risk_classification,
590
600
  "total_operations": stats["total_events"],
591
601
  "human_oversight_rate": f"{oversight.human_involvement_rate:.1%}",
592
- "security_events": risk_log["summary"]["total_security_events"],
602
+ "security_events": risk_log.get("summary", {}).get("total_security_events", 0),
593
603
  "compliance_status": (
594
604
  "Compliant" if all(e.compliant for e in evidence) else "Review Required"
595
605
  ),
@@ -453,7 +453,7 @@ class ISO27001Exporter(BaseComplianceExporter):
453
453
  )
454
454
  evidence_list.append(a9_evidence)
455
455
 
456
- if access_data["summary"]["denial_rate"] > 0.3:
456
+ if access_data.get("summary", {}).get("denial_rate", 0.0) > 0.3:
457
457
  recommendations.append(
458
458
  "High access denial rate detected. Review access policies and user permissions."
459
459
  )
@@ -501,7 +501,7 @@ class ISO27001Exporter(BaseComplianceExporter):
501
501
  ),
502
502
  events=incidents.security_incidents[:50],
503
503
  summary=incidents_data["summary"],
504
- compliant=incidents_data["summary"]["response_rate"] >= 0.95,
504
+ compliant=incidents_data.get("summary", {}).get("response_rate", 0.0) >= 0.95,
505
505
  notes=(
506
506
  f"{len(incidents.security_incidents)} security events detected. "
507
507
  f"Response rate: {incidents_data['summary']['response_rate']:.1%}. "
@@ -397,7 +397,7 @@ class SOC2Exporter(BaseComplianceExporter):
397
397
  )
398
398
  evidence_list.append(cc6_1_evidence)
399
399
 
400
- denial_rate = access_data["summary"]["denial_rate"]
400
+ denial_rate = access_data.get("summary", {}).get("denial_rate", 0.0)
401
401
  if denial_rate > 0.2:
402
402
  recommendations.append(
403
403
  f"High denial rate ({denial_rate:.1%}) detected. "
@@ -415,6 +415,12 @@ class SOC2Exporter(BaseComplianceExporter):
415
415
  monitoring = self.export_monitoring_evidence(start, end)
416
416
  monitoring_data = monitoring.to_dict()
417
417
 
418
+ # CC7.2 is compliant if monitoring coverage exists and incident responses
419
+ # are documented when anomalies are detected
420
+ cc7_2_compliant = (
421
+ monitoring.monitoring_coverage > 0.0 and
422
+ (len(monitoring.anomaly_detections) == 0 or len(monitoring.incident_responses) > 0)
423
+ )
418
424
  cc7_2_evidence = ComplianceEvidence(
419
425
  control_id="CC7.2",
420
426
  control_name="System Monitoring",
@@ -427,7 +433,7 @@ class SOC2Exporter(BaseComplianceExporter):
427
433
  ),
428
434
  events=monitoring.security_alerts + monitoring.anomaly_detections,
429
435
  summary=monitoring_data["summary"],
430
- compliant=True, # Having monitoring in place is compliant
436
+ compliant=cc7_2_compliant,
431
437
  notes=(
432
438
  "Continuous monitoring is in place. "
433
439
  f"{len(monitoring.security_alerts)} alerts and "
@@ -446,6 +452,13 @@ class SOC2Exporter(BaseComplianceExporter):
446
452
  changes = self.export_change_management_evidence(start, end)
447
453
  changes_data = changes.to_dict()
448
454
 
455
+ # CC8.1 is compliant if changes are tracked and approval workflows
456
+ # exist for configuration changes
457
+ total_changes = len(changes.policy_updates) + len(changes.configuration_changes)
458
+ cc8_1_compliant = (
459
+ total_changes == 0 or # No changes is compliant
460
+ len(changes.approval_workflows) > 0 # Changes should have approvals
461
+ )
449
462
  cc8_1_evidence = ComplianceEvidence(
450
463
  control_id="CC8.1",
451
464
  control_name="Change Management",
@@ -458,7 +471,7 @@ class SOC2Exporter(BaseComplianceExporter):
458
471
  ),
459
472
  events=changes.policy_updates + changes.configuration_changes,
460
473
  summary=changes_data["summary"],
461
- compliant=True, # Having change tracking is compliant
474
+ compliant=cc8_1_compliant,
462
475
  notes=(
463
476
  f"Tracked {len(changes.policy_updates)} policy updates and "
464
477
  f"{len(changes.configuration_changes)} configuration changes."
proxilion/audit/events.py CHANGED
@@ -12,6 +12,7 @@ import hashlib
12
12
  import json
13
13
  import os
14
14
  import re
15
+ import threading
15
16
  import time
16
17
  from dataclasses import dataclass, field
17
18
  from datetime import datetime, timezone
@@ -76,21 +77,24 @@ def _utc_now() -> datetime:
76
77
  return datetime.now(timezone.utc)
77
78
 
78
79
 
79
- # Global sequence counter (thread-safe via GIL for simple increments)
80
+ # Global sequence counter with thread-safe access
80
81
  _sequence_counter = 0
81
- _sequence_lock = None
82
+ _sequence_lock = threading.Lock()
83
+
82
84
 
83
85
  def _next_sequence() -> int:
84
86
  """Get next sequence number (monotonically increasing)."""
85
87
  global _sequence_counter
86
- _sequence_counter += 1
87
- return _sequence_counter
88
+ with _sequence_lock:
89
+ _sequence_counter += 1
90
+ return _sequence_counter
88
91
 
89
92
 
90
93
  def reset_sequence(value: int = 0) -> None:
91
94
  """Reset the sequence counter (for testing)."""
92
95
  global _sequence_counter
93
- _sequence_counter = value
96
+ with _sequence_lock:
97
+ _sequence_counter = value
94
98
 
95
99
 
96
100
  @dataclass
@@ -47,15 +47,16 @@ Example:
47
47
 
48
48
  from __future__ import annotations
49
49
 
50
+ import contextlib
50
51
  import hashlib
51
52
  import json
52
53
  import logging
53
- import re
54
54
  import threading
55
- from dataclasses import asdict, dataclass, field
55
+ from collections.abc import Callable
56
+ from dataclasses import dataclass, field
56
57
  from datetime import datetime, timezone
57
58
  from enum import Enum
58
- from typing import Any, Callable
59
+ from typing import Any
59
60
 
60
61
  logger = logging.getLogger(__name__)
61
62
 
@@ -165,16 +166,12 @@ class ExplainableDecision:
165
166
 
166
167
  # Convert string enums
167
168
  if isinstance(self.decision_type, str):
168
- try:
169
+ with contextlib.suppress(ValueError):
169
170
  self.decision_type = DecisionType(self.decision_type)
170
- except ValueError:
171
- pass # Keep as string if not a known type
172
171
 
173
172
  if isinstance(self.outcome, str):
174
- try:
173
+ with contextlib.suppress(ValueError):
175
174
  self.outcome = Outcome(self.outcome)
176
- except ValueError:
177
- pass
178
175
 
179
176
  @property
180
177
  def passed(self) -> bool:
@@ -198,10 +195,16 @@ class ExplainableDecision:
198
195
 
199
196
  def to_dict(self) -> dict[str, Any]:
200
197
  """Convert to dictionary."""
198
+ dt_val = self.decision_type
199
+ if isinstance(dt_val, DecisionType):
200
+ dt_val = dt_val.value
201
+ outcome_val = self.outcome
202
+ if isinstance(outcome_val, Outcome):
203
+ outcome_val = outcome_val.value
201
204
  return {
202
205
  "decision_id": self.decision_id,
203
- "decision_type": str(self.decision_type.value if isinstance(self.decision_type, DecisionType) else self.decision_type),
204
- "outcome": str(self.outcome.value if isinstance(self.outcome, Outcome) else self.outcome),
206
+ "decision_type": str(dt_val),
207
+ "outcome": str(outcome_val),
205
208
  "factors": [f.to_dict() for f in self.factors],
206
209
  "context": self.context,
207
210
  "timestamp": self.timestamp.isoformat(),
@@ -414,7 +417,10 @@ class DecisionExplainer:
414
417
  factors_explained = self._explain_factors(decision, templates)
415
418
  counterfactual = self._generate_counterfactual(decision, templates)
416
419
  confidence_breakdown = self._explain_confidence(decision, templates)
417
- recommendations = self._generate_recommendations(decision) if self._include_recommendations else []
420
+ if self._include_recommendations:
421
+ recommendations = self._generate_recommendations(decision)
422
+ else:
423
+ recommendations = []
418
424
 
419
425
  # Format the output
420
426
  if format == ExplanationFormat.MARKDOWN:
@@ -465,7 +471,8 @@ class DecisionExplainer:
465
471
  template = templates.get("rate_denied", "Rate limit exceeded")
466
472
  return template.format(**context)
467
473
 
468
- elif dt in (DecisionType.INPUT_GUARD, DecisionType.OUTPUT_GUARD) or dt in ("input_guard", "output_guard"):
474
+ elif dt in (DecisionType.INPUT_GUARD, DecisionType.OUTPUT_GUARD,
475
+ "input_guard", "output_guard"):
469
476
  if outcome in (Outcome.ALLOWED, "ALLOWED"):
470
477
  return templates.get("guard_pass", "Content allowed")
471
478
  elif outcome in (Outcome.MODIFIED, "MODIFIED"):
@@ -483,7 +490,8 @@ class DecisionExplainer:
483
490
  return templates.get("circuit_closed", "Service available")
484
491
  elif state == "open":
485
492
  failures = context.get("failures", 0)
486
- return templates.get("circuit_open", "Service unavailable").format(failures=failures)
493
+ tmpl = templates.get("circuit_open", "Service unavailable")
494
+ return tmpl.format(failures=failures)
487
495
  else:
488
496
  return templates.get("circuit_half_open", "Service testing")
489
497
 
@@ -622,13 +630,13 @@ class DecisionExplainer:
622
630
  for f in failing_factors:
623
631
  # Generate specific counterfactual based on factor name
624
632
  if "role" in f.name.lower():
625
- changes.append(f"User had the required role")
633
+ changes.append("User had the required role")
626
634
  elif "rate" in f.name.lower():
627
- changes.append(f"Request was within rate limits")
635
+ changes.append("Request was within rate limits")
628
636
  elif "budget" in f.name.lower():
629
- changes.append(f"Budget was not exceeded")
637
+ changes.append("Budget was not exceeded")
630
638
  elif "trust" in f.name.lower():
631
- changes.append(f"Trust level was sufficient")
639
+ changes.append("Trust level was sufficient")
632
640
  else:
633
641
  changes.append(f"{f.name} check passed")
634
642
 
@@ -1080,7 +1088,10 @@ def create_guard_decision(
1080
1088
  context = {"guard_type": guard_type}
1081
1089
  if content_sample:
1082
1090
  # Truncate and sanitize
1083
- context["content_preview"] = content_sample[:100] + "..." if len(content_sample) > 100 else content_sample
1091
+ if len(content_sample) > 100:
1092
+ context["content_preview"] = content_sample[:100] + "..."
1093
+ else:
1094
+ context["content_preview"] = content_sample
1084
1095
 
1085
1096
  return ExplainableDecision(
1086
1097
  decision_type=decision_type,
@@ -162,6 +162,7 @@ class HashChain:
162
162
  )
163
163
 
164
164
  expected_previous = GENESIS_HASH
165
+ last_sequence = -1
165
166
 
166
167
  for i, event in enumerate(self._events):
167
168
  # Check previous_hash linkage
@@ -185,6 +186,19 @@ class HashChain:
185
186
  verified_count=i,
186
187
  )
187
188
 
189
+ # Verify monotonically increasing sequence numbers
190
+ if event.sequence_number <= last_sequence:
191
+ return ChainVerificationResult(
192
+ valid=False,
193
+ error_message=(
194
+ f"Sequence number not monotonically increasing at index {i}: "
195
+ f"got {event.sequence_number}, previous was {last_sequence}"
196
+ ),
197
+ error_index=i,
198
+ verified_count=i,
199
+ )
200
+ last_sequence = event.sequence_number
201
+
188
202
  expected_previous = event.event_hash
189
203
 
190
204
  return ChainVerificationResult(
@@ -23,6 +23,9 @@ from typing import Any, ParamSpec, TypeVar
23
23
 
24
24
  logger = logging.getLogger(__name__)
25
25
 
26
+ # Sentinel object to distinguish cache misses from cached None values
27
+ _CACHE_MISS = object()
28
+
26
29
  P = ParamSpec("P")
27
30
  T = TypeVar("T")
28
31
 
@@ -320,7 +323,8 @@ class ToolCache:
320
323
  tool_name: str,
321
324
  args: dict[str, Any],
322
325
  user_id: str | None = None,
323
- ) -> Any | None:
326
+ default: Any = None,
327
+ ) -> Any:
324
328
  """
325
329
  Get a cached result.
326
330
 
@@ -328,9 +332,11 @@ class ToolCache:
328
332
  tool_name: Name of the tool.
329
333
  args: Tool arguments.
330
334
  user_id: Optional user ID.
335
+ default: Value to return if not found/expired (default: None).
336
+ Use _CACHE_MISS sentinel to distinguish misses from cached None.
331
337
 
332
338
  Returns:
333
- Cached value or None if not found/expired.
339
+ Cached value or default if not found/expired.
334
340
  """
335
341
  key = self._generate_key(tool_name, args, user_id)
336
342
 
@@ -339,14 +345,14 @@ class ToolCache:
339
345
 
340
346
  if entry is None:
341
347
  self._stats.misses += 1
342
- return None
348
+ return default
343
349
 
344
350
  if entry.is_expired():
345
351
  # Remove expired entry
346
352
  del self._cache[key]
347
353
  self._stats.misses += 1
348
354
  self._stats.expirations += 1
349
- return None
355
+ return default
350
356
 
351
357
  # Record hit and move to end (for LRU)
352
358
  entry.access()
@@ -557,7 +563,7 @@ class ToolCache:
557
563
  def __contains__(self, key: tuple[str, dict[str, Any]]) -> bool:
558
564
  """Check if a tool/args combination is cached."""
559
565
  tool_name, args = key
560
- return self.get(tool_name, args) is not None
566
+ return self.get(tool_name, args, default=_CACHE_MISS) is not _CACHE_MISS
561
567
 
562
568
  def __len__(self) -> int:
563
569
  """Get number of cached entries."""
@@ -613,9 +619,9 @@ def cached_tool(
613
619
  else:
614
620
  cache_args = all_args
615
621
 
616
- # Check cache
617
- cached_result = cache.get(tool_name, cache_args)
618
- if cached_result is not None:
622
+ # Check cache — use sentinel to handle cached falsy values (None, False, 0)
623
+ cached_result = cache.get(tool_name, cache_args, default=_CACHE_MISS)
624
+ if cached_result is not _CACHE_MISS:
619
625
  logger.debug(f"Cache hit for {tool_name}")
620
626
  return cached_result
621
627
 
@@ -195,8 +195,33 @@ class KeepFirstLastStrategy:
195
195
  total = sum(m.token_count or 0 for m in messages)
196
196
  if total <= max_tokens:
197
197
  return messages
198
- # Still need to truncate - use sliding window
199
- return SlidingWindowStrategy().fit(messages, max_tokens)
198
+ # Still need to truncate keep first and last, trim from middle
199
+ # to maintain the KeepFirstLast contract
200
+ first_msgs = messages[: self.keep_first]
201
+ last_msgs = messages[self.keep_first :]
202
+ first_tokens = sum(m.token_count or 0 for m in first_msgs)
203
+ remaining = max_tokens - first_tokens
204
+ if remaining <= 0:
205
+ # First messages alone exceed budget, trim first messages
206
+ kept: list[Message] = []
207
+ budget = 0
208
+ for msg in first_msgs:
209
+ msg_tokens = msg.token_count or 0
210
+ if budget + msg_tokens > max_tokens:
211
+ break
212
+ kept.append(msg)
213
+ budget += msg_tokens
214
+ return kept
215
+ # Fill remaining budget from the end
216
+ kept_last: list[Message] = []
217
+ budget = 0
218
+ for msg in reversed(last_msgs):
219
+ msg_tokens = msg.token_count or 0
220
+ if budget + msg_tokens > remaining:
221
+ break
222
+ kept_last.insert(0, msg)
223
+ budget += msg_tokens
224
+ return first_msgs + kept_last
200
225
 
201
226
  # Get first and last messages
202
227
  first_msgs = messages[: self.keep_first]
@@ -325,9 +325,9 @@ class ProxilionToolHandler:
325
325
  # Check authorization
326
326
  if user is not None:
327
327
  context = {
328
+ **input_data,
328
329
  "tool_name": tool_name,
329
330
  "input": input_data,
330
- **input_data,
331
331
  }
332
332
 
333
333
  auth_result = self.proxilion.check(user, tool.action, tool.resource, context)
@@ -436,9 +436,9 @@ class ProxilionToolHandler:
436
436
  # Check authorization
437
437
  if user is not None:
438
438
  context = {
439
+ **input_data,
439
440
  "tool_name": tool_name,
440
441
  "input": input_data,
441
- **input_data,
442
442
  }
443
443
 
444
444
  auth_result = self.proxilion.check(user, tool.action, tool.resource, context)
proxilion/contrib/mcp.py CHANGED
@@ -520,10 +520,11 @@ class MCPToolWrapper:
520
520
  )
521
521
 
522
522
  # Build context for authorization
523
+ # Spread arguments first so trusted keys can't be overridden
523
524
  context = {
525
+ **arguments, # Flatten arguments for policy access
524
526
  "arguments": arguments,
525
527
  "tool_name": self.name,
526
- **arguments, # Flatten arguments for policy access
527
528
  }
528
529
 
529
530
  # Check authorization
@@ -314,9 +314,9 @@ class ProxilionFunctionHandler:
314
314
  # Check authorization
315
315
  if user is not None:
316
316
  context = {
317
+ **arguments,
317
318
  "function_name": function_name,
318
319
  "arguments": arguments,
319
- **arguments,
320
320
  }
321
321
 
322
322
  auth_result = self.proxilion.check(user, func.action, func.resource, context)
@@ -430,9 +430,9 @@ class ProxilionFunctionHandler:
430
430
  # Check authorization
431
431
  if user is not None:
432
432
  context = {
433
+ **arguments,
433
434
  "function_name": function_name,
434
435
  "arguments": arguments,
435
- **arguments,
436
436
  }
437
437
 
438
438
  auth_result = self.proxilion.check(user, func.action, func.resource, context)
proxilion/core.py CHANGED
@@ -33,6 +33,7 @@ from proxilion.context.session import (
33
33
  )
34
34
  from proxilion.engines import EngineFactory
35
35
  from proxilion.exceptions import (
36
+ ApprovalRequiredError,
36
37
  AuthorizationError,
37
38
  CircuitOpenError,
38
39
  IDORViolationError,
@@ -2488,13 +2489,12 @@ class Proxilion:
2488
2489
  ... tools=tools,
2489
2490
  ... )
2490
2491
  """
2491
- # Get filtered tools
2492
+ # Get filtered tools (apply both filters when both are provided)
2493
+ tools = self._tool_registry.list_enabled()
2492
2494
  if category is not None:
2493
- tools = self._tool_registry.list_by_category(category)
2494
- elif max_risk_level is not None:
2495
- tools = self._tool_registry.list_by_risk_level(max_risk_level)
2496
- else:
2497
- tools = self._tool_registry.list_enabled()
2495
+ tools = [t for t in tools if t.category == category]
2496
+ if max_risk_level is not None:
2497
+ tools = [t for t in tools if t.risk_level.value <= max_risk_level.value]
2498
2498
 
2499
2499
  # Export each tool manually
2500
2500
  if format == "openai":
@@ -2550,11 +2550,12 @@ class Proxilion:
2550
2550
  reason=auth_result.reason,
2551
2551
  )
2552
2552
 
2553
- # Check risk level requires approval
2553
+ # Check if tool requires approval before execution
2554
2554
  if tool_def.requires_approval:
2555
- # In a real implementation, this would trigger an approval workflow
2556
- logger.warning(
2557
- f"Tool {name} requires approval but automatic approval is not implemented"
2555
+ raise ApprovalRequiredError(
2556
+ tool_name=name,
2557
+ user=user.user_id,
2558
+ reason="Tool is marked as requiring approval before execution",
2558
2559
  )
2559
2560
 
2560
2561
  return self._tool_registry.execute(name, **kwargs)
@@ -2603,9 +2604,12 @@ class Proxilion:
2603
2604
  reason=auth_result.reason,
2604
2605
  )
2605
2606
 
2607
+ # Check if tool requires approval before execution
2606
2608
  if tool_def.requires_approval:
2607
- logger.warning(
2608
- f"Tool {name} requires approval but automatic approval is not implemented"
2609
+ raise ApprovalRequiredError(
2610
+ tool_name=name,
2611
+ user=user.user_id,
2612
+ reason="Tool is marked as requiring approval before execution",
2609
2613
  )
2610
2614
 
2611
2615
  return await self._tool_registry.execute_async(name, **kwargs)
@@ -2816,8 +2820,9 @@ class Proxilion:
2816
2820
  results = []
2817
2821
  for call in tool_calls:
2818
2822
  merged_context = dict(context or {})
2819
- merged_context["tool_call_id"] = call.id
2820
2823
  merged_context.update(call.arguments)
2824
+ # Set tool_call_id after arguments to prevent override from untrusted data
2825
+ merged_context["tool_call_id"] = call.id
2821
2826
 
2822
2827
  auth_result = self.check(user, "execute", call.name, merged_context)
2823
2828
  results.append((call, auth_result))
@@ -2859,8 +2864,9 @@ class Proxilion:
2859
2864
  results = []
2860
2865
  for call in tool_calls:
2861
2866
  merged_context = dict(context or {})
2862
- merged_context["tool_call_id"] = call.id
2863
2867
  merged_context.update(call.arguments)
2868
+ # Set tool_call_id after arguments to prevent override from untrusted data
2869
+ merged_context["tool_call_id"] = call.id
2864
2870
 
2865
2871
  # Check authorization
2866
2872
  auth_result = self.check(user, "execute", call.name, merged_context)
@@ -2934,13 +2940,12 @@ class Proxilion:
2934
2940
  ... tools=openai_tools,
2935
2941
  ... )
2936
2942
  """
2937
- # Get filtered tools
2943
+ # Get filtered tools (apply both filters when both are provided)
2944
+ tools = self._tool_registry.list_enabled()
2938
2945
  if category is not None:
2939
- tools = self._tool_registry.list_by_category(category)
2940
- elif max_risk_level is not None:
2941
- tools = self._tool_registry.list_by_risk_level(max_risk_level)
2942
- else:
2943
- tools = self._tool_registry.list_enabled()
2946
+ tools = [t for t in tools if t.category == category]
2947
+ if max_risk_level is not None:
2948
+ tools = [t for t in tools if t.risk_level.value <= max_risk_level.value]
2944
2949
 
2945
2950
  adapter = get_adapter(provider=provider)
2946
2951
  return adapter.format_tools(tools)
@@ -2997,7 +3002,7 @@ class Proxilion:
2997
3002
  # Check authorization
2998
3003
  auth_result = self.check(
2999
3004
  user, "execute", call.name,
3000
- {"tool_call_id": call.id, **call.arguments}
3005
+ {**call.arguments, "tool_call_id": call.id}
3001
3006
  )
3002
3007
 
3003
3008
  if not auth_result.allowed: