synth-ai 0.2.9.dev4__py3-none-any.whl → 0.2.9.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/common_old/backend.py +0 -1
- examples/crafter_debug_render.py +15 -6
- examples/evals_old/compare_models.py +1 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +6 -2
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +4 -4
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +4 -3
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +6 -2
- examples/finetuning_old/synth_qwen_v1/finetune.py +1 -1
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +4 -4
- examples/finetuning_old/synth_qwen_v1/infer.py +1 -2
- examples/finetuning_old/synth_qwen_v1/poll.py +4 -2
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +8 -8
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +5 -4
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +11 -8
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +17 -12
- examples/finetuning_old/synth_qwen_v1/upload_data.py +1 -1
- examples/finetuning_old/synth_qwen_v1/util.py +7 -2
- examples/rl/configs/eval_base_qwen.toml +1 -1
- examples/rl/configs/rl_from_base_qwen17.toml +1 -1
- examples/rl/download_dataset.py +26 -10
- examples/rl/run_eval.py +17 -15
- examples/rl/run_rl_and_save.py +24 -7
- examples/rl/task_app/math_single_step.py +128 -11
- examples/rl/task_app/math_task_app.py +11 -3
- examples/rl_old/task_app.py +222 -53
- examples/warming_up_to_rl/analyze_trace_db.py +7 -5
- examples/warming_up_to_rl/export_trace_sft.py +141 -16
- examples/warming_up_to_rl/groq_test.py +11 -4
- examples/warming_up_to_rl/manage_secrets.py +15 -6
- examples/warming_up_to_rl/readme.md +9 -2
- examples/warming_up_to_rl/run_eval.py +108 -30
- examples/warming_up_to_rl/run_fft_and_save.py +128 -52
- examples/warming_up_to_rl/run_local_rollout.py +87 -36
- examples/warming_up_to_rl/run_local_rollout_modal.py +113 -25
- examples/warming_up_to_rl/run_local_rollout_parallel.py +80 -16
- examples/warming_up_to_rl/run_local_rollout_traced.py +125 -20
- examples/warming_up_to_rl/run_rl_and_save.py +31 -7
- examples/warming_up_to_rl/run_rollout_remote.py +37 -10
- examples/warming_up_to_rl/task_app/grpo_crafter.py +90 -27
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +9 -27
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +46 -108
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +50 -17
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +35 -21
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +8 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +29 -26
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +17 -13
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +106 -63
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +82 -84
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +76 -59
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +43 -49
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +5 -15
- synth_ai/__init__.py +1 -0
- synth_ai/api/train/builders.py +34 -10
- synth_ai/api/train/cli.py +172 -32
- synth_ai/api/train/config_finder.py +59 -4
- synth_ai/api/train/env_resolver.py +32 -14
- synth_ai/api/train/pollers.py +11 -3
- synth_ai/api/train/task_app.py +4 -1
- synth_ai/api/train/utils.py +20 -4
- synth_ai/cli/__init__.py +11 -4
- synth_ai/cli/balance.py +1 -1
- synth_ai/cli/demo.py +19 -5
- synth_ai/cli/rl_demo.py +75 -16
- synth_ai/cli/root.py +116 -37
- synth_ai/cli/task_apps.py +1286 -170
- synth_ai/cli/traces.py +1 -0
- synth_ai/cli/turso.py +73 -0
- synth_ai/core/experiment.py +0 -2
- synth_ai/demo_registry.py +67 -30
- synth_ai/demos/core/cli.py +493 -164
- synth_ai/demos/demo_task_apps/core.py +50 -6
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +36 -28
- synth_ai/demos/demo_task_apps/math/_common.py +1 -2
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +0 -2
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +168 -65
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
- synth_ai/environments/examples/bandit/engine.py +12 -4
- synth_ai/environments/examples/bandit/taskset.py +4 -4
- synth_ai/environments/reproducibility/tree.py +3 -1
- synth_ai/environments/service/core_routes.py +6 -2
- synth_ai/evals/base.py +0 -2
- synth_ai/experimental/synth_oss.py +11 -12
- synth_ai/handshake.py +3 -1
- synth_ai/http_client.py +31 -7
- synth_ai/inference/__init__.py +0 -2
- synth_ai/inference/client.py +8 -4
- synth_ai/jobs/client.py +40 -10
- synth_ai/learning/client.py +33 -8
- synth_ai/learning/config.py +0 -2
- synth_ai/learning/constants.py +0 -2
- synth_ai/learning/ft_client.py +6 -3
- synth_ai/learning/health.py +9 -2
- synth_ai/learning/jobs.py +17 -5
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +1 -3
- synth_ai/learning/prompts/random_search.py +4 -1
- synth_ai/learning/prompts/run_random_search_banking77.py +6 -1
- synth_ai/learning/rl_client.py +42 -14
- synth_ai/learning/sse.py +0 -2
- synth_ai/learning/validators.py +6 -2
- synth_ai/lm/caching/ephemeral.py +1 -3
- synth_ai/lm/core/exceptions.py +0 -2
- synth_ai/lm/core/main.py +13 -1
- synth_ai/lm/core/synth_models.py +0 -1
- synth_ai/lm/core/vendor_clients.py +4 -2
- synth_ai/lm/overrides.py +2 -2
- synth_ai/lm/vendors/core/anthropic_api.py +7 -7
- synth_ai/lm/vendors/core/openai_api.py +2 -0
- synth_ai/lm/vendors/openai_standard.py +3 -1
- synth_ai/lm/vendors/openai_standard_responses.py +6 -3
- synth_ai/lm/vendors/supported/custom_endpoint.py +1 -3
- synth_ai/lm/vendors/synth_client.py +37 -10
- synth_ai/rl/__init__.py +0 -1
- synth_ai/rl/contracts.py +0 -2
- synth_ai/rl/env_keys.py +6 -1
- synth_ai/task/__init__.py +1 -0
- synth_ai/task/apps/__init__.py +11 -11
- synth_ai/task/auth.py +29 -17
- synth_ai/task/client.py +3 -1
- synth_ai/task/contracts.py +1 -0
- synth_ai/task/datasets.py +3 -1
- synth_ai/task/errors.py +3 -2
- synth_ai/task/health.py +0 -2
- synth_ai/task/json.py +0 -1
- synth_ai/task/proxy.py +2 -5
- synth_ai/task/rubrics.py +9 -3
- synth_ai/task/server.py +31 -5
- synth_ai/task/tracing_utils.py +8 -3
- synth_ai/task/validators.py +0 -1
- synth_ai/task/vendors.py +0 -1
- synth_ai/tracing_v3/db_config.py +26 -1
- synth_ai/tracing_v3/decorators.py +1 -0
- synth_ai/tracing_v3/examples/basic_usage.py +3 -2
- synth_ai/tracing_v3/hooks.py +2 -0
- synth_ai/tracing_v3/replica_sync.py +1 -0
- synth_ai/tracing_v3/session_tracer.py +24 -3
- synth_ai/tracing_v3/storage/base.py +4 -1
- synth_ai/tracing_v3/storage/factory.py +0 -1
- synth_ai/tracing_v3/turso/manager.py +102 -38
- synth_ai/tracing_v3/turso/models.py +4 -1
- synth_ai/tracing_v3/utils.py +1 -0
- synth_ai/v0/tracing/upload.py +32 -135
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/METADATA +1 -1
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/RECORD +154 -156
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_stepwise_rewards.py +0 -58
- synth_ai/environments/examples/sokoban/units/astar_common.py +0 -95
- synth_ai/install_sqld.sh +0 -40
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/top_level.txt +0 -0
|
@@ -171,9 +171,7 @@ async def step_policy(
|
|
|
171
171
|
"""Execute a policy step to generate actions."""
|
|
172
172
|
handle = registry.get_policy(request.policy_id)
|
|
173
173
|
if not handle:
|
|
174
|
-
raise HTTPException(
|
|
175
|
-
status_code=404, detail=f"Policy {request.policy_id} not found"
|
|
176
|
-
)
|
|
174
|
+
raise HTTPException(status_code=404, detail=f"Policy {request.policy_id} not found")
|
|
177
175
|
|
|
178
176
|
try:
|
|
179
177
|
task_app = req.app.state.task_app
|
|
@@ -196,9 +194,7 @@ async def step_policy(
|
|
|
196
194
|
from .envs.wordle.shared import format_observation_wordle
|
|
197
195
|
|
|
198
196
|
# ASSERTION: Validate observation structure
|
|
199
|
-
assert request.observation is not None,
|
|
200
|
-
"request.observation cannot be None"
|
|
201
|
-
)
|
|
197
|
+
assert request.observation is not None, "request.observation cannot be None"
|
|
202
198
|
assert isinstance(request.observation, dict), (
|
|
203
199
|
f"request.observation must be dict, got {type(request.observation)}"
|
|
204
200
|
)
|
|
@@ -215,22 +211,14 @@ async def step_policy(
|
|
|
215
211
|
"terminated",
|
|
216
212
|
}
|
|
217
213
|
missing_keys = required_keys - set(request.observation.keys())
|
|
218
|
-
assert not missing_keys,
|
|
219
|
-
f"Wordle observation missing required keys: {missing_keys}"
|
|
220
|
-
)
|
|
214
|
+
assert not missing_keys, f"Wordle observation missing required keys: {missing_keys}"
|
|
221
215
|
|
|
222
216
|
print("DEBUG POLICY_ROUTES: About to format Wordle observation")
|
|
223
|
-
print(
|
|
224
|
-
|
|
225
|
-
)
|
|
226
|
-
print(
|
|
227
|
-
f"DEBUG POLICY_ROUTES: Observation keys: {list(request.observation.keys())}"
|
|
228
|
-
)
|
|
217
|
+
print(f"DEBUG POLICY_ROUTES: Observation type: {type(request.observation)}")
|
|
218
|
+
print(f"DEBUG POLICY_ROUTES: Observation keys: {list(request.observation.keys())}")
|
|
229
219
|
feedback_val = request.observation["feedback"]
|
|
230
220
|
print(f"DEBUG POLICY_ROUTES: Observation feedback: {feedback_val}")
|
|
231
|
-
print(
|
|
232
|
-
f"DEBUG POLICY_ROUTES: Observation guesses: {request.observation['guesses']}"
|
|
233
|
-
)
|
|
221
|
+
print(f"DEBUG POLICY_ROUTES: Observation guesses: {request.observation['guesses']}")
|
|
234
222
|
print(
|
|
235
223
|
f"DEBUG POLICY_ROUTES: Observation text length: {len(request.observation['text'])}"
|
|
236
224
|
)
|
|
@@ -238,50 +226,34 @@ async def step_policy(
|
|
|
238
226
|
# ASSERTION: Validate feedback data
|
|
239
227
|
guesses = request.observation["guesses"]
|
|
240
228
|
feedback = request.observation["feedback"]
|
|
241
|
-
assert isinstance(guesses, list), (
|
|
242
|
-
|
|
243
|
-
)
|
|
244
|
-
assert isinstance(feedback, list), (
|
|
245
|
-
f"feedback must be list, got {type(feedback)}"
|
|
246
|
-
)
|
|
229
|
+
assert isinstance(guesses, list), f"guesses must be list, got {type(guesses)}"
|
|
230
|
+
assert isinstance(feedback, list), f"feedback must be list, got {type(feedback)}"
|
|
247
231
|
# Note: We don't assert equal lengths here since the environment is broken
|
|
248
232
|
|
|
249
233
|
obs_text = format_observation_wordle(request.observation)
|
|
250
234
|
|
|
251
235
|
# ASSERTION: Validate formatted output
|
|
252
|
-
assert isinstance(obs_text, str), (
|
|
253
|
-
f"obs_text must be string, got {type(obs_text)}"
|
|
254
|
-
)
|
|
236
|
+
assert isinstance(obs_text, str), f"obs_text must be string, got {type(obs_text)}"
|
|
255
237
|
assert len(obs_text) > 0, "obs_text cannot be empty"
|
|
256
238
|
assert "WORDLE" in obs_text, "obs_text must contain 'WORDLE' header"
|
|
257
239
|
assert "Respond with a single tool call" in obs_text, (
|
|
258
240
|
"obs_text must contain instruction text"
|
|
259
241
|
)
|
|
260
242
|
|
|
261
|
-
print(
|
|
262
|
-
|
|
263
|
-
)
|
|
264
|
-
print(
|
|
265
|
-
|
|
266
|
-
)
|
|
267
|
-
print(
|
|
268
|
-
f"DEBUG POLICY_ROUTES: Formatted obs_text contains 🟨: {'🟨' in obs_text}"
|
|
269
|
-
)
|
|
270
|
-
print(
|
|
271
|
-
f"DEBUG POLICY_ROUTES: Formatted obs_text contains ⬛: {'⬛' in obs_text}"
|
|
272
|
-
)
|
|
273
|
-
print(
|
|
274
|
-
f"DEBUG POLICY_ROUTES: Formatted obs_text first 200 chars: {obs_text[:200]}"
|
|
275
|
-
)
|
|
243
|
+
print(f"DEBUG POLICY_ROUTES: Formatted obs_text length: {len(obs_text)}")
|
|
244
|
+
print(f"DEBUG POLICY_ROUTES: Formatted obs_text contains 🟩: {'🟩' in obs_text}")
|
|
245
|
+
print(f"DEBUG POLICY_ROUTES: Formatted obs_text contains 🟨: {'🟨' in obs_text}")
|
|
246
|
+
print(f"DEBUG POLICY_ROUTES: Formatted obs_text contains ⬛: {'⬛' in obs_text}")
|
|
247
|
+
print(f"DEBUG POLICY_ROUTES: Formatted obs_text first 200 chars: {obs_text[:200]}")
|
|
276
248
|
elif True:
|
|
277
249
|
try:
|
|
278
250
|
from .envs.sokoban.policy import SokobanPolicy as _SokobanPolicy
|
|
279
251
|
except Exception:
|
|
280
252
|
_SokobanPolicy = None # type: ignore
|
|
281
|
-
|
|
253
|
+
|
|
282
254
|
if _SokobanPolicy is not None and isinstance(policy, _SokobanPolicy):
|
|
283
255
|
from .envs.sokoban.shared import format_observation_sokoban
|
|
284
|
-
|
|
256
|
+
|
|
285
257
|
obs_text = format_observation_sokoban(request.observation)
|
|
286
258
|
elif True:
|
|
287
259
|
try:
|
|
@@ -291,7 +263,9 @@ async def step_policy(
|
|
|
291
263
|
if _MathPolicy is not None and isinstance(policy, _MathPolicy):
|
|
292
264
|
# Simple extraction of problem text
|
|
293
265
|
try:
|
|
294
|
-
obs_text = str(
|
|
266
|
+
obs_text = str(
|
|
267
|
+
request.observation.get("problem_text") or request.observation
|
|
268
|
+
)
|
|
295
269
|
except Exception:
|
|
296
270
|
obs_text = str(request.observation)
|
|
297
271
|
else:
|
|
@@ -316,9 +290,7 @@ async def step_policy(
|
|
|
316
290
|
user_messages: List[str] = []
|
|
317
291
|
if msgs and len(msgs) > 0 and msgs[0]["role"] == "system":
|
|
318
292
|
sys_text = msgs[0]["content"]
|
|
319
|
-
policy_name = (
|
|
320
|
-
getattr(policy, "name", "") or type(policy).__name__.lower()
|
|
321
|
-
)
|
|
293
|
+
policy_name = getattr(policy, "name", "") or type(policy).__name__.lower()
|
|
322
294
|
|
|
323
295
|
# Assert environment-specific prompts match the policy
|
|
324
296
|
if policy_name in ("wordle-react", "wordle"):
|
|
@@ -363,6 +335,7 @@ async def step_policy(
|
|
|
363
335
|
|
|
364
336
|
# Emit full system/user prompts for observability (no secrets included)
|
|
365
337
|
try:
|
|
338
|
+
|
|
366
339
|
def _as_text(content: object) -> str:
|
|
367
340
|
if isinstance(content, str):
|
|
368
341
|
return content
|
|
@@ -404,7 +377,7 @@ async def step_policy(
|
|
|
404
377
|
# Print concise preview for visibility in standard logs
|
|
405
378
|
try:
|
|
406
379
|
last_user = user_messages[-1] if user_messages else ""
|
|
407
|
-
#preview = last_user[:400] if isinstance(last_user, str) else str(last_user)[:400]
|
|
380
|
+
# preview = last_user[:400] if isinstance(last_user, str) else str(last_user)[:400]
|
|
408
381
|
print(f"[task:crafter] user prompt: {last_user}", flush=True)
|
|
409
382
|
except Exception:
|
|
410
383
|
pass
|
|
@@ -435,16 +408,27 @@ async def step_policy(
|
|
|
435
408
|
api_key_override = None
|
|
436
409
|
try:
|
|
437
410
|
import os as _os
|
|
411
|
+
|
|
438
412
|
if isinstance(target_url, str):
|
|
439
413
|
low_url = target_url.lower()
|
|
440
414
|
if "openai.com" in low_url:
|
|
441
|
-
api_key_override = _os.getenv("OPENAI_API_KEY") or getattr(
|
|
415
|
+
api_key_override = _os.getenv("OPENAI_API_KEY") or getattr(
|
|
416
|
+
task_app, "openai_api_key", None
|
|
417
|
+
)
|
|
442
418
|
elif "groq.com" in low_url:
|
|
443
419
|
api_key_override = _os.getenv("GROQ_API_KEY")
|
|
444
420
|
else:
|
|
445
|
-
api_key_override =
|
|
421
|
+
api_key_override = (
|
|
422
|
+
_os.getenv("SYNTH_API_KEY")
|
|
423
|
+
or _os.getenv("OPENAI_API_KEY")
|
|
424
|
+
or getattr(task_app, "openai_api_key", None)
|
|
425
|
+
)
|
|
446
426
|
else:
|
|
447
|
-
api_key_override =
|
|
427
|
+
api_key_override = (
|
|
428
|
+
_os.getenv("SYNTH_API_KEY")
|
|
429
|
+
or _os.getenv("OPENAI_API_KEY")
|
|
430
|
+
or getattr(task_app, "openai_api_key", None)
|
|
431
|
+
)
|
|
448
432
|
except Exception:
|
|
449
433
|
api_key_override = None
|
|
450
434
|
|
|
@@ -455,7 +439,9 @@ async def step_policy(
|
|
|
455
439
|
masked = "<masked>"
|
|
456
440
|
logger.debug(f"INFERENCE_AUTH: Using bearer key {masked}")
|
|
457
441
|
else:
|
|
458
|
-
logger.warning(
|
|
442
|
+
logger.warning(
|
|
443
|
+
"INFERENCE_AUTH: No API key resolved for inference request; downstream may 401"
|
|
444
|
+
)
|
|
459
445
|
|
|
460
446
|
client = create_inference_client(task_app, api_key=api_key_override)
|
|
461
447
|
|
|
@@ -650,6 +636,7 @@ async def step_policy(
|
|
|
650
636
|
if model_for_diag and messages_for_diag:
|
|
651
637
|
try:
|
|
652
638
|
from transformers import AutoTokenizer
|
|
639
|
+
|
|
653
640
|
tok = AutoTokenizer.from_pretrained(model_for_diag)
|
|
654
641
|
prompt_preview = tok.apply_chat_template(
|
|
655
642
|
messages_for_diag,
|
|
@@ -660,7 +647,9 @@ async def step_policy(
|
|
|
660
647
|
max_len = getattr(tok, "model_max_length", None)
|
|
661
648
|
over_limit = False
|
|
662
649
|
try:
|
|
663
|
-
over_limit =
|
|
650
|
+
over_limit = (
|
|
651
|
+
isinstance(max_len, int) and max_len > 0 and len(ids) > int(max_len)
|
|
652
|
+
)
|
|
664
653
|
except Exception:
|
|
665
654
|
over_limit = False
|
|
666
655
|
if over_limit or len(ids) > 10000:
|
|
@@ -672,7 +661,9 @@ async def step_policy(
|
|
|
672
661
|
"prompt_token_overflow_local": True,
|
|
673
662
|
"model": str(model_for_diag),
|
|
674
663
|
"token_count": int(len(ids)),
|
|
675
|
-
"model_max_length": int(max_len)
|
|
664
|
+
"model_max_length": int(max_len)
|
|
665
|
+
if isinstance(max_len, int)
|
|
666
|
+
else None,
|
|
676
667
|
"preview_tokens_logged": int(len(preview_ids)),
|
|
677
668
|
"prompt_preview_first_10k_tokens": preview_text,
|
|
678
669
|
}
|
|
@@ -682,7 +673,9 @@ async def step_policy(
|
|
|
682
673
|
try:
|
|
683
674
|
meta["prompt_debug"] = {
|
|
684
675
|
"token_count": int(len(ids)),
|
|
685
|
-
"model_max_length": int(max_len)
|
|
676
|
+
"model_max_length": int(max_len)
|
|
677
|
+
if isinstance(max_len, int)
|
|
678
|
+
else None,
|
|
686
679
|
"preview_first_10k_tokens": preview_text,
|
|
687
680
|
}
|
|
688
681
|
except Exception:
|
|
@@ -700,14 +693,19 @@ async def step_policy(
|
|
|
700
693
|
if isinstance(msgs, list):
|
|
701
694
|
# Print compact messages structure and tool schema with bounded length
|
|
702
695
|
import json as _json
|
|
696
|
+
|
|
703
697
|
msgs_compact = _json.dumps(msgs)[:20000]
|
|
704
|
-
tools_compact =
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
698
|
+
tools_compact = (
|
|
699
|
+
_json.dumps(tools_dump)[:8000] if tools_dump is not None else None
|
|
700
|
+
)
|
|
701
|
+
print(
|
|
702
|
+
{
|
|
703
|
+
"llm.call": True,
|
|
704
|
+
"policy": str(policy_name),
|
|
705
|
+
"messages_preview": msgs_compact,
|
|
706
|
+
"tools_preview": tools_compact,
|
|
707
|
+
}
|
|
708
|
+
)
|
|
711
709
|
except Exception:
|
|
712
710
|
pass
|
|
713
711
|
|
|
@@ -724,13 +722,20 @@ async def step_policy(
|
|
|
724
722
|
try:
|
|
725
723
|
tools_arr = req_body.get("tools") or []
|
|
726
724
|
if isinstance(tools_arr, list) and tools_arr:
|
|
727
|
-
f =
|
|
725
|
+
f = (
|
|
726
|
+
tools_arr[0].get("function")
|
|
727
|
+
if isinstance(tools_arr[0], dict)
|
|
728
|
+
else None
|
|
729
|
+
)
|
|
728
730
|
cand = (f or {}).get("name") if isinstance(f, dict) else None
|
|
729
731
|
if isinstance(cand, str) and cand:
|
|
730
732
|
func_name = cand
|
|
731
733
|
except Exception:
|
|
732
734
|
pass
|
|
733
|
-
req_body["tool_choice"] = {
|
|
735
|
+
req_body["tool_choice"] = {
|
|
736
|
+
"type": "function",
|
|
737
|
+
"function": {"name": func_name},
|
|
738
|
+
}
|
|
734
739
|
req_body["parallel_tool_calls"] = False
|
|
735
740
|
req_body.setdefault("function_call", {"name": func_name})
|
|
736
741
|
# Inject extra_body for thinking controls expected by Modal service
|
|
@@ -799,10 +804,13 @@ async def step_policy(
|
|
|
799
804
|
else:
|
|
800
805
|
try:
|
|
801
806
|
import json as _json
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
807
|
+
|
|
808
|
+
print(
|
|
809
|
+
{
|
|
810
|
+
"tool_calls_parsed": int(len(tool_calls)),
|
|
811
|
+
"tool_calls_preview": _json.dumps(tool_calls)[:20000],
|
|
812
|
+
}
|
|
813
|
+
)
|
|
806
814
|
except Exception:
|
|
807
815
|
logger.info(f"Parsed {len(tool_calls)} tool calls: {tool_calls}")
|
|
808
816
|
|
|
@@ -814,9 +822,7 @@ async def step_policy(
|
|
|
814
822
|
inference_response, getattr(policy, "use_tools", True)
|
|
815
823
|
)
|
|
816
824
|
else:
|
|
817
|
-
parsed = policy.parse_model_response(
|
|
818
|
-
inference_response, request.observation
|
|
819
|
-
)
|
|
825
|
+
parsed = policy.parse_model_response(inference_response, request.observation)
|
|
820
826
|
# Replace tool_calls with parsed result
|
|
821
827
|
if isinstance(parsed, list):
|
|
822
828
|
tool_calls = parsed
|
|
@@ -866,9 +872,7 @@ async def snapshot_policy(request: PolicySnapshotRequest) -> PolicySnapshotRespo
|
|
|
866
872
|
"""Create a snapshot of the policy state."""
|
|
867
873
|
handle = registry.get_policy(request.policy_id)
|
|
868
874
|
if not handle:
|
|
869
|
-
raise HTTPException(
|
|
870
|
-
status_code=404, detail=f"Policy {request.policy_id} not found"
|
|
871
|
-
)
|
|
875
|
+
raise HTTPException(status_code=404, detail=f"Policy {request.policy_id} not found")
|
|
872
876
|
|
|
873
877
|
try:
|
|
874
878
|
# Serialize policy state
|
|
@@ -906,9 +910,7 @@ async def restore_policy(request: PolicyRestoreRequest) -> PolicyRestoreResponse
|
|
|
906
910
|
"""Restore a policy from a snapshot."""
|
|
907
911
|
snapshot = registry.get_snapshot(request.snapshot_id)
|
|
908
912
|
if not snapshot:
|
|
909
|
-
raise HTTPException(
|
|
910
|
-
status_code=404, detail=f"Snapshot {request.snapshot_id} not found"
|
|
911
|
-
)
|
|
913
|
+
raise HTTPException(status_code=404, detail=f"Snapshot {request.snapshot_id} not found")
|
|
912
914
|
|
|
913
915
|
if snapshot.kind != "policy":
|
|
914
916
|
raise HTTPException(
|
|
@@ -956,9 +958,7 @@ async def restore_policy(request: PolicyRestoreRequest) -> PolicyRestoreResponse
|
|
|
956
958
|
return PolicyRestoreResponse(policy_id=policy_id)
|
|
957
959
|
|
|
958
960
|
except Exception as e:
|
|
959
|
-
logger.error(
|
|
960
|
-
f"Failed to restore policy from snapshot {request.snapshot_id}: {e}"
|
|
961
|
-
)
|
|
961
|
+
logger.error(f"Failed to restore policy from snapshot {request.snapshot_id}: {e}")
|
|
962
962
|
raise HTTPException(status_code=500, detail=str(e))
|
|
963
963
|
|
|
964
964
|
|
|
@@ -967,9 +967,7 @@ async def terminate_policy(request: PolicyTerminateRequest) -> PolicyTerminateRe
|
|
|
967
967
|
"""Terminate a policy and clean up resources."""
|
|
968
968
|
handle = registry.get_policy(request.policy_id)
|
|
969
969
|
if not handle:
|
|
970
|
-
raise HTTPException(
|
|
971
|
-
status_code=404, detail=f"Policy {request.policy_id} not found"
|
|
972
|
-
)
|
|
970
|
+
raise HTTPException(status_code=404, detail=f"Policy {request.policy_id} not found")
|
|
973
971
|
|
|
974
972
|
try:
|
|
975
973
|
# Call terminate on the policy
|