synth-ai 0.2.12__py3-none-any.whl → 0.2.13.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (229) hide show
  1. examples/multi_step/configs/crafter_rl_outcome.toml +74 -0
  2. examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +186 -0
  3. examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +83 -0
  4. examples/multi_step/configs/crafter_rl_stepwise_simple.toml +78 -0
  5. examples/multi_step/crafter_rl_lora.md +51 -10
  6. examples/multi_step/sse_metrics_streaming_notes.md +357 -0
  7. examples/multi_step/task_app_config_notes.md +7 -1
  8. examples/swe/task_app/grpo_swe_mini.py +55 -26
  9. examples/swe/task_app/hosted/rollout.py +40 -0
  10. examples/swe/task_app/hosted/test_service.py +5 -6
  11. examples/task_apps/TESTING.md +275 -0
  12. examples/task_apps/__init__.py +0 -0
  13. examples/task_apps/crafter/__init__.py +0 -0
  14. examples/task_apps/crafter/task_app/__init__.py +2 -0
  15. examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter.py +21 -46
  16. examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter_task_app.py +1 -1
  17. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/policy.py +60 -4
  18. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/openai_client.py +109 -45
  19. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/policy_routes.py +67 -49
  20. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/rollout.py +242 -193
  21. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_service.py +5 -6
  22. examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
  23. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
  24. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
  25. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
  26. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
  27. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
  28. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
  29. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
  30. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
  31. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
  32. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
  33. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
  34. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
  35. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
  36. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
  37. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
  38. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
  39. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
  40. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
  41. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
  42. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/__init__.py +0 -0
  43. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
  44. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
  45. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
  46. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
  47. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/__init__.py +0 -0
  48. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
  49. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
  50. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
  51. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
  52. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
  53. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
  54. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
  55. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
  56. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
  57. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
  58. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
  59. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
  60. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
  61. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
  62. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/__init__.py +0 -0
  63. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
  64. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
  65. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
  66. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
  67. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
  68. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
  69. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
  70. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
  71. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
  72. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
  73. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
  74. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
  75. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
  76. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
  77. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
  78. examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
  79. examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
  80. examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
  81. examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
  82. examples/task_apps/enron/__init__.py +1 -0
  83. examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
  84. examples/task_apps/enron/task_app/README.md +14 -0
  85. examples/task_apps/enron/task_app/__init__.py +1 -0
  86. examples/task_apps/enron/task_app/grpo_enron.py +906 -0
  87. examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
  88. examples/task_apps/enron/tests/__init__.py +2 -0
  89. examples/task_apps/enron/tests/conftest.py +115 -0
  90. examples/task_apps/enron/tests/integration/__init__.py +2 -0
  91. examples/task_apps/enron/tests/integration/test_enron_eval.py +177 -0
  92. examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
  93. examples/task_apps/enron/tests/unit/__init__.py +2 -0
  94. examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
  95. examples/task_apps/math/__init__.py +0 -0
  96. examples/{rl/task_app → task_apps/math}/math_single_step.py +19 -10
  97. examples/task_apps/pokemon_battle/__init__.py +2 -0
  98. examples/task_apps/pokemon_battle/modal_app.py +104 -0
  99. examples/task_apps/pokemon_battle/task_app/README.md +68 -0
  100. examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
  101. examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
  102. examples/task_apps/pokemon_red/README.md +357 -0
  103. examples/task_apps/pokemon_red/__init__.py +3 -0
  104. examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +225 -0
  105. examples/task_apps/pokemon_red/pallet_town_rl_config.toml +73 -0
  106. examples/task_apps/pokemon_red/task_app.py +606 -0
  107. examples/task_apps/pokemon_red/test_pallet_town_rewards.py +191 -0
  108. examples/task_apps/sokoban/README.md +307 -0
  109. examples/task_apps/sokoban/__init__.py +3 -0
  110. examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
  111. examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
  112. examples/task_apps/sokoban/task_app.py +1058 -0
  113. examples/task_apps/sokoban/tests/__init__.py +2 -0
  114. examples/task_apps/sokoban/tests/conftest.py +113 -0
  115. examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
  116. examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
  117. examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
  118. examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
  119. examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
  120. examples/task_apps/verilog/__init__.py +1 -0
  121. examples/task_apps/verilog/eval_groq_qwen32b.toml +20 -0
  122. examples/task_apps/verilog/task_app/README.md +12 -0
  123. examples/task_apps/verilog/task_app/__init__.py +1 -0
  124. examples/task_apps/verilog/task_app/grpo_verilog.py +931 -0
  125. examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
  126. examples/task_apps/verilog/tests/__init__.py +2 -0
  127. examples/task_apps/verilog/tests/conftest.py +115 -0
  128. examples/task_apps/verilog/tests/integration/__init__.py +2 -0
  129. examples/task_apps/verilog/tests/integration/test_verilog_eval.py +179 -0
  130. examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
  131. examples/task_apps/verilog/tests/unit/__init__.py +2 -0
  132. examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
  133. examples/vlm/crafter_openai_vlm_agent.py +4 -4
  134. examples/vlm/run_crafter_vlm_benchmark.py +4 -4
  135. examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +4 -2
  136. examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +4 -2
  137. examples/warming_up_to_rl/run_eval.py +127 -18
  138. examples/workflows/__init__.py +0 -0
  139. examples/workflows/math_rl/__init__.py +0 -0
  140. examples/workflows/math_rl/download_dataset.py +80 -0
  141. synth_ai/__init__.py +41 -1
  142. synth_ai/api/train/builders.py +73 -29
  143. synth_ai/api/train/cli.py +12 -6
  144. synth_ai/api/train/configs/__init__.py +44 -0
  145. synth_ai/api/train/configs/rl.py +134 -0
  146. synth_ai/api/train/configs/sft.py +95 -0
  147. synth_ai/api/train/configs/shared.py +24 -0
  148. synth_ai/api/train/env_resolver.py +5 -2
  149. synth_ai/api/train/supported_algos.py +10 -5
  150. synth_ai/api/train/utils.py +7 -4
  151. synth_ai/cli/__init__.py +7 -51
  152. synth_ai/cli/_storage.py +4 -3
  153. synth_ai/cli/_validate_task_app.py +11 -0
  154. synth_ai/cli/balance.py +4 -3
  155. synth_ai/cli/calc.py +2 -2
  156. synth_ai/cli/demo.py +49 -43
  157. synth_ai/cli/legacy_root_backup.py +1 -1
  158. synth_ai/cli/rl_demo.py +86 -106
  159. synth_ai/cli/root.py +0 -97
  160. synth_ai/cli/task_apps.py +1710 -186
  161. synth_ai/demos/core/cli.py +121 -159
  162. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +28 -16
  163. synth_ai/environments/examples/crafter_classic/environment.py +16 -0
  164. synth_ai/environments/examples/enron/engine.py +7 -2
  165. synth_ai/environments/examples/enron/environment.py +68 -0
  166. synth_ai/environments/examples/red/engine.py +27 -0
  167. synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
  168. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
  169. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
  170. synth_ai/environments/examples/red/environment.py +60 -0
  171. synth_ai/environments/examples/sokoban/taskset.py +116 -0
  172. synth_ai/environments/examples/verilog/engine.py +30 -4
  173. synth_ai/evals/__init__.py +15 -0
  174. synth_ai/evals/client.py +82 -0
  175. synth_ai/evals/types.py +42 -0
  176. synth_ai/jobs/client.py +16 -4
  177. synth_ai/judge_schemas.py +127 -0
  178. synth_ai/py.typed +0 -0
  179. synth_ai/task/__init__.py +14 -5
  180. synth_ai/task/contracts.py +124 -38
  181. synth_ai/task/proxy.py +48 -56
  182. synth_ai/task/rubrics/__init__.py +53 -0
  183. synth_ai/task/rubrics/loaders.py +133 -0
  184. synth_ai/task/rubrics/models.py +57 -0
  185. synth_ai/task/rubrics/scoring.py +113 -0
  186. synth_ai/task/rubrics/strict.py +149 -0
  187. synth_ai/task/server.py +8 -7
  188. synth_ai/task/validators.py +269 -6
  189. synth_ai/tracing_v3/decorators.py +7 -3
  190. synth_ai/tracing_v3/replica_sync.py +4 -4
  191. synth_ai/tracing_v3/serialization.py +130 -0
  192. synth_ai/tracing_v3/trace_utils.py +317 -0
  193. synth_ai/tracing_v3/turso/native_manager.py +3 -3
  194. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/METADATA +4 -1
  195. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/RECORD +228 -89
  196. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/entry_points.txt +0 -1
  197. synth_ai/task/rubrics.py +0 -219
  198. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/README.md +0 -0
  199. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/README.md +0 -0
  200. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/__init__.py +0 -0
  201. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/branching.py +0 -0
  202. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/environment_routes.py +0 -0
  203. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/__init__.py +0 -0
  204. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/__init__.py +0 -0
  205. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/app.py +0 -0
  206. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/environment.py +0 -0
  207. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/react_agent.py +0 -0
  208. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/shared.py +0 -0
  209. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/tools.py +0 -0
  210. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/hosted_app.py +0 -0
  211. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/__init__.py +0 -0
  212. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/main.py +0 -0
  213. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/registry.py +0 -0
  214. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/__init__.py +0 -0
  215. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/volume.py +0 -0
  216. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_agents.py +0 -0
  217. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/utils.py +0 -0
  218. /examples/{rl/task_app → task_apps/math}/README.md +0 -0
  219. /examples/{rl/task_app → task_apps/math}/math_task_app.py +0 -0
  220. /examples/{rl → workflows/math_rl}/configs/eval_base_qwen.toml +0 -0
  221. /examples/{rl → workflows/math_rl}/configs/eval_rl_qwen.toml +0 -0
  222. /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen.toml +0 -0
  223. /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen17.toml +0 -0
  224. /examples/{rl → workflows/math_rl}/configs/rl_from_ft_qwen.toml +0 -0
  225. /examples/{rl → workflows/math_rl}/run_eval.py +0 -0
  226. /examples/{rl → workflows/math_rl}/run_rl_and_save.py +0 -0
  227. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/WHEEL +0 -0
  228. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/licenses/LICENSE +0 -0
  229. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,931 @@
1
+ """Task App configuration for the GRPO Verilog spec-to-RTL example."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import contextlib
7
+ import json
8
+ import logging
9
+ import os
10
+ import re
11
+ from dataclasses import dataclass
12
+ from pathlib import Path
13
+ from typing import Any, Iterable, Optional, Sequence
14
+
15
+ import httpx
16
+
17
+ from synth_ai.environments.environment.tools import EnvToolCall
18
+ from synth_ai.environments.examples.verilog.environment import VerilogEnvironment
19
+ from synth_ai.environments.examples.verilog.taskset import (
20
+ VerilogTaskInstance,
21
+ VerilogTaskInstanceMetadata,
22
+ create_verilog_taskset,
23
+ )
24
+ from synth_ai.environments.tasks.core import TaskInstanceSet
25
+ from synth_ai.task.apps import ModalDeploymentConfig, TaskAppEntry, register_task_app
26
+ from synth_ai.task.contracts import (
27
+ RolloutMetrics,
28
+ RolloutRequest,
29
+ RolloutResponse,
30
+ RolloutTrajectory,
31
+ RolloutStep,
32
+ TaskInfo,
33
+ )
34
+ from synth_ai.task.datasets import TaskDatasetRegistry, TaskDatasetSpec
35
+ from synth_ai.task.rubrics import load_rubric
36
+ from synth_ai.task.server import ProxyConfig, RubricBundle, TaskAppConfig
37
+ from synth_ai.task.tracing_utils import (
38
+ build_tracer_factory,
39
+ resolve_sft_output_dir,
40
+ resolve_tracing_db_url,
41
+ tracing_env_enabled,
42
+ )
43
+ from synth_ai.tracing_v3.session_tracer import SessionTracer
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+ _HERE = Path(__file__).resolve()
48
+ REPO_ROOT = _HERE.parents[4]
49
+
50
+ DATASET_SPEC = TaskDatasetSpec(
51
+ id="verilog_eval_v2",
52
+ name="VerilogEval Spec-to-RTL",
53
+ version="1.0.0",
54
+ splits=["train", "val", "test"],
55
+ default_split="val",
56
+ description="Spec-to-RTL problems sourced from the VerilogEval v2 benchmark.",
57
+ )
58
+
59
+ MAX_INSTANCES = int(os.getenv("VERILOG_MAX_INSTANCES", "10"))
60
+ TOOLS = ["write_file", "compile", "simulate", "submit"]
61
+ DEFAULT_INFERENCE_URL = os.getenv(
62
+ "VERILOG_INFERENCE_URL", "https://api.groq.com/openai/v1/chat/completions"
63
+ )
64
+ DEFAULT_MODEL = os.getenv("VERILOG_DEFAULT_MODEL", "qwen/qwen3-32b")
65
+ DEFAULT_TEMPERATURE = float(os.getenv("VERILOG_DEFAULT_TEMPERATURE", "0.2"))
66
+ DEFAULT_MAX_TOKENS = int(os.getenv("VERILOG_DEFAULT_MAX_TOKENS", "768"))
67
+ DEFAULT_MAX_STEPS = int(os.getenv("VERILOG_DEFAULT_MAX_STEPS", "10"))
68
+ FILE_PREVIEW_CHARS = int(os.getenv("VERILOG_FILE_PREVIEW_CHARS", "600"))
69
+ HTTP_TIMEOUT_SECONDS = float(os.getenv("VERILOG_INFERENCE_TIMEOUT", "90"))
70
+
71
+ VERILOG_SYSTEM_PROMPT = (
72
+ "You are an expert digital design engineer helping with Verilog spec-to-RTL tasks. "
73
+ "Choose between these tools: write_file, compile, simulate, submit. "
74
+ "Always respond with a JSON object describing exactly one tool call in the form "
75
+ "{\"tool\": \"<tool_name>\", \"args\": { ... }}. "
76
+ "You may wrap the JSON inside a ```json``` block but MUST NOT include any other prose outside it. "
77
+ "When editing files, rewrite the entire file content. Compile after code changes, simulate to verify behavior, "
78
+ "and submit only after the tests pass. If compilation reports errors (missing ports, mismatched interfaces, etc.), "
79
+ "fix the design with write_file before compiling again—never repeat compile without modifying the source first."
80
+ )
81
+
82
+
83
+ def _load_taskset_blocking(max_instances: int) -> TaskInstanceSet:
84
+ try:
85
+ return asyncio.run(create_verilog_taskset(max_instances=max_instances))
86
+ except RuntimeError:
87
+ loop = asyncio.new_event_loop()
88
+ try:
89
+ return loop.run_until_complete(create_verilog_taskset(max_instances=max_instances))
90
+ finally:
91
+ loop.close()
92
+
93
+
94
+ @dataclass
95
+ class VerilogDataset:
96
+ spec: TaskDatasetSpec
97
+ max_instances: int
98
+
99
+ def __post_init__(self) -> None:
100
+ self._taskset = _load_taskset_blocking(self.max_instances)
101
+ self.instances: list[VerilogTaskInstance] = list(self._taskset.instances)
102
+ self.instance_ids = [str(inst.id) for inst in self.instances]
103
+ self.default_seed = 0
104
+ self.seed_min = 0
105
+ self.seed_max = max(len(self.instances) - 1, 0)
106
+
107
+ def describe(self) -> dict[str, Any]:
108
+ return {
109
+ **self.spec.model_dump(),
110
+ "instance_count": len(self.instances),
111
+ "instance_ids": self.instance_ids[:50],
112
+ }
113
+
114
+ def instance_by_seed(self, seed: int | None) -> VerilogTaskInstance:
115
+ if not self.instances:
116
+ raise ValueError("Verilog dataset is empty.")
117
+ if seed is None:
118
+ index = 0
119
+ else:
120
+ index = int(seed) % len(self.instances)
121
+ return self.instances[index]
122
+
123
+
124
+ def build_dataset() -> tuple[TaskDatasetRegistry, VerilogDataset]:
125
+ registry = TaskDatasetRegistry()
126
+ dataset = VerilogDataset(DATASET_SPEC, MAX_INSTANCES)
127
+ registry.register(DATASET_SPEC, lambda _spec: dataset, cache=True)
128
+ return registry, dataset
129
+
130
+
131
+ def _base_task_info(dataset: VerilogDataset) -> TaskInfo:
132
+ return TaskInfo(
133
+ task={"id": "verilog_eval_v2", "name": "VerilogEval Spec-to-RTL", "version": "1.0.0"},
134
+ environment="verilog",
135
+ action_space={
136
+ "type": "tool_calls",
137
+ "tools": TOOLS,
138
+ "description": "Filesystem editing, compilation, simulation, and submission tools.",
139
+ },
140
+ observation={
141
+ "summary": "Dictionary observations describing files, compilation status, simulation results, and rewards.",
142
+ "format": "dict",
143
+ "keys": ["files", "compile_status", "simulate_status", "reward_last"],
144
+ },
145
+ dataset={**dataset.describe(), "default_seed": dataset.default_seed},
146
+ rubric={
147
+ "version": "1",
148
+ "criteria_count": 1,
149
+ "source": "inline",
150
+ "aggregation": "weighted_sum",
151
+ },
152
+ inference={
153
+ "supports_proxy": True,
154
+ "endpoints": {
155
+ "openai": "/proxy/v1/chat/completions",
156
+ "groq": "/proxy/groq/v1/chat/completions",
157
+ },
158
+ "tool": {"name": "verilog_tools", "parallel_tool_calls": False},
159
+ },
160
+ limits={"max_ops": 0, "max_time_s": 3600},
161
+ )
162
+
163
+
164
+ def _normalize_inference_url(url: str | None) -> str:
165
+ candidate = (url or DEFAULT_INFERENCE_URL).strip()
166
+ if not candidate:
167
+ candidate = DEFAULT_INFERENCE_URL
168
+ if candidate.endswith("/v1/chat/completions"):
169
+ return candidate
170
+ if candidate.endswith("/chat/completions"):
171
+ return candidate
172
+ if candidate.endswith("/v1"):
173
+ return f"{candidate.rstrip('/')}/chat/completions"
174
+ if candidate.endswith("/v1/"):
175
+ return f"{candidate.rstrip('/')}/chat/completions"
176
+ if candidate.endswith("/chat"):
177
+ return f"{candidate.rstrip('/')}/completions"
178
+ if candidate.endswith("/chat/"):
179
+ return f"{candidate.rstrip('/')}/completions"
180
+ return f"{candidate.rstrip('/')}/v1/chat/completions"
181
+
182
+
183
+ def _format_file_previews(files: dict[str, str]) -> str:
184
+ if not files:
185
+ return "No files in the workspace yet."
186
+
187
+ sections: list[str] = []
188
+ for name in sorted(files.keys()):
189
+ content = files[name] or ""
190
+ snippet = content.strip()
191
+ if len(snippet) > FILE_PREVIEW_CHARS:
192
+ snippet = snippet[:FILE_PREVIEW_CHARS] + "\n..."
193
+ sections.append(f"{name}:\n{snippet}")
194
+ return "\n\n".join(sections)
195
+
196
+
197
+ def _format_observation_text(
198
+ *,
199
+ observation: dict[str, Any],
200
+ step_index: int,
201
+ instructions: str | None,
202
+ action_feedback: str | None,
203
+ guidance: str | None = None,
204
+ ) -> str:
205
+ lines: list[str] = []
206
+ if step_index == 0 and instructions:
207
+ lines.append("Task instructions:")
208
+ lines.append(instructions.strip())
209
+ lines.append("")
210
+
211
+ lines.append(f"Step {step_index} status:")
212
+ reward_last = observation.get("reward_last")
213
+ total_reward = observation.get("total_reward")
214
+ if reward_last is not None or total_reward is not None:
215
+ lines.append(
216
+ f"- reward_last={reward_last!r}, total_reward={total_reward!r}"
217
+ )
218
+ lines.append(f"- task_completed={bool(observation.get('task_completed'))}")
219
+ compile_status = observation.get("compile_status")
220
+ if compile_status:
221
+ lines.append(f"- compile_status: {compile_status}")
222
+ simulate_status = observation.get("simulate_status")
223
+ if simulate_status:
224
+ lines.append(f"- simulate_status: {simulate_status}")
225
+ build_dir = observation.get("build_dir")
226
+ if build_dir:
227
+ lines.append(f"- build_directory: {build_dir}")
228
+
229
+ if action_feedback:
230
+ lines.append("")
231
+ lines.append(action_feedback)
232
+
233
+ files = observation.get("files")
234
+ lines.append("")
235
+ lines.append("Workspace files:")
236
+ lines.append(_format_file_previews(files or {}))
237
+
238
+ lines.append("")
239
+ lines.append(
240
+ "Select the single most helpful tool for the next step (write_file, compile, simulate, submit)."
241
+ )
242
+ lines.append(
243
+ "Respond with JSON only: {\"tool\": \"<tool_name>\", \"args\": {...}}."
244
+ )
245
+ if guidance:
246
+ lines.append("")
247
+ lines.append(guidance.strip())
248
+ return "\n".join(lines)
249
+
250
+
251
+ def _summarize_action_feedback(
252
+ tool_name: str, args: dict[str, Any], observation: dict[str, Any], reward: float
253
+ ) -> str:
254
+ argument_preview = json.dumps(args, ensure_ascii=False)
255
+ parts = [
256
+ f"Previous action: {tool_name}({argument_preview})",
257
+ f"Reward delta: {reward:.4f}",
258
+ ]
259
+ compile_status = observation.get("compile_status")
260
+ if compile_status:
261
+ parts.append(f"Compile status: {compile_status}")
262
+ simulate_status = observation.get("simulate_status")
263
+ if simulate_status:
264
+ parts.append(f"Simulation status: {simulate_status}")
265
+ if observation.get("task_completed"):
266
+ parts.append("Task completed ✅")
267
+ total_reward = observation.get("total_reward")
268
+ if total_reward is not None:
269
+ parts.append(f"Total reward: {total_reward}")
270
+ return "\n".join(parts)
271
+
272
+
273
+ JSON_BLOCK_PATTERN = re.compile(r"```(?:json)?\s*(\{.*?\})\s*```", re.DOTALL)
274
+
275
+
276
+ def _parse_tool_json(text: str) -> list[dict[str, Any]]:
277
+ candidates: list[dict[str, Any]] = []
278
+
279
+ try:
280
+ parsed = json.loads(text)
281
+ if isinstance(parsed, dict):
282
+ candidates.append(parsed)
283
+ except Exception:
284
+ pass
285
+
286
+ if not candidates:
287
+ for match in JSON_BLOCK_PATTERN.finditer(text):
288
+ snippet = match.group(1)
289
+ try:
290
+ parsed = json.loads(snippet)
291
+ except Exception:
292
+ continue
293
+ if isinstance(parsed, dict):
294
+ candidates.append(parsed)
295
+
296
+ if not candidates:
297
+ brace_match = re.search(r"\{.*\}", text, re.DOTALL)
298
+ if brace_match:
299
+ try:
300
+ parsed = json.loads(brace_match.group(0))
301
+ if isinstance(parsed, dict):
302
+ candidates.append(parsed)
303
+ except Exception:
304
+ pass
305
+
306
+ for candidate in candidates:
307
+ tool_name = candidate.get("tool") if isinstance(candidate, dict) else None
308
+ if not isinstance(tool_name, str):
309
+ continue
310
+ raw_args = candidate.get("args") if isinstance(candidate, dict) else None
311
+ args = raw_args if isinstance(raw_args, dict) else {}
312
+ tool_name = tool_name.strip()
313
+ normalized_args: dict[str, Any] = dict(args)
314
+ if tool_name == "write_file":
315
+ if "file_path" in normalized_args and "path" not in normalized_args:
316
+ normalized_args["path"] = normalized_args.pop("file_path")
317
+ if "file" in normalized_args and "path" not in normalized_args:
318
+ normalized_args["path"] = normalized_args.pop("file")
319
+ if "contents" in normalized_args and "content" not in normalized_args:
320
+ normalized_args["content"] = normalized_args.pop("contents")
321
+ return [{"tool": tool_name, "args": normalized_args}]
322
+
323
+ return []
324
+
325
+
326
+ class VerilogLLMAgent:
327
+ """Minimal ReAct-style agent that communicates with a chat-completions API."""
328
+
329
+ def __init__(
330
+ self,
331
+ *,
332
+ instructions: str,
333
+ inference_url: str | None,
334
+ model: str | None,
335
+ temperature: float,
336
+ max_tokens: int,
337
+ ) -> None:
338
+ self.instructions = instructions.strip()
339
+ self.inference_url = _normalize_inference_url(inference_url)
340
+ self.model = model or DEFAULT_MODEL
341
+ self.temperature = temperature
342
+ self.max_tokens = max_tokens
343
+ self.messages: list[dict[str, Any]] = [{"role": "system", "content": VERILOG_SYSTEM_PROMPT}]
344
+ self.headers: dict[str, str] = {"Content-Type": "application/json"}
345
+
346
+ lowered = self.inference_url.lower()
347
+ if "groq" in lowered:
348
+ api_key = os.getenv("GROQ_API_KEY")
349
+ if not api_key:
350
+ raise RuntimeError("GROQ_API_KEY is not configured for Verilog inference.")
351
+ self.headers["Authorization"] = f"Bearer {api_key.strip()}"
352
+ elif "openai" in lowered:
353
+ api_key = os.getenv("OPENAI_API_KEY")
354
+ if not api_key:
355
+ raise RuntimeError("OPENAI_API_KEY is not configured for Verilog inference.")
356
+ self.headers["Authorization"] = f"Bearer {api_key.strip()}"
357
+
358
+ self.history: list[dict[str, Any]] = []
359
+
360
+ def append_observation(
361
+ self,
362
+ *,
363
+ observation: dict[str, Any],
364
+ step_index: int,
365
+ action_feedback: str | None,
366
+ guidance: str | None = None,
367
+ ) -> str:
368
+ text = _format_observation_text(
369
+ observation=observation,
370
+ step_index=step_index,
371
+ instructions=self.instructions if step_index == 0 else None,
372
+ action_feedback=action_feedback,
373
+ guidance=guidance,
374
+ )
375
+ self.messages.append({"role": "user", "content": text})
376
+ self.history.append({"role": "user", "content": text})
377
+ return text
378
+
379
+ async def invoke(
380
+ self, client: httpx.AsyncClient
381
+ ) -> tuple[str, list[dict[str, Any]], dict[str, Any], dict[str, Any]]:
382
+ payload: dict[str, Any] = {
383
+ "model": self.model,
384
+ "messages": self.messages,
385
+ "temperature": self.temperature,
386
+ }
387
+ if self.max_tokens > 0:
388
+ payload["max_tokens"] = self.max_tokens
389
+
390
+ try:
391
+ response = await client.post(self.inference_url, json=payload, headers=self.headers)
392
+ except Exception as exc: # pragma: no cover - network failure
393
+ raise RuntimeError(f"Failed to reach inference endpoint: {exc}") from exc
394
+
395
+ try:
396
+ response.raise_for_status()
397
+ except httpx.HTTPStatusError as exc: # pragma: no cover - inference error
398
+ preview = exc.response.text[:2000]
399
+ raise RuntimeError(
400
+ f"Inference call failed with status {exc.response.status_code}: {preview}"
401
+ ) from exc
402
+
403
+ data = response.json()
404
+ choices = data.get("choices") or []
405
+ message = choices[0].get("message", {}) if choices else {}
406
+ assistant_text = message.get("content") or ""
407
+ self.messages.append({"role": "assistant", "content": assistant_text})
408
+ self.history.append({"role": "assistant", "content": assistant_text})
409
+
410
+ parsed_calls = _parse_tool_json(assistant_text)
411
+
412
+ return assistant_text, parsed_calls, data, payload
413
+
414
+ OUTCOME_RUBRIC = load_rubric(
415
+ {
416
+ "version": "1",
417
+ "goal_text": "Produce a Verilog implementation that passes the provided testbench.",
418
+ "aggregation": "weighted_sum",
419
+ "criteria": [
420
+ {
421
+ "id": "tests_pass",
422
+ "description": "Submission passes all compile and simulation checks.",
423
+ "weight": 1.0,
424
+ }
425
+ ],
426
+ }
427
+ )
428
+
429
+ EVENTS_RUBRIC = load_rubric(
430
+ {
431
+ "version": "1",
432
+ "goal_text": "Encourage deliberate hardware design iterations.",
433
+ "aggregation": "weighted_sum",
434
+ "criteria": [
435
+ {
436
+ "id": "efficient_iterations",
437
+ "description": "Use write/compile/simulate tools strategically before submitting.",
438
+ "weight": 1.0,
439
+ }
440
+ ],
441
+ }
442
+ )
443
+
444
+
445
+ def describe_taskset(dataset: VerilogDataset) -> dict[str, Any]:
446
+ return dataset.describe()
447
+
448
+
449
+ def provide_task_instances(
450
+ dataset: VerilogDataset, base_info: TaskInfo, seeds: Sequence[int]
451
+ ) -> Iterable[TaskInfo]:
452
+ infos: list[TaskInfo] = []
453
+ base_observation = getattr(base_info, "observation", None)
454
+ if hasattr(base_observation, "model_dump"):
455
+ observation_template = base_observation.model_dump()
456
+ elif isinstance(base_observation, dict):
457
+ observation_template = dict(base_observation)
458
+ else:
459
+ observation_template = {}
460
+
461
+ for seed in seeds:
462
+ instance = dataset.instance_by_seed(seed)
463
+ metadata: VerilogTaskInstanceMetadata = instance.metadata # type: ignore[assignment]
464
+ meta_dict = {
465
+ "problem_name": getattr(metadata, "problem_name", None),
466
+ "difficulty": getattr(metadata, "difficulty", None),
467
+ "description": getattr(metadata, "description", None),
468
+ "files_provided": getattr(metadata, "files_provided", None),
469
+ }
470
+ infos.append(
471
+ TaskInfo(
472
+ task=base_info.task,
473
+ environment=base_info.environment,
474
+ action_space=base_info.action_space,
475
+ observation={
476
+ **observation_template,
477
+ "problem_name": meta_dict["problem_name"],
478
+ "difficulty": meta_dict["difficulty"],
479
+ },
480
+ dataset={
481
+ **base_info.dataset.model_dump(),
482
+ "instance_id": str(instance.id),
483
+ "metadata": meta_dict,
484
+ },
485
+ rubric=base_info.rubric,
486
+ inference=base_info.inference,
487
+ limits=base_info.limits,
488
+ )
489
+ )
490
+ return infos
491
+
492
+
493
+ def _ensure_dataset_from_state(fastapi_request, fallback: VerilogDataset) -> VerilogDataset:
494
+ if fastapi_request is None:
495
+ return fallback
496
+ state = getattr(getattr(fastapi_request, "app", None), "state", None)
497
+ candidate = getattr(state, "dataset", None)
498
+ return candidate or fallback
499
+
500
+
501
+ def _normalise_observation(value: Any) -> dict[str, Any]:
502
+ if isinstance(value, dict):
503
+ return value
504
+ if hasattr(value, "observation"):
505
+ obs = getattr(value, "observation")
506
+ if isinstance(obs, dict):
507
+ return obs
508
+ return {"text": str(obs)}
509
+ return {"text": str(value)}
510
+
511
+
512
+ async def rollout_executor(
513
+ request: RolloutRequest, fastapi_request
514
+ ) -> RolloutResponse:
515
+ dataset = _ensure_dataset_from_state(fastapi_request, RUNTIME_DATASET)
516
+ env_seed = getattr(request.env, "seed", None) if request and request.env else None
517
+ instance = dataset.instance_by_seed(env_seed)
518
+ env = VerilogEnvironment(task_instance=instance)
519
+
520
+ policy_config_raw = getattr(request.policy, "config", {}) if request.policy else {}
521
+ policy_config = dict(policy_config_raw) if isinstance(policy_config_raw, dict) else {}
522
+
523
+ policy_model = policy_config.get("model")
524
+ if not isinstance(policy_model, str) or not policy_model.strip():
525
+ policy_model = getattr(request.policy, "policy_name", None) or DEFAULT_MODEL
526
+ policy_model = policy_model.strip()
527
+
528
+ temperature = policy_config.get("temperature", DEFAULT_TEMPERATURE)
529
+ try:
530
+ temperature = float(temperature)
531
+ except (TypeError, ValueError):
532
+ temperature = DEFAULT_TEMPERATURE
533
+
534
+ max_tokens = policy_config.get("max_tokens", DEFAULT_MAX_TOKENS)
535
+ try:
536
+ max_tokens = int(max_tokens)
537
+ except (TypeError, ValueError):
538
+ max_tokens = DEFAULT_MAX_TOKENS
539
+
540
+ max_steps_candidate = (
541
+ policy_config.get("max_steps")
542
+ or policy_config.get("max_llm_calls")
543
+ or DEFAULT_MAX_STEPS
544
+ )
545
+ try:
546
+ max_steps = int(max_steps_candidate)
547
+ except (TypeError, ValueError):
548
+ max_steps = DEFAULT_MAX_STEPS
549
+ max_steps = max(1, min(25, max_steps))
550
+
551
+ inference_url = policy_config.get("inference_url")
552
+ if isinstance(inference_url, str) and inference_url.strip():
553
+ resolved_inference = inference_url.strip()
554
+ else:
555
+ resolved_inference = os.getenv("VERILOG_INFERENCE_URL", DEFAULT_INFERENCE_URL)
556
+
557
+ instructions = getattr(getattr(instance, "impetus", None), "instructions", "")
558
+ agent = VerilogLLMAgent(
559
+ instructions=getattr(getattr(instance, "impetus", None), "instructions", ""),
560
+ inference_url=resolved_inference,
561
+ model=policy_model,
562
+ temperature=temperature,
563
+ max_tokens=max_tokens,
564
+ )
565
+
566
+ policy_id = (
567
+ getattr(request.policy, "policy_id", None)
568
+ or getattr(request.policy, "policy_name", None)
569
+ or policy_model
570
+ )
571
+ env_id = getattr(request.env, "env_id", None) or getattr(request.env, "env_name", None) or "verilog"
572
+
573
+ steps: list[RolloutStep] = []
574
+ total_reward = 0.0
575
+ final_observation: dict[str, Any] | None = None
576
+ truncated_due_to_limit = False
577
+ code_dirty = False
578
+ last_compile_success = False
579
+ simulate_since_last_compile = False
580
+ last_compile_failed = False
581
+ needs_design_update = False
582
+
583
+ def _build_guidance(step_idx: int) -> str | None:
584
+ hints: list[str] = []
585
+ if step_idx == 0 and not last_compile_success:
586
+ hints.append("Begin by using write_file to implement TopModule according to the problem instructions before compiling.")
587
+ if last_compile_failed or needs_design_update:
588
+ hints.append("Compilation failed; update the design with write_file to match the required ports and behavior before compiling again.")
589
+ if code_dirty and not last_compile_success:
590
+ hints.append("Source was modified; run compile before simulate or submit.")
591
+ if (not code_dirty) and last_compile_success and not simulate_since_last_compile:
592
+ hints.append("Compilation succeeded; run simulate to verify before other actions.")
593
+ if (not code_dirty) and last_compile_success and simulate_since_last_compile:
594
+ hints.append("Simulation already ran after the latest compile; submit if the checks passed or make new edits first.")
595
+ return " ".join(hints) if hints else None
596
+
597
+ try:
598
+ initial_raw_observation = await env.initialize()
599
+ current_observation = _normalise_observation(initial_raw_observation)
600
+ final_observation = current_observation
601
+ agent.append_observation(
602
+ observation=current_observation,
603
+ step_index=0,
604
+ action_feedback=None,
605
+ guidance=_build_guidance(0),
606
+ )
607
+
608
+ total_reward = float(current_observation.get("total_reward") or 0.0)
609
+ already_done = bool(
610
+ current_observation.get("terminated") or current_observation.get("task_completed")
611
+ )
612
+
613
+ timeout = httpx.Timeout(
614
+ HTTP_TIMEOUT_SECONDS,
615
+ connect=HTTP_TIMEOUT_SECONDS,
616
+ read=HTTP_TIMEOUT_SECONDS,
617
+ )
618
+
619
+ async with httpx.AsyncClient(timeout=timeout) as client:
620
+ if not already_done:
621
+ for step_index in range(1, max_steps + 1):
622
+ assistant_text, tool_calls, raw_response, request_payload = await agent.invoke(client)
623
+ override_info: dict[str, Any] | None = None
624
+ if not tool_calls:
625
+ fallback_tool = (
626
+ "submit" if current_observation.get("task_completed") else "compile"
627
+ )
628
+ tool_calls = [{"tool": fallback_tool, "args": {}}]
629
+
630
+ primary_call = dict(tool_calls[0])
631
+ tool_name_raw = str(primary_call.get("tool", ""))
632
+ normalized_tool = tool_name_raw.strip().lower()
633
+ if normalized_tool == "compile":
634
+ if (not code_dirty) and last_compile_success and not simulate_since_last_compile:
635
+ override_info = {
636
+ "from": dict(primary_call),
637
+ "reason": "compile_after_success_without_changes",
638
+ }
639
+ primary_call = {"tool": "simulate", "args": {}}
640
+ tool_calls = [primary_call]
641
+ override_info["to"] = dict(primary_call)
642
+ env_call = EnvToolCall(tool=primary_call["tool"], args=primary_call["args"])
643
+
644
+ try:
645
+ skip_env_step = (
646
+ normalized_tool == "compile"
647
+ and needs_design_update
648
+ and not code_dirty
649
+ )
650
+ if skip_env_step:
651
+ reward_last = -0.01
652
+ total_reward += reward_last
653
+ current_observation = dict(current_observation)
654
+ current_observation["reward_last"] = reward_last
655
+ current_observation["total_reward"] = total_reward
656
+ final_observation = current_observation
657
+ done_flag = False
658
+ truncated_flag = False
659
+ else:
660
+ step_observation = await env.step(env_call)
661
+ current_observation = _normalise_observation(step_observation)
662
+ final_observation = current_observation
663
+ reward_last = float(current_observation.get("reward_last") or 0.0)
664
+ total_reward = float(
665
+ current_observation.get("total_reward") or (total_reward + reward_last)
666
+ )
667
+ done_flag = bool(
668
+ current_observation.get("terminated")
669
+ or current_observation.get("task_completed")
670
+ )
671
+ truncated_flag = bool(current_observation.get("truncated"))
672
+
673
+ executed_tool_name = str(primary_call["tool"])
674
+ normalized_executed_tool = executed_tool_name.strip().lower()
675
+
676
+ if normalized_executed_tool == "write_file":
677
+ code_dirty = True
678
+ last_compile_success = False
679
+ simulate_since_last_compile = False
680
+ last_compile_failed = False
681
+ needs_design_update = False
682
+ elif normalized_executed_tool == "compile":
683
+ compile_status_text = str(current_observation.get("compile_status") or "")
684
+ if "success" in compile_status_text.lower():
685
+ code_dirty = False
686
+ last_compile_success = True
687
+ simulate_since_last_compile = False
688
+ last_compile_failed = False
689
+ needs_design_update = False
690
+ else:
691
+ last_compile_success = False
692
+ last_compile_failed = True
693
+ needs_design_update = True
694
+ elif normalized_executed_tool == "simulate":
695
+ simulate_since_last_compile = True
696
+
697
+ tool_call_records = [
698
+ {"tool_name": call["tool"], "arguments": call["args"]}
699
+ for call in tool_calls
700
+ ]
701
+ step_info = {
702
+ "assistant_message": assistant_text,
703
+ "model_response": raw_response,
704
+ "llm_request": request_payload,
705
+ }
706
+ if override_info:
707
+ step_info["auto_override"] = override_info
708
+ if normalized_tool == "compile" and skip_env_step:
709
+ step_info["compile_blocked"] = {
710
+ "reason": "design_requires_update_before_compile",
711
+ "hint": "Use write_file to match required ports/behavior before compiling again.",
712
+ }
713
+ steps.append(
714
+ RolloutStep(
715
+ obs=current_observation,
716
+ tool_calls=tool_call_records,
717
+ reward=reward_last,
718
+ done=done_flag,
719
+ truncated=truncated_flag,
720
+ info=step_info,
721
+ )
722
+ )
723
+
724
+ if normalized_tool == "compile" and skip_env_step:
725
+ action_feedback = (
726
+ "Compilation blocked: update the design with write_file (declare required ports and logic) before compiling again."
727
+ )
728
+ else:
729
+ action_feedback = _summarize_action_feedback(
730
+ primary_call["tool"], primary_call["args"], current_observation, reward_last
731
+ )
732
+ agent.append_observation(
733
+ observation=current_observation,
734
+ step_index=step_index,
735
+ action_feedback=action_feedback,
736
+ guidance=_build_guidance(step_index),
737
+ )
738
+
739
+ if done_flag:
740
+ break
741
+
742
+ if step_index == max_steps:
743
+ truncated_due_to_limit = True
744
+ break
745
+ except Exception as exc: # pragma: no cover - defensive path
746
+ error_text = str(exc)
747
+ logger.exception("Verilog environment step failed: %s", exc)
748
+ failure_observation = dict(current_observation)
749
+ failure_observation["error"] = error_text
750
+ final_observation = failure_observation
751
+ tool_call_records = [
752
+ {"tool_name": primary_call["tool"], "arguments": primary_call["args"]}
753
+ ]
754
+ step_info = {
755
+ "assistant_message": assistant_text,
756
+ "model_response": raw_response,
757
+ "llm_request": request_payload,
758
+ "error": error_text,
759
+ }
760
+ steps.append(
761
+ RolloutStep(
762
+ obs=failure_observation,
763
+ tool_calls=tool_call_records,
764
+ reward=0.0,
765
+ done=True,
766
+ truncated=True,
767
+ info=step_info,
768
+ )
769
+ )
770
+ truncated_due_to_limit = True
771
+ break
772
+ finally:
773
+ with contextlib.suppress(Exception):
774
+ await env.terminate()
775
+
776
+ if final_observation is None:
777
+ final_observation = {}
778
+
779
+ final_total_reward = float(final_observation.get("total_reward") or total_reward)
780
+ final_done = bool(
781
+ final_observation.get("terminated") or final_observation.get("task_completed")
782
+ )
783
+ final_truncated = truncated_due_to_limit or bool(final_observation.get("truncated"))
784
+
785
+ metrics = RolloutMetrics(
786
+ episode_returns=[final_total_reward],
787
+ mean_return=final_total_reward,
788
+ num_steps=len(steps),
789
+ num_episodes=1,
790
+ outcome_score=final_total_reward,
791
+ events_score=None,
792
+ details={
793
+ "task_completed": bool(final_observation.get("task_completed")),
794
+ "total_reward": final_total_reward,
795
+ "steps": len(steps),
796
+ "truncated": final_truncated,
797
+ },
798
+ )
799
+
800
+ trajectory = RolloutTrajectory(
801
+ env_id=str(env_id),
802
+ policy_id=str(policy_id),
803
+ steps=steps,
804
+ final={
805
+ "observation": final_observation,
806
+ "reward": final_total_reward,
807
+ "done": final_done,
808
+ "truncated": final_truncated,
809
+ "info": {
810
+ "total_reward": final_total_reward,
811
+ "task_completed": bool(final_observation.get("task_completed")),
812
+ "policy_model": policy_model,
813
+ "inference_url": agent.inference_url,
814
+ },
815
+ },
816
+ length=len(steps),
817
+ inference_url=agent.inference_url, # NEW: Required for trace correlation
818
+ decision_samples=None,
819
+ )
820
+
821
+ # Build trace payload
822
+ trace_payload = {
823
+ "session_trace": {
824
+ "session_id": request.run_id,
825
+ "created_at": None,
826
+ "metadata": {
827
+ "task": "verilog",
828
+ "provider": "groq",
829
+ "model": policy_model,
830
+ "total_reward": final_total_reward,
831
+ "task_completed": bool(final_observation.get("task_completed")),
832
+ },
833
+ "session_time_steps": [],
834
+ "event_history": [],
835
+ "markov_blanket_message_history": [],
836
+ }
837
+ }
838
+
839
+ return RolloutResponse(
840
+ run_id=request.run_id,
841
+ trajectories=[trajectory],
842
+ branches={},
843
+ metrics=metrics,
844
+ aborted=False,
845
+ ops_executed=len(steps),
846
+ trace=trace_payload,
847
+ )
848
+
849
+
850
+ RUNTIME_DATASET: VerilogDataset
851
+ registry, RUNTIME_DATASET = build_dataset()
852
+ BASE_INFO = _base_task_info(RUNTIME_DATASET)
853
+
854
+
855
+ def build_config() -> TaskAppConfig:
856
+ tracing_enabled = tracing_env_enabled()
857
+ tracing_db_url = resolve_tracing_db_url()
858
+ tracer_factory = build_tracer_factory(
859
+ SessionTracer, enabled=tracing_enabled, db_url=tracing_db_url
860
+ )
861
+ sft_output_dir = resolve_sft_output_dir()
862
+
863
+ app_state: dict[str, Any] = {
864
+ "dataset": RUNTIME_DATASET,
865
+ "allowed_environments": ["verilog"],
866
+ "tracing_enabled": tracing_enabled,
867
+ }
868
+ if tracer_factory is not None:
869
+ app_state["session_tracer_factory"] = tracer_factory
870
+ if sft_output_dir:
871
+ app_state["sft_output_dir"] = sft_output_dir
872
+
873
+ if tracing_enabled:
874
+ logger.info("[verilog:tracing] enabled (db=%s)", tracing_db_url or "default")
875
+ else:
876
+ logger.info("[verilog:tracing] disabled")
877
+ if sft_output_dir:
878
+ logger.info("[verilog:sft] writing JSONL to %s", sft_output_dir)
879
+
880
+ config = TaskAppConfig(
881
+ app_id="grpo-verilog",
882
+ name="GRPO Verilog Task App",
883
+ description="Spec-to-RTL Verilog environment with GRPO-compatible metadata endpoints.",
884
+ base_task_info=BASE_INFO,
885
+ describe_taskset=lambda: describe_taskset(RUNTIME_DATASET),
886
+ provide_task_instances=lambda seeds: provide_task_instances(RUNTIME_DATASET, BASE_INFO, seeds),
887
+ rollout=rollout_executor,
888
+ dataset_registry=registry,
889
+ rubrics=RubricBundle(outcome=OUTCOME_RUBRIC, events=EVENTS_RUBRIC),
890
+ proxy=ProxyConfig(
891
+ enable_openai=True,
892
+ enable_groq=True,
893
+ system_hint=VERILOG_SYSTEM_PROMPT,
894
+ ),
895
+ routers=(),
896
+ app_state=app_state,
897
+ cors_origins=["*"],
898
+ )
899
+ return config
900
+
901
+
902
+ register_task_app(
903
+ entry=TaskAppEntry(
904
+ app_id="grpo-verilog",
905
+ description="Verilog spec-to-RTL task app with rollout metadata endpoints.",
906
+ config_factory=build_config,
907
+ aliases=("verilog", "verilog-task"),
908
+ env_files=(str(REPO_ROOT / "backend" / ".env.dev"),),
909
+ modal=ModalDeploymentConfig(
910
+ app_name="grpo-verilog-task-app",
911
+ python_version="3.11",
912
+ pip_packages=(
913
+ "fastapi>=0.100.0",
914
+ "uvicorn>=0.23.0",
915
+ "pydantic>=2.0.0",
916
+ "httpx>=0.24.0",
917
+ "python-dotenv>=1.0.1",
918
+ "datasets>=2.10.0",
919
+ ),
920
+ extra_local_dirs=(
921
+ (str(REPO_ROOT), "/opt/synth_ai_repo"),
922
+ (str(REPO_ROOT / "synth_ai"), "/opt/synth_ai_repo/synth_ai"),
923
+ (str(_HERE.parent), "/opt/synth_ai_repo/examples/task_apps/verilog/task_app"),
924
+ ),
925
+ secret_names=("groq-api-key", "openai-api-key"),
926
+ memory=8192,
927
+ cpu=2.0,
928
+ max_containers=4,
929
+ ),
930
+ )
931
+ )