synth-ai 0.2.9.dev5__py3-none-any.whl → 0.2.9.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (155) hide show
  1. examples/common_old/backend.py +0 -1
  2. examples/crafter_debug_render.py +15 -6
  3. examples/evals_old/compare_models.py +1 -0
  4. examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +6 -2
  5. examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +4 -4
  6. examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +4 -3
  7. examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +6 -2
  8. examples/finetuning_old/synth_qwen_v1/finetune.py +1 -1
  9. examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +4 -4
  10. examples/finetuning_old/synth_qwen_v1/infer.py +1 -2
  11. examples/finetuning_old/synth_qwen_v1/poll.py +4 -2
  12. examples/finetuning_old/synth_qwen_v1/prepare_data.py +8 -8
  13. examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +5 -4
  14. examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +11 -8
  15. examples/finetuning_old/synth_qwen_v1/run_ft_job.py +17 -12
  16. examples/finetuning_old/synth_qwen_v1/upload_data.py +1 -1
  17. examples/finetuning_old/synth_qwen_v1/util.py +7 -2
  18. examples/rl/configs/eval_base_qwen.toml +1 -1
  19. examples/rl/configs/rl_from_base_qwen17.toml +1 -1
  20. examples/rl/download_dataset.py +26 -10
  21. examples/rl/run_eval.py +17 -15
  22. examples/rl/run_rl_and_save.py +24 -7
  23. examples/rl/task_app/math_single_step.py +128 -11
  24. examples/rl/task_app/math_task_app.py +11 -3
  25. examples/rl_old/task_app.py +222 -53
  26. examples/warming_up_to_rl/analyze_trace_db.py +7 -5
  27. examples/warming_up_to_rl/export_trace_sft.py +141 -16
  28. examples/warming_up_to_rl/groq_test.py +11 -4
  29. examples/warming_up_to_rl/manage_secrets.py +15 -6
  30. examples/warming_up_to_rl/readme.md +9 -2
  31. examples/warming_up_to_rl/run_eval.py +108 -30
  32. examples/warming_up_to_rl/run_fft_and_save.py +128 -52
  33. examples/warming_up_to_rl/run_local_rollout.py +87 -36
  34. examples/warming_up_to_rl/run_local_rollout_modal.py +113 -25
  35. examples/warming_up_to_rl/run_local_rollout_parallel.py +80 -16
  36. examples/warming_up_to_rl/run_local_rollout_traced.py +125 -20
  37. examples/warming_up_to_rl/run_rl_and_save.py +31 -7
  38. examples/warming_up_to_rl/run_rollout_remote.py +37 -10
  39. examples/warming_up_to_rl/task_app/grpo_crafter.py +90 -27
  40. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +9 -27
  41. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +46 -108
  42. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
  43. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
  44. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
  45. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +50 -17
  46. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +35 -21
  47. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +8 -4
  48. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +29 -26
  49. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
  50. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +17 -13
  51. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
  52. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +106 -63
  53. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +82 -84
  54. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +76 -59
  55. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
  56. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +43 -49
  57. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +5 -15
  58. synth_ai/__init__.py +1 -0
  59. synth_ai/api/train/builders.py +34 -10
  60. synth_ai/api/train/cli.py +172 -32
  61. synth_ai/api/train/config_finder.py +59 -4
  62. synth_ai/api/train/env_resolver.py +32 -14
  63. synth_ai/api/train/pollers.py +11 -3
  64. synth_ai/api/train/task_app.py +4 -1
  65. synth_ai/api/train/utils.py +20 -4
  66. synth_ai/cli/__init__.py +11 -4
  67. synth_ai/cli/balance.py +1 -1
  68. synth_ai/cli/demo.py +19 -5
  69. synth_ai/cli/rl_demo.py +75 -16
  70. synth_ai/cli/root.py +116 -37
  71. synth_ai/cli/task_apps.py +1276 -186
  72. synth_ai/cli/traces.py +1 -0
  73. synth_ai/cli/turso.py +73 -0
  74. synth_ai/core/experiment.py +0 -2
  75. synth_ai/demo_registry.py +67 -30
  76. synth_ai/demos/core/cli.py +493 -164
  77. synth_ai/demos/demo_task_apps/core.py +50 -6
  78. synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
  79. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +36 -28
  80. synth_ai/demos/demo_task_apps/math/_common.py +1 -2
  81. synth_ai/demos/demo_task_apps/math/deploy_modal.py +0 -2
  82. synth_ai/demos/demo_task_apps/math/modal_task_app.py +168 -65
  83. synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
  84. synth_ai/environments/examples/bandit/engine.py +12 -4
  85. synth_ai/environments/examples/bandit/taskset.py +4 -4
  86. synth_ai/environments/reproducibility/tree.py +3 -1
  87. synth_ai/environments/service/core_routes.py +6 -2
  88. synth_ai/evals/base.py +0 -2
  89. synth_ai/experimental/synth_oss.py +11 -12
  90. synth_ai/handshake.py +3 -1
  91. synth_ai/http_client.py +31 -7
  92. synth_ai/inference/__init__.py +0 -2
  93. synth_ai/inference/client.py +8 -4
  94. synth_ai/jobs/client.py +40 -10
  95. synth_ai/learning/client.py +33 -8
  96. synth_ai/learning/config.py +0 -2
  97. synth_ai/learning/constants.py +0 -2
  98. synth_ai/learning/ft_client.py +6 -3
  99. synth_ai/learning/health.py +9 -2
  100. synth_ai/learning/jobs.py +17 -5
  101. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +1 -3
  102. synth_ai/learning/prompts/random_search.py +4 -1
  103. synth_ai/learning/prompts/run_random_search_banking77.py +6 -1
  104. synth_ai/learning/rl_client.py +42 -14
  105. synth_ai/learning/sse.py +0 -2
  106. synth_ai/learning/validators.py +6 -2
  107. synth_ai/lm/caching/ephemeral.py +1 -3
  108. synth_ai/lm/core/exceptions.py +0 -2
  109. synth_ai/lm/core/main.py +13 -1
  110. synth_ai/lm/core/synth_models.py +0 -1
  111. synth_ai/lm/core/vendor_clients.py +4 -2
  112. synth_ai/lm/overrides.py +2 -2
  113. synth_ai/lm/vendors/core/anthropic_api.py +7 -7
  114. synth_ai/lm/vendors/core/openai_api.py +2 -0
  115. synth_ai/lm/vendors/openai_standard.py +3 -1
  116. synth_ai/lm/vendors/openai_standard_responses.py +6 -3
  117. synth_ai/lm/vendors/supported/custom_endpoint.py +1 -3
  118. synth_ai/lm/vendors/synth_client.py +37 -10
  119. synth_ai/rl/__init__.py +0 -1
  120. synth_ai/rl/contracts.py +0 -2
  121. synth_ai/rl/env_keys.py +6 -1
  122. synth_ai/task/__init__.py +1 -0
  123. synth_ai/task/apps/__init__.py +11 -11
  124. synth_ai/task/auth.py +29 -17
  125. synth_ai/task/client.py +3 -1
  126. synth_ai/task/contracts.py +1 -0
  127. synth_ai/task/datasets.py +3 -1
  128. synth_ai/task/errors.py +3 -2
  129. synth_ai/task/health.py +0 -2
  130. synth_ai/task/json.py +0 -1
  131. synth_ai/task/proxy.py +2 -5
  132. synth_ai/task/rubrics.py +9 -3
  133. synth_ai/task/server.py +31 -5
  134. synth_ai/task/tracing_utils.py +8 -3
  135. synth_ai/task/validators.py +0 -1
  136. synth_ai/task/vendors.py +0 -1
  137. synth_ai/tracing_v3/db_config.py +26 -1
  138. synth_ai/tracing_v3/decorators.py +1 -0
  139. synth_ai/tracing_v3/examples/basic_usage.py +3 -2
  140. synth_ai/tracing_v3/hooks.py +2 -0
  141. synth_ai/tracing_v3/replica_sync.py +1 -0
  142. synth_ai/tracing_v3/session_tracer.py +24 -3
  143. synth_ai/tracing_v3/storage/base.py +4 -1
  144. synth_ai/tracing_v3/storage/factory.py +0 -1
  145. synth_ai/tracing_v3/turso/manager.py +102 -38
  146. synth_ai/tracing_v3/turso/models.py +4 -1
  147. synth_ai/tracing_v3/utils.py +1 -0
  148. synth_ai/v0/tracing/upload.py +32 -135
  149. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/METADATA +1 -1
  150. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/RECORD +154 -154
  151. synth_ai/install_sqld.sh +0 -40
  152. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/WHEEL +0 -0
  153. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/entry_points.txt +0 -0
  154. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/licenses/LICENSE +0 -0
  155. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/top_level.txt +0 -0
@@ -171,9 +171,7 @@ async def step_policy(
171
171
  """Execute a policy step to generate actions."""
172
172
  handle = registry.get_policy(request.policy_id)
173
173
  if not handle:
174
- raise HTTPException(
175
- status_code=404, detail=f"Policy {request.policy_id} not found"
176
- )
174
+ raise HTTPException(status_code=404, detail=f"Policy {request.policy_id} not found")
177
175
 
178
176
  try:
179
177
  task_app = req.app.state.task_app
@@ -196,9 +194,7 @@ async def step_policy(
196
194
  from .envs.wordle.shared import format_observation_wordle
197
195
 
198
196
  # ASSERTION: Validate observation structure
199
- assert request.observation is not None, (
200
- "request.observation cannot be None"
201
- )
197
+ assert request.observation is not None, "request.observation cannot be None"
202
198
  assert isinstance(request.observation, dict), (
203
199
  f"request.observation must be dict, got {type(request.observation)}"
204
200
  )
@@ -215,22 +211,14 @@ async def step_policy(
215
211
  "terminated",
216
212
  }
217
213
  missing_keys = required_keys - set(request.observation.keys())
218
- assert not missing_keys, (
219
- f"Wordle observation missing required keys: {missing_keys}"
220
- )
214
+ assert not missing_keys, f"Wordle observation missing required keys: {missing_keys}"
221
215
 
222
216
  print("DEBUG POLICY_ROUTES: About to format Wordle observation")
223
- print(
224
- f"DEBUG POLICY_ROUTES: Observation type: {type(request.observation)}"
225
- )
226
- print(
227
- f"DEBUG POLICY_ROUTES: Observation keys: {list(request.observation.keys())}"
228
- )
217
+ print(f"DEBUG POLICY_ROUTES: Observation type: {type(request.observation)}")
218
+ print(f"DEBUG POLICY_ROUTES: Observation keys: {list(request.observation.keys())}")
229
219
  feedback_val = request.observation["feedback"]
230
220
  print(f"DEBUG POLICY_ROUTES: Observation feedback: {feedback_val}")
231
- print(
232
- f"DEBUG POLICY_ROUTES: Observation guesses: {request.observation['guesses']}"
233
- )
221
+ print(f"DEBUG POLICY_ROUTES: Observation guesses: {request.observation['guesses']}")
234
222
  print(
235
223
  f"DEBUG POLICY_ROUTES: Observation text length: {len(request.observation['text'])}"
236
224
  )
@@ -238,50 +226,34 @@ async def step_policy(
238
226
  # ASSERTION: Validate feedback data
239
227
  guesses = request.observation["guesses"]
240
228
  feedback = request.observation["feedback"]
241
- assert isinstance(guesses, list), (
242
- f"guesses must be list, got {type(guesses)}"
243
- )
244
- assert isinstance(feedback, list), (
245
- f"feedback must be list, got {type(feedback)}"
246
- )
229
+ assert isinstance(guesses, list), f"guesses must be list, got {type(guesses)}"
230
+ assert isinstance(feedback, list), f"feedback must be list, got {type(feedback)}"
247
231
  # Note: We don't assert equal lengths here since the environment is broken
248
232
 
249
233
  obs_text = format_observation_wordle(request.observation)
250
234
 
251
235
  # ASSERTION: Validate formatted output
252
- assert isinstance(obs_text, str), (
253
- f"obs_text must be string, got {type(obs_text)}"
254
- )
236
+ assert isinstance(obs_text, str), f"obs_text must be string, got {type(obs_text)}"
255
237
  assert len(obs_text) > 0, "obs_text cannot be empty"
256
238
  assert "WORDLE" in obs_text, "obs_text must contain 'WORDLE' header"
257
239
  assert "Respond with a single tool call" in obs_text, (
258
240
  "obs_text must contain instruction text"
259
241
  )
260
242
 
261
- print(
262
- f"DEBUG POLICY_ROUTES: Formatted obs_text length: {len(obs_text)}"
263
- )
264
- print(
265
- f"DEBUG POLICY_ROUTES: Formatted obs_text contains 🟩: {'🟩' in obs_text}"
266
- )
267
- print(
268
- f"DEBUG POLICY_ROUTES: Formatted obs_text contains 🟨: {'🟨' in obs_text}"
269
- )
270
- print(
271
- f"DEBUG POLICY_ROUTES: Formatted obs_text contains ⬛: {'⬛' in obs_text}"
272
- )
273
- print(
274
- f"DEBUG POLICY_ROUTES: Formatted obs_text first 200 chars: {obs_text[:200]}"
275
- )
243
+ print(f"DEBUG POLICY_ROUTES: Formatted obs_text length: {len(obs_text)}")
244
+ print(f"DEBUG POLICY_ROUTES: Formatted obs_text contains 🟩: {'🟩' in obs_text}")
245
+ print(f"DEBUG POLICY_ROUTES: Formatted obs_text contains 🟨: {'🟨' in obs_text}")
246
+ print(f"DEBUG POLICY_ROUTES: Formatted obs_text contains ⬛: {'⬛' in obs_text}")
247
+ print(f"DEBUG POLICY_ROUTES: Formatted obs_text first 200 chars: {obs_text[:200]}")
276
248
  elif True:
277
249
  try:
278
250
  from .envs.sokoban.policy import SokobanPolicy as _SokobanPolicy
279
251
  except Exception:
280
252
  _SokobanPolicy = None # type: ignore
281
-
253
+
282
254
  if _SokobanPolicy is not None and isinstance(policy, _SokobanPolicy):
283
255
  from .envs.sokoban.shared import format_observation_sokoban
284
-
256
+
285
257
  obs_text = format_observation_sokoban(request.observation)
286
258
  elif True:
287
259
  try:
@@ -291,7 +263,9 @@ async def step_policy(
291
263
  if _MathPolicy is not None and isinstance(policy, _MathPolicy):
292
264
  # Simple extraction of problem text
293
265
  try:
294
- obs_text = str(request.observation.get("problem_text") or request.observation)
266
+ obs_text = str(
267
+ request.observation.get("problem_text") or request.observation
268
+ )
295
269
  except Exception:
296
270
  obs_text = str(request.observation)
297
271
  else:
@@ -316,9 +290,7 @@ async def step_policy(
316
290
  user_messages: List[str] = []
317
291
  if msgs and len(msgs) > 0 and msgs[0]["role"] == "system":
318
292
  sys_text = msgs[0]["content"]
319
- policy_name = (
320
- getattr(policy, "name", "") or type(policy).__name__.lower()
321
- )
293
+ policy_name = getattr(policy, "name", "") or type(policy).__name__.lower()
322
294
 
323
295
  # Assert environment-specific prompts match the policy
324
296
  if policy_name in ("wordle-react", "wordle"):
@@ -363,6 +335,7 @@ async def step_policy(
363
335
 
364
336
  # Emit full system/user prompts for observability (no secrets included)
365
337
  try:
338
+
366
339
  def _as_text(content: object) -> str:
367
340
  if isinstance(content, str):
368
341
  return content
@@ -404,7 +377,7 @@ async def step_policy(
404
377
  # Print concise preview for visibility in standard logs
405
378
  try:
406
379
  last_user = user_messages[-1] if user_messages else ""
407
- #preview = last_user[:400] if isinstance(last_user, str) else str(last_user)[:400]
380
+ # preview = last_user[:400] if isinstance(last_user, str) else str(last_user)[:400]
408
381
  print(f"[task:crafter] user prompt: {last_user}", flush=True)
409
382
  except Exception:
410
383
  pass
@@ -435,16 +408,27 @@ async def step_policy(
435
408
  api_key_override = None
436
409
  try:
437
410
  import os as _os
411
+
438
412
  if isinstance(target_url, str):
439
413
  low_url = target_url.lower()
440
414
  if "openai.com" in low_url:
441
- api_key_override = _os.getenv("OPENAI_API_KEY") or getattr(task_app, "openai_api_key", None)
415
+ api_key_override = _os.getenv("OPENAI_API_KEY") or getattr(
416
+ task_app, "openai_api_key", None
417
+ )
442
418
  elif "groq.com" in low_url:
443
419
  api_key_override = _os.getenv("GROQ_API_KEY")
444
420
  else:
445
- api_key_override = _os.getenv("SYNTH_API_KEY") or _os.getenv("OPENAI_API_KEY") or getattr(task_app, "openai_api_key", None)
421
+ api_key_override = (
422
+ _os.getenv("SYNTH_API_KEY")
423
+ or _os.getenv("OPENAI_API_KEY")
424
+ or getattr(task_app, "openai_api_key", None)
425
+ )
446
426
  else:
447
- api_key_override = _os.getenv("SYNTH_API_KEY") or _os.getenv("OPENAI_API_KEY") or getattr(task_app, "openai_api_key", None)
427
+ api_key_override = (
428
+ _os.getenv("SYNTH_API_KEY")
429
+ or _os.getenv("OPENAI_API_KEY")
430
+ or getattr(task_app, "openai_api_key", None)
431
+ )
448
432
  except Exception:
449
433
  api_key_override = None
450
434
 
@@ -455,7 +439,9 @@ async def step_policy(
455
439
  masked = "<masked>"
456
440
  logger.debug(f"INFERENCE_AUTH: Using bearer key {masked}")
457
441
  else:
458
- logger.warning("INFERENCE_AUTH: No API key resolved for inference request; downstream may 401")
442
+ logger.warning(
443
+ "INFERENCE_AUTH: No API key resolved for inference request; downstream may 401"
444
+ )
459
445
 
460
446
  client = create_inference_client(task_app, api_key=api_key_override)
461
447
 
@@ -650,6 +636,7 @@ async def step_policy(
650
636
  if model_for_diag and messages_for_diag:
651
637
  try:
652
638
  from transformers import AutoTokenizer
639
+
653
640
  tok = AutoTokenizer.from_pretrained(model_for_diag)
654
641
  prompt_preview = tok.apply_chat_template(
655
642
  messages_for_diag,
@@ -660,7 +647,9 @@ async def step_policy(
660
647
  max_len = getattr(tok, "model_max_length", None)
661
648
  over_limit = False
662
649
  try:
663
- over_limit = isinstance(max_len, int) and max_len > 0 and len(ids) > int(max_len)
650
+ over_limit = (
651
+ isinstance(max_len, int) and max_len > 0 and len(ids) > int(max_len)
652
+ )
664
653
  except Exception:
665
654
  over_limit = False
666
655
  if over_limit or len(ids) > 10000:
@@ -672,7 +661,9 @@ async def step_policy(
672
661
  "prompt_token_overflow_local": True,
673
662
  "model": str(model_for_diag),
674
663
  "token_count": int(len(ids)),
675
- "model_max_length": int(max_len) if isinstance(max_len, int) else None,
664
+ "model_max_length": int(max_len)
665
+ if isinstance(max_len, int)
666
+ else None,
676
667
  "preview_tokens_logged": int(len(preview_ids)),
677
668
  "prompt_preview_first_10k_tokens": preview_text,
678
669
  }
@@ -682,7 +673,9 @@ async def step_policy(
682
673
  try:
683
674
  meta["prompt_debug"] = {
684
675
  "token_count": int(len(ids)),
685
- "model_max_length": int(max_len) if isinstance(max_len, int) else None,
676
+ "model_max_length": int(max_len)
677
+ if isinstance(max_len, int)
678
+ else None,
686
679
  "preview_first_10k_tokens": preview_text,
687
680
  }
688
681
  except Exception:
@@ -700,14 +693,19 @@ async def step_policy(
700
693
  if isinstance(msgs, list):
701
694
  # Print compact messages structure and tool schema with bounded length
702
695
  import json as _json
696
+
703
697
  msgs_compact = _json.dumps(msgs)[:20000]
704
- tools_compact = _json.dumps(tools_dump)[:8000] if tools_dump is not None else None
705
- print({
706
- "llm.call": True,
707
- "policy": str(policy_name),
708
- "messages_preview": msgs_compact,
709
- "tools_preview": tools_compact,
710
- })
698
+ tools_compact = (
699
+ _json.dumps(tools_dump)[:8000] if tools_dump is not None else None
700
+ )
701
+ print(
702
+ {
703
+ "llm.call": True,
704
+ "policy": str(policy_name),
705
+ "messages_preview": msgs_compact,
706
+ "tools_preview": tools_compact,
707
+ }
708
+ )
711
709
  except Exception:
712
710
  pass
713
711
 
@@ -724,13 +722,20 @@ async def step_policy(
724
722
  try:
725
723
  tools_arr = req_body.get("tools") or []
726
724
  if isinstance(tools_arr, list) and tools_arr:
727
- f = tools_arr[0].get("function") if isinstance(tools_arr[0], dict) else None
725
+ f = (
726
+ tools_arr[0].get("function")
727
+ if isinstance(tools_arr[0], dict)
728
+ else None
729
+ )
728
730
  cand = (f or {}).get("name") if isinstance(f, dict) else None
729
731
  if isinstance(cand, str) and cand:
730
732
  func_name = cand
731
733
  except Exception:
732
734
  pass
733
- req_body["tool_choice"] = {"type": "function", "function": {"name": func_name}}
735
+ req_body["tool_choice"] = {
736
+ "type": "function",
737
+ "function": {"name": func_name},
738
+ }
734
739
  req_body["parallel_tool_calls"] = False
735
740
  req_body.setdefault("function_call", {"name": func_name})
736
741
  # Inject extra_body for thinking controls expected by Modal service
@@ -799,10 +804,13 @@ async def step_policy(
799
804
  else:
800
805
  try:
801
806
  import json as _json
802
- print({
803
- "tool_calls_parsed": int(len(tool_calls)),
804
- "tool_calls_preview": _json.dumps(tool_calls)[:20000],
805
- })
807
+
808
+ print(
809
+ {
810
+ "tool_calls_parsed": int(len(tool_calls)),
811
+ "tool_calls_preview": _json.dumps(tool_calls)[:20000],
812
+ }
813
+ )
806
814
  except Exception:
807
815
  logger.info(f"Parsed {len(tool_calls)} tool calls: {tool_calls}")
808
816
 
@@ -814,9 +822,7 @@ async def step_policy(
814
822
  inference_response, getattr(policy, "use_tools", True)
815
823
  )
816
824
  else:
817
- parsed = policy.parse_model_response(
818
- inference_response, request.observation
819
- )
825
+ parsed = policy.parse_model_response(inference_response, request.observation)
820
826
  # Replace tool_calls with parsed result
821
827
  if isinstance(parsed, list):
822
828
  tool_calls = parsed
@@ -866,9 +872,7 @@ async def snapshot_policy(request: PolicySnapshotRequest) -> PolicySnapshotRespo
866
872
  """Create a snapshot of the policy state."""
867
873
  handle = registry.get_policy(request.policy_id)
868
874
  if not handle:
869
- raise HTTPException(
870
- status_code=404, detail=f"Policy {request.policy_id} not found"
871
- )
875
+ raise HTTPException(status_code=404, detail=f"Policy {request.policy_id} not found")
872
876
 
873
877
  try:
874
878
  # Serialize policy state
@@ -906,9 +910,7 @@ async def restore_policy(request: PolicyRestoreRequest) -> PolicyRestoreResponse
906
910
  """Restore a policy from a snapshot."""
907
911
  snapshot = registry.get_snapshot(request.snapshot_id)
908
912
  if not snapshot:
909
- raise HTTPException(
910
- status_code=404, detail=f"Snapshot {request.snapshot_id} not found"
911
- )
913
+ raise HTTPException(status_code=404, detail=f"Snapshot {request.snapshot_id} not found")
912
914
 
913
915
  if snapshot.kind != "policy":
914
916
  raise HTTPException(
@@ -956,9 +958,7 @@ async def restore_policy(request: PolicyRestoreRequest) -> PolicyRestoreResponse
956
958
  return PolicyRestoreResponse(policy_id=policy_id)
957
959
 
958
960
  except Exception as e:
959
- logger.error(
960
- f"Failed to restore policy from snapshot {request.snapshot_id}: {e}"
961
- )
961
+ logger.error(f"Failed to restore policy from snapshot {request.snapshot_id}: {e}")
962
962
  raise HTTPException(status_code=500, detail=str(e))
963
963
 
964
964
 
@@ -967,9 +967,7 @@ async def terminate_policy(request: PolicyTerminateRequest) -> PolicyTerminateRe
967
967
  """Terminate a policy and clean up resources."""
968
968
  handle = registry.get_policy(request.policy_id)
969
969
  if not handle:
970
- raise HTTPException(
971
- status_code=404, detail=f"Policy {request.policy_id} not found"
972
- )
970
+ raise HTTPException(status_code=404, detail=f"Policy {request.policy_id} not found")
973
971
 
974
972
  try:
975
973
  # Call terminate on the policy