synth-ai 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/baseline/banking77_baseline.py +204 -0
- examples/baseline/crafter_baseline.py +407 -0
- examples/baseline/pokemon_red_baseline.py +326 -0
- examples/baseline/simple_baseline.py +56 -0
- examples/baseline/warming_up_to_rl_baseline.py +239 -0
- examples/blog_posts/gepa/README.md +355 -0
- examples/blog_posts/gepa/configs/banking77_gepa_local.toml +95 -0
- examples/blog_posts/gepa/configs/banking77_gepa_test.toml +82 -0
- examples/blog_posts/gepa/configs/banking77_mipro_local.toml +52 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hotpotqa_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/hover_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hover_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hover_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/ifbench_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/pupa_gepa_local.toml +60 -0
- examples/blog_posts/gepa/configs/pupa_mipro_local.toml +54 -0
- examples/blog_posts/gepa/deploy_banking77_task_app.sh +41 -0
- examples/blog_posts/gepa/gepa_baseline.py +204 -0
- examples/blog_posts/gepa/query_prompts_example.py +97 -0
- examples/blog_posts/gepa/run_gepa_banking77.sh +87 -0
- examples/blog_posts/gepa/task_apps.py +105 -0
- examples/blog_posts/gepa/test_gepa_local.sh +67 -0
- examples/blog_posts/gepa/verify_banking77_setup.sh +123 -0
- examples/blog_posts/pokemon_vl/configs/eval_gpt5nano.toml +26 -0
- examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +12 -10
- examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +1 -0
- examples/blog_posts/pokemon_vl/extract_images.py +239 -0
- examples/blog_posts/pokemon_vl/pokemon_vl_baseline.py +326 -0
- examples/blog_posts/pokemon_vl/run_eval_extract_images.py +209 -0
- examples/blog_posts/pokemon_vl/run_qwen_eval_extract_images.py +212 -0
- examples/blog_posts/pokemon_vl/text_box_analysis.md +106 -0
- examples/blog_posts/warming_up_to_rl/ARCHITECTURE.md +195 -0
- examples/blog_posts/warming_up_to_rl/FINAL_TEST_RESULTS.md +127 -0
- examples/blog_posts/warming_up_to_rl/INFERENCE_SUCCESS.md +132 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TESTING.md +164 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TEST_COMPLETE.md +253 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_baseline_qwen32b_10x20.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b_10x20.toml +26 -0
- examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +1 -1
- examples/blog_posts/warming_up_to_rl/configs/smoke_test.toml +75 -0
- examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +60 -10
- examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +1 -1
- examples/blog_posts/warming_up_to_rl/warming_up_to_rl_baseline.py +187 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +4 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +4 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +1 -0
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +1 -0
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +1 -0
- examples/rl/configs/rl_from_base_qwen17.toml +1 -0
- examples/swe/task_app/hosted/inference/openai_client.py +0 -34
- examples/swe/task_app/hosted/policy_routes.py +17 -0
- examples/swe/task_app/hosted/rollout.py +4 -2
- examples/task_apps/banking77/__init__.py +6 -0
- examples/task_apps/banking77/banking77_task_app.py +841 -0
- examples/task_apps/banking77/deploy_wrapper.py +46 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +4 -0
- examples/task_apps/crafter/task_app/grpo_crafter.py +24 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +355 -58
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +68 -7
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +78 -21
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +194 -1
- examples/task_apps/gepa_benchmarks/__init__.py +7 -0
- examples/task_apps/gepa_benchmarks/common.py +260 -0
- examples/task_apps/gepa_benchmarks/hotpotqa_task_app.py +507 -0
- examples/task_apps/gepa_benchmarks/hover_task_app.py +436 -0
- examples/task_apps/gepa_benchmarks/ifbench_task_app.py +563 -0
- examples/task_apps/gepa_benchmarks/pupa_task_app.py +460 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +4 -0
- examples/task_apps/pokemon_red/task_app.py +254 -36
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +1 -0
- examples/warming_up_to_rl/task_app/grpo_crafter.py +53 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +152 -41
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +31 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +33 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +67 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +1 -0
- synth_ai/api/train/builders.py +90 -1
- synth_ai/api/train/cli.py +396 -21
- synth_ai/api/train/config_finder.py +13 -2
- synth_ai/api/train/configs/__init__.py +15 -1
- synth_ai/api/train/configs/prompt_learning.py +442 -0
- synth_ai/api/train/configs/rl.py +29 -0
- synth_ai/api/train/task_app.py +1 -1
- synth_ai/api/train/validators.py +277 -0
- synth_ai/baseline/__init__.py +25 -0
- synth_ai/baseline/config.py +209 -0
- synth_ai/baseline/discovery.py +214 -0
- synth_ai/baseline/execution.py +146 -0
- synth_ai/cli/__init__.py +85 -17
- synth_ai/cli/__main__.py +0 -0
- synth_ai/cli/claude.py +70 -0
- synth_ai/cli/codex.py +84 -0
- synth_ai/cli/commands/__init__.py +1 -0
- synth_ai/cli/commands/baseline/__init__.py +12 -0
- synth_ai/cli/commands/baseline/core.py +637 -0
- synth_ai/cli/commands/baseline/list.py +93 -0
- synth_ai/cli/commands/eval/core.py +13 -10
- synth_ai/cli/commands/filter/core.py +53 -17
- synth_ai/cli/commands/help/core.py +0 -1
- synth_ai/cli/commands/smoke/__init__.py +7 -0
- synth_ai/cli/commands/smoke/core.py +1436 -0
- synth_ai/cli/commands/status/subcommands/pricing.py +22 -0
- synth_ai/cli/commands/status/subcommands/usage.py +203 -0
- synth_ai/cli/commands/train/judge_schemas.py +1 -0
- synth_ai/cli/commands/train/judge_validation.py +1 -0
- synth_ai/cli/commands/train/validation.py +0 -57
- synth_ai/cli/demo.py +35 -3
- synth_ai/cli/deploy/__init__.py +40 -25
- synth_ai/cli/deploy.py +162 -0
- synth_ai/cli/legacy_root_backup.py +14 -8
- synth_ai/cli/opencode.py +107 -0
- synth_ai/cli/root.py +9 -5
- synth_ai/cli/task_app_deploy.py +1 -1
- synth_ai/cli/task_apps.py +53 -53
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +7 -4
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +9 -5
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +4 -3
- synth_ai/judge_schemas.py +1 -0
- synth_ai/learning/__init__.py +10 -0
- synth_ai/learning/prompt_learning_client.py +276 -0
- synth_ai/learning/prompt_learning_types.py +184 -0
- synth_ai/pricing/__init__.py +2 -0
- synth_ai/pricing/model_pricing.py +57 -0
- synth_ai/streaming/handlers.py +53 -4
- synth_ai/streaming/streamer.py +19 -0
- synth_ai/task/apps/__init__.py +1 -0
- synth_ai/task/config.py +2 -0
- synth_ai/task/tracing_utils.py +25 -25
- synth_ai/task/validators.py +44 -8
- synth_ai/task_app_cfgs.py +21 -0
- synth_ai/tracing_v3/config.py +162 -19
- synth_ai/tracing_v3/constants.py +1 -1
- synth_ai/tracing_v3/db_config.py +24 -38
- synth_ai/tracing_v3/storage/config.py +47 -13
- synth_ai/tracing_v3/storage/factory.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +113 -11
- synth_ai/tracing_v3/turso/native_manager.py +92 -16
- synth_ai/types.py +8 -0
- synth_ai/urls.py +11 -0
- synth_ai/utils/__init__.py +30 -1
- synth_ai/utils/agents.py +74 -0
- synth_ai/utils/bin.py +39 -0
- synth_ai/utils/cli.py +149 -5
- synth_ai/utils/env.py +17 -17
- synth_ai/utils/json.py +72 -0
- synth_ai/utils/modal.py +283 -1
- synth_ai/utils/paths.py +48 -0
- synth_ai/utils/uvicorn.py +113 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/METADATA +102 -4
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/RECORD +162 -88
- synth_ai/cli/commands/deploy/__init__.py +0 -23
- synth_ai/cli/commands/deploy/core.py +0 -614
- synth_ai/cli/commands/deploy/errors.py +0 -72
- synth_ai/cli/commands/deploy/validation.py +0 -11
- synth_ai/cli/deploy/core.py +0 -5
- synth_ai/cli/deploy/errors.py +0 -23
- synth_ai/cli/deploy/validation.py +0 -5
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
"""Banking77 baseline file for intent classification evaluation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict
|
|
6
|
+
|
|
7
|
+
from datasets import load_dataset
|
|
8
|
+
|
|
9
|
+
from synth_ai.baseline import BaselineConfig, BaselineTaskRunner, DataSplit, TaskResult
|
|
10
|
+
from synth_ai.inference import InferenceClient
|
|
11
|
+
import os
|
|
12
|
+
import httpx
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# Load dataset once at module level
|
|
16
|
+
_dataset = None
|
|
17
|
+
_label_names = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _load_dataset():
|
|
21
|
+
"""Load Banking77 dataset."""
|
|
22
|
+
global _dataset, _label_names
|
|
23
|
+
if _dataset is None:
|
|
24
|
+
try:
|
|
25
|
+
_dataset = load_dataset("PolyAI/banking77")
|
|
26
|
+
except Exception:
|
|
27
|
+
# Fallback: try without org prefix
|
|
28
|
+
_dataset = load_dataset("banking77")
|
|
29
|
+
_label_names = _dataset["train"].features["label"].names
|
|
30
|
+
return _dataset, _label_names
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Banking77TaskRunner(BaselineTaskRunner):
|
|
34
|
+
"""Task runner for Banking77 intent classification."""
|
|
35
|
+
|
|
36
|
+
def __init__(self, policy_config: Dict[str, Any], env_config: Dict[str, Any]):
|
|
37
|
+
super().__init__(policy_config, env_config)
|
|
38
|
+
|
|
39
|
+
# Load dataset
|
|
40
|
+
self.dataset, self.label_names = _load_dataset()
|
|
41
|
+
|
|
42
|
+
# Store config for inference
|
|
43
|
+
self.model = policy_config["model"]
|
|
44
|
+
self.temperature = policy_config.get("temperature", 0.0)
|
|
45
|
+
self.max_tokens = policy_config.get("max_tokens", 128)
|
|
46
|
+
self.inference_url = policy_config.get("inference_url")
|
|
47
|
+
|
|
48
|
+
# Tool definition
|
|
49
|
+
self.tool = {
|
|
50
|
+
"type": "function",
|
|
51
|
+
"function": {
|
|
52
|
+
"name": "banking77_classify",
|
|
53
|
+
"description": "Classify a banking query into an intent",
|
|
54
|
+
"parameters": {
|
|
55
|
+
"type": "object",
|
|
56
|
+
"properties": {
|
|
57
|
+
"label": {
|
|
58
|
+
"type": "string",
|
|
59
|
+
"enum": self.label_names,
|
|
60
|
+
"description": "The intent label",
|
|
61
|
+
}
|
|
62
|
+
},
|
|
63
|
+
"required": ["label"],
|
|
64
|
+
},
|
|
65
|
+
},
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
async def run_task(self, seed: int) -> TaskResult:
|
|
69
|
+
"""Run a single Banking77 classification task."""
|
|
70
|
+
|
|
71
|
+
# Get split
|
|
72
|
+
split = self.env_config.get("split", "train")
|
|
73
|
+
|
|
74
|
+
# Get example from dataset
|
|
75
|
+
example = self.dataset[split][seed]
|
|
76
|
+
|
|
77
|
+
# Build prompt
|
|
78
|
+
system_prompt = f"""You are an expert banking assistant that classifies customer queries.
|
|
79
|
+
Given a customer message, respond with exactly one intent label using the tool call.
|
|
80
|
+
|
|
81
|
+
Valid intents: {', '.join(self.label_names)}"""
|
|
82
|
+
|
|
83
|
+
user_prompt = f"Customer Query: {example['text']}\n\nClassify this query."
|
|
84
|
+
|
|
85
|
+
# Run inference
|
|
86
|
+
messages = [
|
|
87
|
+
{"role": "system", "content": system_prompt},
|
|
88
|
+
{"role": "user", "content": user_prompt},
|
|
89
|
+
]
|
|
90
|
+
|
|
91
|
+
# Use InferenceClient if URL provided, otherwise use OpenAI-compatible API
|
|
92
|
+
if self.inference_url and self.inference_url.startswith("http"):
|
|
93
|
+
api_key = os.getenv("SYNTH_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
|
|
94
|
+
base_url = self.inference_url.rstrip("/")
|
|
95
|
+
if not base_url.endswith("/api"):
|
|
96
|
+
base_url = f"{base_url}/api" if "/api" not in base_url else base_url
|
|
97
|
+
client = InferenceClient(base_url=base_url, api_key=api_key)
|
|
98
|
+
response = await client.create_chat_completion(
|
|
99
|
+
model=self.model,
|
|
100
|
+
messages=messages,
|
|
101
|
+
tools=[self.tool],
|
|
102
|
+
tool_choice={"type": "function", "function": {"name": "banking77_classify"}},
|
|
103
|
+
temperature=self.temperature,
|
|
104
|
+
max_tokens=self.max_tokens,
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
# Use OpenAI/Groq directly
|
|
108
|
+
# Check if model starts with groq: prefix
|
|
109
|
+
model_name = self.model
|
|
110
|
+
use_groq = model_name.startswith("groq:")
|
|
111
|
+
if use_groq:
|
|
112
|
+
model_name = model_name[5:] # Remove "groq:" prefix
|
|
113
|
+
|
|
114
|
+
api_key = os.getenv("GROQ_API_KEY") if use_groq else os.getenv("OPENAI_API_KEY") or ""
|
|
115
|
+
base_url = "https://api.groq.com/openai/v1" if use_groq else "https://api.openai.com/v1"
|
|
116
|
+
async with httpx.AsyncClient() as http_client:
|
|
117
|
+
resp = await http_client.post(
|
|
118
|
+
f"{base_url}/chat/completions",
|
|
119
|
+
json={
|
|
120
|
+
"model": model_name,
|
|
121
|
+
"messages": messages,
|
|
122
|
+
"tools": [self.tool],
|
|
123
|
+
"tool_choice": {"type": "function", "function": {"name": "banking77_classify"}},
|
|
124
|
+
"temperature": self.temperature,
|
|
125
|
+
"max_tokens": self.max_tokens,
|
|
126
|
+
},
|
|
127
|
+
headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
|
|
128
|
+
)
|
|
129
|
+
response = resp.json()
|
|
130
|
+
|
|
131
|
+
# Extract prediction
|
|
132
|
+
predicted_label = ""
|
|
133
|
+
tool_calls = []
|
|
134
|
+
if "choices" in response and len(response["choices"]) > 0:
|
|
135
|
+
message = response["choices"][0].get("message", {})
|
|
136
|
+
tool_calls = message.get("tool_calls", [])
|
|
137
|
+
elif "tool_calls" in response:
|
|
138
|
+
tool_calls = response["tool_calls"]
|
|
139
|
+
|
|
140
|
+
if tool_calls:
|
|
141
|
+
# Handle both string and dict arguments
|
|
142
|
+
args = tool_calls[0]["function"].get("arguments", "")
|
|
143
|
+
if isinstance(args, str):
|
|
144
|
+
import json
|
|
145
|
+
args = json.loads(args)
|
|
146
|
+
predicted_label = args.get("label", "") if isinstance(args, dict) else ""
|
|
147
|
+
|
|
148
|
+
# Evaluate
|
|
149
|
+
expected_label = self.label_names[example["label"]]
|
|
150
|
+
correct = predicted_label == expected_label
|
|
151
|
+
|
|
152
|
+
return TaskResult(
|
|
153
|
+
seed=seed,
|
|
154
|
+
success=True,
|
|
155
|
+
outcome_reward=1.0 if correct else 0.0,
|
|
156
|
+
total_steps=1,
|
|
157
|
+
metadata={
|
|
158
|
+
"query": example["text"],
|
|
159
|
+
"expected": expected_label,
|
|
160
|
+
"predicted": predicted_label,
|
|
161
|
+
"correct": correct,
|
|
162
|
+
"split": split,
|
|
163
|
+
},
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
# Define baseline config
|
|
168
|
+
# Note: We need to load the dataset first to get the label names
|
|
169
|
+
_load_dataset()
|
|
170
|
+
banking77_baseline = BaselineConfig(
|
|
171
|
+
baseline_id="banking77",
|
|
172
|
+
name="Banking77 Intent Classification",
|
|
173
|
+
description="Banking intent classification from customer queries",
|
|
174
|
+
task_runner=Banking77TaskRunner,
|
|
175
|
+
splits={
|
|
176
|
+
"train": DataSplit(
|
|
177
|
+
name="train",
|
|
178
|
+
seeds=list(range(min(10000, len(_dataset["train"]))) if _dataset else range(10000)),
|
|
179
|
+
),
|
|
180
|
+
"val": DataSplit(
|
|
181
|
+
name="val",
|
|
182
|
+
seeds=list(range(min(1000, len(_dataset["test"]))) if _dataset else range(1000)),
|
|
183
|
+
),
|
|
184
|
+
"test": DataSplit(
|
|
185
|
+
name="test",
|
|
186
|
+
seeds=list(range(min(3000, len(_dataset["test"]))) if _dataset else range(3000)),
|
|
187
|
+
),
|
|
188
|
+
},
|
|
189
|
+
default_policy_config={
|
|
190
|
+
"model": "groq:llama-3.1-70b-versatile",
|
|
191
|
+
"temperature": 0.0,
|
|
192
|
+
"max_tokens": 128,
|
|
193
|
+
},
|
|
194
|
+
default_env_config={
|
|
195
|
+
"split": "train",
|
|
196
|
+
},
|
|
197
|
+
metadata={
|
|
198
|
+
"dataset": "PolyAI/banking77",
|
|
199
|
+
"num_classes": 77,
|
|
200
|
+
"task_type": "classification",
|
|
201
|
+
},
|
|
202
|
+
tags=["classification", "nlp", "intent"],
|
|
203
|
+
)
|
|
204
|
+
|
|
@@ -0,0 +1,407 @@
|
|
|
1
|
+
"""Crafter baseline file for self-contained evaluation.
|
|
2
|
+
|
|
3
|
+
This baseline file defines how to evaluate agents on Crafter without
|
|
4
|
+
requiring a deployed task app. It includes train/val/test splits and
|
|
5
|
+
computes both event rewards (achievement deltas) and outcome rewards
|
|
6
|
+
(total unique achievements).
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from typing import Any, Dict, List, Optional, Set
|
|
12
|
+
from uuid import uuid4
|
|
13
|
+
|
|
14
|
+
from synth_ai.baseline import BaselineConfig, BaselineTaskRunner, DataSplit, TaskResult
|
|
15
|
+
from synth_ai.environments.examples.crafter_classic.environment import (
|
|
16
|
+
CrafterClassicEnvironment,
|
|
17
|
+
)
|
|
18
|
+
from synth_ai.environments.examples.crafter_classic.taskset import (
|
|
19
|
+
CrafterTaskInstance,
|
|
20
|
+
CrafterTaskInstanceMetadata,
|
|
21
|
+
)
|
|
22
|
+
from synth_ai.environments.tasks.core import Impetus, Intent
|
|
23
|
+
from synth_ai.inference import InferenceClient
|
|
24
|
+
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
25
|
+
import os
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# Action mapping: string names to action indices
|
|
29
|
+
CRAFTER_ACTION_MAP: Dict[str, int] = {
|
|
30
|
+
"noop": 0,
|
|
31
|
+
"move_left": 1,
|
|
32
|
+
"move_right": 2,
|
|
33
|
+
"move_up": 3,
|
|
34
|
+
"move_down": 4,
|
|
35
|
+
"do": 5,
|
|
36
|
+
"sleep": 6,
|
|
37
|
+
"place_stone": 7,
|
|
38
|
+
"place_table": 8,
|
|
39
|
+
"place_furnace": 9,
|
|
40
|
+
"place_plant": 10,
|
|
41
|
+
"make_wood_pickaxe": 11,
|
|
42
|
+
"make_stone_pickaxe": 12,
|
|
43
|
+
"make_iron_pickaxe": 13,
|
|
44
|
+
"make_wood_sword": 14,
|
|
45
|
+
"make_stone_sword": 15,
|
|
46
|
+
"make_iron_sword": 16,
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def format_crafter_observation(obs: Dict[str, Any]) -> str:
|
|
51
|
+
"""Format Crafter observation as text for LLM."""
|
|
52
|
+
health = obs.get("health") or obs.get("inventory", {}).get("health", 0)
|
|
53
|
+
inventory = obs.get("inventory", {})
|
|
54
|
+
pos = obs.get("player_position", [0, 0])
|
|
55
|
+
achievements_status = obs.get("achievements_status", {})
|
|
56
|
+
|
|
57
|
+
# Format inventory (skip health)
|
|
58
|
+
inv_items = [f"{k}:{v}" for k, v in inventory.items() if v > 0 and k != "health"]
|
|
59
|
+
inventory_str = ", ".join(inv_items) if inv_items else "empty"
|
|
60
|
+
|
|
61
|
+
# Format achievements
|
|
62
|
+
achieved_list = [k for k, v in achievements_status.items() if v]
|
|
63
|
+
achievements_str = ", ".join(achieved_list) if achieved_list else "none"
|
|
64
|
+
|
|
65
|
+
return f"""Crafter Game State:
|
|
66
|
+
- Health: {health}/10
|
|
67
|
+
- Hunger: {inventory.get('hunger', 0)}/10
|
|
68
|
+
- Position: {pos}
|
|
69
|
+
- Inventory: {inventory_str}
|
|
70
|
+
- Achievements unlocked: {len(achieved_list)}/22
|
|
71
|
+
- Achievements: {achievements_str}
|
|
72
|
+
|
|
73
|
+
What actions should we take?"""
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class CrafterTaskRunner(BaselineTaskRunner):
|
|
77
|
+
"""Task runner for Crafter survival game."""
|
|
78
|
+
|
|
79
|
+
def __init__(self, policy_config: Dict[str, Any], env_config: Dict[str, Any]):
|
|
80
|
+
super().__init__(policy_config, env_config)
|
|
81
|
+
|
|
82
|
+
# Initialize inference client
|
|
83
|
+
inference_url = policy_config.get("inference_url")
|
|
84
|
+
if inference_url and inference_url.startswith("http"):
|
|
85
|
+
# External URL - use InferenceClient
|
|
86
|
+
api_key = os.getenv("SYNTH_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
|
|
87
|
+
base_url = inference_url.rstrip("/")
|
|
88
|
+
if not base_url.endswith("/api"):
|
|
89
|
+
base_url = f"{base_url}/api" if "/api" not in base_url else base_url
|
|
90
|
+
self.client = InferenceClient(base_url=base_url, api_key=api_key)
|
|
91
|
+
self.use_inference_client = True
|
|
92
|
+
else:
|
|
93
|
+
# For OpenAI/Groq direct APIs, we'll use httpx
|
|
94
|
+
import httpx
|
|
95
|
+
self.http_client = httpx.AsyncClient()
|
|
96
|
+
self.use_inference_client = False
|
|
97
|
+
|
|
98
|
+
self.model = policy_config["model"]
|
|
99
|
+
self.temperature = policy_config.get("temperature", 0.0)
|
|
100
|
+
self.max_tokens = policy_config.get("max_tokens", 512)
|
|
101
|
+
|
|
102
|
+
# System prompt
|
|
103
|
+
self.system_prompt = """You are playing Crafter, a survival game. Your goal is to unlock achievements.
|
|
104
|
+
|
|
105
|
+
Core rules:
|
|
106
|
+
- The world contains trees (wood), stone, coal, iron, plants, cows, zombies, and water.
|
|
107
|
+
- Movement constraints: you cannot walk onto blocking tiles (tree, stone, water, lava, coal, iron).
|
|
108
|
+
- You start with empty hands and low health/hunger.
|
|
109
|
+
- Interact ('do') only when adjacent to a resource.
|
|
110
|
+
- Movement is essential: move multiple steps in one turn to explore.
|
|
111
|
+
|
|
112
|
+
Available actions: noop, move_up, move_down, move_left, move_right, do, sleep,
|
|
113
|
+
place_stone, place_table, place_furnace, place_plant, make_wood_pickaxe,
|
|
114
|
+
make_stone_pickaxe, make_iron_pickaxe, make_wood_sword, make_stone_sword, make_iron_sword
|
|
115
|
+
|
|
116
|
+
Always return a tool call: interact_many({actions: [...]})
|
|
117
|
+
Use 2-5 actions per call. Prefer long movement sequences."""
|
|
118
|
+
|
|
119
|
+
# Tool definition
|
|
120
|
+
self.tools = [{
|
|
121
|
+
"type": "function",
|
|
122
|
+
"function": {
|
|
123
|
+
"name": "interact_many",
|
|
124
|
+
"description": "Execute multiple Crafter actions in sequence",
|
|
125
|
+
"parameters": {
|
|
126
|
+
"type": "object",
|
|
127
|
+
"properties": {
|
|
128
|
+
"actions": {
|
|
129
|
+
"type": "array",
|
|
130
|
+
"items": {"type": "string", "enum": list(CRAFTER_ACTION_MAP.keys())},
|
|
131
|
+
"description": "List of actions to execute",
|
|
132
|
+
}
|
|
133
|
+
},
|
|
134
|
+
"required": ["actions"],
|
|
135
|
+
},
|
|
136
|
+
},
|
|
137
|
+
}]
|
|
138
|
+
|
|
139
|
+
async def run_task(self, seed: int) -> TaskResult:
|
|
140
|
+
"""Run a single Crafter episode and return results."""
|
|
141
|
+
|
|
142
|
+
# Create task instance
|
|
143
|
+
difficulty = self.env_config.get("difficulty", "normal")
|
|
144
|
+
max_steps = self.env_config.get("max_steps", 100)
|
|
145
|
+
|
|
146
|
+
impetus = Impetus(instructions="Survive and unlock achievements.")
|
|
147
|
+
intent = Intent(
|
|
148
|
+
rubric={"goal": "Unlock achievements"},
|
|
149
|
+
gold_trajectories=None,
|
|
150
|
+
gold_state_diff={},
|
|
151
|
+
)
|
|
152
|
+
metadata = CrafterTaskInstanceMetadata(
|
|
153
|
+
difficulty=difficulty,
|
|
154
|
+
seed=seed,
|
|
155
|
+
num_trees_radius=0,
|
|
156
|
+
num_cows_radius=0,
|
|
157
|
+
num_hostiles_radius=0,
|
|
158
|
+
)
|
|
159
|
+
task_instance = CrafterTaskInstance(
|
|
160
|
+
id=uuid4(),
|
|
161
|
+
impetus=impetus,
|
|
162
|
+
intent=intent,
|
|
163
|
+
metadata=metadata,
|
|
164
|
+
is_reproducible=True,
|
|
165
|
+
initial_engine_snapshot=None,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# Attach config
|
|
169
|
+
task_instance.config = {"seed": seed, "length": 256, "area": [64, 64]}
|
|
170
|
+
|
|
171
|
+
# Create environment
|
|
172
|
+
env = CrafterClassicEnvironment(task_instance=task_instance)
|
|
173
|
+
|
|
174
|
+
# Setup tracing
|
|
175
|
+
tracer: Optional[SessionTracer] = None
|
|
176
|
+
session_id: Optional[str] = None
|
|
177
|
+
if self.env_config.get("enable_tracing", True):
|
|
178
|
+
tracer = SessionTracer(db_url=None, auto_save=False)
|
|
179
|
+
await tracer.initialize()
|
|
180
|
+
session_id = tracer.create_session(metadata={
|
|
181
|
+
"seed": seed,
|
|
182
|
+
"difficulty": difficulty,
|
|
183
|
+
"model": self.policy_config["model"],
|
|
184
|
+
})
|
|
185
|
+
|
|
186
|
+
# Initialize environment
|
|
187
|
+
raw_obs = await env.initialize()
|
|
188
|
+
observation = getattr(raw_obs, "observation", raw_obs) if hasattr(raw_obs, "observation") else raw_obs
|
|
189
|
+
obs_dict = observation if isinstance(observation, dict) else {}
|
|
190
|
+
|
|
191
|
+
# Track achievements
|
|
192
|
+
prev_achievements: Set[str] = set()
|
|
193
|
+
if isinstance(obs_dict.get("achievements_status"), dict):
|
|
194
|
+
prev_achievements = {
|
|
195
|
+
k for k, v in obs_dict.get("achievements_status", {}).items() if v
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
event_rewards: List[Dict[str, Any]] = []
|
|
199
|
+
total_steps = 0
|
|
200
|
+
tool_calls_history: List[Dict[str, Any]] = []
|
|
201
|
+
|
|
202
|
+
# Episode loop
|
|
203
|
+
for step in range(max_steps):
|
|
204
|
+
# Format observation
|
|
205
|
+
obs_text = format_crafter_observation(obs_dict)
|
|
206
|
+
|
|
207
|
+
# Build messages
|
|
208
|
+
messages = [
|
|
209
|
+
{"role": "system", "content": self.system_prompt},
|
|
210
|
+
{"role": "user", "content": f"{obs_text}\n\nPrevious tool calls: {tool_calls_history[-3:]}"},
|
|
211
|
+
]
|
|
212
|
+
|
|
213
|
+
# Record LLM event
|
|
214
|
+
llm_event_id = None
|
|
215
|
+
if tracer and session_id:
|
|
216
|
+
llm_event_id = tracer.record_event(
|
|
217
|
+
session_id=session_id,
|
|
218
|
+
event_type="cais",
|
|
219
|
+
data={"messages": messages, "step": step},
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Get action from LLM
|
|
223
|
+
if self.use_inference_client:
|
|
224
|
+
response = await self.client.create_chat_completion(
|
|
225
|
+
model=self.model,
|
|
226
|
+
messages=messages,
|
|
227
|
+
tools=self.tools,
|
|
228
|
+
tool_choice={"type": "function", "function": {"name": "interact_many"}},
|
|
229
|
+
temperature=self.temperature,
|
|
230
|
+
max_tokens=self.max_tokens,
|
|
231
|
+
)
|
|
232
|
+
else:
|
|
233
|
+
# Fallback: use OpenAI-compatible API
|
|
234
|
+
import httpx
|
|
235
|
+
import json as json_lib
|
|
236
|
+
api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GROQ_API_KEY") or ""
|
|
237
|
+
base_url = "https://api.openai.com/v1" if "openai" in self.model.lower() else "https://api.groq.com/openai/v1"
|
|
238
|
+
async with httpx.AsyncClient() as client:
|
|
239
|
+
resp = await client.post(
|
|
240
|
+
f"{base_url}/chat/completions",
|
|
241
|
+
json={
|
|
242
|
+
"model": self.model,
|
|
243
|
+
"messages": messages,
|
|
244
|
+
"tools": self.tools,
|
|
245
|
+
"tool_choice": {"type": "function", "function": {"name": "interact_many"}},
|
|
246
|
+
"temperature": self.temperature,
|
|
247
|
+
"max_tokens": self.max_tokens,
|
|
248
|
+
},
|
|
249
|
+
headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
|
|
250
|
+
)
|
|
251
|
+
response = resp.json()
|
|
252
|
+
|
|
253
|
+
# Parse tool call
|
|
254
|
+
tool_calls = []
|
|
255
|
+
if "choices" in response and len(response["choices"]) > 0:
|
|
256
|
+
message = response["choices"][0].get("message", {})
|
|
257
|
+
tool_calls = message.get("tool_calls", [])
|
|
258
|
+
elif "tool_calls" in response:
|
|
259
|
+
tool_calls = response["tool_calls"]
|
|
260
|
+
|
|
261
|
+
if not tool_calls:
|
|
262
|
+
break
|
|
263
|
+
|
|
264
|
+
tool_call = tool_calls[0]
|
|
265
|
+
actions = tool_call["function"]["arguments"].get("actions", [])
|
|
266
|
+
tool_calls_history.append({"step": step, "actions": actions})
|
|
267
|
+
|
|
268
|
+
# Execute actions
|
|
269
|
+
for action_name in actions:
|
|
270
|
+
if total_steps >= max_steps:
|
|
271
|
+
break
|
|
272
|
+
|
|
273
|
+
# Map action string to index
|
|
274
|
+
action_idx = CRAFTER_ACTION_MAP.get(action_name, 0)
|
|
275
|
+
|
|
276
|
+
# Step environment
|
|
277
|
+
step_result = await env.step(action_idx)
|
|
278
|
+
total_steps += 1
|
|
279
|
+
|
|
280
|
+
# Get observation from step result
|
|
281
|
+
step_obs = getattr(step_result, "observation", step_result) if hasattr(step_result, "observation") else step_result
|
|
282
|
+
obs_dict = step_obs if isinstance(step_obs, dict) else {}
|
|
283
|
+
|
|
284
|
+
# Record environment event
|
|
285
|
+
env_event_id = None
|
|
286
|
+
if tracer and session_id:
|
|
287
|
+
env_event_id = tracer.record_event(
|
|
288
|
+
session_id=session_id,
|
|
289
|
+
event_type="environment",
|
|
290
|
+
data={
|
|
291
|
+
"action": action_name,
|
|
292
|
+
"reward": getattr(step_result, "reward", 0.0),
|
|
293
|
+
"terminated": getattr(step_result, "terminated", False),
|
|
294
|
+
"step": total_steps,
|
|
295
|
+
},
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
# Check for new achievements
|
|
299
|
+
current_achievements: Set[str] = set()
|
|
300
|
+
if isinstance(obs_dict.get("achievements_status"), dict):
|
|
301
|
+
current_achievements = {
|
|
302
|
+
k for k, v in obs_dict.get("achievements_status", {}).items() if v
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
new_achievements = current_achievements - prev_achievements
|
|
306
|
+
|
|
307
|
+
if new_achievements:
|
|
308
|
+
event_reward_value = len(new_achievements)
|
|
309
|
+
if tracer and session_id and env_event_id:
|
|
310
|
+
tracer.record_event_reward(
|
|
311
|
+
session_id=session_id,
|
|
312
|
+
event_id=env_event_id,
|
|
313
|
+
reward_value=float(event_reward_value),
|
|
314
|
+
reward_type="achievement_delta",
|
|
315
|
+
key="achievements",
|
|
316
|
+
annotation={"new_achievements": list(new_achievements)},
|
|
317
|
+
source="environment",
|
|
318
|
+
)
|
|
319
|
+
event_rewards.append({
|
|
320
|
+
"step": total_steps,
|
|
321
|
+
"reward": event_reward_value,
|
|
322
|
+
"achievements": list(new_achievements),
|
|
323
|
+
})
|
|
324
|
+
|
|
325
|
+
prev_achievements = current_achievements
|
|
326
|
+
|
|
327
|
+
# Check termination
|
|
328
|
+
if getattr(step_result, "terminated", False) or getattr(step_result, "truncated", False):
|
|
329
|
+
break
|
|
330
|
+
|
|
331
|
+
if getattr(step_result, "terminated", False) or getattr(step_result, "truncated", False):
|
|
332
|
+
break
|
|
333
|
+
|
|
334
|
+
# Compute outcome reward
|
|
335
|
+
unique_achievements = len(prev_achievements)
|
|
336
|
+
if tracer and session_id:
|
|
337
|
+
tracer.record_outcome_reward(
|
|
338
|
+
session_id=session_id,
|
|
339
|
+
total_reward=unique_achievements,
|
|
340
|
+
achievements_count=unique_achievements,
|
|
341
|
+
total_steps=total_steps,
|
|
342
|
+
reward_metadata={
|
|
343
|
+
"achievements": list(prev_achievements),
|
|
344
|
+
},
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# Export trace
|
|
348
|
+
trace_dict = await tracer.export_session(session_id)
|
|
349
|
+
else:
|
|
350
|
+
trace_dict = None
|
|
351
|
+
|
|
352
|
+
return TaskResult(
|
|
353
|
+
seed=seed,
|
|
354
|
+
success=True,
|
|
355
|
+
outcome_reward=float(unique_achievements),
|
|
356
|
+
event_rewards=event_rewards,
|
|
357
|
+
total_steps=total_steps,
|
|
358
|
+
metadata={
|
|
359
|
+
"achievements": list(prev_achievements),
|
|
360
|
+
"achievement_count": unique_achievements,
|
|
361
|
+
"difficulty": difficulty,
|
|
362
|
+
},
|
|
363
|
+
trace=trace_dict,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
# Define baseline config
|
|
368
|
+
crafter_baseline = BaselineConfig(
|
|
369
|
+
baseline_id="crafter",
|
|
370
|
+
name="Crafter Survival",
|
|
371
|
+
description="Crafter survival game with achievement tracking",
|
|
372
|
+
task_runner=CrafterTaskRunner,
|
|
373
|
+
splits={
|
|
374
|
+
"train": DataSplit(
|
|
375
|
+
name="train",
|
|
376
|
+
seeds=list(range(100)),
|
|
377
|
+
metadata={"difficulty": "normal"},
|
|
378
|
+
),
|
|
379
|
+
"val": DataSplit(
|
|
380
|
+
name="val",
|
|
381
|
+
seeds=list(range(100, 150)),
|
|
382
|
+
metadata={"difficulty": "normal"},
|
|
383
|
+
),
|
|
384
|
+
"test": DataSplit(
|
|
385
|
+
name="test",
|
|
386
|
+
seeds=list(range(150, 200)),
|
|
387
|
+
metadata={"difficulty": "hard"},
|
|
388
|
+
),
|
|
389
|
+
},
|
|
390
|
+
default_policy_config={
|
|
391
|
+
"model": "groq:llama-3.1-70b-versatile",
|
|
392
|
+
"temperature": 0.0,
|
|
393
|
+
"max_tokens": 1024,
|
|
394
|
+
},
|
|
395
|
+
default_env_config={
|
|
396
|
+
"difficulty": "normal",
|
|
397
|
+
"max_steps": 100,
|
|
398
|
+
"enable_tracing": True,
|
|
399
|
+
},
|
|
400
|
+
metadata={
|
|
401
|
+
"environment": "crafter",
|
|
402
|
+
"reward_type": "achievements",
|
|
403
|
+
"max_achievements": 22,
|
|
404
|
+
},
|
|
405
|
+
tags=["rl", "gym", "survival", "achievements"],
|
|
406
|
+
)
|
|
407
|
+
|