synth-ai 0.2.9.dev3__py3-none-any.whl → 0.2.9.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (107) hide show
  1. examples/analyze_semantic_words.sh +17 -0
  2. examples/common_old/backend.py +21 -0
  3. examples/crafter_debug_render.py +180 -0
  4. examples/evals_old/README.md +98 -0
  5. examples/evals_old/__init__.py +6 -0
  6. examples/evals_old/compare_models.py +1037 -0
  7. examples/evals_old/example_log.md +145 -0
  8. examples/evals_old/run_demo.sh +126 -0
  9. examples/evals_old/trace_analysis.py +270 -0
  10. examples/finetuning_old/_backup_synth_qwen/config.toml +29 -0
  11. examples/finetuning_old/_backup_synth_qwen/example_log.md +324 -0
  12. examples/finetuning_old/_backup_synth_qwen/filter_traces.py +60 -0
  13. examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +239 -0
  14. examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +109 -0
  15. examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +1924 -0
  16. examples/finetuning_old/_backup_synth_qwen/readme.md +49 -0
  17. examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +114 -0
  18. examples/finetuning_old/_backup_synth_qwen/run_demo.sh +195 -0
  19. examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +118 -0
  20. examples/finetuning_old/synth_qwen_v1/README.md +68 -0
  21. examples/finetuning_old/synth_qwen_v1/filter_traces.py +60 -0
  22. examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +239 -0
  23. examples/finetuning_old/synth_qwen_v1/finetune.py +46 -0
  24. examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +71 -0
  25. examples/finetuning_old/synth_qwen_v1/infer.py +37 -0
  26. examples/finetuning_old/synth_qwen_v1/poll.py +44 -0
  27. examples/finetuning_old/synth_qwen_v1/prepare_data.py +35 -0
  28. examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +109 -0
  29. examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +1932 -0
  30. examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +207 -0
  31. examples/finetuning_old/synth_qwen_v1/run_ft_job.py +232 -0
  32. examples/finetuning_old/synth_qwen_v1/upload_data.py +34 -0
  33. examples/finetuning_old/synth_qwen_v1/util.py +147 -0
  34. examples/rl/README.md +169 -0
  35. examples/rl/configs/eval_base_qwen.toml +15 -0
  36. examples/rl/configs/eval_rl_qwen.toml +11 -0
  37. examples/rl/configs/rl_from_base_qwen.toml +35 -0
  38. examples/rl/configs/rl_from_base_qwen17.toml +74 -0
  39. examples/rl/configs/rl_from_ft_qwen.toml +35 -0
  40. examples/rl/download_dataset.py +64 -0
  41. examples/rl/run_eval.py +435 -0
  42. examples/rl/run_rl_and_save.py +94 -0
  43. examples/rl/task_app/README.md +22 -0
  44. {synth_ai/task/apps → examples/rl/task_app}/math_single_step.py +8 -8
  45. examples/rl/task_app/math_task_app.py +107 -0
  46. examples/rl_old/task_app.py +962 -0
  47. examples/run_crafter_demo.sh +10 -0
  48. examples/warming_up_to_rl/analyze_trace_db.py +420 -0
  49. examples/warming_up_to_rl/configs/crafter_fft.toml +48 -0
  50. examples/warming_up_to_rl/configs/crafter_fft_4b.toml +54 -0
  51. examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +20 -0
  52. examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +13 -0
  53. examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +23 -0
  54. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +73 -0
  55. examples/warming_up_to_rl/configs/rl_from_ft.toml +56 -0
  56. examples/warming_up_to_rl/export_trace_sft.py +541 -0
  57. examples/warming_up_to_rl/groq_test.py +88 -0
  58. examples/warming_up_to_rl/manage_secrets.py +127 -0
  59. examples/warming_up_to_rl/old/event_rewards.md +234 -0
  60. examples/warming_up_to_rl/old/notes.md +73 -0
  61. examples/warming_up_to_rl/readme.md +172 -0
  62. examples/warming_up_to_rl/run_eval.py +434 -0
  63. examples/warming_up_to_rl/run_fft_and_save.py +309 -0
  64. examples/warming_up_to_rl/run_local_rollout.py +188 -0
  65. examples/warming_up_to_rl/run_local_rollout_modal.py +160 -0
  66. examples/warming_up_to_rl/run_local_rollout_parallel.py +342 -0
  67. examples/warming_up_to_rl/run_local_rollout_traced.py +372 -0
  68. examples/warming_up_to_rl/run_rl_and_save.py +101 -0
  69. examples/warming_up_to_rl/run_rollout_remote.py +129 -0
  70. examples/warming_up_to_rl/task_app/README.md +38 -0
  71. {synth_ai/task/apps → examples/warming_up_to_rl/task_app}/grpo_crafter.py +7 -7
  72. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +165 -0
  73. examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
  74. examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
  75. examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +145 -0
  76. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1271 -0
  77. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
  78. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
  79. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
  80. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +429 -0
  81. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +442 -0
  82. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +96 -0
  83. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +302 -0
  84. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
  85. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +202 -0
  86. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
  87. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +512 -0
  88. examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +102 -0
  89. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +985 -0
  90. examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +197 -0
  91. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1749 -0
  92. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
  93. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +217 -0
  94. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +160 -0
  95. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +146 -0
  96. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_stepwise_rewards.py +58 -0
  97. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +61 -0
  98. synth_ai/api/train/config_finder.py +18 -18
  99. synth_ai/api/train/env_resolver.py +28 -1
  100. synth_ai/cli/task_apps.py +264 -55
  101. synth_ai/task/apps/__init__.py +54 -13
  102. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/METADATA +1 -1
  103. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/RECORD +107 -12
  104. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/top_level.txt +1 -0
  105. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/WHEEL +0 -0
  106. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/entry_points.txt +0 -0
  107. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,434 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Baseline evaluation script (public-friendly skeleton)
4
+ - Targets a task app (Crafter-like) via initialize/step/terminate
5
+ - Uses a TaskAppClient interface (to be implemented in synth-ai SDK)
6
+ - Keeps structure aligned with research/testing/crafter eval harness
7
+ """
8
+ from __future__ import annotations
9
+ import os
10
+ import json
11
+ import re
12
+ from typing import Any, Dict, List, Optional
13
+ from collections import Counter
14
+ import asyncio
15
+ import httpx
16
+ import argparse
17
+ import tomllib
18
+ from pathlib import Path
19
+
20
+ class TaskAppClient:
21
+ """Minimal async client for the task app initialize/step/terminate routes.
22
+
23
+ This is a public-friendly shim for examples, pending SDK surface consolidation.
24
+ """
25
+
26
+ def __init__(self, base_url: str, api_key: Optional[str] = None) -> None:
27
+ self.base_url = base_url.rstrip("/")
28
+ self.api_key = api_key
29
+ self._client: Optional[httpx.AsyncClient] = None
30
+
31
+ async def __aenter__(self) -> "TaskAppClient":
32
+ headers = {}
33
+ if self.api_key:
34
+ headers["X-API-Key"] = self.api_key
35
+ self._client = httpx.AsyncClient(
36
+ base_url=self.base_url, headers=headers, timeout=600.0, follow_redirects=True
37
+ )
38
+ return self
39
+
40
+ async def __aexit__(self, exc_type, exc, tb) -> None:
41
+ if self._client is not None:
42
+ await self._client.aclose()
43
+ self._client = None
44
+
45
+ @property
46
+ def client(self) -> httpx.AsyncClient:
47
+ if self._client is None:
48
+ # Fallback for direct use without context manager
49
+ headers = {}
50
+ if self.api_key:
51
+ headers["X-API-Key"] = self.api_key
52
+ self._client = httpx.AsyncClient(
53
+ base_url=self.base_url, headers=headers, timeout=600.0, follow_redirects=True
54
+ )
55
+ return self._client
56
+
57
+ async def initialize(self, env_name: str, config: Dict[str, Any]) -> Dict[str, Any]:
58
+ """POST /env/{env_name}/initialize (compat route supported in task app)."""
59
+ payload: Dict[str, Any] = {
60
+ "seed": config.get("seed"),
61
+ }
62
+ # Allow both world_config and config inputs; env routes will normalize difficulty
63
+ if "world_config" in config:
64
+ payload["world_config"] = config["world_config"]
65
+ if "config" in config:
66
+ payload["config"] = config["config"]
67
+ resp = await self.client.post(f"/env/{env_name}/initialize", json=payload)
68
+ resp.raise_for_status()
69
+ return resp.json()
70
+
71
+ async def step(self, env_name: str, env_id: str, tool_calls: List[Dict[str, Any]]) -> Dict[str, Any]:
72
+ """POST /env/{env_name}/step with wrapped tool_calls in action."""
73
+ payload = {"env_id": env_id, "action": {"tool_calls": tool_calls}}
74
+ resp = await self.client.post(f"/env/{env_name}/step", json=payload)
75
+ resp.raise_for_status()
76
+ return resp.json()
77
+
78
+ async def terminate(self, env_name: str, env_id: str) -> Dict[str, Any]:
79
+ resp = await self.client.post(f"/env/{env_name}/terminate", json={"env_id": env_id})
80
+ resp.raise_for_status()
81
+ return resp.json()
82
+
83
+ async def get_info(self) -> Dict[str, Any]:
84
+ resp = await self.client.get("/info")
85
+ resp.raise_for_status()
86
+ return resp.json()
87
+
88
+ async def proxy_groq_chat(self, payload: Dict[str, Any]) -> Dict[str, Any]:
89
+ resp = await self.client.post("/proxy/groq/v1/chat/completions", json=payload)
90
+ resp.raise_for_status()
91
+ return resp.json()
92
+
93
+ async def vllm_chat(self, vllm_base_url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
94
+ async with httpx.AsyncClient(base_url=vllm_base_url.rstrip("/"), timeout=60.0) as c:
95
+ resp = await c.post("/v1/chat/completions", json=payload)
96
+ # Do not raise for status to surface body in errors
97
+ try:
98
+ data = resp.json()
99
+ except Exception:
100
+ data = {"error": "invalid_json", "raw": resp.text[:800]}
101
+ if resp.status_code >= 400:
102
+ return {"error": data}
103
+ return data
104
+
105
+ async def rollout(self, *, run_id: str, env_name: str, seed: int, difficulty: str, policy_name: str, policy_config: Dict[str, Any], max_turns: int) -> Dict[str, Any]:
106
+ ops: List[str] = []
107
+ for _ in range(max_turns):
108
+ ops.extend(["agent", "env"])
109
+ payload: Dict[str, Any] = {
110
+ "run_id": run_id,
111
+ "env": {
112
+ "env_name": env_name,
113
+ "config": {"difficulty": difficulty},
114
+ "seed": seed,
115
+ },
116
+ "policy": {
117
+ "policy_name": policy_name,
118
+ "config": policy_config,
119
+ },
120
+ "ops": ops,
121
+ "on_done": "terminate",
122
+ }
123
+ # Ensure X-API-Key is included
124
+ headers = {}
125
+ if self.api_key:
126
+ headers["X-API-Key"] = self.api_key
127
+ resp = await self.client.post("/rollout", json=payload, headers=headers)
128
+ resp.raise_for_status()
129
+ return resp.json()
130
+
131
+ TASK_APP_URL = os.getenv("TASK_APP_URL", "https://YOUR-TASK-APP.modal.run").rstrip("/")
132
+ MODEL = os.getenv("EVAL_MODEL", "qwen/qwen3-32b")
133
+ NUM_EPISODES = int(os.getenv("NUM_EPISODES", "3"))
134
+ MAX_TURNS = int(os.getenv("MAX_TURNS", "10"))
135
+ CONCURRENCY = int(os.getenv("CONCURRENCY", "1"))
136
+
137
+ def _interact_tool_schema() -> List[Dict[str, Any]]:
138
+ return [{
139
+ "type": "function",
140
+ "function": {
141
+ "name": "interact",
142
+ "description": "Perform actions in the Crafter environment.",
143
+ "parameters": {
144
+ "type": "object",
145
+ "properties": {
146
+ "actions": {"type": "array", "items": {"type": "string"}},
147
+ "reasoning": {"type": "string"},
148
+ },
149
+ "required": ["actions", "reasoning"],
150
+ },
151
+ },
152
+ }]
153
+
154
+ def _build_messages_from_observation(observation: Dict[str, Any], history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
155
+ inv = observation.get("inventory") or {}
156
+ pos = observation.get("player_position") or []
157
+ ach = observation.get("achievements_status") or {}
158
+ turns_taken = observation.get("num_steps_taken") or 0
159
+ user_lines: List[str] = []
160
+ user_lines.append("Environment: CrafterClassic")
161
+ user_lines.append(f"Player position: {pos}")
162
+ user_lines.append(f"Inventory: {json.dumps(inv, ensure_ascii=False)}")
163
+ unlocked = [k for k, v in ach.items() if v]
164
+ if unlocked:
165
+ user_lines.append(f"Unlocked achievements: {unlocked}")
166
+ user_lines.append("Provide 2-5 actions as a plan to explore and progress.")
167
+ # short history summary
168
+ if history:
169
+ last = history[-1]
170
+ user_lines.append(f"Last actions: {last.get('actions')}")
171
+ content = "\n".join(user_lines)
172
+ return [{"role": "user", "content": content}]
173
+
174
+ def _parse_tool_calls_from_openai_response(data: Dict[str, Any]) -> List[str]:
175
+ try:
176
+ choices = data.get("choices")
177
+ if isinstance(choices, list) and choices:
178
+ msg = choices[0].get("message", {}) if isinstance(choices[0], dict) else {}
179
+ tcs = msg.get("tool_calls")
180
+ if isinstance(tcs, list) and tcs:
181
+ fn = tcs[0].get("function", {}) if isinstance(tcs[0], dict) else {}
182
+ args = fn.get("arguments")
183
+ if isinstance(args, str):
184
+ try:
185
+ obj = json.loads(args)
186
+ except Exception:
187
+ obj = {}
188
+ elif isinstance(args, dict):
189
+ obj = args
190
+ else:
191
+ obj = {}
192
+ acts = obj.get("actions")
193
+ if isinstance(acts, list):
194
+ return [str(a) for a in acts][:5]
195
+ except Exception:
196
+ pass
197
+ # Fallback: parse JSON object from assistant text
198
+ try:
199
+ choices = data.get("choices")
200
+ msg = choices[0].get("message", {}) if isinstance(choices, list) and choices else {}
201
+ content = msg.get("content")
202
+ text = ""
203
+ if isinstance(content, str):
204
+ text = content
205
+ elif isinstance(content, list):
206
+ text = "\n".join(str(part.get("text")) for part in content if isinstance(part, dict) and part.get("text"))
207
+ for raw in re.findall(r"\{[\s\S]*\}", text or ""):
208
+ try:
209
+ obj = json.loads(raw)
210
+ if isinstance(obj, dict) and obj.get("tool") in ("interact", None):
211
+ acts = obj.get("args", {}).get("actions")
212
+ if isinstance(acts, list):
213
+ return [str(a) for a in acts][:5]
214
+ except Exception:
215
+ continue
216
+ except Exception:
217
+ pass
218
+ return []
219
+
220
+ async def _choose_actions_via_llm(client: TaskAppClient, provider: str, model: str, observation: Dict[str, Any], history: List[Dict[str, Any]]) -> List[str]:
221
+ messages = _build_messages_from_observation(observation, history)
222
+ payload: Dict[str, Any] = {
223
+ "model": model,
224
+ "messages": messages,
225
+ "tools": _interact_tool_schema(),
226
+ "max_tokens": 8192,
227
+ "temperature": 0.2,
228
+ }
229
+ if provider == "groq":
230
+ # Groq path: avoid forcing tool_choice to reduce 400 errors; proxy will synthesize if missing
231
+ data = await client.proxy_groq_chat(payload)
232
+ elif provider == "vllm":
233
+ info = await client.get_info()
234
+ vllm_base = ((info or {}).get("inference") or {}).get("base_url")
235
+ if not vllm_base:
236
+ return []
237
+ # For vLLM path, we can force the single tool
238
+ vllm_payload = dict(payload)
239
+ vllm_payload["tool_choice"] = {"type": "function", "function": {"name": "interact"}}
240
+ data = await client.vllm_chat(str(vllm_base), vllm_payload)
241
+ if isinstance(data, dict) and data.get("error"):
242
+ return []
243
+ else:
244
+ return []
245
+ actions = _parse_tool_calls_from_openai_response(data)
246
+ return actions or []
247
+
248
+ def _expand_actions_to_tool_calls(actions: List[str]) -> List[Dict[str, Any]]:
249
+ out: List[Dict[str, Any]] = []
250
+ for a in actions[:5]:
251
+ out.append({"tool": "interact", "args": {"action": a}})
252
+ return out
253
+
254
+ def _detect_provider(model: str) -> str:
255
+ m = (model or "").lower()
256
+ if "qwen/qwen3-32b" in m or "qwen-2.5-" in m or m.startswith("groq:"):
257
+ return "groq"
258
+ return "vllm"
259
+
260
+ def _rollout_inference_url_from_cfg(cfg: Dict[str, Any], default_vllm: Optional[str]) -> Optional[str]:
261
+ # Prefer explicit inference_url in TOML; else fall back to discovered vLLM base
262
+ url = cfg.get("inference_url")
263
+ if isinstance(url, str) and url:
264
+ return url
265
+ return default_vllm
266
+
267
+ async def eval_episode(client: TaskAppClient, seed: int) -> Dict[str, Any]:
268
+ env_name = "CrafterClassic"
269
+ history: List[Dict[str, Any]] = []
270
+ achievements: set[str] = set()
271
+ turns = 0
272
+
273
+ # Initialize environment
274
+ init_cfg: Dict[str, Any] = {"seed": seed, "world_config": {"difficulty": os.getenv("DIFFICULTY", "easy")}}
275
+ created = await client.initialize(env_name, init_cfg)
276
+ env_id = created.get("env_id")
277
+ if not isinstance(env_id, str) or not env_id:
278
+ raise RuntimeError(f"Invalid env_id from initialize: {created}")
279
+ done = False
280
+ provider = _detect_provider(MODEL)
281
+ observation = created.get("observation") if isinstance(created, dict) else None
282
+ if not isinstance(observation, dict):
283
+ observation = {}
284
+
285
+ try:
286
+ while turns < MAX_TURNS and not done:
287
+ # Ask LLM for actions; fallback to a simple exploratory pair
288
+ chosen_actions = await _choose_actions_via_llm(client, provider, MODEL, observation, history)
289
+ if not chosen_actions:
290
+ chosen_actions = ["move_up", "do"]
291
+ tool_calls = _expand_actions_to_tool_calls(chosen_actions)
292
+ step = await client.step(env_name, env_id, tool_calls)
293
+ done = bool(step.get("done"))
294
+ turns += 1
295
+ history.append({"actions": chosen_actions, "reasoning": "explore then interact"})
296
+ # Update observation for next turn if available
297
+ if isinstance(step, dict):
298
+ nxt = step.get("observation")
299
+ if isinstance(nxt, dict):
300
+ observation = nxt
301
+ finally:
302
+ try:
303
+ await client.terminate(env_name, env_id)
304
+ except Exception:
305
+ pass
306
+
307
+ return {"seed": seed, "turns": turns, "achievements": sorted(achievements)}
308
+
309
+ async def main() -> None:
310
+ # Best-effort load local .env if present (ensures ENVIRONMENT_API_KEY for rollout)
311
+ try:
312
+ env_path = Path(__file__).resolve().parent / ".env"
313
+ if env_path.exists():
314
+ for line in env_path.read_text(encoding="utf-8").splitlines():
315
+ line = line.strip()
316
+ if not line or line.startswith("#") or "=" not in line:
317
+ continue
318
+ k, v = line.split("=", 1)
319
+ k = k.strip()
320
+ v = v.strip().strip('"').strip("'")
321
+ os.environ.setdefault(k, v)
322
+ except Exception:
323
+ pass
324
+
325
+ parser = argparse.ArgumentParser(description="Baseline eval against task app with optional TOML config")
326
+ parser.add_argument("--toml", help="Path to TOML config file", default=None)
327
+ parser.add_argument("--use-rollout", action="store_true", help="Use server-side rollout endpoint for eval")
328
+ args = parser.parse_args()
329
+
330
+ global TASK_APP_URL, MODEL, NUM_EPISODES, MAX_TURNS, CONCURRENCY
331
+ cfg: Dict[str, Any] = {}
332
+ if args.toml:
333
+ with open(args.toml, "rb") as f:
334
+ cfg = tomllib.load(f)
335
+ # Map known keys; tolerate missing
336
+ TASK_APP_URL = (cfg.get("task_app_url") or TASK_APP_URL).rstrip("/")
337
+ MODEL = cfg.get("model") or MODEL
338
+ NUM_EPISODES = int(cfg.get("num_episodes") or NUM_EPISODES)
339
+ MAX_TURNS = int(cfg.get("max_turns") or MAX_TURNS)
340
+ CONCURRENCY = int(cfg.get("concurrency") or CONCURRENCY)
341
+ if "difficulty" in cfg:
342
+ os.environ["DIFFICULTY"] = str(cfg.get("difficulty"))
343
+ # Replace placeholder URLs with env if present
344
+ if "your-task-app.modal.run" in TASK_APP_URL.lower():
345
+ env_url = os.getenv("TASK_APP_URL")
346
+ if env_url:
347
+ TASK_APP_URL = env_url.rstrip("/")
348
+ else:
349
+ raise RuntimeError("TASK_APP_URL is a placeholder. Set task_app_url in TOML or export TASK_APP_URL.")
350
+
351
+ print(f"Task App: {TASK_APP_URL}")
352
+ print(f"Model: {MODEL} Episodes: {NUM_EPISODES} Max turns: {MAX_TURNS} Concurrency: {CONCURRENCY}")
353
+ sem = asyncio.Semaphore(max(CONCURRENCY, 1))
354
+ async with TaskAppClient(TASK_APP_URL, api_key=os.getenv("ENVIRONMENT_API_KEY")) as client:
355
+ if args.use_rollout:
356
+ # Use server-side rollout; derive inference URL per provider
357
+ info = await client.get_info()
358
+ default_vllm = ((info or {}).get("inference") or {}).get("base_url")
359
+ inf_url = _rollout_inference_url_from_cfg(cfg, default_vllm)
360
+ if not inf_url:
361
+ raise RuntimeError("Could not resolve inference URL for rollout")
362
+ async def _run(seed: int):
363
+ async with sem:
364
+ try:
365
+ run_id = f"eval-{seed}"
366
+ # Build policy config from TOML (explicit control; no server-side guessing)
367
+ policy_cfg: Dict[str, Any] = {
368
+ "model": cfg.get("model", MODEL),
369
+ "inference_url": inf_url,
370
+ }
371
+ for k in ("max_tokens", "temperature", "top_p", "thinking_mode", "thinking_budget", "use_tools"):
372
+ if k in cfg and cfg.get(k) is not None:
373
+ policy_cfg[k] = cfg.get(k)
374
+
375
+ r = await client.rollout(
376
+ run_id=run_id,
377
+ env_name="crafter",
378
+ seed=seed,
379
+ difficulty=os.getenv("DIFFICULTY", "easy"),
380
+ policy_name=cfg.get("policy_name", "crafter"),
381
+ policy_config=policy_cfg,
382
+ max_turns=MAX_TURNS,
383
+ )
384
+ # Extract achievements count if present
385
+ ach = []
386
+ try:
387
+ trajs = r.get("trajectories") or []
388
+ final_obs = (trajs[0].get("final") or {}).get("observation") if trajs and isinstance(trajs[0], dict) else None
389
+ ach_map = (final_obs or {}).get("achievements_status") if isinstance(final_obs, dict) else None
390
+ if isinstance(ach_map, dict):
391
+ ach = sorted([k for k, v in ach_map.items() if v])
392
+ except Exception:
393
+ pass
394
+ length = 0
395
+ try:
396
+ trajs = r.get("trajectories") or []
397
+ if trajs and isinstance(trajs[0], dict):
398
+ length = int(trajs[0].get("length") or 0)
399
+ except Exception:
400
+ pass
401
+ return {"seed": seed, "turns": length, "achievements": ach}
402
+ except Exception as e:
403
+ return {"seed": seed, "turns": 0, "achievements": [], "error": str(e)}
404
+ results = await asyncio.gather(*[asyncio.create_task(_run(i)) for i in range(1, NUM_EPISODES + 1)], return_exceptions=False)
405
+ # Aggregate summary
406
+ counts = [len(r.get("achievements") or []) for r in results if isinstance(r, dict)]
407
+ turns = [int(r.get("turns") or 0) for r in results if isinstance(r, dict)]
408
+ all_ach = Counter()
409
+ for r in results:
410
+ try:
411
+ for a in r.get("achievements") or []:
412
+ all_ach[a] += 1
413
+ except Exception:
414
+ pass
415
+ summary = {
416
+ "episodes": results,
417
+ "aggregate": {
418
+ "completed": sum(1 for r in results if not r.get("error")),
419
+ "total": len(results),
420
+ "avg_turns": (sum(turns) / len(turns)) if turns else 0.0,
421
+ "avg_achievements": (sum(counts) / len(counts)) if counts else 0.0,
422
+ "achievements_freq": dict(all_ach),
423
+ },
424
+ }
425
+ print(json.dumps(summary, indent=2))
426
+ else:
427
+ async def _run(seed: int):
428
+ async with sem:
429
+ return await eval_episode(client, seed)
430
+ results = await asyncio.gather(*[asyncio.create_task(_run(i)) for i in range(1, NUM_EPISODES + 1)])
431
+ print(json.dumps({"episodes": results}, indent=2))
432
+
433
+ if __name__ == "__main__":
434
+ asyncio.run(main())