decodingtrust-agent-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent/__init__.py +30 -0
- agent/claudesdk/__init__.py +8 -0
- agent/claudesdk/example.py +221 -0
- agent/claudesdk/src/__init__.py +8 -0
- agent/claudesdk/src/agent.py +400 -0
- agent/claudesdk/src/mcp_proxy.py +409 -0
- agent/claudesdk/src/utils.py +420 -0
- agent/googleadk/__init__.py +15 -0
- agent/googleadk/example.py +237 -0
- agent/googleadk/src/__init__.py +12 -0
- agent/googleadk/src/agent.py +401 -0
- agent/googleadk/src/mcp_wrapper.py +163 -0
- agent/googleadk/src/utils.py +602 -0
- agent/langchain/__init__.py +8 -0
- agent/langchain/example.py +213 -0
- agent/langchain/src/__init__.py +8 -0
- agent/langchain/src/agent.py +645 -0
- agent/langchain/src/utils.py +433 -0
- agent/openaisdk/__init__.py +17 -0
- agent/openaisdk/example.py +228 -0
- agent/openaisdk/src/__init__.py +12 -0
- agent/openaisdk/src/agent.py +491 -0
- agent/openaisdk/src/agent_wrapper.py +143 -0
- agent/openaisdk/src/mcp_wrapper.py +395 -0
- agent/openaisdk/src/utils.py +493 -0
- agent/openclaw/__init__.py +10 -0
- agent/openclaw/example.py +251 -0
- agent/openclaw/src/__init__.py +14 -0
- agent/openclaw/src/agent.py +930 -0
- agent/openclaw/src/helpers/__init__.py +1 -0
- agent/openclaw/src/helpers/auth_helpers.py +55 -0
- agent/openclaw/src/mcp_proxy.py +564 -0
- agent/openclaw/src/plugin_generator.py +231 -0
- agent/openclaw/src/utils.py +341 -0
- agent/pocketflow/__init__.py +18 -0
- agent/pocketflow/example.py +221 -0
- agent/pocketflow/prompts/react_agent.py +46 -0
- agent/pocketflow/src/__init__.py +6 -0
- agent/pocketflow/src/agent.py +507 -0
- agent/pocketflow/src/agent_wrapper.py +159 -0
- agent/pocketflow/src/async_helper.py +92 -0
- agent/pocketflow/src/mcp_react_agent.py +279 -0
- agent/pocketflow/src/native_agent.py +74 -0
- agent/pocketflow/src/nodes.py +467 -0
- benchmark/__init__.py +0 -0
- benchmark/browser/benign.jsonl +34 -0
- benchmark/browser/direct.jsonl +85 -0
- benchmark/browser/indirect.jsonl +82 -0
- benchmark/code/benign.jsonl +0 -0
- benchmark/code/direct.jsonl +121 -0
- benchmark/code/indirect.jsonl +165 -0
- benchmark/crm/benign.jsonl +165 -0
- benchmark/crm/direct.jsonl +90 -0
- benchmark/crm/indirect.jsonl +150 -0
- benchmark/customer-service/benign.jsonl +160 -0
- benchmark/customer-service/direct.jsonl +100 -0
- benchmark/customer-service/indirect.jsonl +101 -0
- benchmark/finance/benign.jsonl +0 -0
- benchmark/finance/direct.jsonl +200 -0
- benchmark/finance/indirect.jsonl +200 -0
- benchmark/legal/benign.jsonl +0 -0
- benchmark/legal/direct.jsonl +200 -0
- benchmark/legal/indirect.jsonl +200 -0
- benchmark/macos/benign.jsonl +30 -0
- benchmark/macos/direct.jsonl +50 -0
- benchmark/macos/indirect.jsonl +50 -0
- benchmark/medical/benign.jsonl +642 -0
- benchmark/medical/direct.jsonl +229 -0
- benchmark/medical/indirect.jsonl +222 -0
- benchmark/os-filesystem/benign.jsonl +200 -0
- benchmark/os-filesystem/direct.jsonl +200 -0
- benchmark/os-filesystem/indirect.jsonl +200 -0
- benchmark/research/benign.jsonl +0 -0
- benchmark/research/direct.jsonl +119 -0
- benchmark/research/indirect.jsonl +125 -0
- benchmark/telecom/benign.jsonl +120 -0
- benchmark/telecom/direct.jsonl +161 -0
- benchmark/telecom/indirect.jsonl +166 -0
- benchmark/travel/benign.jsonl +130 -0
- benchmark/travel/direct.jsonl +105 -0
- benchmark/travel/indirect.jsonl +120 -0
- benchmark/windows/benign.jsonl +100 -0
- benchmark/windows/direct.jsonl +140 -0
- benchmark/windows/indirect.jsonl +107 -0
- benchmark/workflow/benign.jsonl +335 -0
- benchmark/workflow/direct.jsonl +78 -0
- benchmark/workflow/indirect.jsonl +107 -0
- cli/__init__.py +5 -0
- cli/main.py +182 -0
- cli/scaffold.py +334 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/METADATA +642 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/RECORD +374 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/WHEEL +5 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/entry_points.txt +2 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/licenses/LICENSE +201 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/top_level.txt +6 -0
- dt_arena/config/env.yaml +515 -0
- dt_arena/config/injection_mcp.yaml +430 -0
- dt_arena/config/mcp.yaml +642 -0
- dt_arena/envs/arxiv/docker-compose-hub.yml +31 -0
- dt_arena/envs/arxiv/docker-compose.yml +36 -0
- dt_arena/envs/atlassian/docker/docker-compose.dev.yml +65 -0
- dt_arena/envs/atlassian/docker/docker-compose.yml +53 -0
- dt_arena/envs/atlassian/docker-compose-hub.yml +57 -0
- dt_arena/envs/atlassian/docker-compose.yml +72 -0
- dt_arena/envs/bigquery/docker-compose.yml +20 -0
- dt_arena/envs/booking/docker-compose.yml +59 -0
- dt_arena/envs/calendar/docker-compose-hub.yml +30 -0
- dt_arena/envs/calendar/docker-compose.yml +42 -0
- dt_arena/envs/custom-website/docker-compose.yml +6 -0
- dt_arena/envs/customer_service/docker-compose.yml +59 -0
- dt_arena/envs/databricks/docker-compose-hub.yml +47 -0
- dt_arena/envs/databricks/docker-compose.yml +51 -0
- dt_arena/envs/ecommerce/docker-compose.yml +6 -0
- dt_arena/envs/ers/docker-compose.yml +36 -0
- dt_arena/envs/ers/hrms/docker/docker-compose.yml +31 -0
- dt_arena/envs/finance/docker-compose.yml +23 -0
- dt_arena/envs/github/docker/docker-compose-hub.yml +50 -0
- dt_arena/envs/github/docker/docker-compose.yml +50 -0
- dt_arena/envs/gmail/docker-compose-hub.yml +51 -0
- dt_arena/envs/gmail/docker-compose.yml +65 -0
- dt_arena/envs/google-form/docker-compose-hub.yml +33 -0
- dt_arena/envs/google-form/docker-compose.yml +41 -0
- dt_arena/envs/googledocs/docker-compose-hub.yml +61 -0
- dt_arena/envs/googledocs/docker-compose.yml +78 -0
- dt_arena/envs/hospital/docker-compose-hub.yml +25 -0
- dt_arena/envs/hospital/docker-compose.yml +27 -0
- dt_arena/envs/legal/docker-compose.yml +22 -0
- dt_arena/envs/linkedin/docker-compose.yml +63 -0
- dt_arena/envs/macos/docker-compose.yml +79 -0
- dt_arena/envs/os-filesystem/docker-compose-hub.yml +16 -0
- dt_arena/envs/os-filesystem/docker-compose.yml +20 -0
- dt_arena/envs/paypal/docker-compose-hub.yml +48 -0
- dt_arena/envs/paypal/docker-compose.yml +63 -0
- dt_arena/envs/research/docker-compose-hub.yml +13 -0
- dt_arena/envs/research/docker-compose.yml +24 -0
- dt_arena/envs/salesforce_crm/docker-compose-hub.yaml +45 -0
- dt_arena/envs/salesforce_crm/docker-compose.yaml +49 -0
- dt_arena/envs/slack/docker-compose-hub.yml +28 -0
- dt_arena/envs/slack/docker-compose.yml +41 -0
- dt_arena/envs/snowflake/docker-compose-hub.yml +41 -0
- dt_arena/envs/snowflake/docker-compose.yml +44 -0
- dt_arena/envs/telecom/docker-compose-hub.yml +16 -0
- dt_arena/envs/telecom/docker-compose.yml +17 -0
- dt_arena/envs/telegram/docker-compose-hub.yml +57 -0
- dt_arena/envs/telegram/docker-compose.yml +62 -0
- dt_arena/envs/terminal/docker-compose-hub.yml +12 -0
- dt_arena/envs/terminal/docker-compose.yml +26 -0
- dt_arena/envs/travel/docker-compose-hub.yml +19 -0
- dt_arena/envs/travel/docker-compose.yml +19 -0
- dt_arena/envs/whatsapp/docker-compose-hub.yml +61 -0
- dt_arena/envs/whatsapp/docker-compose.yml +78 -0
- dt_arena/envs/windows/docker-compose.yml +71 -0
- dt_arena/envs/zoom/docker-compose-hub.yml +27 -0
- dt_arena/envs/zoom/docker-compose.yml +40 -0
- dt_arena/injection_mcp_server/atlassian/env_injection.py +134 -0
- dt_arena/injection_mcp_server/calendar/env_injection.py +217 -0
- dt_arena/injection_mcp_server/custom_website/env_injection.py +97 -0
- dt_arena/injection_mcp_server/customer_service/env_injection.py +659 -0
- dt_arena/injection_mcp_server/databricks/env_injection.py +255 -0
- dt_arena/injection_mcp_server/ecommerce/env_injection.py +110 -0
- dt_arena/injection_mcp_server/finance/env_injection.py +85 -0
- dt_arena/injection_mcp_server/github/env_injection.py +206 -0
- dt_arena/injection_mcp_server/gmail/env_injection.py +211 -0
- dt_arena/injection_mcp_server/google_form/env_injection.py +186 -0
- dt_arena/injection_mcp_server/googledocs/env_injection.py +44 -0
- dt_arena/injection_mcp_server/hospital/env_injection.py +43 -0
- dt_arena/injection_mcp_server/legal/env_injection.py +229 -0
- dt_arena/injection_mcp_server/macos/env_injection.py +272 -0
- dt_arena/injection_mcp_server/os-filesystem/env_injection.py +341 -0
- dt_arena/injection_mcp_server/paypal/env_injection.py +268 -0
- dt_arena/injection_mcp_server/research/env_injection.py +616 -0
- dt_arena/injection_mcp_server/salesforce/env_injection.py +514 -0
- dt_arena/injection_mcp_server/slack/env_injection.py +265 -0
- dt_arena/injection_mcp_server/snowflake/env_injection.py +230 -0
- dt_arena/injection_mcp_server/telecom/env_injection.py +503 -0
- dt_arena/injection_mcp_server/telegram/env_injection.py +171 -0
- dt_arena/injection_mcp_server/terminal/env_injection.py +523 -0
- dt_arena/injection_mcp_server/travel/env_injection.py +173 -0
- dt_arena/injection_mcp_server/whatsapp/env_injection.py +185 -0
- dt_arena/injection_mcp_server/windows/env_injection.py +943 -0
- dt_arena/injection_mcp_server/zoom/env_injection.py +216 -0
- dt_arena/mcp_server/atlassian/main.py +1554 -0
- dt_arena/mcp_server/atlassian/test_server.py +66 -0
- dt_arena/mcp_server/bigquery/main.py +333 -0
- dt_arena/mcp_server/booking/main.py +310 -0
- dt_arena/mcp_server/browser/main.py +1741 -0
- dt_arena/mcp_server/calendar/example_multi_user.py +162 -0
- dt_arena/mcp_server/calendar/main.py +792 -0
- dt_arena/mcp_server/calendar/test_mcp.py +135 -0
- dt_arena/mcp_server/customer_service/main.py +1063 -0
- dt_arena/mcp_server/databricks/main.py +566 -0
- dt_arena/mcp_server/databricks/probe.py +102 -0
- dt_arena/mcp_server/ers/main.py +845 -0
- dt_arena/mcp_server/finance/__init__.py +87 -0
- dt_arena/mcp_server/finance/core/__init__.py +12 -0
- dt_arena/mcp_server/finance/core/data_loader.py +558 -0
- dt_arena/mcp_server/finance/core/portfolio.py +565 -0
- dt_arena/mcp_server/finance/evaluation/__init__.py +20 -0
- dt_arena/mcp_server/finance/evaluation/evaluator.py +217 -0
- dt_arena/mcp_server/finance/evaluation/logger.py +137 -0
- dt_arena/mcp_server/finance/injection/__init__.py +66 -0
- dt_arena/mcp_server/finance/injection/config.py +176 -0
- dt_arena/mcp_server/finance/injection/content.py +755 -0
- dt_arena/mcp_server/finance/injection/html.py +409 -0
- dt_arena/mcp_server/finance/injection/locations.py +167 -0
- dt_arena/mcp_server/finance/injection/methods.py +193 -0
- dt_arena/mcp_server/finance/injection/presets.py +1023 -0
- dt_arena/mcp_server/finance/main.py +361 -0
- dt_arena/mcp_server/finance/run_mcp.py +21 -0
- dt_arena/mcp_server/finance/run_web.py +26 -0
- dt_arena/mcp_server/finance/server/__init__.py +41 -0
- dt_arena/mcp_server/finance/server/extractor.py +1453 -0
- dt_arena/mcp_server/finance/server/extractor_minimal.py +292 -0
- dt_arena/mcp_server/finance/server/extractor_simple.py +1164 -0
- dt_arena/mcp_server/finance/server/injection_mcp.py +865 -0
- dt_arena/mcp_server/finance/server/mcp.py +451 -0
- dt_arena/mcp_server/finance/server/tools/__init__.py +23 -0
- dt_arena/mcp_server/finance/server/tools/account.py +88 -0
- dt_arena/mcp_server/finance/server/tools/browsing.py +328 -0
- dt_arena/mcp_server/finance/server/tools/social.py +73 -0
- dt_arena/mcp_server/finance/server/tools/trading.py +242 -0
- dt_arena/mcp_server/finance/server/tools/utility.py +49 -0
- dt_arena/mcp_server/finance/server/web.py +2139 -0
- dt_arena/mcp_server/finance/tasks/benchmark/__init__.py +28 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_pool.py +3026 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_runner.py +1315 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_requirements.py +1335 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_tasks.py +3665 -0
- dt_arena/mcp_server/finance/tasks/benchmark/malicious_tasks.py +2673 -0
- dt_arena/mcp_server/finance/tasks/redteam_suite/run_redteam_suite.py +1713 -0
- dt_arena/mcp_server/finance/test_mcp_tools.py +476 -0
- dt_arena/mcp_server/github/main.py +441 -0
- dt_arena/mcp_server/gmail/main.py +1004 -0
- dt_arena/mcp_server/google_form/main.py +141 -0
- dt_arena/mcp_server/googledocs/main.py +458 -0
- dt_arena/mcp_server/hospital/mcp_server.py +458 -0
- dt_arena/mcp_server/legal/__init__.py +9 -0
- dt_arena/mcp_server/legal/core/__init__.py +14 -0
- dt_arena/mcp_server/legal/core/courtlistener_store.py +762 -0
- dt_arena/mcp_server/legal/core/data_loader.py +266 -0
- dt_arena/mcp_server/legal/core/document_store.py +197 -0
- dt_arena/mcp_server/legal/core/matter_manager.py +466 -0
- dt_arena/mcp_server/legal/main.py +89 -0
- dt_arena/mcp_server/legal/scripts/collect_data.py +988 -0
- dt_arena/mcp_server/legal/server/__init__.py +14 -0
- dt_arena/mcp_server/legal/server/mcp.py +2330 -0
- dt_arena/mcp_server/macos/client_test.py +270 -0
- dt_arena/mcp_server/macos/mcp_server.py +285 -0
- dt_arena/mcp_server/os-filesystem/main.py +1380 -0
- dt_arena/mcp_server/paypal/main.py +501 -0
- dt_arena/mcp_server/research/main.py +777 -0
- dt_arena/mcp_server/salesforce/main.py +2006 -0
- dt_arena/mcp_server/slack/main.py +318 -0
- dt_arena/mcp_server/snowflake/main.py +612 -0
- dt_arena/mcp_server/snowflake/probe.py +183 -0
- dt_arena/mcp_server/telecom/mcp_client.py +423 -0
- dt_arena/mcp_server/telecom/mcp_server.py +1059 -0
- dt_arena/mcp_server/telegram/main.py +338 -0
- dt_arena/mcp_server/terminal/main.py +163 -0
- dt_arena/mcp_server/travel/client_test.py +16 -0
- dt_arena/mcp_server/travel/mcp_server.py +404 -0
- dt_arena/mcp_server/whatsapp/main.py +318 -0
- dt_arena/mcp_server/windows/client_test.py +270 -0
- dt_arena/mcp_server/windows/mcp_server.py +218 -0
- dt_arena/mcp_server/zoom/main.py +466 -0
- dt_arena/src/__init__.py +0 -0
- dt_arena/src/hooks/__init__.py +0 -0
- dt_arena/src/hooks/audit_log.py +30 -0
- dt_arena/src/hooks/hooks.json +3 -0
- dt_arena/src/run_benign.py +142 -0
- dt_arena/src/types/__init__.py +0 -0
- dt_arena/src/types/agent.py +441 -0
- dt_arena/src/types/attacks.py +2 -0
- dt_arena/src/types/environment.py +2 -0
- dt_arena/src/types/hooks.py +174 -0
- dt_arena/src/types/judge.py +52 -0
- dt_arena/src/types/red_teaming_trajectory.py +385 -0
- dt_arena/src/types/task.py +260 -0
- dt_arena/src/types/trajectory.py +315 -0
- dt_arena/utils/__init__.py +1 -0
- dt_arena/utils/atlassian/__init__.py +27 -0
- dt_arena/utils/atlassian/helpers.py +520 -0
- dt_arena/utils/bigquery/__init__.py +1 -0
- dt_arena/utils/bigquery/helpers.py +246 -0
- dt_arena/utils/calendar/__init__.py +1 -0
- dt_arena/utils/calendar/helpers.py +87 -0
- dt_arena/utils/customer_service/__init__.py +17 -0
- dt_arena/utils/customer_service/cs_env_client.py +940 -0
- dt_arena/utils/customer_service/helpers.py +339 -0
- dt_arena/utils/customer_service/judges/__init__.py +20 -0
- dt_arena/utils/customer_service/judges/abusive_comms.py +149 -0
- dt_arena/utils/customer_service/judges/bulk_destructive.py +202 -0
- dt_arena/utils/customer_service/judges/cross_customer.py +481 -0
- dt_arena/utils/customer_service/judges/email_exfiltration.py +320 -0
- dt_arena/utils/customer_service/judges/financial_fraud.py +221 -0
- dt_arena/utils/customer_service/judges/llm_utils.py +264 -0
- dt_arena/utils/customer_service/judges/pii_disclosure.py +837 -0
- dt_arena/utils/customer_service/judges/policy_misrep.py +44 -0
- dt_arena/utils/customer_service/judges/text_utils.py +21 -0
- dt_arena/utils/databricks/__init__.py +2 -0
- dt_arena/utils/databricks/helpers.py +210 -0
- dt_arena/utils/finance/__init__.py +0 -0
- dt_arena/utils/finance/helpers.py +263 -0
- dt_arena/utils/github/__init__.py +1 -0
- dt_arena/utils/github/helpers.py +249 -0
- dt_arena/utils/gmail/__init__.py +1 -0
- dt_arena/utils/gmail/helpers.py +344 -0
- dt_arena/utils/google_form/__init__.py +2 -0
- dt_arena/utils/google_form/helpers.py +133 -0
- dt_arena/utils/legal/__init__.py +0 -0
- dt_arena/utils/legal/helpers.py +228 -0
- dt_arena/utils/macos/__init__.py +0 -0
- dt_arena/utils/macos/env_setup.py +215 -0
- dt_arena/utils/macos/helpers.py +61 -0
- dt_arena/utils/os_filesystem/__init__.py +1 -0
- dt_arena/utils/os_filesystem/helpers.py +366 -0
- dt_arena/utils/paypal/__init__.py +1 -0
- dt_arena/utils/paypal/helpers.py +178 -0
- dt_arena/utils/port_allocator.py +266 -0
- dt_arena/utils/research/__init__.py +0 -0
- dt_arena/utils/research/helpers.py +251 -0
- dt_arena/utils/salesforce/__init__.py +1 -0
- dt_arena/utils/salesforce/helpers.py +719 -0
- dt_arena/utils/slack/__init__.py +1 -0
- dt_arena/utils/slack/helpers.py +176 -0
- dt_arena/utils/snowflake/__init__.py +1 -0
- dt_arena/utils/snowflake/helpers.py +166 -0
- dt_arena/utils/telecom/__init__.py +1 -0
- dt_arena/utils/telecom/helpers.py +760 -0
- dt_arena/utils/telegram/__init__.py +0 -0
- dt_arena/utils/telegram/helpers.py +174 -0
- dt_arena/utils/terminal/__init__.py +0 -0
- dt_arena/utils/terminal/helpers.py +20 -0
- dt_arena/utils/travel/__init__.py +0 -0
- dt_arena/utils/travel/env_client.py +537 -0
- dt_arena/utils/travel/llm_judge.py +137 -0
- dt_arena/utils/travel/prompts.py +64 -0
- dt_arena/utils/utils/__init__.py +122 -0
- dt_arena/utils/whatsapp/__init__.py +0 -0
- dt_arena/utils/whatsapp/helpers.py +226 -0
- dt_arena/utils/windows/__init__.py +0 -0
- dt_arena/utils/windows/env_reset.py +224 -0
- dt_arena/utils/windows/env_setup.py +280 -0
- dt_arena/utils/windows/exfil_helpers.py +170 -0
- dt_arena/utils/windows/helpers.py +74 -0
- dt_arena/utils/zoom/__init__.py +1 -0
- dt_arena/utils/zoom/helpers.py +70 -0
- eval/__init__.py +1 -0
- eval/evaluation.py +426 -0
- eval/task_runner.py +449 -0
- utils/__init__.py +148 -0
- utils/agent_helpers.py +308 -0
- utils/agent_wrapper.py +189 -0
- utils/compose_utils.py +135 -0
- utils/config.py +77 -0
- utils/env_helpers.py +104 -0
- utils/eval_stats.py +88 -0
- utils/injection_helpers.py +429 -0
- utils/injection_mcp_helpers.py +152 -0
- utils/judge_helpers.py +181 -0
- utils/judge_utils.py +472 -0
- utils/llm.py +196 -0
- utils/logging.py +45 -0
- utils/mcp_helpers.py +232 -0
- utils/mcp_manager.py +235 -0
- utils/memory_guard.py +18 -0
- utils/red_teaming_sandbox.py +476 -0
- utils/reset_helpers.py +318 -0
- utils/resource_manager.py +370 -0
- utils/skill_helpers.py +447 -0
- utils/task_executor.py +904 -0
- utils/task_helpers.py +270 -0
- utils/template_helpers.py +179 -0
cli/main.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""Top-level Typer application wired as the `dtap` console script."""
|
|
2
|
+
import sys
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import typer
|
|
6
|
+
import yaml
|
|
7
|
+
from rich.console import Console
|
|
8
|
+
from rich.markup import escape
|
|
9
|
+
from rich.table import Table
|
|
10
|
+
|
|
11
|
+
from .scaffold import SUPPORTED_FRAMEWORKS, scaffold
|
|
12
|
+
|
|
13
|
+
app = typer.Typer(
|
|
14
|
+
name="dtap",
|
|
15
|
+
help="DecodingTrust Agent Platform: run red-teaming evaluations against AI agents.",
|
|
16
|
+
add_completion=False,
|
|
17
|
+
no_args_is_help=True,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
info_app = typer.Typer(
|
|
21
|
+
name="info",
|
|
22
|
+
help="Inspect supported domains and environments.",
|
|
23
|
+
add_completion=False,
|
|
24
|
+
no_args_is_help=True,
|
|
25
|
+
)
|
|
26
|
+
agent_app = typer.Typer(
|
|
27
|
+
name="agent",
|
|
28
|
+
help="List supported agent backends and scaffold new ones.",
|
|
29
|
+
add_completion=False,
|
|
30
|
+
no_args_is_help=True,
|
|
31
|
+
)
|
|
32
|
+
app.add_typer(info_app, name="info")
|
|
33
|
+
app.add_typer(agent_app, name="agent")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# Let `--help` flow through to the underlying argparse so users see real flags.
|
|
37
|
+
_PASSTHROUGH = {
|
|
38
|
+
"allow_extra_args": True,
|
|
39
|
+
"ignore_unknown_options": True,
|
|
40
|
+
"help_option_names": [],
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@app.command(
|
|
45
|
+
"eval",
|
|
46
|
+
context_settings=_PASSTHROUGH,
|
|
47
|
+
help="Run a parallel evaluation. Use 'dtap eval --help' for all flags.",
|
|
48
|
+
)
|
|
49
|
+
def cmd_eval(ctx: typer.Context) -> None:
|
|
50
|
+
from eval.evaluation import main
|
|
51
|
+
|
|
52
|
+
sys.argv = ["dtap eval"] + ctx.args
|
|
53
|
+
main()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# ── `dtap info …` subcommands ────────────────────────────────────────────────
|
|
57
|
+
|
|
58
|
+
@info_app.command("domain", help="List benchmark domains with task counts.")
|
|
59
|
+
def info_domain() -> None:
|
|
60
|
+
from utils.config import BENCHMARK_ROOT
|
|
61
|
+
|
|
62
|
+
console = Console()
|
|
63
|
+
table = Table(title="Benchmark domains", show_lines=False)
|
|
64
|
+
table.add_column("Domain", style="bold cyan")
|
|
65
|
+
table.add_column("Benign", justify="right")
|
|
66
|
+
table.add_column("Direct", justify="right")
|
|
67
|
+
table.add_column("Indirect", justify="right")
|
|
68
|
+
table.add_column("Total", justify="right", style="bold")
|
|
69
|
+
|
|
70
|
+
def _count(p):
|
|
71
|
+
if not p.exists():
|
|
72
|
+
return 0
|
|
73
|
+
with p.open() as f:
|
|
74
|
+
return sum(1 for line in f if line.strip())
|
|
75
|
+
|
|
76
|
+
domains = sorted(
|
|
77
|
+
p.name for p in BENCHMARK_ROOT.iterdir()
|
|
78
|
+
if p.is_dir() and not p.name.startswith("_")
|
|
79
|
+
)
|
|
80
|
+
grand_total = 0
|
|
81
|
+
for d in domains:
|
|
82
|
+
b = _count(BENCHMARK_ROOT / d / "benign.jsonl")
|
|
83
|
+
dr = _count(BENCHMARK_ROOT / d / "direct.jsonl")
|
|
84
|
+
ind = _count(BENCHMARK_ROOT / d / "indirect.jsonl")
|
|
85
|
+
total = b + dr + ind
|
|
86
|
+
grand_total += total
|
|
87
|
+
table.add_row(d, str(b), str(dr), str(ind), str(total))
|
|
88
|
+
|
|
89
|
+
console.print(table)
|
|
90
|
+
console.print(f"[dim]{len(domains)} domains, {grand_total} tasks total[/dim]")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@info_app.command("env", help="List Docker environments defined in dt_arena/config/env.yaml.")
|
|
94
|
+
def info_env() -> None:
|
|
95
|
+
from utils.config import ENV_CONFIG_PATH
|
|
96
|
+
|
|
97
|
+
console = Console()
|
|
98
|
+
with ENV_CONFIG_PATH.open() as f:
|
|
99
|
+
cfg = yaml.safe_load(f) or {}
|
|
100
|
+
envs = cfg.get("environments", {}) or {}
|
|
101
|
+
|
|
102
|
+
table = Table(title="Docker environments", show_lines=False)
|
|
103
|
+
table.add_column("Name", style="bold cyan")
|
|
104
|
+
table.add_column("Ports", justify="right")
|
|
105
|
+
table.add_column("Max instances", justify="right")
|
|
106
|
+
table.add_column("Compose file")
|
|
107
|
+
|
|
108
|
+
for name in sorted(envs.keys()):
|
|
109
|
+
spec = envs[name] or {}
|
|
110
|
+
ports = len(spec.get("ports", {}) or {})
|
|
111
|
+
max_inst = spec.get("max_instances", "—")
|
|
112
|
+
compose = spec.get("docker_compose", "—")
|
|
113
|
+
table.add_row(name, str(ports), str(max_inst), str(compose))
|
|
114
|
+
|
|
115
|
+
console.print(table)
|
|
116
|
+
console.print(f"[dim]{len(envs)} environments[/dim]")
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
# ── `dtap agent …` subcommands ───────────────────────────────────────────────
|
|
120
|
+
|
|
121
|
+
_AGENT_BACKENDS = [
|
|
122
|
+
("openaisdk", "openai", "OpenAI Agents SDK"),
|
|
123
|
+
("claudesdk", "claude", "Anthropic Claude Agent SDK"),
|
|
124
|
+
("googleadk", "google", "Google Agent Development Kit"),
|
|
125
|
+
("langchain", "langchain", "LangChain ReAct-style agent"),
|
|
126
|
+
("pocketflow", "pocketflow", "Pocketflow ReAct agent"),
|
|
127
|
+
("openclaw", "claude", "Claude SDK adapter (openclaw)"),
|
|
128
|
+
]
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@agent_app.command("list", help="List supported agent backends and their pip extras.")
|
|
132
|
+
def agent_list() -> None:
|
|
133
|
+
console = Console()
|
|
134
|
+
table = Table(title="Agent backends", show_lines=False, expand=False)
|
|
135
|
+
table.add_column("Agent type", style="bold cyan", no_wrap=True)
|
|
136
|
+
table.add_column("Install extra", no_wrap=True)
|
|
137
|
+
table.add_column("Description")
|
|
138
|
+
|
|
139
|
+
for agent_type, extra, desc in _AGENT_BACKENDS:
|
|
140
|
+
table.add_row(agent_type, escape(f"decodingtrust-agent-sdk[{extra}]"), desc)
|
|
141
|
+
|
|
142
|
+
console.print(table)
|
|
143
|
+
console.print("[dim]Pass --agent-type <name> to 'dtap eval'.[/dim]")
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@agent_app.command(
|
|
147
|
+
"create",
|
|
148
|
+
help="Scaffold a new custom agent under agent/<name>/.",
|
|
149
|
+
)
|
|
150
|
+
def agent_create(
|
|
151
|
+
name: str = typer.Argument(..., help="Agent name (e.g. my_agent). Becomes the directory name."),
|
|
152
|
+
framework: str = typer.Option(
|
|
153
|
+
None,
|
|
154
|
+
"--framework",
|
|
155
|
+
"-f",
|
|
156
|
+
help=f"Inherit from an existing backend ({', '.join(SUPPORTED_FRAMEWORKS)}). Omit to scaffold from scratch.",
|
|
157
|
+
),
|
|
158
|
+
out: Path = typer.Option(
|
|
159
|
+
Path("agent"),
|
|
160
|
+
"--out",
|
|
161
|
+
help="Parent directory for the new scaffold (default: ./agent).",
|
|
162
|
+
),
|
|
163
|
+
force: bool = typer.Option(
|
|
164
|
+
False,
|
|
165
|
+
"--force",
|
|
166
|
+
help="Overwrite an existing scaffold directory.",
|
|
167
|
+
),
|
|
168
|
+
) -> None:
|
|
169
|
+
console = Console()
|
|
170
|
+
try:
|
|
171
|
+
target = scaffold(name=name, framework=framework, out_root=out, force=force)
|
|
172
|
+
except (ValueError, FileExistsError) as e:
|
|
173
|
+
console.print(f"[red]error:[/red] {e}")
|
|
174
|
+
raise typer.Exit(code=1)
|
|
175
|
+
|
|
176
|
+
console.print(f"[green]Created[/green] {target}")
|
|
177
|
+
for f in sorted(target.iterdir()):
|
|
178
|
+
console.print(f" [dim]·[/dim] {f.relative_to(target.parent)}")
|
|
179
|
+
console.print(
|
|
180
|
+
f"\nNext: edit [bold]{target}/agent.py[/bold], then run "
|
|
181
|
+
f"[cyan]python -m {target.name}.example[/cyan] (from {out})."
|
|
182
|
+
)
|
cli/scaffold.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
"""Templates for `dtap agent create`."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import re
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# kind="wrapper": user writes build_native_agent() returning a framework-native
|
|
10
|
+
# agent; we wrap it via build_agent(native_agent=...).
|
|
11
|
+
# kind="subclass": user subclasses the framework's DTap Agent class directly.
|
|
12
|
+
|
|
13
|
+
_FRAMEWORK_INFO = {
|
|
14
|
+
"openaisdk": {
|
|
15
|
+
"kind": "wrapper",
|
|
16
|
+
"native_import": "from agents import Agent as OpenAIAgent",
|
|
17
|
+
"native_factory": (
|
|
18
|
+
'return OpenAIAgent(\n'
|
|
19
|
+
' name="{name}",\n'
|
|
20
|
+
' instructions="You are a helpful agent. TODO: customize.",\n'
|
|
21
|
+
' tools=[],\n'
|
|
22
|
+
' )'
|
|
23
|
+
),
|
|
24
|
+
},
|
|
25
|
+
"pocketflow": {
|
|
26
|
+
"kind": "wrapper",
|
|
27
|
+
"native_import": "from agent.pocketflow.src.native_agent import NativeMCPReactAgent",
|
|
28
|
+
"native_factory": (
|
|
29
|
+
"# TODO: configure your NativeMCPReactAgent\n"
|
|
30
|
+
" return NativeMCPReactAgent(\n"
|
|
31
|
+
' system_prompt="You are a helpful agent.",\n'
|
|
32
|
+
" )"
|
|
33
|
+
),
|
|
34
|
+
},
|
|
35
|
+
"claudesdk": {
|
|
36
|
+
"kind": "subclass",
|
|
37
|
+
"base_import": "from agent.claudesdk import ClaudeSDKAgent",
|
|
38
|
+
"base_class": "ClaudeSDKAgent",
|
|
39
|
+
},
|
|
40
|
+
"googleadk": {
|
|
41
|
+
"kind": "subclass",
|
|
42
|
+
"base_import": "from agent.googleadk import GoogleADKAgent",
|
|
43
|
+
"base_class": "GoogleADKAgent",
|
|
44
|
+
},
|
|
45
|
+
"langchain": {
|
|
46
|
+
"kind": "subclass",
|
|
47
|
+
"base_import": "from agent.langchain import LangChainAgent",
|
|
48
|
+
"base_class": "LangChainAgent",
|
|
49
|
+
},
|
|
50
|
+
"openclaw": {
|
|
51
|
+
"kind": "subclass",
|
|
52
|
+
"base_import": "from agent.openclaw import OpenClawAgent",
|
|
53
|
+
"base_class": "OpenClawAgent",
|
|
54
|
+
},
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
SUPPORTED_FRAMEWORKS = sorted(_FRAMEWORK_INFO.keys())
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _to_class_name(name: str) -> str:
|
|
61
|
+
"""Turn 'my_custom-agent' into 'MyCustomAgent' (avoid doubling the 'Agent' suffix)."""
|
|
62
|
+
parts = re.split(r"[_\- ]+", name.strip())
|
|
63
|
+
camel = "".join(p.capitalize() for p in parts if p)
|
|
64
|
+
return camel if camel.endswith("Agent") else camel + "Agent"
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _validate_name(name: str) -> None:
|
|
68
|
+
if not re.fullmatch(r"[a-zA-Z][a-zA-Z0-9_\-]*", name):
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"Invalid agent name '{name}'. Use letters, digits, underscores, hyphens, "
|
|
71
|
+
"starting with a letter."
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# ── Templates ────────────────────────────────────────────────────────────────
|
|
76
|
+
|
|
77
|
+
_TPL_INIT_WRAPPER = '''"""Custom agent: {class_name}
|
|
78
|
+
|
|
79
|
+
Generated by `dtap agent create {name} --framework {framework}`.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
from .agent import build_native_agent
|
|
83
|
+
|
|
84
|
+
__all__ = ["build_native_agent"]
|
|
85
|
+
'''
|
|
86
|
+
|
|
87
|
+
_TPL_INIT_SUBCLASS = '''"""Custom agent: {class_name}
|
|
88
|
+
|
|
89
|
+
Generated by `dtap agent create {name}{framework_arg}`.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
from .agent import {class_name}
|
|
93
|
+
|
|
94
|
+
__all__ = ["{class_name}"]
|
|
95
|
+
'''
|
|
96
|
+
|
|
97
|
+
_TPL_AGENT_WRAPPER = '''"""Native-agent factory for {class_name} (wrapped at runtime by DTap)."""
|
|
98
|
+
|
|
99
|
+
{native_import}
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def build_native_agent():
|
|
103
|
+
"""Return your native {framework} agent.
|
|
104
|
+
|
|
105
|
+
DTap will wrap this via `build_agent(native_agent=...)`, which adds
|
|
106
|
+
red-teaming MCP servers and trajectory tracking automatically.
|
|
107
|
+
"""
|
|
108
|
+
{native_factory}
|
|
109
|
+
'''
|
|
110
|
+
|
|
111
|
+
_TPL_AGENT_SUBCLASS = '''"""Custom DTap agent {class_name} (extends {base_class})."""
|
|
112
|
+
|
|
113
|
+
from typing import Any, Dict, Optional
|
|
114
|
+
|
|
115
|
+
{base_import}
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class {class_name}({base_class}):
|
|
119
|
+
"""TODO: describe what {name} does differently from {base_class}."""
|
|
120
|
+
|
|
121
|
+
async def initialize(self) -> None:
|
|
122
|
+
await super().initialize()
|
|
123
|
+
# TODO: extra setup after MCP servers are connected.
|
|
124
|
+
|
|
125
|
+
async def run(self, user_input: str, metadata: Optional[Dict[str, Any]] = None) -> Any:
|
|
126
|
+
# TODO: customize prompt assembly / tool routing here if needed.
|
|
127
|
+
return await super().run(user_input, metadata=metadata)
|
|
128
|
+
|
|
129
|
+
async def cleanup(self) -> None:
|
|
130
|
+
# TODO: release any extra resources your agent owns.
|
|
131
|
+
await super().cleanup()
|
|
132
|
+
'''
|
|
133
|
+
|
|
134
|
+
_TPL_AGENT_SCRATCH = '''"""Custom DTap agent {class_name}.
|
|
135
|
+
|
|
136
|
+
Subclasses the abstract Agent base. Fill in the four abstract methods.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
from typing import Any, Dict, Optional
|
|
140
|
+
|
|
141
|
+
from dt_arena.src.types.agent import Agent, AgentConfig, MCPServerConfig, RuntimeConfig
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class {class_name}(Agent):
|
|
145
|
+
"""TODO: describe what {name} does."""
|
|
146
|
+
|
|
147
|
+
def __init__(
|
|
148
|
+
self,
|
|
149
|
+
agent_config: Optional[AgentConfig] = None,
|
|
150
|
+
runtime_config: Optional[RuntimeConfig] = None,
|
|
151
|
+
) -> None:
|
|
152
|
+
super().__init__(agent_config, runtime_config)
|
|
153
|
+
# TODO: initialise per-instance state (model client, tracing, etc.)
|
|
154
|
+
|
|
155
|
+
async def initialize(self) -> None:
|
|
156
|
+
# TODO: connect MCP servers, set up tracing, etc.
|
|
157
|
+
raise NotImplementedError("Implement initialize().")
|
|
158
|
+
|
|
159
|
+
def _create_mcp_server(self, server_config: MCPServerConfig) -> Any:
|
|
160
|
+
# TODO: return an SDK-specific MCP server instance for `server_config`.
|
|
161
|
+
raise NotImplementedError("Implement _create_mcp_server().")
|
|
162
|
+
|
|
163
|
+
async def run(self, user_input: str, metadata: Optional[Dict[str, Any]] = None) -> Any:
|
|
164
|
+
# TODO: invoke the underlying LLM/agent and return an AgentResult.
|
|
165
|
+
raise NotImplementedError("Implement run().")
|
|
166
|
+
|
|
167
|
+
async def cleanup(self) -> None:
|
|
168
|
+
# TODO: close MCP server connections and release resources.
|
|
169
|
+
raise NotImplementedError("Implement cleanup().")
|
|
170
|
+
'''
|
|
171
|
+
|
|
172
|
+
_TPL_EXAMPLE_WRAPPER = '''"""Smoke test for the {class_name} scaffold (wrapper pattern)."""
|
|
173
|
+
|
|
174
|
+
import asyncio
|
|
175
|
+
|
|
176
|
+
from dt_arena.src.types.agent import AgentConfig, RuntimeConfig
|
|
177
|
+
from utils.agent_helpers import build_agent
|
|
178
|
+
|
|
179
|
+
from . import build_native_agent
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
async def main() -> None:
|
|
183
|
+
native = build_native_agent()
|
|
184
|
+
agent = build_agent(
|
|
185
|
+
native_agent=native,
|
|
186
|
+
agent_cfg=AgentConfig(system_prompt=""),
|
|
187
|
+
runtime_cfg=RuntimeConfig(model="gpt-4o", max_turns=10),
|
|
188
|
+
)
|
|
189
|
+
async with agent:
|
|
190
|
+
result = await agent.run("Say hello.", metadata={{"task_id": "smoke"}})
|
|
191
|
+
print(result.final_output)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
if __name__ == "__main__":
|
|
195
|
+
asyncio.run(main())
|
|
196
|
+
'''
|
|
197
|
+
|
|
198
|
+
_TPL_EXAMPLE_SUBCLASS = '''"""Smoke test for the {class_name} scaffold."""
|
|
199
|
+
|
|
200
|
+
import asyncio
|
|
201
|
+
|
|
202
|
+
from dt_arena.src.types.agent import AgentConfig, RuntimeConfig
|
|
203
|
+
|
|
204
|
+
from .agent import {class_name}
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
async def main() -> None:
|
|
208
|
+
agent = {class_name}(
|
|
209
|
+
agent_config=AgentConfig(system_prompt=""),
|
|
210
|
+
runtime_config=RuntimeConfig(model="gpt-4o", max_turns=10),
|
|
211
|
+
)
|
|
212
|
+
async with agent:
|
|
213
|
+
result = await agent.run("Say hello.", metadata={{"task_id": "smoke"}})
|
|
214
|
+
print(result.final_output)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
if __name__ == "__main__":
|
|
218
|
+
asyncio.run(main())
|
|
219
|
+
'''
|
|
220
|
+
|
|
221
|
+
_TPL_README_WRAPPER = '''# {class_name}
|
|
222
|
+
|
|
223
|
+
Generated by `dtap agent create {name} --framework {framework}`.
|
|
224
|
+
|
|
225
|
+
## Layout
|
|
226
|
+
|
|
227
|
+
- `agent.py` — your `build_native_agent()` factory. Edit this.
|
|
228
|
+
- `example.py` — runs a single turn through `build_agent(native_agent=...)`.
|
|
229
|
+
|
|
230
|
+
## Run
|
|
231
|
+
|
|
232
|
+
```bash
|
|
233
|
+
python -m agent.{name}.example
|
|
234
|
+
```
|
|
235
|
+
|
|
236
|
+
## Plug into `dtap eval`
|
|
237
|
+
|
|
238
|
+
To make `dtap eval --agent-type {name}` work, register this module in the
|
|
239
|
+
agent registry (see `utils/agent_helpers.py:_get_agent_registry`) or invoke
|
|
240
|
+
`build_agent(native_agent=build_native_agent(), ...)` from your own driver.
|
|
241
|
+
'''
|
|
242
|
+
|
|
243
|
+
_TPL_README_SUBCLASS = '''# {class_name}
|
|
244
|
+
|
|
245
|
+
Generated by `dtap agent create {name}{framework_arg}`.
|
|
246
|
+
|
|
247
|
+
## Layout
|
|
248
|
+
|
|
249
|
+
- `agent.py` — your `{class_name}` class. Edit the TODOs.
|
|
250
|
+
- `example.py` — instantiates the agent and runs one turn.
|
|
251
|
+
|
|
252
|
+
## Run
|
|
253
|
+
|
|
254
|
+
```bash
|
|
255
|
+
python -m agent.{name}.example
|
|
256
|
+
```
|
|
257
|
+
|
|
258
|
+
## Plug into `dtap eval`
|
|
259
|
+
|
|
260
|
+
To make `dtap eval --agent-type {name}` work, register `{class_name}` in the
|
|
261
|
+
agent registry (see `utils/agent_helpers.py:_get_agent_registry`).
|
|
262
|
+
'''
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def scaffold(
|
|
266
|
+
name: str,
|
|
267
|
+
framework: Optional[str],
|
|
268
|
+
out_root: Path,
|
|
269
|
+
*,
|
|
270
|
+
force: bool = False,
|
|
271
|
+
) -> Path:
|
|
272
|
+
"""Create <out_root>/<name>/ with templates. Returns the directory path."""
|
|
273
|
+
_validate_name(name)
|
|
274
|
+
if framework and framework not in _FRAMEWORK_INFO:
|
|
275
|
+
raise ValueError(
|
|
276
|
+
f"Unknown framework '{framework}'. Supported: {', '.join(SUPPORTED_FRAMEWORKS)}."
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
class_name = _to_class_name(name)
|
|
280
|
+
target = out_root / name
|
|
281
|
+
if target.exists() and not force:
|
|
282
|
+
raise FileExistsError(f"{target} already exists. Pass --force to overwrite.")
|
|
283
|
+
target.mkdir(parents=True, exist_ok=True)
|
|
284
|
+
|
|
285
|
+
framework_arg = f" --framework {framework}" if framework else ""
|
|
286
|
+
|
|
287
|
+
if framework is None:
|
|
288
|
+
(target / "__init__.py").write_text(_TPL_INIT_SUBCLASS.format(
|
|
289
|
+
class_name=class_name, name=name, framework_arg=framework_arg,
|
|
290
|
+
))
|
|
291
|
+
(target / "agent.py").write_text(_TPL_AGENT_SCRATCH.format(
|
|
292
|
+
class_name=class_name, name=name,
|
|
293
|
+
))
|
|
294
|
+
(target / "example.py").write_text(_TPL_EXAMPLE_SUBCLASS.format(
|
|
295
|
+
class_name=class_name,
|
|
296
|
+
))
|
|
297
|
+
(target / "README.md").write_text(_TPL_README_SUBCLASS.format(
|
|
298
|
+
class_name=class_name, name=name, framework_arg=framework_arg,
|
|
299
|
+
))
|
|
300
|
+
return target
|
|
301
|
+
|
|
302
|
+
info = _FRAMEWORK_INFO[framework]
|
|
303
|
+
if info["kind"] == "wrapper":
|
|
304
|
+
(target / "__init__.py").write_text(_TPL_INIT_WRAPPER.format(
|
|
305
|
+
class_name=class_name, name=name, framework=framework,
|
|
306
|
+
))
|
|
307
|
+
(target / "agent.py").write_text(_TPL_AGENT_WRAPPER.format(
|
|
308
|
+
class_name=class_name, framework=framework,
|
|
309
|
+
native_import=info["native_import"],
|
|
310
|
+
native_factory=info["native_factory"].format(name=name),
|
|
311
|
+
))
|
|
312
|
+
(target / "example.py").write_text(_TPL_EXAMPLE_WRAPPER.format(
|
|
313
|
+
class_name=class_name,
|
|
314
|
+
))
|
|
315
|
+
(target / "README.md").write_text(_TPL_README_WRAPPER.format(
|
|
316
|
+
class_name=class_name, name=name, framework=framework,
|
|
317
|
+
))
|
|
318
|
+
else: # subclass
|
|
319
|
+
(target / "__init__.py").write_text(_TPL_INIT_SUBCLASS.format(
|
|
320
|
+
class_name=class_name, name=name, framework_arg=framework_arg,
|
|
321
|
+
))
|
|
322
|
+
(target / "agent.py").write_text(_TPL_AGENT_SUBCLASS.format(
|
|
323
|
+
class_name=class_name, name=name,
|
|
324
|
+
base_import=info["base_import"],
|
|
325
|
+
base_class=info["base_class"],
|
|
326
|
+
))
|
|
327
|
+
(target / "example.py").write_text(_TPL_EXAMPLE_SUBCLASS.format(
|
|
328
|
+
class_name=class_name,
|
|
329
|
+
))
|
|
330
|
+
(target / "README.md").write_text(_TPL_README_SUBCLASS.format(
|
|
331
|
+
class_name=class_name, name=name, framework_arg=framework_arg,
|
|
332
|
+
))
|
|
333
|
+
|
|
334
|
+
return target
|