synth-ai 0.2.12__py3-none-any.whl → 0.2.13.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (48) hide show
  1. examples/agora_ex/README_MoE.md +224 -0
  2. examples/agora_ex/__init__.py +7 -0
  3. examples/agora_ex/agora_ex.py +65 -0
  4. examples/agora_ex/agora_ex_task_app.py +590 -0
  5. examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +121 -0
  6. examples/agora_ex/reward_fn_grpo-human.py +129 -0
  7. examples/agora_ex/system_prompt_CURRENT.md +63 -0
  8. examples/agora_ex/task_app/agora_ex_task_app.py +590 -0
  9. examples/agora_ex/task_app/reward_fn_grpo-human.py +129 -0
  10. examples/agora_ex/task_app/system_prompt_CURRENT.md +63 -0
  11. examples/multi_step/configs/crafter_rl_outcome.toml +74 -0
  12. examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +175 -0
  13. examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +83 -0
  14. examples/multi_step/configs/crafter_rl_stepwise_simple.toml +78 -0
  15. examples/multi_step/crafter_rl_lora.md +51 -10
  16. examples/multi_step/sse_metrics_streaming_notes.md +357 -0
  17. examples/multi_step/task_app_config_notes.md +7 -1
  18. examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +4 -2
  19. examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +4 -2
  20. examples/warming_up_to_rl/run_eval.py +127 -18
  21. examples/warming_up_to_rl/task_app/grpo_crafter.py +3 -33
  22. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +109 -45
  23. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +42 -46
  24. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +232 -193
  25. synth_ai/__init__.py +41 -1
  26. synth_ai/api/train/builders.py +49 -19
  27. synth_ai/api/train/configs/__init__.py +44 -0
  28. synth_ai/api/train/configs/rl.py +133 -0
  29. synth_ai/api/train/configs/sft.py +94 -0
  30. synth_ai/api/train/configs/shared.py +24 -0
  31. synth_ai/cli/demo.py +38 -39
  32. synth_ai/cli/rl_demo.py +81 -102
  33. synth_ai/cli/task_apps.py +3 -0
  34. synth_ai/demos/core/cli.py +121 -159
  35. synth_ai/environments/examples/crafter_classic/environment.py +16 -0
  36. synth_ai/evals/__init__.py +15 -0
  37. synth_ai/evals/client.py +85 -0
  38. synth_ai/evals/types.py +42 -0
  39. synth_ai/judge_schemas.py +127 -0
  40. synth_ai/rubrics/__init__.py +22 -0
  41. synth_ai/rubrics/validators.py +126 -0
  42. synth_ai/tracing_v3/serialization.py +130 -0
  43. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/METADATA +1 -1
  44. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/RECORD +48 -22
  45. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/entry_points.txt +0 -1
  46. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/WHEEL +0 -0
  47. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/licenses/LICENSE +0 -0
  48. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ from pathlib import Path
6
6
  from typing import Any, cast
7
7
 
8
8
  import click
9
+ from pydantic import ValidationError
9
10
 
10
11
  try:
11
12
  _models_module = importlib.import_module("synth_ai.api.models.supported")
@@ -25,7 +26,8 @@ from .supported_algos import (
25
26
  ensure_model_supported_for_algorithm,
26
27
  validate_algorithm_config,
27
28
  )
28
- from .utils import TrainError, ensure_api_base, load_toml
29
+ from .utils import TrainError, ensure_api_base
30
+ from .configs import RLConfig, SFTConfig
29
31
 
30
32
 
31
33
  @dataclass(slots=True)
@@ -42,6 +44,16 @@ class SFTBuildResult:
42
44
  validation_file: Path | None
43
45
 
44
46
 
47
+ def _format_validation_error(path: Path, exc: ValidationError) -> str:
48
+ lines: list[str] = []
49
+ for error in exc.errors():
50
+ loc = ".".join(str(part) for part in error.get("loc", ()))
51
+ msg = error.get("msg", "invalid value")
52
+ lines.append(f"{loc or '<root>'}: {msg}")
53
+ details = "\n".join(f" - {line}" for line in lines) or " - Invalid configuration"
54
+ return f"Config validation failed ({path}):\n{details}"
55
+
56
+
45
57
  def build_rl_payload(
46
58
  *,
47
59
  config_path: Path,
@@ -50,13 +62,30 @@ def build_rl_payload(
50
62
  idempotency: str | None,
51
63
  allow_experimental: bool | None = None,
52
64
  ) -> RLBuildResult:
53
- data = load_toml(config_path)
54
65
  try:
55
- spec = validate_algorithm_config(data.get("algorithm"), expected_family="rl")
66
+ rl_cfg = RLConfig.from_path(config_path)
67
+ except ValidationError as exc:
68
+ raise click.ClickException(_format_validation_error(config_path, exc)) from exc
69
+
70
+ data = rl_cfg.to_dict()
71
+ # Ensure required [reference] section for backend validators
72
+ try:
73
+ ref_cfg = data.get("reference") if isinstance(data, dict) else None
74
+ if not isinstance(ref_cfg, dict):
75
+ data["reference"] = {"placement": "none"}
76
+ else:
77
+ ref_cfg.setdefault("placement", "none")
78
+ except Exception:
79
+ # Defensive: never fail builder due to optional defaults
80
+ data["reference"] = {"placement": "none"}
81
+ try:
82
+ spec = validate_algorithm_config(
83
+ rl_cfg.algorithm.model_dump(), expected_family="rl"
84
+ )
56
85
  except AlgorithmValidationError as exc:
57
86
  raise click.ClickException(str(exc)) from exc
58
87
  services = data.get("services") if isinstance(data.get("services"), dict) else {}
59
- model_cfg = data.get("model") if isinstance(data.get("model"), dict) else {}
88
+ model_cfg = rl_cfg.model
60
89
 
61
90
  final_task_url = (
62
91
  overrides.get("task_url")
@@ -69,10 +98,8 @@ def build_rl_payload(
69
98
  "Task app URL required (provide --task-url or set services.task_url in TOML)"
70
99
  )
71
100
 
72
- raw_source = model_cfg.get("source") if isinstance(model_cfg, dict) else ""
73
- model_source = str(raw_source or "").strip()
74
- raw_base = model_cfg.get("base") if isinstance(model_cfg, dict) else ""
75
- model_base = str(raw_base or "").strip()
101
+ model_source = (model_cfg.source or "").strip()
102
+ model_base = (model_cfg.base or "").strip()
76
103
  override_model = (overrides.get("model") or "").strip()
77
104
  if override_model:
78
105
  model_source = override_model
@@ -160,22 +187,23 @@ def build_sft_payload(
160
187
  dataset_override: Path | None,
161
188
  allow_experimental: bool | None,
162
189
  ) -> SFTBuildResult:
163
- data = load_toml(config_path)
164
190
  try:
165
- spec = validate_algorithm_config(data.get("algorithm"), expected_family="sft")
191
+ sft_cfg = SFTConfig.from_path(config_path)
192
+ except ValidationError as exc:
193
+ raise TrainError(_format_validation_error(config_path, exc)) from exc
194
+
195
+ data = sft_cfg.to_dict()
196
+ try:
197
+ algo_mapping = sft_cfg.algorithm.model_dump() if sft_cfg.algorithm else None
198
+ spec = validate_algorithm_config(algo_mapping, expected_family="sft")
166
199
  except AlgorithmValidationError as exc:
167
200
  raise TrainError(str(exc)) from exc
168
- job_cfg = data.get("job") if isinstance(data.get("job"), dict) else {}
169
201
  data_cfg = data.get("data") if isinstance(data.get("data"), dict) else {}
170
202
  hp_cfg = data.get("hyperparameters") if isinstance(data.get("hyperparameters"), dict) else {}
171
203
  train_cfg = data.get("training") if isinstance(data.get("training"), dict) else {}
172
204
  compute_cfg = data.get("compute") if isinstance(data.get("compute"), dict) else {}
173
205
 
174
- raw_dataset = (
175
- dataset_override
176
- or (job_cfg.get("data") if isinstance(job_cfg, dict) else None)
177
- or (job_cfg.get("data_path") if isinstance(job_cfg, dict) else None)
178
- )
206
+ raw_dataset = dataset_override or sft_cfg.job.data or sft_cfg.job.data_path
179
207
  if not raw_dataset:
180
208
  raise TrainError("Dataset not specified; pass --dataset or set [job].data")
181
209
  dataset_path = Path(raw_dataset)
@@ -260,9 +288,11 @@ def build_sft_payload(
260
288
  "enabled": bool(validation_cfg.get("enabled", True))
261
289
  }
262
290
 
263
- raw_model = str(
264
- job_cfg.get("model") if isinstance(job_cfg, dict) else None or data.get("model") or ""
265
- ).strip()
291
+ raw_model = (sft_cfg.job.model or "").strip()
292
+ if not raw_model:
293
+ model_block = data.get("model")
294
+ if isinstance(model_block, str):
295
+ raw_model = model_block.strip()
266
296
  if not raw_model:
267
297
  raise TrainError("Model not specified; set [job].model or [model].base in the config")
268
298
 
@@ -0,0 +1,44 @@
1
+ """Typed training config loaders for RL and SFT jobs."""
2
+
3
+ from .shared import AlgorithmConfig, ComputeConfig
4
+ from .sft import (
5
+ HyperparametersConfig,
6
+ HyperparametersParallelism,
7
+ JobConfig,
8
+ SFTConfig,
9
+ SFTDataConfig,
10
+ TrainingConfig,
11
+ TrainingValidationConfig,
12
+ )
13
+ from .rl import (
14
+ EvaluationConfig,
15
+ JudgeConfig,
16
+ JudgeOptionsConfig,
17
+ ModelConfig,
18
+ RLConfig,
19
+ RLServicesConfig,
20
+ RLTrainingConfig,
21
+ RolloutConfig,
22
+ WeightSyncConfig,
23
+ )
24
+
25
+ __all__ = [
26
+ "AlgorithmConfig",
27
+ "ComputeConfig",
28
+ "EvaluationConfig",
29
+ "HyperparametersConfig",
30
+ "HyperparametersParallelism",
31
+ "JobConfig",
32
+ "JudgeConfig",
33
+ "JudgeOptionsConfig",
34
+ "ModelConfig",
35
+ "RLConfig",
36
+ "RLServicesConfig",
37
+ "RLTrainingConfig",
38
+ "RolloutConfig",
39
+ "SFTConfig",
40
+ "SFTDataConfig",
41
+ "TrainingConfig",
42
+ "TrainingValidationConfig",
43
+ "WeightSyncConfig",
44
+ ]
@@ -0,0 +1,133 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Mapping
5
+
6
+ from pydantic import model_validator
7
+
8
+ from ..utils import load_toml
9
+ from .shared import AlgorithmConfig, ComputeConfig, ExtraModel
10
+
11
+
12
+ class RLServicesConfig(ExtraModel):
13
+ task_url: str
14
+ judge_url: str | None = None
15
+
16
+
17
+ class ModelConfig(ExtraModel):
18
+ source: str | None = None
19
+ base: str | None = None
20
+ trainer_mode: str
21
+ label: str
22
+
23
+ @model_validator(mode="after")
24
+ def _ensure_exactly_one_source_or_base(self) -> "ModelConfig":
25
+ if bool(self.source) == bool(self.base):
26
+ raise ValueError("Config must set exactly one of [model].source or [model].base")
27
+ return self
28
+
29
+
30
+ class RolloutConfig(ExtraModel):
31
+ env_name: str
32
+ policy_name: str
33
+ env_config: dict[str, Any] | None = None
34
+ policy_config: dict[str, Any] | None = None
35
+ max_turns: int
36
+ episodes_per_batch: int
37
+ max_concurrent_rollouts: int
38
+ batches_per_step: int | None = None
39
+ ops: list[str] | None = None
40
+
41
+
42
+ class WeightSyncConfig(ExtraModel):
43
+ enable: bool | None = None
44
+ targets: list[str] | None = None
45
+ mode: str | None = None
46
+ direct: bool | None = None
47
+ verify_every_k: int | None = None
48
+
49
+
50
+ class RLTrainingConfig(ExtraModel):
51
+ num_epochs: int
52
+ iterations_per_epoch: int
53
+ gradient_accumulation_steps: int | None = None
54
+ max_accumulated_minibatch: int | None = None
55
+ max_turns: int
56
+ batch_size: int
57
+ group_size: int
58
+ learning_rate: float
59
+ log_interval: int | None = None
60
+ weight_sync_interval: int | None = None
61
+ step_rewards_enabled: bool | None = None
62
+ step_rewards_mode: str | None = None
63
+ step_rewards_indicator_lambda: float | None = None
64
+ step_rewards_beta: float | None = None
65
+ step_rewards_strategy: str | None = None
66
+ event_rewards_kind: str | None = None
67
+ weight_sync: WeightSyncConfig | None = None
68
+
69
+
70
+ class EvaluationConfig(ExtraModel):
71
+ instances: int
72
+ every_n_iters: int
73
+ seeds: list[int]
74
+
75
+
76
+ class JudgeOptionsConfig(ExtraModel):
77
+ event: bool | None = None
78
+ outcome: bool | None = None
79
+ provider: str | None = None
80
+ model: str | None = None
81
+ rubric_id: str | None = None
82
+ rubric_overrides: dict[str, Any] | None = None
83
+ tracks: list[str] | None = None
84
+ weights: dict[str, float] | None = None
85
+ max_concurrency: int | None = None
86
+
87
+
88
+ class JudgeConfig(ExtraModel):
89
+ type: str | None = None
90
+ timeout_s: int | None = None
91
+ options: JudgeOptionsConfig | None = None
92
+
93
+
94
+ class RLConfig(ExtraModel):
95
+ algorithm: AlgorithmConfig
96
+ services: RLServicesConfig
97
+ compute: ComputeConfig | None = None
98
+ topology: dict[str, Any] | None = None
99
+ vllm: dict[str, Any] | None = None
100
+ reference: dict[str, Any] | None = None
101
+ model: ModelConfig
102
+ lora: dict[str, Any] | None = None
103
+ rollout: RolloutConfig | None = None
104
+ evaluation: EvaluationConfig | None = None
105
+ training: RLTrainingConfig | None = None
106
+ rubric: dict[str, Any] | None = None
107
+ judge: JudgeConfig | None = None
108
+ tags: dict[str, Any] | None = None
109
+
110
+ def to_dict(self) -> dict[str, Any]:
111
+ return self.model_dump(mode="python", exclude_none=True)
112
+
113
+ @classmethod
114
+ def from_mapping(cls, data: Mapping[str, Any]) -> "RLConfig":
115
+ return cls.model_validate(dict(data))
116
+
117
+ @classmethod
118
+ def from_path(cls, path: Path) -> "RLConfig":
119
+ content = load_toml(path)
120
+ return cls.from_mapping(content)
121
+
122
+
123
+ __all__ = [
124
+ "EvaluationConfig",
125
+ "JudgeConfig",
126
+ "JudgeOptionsConfig",
127
+ "ModelConfig",
128
+ "RLConfig",
129
+ "RLServicesConfig",
130
+ "RLTrainingConfig",
131
+ "RolloutConfig",
132
+ "WeightSyncConfig",
133
+ ]
@@ -0,0 +1,94 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Mapping
5
+
6
+ from pydantic import Field
7
+
8
+ from ..utils import load_toml
9
+ from .shared import AlgorithmConfig, ComputeConfig, ExtraModel
10
+
11
+
12
+ class JobConfig(ExtraModel):
13
+ model: str
14
+ data: str | None = None
15
+ data_path: str | None = None
16
+ poll_seconds: int | None = None
17
+
18
+
19
+ class SFTDataConfig(ExtraModel):
20
+ topology: dict[str, Any] | None = None
21
+ validation_path: str | None = None
22
+
23
+
24
+ class TrainingValidationConfig(ExtraModel):
25
+ enabled: bool | None = None
26
+ evaluation_strategy: str | None = None
27
+ eval_steps: int | None = None
28
+ save_best_model_at_end: bool | None = None
29
+ metric_for_best_model: str | None = None
30
+ greater_is_better: bool | None = None
31
+
32
+
33
+ class TrainingConfig(ExtraModel):
34
+ mode: str | None = None
35
+ use_qlora: bool | None = None
36
+ validation: TrainingValidationConfig | None = None
37
+
38
+
39
+ class HyperparametersParallelism(ExtraModel):
40
+ use_deepspeed: bool | None = None
41
+ deepspeed_stage: int | None = None
42
+ fsdp: bool | None = None
43
+ bf16: bool | None = None
44
+ fp16: bool | None = None
45
+ activation_checkpointing: bool | None = None
46
+ tensor_parallel_size: int | None = None
47
+ pipeline_parallel_size: int | None = None
48
+
49
+
50
+ class HyperparametersConfig(ExtraModel):
51
+ n_epochs: int = 1
52
+ batch_size: int | None = None
53
+ global_batch: int | None = None
54
+ per_device_batch: int | None = None
55
+ gradient_accumulation_steps: int | None = None
56
+ sequence_length: int | None = None
57
+ learning_rate: float | None = None
58
+ warmup_ratio: float | None = None
59
+ train_kind: str | None = None
60
+ weight_decay: float | None = None
61
+ parallelism: HyperparametersParallelism | None = None
62
+
63
+
64
+ class SFTConfig(ExtraModel):
65
+ algorithm: AlgorithmConfig | None = None
66
+ job: JobConfig
67
+ compute: ComputeConfig | None = None
68
+ data: SFTDataConfig | None = None
69
+ training: TrainingConfig | None = None
70
+ hyperparameters: HyperparametersConfig = Field(default_factory=HyperparametersConfig)
71
+ tags: dict[str, Any] | None = None
72
+
73
+ def to_dict(self) -> dict[str, Any]:
74
+ return self.model_dump(mode="python", exclude_none=True)
75
+
76
+ @classmethod
77
+ def from_mapping(cls, data: Mapping[str, Any]) -> "SFTConfig":
78
+ return cls.model_validate(dict(data))
79
+
80
+ @classmethod
81
+ def from_path(cls, path: Path) -> "SFTConfig":
82
+ content = load_toml(path)
83
+ return cls.from_mapping(content)
84
+
85
+
86
+ __all__ = [
87
+ "HyperparametersConfig",
88
+ "HyperparametersParallelism",
89
+ "JobConfig",
90
+ "SFTConfig",
91
+ "SFTDataConfig",
92
+ "TrainingConfig",
93
+ "TrainingValidationConfig",
94
+ ]
@@ -0,0 +1,24 @@
1
+ from __future__ import annotations
2
+
3
+ from pydantic import BaseModel, ConfigDict
4
+
5
+
6
+ class ExtraModel(BaseModel):
7
+ """Base model that tolerates unknown keys so configs keep forward compatibility."""
8
+
9
+ model_config = ConfigDict(extra="allow")
10
+
11
+
12
+ class AlgorithmConfig(ExtraModel):
13
+ type: str
14
+ method: str
15
+ variety: str
16
+
17
+
18
+ class ComputeConfig(ExtraModel):
19
+ gpu_type: str
20
+ gpu_count: int
21
+ nodes: int | None = None
22
+
23
+
24
+ __all__ = ["ExtraModel", "AlgorithmConfig", "ComputeConfig"]
synth_ai/cli/demo.py CHANGED
@@ -1,9 +1,9 @@
1
1
  #!/usr/bin/env python3
2
2
  """
3
- CLI: interactive launcher for example demos and forwarders for new RL demo.
3
+ CLI: interactive launcher for example demos and RL demo helpers.
4
4
 
5
- - `synth-ai demo` (no subcommand) -> legacy examples/ runner (run_demo.sh picker)
6
- - `synth-ai demo deploy|configure|run` -> forwards to synth_ai.demos.core.cli
5
+ - `synth-ai demo` (no subcommand) -> initialize RL demo files into ./synth_demo/
6
+ - `synth-ai demo deploy|configure|run` -> invoke RL demo helpers directly.
7
7
  """
8
8
 
9
9
  from __future__ import annotations
@@ -14,6 +14,8 @@ from pathlib import Path
14
14
 
15
15
  import click
16
16
 
17
+ from synth_ai.demos.core import cli as demo_commands
18
+
17
19
 
18
20
  def _find_demo_scripts(root: Path) -> list[Path]:
19
21
  if not root.exists():
@@ -21,17 +23,23 @@ def _find_demo_scripts(root: Path) -> list[Path]:
21
23
  return sorted([p for p in root.rglob("run_demo.sh") if p.is_file()])
22
24
 
23
25
 
24
- def _forward_to_new(args: list[str]) -> None:
25
- import sys
26
+ def _run_demo_command(func, *args, **kwargs) -> None:
27
+ """Invoke a demo command and exit via Click on non-zero status codes."""
28
+
29
+ try:
30
+ result = func(*args, **kwargs)
31
+ except SystemExit as exc: # pragma: no cover - defensive
32
+ raise click.exceptions.Exit(exc.code or 1) from exc
33
+
34
+ if result is None:
35
+ return
26
36
 
27
37
  try:
28
- from synth_ai.demos.core import cli as demo_cli # type: ignore
29
- except Exception as e: # pragma: no cover
30
- click.echo(f"Failed to import demo CLI: {e}")
31
- sys.exit(1)
32
- rc = int(demo_cli.main(args) or 0)
33
- if rc != 0:
34
- sys.exit(rc)
38
+ code = int(result)
39
+ except (TypeError, ValueError):
40
+ return
41
+ if code != 0:
42
+ raise click.exceptions.Exit(code)
35
43
 
36
44
 
37
45
  def register(cli):
@@ -92,11 +100,8 @@ def register(cli):
92
100
  click.echo("\n🛑 Demo interrupted by user")
93
101
  return
94
102
 
95
- # Default: forward to RL demo init behavior, optionally with --force
96
- args: list[str] = ["rl_demo.init"]
97
- if force:
98
- args.append("--force")
99
- _forward_to_new(args)
103
+ # Default: initialize RL demo files via new command
104
+ _run_demo_command(demo_commands.init, force=force)
100
105
 
101
106
  # (prepare command removed; configure now prepares baseline TOML)
102
107
 
@@ -122,24 +127,21 @@ def register(cli):
122
127
  help="Path to deploy_task_app.sh (optional legacy)",
123
128
  )
124
129
  def demo_deploy(local: bool, app: str | None, name: str, script: str | None):
125
- args: list[str] = ["rl_demo.deploy"]
126
- if local:
127
- args.append("--local")
128
- if app:
129
- args.extend(["--app", app])
130
- if name:
131
- args.extend(["--name", name])
132
- if script:
133
- args.extend(["--script", script])
134
- _forward_to_new(args)
130
+ _run_demo_command(
131
+ demo_commands.deploy,
132
+ local=local,
133
+ app=app,
134
+ name=name,
135
+ script=script,
136
+ )
135
137
 
136
138
  @_dg.command("configure")
137
139
  def demo_configure():
138
- _forward_to_new(["rl_demo.configure"])
140
+ _run_demo_command(demo_commands.run)
139
141
 
140
142
  @_dg.command("setup")
141
143
  def demo_setup():
142
- _forward_to_new(["rl_demo.setup"])
144
+ _run_demo_command(demo_commands.setup)
143
145
 
144
146
  @_dg.command("run")
145
147
  @click.option("--batch-size", type=int, default=None)
@@ -147,13 +149,10 @@ def register(cli):
147
149
  @click.option("--model", type=str, default=None)
148
150
  @click.option("--timeout", type=int, default=600)
149
151
  def demo_run(batch_size: int | None, group_size: int | None, model: str | None, timeout: int):
150
- args = ["rl_demo.run"]
151
- if batch_size is not None:
152
- args.extend(["--batch-size", str(batch_size)])
153
- if group_size is not None:
154
- args.extend(["--group-size", str(group_size)])
155
- if model:
156
- args.extend(["--model", model])
157
- if timeout:
158
- args.extend(["--timeout", str(timeout)])
159
- _forward_to_new(args)
152
+ _run_demo_command(
153
+ demo_commands.run,
154
+ batch_size=batch_size,
155
+ group_size=group_size,
156
+ model=model,
157
+ timeout=timeout,
158
+ )