ea-agentgate 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ea_agentgate/README.md +28 -0
- ea_agentgate/__init__.py +60 -0
- ea_agentgate/_version.py +7 -0
- ea_agentgate/agent.py +604 -0
- ea_agentgate/api_client.py +352 -0
- ea_agentgate/backends/__init__.py +101 -0
- ea_agentgate/backends/compliant.py +705 -0
- ea_agentgate/backends/guardrail_backend.py +29 -0
- ea_agentgate/backends/guardrail_memory.py +326 -0
- ea_agentgate/backends/guardrail_redis.py +443 -0
- ea_agentgate/backends/guardrail_types.py +123 -0
- ea_agentgate/backends/memory.py +391 -0
- ea_agentgate/backends/protocols.py +512 -0
- ea_agentgate/backends/redis.py +560 -0
- ea_agentgate/backends/redis_async.py +582 -0
- ea_agentgate/backends/redis_common.py +84 -0
- ea_agentgate/backends/types.py +30 -0
- ea_agentgate/cli/__init__.py +171 -0
- ea_agentgate/cli/__main__.py +5 -0
- ea_agentgate/cli/cmd_approvals.py +155 -0
- ea_agentgate/cli/cmd_audit.py +111 -0
- ea_agentgate/cli/cmd_auth.py +73 -0
- ea_agentgate/cli/cmd_costs.py +147 -0
- ea_agentgate/cli/cmd_datasets.py +197 -0
- ea_agentgate/cli/cmd_formal.py +382 -0
- ea_agentgate/cli/cmd_overview.py +38 -0
- ea_agentgate/cli/cmd_pii.py +264 -0
- ea_agentgate/cli/cmd_settings.py +93 -0
- ea_agentgate/cli/cmd_threats.py +164 -0
- ea_agentgate/cli/cmd_traces.py +120 -0
- ea_agentgate/cli/cmd_users.py +138 -0
- ea_agentgate/cli/formatters.py +63 -0
- ea_agentgate/cli/table_helpers.py +31 -0
- ea_agentgate/client.py +612 -0
- ea_agentgate/examples/demo.py +1 -0
- ea_agentgate/examples/jinja2_templating_demo.py +1 -0
- ea_agentgate/examples/openai/README.md +3 -0
- ea_agentgate/examples/openai/__init__.py +1 -0
- ea_agentgate/examples/openai/sdk_example.py +12 -0
- ea_agentgate/examples/policy_engine_demo.py +1 -0
- ea_agentgate/examples/prompt_guard_demo.py +1 -0
- ea_agentgate/examples/resilience_demo.py +13 -0
- ea_agentgate/exceptions.py +355 -0
- ea_agentgate/feedback/__init__.py +49 -0
- ea_agentgate/feedback/dpo_formatter.py +346 -0
- ea_agentgate/feedback/models.py +108 -0
- ea_agentgate/feedback/storage.py +316 -0
- ea_agentgate/formal/__init__.py +44 -0
- ea_agentgate/formal/helpers.py +47 -0
- ea_agentgate/formal/models.py +292 -0
- ea_agentgate/inference/__init__.py +14 -0
- ea_agentgate/inference/sidecar.py +277 -0
- ea_agentgate/integrations/__init__.py +8 -0
- ea_agentgate/integrations/anthropic.py +140 -0
- ea_agentgate/integrations/base.py +160 -0
- ea_agentgate/integrations/openai.py +198 -0
- ea_agentgate/integrations/request_utils.py +30 -0
- ea_agentgate/integrations/types.py +30 -0
- ea_agentgate/middleware/__init__.py +76 -0
- ea_agentgate/middleware/approval.py +247 -0
- ea_agentgate/middleware/audit_log.py +197 -0
- ea_agentgate/middleware/base.py +251 -0
- ea_agentgate/middleware/cost_tracker.py +107 -0
- ea_agentgate/middleware/dashboard.py +144 -0
- ea_agentgate/middleware/dataset_recorder.py +400 -0
- ea_agentgate/middleware/feedback_collector.py +284 -0
- ea_agentgate/middleware/guardrail.py +288 -0
- ea_agentgate/middleware/input_extraction.py +20 -0
- ea_agentgate/middleware/otel_exporter.py +243 -0
- ea_agentgate/middleware/pii_vault.py +810 -0
- ea_agentgate/middleware/pii_vault_detector.py +438 -0
- ea_agentgate/middleware/pii_vault_manager.py +78 -0
- ea_agentgate/middleware/pii_vault_models.py +31 -0
- ea_agentgate/middleware/policy_middleware.py +213 -0
- ea_agentgate/middleware/prompt_guard.py +690 -0
- ea_agentgate/middleware/prompt_template.py +282 -0
- ea_agentgate/middleware/proof_middleware.py +622 -0
- ea_agentgate/middleware/rate_limiter.py +205 -0
- ea_agentgate/middleware/semantic_cache.py +368 -0
- ea_agentgate/middleware/semantic_validator.py +518 -0
- ea_agentgate/middleware/validator.py +280 -0
- ea_agentgate/prompts/JINJA2_QUICKSTART.md +122 -0
- ea_agentgate/prompts/__init__.py +46 -0
- ea_agentgate/prompts/filters.py +264 -0
- ea_agentgate/prompts/manager.py +601 -0
- ea_agentgate/prompts/registry.json +26 -0
- ea_agentgate/prompts/registry.py +61 -0
- ea_agentgate/prompts/templates/chain_of_thought.json +63 -0
- ea_agentgate/prompts/templates/creative_steering.json +120 -0
- ea_agentgate/prompts/templates/few_shot.json +80 -0
- ea_agentgate/prompts/templates/role_based.json +102 -0
- ea_agentgate/providers/__init__.py +80 -0
- ea_agentgate/providers/anthropic_async.py +138 -0
- ea_agentgate/providers/anthropic_base.py +92 -0
- ea_agentgate/providers/anthropic_provider.py +138 -0
- ea_agentgate/providers/base.py +90 -0
- ea_agentgate/providers/google_provider.py +291 -0
- ea_agentgate/providers/health.py +469 -0
- ea_agentgate/providers/openai_async.py +111 -0
- ea_agentgate/providers/openai_common.py +174 -0
- ea_agentgate/providers/openai_provider.py +111 -0
- ea_agentgate/providers/registry.py +409 -0
- ea_agentgate/providers/routing.py +410 -0
- ea_agentgate/py.typed +0 -0
- ea_agentgate/resilience/__init__.py +15 -0
- ea_agentgate/resilience/circuit_breaker.py +293 -0
- ea_agentgate/security/__init__.py +76 -0
- ea_agentgate/security/access_control.py +494 -0
- ea_agentgate/security/audit.py +567 -0
- ea_agentgate/security/audit_models.py +367 -0
- ea_agentgate/security/encryption.py +402 -0
- ea_agentgate/security/integrity.py +375 -0
- ea_agentgate/security/policies/default_guardrails.json +109 -0
- ea_agentgate/security/policies/pii_protection.json +88 -0
- ea_agentgate/security/policy.py +422 -0
- ea_agentgate/security/policy_engine.py +486 -0
- ea_agentgate/security/policy_io.py +29 -0
- ea_agentgate/security/policy_parser.py +237 -0
- ea_agentgate/security/policy_types.py +130 -0
- ea_agentgate/security/secure_delete.py +402 -0
- ea_agentgate/tool_registry.py +102 -0
- ea_agentgate/trace.py +133 -0
- ea_agentgate/transaction_manager.py +261 -0
- ea_agentgate/verification.py +404 -0
- ea_agentgate/verification_manager.py +138 -0
- ea_agentgate-1.0.0.dist-info/METADATA +1419 -0
- ea_agentgate-1.0.0.dist-info/RECORD +289 -0
- ea_agentgate-1.0.0.dist-info/WHEEL +4 -0
- ea_agentgate-1.0.0.dist-info/entry_points.txt +2 -0
- ea_agentgate-1.0.0.dist-info/licenses/LICENSE +21 -0
- server/.env.example +69 -0
- server/__init__.py +20 -0
- server/adapters/__init__.py +1 -0
- server/adapters/budget_deepseek.py +139 -0
- server/adapters/mcp_policy/__init__.py +1 -0
- server/adapters/mcp_policy/tools_policy_governance.py +397 -0
- server/audit/__init__.py +50 -0
- server/audit/bus.py +163 -0
- server/audit/config.py +19 -0
- server/audit/consumer.py +263 -0
- server/config.py +80 -0
- server/config_runtime.py +56 -0
- server/config_secrets.py +179 -0
- server/cors_config.py +57 -0
- server/db/__init__.py +11 -0
- server/db/readiness.py +248 -0
- server/db/schema_guard.py +163 -0
- server/lifespan.py +353 -0
- server/logging_config.py +72 -0
- server/main.py +872 -0
- server/mcp/__init__.py +44 -0
- server/mcp/__main__.py +82 -0
- server/mcp/api_client.py +295 -0
- server/mcp/auth_session.py +432 -0
- server/mcp/azure_mfa_guard.py +60 -0
- server/mcp/confirm.py +114 -0
- server/mcp/execution_policy.py +294 -0
- server/mcp/guardrails.py +552 -0
- server/mcp/guardrails_sync.py +349 -0
- server/mcp/job_store.py +247 -0
- server/mcp/models.py +154 -0
- server/mcp/monitoring.py +93 -0
- server/mcp/policy_engine.py +274 -0
- server/mcp/resources.py +198 -0
- server/mcp/server.py +89 -0
- server/mcp/tools_api.py +286 -0
- server/mcp/tools_async.py +175 -0
- server/mcp/tools_governance.py +633 -0
- server/mcp/tools_safety.py +39 -0
- server/mcp/types_ground_truth.py +47 -0
- server/metrics.py +309 -0
- server/middleware/__init__.py +19 -0
- server/middleware/security_headers.py +100 -0
- server/middleware/threat_detection.py +578 -0
- server/migrations/add_failed_login_tracking.sql +5 -0
- server/migrations/add_mfa_fields.sql +7 -0
- server/migrations/add_webauthn_fields.sql +9 -0
- server/models/__init__.py +164 -0
- server/models/approval_schemas.py +85 -0
- server/models/audit_schemas.py +177 -0
- server/models/common_enums.py +27 -0
- server/models/database.py +503 -0
- server/models/dataset_schemas.py +380 -0
- server/models/formal_security_schemas.py +336 -0
- server/models/governance_schemas.py +68 -0
- server/models/identity_schemas.py +275 -0
- server/models/pii_schemas.py +345 -0
- server/models/policy_input_schemas.py +82 -0
- server/models/prompt_schemas.py +160 -0
- server/models/schemas.py +236 -0
- server/models/security_policy_schemas.py +75 -0
- server/models/trace_schemas.py +79 -0
- server/models/user_schemas.py +318 -0
- server/policy_governance/__init__.py +1 -0
- server/policy_governance/kernel/__init__.py +1 -0
- server/policy_governance/kernel/alert_dispatch.py +341 -0
- server/policy_governance/kernel/alert_models.py +249 -0
- server/policy_governance/kernel/alerting_factory.py +289 -0
- server/policy_governance/kernel/alerts.py +367 -0
- server/policy_governance/kernel/consensus_verifier.py +685 -0
- server/policy_governance/kernel/counterfactual_verifier.py +109 -0
- server/policy_governance/kernel/credential_check.py +198 -0
- server/policy_governance/kernel/deception_injector.py +590 -0
- server/policy_governance/kernel/delegation_lineage.py +496 -0
- server/policy_governance/kernel/detection_behavioral.py +143 -0
- server/policy_governance/kernel/detection_input.py +226 -0
- server/policy_governance/kernel/detection_ip_blocking.py +192 -0
- server/policy_governance/kernel/distributed_health_monitor.py +561 -0
- server/policy_governance/kernel/enforcement.py +452 -0
- server/policy_governance/kernel/evidence_log.py +226 -0
- server/policy_governance/kernel/formal_models.py +101 -0
- server/policy_governance/kernel/gamma_builder.py +249 -0
- server/policy_governance/kernel/master_key.py +463 -0
- server/policy_governance/kernel/master_key_router.py +445 -0
- server/policy_governance/kernel/patterns_injection.py +207 -0
- server/policy_governance/kernel/patterns_traversal.py +90 -0
- server/policy_governance/kernel/pii_token_service.py +653 -0
- server/policy_governance/kernel/runtime_settings.py +86 -0
- server/policy_governance/kernel/solver_engine.py +759 -0
- server/policy_governance/kernel/spec_synthesizer.py +617 -0
- server/policy_governance/kernel/threat_definitions.py +56 -0
- server/policy_governance/kernel/threat_detector.py +641 -0
- server/policy_governance/kernel/threat_detector_analysis.py +567 -0
- server/policy_governance/kernel/threat_detector_config.py +134 -0
- server/policy_governance/kernel/threat_detector_events.py +201 -0
- server/policy_governance/kernel/threat_detector_utils.py +230 -0
- server/policy_governance/kernel/threat_pattern_base.py +160 -0
- server/policy_governance/kernel/threat_patterns.py +300 -0
- server/policy_governance/kernel/verification_grants.py +162 -0
- server/policy_governance/kernel/z3_runtime_engine.py +165 -0
- server/rate_limiting.py +134 -0
- server/routers/__init__.py +55 -0
- server/routers/access_mode.py +71 -0
- server/routers/api_keys.py +429 -0
- server/routers/approvals.py +165 -0
- server/routers/audit.py +219 -0
- server/routers/auth.py +690 -0
- server/routers/auth_helpers.py +336 -0
- server/routers/auth_mfa.py +299 -0
- server/routers/auth_registration.py +638 -0
- server/routers/auth_utils.py +158 -0
- server/routers/dataset_helpers.py +79 -0
- server/routers/datasets.py +786 -0
- server/routers/datasets_operations.py +219 -0
- server/routers/device_auth.py +330 -0
- server/routers/health.py +57 -0
- server/routers/mcp_mfa_callback.py +240 -0
- server/routers/passkey.py +517 -0
- server/routers/pii.py +771 -0
- server/routers/pii_compliance.py +514 -0
- server/routers/pii_nlp.py +503 -0
- server/routers/pii_utils.py +81 -0
- server/routers/policies.py +776 -0
- server/routers/policy_governance.py +678 -0
- server/routers/policy_governance_verification.py +843 -0
- server/routers/result_utils.py +40 -0
- server/routers/settings.py +128 -0
- server/routers/setup.py +444 -0
- server/routers/test.py +658 -0
- server/routers/traces.py +295 -0
- server/routers/users.py +167 -0
- server/routers/verification.py +165 -0
- server/runtime/__init__.py +21 -0
- server/runtime/profile.py +78 -0
- server/security/__init__.py +1 -0
- server/security/azure/__init__.py +17 -0
- server/security/azure/credential_factory.py +86 -0
- server/security/azure/postgres_token_provider.py +62 -0
- server/security/identity/__init__.py +64 -0
- server/security/identity/adapter.py +109 -0
- server/security/identity/custom_oidc_provider.py +14 -0
- server/security/identity/descope_provider.py +14 -0
- server/security/identity/local_provider.py +73 -0
- server/security/identity/mcp_access.py +100 -0
- server/security/identity/oidc.py +106 -0
- server/security/identity/policy.py +175 -0
- server/security/identity/roles.py +59 -0
- server/security/identity/service.py +125 -0
- server/security/identity/store.py +281 -0
- server/sentry_config.py +203 -0
- server/static/vendor/scalar-api-reference-1.44.13.min.js +8 -0
- server/utils/__init__.py +21 -0
- server/utils/captcha.py +159 -0
- server/utils/db.py +66 -0
- server/utils/mfa.py +150 -0
- server/utils/secret_loader.py +117 -0
- server/utils/test_runner.py +401 -0
- server/utils/time_buckets.py +65 -0
- server/utils/webauthn_helper.py +278 -0
ea_agentgate/README.md
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# AgentGate Package Notes
|
|
2
|
+
|
|
3
|
+
## Principal-Led Run Cleanup
|
|
4
|
+
|
|
5
|
+
Principal-led cleanup is approval-gated and run-scoped. After the run is approved and reaches
|
|
6
|
+
`gatekeeper PASS` with `state DONE`, execute:
|
|
7
|
+
|
|
8
|
+
```bash
|
|
9
|
+
python3 scripts/cleanup_completed_run_worktrees.py \
|
|
10
|
+
--run-dir tests/artifacts/workflow/<run_id> \
|
|
11
|
+
--apply
|
|
12
|
+
```
|
|
13
|
+
|
|
14
|
+
Only worktrees and branches declared for that `run_id` inside `plan_proposal.json lanes` are
|
|
15
|
+
eligible for deletion.
|
|
16
|
+
|
|
17
|
+
```mermaid
|
|
18
|
+
flowchart TD
|
|
19
|
+
A["User approves run"] --> B["Gatekeeper PASS + state DONE"]
|
|
20
|
+
B --> C["cleanup_completed_run_worktrees.py for run_id"]
|
|
21
|
+
C --> D["Delete run-scoped lane branches/worktrees only"]
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
Codex app path:
|
|
25
|
+
|
|
26
|
+
1. Open terminal with `Toggle Terminal` (top-right) or `Cmd+J` on macOS.
|
|
27
|
+
2. Optional: right-click run in left sidebar and select `Fork into local`.
|
|
28
|
+
3. Run cleanup command with the target `<run_id>`.
|
ea_agentgate/__init__.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Public package surface for AgentGate.
|
|
2
|
+
|
|
3
|
+
Keep the top-level import lightweight so ``import ea_agentgate`` works
|
|
4
|
+
from a base wheel install without requiring optional provider, crypto,
|
|
5
|
+
or server dependencies. Public symbols are loaded lazily on demand.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from importlib import import_module
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from ._version import __version__
|
|
14
|
+
|
|
15
|
+
_LAZY_EXPORTS: dict[str, tuple[str, str]] = {
|
|
16
|
+
"Agent": ("ea_agentgate.agent", "Agent"),
|
|
17
|
+
"ToolDef": ("ea_agentgate.tool_registry", "ToolDef"),
|
|
18
|
+
"providers": ("ea_agentgate", "providers"),
|
|
19
|
+
"Trace": ("ea_agentgate.trace", "Trace"),
|
|
20
|
+
"TraceStatus": ("ea_agentgate.trace", "TraceStatus"),
|
|
21
|
+
"UniversalClient": ("ea_agentgate.client", "UniversalClient"),
|
|
22
|
+
"AgentGate": ("ea_agentgate.client", "UniversalClient"),
|
|
23
|
+
"CompletionResult": ("ea_agentgate.client", "CompletionResult"),
|
|
24
|
+
"AllProvidersFailedError": ("ea_agentgate.client", "AllProvidersFailedError"),
|
|
25
|
+
"Metadata": ("ea_agentgate.client", "Metadata"),
|
|
26
|
+
"TokenUsage": ("ea_agentgate.client", "TokenUsage"),
|
|
27
|
+
"Performance": ("ea_agentgate.client", "Performance"),
|
|
28
|
+
"AgentGateError": ("ea_agentgate.exceptions", "AgentGateError"),
|
|
29
|
+
"AgentSafetyError": ("ea_agentgate.exceptions", "AgentSafetyError"),
|
|
30
|
+
"ValidationError": ("ea_agentgate.exceptions", "ValidationError"),
|
|
31
|
+
"RateLimitError": ("ea_agentgate.exceptions", "RateLimitError"),
|
|
32
|
+
"BudgetExceededError": ("ea_agentgate.exceptions", "BudgetExceededError"),
|
|
33
|
+
"ApprovalRequired": ("ea_agentgate.exceptions", "ApprovalRequired"),
|
|
34
|
+
"ApprovalDenied": ("ea_agentgate.exceptions", "ApprovalDenied"),
|
|
35
|
+
"ApprovalTimeout": ("ea_agentgate.exceptions", "ApprovalTimeout"),
|
|
36
|
+
"TransactionFailed": ("ea_agentgate.exceptions", "TransactionFailed"),
|
|
37
|
+
"GuardrailViolationError": ("ea_agentgate.exceptions", "GuardrailViolationError"),
|
|
38
|
+
"check_admissibility": ("ea_agentgate.verification", "check_admissibility"),
|
|
39
|
+
"verify_certificate": ("ea_agentgate.verification", "verify_certificate"),
|
|
40
|
+
"verify_plan": ("ea_agentgate.verification", "verify_plan"),
|
|
41
|
+
"AdmissibilityResult": ("ea_agentgate.verification", "AdmissibilityResult"),
|
|
42
|
+
"CertificateVerificationResult": (
|
|
43
|
+
"ea_agentgate.verification",
|
|
44
|
+
"CertificateVerificationResult",
|
|
45
|
+
),
|
|
46
|
+
"PlanVerificationResult": ("ea_agentgate.verification", "PlanVerificationResult"),
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
__all__ = [*_LAZY_EXPORTS, "__version__"]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def __getattr__(name: str) -> Any:
|
|
53
|
+
"""Resolve public exports lazily."""
|
|
54
|
+
if name == "providers":
|
|
55
|
+
return import_module("ea_agentgate.providers")
|
|
56
|
+
if name in _LAZY_EXPORTS:
|
|
57
|
+
module_name, attr_name = _LAZY_EXPORTS[name]
|
|
58
|
+
module = import_module(module_name)
|
|
59
|
+
return getattr(module, attr_name)
|
|
60
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
ea_agentgate/_version.py
ADDED
ea_agentgate/agent.py
ADDED
|
@@ -0,0 +1,604 @@
|
|
|
1
|
+
"""Main Agent class for safe tool execution."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import functools
|
|
7
|
+
import inspect
|
|
8
|
+
import uuid
|
|
9
|
+
from contextlib import contextmanager, asynccontextmanager
|
|
10
|
+
from dataclasses import dataclass, field, replace
|
|
11
|
+
from typing import (
|
|
12
|
+
Any,
|
|
13
|
+
TypeVar,
|
|
14
|
+
ParamSpec,
|
|
15
|
+
cast,
|
|
16
|
+
)
|
|
17
|
+
from collections.abc import Callable, Generator, AsyncGenerator
|
|
18
|
+
|
|
19
|
+
from .trace import Trace
|
|
20
|
+
from .middleware.base import Middleware, MiddlewareChain, MiddlewareContext
|
|
21
|
+
from .tool_registry import ToolDef, ToolRegistry
|
|
22
|
+
from .transaction_manager import TransactionManager
|
|
23
|
+
from .verification_manager import VerificationConfig, VerificationInputs, VerificationManager
|
|
24
|
+
|
|
25
|
+
P = ParamSpec("P")
|
|
26
|
+
R = TypeVar("R")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class AgentConfig:
|
|
31
|
+
"""Configuration for Agent identifiers."""
|
|
32
|
+
|
|
33
|
+
agent_id: str = field(default_factory=lambda: str(uuid.uuid4())[:8])
|
|
34
|
+
session_id: str | None = None
|
|
35
|
+
user_id: str | None = None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
_VERIFICATION_KWARG_ALIASES = {
|
|
39
|
+
"formal_verification": "enabled",
|
|
40
|
+
"principal": "principal",
|
|
41
|
+
"tenant_id": "tenant_id",
|
|
42
|
+
"verification_mode": "mode",
|
|
43
|
+
"verification_provider": "provider",
|
|
44
|
+
"formal_api_client": "api_client",
|
|
45
|
+
"certificate_callback": "certificate_callback",
|
|
46
|
+
}
|
|
47
|
+
_VERIFICATION_INPUT_KWARG_ALIASES = {
|
|
48
|
+
"policies": "policies",
|
|
49
|
+
"grants": "grants",
|
|
50
|
+
"revocations": "revocations",
|
|
51
|
+
"obligations": "obligations",
|
|
52
|
+
"environment": "environment",
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _build_verification_config(
|
|
57
|
+
verification: VerificationConfig | None,
|
|
58
|
+
overrides: dict[str, Any],
|
|
59
|
+
) -> VerificationConfig:
|
|
60
|
+
"""Merge legacy verification kwargs into a VerificationConfig."""
|
|
61
|
+
config = verification or VerificationConfig()
|
|
62
|
+
if not overrides:
|
|
63
|
+
return config
|
|
64
|
+
|
|
65
|
+
remaining = dict(overrides)
|
|
66
|
+
updated_fields: dict[str, Any] = {}
|
|
67
|
+
for legacy_name, field_name in _VERIFICATION_KWARG_ALIASES.items():
|
|
68
|
+
if legacy_name in remaining:
|
|
69
|
+
updated_fields[field_name] = remaining.pop(legacy_name)
|
|
70
|
+
input_updates: dict[str, Any] = {}
|
|
71
|
+
for legacy_name, field_name in _VERIFICATION_INPUT_KWARG_ALIASES.items():
|
|
72
|
+
if legacy_name in remaining:
|
|
73
|
+
input_updates[field_name] = remaining.pop(legacy_name)
|
|
74
|
+
|
|
75
|
+
if remaining:
|
|
76
|
+
unknown_args = ", ".join(sorted(remaining))
|
|
77
|
+
raise TypeError(f"Unexpected verification arguments: {unknown_args}")
|
|
78
|
+
if input_updates:
|
|
79
|
+
verification_inputs = verification.inputs if verification else VerificationInputs()
|
|
80
|
+
updated_fields["inputs"] = replace(verification_inputs, **input_updates)
|
|
81
|
+
return replace(config, **updated_fields)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _convert_args_to_kwargs(
|
|
85
|
+
tool_def: ToolDef,
|
|
86
|
+
args: tuple[Any, ...],
|
|
87
|
+
kwargs: dict[str, Any],
|
|
88
|
+
) -> dict[str, Any]:
|
|
89
|
+
"""Convert positional args to keyword args using tool signature.
|
|
90
|
+
|
|
91
|
+
Inspects the function signature of the tool and maps positional
|
|
92
|
+
arguments to their corresponding parameter names.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
tool_def: The tool definition whose function signature is used.
|
|
96
|
+
args: Positional arguments to convert.
|
|
97
|
+
kwargs: Existing keyword arguments (mutated in place).
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
The updated kwargs dict with positional args merged in.
|
|
101
|
+
"""
|
|
102
|
+
sig = inspect.signature(tool_def.fn)
|
|
103
|
+
params = list(sig.parameters.keys())
|
|
104
|
+
for i, arg in enumerate(args):
|
|
105
|
+
if i < len(params):
|
|
106
|
+
kwargs[params[i]] = arg
|
|
107
|
+
return kwargs
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class Agent:
|
|
111
|
+
"""Agent with safe, traced tool execution.
|
|
112
|
+
|
|
113
|
+
Provides:
|
|
114
|
+
- Automatic tracing of all tool calls
|
|
115
|
+
- Middleware stack for validation, rate limiting, etc.
|
|
116
|
+
- Transaction support with automatic rollback
|
|
117
|
+
- Human-in-the-loop approvals
|
|
118
|
+
- Optional formal verification with proof-carrying authorization
|
|
119
|
+
|
|
120
|
+
Example:
|
|
121
|
+
agent = Agent(
|
|
122
|
+
middleware=[
|
|
123
|
+
Validator(block_paths=["/"]),
|
|
124
|
+
RateLimiter(max_calls=100, window="1m"),
|
|
125
|
+
AuditLog(destination="audit.jsonl"),
|
|
126
|
+
]
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
@agent.tool
|
|
130
|
+
def delete_file(path: str) -> str:
|
|
131
|
+
os.remove(path)
|
|
132
|
+
return f"Deleted {path}"
|
|
133
|
+
|
|
134
|
+
# Execute with full tracing
|
|
135
|
+
result = agent.call("delete_file", path="/tmp/cache.txt")
|
|
136
|
+
|
|
137
|
+
# View traces
|
|
138
|
+
for trace in agent.traces:
|
|
139
|
+
print(trace)
|
|
140
|
+
|
|
141
|
+
Formal Verification:
|
|
142
|
+
agent = Agent(
|
|
143
|
+
formal_verification=True,
|
|
144
|
+
principal="agent:ops",
|
|
145
|
+
policies=[{"effect": "deny", "action": "delete", "resource": "/prod/*"}],
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
result = agent.call("read_data", resource="/api/users")
|
|
149
|
+
cert = agent.last_certificate # DecisionCertificate dict
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
def __init__(
|
|
153
|
+
self,
|
|
154
|
+
middleware: list[Middleware] | None = None,
|
|
155
|
+
agent_id: str | None = None,
|
|
156
|
+
session_id: str | None = None,
|
|
157
|
+
user_id: str | None = None,
|
|
158
|
+
*,
|
|
159
|
+
verification: VerificationConfig | None = None,
|
|
160
|
+
**verification_kwargs: Any,
|
|
161
|
+
) -> None:
|
|
162
|
+
"""Initialize agent.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
middleware: List of middleware to apply to all tool calls.
|
|
166
|
+
agent_id: Unique identifier for this agent.
|
|
167
|
+
session_id: Session identifier.
|
|
168
|
+
user_id: User identifier.
|
|
169
|
+
verification: Optional grouped formal verification config.
|
|
170
|
+
**verification_kwargs: Backward-compatible formal verification
|
|
171
|
+
keyword arguments such as
|
|
172
|
+
``formal_verification``, ``principal``, ``tenant_id``,
|
|
173
|
+
``verification_mode``, and ``formal_api_client``.
|
|
174
|
+
"""
|
|
175
|
+
self.middleware = middleware or []
|
|
176
|
+
self.config = AgentConfig(
|
|
177
|
+
agent_id=agent_id or str(uuid.uuid4())[:8],
|
|
178
|
+
session_id=session_id,
|
|
179
|
+
user_id=user_id,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Verification manager
|
|
183
|
+
self._verification = VerificationManager(
|
|
184
|
+
_build_verification_config(verification, verification_kwargs),
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Auto-inject ProofCarryingMiddleware when enabled
|
|
188
|
+
proof_mw = self._verification.build_middleware()
|
|
189
|
+
if proof_mw is not None:
|
|
190
|
+
self.middleware = [proof_mw] + self.middleware
|
|
191
|
+
|
|
192
|
+
self._tool_registry = ToolRegistry()
|
|
193
|
+
self._traces: list[Trace] = []
|
|
194
|
+
self._chain = MiddlewareChain(self.middleware)
|
|
195
|
+
self.txn = TransactionManager()
|
|
196
|
+
|
|
197
|
+
# ------------------------------------------------------------------
|
|
198
|
+
# Identity properties
|
|
199
|
+
# ------------------------------------------------------------------
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def agent_id(self) -> str:
|
|
203
|
+
"""Current agent identifier."""
|
|
204
|
+
return self.config.agent_id
|
|
205
|
+
|
|
206
|
+
@agent_id.setter
|
|
207
|
+
def agent_id(self, value: str) -> None:
|
|
208
|
+
self.config.agent_id = value
|
|
209
|
+
|
|
210
|
+
@property
|
|
211
|
+
def session_id(self) -> str | None:
|
|
212
|
+
"""Current session identifier."""
|
|
213
|
+
return self.config.session_id
|
|
214
|
+
|
|
215
|
+
@session_id.setter
|
|
216
|
+
def session_id(self, value: str | None) -> None:
|
|
217
|
+
self.config.session_id = value
|
|
218
|
+
|
|
219
|
+
@property
|
|
220
|
+
def user_id(self) -> str | None:
|
|
221
|
+
"""Current user identifier."""
|
|
222
|
+
return self.config.user_id
|
|
223
|
+
|
|
224
|
+
@user_id.setter
|
|
225
|
+
def user_id(self, value: str | None) -> None:
|
|
226
|
+
self.config.user_id = value
|
|
227
|
+
|
|
228
|
+
# ------------------------------------------------------------------
|
|
229
|
+
# Tool and trace accessors
|
|
230
|
+
# ------------------------------------------------------------------
|
|
231
|
+
|
|
232
|
+
@property
|
|
233
|
+
def traces(self) -> list[Trace]:
|
|
234
|
+
"""Return a copy of all traces from this agent."""
|
|
235
|
+
return self._traces.copy()
|
|
236
|
+
|
|
237
|
+
@property
|
|
238
|
+
def tools(self) -> dict[str, ToolDef]:
|
|
239
|
+
"""Return registered tools as a dict copy."""
|
|
240
|
+
return self._tool_registry.tools
|
|
241
|
+
|
|
242
|
+
@property
|
|
243
|
+
def formal_verification(self) -> bool:
|
|
244
|
+
"""Whether formal verification is enabled."""
|
|
245
|
+
return self._verification.enabled
|
|
246
|
+
|
|
247
|
+
@property
|
|
248
|
+
def last_certificate(self) -> dict[str, Any] | None:
|
|
249
|
+
"""Most recent DecisionCertificate from formal verification.
|
|
250
|
+
|
|
251
|
+
Returns ``None`` if formal verification is disabled or no tool
|
|
252
|
+
call has been made yet.
|
|
253
|
+
|
|
254
|
+
The dict contains:
|
|
255
|
+
- ``decision_id``: Unique certificate ID
|
|
256
|
+
- ``result``: ``"ADMISSIBLE"`` or ``"INADMISSIBLE"``
|
|
257
|
+
- ``proof_type``: ``"CONSTRUCTIVE_TRACE"`` / ``"UNSAT_CORE"``
|
|
258
|
+
/ ``"COUNTEREXAMPLE"``
|
|
259
|
+
- ``theorem_hash``: SHA-256 of the theorem expression
|
|
260
|
+
- ``alpha_hash``: SHA-256 of the action context
|
|
261
|
+
- ``gamma_hash``: SHA-256 of the knowledge base
|
|
262
|
+
- ``signature``: Ed25519 signature (base64)
|
|
263
|
+
"""
|
|
264
|
+
return self._verification.last_certificate
|
|
265
|
+
|
|
266
|
+
# ------------------------------------------------------------------
|
|
267
|
+
# Verification
|
|
268
|
+
# ------------------------------------------------------------------
|
|
269
|
+
|
|
270
|
+
def verify_last_certificate(self) -> bool:
|
|
271
|
+
"""Verify the most recent certificate's signature and theorem hash.
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
``True`` if the certificate is valid, ``False`` if invalid
|
|
275
|
+
or no certificate exists.
|
|
276
|
+
"""
|
|
277
|
+
return self._verification.verify_last_certificate()
|
|
278
|
+
|
|
279
|
+
# ------------------------------------------------------------------
|
|
280
|
+
# Tool registration
|
|
281
|
+
# ------------------------------------------------------------------
|
|
282
|
+
|
|
283
|
+
def tool(
|
|
284
|
+
self,
|
|
285
|
+
fn: Callable[P, R] | None = None,
|
|
286
|
+
*,
|
|
287
|
+
name: str | None = None,
|
|
288
|
+
requires_approval: bool = False,
|
|
289
|
+
cost: float | None = None,
|
|
290
|
+
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
|
291
|
+
"""Decorator to register a tool with the agent.
|
|
292
|
+
|
|
293
|
+
Example:
|
|
294
|
+
@agent.tool
|
|
295
|
+
def my_tool(x: int) -> int:
|
|
296
|
+
return x * 2
|
|
297
|
+
|
|
298
|
+
@agent.tool(requires_approval=True, cost=0.10)
|
|
299
|
+
def expensive_tool(data: str) -> str:
|
|
300
|
+
...
|
|
301
|
+
"""
|
|
302
|
+
|
|
303
|
+
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
|
304
|
+
"""Register *func* in the tool registry."""
|
|
305
|
+
tool_name = name or func.__name__
|
|
306
|
+
self._tool_registry.register(
|
|
307
|
+
tool_name,
|
|
308
|
+
func,
|
|
309
|
+
requires_approval=requires_approval,
|
|
310
|
+
cost=cost,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
@functools.wraps(func)
|
|
314
|
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
315
|
+
"""Delegate to ``Agent.call`` for traced execution."""
|
|
316
|
+
return cast(R, self.call(tool_name, *args, **kwargs))
|
|
317
|
+
|
|
318
|
+
return wrapper
|
|
319
|
+
|
|
320
|
+
if fn is not None:
|
|
321
|
+
return decorator(fn)
|
|
322
|
+
return decorator
|
|
323
|
+
|
|
324
|
+
def register_tool(
|
|
325
|
+
self,
|
|
326
|
+
name: str,
|
|
327
|
+
fn: Callable[..., Any],
|
|
328
|
+
requires_approval: bool = False,
|
|
329
|
+
cost: float | None = None,
|
|
330
|
+
) -> None:
|
|
331
|
+
"""Register a tool with the agent.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
name: Name to register the tool under.
|
|
335
|
+
fn: The tool function.
|
|
336
|
+
requires_approval: Whether tool requires human approval.
|
|
337
|
+
cost: Optional cost per invocation.
|
|
338
|
+
"""
|
|
339
|
+
self._tool_registry.register(
|
|
340
|
+
name,
|
|
341
|
+
fn,
|
|
342
|
+
requires_approval=requires_approval,
|
|
343
|
+
cost=cost,
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
def compensate(
|
|
347
|
+
self,
|
|
348
|
+
tool_name: str,
|
|
349
|
+
compensation: Callable[..., Any],
|
|
350
|
+
) -> None:
|
|
351
|
+
"""Register a compensation function for a tool.
|
|
352
|
+
|
|
353
|
+
Compensation is called during rollback if the tool succeeded
|
|
354
|
+
but a later step failed.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
tool_name: Name of the tool.
|
|
358
|
+
compensation: Callable invoked with the tool output during
|
|
359
|
+
rollback.
|
|
360
|
+
|
|
361
|
+
Example:
|
|
362
|
+
agent.compensate(
|
|
363
|
+
"create_user",
|
|
364
|
+
lambda ctx: db.delete_user(ctx["user_id"]),
|
|
365
|
+
)
|
|
366
|
+
"""
|
|
367
|
+
self.txn.set_compensation(tool_name, compensation)
|
|
368
|
+
self._tool_registry.set_compensation(tool_name, compensation)
|
|
369
|
+
|
|
370
|
+
# ------------------------------------------------------------------
|
|
371
|
+
# Synchronous execution
|
|
372
|
+
# ------------------------------------------------------------------
|
|
373
|
+
|
|
374
|
+
def call(self, tool_name: str, *args: Any, **kwargs: Any) -> Any:
|
|
375
|
+
"""Call a tool by name with tracing and middleware.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
tool_name: Name of the registered tool.
|
|
379
|
+
*args: Positional arguments (converted to kwargs via sig).
|
|
380
|
+
**kwargs: Keyword arguments.
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
Tool's return value.
|
|
384
|
+
|
|
385
|
+
Raises:
|
|
386
|
+
Various exceptions from middleware (ValidationError, etc.)
|
|
387
|
+
RuntimeError: If tool is async and called from async context.
|
|
388
|
+
"""
|
|
389
|
+
tool_def = self._tool_registry.get(tool_name)
|
|
390
|
+
|
|
391
|
+
# Async tool guard
|
|
392
|
+
if inspect.iscoroutinefunction(tool_def.fn):
|
|
393
|
+
try:
|
|
394
|
+
asyncio.get_running_loop()
|
|
395
|
+
except RuntimeError:
|
|
396
|
+
return asyncio.run(
|
|
397
|
+
self.acall(tool_name, *args, **kwargs),
|
|
398
|
+
)
|
|
399
|
+
raise RuntimeError(
|
|
400
|
+
f"Tool '{tool_name}' is async. Use agent.acall() "
|
|
401
|
+
f"instead of agent.call() when in an async context."
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
if args:
|
|
405
|
+
kwargs = _convert_args_to_kwargs(tool_def, args, kwargs)
|
|
406
|
+
|
|
407
|
+
trace, ctx = self._prepare_call(tool_name, tool_def, kwargs)
|
|
408
|
+
|
|
409
|
+
trace.start()
|
|
410
|
+
try:
|
|
411
|
+
result = self._chain.execute(ctx, tool_def.fn)
|
|
412
|
+
trace.succeed(result)
|
|
413
|
+
except Exception as exc:
|
|
414
|
+
self._verification.extract_certificate(ctx)
|
|
415
|
+
trace.fail(str(exc))
|
|
416
|
+
self._traces.append(trace)
|
|
417
|
+
self.txn.record_trace(trace)
|
|
418
|
+
raise
|
|
419
|
+
|
|
420
|
+
self._verification.extract_certificate(ctx)
|
|
421
|
+
self._traces.append(trace)
|
|
422
|
+
self.txn.record_trace(trace)
|
|
423
|
+
return result
|
|
424
|
+
|
|
425
|
+
# ------------------------------------------------------------------
|
|
426
|
+
# Transaction delegation
|
|
427
|
+
# ------------------------------------------------------------------
|
|
428
|
+
|
|
429
|
+
@contextmanager
|
|
430
|
+
def transaction(self) -> Generator[None, None, None]:
|
|
431
|
+
"""Execute tools in a transaction with automatic rollback.
|
|
432
|
+
|
|
433
|
+
If any tool fails, all previously successful tools have
|
|
434
|
+
their compensation functions called in reverse order.
|
|
435
|
+
|
|
436
|
+
Example:
|
|
437
|
+
with agent.transaction():
|
|
438
|
+
agent.call("create_user", email="test@example.com")
|
|
439
|
+
agent.call("charge_card", amount=99.00)
|
|
440
|
+
"""
|
|
441
|
+
with self.txn.transaction():
|
|
442
|
+
yield
|
|
443
|
+
|
|
444
|
+
def rollback(self) -> None:
|
|
445
|
+
"""Rollback the current transaction synchronously."""
|
|
446
|
+
self.txn.rollback()
|
|
447
|
+
|
|
448
|
+
# ------------------------------------------------------------------
|
|
449
|
+
# Trace management
|
|
450
|
+
# ------------------------------------------------------------------
|
|
451
|
+
|
|
452
|
+
def record_trace(self, trace: Trace) -> None:
|
|
453
|
+
"""Record a trace from an integration or external validation path."""
|
|
454
|
+
self._record_trace(trace)
|
|
455
|
+
|
|
456
|
+
def clear_traces(self) -> None:
|
|
457
|
+
"""Clear all traces."""
|
|
458
|
+
self._traces.clear()
|
|
459
|
+
|
|
460
|
+
def _record_trace(self, trace: Trace) -> None:
|
|
461
|
+
"""Record a trace and keep transaction state in sync."""
|
|
462
|
+
self._traces.append(trace)
|
|
463
|
+
self.txn.record_trace(trace)
|
|
464
|
+
|
|
465
|
+
# ------------------------------------------------------------------
|
|
466
|
+
# Middleware management
|
|
467
|
+
# ------------------------------------------------------------------
|
|
468
|
+
|
|
469
|
+
def add_middleware(self, middleware: Middleware) -> None:
|
|
470
|
+
"""Add middleware to the stack."""
|
|
471
|
+
self.middleware.append(middleware)
|
|
472
|
+
self._chain = MiddlewareChain(self.middleware)
|
|
473
|
+
|
|
474
|
+
# ------------------------------------------------------------------
|
|
475
|
+
# Lifecycle
|
|
476
|
+
# ------------------------------------------------------------------
|
|
477
|
+
|
|
478
|
+
def _close(self) -> None:
|
|
479
|
+
"""Close the agent and release middleware resources."""
|
|
480
|
+
for mw in self.middleware:
|
|
481
|
+
closer = getattr(mw, "close", None)
|
|
482
|
+
if callable(closer):
|
|
483
|
+
closer()
|
|
484
|
+
|
|
485
|
+
def __enter__(self) -> "Agent":
|
|
486
|
+
"""Enter context manager."""
|
|
487
|
+
return self
|
|
488
|
+
|
|
489
|
+
def __exit__(self, *_: Any) -> None:
|
|
490
|
+
"""Exit context manager and close resources."""
|
|
491
|
+
self._close()
|
|
492
|
+
|
|
493
|
+
# ------------------------------------------------------------------
|
|
494
|
+
# Async methods
|
|
495
|
+
# ------------------------------------------------------------------
|
|
496
|
+
|
|
497
|
+
async def acall(
|
|
498
|
+
self,
|
|
499
|
+
tool_name: str,
|
|
500
|
+
*args: Any,
|
|
501
|
+
**kwargs: Any,
|
|
502
|
+
) -> Any:
|
|
503
|
+
"""Async call a tool by name with tracing and middleware.
|
|
504
|
+
|
|
505
|
+
Supports both sync and async tool functions. This method should
|
|
506
|
+
be used when running in an async context (e.g., FastAPI).
|
|
507
|
+
|
|
508
|
+
Args:
|
|
509
|
+
tool_name: Name of the registered tool.
|
|
510
|
+
*args: Positional arguments (converted to kwargs via sig).
|
|
511
|
+
**kwargs: Keyword arguments.
|
|
512
|
+
|
|
513
|
+
Returns:
|
|
514
|
+
Tool's return value.
|
|
515
|
+
|
|
516
|
+
Example:
|
|
517
|
+
result = await agent.acall(
|
|
518
|
+
"fetch_data", url="https://api.example.com",
|
|
519
|
+
)
|
|
520
|
+
"""
|
|
521
|
+
tool_def = self._tool_registry.get(tool_name)
|
|
522
|
+
|
|
523
|
+
if args:
|
|
524
|
+
kwargs = _convert_args_to_kwargs(tool_def, args, kwargs)
|
|
525
|
+
|
|
526
|
+
trace, ctx = self._prepare_call(tool_name, tool_def, kwargs)
|
|
527
|
+
|
|
528
|
+
trace.start()
|
|
529
|
+
try:
|
|
530
|
+
result = await self._chain.aexecute(ctx, tool_def.fn)
|
|
531
|
+
trace.succeed(result)
|
|
532
|
+
except Exception as exc:
|
|
533
|
+
self._verification.extract_certificate(ctx)
|
|
534
|
+
trace.fail(str(exc))
|
|
535
|
+
self._traces.append(trace)
|
|
536
|
+
self.txn.record_trace(trace)
|
|
537
|
+
raise
|
|
538
|
+
|
|
539
|
+
self._verification.extract_certificate(ctx)
|
|
540
|
+
self._traces.append(trace)
|
|
541
|
+
self.txn.record_trace(trace)
|
|
542
|
+
return result
|
|
543
|
+
|
|
544
|
+
@asynccontextmanager
|
|
545
|
+
async def atransaction(self) -> AsyncGenerator[None, None]:
|
|
546
|
+
"""Async transaction with automatic rollback.
|
|
547
|
+
|
|
548
|
+
If any tool fails, all previously successful tools have
|
|
549
|
+
their compensation functions called in reverse order.
|
|
550
|
+
Supports both sync and async compensation functions.
|
|
551
|
+
|
|
552
|
+
Example:
|
|
553
|
+
async with agent.atransaction():
|
|
554
|
+
await agent.acall("create_user", email="test@example.com")
|
|
555
|
+
await agent.acall("charge_card", amount=99.00)
|
|
556
|
+
"""
|
|
557
|
+
async with self.txn.atransaction():
|
|
558
|
+
yield
|
|
559
|
+
|
|
560
|
+
async def arollback(self) -> None:
|
|
561
|
+
"""Async rollback the current transaction."""
|
|
562
|
+
await self.txn.arollback()
|
|
563
|
+
|
|
564
|
+
# ------------------------------------------------------------------
|
|
565
|
+
# Private helpers
|
|
566
|
+
# ------------------------------------------------------------------
|
|
567
|
+
|
|
568
|
+
def _prepare_call(
|
|
569
|
+
self,
|
|
570
|
+
tool_name: str,
|
|
571
|
+
tool_def: ToolDef,
|
|
572
|
+
kwargs: dict[str, Any],
|
|
573
|
+
) -> tuple[Trace, MiddlewareContext]:
|
|
574
|
+
"""Build Trace and MiddlewareContext for a tool call.
|
|
575
|
+
|
|
576
|
+
Args:
|
|
577
|
+
tool_name: Name of the tool being called.
|
|
578
|
+
tool_def: The tool definition.
|
|
579
|
+
kwargs: Resolved keyword arguments.
|
|
580
|
+
|
|
581
|
+
Returns:
|
|
582
|
+
A (Trace, MiddlewareContext) tuple ready for execution.
|
|
583
|
+
"""
|
|
584
|
+
trace = Trace(tool=tool_name, inputs=kwargs)
|
|
585
|
+
trace.context.compensation = (
|
|
586
|
+
tool_def.compensation.__name__ if tool_def.compensation else None
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
ctx = MiddlewareContext(
|
|
590
|
+
tool=tool_name,
|
|
591
|
+
inputs=kwargs,
|
|
592
|
+
trace=trace,
|
|
593
|
+
agent_id=self.config.agent_id,
|
|
594
|
+
session_id=self.config.session_id,
|
|
595
|
+
user_id=self.config.user_id,
|
|
596
|
+
metadata={
|
|
597
|
+
"requires_approval": tool_def.requires_approval,
|
|
598
|
+
},
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
if tool_def.cost is not None:
|
|
602
|
+
ctx.cost = tool_def.cost
|
|
603
|
+
|
|
604
|
+
return trace, ctx
|