synth-ai 0.2.9.dev4__py3-none-any.whl → 0.2.9.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/common_old/backend.py +0 -1
- examples/crafter_debug_render.py +15 -6
- examples/evals_old/compare_models.py +1 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +6 -2
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +4 -4
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +4 -3
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +6 -2
- examples/finetuning_old/synth_qwen_v1/finetune.py +1 -1
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +4 -4
- examples/finetuning_old/synth_qwen_v1/infer.py +1 -2
- examples/finetuning_old/synth_qwen_v1/poll.py +4 -2
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +8 -8
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +5 -4
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +11 -8
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +17 -12
- examples/finetuning_old/synth_qwen_v1/upload_data.py +1 -1
- examples/finetuning_old/synth_qwen_v1/util.py +7 -2
- examples/rl/configs/eval_base_qwen.toml +1 -1
- examples/rl/configs/rl_from_base_qwen17.toml +1 -1
- examples/rl/download_dataset.py +26 -10
- examples/rl/run_eval.py +17 -15
- examples/rl/run_rl_and_save.py +24 -7
- examples/rl/task_app/math_single_step.py +128 -11
- examples/rl/task_app/math_task_app.py +11 -3
- examples/rl_old/task_app.py +222 -53
- examples/warming_up_to_rl/analyze_trace_db.py +7 -5
- examples/warming_up_to_rl/export_trace_sft.py +141 -16
- examples/warming_up_to_rl/groq_test.py +11 -4
- examples/warming_up_to_rl/manage_secrets.py +15 -6
- examples/warming_up_to_rl/readme.md +9 -2
- examples/warming_up_to_rl/run_eval.py +108 -30
- examples/warming_up_to_rl/run_fft_and_save.py +128 -52
- examples/warming_up_to_rl/run_local_rollout.py +87 -36
- examples/warming_up_to_rl/run_local_rollout_modal.py +113 -25
- examples/warming_up_to_rl/run_local_rollout_parallel.py +80 -16
- examples/warming_up_to_rl/run_local_rollout_traced.py +125 -20
- examples/warming_up_to_rl/run_rl_and_save.py +31 -7
- examples/warming_up_to_rl/run_rollout_remote.py +37 -10
- examples/warming_up_to_rl/task_app/grpo_crafter.py +90 -27
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +9 -27
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +46 -108
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +50 -17
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +35 -21
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +8 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +29 -26
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +17 -13
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +106 -63
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +82 -84
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +76 -59
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +43 -49
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +5 -15
- synth_ai/__init__.py +1 -0
- synth_ai/api/train/builders.py +34 -10
- synth_ai/api/train/cli.py +172 -32
- synth_ai/api/train/config_finder.py +59 -4
- synth_ai/api/train/env_resolver.py +32 -14
- synth_ai/api/train/pollers.py +11 -3
- synth_ai/api/train/task_app.py +4 -1
- synth_ai/api/train/utils.py +20 -4
- synth_ai/cli/__init__.py +11 -4
- synth_ai/cli/balance.py +1 -1
- synth_ai/cli/demo.py +19 -5
- synth_ai/cli/rl_demo.py +75 -16
- synth_ai/cli/root.py +116 -37
- synth_ai/cli/task_apps.py +1286 -170
- synth_ai/cli/traces.py +1 -0
- synth_ai/cli/turso.py +73 -0
- synth_ai/core/experiment.py +0 -2
- synth_ai/demo_registry.py +67 -30
- synth_ai/demos/core/cli.py +493 -164
- synth_ai/demos/demo_task_apps/core.py +50 -6
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +36 -28
- synth_ai/demos/demo_task_apps/math/_common.py +1 -2
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +0 -2
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +168 -65
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
- synth_ai/environments/examples/bandit/engine.py +12 -4
- synth_ai/environments/examples/bandit/taskset.py +4 -4
- synth_ai/environments/reproducibility/tree.py +3 -1
- synth_ai/environments/service/core_routes.py +6 -2
- synth_ai/evals/base.py +0 -2
- synth_ai/experimental/synth_oss.py +11 -12
- synth_ai/handshake.py +3 -1
- synth_ai/http_client.py +31 -7
- synth_ai/inference/__init__.py +0 -2
- synth_ai/inference/client.py +8 -4
- synth_ai/jobs/client.py +40 -10
- synth_ai/learning/client.py +33 -8
- synth_ai/learning/config.py +0 -2
- synth_ai/learning/constants.py +0 -2
- synth_ai/learning/ft_client.py +6 -3
- synth_ai/learning/health.py +9 -2
- synth_ai/learning/jobs.py +17 -5
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +1 -3
- synth_ai/learning/prompts/random_search.py +4 -1
- synth_ai/learning/prompts/run_random_search_banking77.py +6 -1
- synth_ai/learning/rl_client.py +42 -14
- synth_ai/learning/sse.py +0 -2
- synth_ai/learning/validators.py +6 -2
- synth_ai/lm/caching/ephemeral.py +1 -3
- synth_ai/lm/core/exceptions.py +0 -2
- synth_ai/lm/core/main.py +13 -1
- synth_ai/lm/core/synth_models.py +0 -1
- synth_ai/lm/core/vendor_clients.py +4 -2
- synth_ai/lm/overrides.py +2 -2
- synth_ai/lm/vendors/core/anthropic_api.py +7 -7
- synth_ai/lm/vendors/core/openai_api.py +2 -0
- synth_ai/lm/vendors/openai_standard.py +3 -1
- synth_ai/lm/vendors/openai_standard_responses.py +6 -3
- synth_ai/lm/vendors/supported/custom_endpoint.py +1 -3
- synth_ai/lm/vendors/synth_client.py +37 -10
- synth_ai/rl/__init__.py +0 -1
- synth_ai/rl/contracts.py +0 -2
- synth_ai/rl/env_keys.py +6 -1
- synth_ai/task/__init__.py +1 -0
- synth_ai/task/apps/__init__.py +11 -11
- synth_ai/task/auth.py +29 -17
- synth_ai/task/client.py +3 -1
- synth_ai/task/contracts.py +1 -0
- synth_ai/task/datasets.py +3 -1
- synth_ai/task/errors.py +3 -2
- synth_ai/task/health.py +0 -2
- synth_ai/task/json.py +0 -1
- synth_ai/task/proxy.py +2 -5
- synth_ai/task/rubrics.py +9 -3
- synth_ai/task/server.py +31 -5
- synth_ai/task/tracing_utils.py +8 -3
- synth_ai/task/validators.py +0 -1
- synth_ai/task/vendors.py +0 -1
- synth_ai/tracing_v3/db_config.py +26 -1
- synth_ai/tracing_v3/decorators.py +1 -0
- synth_ai/tracing_v3/examples/basic_usage.py +3 -2
- synth_ai/tracing_v3/hooks.py +2 -0
- synth_ai/tracing_v3/replica_sync.py +1 -0
- synth_ai/tracing_v3/session_tracer.py +24 -3
- synth_ai/tracing_v3/storage/base.py +4 -1
- synth_ai/tracing_v3/storage/factory.py +0 -1
- synth_ai/tracing_v3/turso/manager.py +102 -38
- synth_ai/tracing_v3/turso/models.py +4 -1
- synth_ai/tracing_v3/utils.py +1 -0
- synth_ai/v0/tracing/upload.py +32 -135
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/METADATA +1 -1
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/RECORD +154 -156
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_stepwise_rewards.py +0 -58
- synth_ai/environments/examples/sokoban/units/astar_common.py +0 -95
- synth_ai/install_sqld.sh +0 -40
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/top_level.txt +0 -0
|
@@ -90,6 +90,7 @@ class CrafterEnvironmentWrapper:
|
|
|
90
90
|
logger.warning("Unknown Crafter action: %s - ignoring", action_str)
|
|
91
91
|
return None # Signal to skip this action
|
|
92
92
|
return CRAFTER_ACTIONS[action_str]
|
|
93
|
+
|
|
93
94
|
for tc in tool_calls:
|
|
94
95
|
if isinstance(tc, EnvToolCall):
|
|
95
96
|
# Expand interact_many; otherwise coerce non-interact tools into interact(action=tool)
|
|
@@ -103,12 +104,12 @@ class CrafterEnvironmentWrapper:
|
|
|
103
104
|
)
|
|
104
105
|
elif tc.tool != "interact":
|
|
105
106
|
candidate_action = tc.args.get("action") if isinstance(tc.args, dict) else None
|
|
106
|
-
resolved_action =
|
|
107
|
+
resolved_action = (
|
|
108
|
+
candidate_action if candidate_action in allowed_actions else tc.tool
|
|
109
|
+
)
|
|
107
110
|
action_int = _action_to_int(resolved_action)
|
|
108
111
|
if action_int is not None: # Skip invalid actions
|
|
109
|
-
normalized.append(
|
|
110
|
-
EnvToolCall(tool="interact", args={"action": action_int})
|
|
111
|
-
)
|
|
112
|
+
normalized.append(EnvToolCall(tool="interact", args={"action": action_int}))
|
|
112
113
|
else:
|
|
113
114
|
normalized.append(tc)
|
|
114
115
|
else:
|
|
@@ -120,13 +121,14 @@ class CrafterEnvironmentWrapper:
|
|
|
120
121
|
args = tc.get("arguments") or tc.get("args") or {}
|
|
121
122
|
if isinstance(args, str):
|
|
122
123
|
import json as _json
|
|
124
|
+
|
|
123
125
|
try:
|
|
124
126
|
args = _json.loads(args)
|
|
125
127
|
except Exception:
|
|
126
128
|
args = {}
|
|
127
129
|
# Expand interact_many into multiple interacts
|
|
128
130
|
if tool_name == "interact_many":
|
|
129
|
-
for action in
|
|
131
|
+
for action in args.get("actions") or []:
|
|
130
132
|
action_int = _action_to_int(action)
|
|
131
133
|
if action_int is not None: # Skip invalid actions
|
|
132
134
|
normalized.append(
|
|
@@ -135,11 +137,17 @@ class CrafterEnvironmentWrapper:
|
|
|
135
137
|
else:
|
|
136
138
|
# For any non-interact tool, resolve to an interact action.
|
|
137
139
|
# Support a packed list of actions under 'actions' for convenience.
|
|
138
|
-
if
|
|
140
|
+
if (
|
|
141
|
+
isinstance(args, dict)
|
|
142
|
+
and isinstance(args.get("actions"), list)
|
|
143
|
+
and args.get("actions")
|
|
144
|
+
):
|
|
139
145
|
for action in args.get("actions"):
|
|
140
146
|
action_int = _action_to_int(action)
|
|
141
147
|
if action_int is not None:
|
|
142
|
-
normalized.append(
|
|
148
|
+
normalized.append(
|
|
149
|
+
EnvToolCall(tool="interact", args={"action": action_int})
|
|
150
|
+
)
|
|
143
151
|
else:
|
|
144
152
|
candidate_action = None
|
|
145
153
|
if isinstance(args, dict) and "action" in args:
|
|
@@ -148,13 +156,18 @@ class CrafterEnvironmentWrapper:
|
|
|
148
156
|
action_int: Optional[int]
|
|
149
157
|
if isinstance(candidate_action, int):
|
|
150
158
|
action_int = _action_to_int(candidate_action)
|
|
151
|
-
elif
|
|
159
|
+
elif (
|
|
160
|
+
isinstance(candidate_action, str)
|
|
161
|
+
and candidate_action in allowed_actions
|
|
162
|
+
):
|
|
152
163
|
action_int = _action_to_int(candidate_action)
|
|
153
164
|
else:
|
|
154
165
|
# Fallback: interpret the tool name itself as the action label
|
|
155
166
|
action_int = _action_to_int(tool_name)
|
|
156
167
|
if action_int is not None:
|
|
157
|
-
normalized.append(
|
|
168
|
+
normalized.append(
|
|
169
|
+
EnvToolCall(tool="interact", args={"action": action_int})
|
|
170
|
+
)
|
|
158
171
|
|
|
159
172
|
# Ensure we have at least one valid action; default to noop if none provided
|
|
160
173
|
if not normalized:
|
|
@@ -173,7 +186,9 @@ class CrafterEnvironmentWrapper:
|
|
|
173
186
|
"semantic_map": pub_before.semantic_map,
|
|
174
187
|
}
|
|
175
188
|
actions_printable = [
|
|
176
|
-
(tc.args.get("action") if isinstance(tc.args, dict) else None)
|
|
189
|
+
(tc.args.get("action") if isinstance(tc.args, dict) else None)
|
|
190
|
+
if isinstance(tc, EnvToolCall)
|
|
191
|
+
else None
|
|
177
192
|
for tc in normalized
|
|
178
193
|
]
|
|
179
194
|
logger.info(
|
|
@@ -185,7 +200,11 @@ class CrafterEnvironmentWrapper:
|
|
|
185
200
|
[k for k, v in before_state["achievements_status"].items() if v],
|
|
186
201
|
actions_printable,
|
|
187
202
|
)
|
|
188
|
-
logger.info(
|
|
203
|
+
logger.info(
|
|
204
|
+
"Surroundings BEFORE (seed=%s):\n%s",
|
|
205
|
+
str(self.seed),
|
|
206
|
+
_format_semantic_map_view(before_state),
|
|
207
|
+
)
|
|
189
208
|
except Exception as _:
|
|
190
209
|
# Logging should not interfere with stepping; fail-fast elsewhere
|
|
191
210
|
pass
|
|
@@ -253,8 +272,14 @@ class CrafterEnvironmentWrapper:
|
|
|
253
272
|
inv_changes = ", ".join(changed_items) if changed_items else "none"
|
|
254
273
|
|
|
255
274
|
# Achievements gained/lost
|
|
256
|
-
ach_b = {
|
|
257
|
-
|
|
275
|
+
ach_b = {
|
|
276
|
+
k
|
|
277
|
+
for k, v in (before_state.get("achievements_status", {}) or {}).items()
|
|
278
|
+
if v
|
|
279
|
+
}
|
|
280
|
+
ach_a = {
|
|
281
|
+
k for k, v in (after_dict.get("achievements_status", {}) or {}).items() if v
|
|
282
|
+
}
|
|
258
283
|
ach_added = sorted(list(ach_a - ach_b))
|
|
259
284
|
ach_added_latest = ach_added
|
|
260
285
|
ach_removed = sorted(list(ach_b - ach_a))
|
|
@@ -272,12 +297,19 @@ class CrafterEnvironmentWrapper:
|
|
|
272
297
|
if reward is None and ach_added_latest:
|
|
273
298
|
try:
|
|
274
299
|
reward = float(len(ach_added_latest))
|
|
275
|
-
logger.info(
|
|
300
|
+
logger.info(
|
|
301
|
+
"Reward shaping applied: +%s (achievements added)",
|
|
302
|
+
len(ach_added_latest),
|
|
303
|
+
)
|
|
276
304
|
except Exception:
|
|
277
305
|
pass
|
|
278
306
|
except Exception:
|
|
279
307
|
pass
|
|
280
|
-
logger.info(
|
|
308
|
+
logger.info(
|
|
309
|
+
"Surroundings AFTER (seed=%s):\n%s",
|
|
310
|
+
str(self.seed),
|
|
311
|
+
_format_semantic_map_view(after_dict),
|
|
312
|
+
)
|
|
281
313
|
except Exception as _:
|
|
282
314
|
pass
|
|
283
315
|
result: Dict[str, Any] = {
|
|
@@ -340,6 +372,7 @@ class CrafterEnvironmentWrapper:
|
|
|
340
372
|
# Build reverse action map for readability
|
|
341
373
|
int_to_action = {v: k for k, v in CRAFTER_ACTIONS.items()}
|
|
342
374
|
from collections import Counter
|
|
375
|
+
|
|
343
376
|
action_ids = []
|
|
344
377
|
for tc in normalized:
|
|
345
378
|
if isinstance(tc, EnvToolCall) and isinstance(tc.args, dict):
|
|
@@ -380,7 +413,7 @@ class CrafterEnvironmentWrapper:
|
|
|
380
413
|
return {
|
|
381
414
|
"observation": convert_numpy_to_python(observation),
|
|
382
415
|
"info": convert_numpy_to_python(info) if info else None,
|
|
383
|
-
"step_idx": self.step_idx
|
|
416
|
+
"step_idx": self.step_idx,
|
|
384
417
|
}
|
|
385
418
|
|
|
386
419
|
async def terminate(self) -> Dict[str, Any]:
|
|
@@ -390,7 +423,7 @@ class CrafterEnvironmentWrapper:
|
|
|
390
423
|
return {
|
|
391
424
|
"observation": convert_numpy_to_python(observation),
|
|
392
425
|
"info": convert_numpy_to_python(info) if info else None,
|
|
393
|
-
"step_idx": self.step_idx
|
|
426
|
+
"step_idx": self.step_idx,
|
|
394
427
|
}
|
|
395
428
|
|
|
396
429
|
def state_dict(self) -> Dict[str, Any]:
|
|
@@ -5,17 +5,18 @@ from abc import ABC, abstractmethod
|
|
|
5
5
|
from .react_agent import CrafterReActAgent
|
|
6
6
|
from .tools import TOOLS_SCHEMA
|
|
7
7
|
|
|
8
|
+
|
|
8
9
|
# Define Policy base class here to avoid circular import
|
|
9
10
|
class Policy(ABC):
|
|
10
11
|
"""Base class for environment-specific policies."""
|
|
11
|
-
|
|
12
|
+
|
|
12
13
|
@abstractmethod
|
|
13
14
|
def prepare_inference_request(
|
|
14
15
|
self, observation: Dict[str, Any], history: List[Dict[str, Any]] = None
|
|
15
16
|
) -> Tuple[List[Dict[str, Any]], Optional[List[Dict[str, Any]]]]:
|
|
16
17
|
"""Prepare an inference request."""
|
|
17
18
|
pass
|
|
18
|
-
|
|
19
|
+
|
|
19
20
|
@abstractmethod
|
|
20
21
|
def parse_model_response(
|
|
21
22
|
self, response: str, observation: Dict[str, Any]
|
|
@@ -23,6 +24,7 @@ class Policy(ABC):
|
|
|
23
24
|
"""Parse model response into tool calls."""
|
|
24
25
|
pass
|
|
25
26
|
|
|
27
|
+
|
|
26
28
|
# (imports moved to top of file to satisfy linter)
|
|
27
29
|
|
|
28
30
|
|
|
@@ -161,7 +163,7 @@ class CrafterPolicy(Policy):
|
|
|
161
163
|
# First check if we got actual tool calls
|
|
162
164
|
choices = response.get("choices", [])
|
|
163
165
|
tool_calls: List[Dict[str, Any]] = []
|
|
164
|
-
|
|
166
|
+
|
|
165
167
|
for choice in choices:
|
|
166
168
|
msg = choice.get("message", {})
|
|
167
169
|
if "tool_calls" in msg and msg["tool_calls"] is not None:
|
|
@@ -185,7 +187,7 @@ class CrafterPolicy(Policy):
|
|
|
185
187
|
"arguments": tc["arguments"],
|
|
186
188
|
}
|
|
187
189
|
)
|
|
188
|
-
|
|
190
|
+
|
|
189
191
|
# If we got tool calls, return them
|
|
190
192
|
if tool_calls:
|
|
191
193
|
# Normalize common degenerate pattern ["move_right", "do"] when nothing is nearby.
|
|
@@ -197,6 +199,7 @@ class CrafterPolicy(Policy):
|
|
|
197
199
|
if isinstance(args, str):
|
|
198
200
|
try:
|
|
199
201
|
import json
|
|
202
|
+
|
|
200
203
|
args = json.loads(args)
|
|
201
204
|
except (json.JSONDecodeError, ValueError):
|
|
202
205
|
args = {}
|
|
@@ -208,11 +211,13 @@ class CrafterPolicy(Policy):
|
|
|
208
211
|
# Simple heuristic: avoid repeating same pair; avoid 'do' with no context
|
|
209
212
|
if len(actions) == 2 and actions[0] == "move_right" and actions[1] == "do":
|
|
210
213
|
actions = ["move_right"]
|
|
211
|
-
normalized.append(
|
|
214
|
+
normalized.append(
|
|
215
|
+
{"tool_name": "interact_many", "arguments": {"actions": actions or []}}
|
|
216
|
+
)
|
|
212
217
|
else:
|
|
213
218
|
normalized.append(tc)
|
|
214
219
|
return normalized
|
|
215
|
-
|
|
220
|
+
|
|
216
221
|
# Otherwise, parse plain text content for actions
|
|
217
222
|
text = ""
|
|
218
223
|
for choice in choices:
|
|
@@ -221,15 +226,16 @@ class CrafterPolicy(Policy):
|
|
|
221
226
|
if content:
|
|
222
227
|
text = content
|
|
223
228
|
break
|
|
224
|
-
|
|
229
|
+
|
|
225
230
|
if text:
|
|
226
231
|
# Try to parse actions from the text
|
|
227
232
|
from .shared import parse_actions
|
|
233
|
+
|
|
228
234
|
actions = parse_actions(text)
|
|
229
235
|
if actions:
|
|
230
236
|
# Wrap actions in interact_many tool call
|
|
231
237
|
return [{"tool_name": "interact_many", "arguments": {"actions": actions}}]
|
|
232
|
-
|
|
238
|
+
|
|
233
239
|
# No actions found
|
|
234
240
|
return []
|
|
235
241
|
|
|
@@ -264,7 +270,11 @@ class CrafterPolicy(Policy):
|
|
|
264
270
|
prev_tool_calls = metadata["prev_tool_calls"]
|
|
265
271
|
if "prev_env_result" in metadata:
|
|
266
272
|
prev_env_result = metadata["prev_env_result"]
|
|
267
|
-
if
|
|
273
|
+
if (
|
|
274
|
+
prev_assistant_text is not None
|
|
275
|
+
or prev_tool_calls is not None
|
|
276
|
+
or prev_env_result is not None
|
|
277
|
+
):
|
|
268
278
|
self._append_assistant_turn(prev_assistant_text, prev_tool_calls, prev_env_result)
|
|
269
279
|
|
|
270
280
|
# Append current observation as the next user message (internal history only)
|
|
@@ -274,8 +284,12 @@ class CrafterPolicy(Policy):
|
|
|
274
284
|
# (formatted surroundings/inventory) with the previous 3 tool calls as context.
|
|
275
285
|
# Most recent first.
|
|
276
286
|
lines: List[str] = []
|
|
277
|
-
|
|
287
|
+
|
|
288
|
+
def _format_tool_call_line_for_context(
|
|
289
|
+
tool_name: str, arguments: Any, max_chars: int = 500
|
|
290
|
+
) -> str:
|
|
278
291
|
import json as _json
|
|
292
|
+
|
|
279
293
|
# Render arguments compactly, then clip to max_chars
|
|
280
294
|
if isinstance(arguments, (dict, list)):
|
|
281
295
|
try:
|
|
@@ -289,6 +303,7 @@ class CrafterPolicy(Policy):
|
|
|
289
303
|
if isinstance(rendered, str) and len(rendered) > max_chars:
|
|
290
304
|
rendered = rendered[:max_chars]
|
|
291
305
|
return f"- {tool_name}: {rendered}"
|
|
306
|
+
|
|
292
307
|
# Prefer pulling from trajectory_history (accumulates over turns)
|
|
293
308
|
for record in reversed(self.trajectory_history):
|
|
294
309
|
if len(lines) >= 3:
|
|
@@ -316,7 +331,9 @@ class CrafterPolicy(Policy):
|
|
|
316
331
|
args = call.get("arguments")
|
|
317
332
|
lines.append(_format_tool_call_line_for_context(name, args))
|
|
318
333
|
|
|
319
|
-
context_text = "Previous tool calls (most recent first):\n" + (
|
|
334
|
+
context_text = "Previous tool calls (most recent first):\n" + (
|
|
335
|
+
"\n".join(lines) if lines else "- none"
|
|
336
|
+
)
|
|
320
337
|
|
|
321
338
|
# Combine observation with context so the model always sees surroundings/inventory
|
|
322
339
|
combined_text = f"{observation_text}\n\n{context_text}"
|
|
@@ -326,7 +343,7 @@ class CrafterPolicy(Policy):
|
|
|
326
343
|
history=[], # no prior user/assistant history
|
|
327
344
|
turn=self.turn_index,
|
|
328
345
|
)
|
|
329
|
-
#print("Debugging only:; ", payload)
|
|
346
|
+
# print("Debugging only:; ", payload)
|
|
330
347
|
meta_out = {
|
|
331
348
|
"inference_url": self.inference_url,
|
|
332
349
|
"inference_request": payload,
|
|
@@ -372,7 +389,7 @@ class CrafterPolicy(Policy):
|
|
|
372
389
|
|
|
373
390
|
async def terminate(self) -> None:
|
|
374
391
|
return None
|
|
375
|
-
|
|
392
|
+
|
|
376
393
|
def prepare_inference_request(
|
|
377
394
|
self, observation: Dict[str, Any], history: List[Dict[str, Any]] = None
|
|
378
395
|
) -> Tuple[List[Dict[str, Any]], Optional[List[Dict[str, Any]]]]:
|
|
@@ -382,9 +399,7 @@ class CrafterPolicy(Policy):
|
|
|
382
399
|
|
|
383
400
|
# Build messages (observation_text already formatted; no raw matrices)
|
|
384
401
|
messages = CrafterReActAgent.build_messages(
|
|
385
|
-
observation=observation_text,
|
|
386
|
-
history=history,
|
|
387
|
-
turn=self.turn_index
|
|
402
|
+
observation=observation_text, history=history, turn=self.turn_index
|
|
388
403
|
)
|
|
389
404
|
|
|
390
405
|
# Return messages and tools schema
|
|
@@ -402,7 +417,6 @@ class CrafterPolicy(Policy):
|
|
|
402
417
|
if not isinstance(obs_data, dict):
|
|
403
418
|
return f"Observation: {str(observation)}"
|
|
404
419
|
|
|
405
|
-
|
|
406
420
|
# Use the shared format_observation function with step information
|
|
407
421
|
step_idx = observation.get("step_idx", 0)
|
|
408
422
|
max_steps = 100 # Default max steps, could be made configurable
|
|
@@ -416,25 +430,25 @@ class CrafterPolicy(Policy):
|
|
|
416
430
|
obs_data["health"] = info["health"]
|
|
417
431
|
|
|
418
432
|
return format_observation(obs_data, step_count=step_idx, max_steps=max_steps)
|
|
419
|
-
|
|
433
|
+
|
|
420
434
|
def parse_model_response(
|
|
421
435
|
self, response: str, observation: Dict[str, Any]
|
|
422
436
|
) -> List[Dict[str, Any]]:
|
|
423
437
|
"""Parse model response into tool calls (implementing abstract method).
|
|
424
|
-
|
|
438
|
+
|
|
425
439
|
Note: Despite the type hint, vLLM actually returns a dict response,
|
|
426
440
|
not a string. We handle both cases.
|
|
427
441
|
"""
|
|
428
442
|
# Handle dict response from vLLM (the actual case)
|
|
429
443
|
if isinstance(response, dict):
|
|
430
444
|
return self.parse_response_to_tool_calls(response, self.use_tools)
|
|
431
|
-
|
|
445
|
+
|
|
432
446
|
# Handle string response (fallback case for raw text)
|
|
433
447
|
if isinstance(response, str):
|
|
434
448
|
actions = CrafterReActAgent.parse_actions_from_response(response)
|
|
435
449
|
if actions:
|
|
436
450
|
return [{"tool_name": "interact_many", "arguments": {"actions": actions}}]
|
|
437
|
-
|
|
451
|
+
|
|
438
452
|
# Default empty response
|
|
439
453
|
return []
|
|
440
454
|
|
|
@@ -51,7 +51,7 @@ class CrafterReActAgent:
|
|
|
51
51
|
"place_stone, place_table, place_furnace, place_plant, make_wood_pickaxe, make_stone_pickaxe, "
|
|
52
52
|
"make_iron_pickaxe, make_wood_sword, make_stone_sword, make_iron_sword\n"
|
|
53
53
|
)
|
|
54
|
-
|
|
54
|
+
|
|
55
55
|
@staticmethod
|
|
56
56
|
def get_system_prompt_with_tools() -> str:
|
|
57
57
|
"""System prompt for tool-based interaction (e.g., Qwen3 models)."""
|
|
@@ -80,9 +80,13 @@ class CrafterReActAgent:
|
|
|
80
80
|
)
|
|
81
81
|
|
|
82
82
|
@staticmethod
|
|
83
|
-
def build_messages(
|
|
83
|
+
def build_messages(
|
|
84
|
+
observation: str, history: Optional[List[Dict[str, str]]] = None, turn: Optional[int] = None
|
|
85
|
+
) -> List[Dict[str, str]]:
|
|
84
86
|
"""Construct OpenAI-style messages list for vLLM generation."""
|
|
85
|
-
msgs: List[Dict[str, str]] = [
|
|
87
|
+
msgs: List[Dict[str, str]] = [
|
|
88
|
+
{"role": "system", "content": CrafterReActAgent.get_system_prompt()}
|
|
89
|
+
]
|
|
86
90
|
if history:
|
|
87
91
|
msgs.extend(history)
|
|
88
92
|
msgs.append({"role": "user", "content": observation})
|
|
@@ -93,4 +97,4 @@ class CrafterReActAgent:
|
|
|
93
97
|
return parse_actions(response_text)
|
|
94
98
|
|
|
95
99
|
|
|
96
|
-
__all__ = ["CrafterReActAgent"]
|
|
100
|
+
__all__ = ["CrafterReActAgent"]
|
|
@@ -71,7 +71,7 @@ def validate_action(action: str) -> bool:
|
|
|
71
71
|
|
|
72
72
|
def parse_actions(action_text: str) -> List[str]:
|
|
73
73
|
"""Extract actions from response text.
|
|
74
|
-
|
|
74
|
+
|
|
75
75
|
Tries multiple parsing strategies:
|
|
76
76
|
1. <action>...</action> tags (original format)
|
|
77
77
|
2. [action]...[/action] or [action]... format
|
|
@@ -80,43 +80,43 @@ def parse_actions(action_text: str) -> List[str]:
|
|
|
80
80
|
5. Newline-separated actions
|
|
81
81
|
"""
|
|
82
82
|
import json
|
|
83
|
-
|
|
83
|
+
|
|
84
84
|
# First try the original <action> tag format
|
|
85
85
|
matches = re.findall(r"<action>(.*?)</action>", action_text, re.IGNORECASE)
|
|
86
86
|
if matches:
|
|
87
87
|
return [m.strip() for m in matches if validate_action(m.strip())]
|
|
88
|
-
|
|
88
|
+
|
|
89
89
|
# Try [action] format
|
|
90
90
|
matches = re.findall(r"\[action\](.*?)(?:\[/action\]|\n|$)", action_text, re.IGNORECASE)
|
|
91
91
|
if matches:
|
|
92
92
|
return [m.strip() for m in matches if validate_action(m.strip())]
|
|
93
|
-
|
|
93
|
+
|
|
94
94
|
# If no tags found, try to parse plain text
|
|
95
95
|
text = action_text.strip()
|
|
96
|
-
|
|
96
|
+
|
|
97
97
|
# Check if the entire text is a valid action
|
|
98
98
|
if validate_action(text):
|
|
99
99
|
return [text]
|
|
100
|
-
|
|
100
|
+
|
|
101
101
|
# Try splitting by newlines and checking each line
|
|
102
|
-
lines = text.split(
|
|
102
|
+
lines = text.split("\n")
|
|
103
103
|
actions = []
|
|
104
104
|
for line in lines:
|
|
105
105
|
line = line.strip()
|
|
106
|
-
|
|
106
|
+
|
|
107
107
|
# Remove various prefixes
|
|
108
|
-
for prefix in [
|
|
108
|
+
for prefix in ["ACTION:", "Action:", "action:", "ACTION", "-", "*", "•", "**ACTION:**"]:
|
|
109
109
|
if line.startswith(prefix):
|
|
110
|
-
line = line[len(prefix):].strip()
|
|
110
|
+
line = line[len(prefix) :].strip()
|
|
111
111
|
break
|
|
112
|
-
|
|
112
|
+
|
|
113
113
|
# Also handle numbered lists
|
|
114
|
-
if re.match(r
|
|
115
|
-
line = re.sub(r
|
|
116
|
-
|
|
114
|
+
if re.match(r"^\d+\.\s*", line):
|
|
115
|
+
line = re.sub(r"^\d+\.\s*", "", line)
|
|
116
|
+
|
|
117
117
|
# Split by common separators to handle multiple actions on one line
|
|
118
|
-
parts = re.split(r
|
|
119
|
-
|
|
118
|
+
parts = re.split(r"[,;]|\s+and\s+|\s+then\s+", line)
|
|
119
|
+
|
|
120
120
|
for part in parts:
|
|
121
121
|
part = part.strip()
|
|
122
122
|
# Remove quotes if present
|
|
@@ -124,23 +124,23 @@ def parse_actions(action_text: str) -> List[str]:
|
|
|
124
124
|
part = part[1:-1]
|
|
125
125
|
if part.startswith("'") and part.endswith("'"):
|
|
126
126
|
part = part[1:-1]
|
|
127
|
-
|
|
127
|
+
|
|
128
128
|
# Check if it's a valid action
|
|
129
129
|
if part and validate_action(part):
|
|
130
130
|
actions.append(part)
|
|
131
|
-
|
|
131
|
+
|
|
132
132
|
return actions
|
|
133
133
|
|
|
134
134
|
|
|
135
135
|
def format_observation(obs_data: Dict[str, Any], step_count: int = 0, max_steps: int = 100) -> str:
|
|
136
136
|
"""Format a Crafter observation dictionary into a human-readable string.
|
|
137
|
-
|
|
137
|
+
|
|
138
138
|
This is critical for preventing massive token counts when observations
|
|
139
139
|
contain large numpy arrays or deeply nested structures.
|
|
140
140
|
"""
|
|
141
141
|
if not obs_data:
|
|
142
142
|
return ""
|
|
143
|
-
|
|
143
|
+
|
|
144
144
|
# Extract key information
|
|
145
145
|
health = obs_data.get("health") or obs_data.get("inventory", {}).get("health", 0)
|
|
146
146
|
inventory_dict = obs_data.get("inventory", {})
|
|
@@ -160,18 +160,18 @@ def format_observation(obs_data: Dict[str, Any], step_count: int = 0, max_steps:
|
|
|
160
160
|
max_steps_from_obs = obs_data.get("max_steps_episode") or obs_data.get("max_steps")
|
|
161
161
|
if isinstance(max_steps_from_obs, (int, float)) and max_steps_from_obs > 0:
|
|
162
162
|
max_steps = int(max_steps_from_obs)
|
|
163
|
-
|
|
163
|
+
|
|
164
164
|
# Format inventory (skip health as it's shown separately)
|
|
165
165
|
inv_items = [f"{k}:{v}" for k, v in inventory_dict.items() if v > 0 and k != "health"]
|
|
166
166
|
inventory_str = ", ".join(inv_items) if inv_items else "empty"
|
|
167
|
-
|
|
167
|
+
|
|
168
168
|
# Format achievements
|
|
169
169
|
achieved_list = [k for k, v in achievements.items() if v]
|
|
170
170
|
achievements_str = ", ".join(achieved_list) if achieved_list else "none"
|
|
171
|
-
|
|
171
|
+
|
|
172
172
|
# Format semantic map view (simplified version)
|
|
173
173
|
map_view = _format_semantic_map_view(obs_data, VIEW_SIZE)
|
|
174
|
-
|
|
174
|
+
|
|
175
175
|
return (
|
|
176
176
|
f"=== CRAFTER GAME STATE ===\n"
|
|
177
177
|
f"Step: {step_count}/{max_steps}\n"
|
|
@@ -184,6 +184,7 @@ def format_observation(obs_data: Dict[str, Any], step_count: int = 0, max_steps:
|
|
|
184
184
|
f"Choose your next actions.\n"
|
|
185
185
|
)
|
|
186
186
|
|
|
187
|
+
|
|
187
188
|
def _try_build_dynamic_mapping():
|
|
188
189
|
"""Attempt to build id->name mapping from a real Crafter env.
|
|
189
190
|
|
|
@@ -232,7 +233,7 @@ def _try_build_dynamic_mapping():
|
|
|
232
233
|
# Build dynamic mapping if possible; otherwise fall back to a basic map
|
|
233
234
|
_ID_TO_NAME = _try_build_dynamic_mapping()
|
|
234
235
|
_FALLBACK_ID_TO_NAME = {
|
|
235
|
-
0: "none",
|
|
236
|
+
0: "none", # None from materials
|
|
236
237
|
1: "water",
|
|
237
238
|
2: "grass",
|
|
238
239
|
3: "stone",
|
|
@@ -299,4 +300,6 @@ def _format_semantic_map_view(obs_data: Dict[str, Any], view_size: int = VIEW_SI
|
|
|
299
300
|
|
|
300
301
|
transposed = list(zip(*matrix))
|
|
301
302
|
grid_rows: List[str] = [" ".join(row) for row in transposed]
|
|
302
|
-
return
|
|
303
|
+
return (
|
|
304
|
+
"\nLocal Map View (" + str(view_size) + "x" + str(view_size) + "):\n" + "\n".join(grid_rows)
|
|
305
|
+
)
|
|
@@ -22,9 +22,7 @@ class TaskApp:
|
|
|
22
22
|
self.service_base_url = service_base_url or os.getenv(
|
|
23
23
|
"SERVICE_BASE_URL", "http://localhost:8000"
|
|
24
24
|
)
|
|
25
|
-
self.vllm_base_url = vllm_base_url or os.getenv(
|
|
26
|
-
"VLLM_BASE_URL", "http://localhost:8001"
|
|
27
|
-
)
|
|
25
|
+
self.vllm_base_url = vllm_base_url or os.getenv("VLLM_BASE_URL", "http://localhost:8001")
|
|
28
26
|
self.default_model = default_model or os.getenv("DEFAULT_MODEL")
|
|
29
27
|
|
|
30
28
|
|
|
@@ -69,9 +67,7 @@ def create_app(allowed_environments: list[str] = None) -> FastAPI:
|
|
|
69
67
|
@app.middleware("http")
|
|
70
68
|
async def validate_environment(request, call_next):
|
|
71
69
|
# Check if this is an environment-related request
|
|
72
|
-
if request.url.path.startswith("/env/") or request.url.path.startswith(
|
|
73
|
-
"/rollout"
|
|
74
|
-
):
|
|
70
|
+
if request.url.path.startswith("/env/") or request.url.path.startswith("/rollout"):
|
|
75
71
|
# Extract environment name from request body for POST requests
|
|
76
72
|
if request.method == "POST":
|
|
77
73
|
# We need to read the body to check env_name
|
|
@@ -83,9 +79,7 @@ def create_app(allowed_environments: list[str] = None) -> FastAPI:
|
|
|
83
79
|
env_name = data.get("env_name", "").lower()
|
|
84
80
|
|
|
85
81
|
# Check if environment is allowed
|
|
86
|
-
if env_name and env_name not in [
|
|
87
|
-
e.lower() for e in allowed_environments
|
|
88
|
-
]:
|
|
82
|
+
if env_name and env_name not in [e.lower() for e in allowed_environments]:
|
|
89
83
|
from fastapi import HTTPException
|
|
90
84
|
|
|
91
85
|
raise HTTPException(
|
|
@@ -111,6 +105,7 @@ def create_app(allowed_environments: list[str] = None) -> FastAPI:
|
|
|
111
105
|
# Policy routes are optional; skip if optional envs are missing in this build
|
|
112
106
|
try:
|
|
113
107
|
from .policy_routes import router as policy_router
|
|
108
|
+
|
|
114
109
|
app.include_router(policy_router, prefix="/policy", tags=["policy"])
|
|
115
110
|
except Exception as _e:
|
|
116
111
|
# Log lightweight message; policy endpoints will be unavailable
|
|
@@ -157,6 +152,7 @@ def create_app(allowed_environments: list[str] = None) -> FastAPI:
|
|
|
157
152
|
|
|
158
153
|
# Check if any environment API keys are configured
|
|
159
154
|
from synth_ai.task.auth import allowed_environment_api_keys
|
|
155
|
+
|
|
160
156
|
allowed_keys = allowed_environment_api_keys()
|
|
161
157
|
if not allowed_keys:
|
|
162
158
|
# Server-side misconfiguration; rollout would fail with 503
|
|
@@ -167,22 +163,28 @@ def create_app(allowed_environments: list[str] = None) -> FastAPI:
|
|
|
167
163
|
"detail": "Auth not configured: missing ENVIRONMENT_API_KEY in task service environment",
|
|
168
164
|
},
|
|
169
165
|
)
|
|
170
|
-
|
|
166
|
+
|
|
171
167
|
# Authorize using all header variants without typed Header params (avoid 422s)
|
|
172
168
|
from synth_ai.task.auth import is_api_key_header_authorized
|
|
169
|
+
|
|
173
170
|
authorized = is_api_key_header_authorized(request)
|
|
174
171
|
if not authorized:
|
|
175
172
|
# Soft-pass 200 with authorized=False to avoid failing CLI preflight
|
|
176
173
|
primary_key = list(allowed_keys)[0] if allowed_keys else None
|
|
177
|
-
prefix =
|
|
174
|
+
prefix = primary_key[: max(1, len(primary_key) // 2)] if primary_key else None
|
|
178
175
|
content = {"status": "healthy", "authorized": False}
|
|
179
176
|
if prefix:
|
|
180
177
|
content["expected_api_key_prefix"] = prefix
|
|
181
178
|
return JSONResponse(status_code=200, content=content)
|
|
182
|
-
return {
|
|
179
|
+
return {
|
|
180
|
+
"status": "healthy",
|
|
181
|
+
"authorized": True,
|
|
182
|
+
"service": {"base_url": task_app.service_base_url},
|
|
183
|
+
}
|
|
183
184
|
|
|
184
185
|
# Log and surface 422 validation errors with header presence
|
|
185
186
|
from fastapi.exceptions import RequestValidationError
|
|
187
|
+
|
|
186
188
|
@app.exception_handler(RequestValidationError)
|
|
187
189
|
async def _on_validation_error(request: Request, exc: RequestValidationError):
|
|
188
190
|
try:
|
|
@@ -197,6 +199,8 @@ def create_app(allowed_environments: list[str] = None) -> FastAPI:
|
|
|
197
199
|
print("[422] validation", snapshot, flush=True)
|
|
198
200
|
except Exception:
|
|
199
201
|
pass
|
|
200
|
-
return JSONResponse(
|
|
202
|
+
return JSONResponse(
|
|
203
|
+
status_code=422, content={"status": "invalid", "detail": exc.errors()[:5]}
|
|
204
|
+
)
|
|
201
205
|
|
|
202
206
|
return app
|