synth-ai 0.2.8.dev12__py3-none-any.whl → 0.2.9.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. synth_ai/api/train/__init__.py +5 -0
  2. synth_ai/api/train/builders.py +165 -0
  3. synth_ai/api/train/cli.py +450 -0
  4. synth_ai/api/train/config_finder.py +168 -0
  5. synth_ai/api/train/env_resolver.py +302 -0
  6. synth_ai/api/train/pollers.py +66 -0
  7. synth_ai/api/train/task_app.py +193 -0
  8. synth_ai/api/train/utils.py +232 -0
  9. synth_ai/cli/__init__.py +23 -0
  10. synth_ai/cli/rl_demo.py +18 -6
  11. synth_ai/cli/root.py +38 -6
  12. synth_ai/cli/task_apps.py +1107 -0
  13. synth_ai/demo_registry.py +258 -0
  14. synth_ai/demos/core/cli.py +147 -111
  15. synth_ai/demos/demo_task_apps/__init__.py +7 -1
  16. synth_ai/demos/demo_task_apps/math/config.toml +55 -110
  17. synth_ai/demos/demo_task_apps/math/modal_task_app.py +157 -21
  18. synth_ai/demos/demo_task_apps/math/task_app_entry.py +39 -0
  19. synth_ai/task/__init__.py +94 -1
  20. synth_ai/task/apps/__init__.py +88 -0
  21. synth_ai/task/apps/grpo_crafter.py +438 -0
  22. synth_ai/task/apps/math_single_step.py +852 -0
  23. synth_ai/task/auth.py +153 -0
  24. synth_ai/task/client.py +165 -0
  25. synth_ai/task/contracts.py +29 -14
  26. synth_ai/task/datasets.py +105 -0
  27. synth_ai/task/errors.py +49 -0
  28. synth_ai/task/json.py +77 -0
  29. synth_ai/task/proxy.py +258 -0
  30. synth_ai/task/rubrics.py +212 -0
  31. synth_ai/task/server.py +398 -0
  32. synth_ai/task/tracing_utils.py +79 -0
  33. synth_ai/task/vendors.py +61 -0
  34. synth_ai/tracing_v3/session_tracer.py +13 -5
  35. synth_ai/tracing_v3/storage/base.py +10 -12
  36. synth_ai/tracing_v3/turso/manager.py +20 -6
  37. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/METADATA +3 -2
  38. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/RECORD +42 -18
  39. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/WHEEL +0 -0
  40. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/entry_points.txt +0 -0
  41. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/licenses/LICENSE +0 -0
  42. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,193 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from dataclasses import dataclass
5
+ from typing import Iterable
6
+
7
+ import click
8
+ import requests
9
+
10
+ from .utils import CLIResult, http_get, run_cli
11
+
12
+
13
+ @dataclass(slots=True)
14
+ class TaskAppHealth:
15
+ ok: bool
16
+ health_status: int | None
17
+ task_info_status: int | None
18
+ detail: str | None = None
19
+
20
+
21
+ def _health_response_ok(resp: requests.Response | None) -> tuple[bool, str]:
22
+ if resp is None:
23
+ return False, ""
24
+ status = resp.status_code
25
+ if status == 200:
26
+ return True, ""
27
+ if status in {401, 403}:
28
+ try:
29
+ payload = resp.json()
30
+ except ValueError:
31
+ payload = {}
32
+ prefix = payload.get("expected_api_key_prefix")
33
+ detail = str(payload.get("detail", ""))
34
+ if prefix or "expected prefix" in detail.lower():
35
+ note = "auth-optional"
36
+ if prefix:
37
+ note += f" (expected-prefix={prefix})"
38
+ return True, note
39
+ return False, ""
40
+
41
+
42
+ def check_task_app_health(base_url: str, api_key: str, *, timeout: float = 10.0) -> TaskAppHealth:
43
+ # Send ALL known environment keys so the server can authorize any valid one
44
+ import os
45
+ headers = {"X-API-Key": api_key}
46
+ aliases = (os.getenv("ENVIRONMENT_API_KEY_ALIASES") or "").strip()
47
+ keys: list[str] = [api_key]
48
+ if aliases:
49
+ keys.extend([p.strip() for p in aliases.split(",") if p.strip()])
50
+ if keys:
51
+ headers["X-API-Keys"] = ",".join(keys)
52
+ headers.setdefault("Authorization", f"Bearer {api_key}")
53
+ base = base_url.rstrip("/")
54
+ detail_parts: list[str] = []
55
+
56
+ health_resp: requests.Response | None = None
57
+ health_ok = False
58
+ try:
59
+ health_resp = http_get(f"{base}/health", headers=headers, timeout=timeout)
60
+ health_ok, note = _health_response_ok(health_resp)
61
+ suffix = f" ({note})" if note else ""
62
+ # On non-200, include brief JSON detail if present
63
+ if not health_ok and health_resp is not None:
64
+ try:
65
+ hjs = health_resp.json()
66
+ # pull a few helpful fields without dumping everything
67
+ expected = hjs.get("expected_api_key_prefix")
68
+ authorized = hjs.get("authorized")
69
+ detail = hjs.get("detail")
70
+ extras = []
71
+ if authorized is not None:
72
+ extras.append(f"authorized={authorized}")
73
+ if expected:
74
+ extras.append(f"expected_prefix={expected}")
75
+ if detail:
76
+ extras.append(f"detail={str(detail)[:80]}")
77
+ if extras:
78
+ suffix += " [" + ", ".join(extras) + "]"
79
+ except Exception:
80
+ pass
81
+ detail_parts.append(f"/health={health_resp.status_code}{suffix}")
82
+ except requests.RequestException as exc:
83
+ detail_parts.append(f"/health_error={exc}")
84
+
85
+ task_resp: requests.Response | None = None
86
+ task_ok = False
87
+ try:
88
+ task_resp = http_get(f"{base}/task_info", headers=headers, timeout=timeout)
89
+ task_ok = bool(task_resp.status_code == 200)
90
+ if not task_ok and task_resp is not None:
91
+ try:
92
+ tjs = task_resp.json()
93
+ msg = tjs.get("detail") or tjs.get("status")
94
+ detail_parts.append(f"/task_info={task_resp.status_code} ({str(msg)[:80]})")
95
+ except Exception:
96
+ detail_parts.append(f"/task_info={task_resp.status_code}")
97
+ else:
98
+ detail_parts.append(f"/task_info={task_resp.status_code}")
99
+ except requests.RequestException as exc:
100
+ detail_parts.append(f"/task_info_error={exc}")
101
+
102
+ ok = bool(health_ok and task_ok)
103
+ detail = ", ".join(detail_parts)
104
+ return TaskAppHealth(
105
+ ok=ok,
106
+ health_status=None if health_resp is None else health_resp.status_code,
107
+ task_info_status=None if task_resp is None else task_resp.status_code,
108
+ detail=detail,
109
+ )
110
+
111
+
112
+ @dataclass(slots=True)
113
+ class ModalSecret:
114
+ name: str
115
+ value: str
116
+
117
+
118
+ @dataclass(slots=True)
119
+ class ModalApp:
120
+ app_id: str
121
+ label: str
122
+ url: str
123
+
124
+
125
+ def _run_modal(args: Iterable[str]) -> CLIResult:
126
+ return run_cli(["modal", *args], timeout=30.0)
127
+
128
+
129
+ def list_modal_secrets(pattern: str | None = None) -> list[str]:
130
+ result = _run_modal(["secret", "list"])
131
+ if result.code != 0:
132
+ raise click.ClickException(f"modal secret list failed: {result.stderr or result.stdout}")
133
+ names: list[str] = []
134
+ for line in result.stdout.splitlines():
135
+ line = line.strip()
136
+ if not line or line.startswith("NAME"):
137
+ continue
138
+ parts = line.split()
139
+ name = parts[0]
140
+ if pattern and pattern.lower() not in name.lower():
141
+ continue
142
+ names.append(name)
143
+ return names
144
+
145
+
146
+ def get_modal_secret_value(name: str) -> str:
147
+ result = _run_modal(["secret", "get", name])
148
+ if result.code != 0:
149
+ raise click.ClickException(f"modal secret get {name} failed: {result.stderr or result.stdout}")
150
+ value = result.stdout.strip()
151
+ if not value:
152
+ raise click.ClickException(f"Secret {name} is empty")
153
+ return value
154
+
155
+
156
+ def list_modal_apps(pattern: str | None = None) -> list[ModalApp]:
157
+ result = _run_modal(["app", "list"])
158
+ if result.code != 0:
159
+ raise click.ClickException(f"modal app list failed: {result.stderr or result.stdout}")
160
+ apps: list[ModalApp] = []
161
+ for line in result.stdout.splitlines():
162
+ line = line.strip()
163
+ if not line or line.startswith("APP"):
164
+ continue
165
+ parts = line.split()
166
+ if len(parts) < 3:
167
+ continue
168
+ app_id, label, url = parts[0], parts[1], parts[-1]
169
+ if pattern and pattern.lower() not in (label.lower() + url.lower() + app_id.lower()):
170
+ continue
171
+ apps.append(ModalApp(app_id=app_id, label=label, url=url))
172
+ return apps
173
+
174
+
175
+ def format_modal_apps(apps: list[ModalApp]) -> str:
176
+ rows = [f"{idx}) {app.label} {app.url}" for idx, app in enumerate(apps, start=1)]
177
+ return "\n".join(rows)
178
+
179
+
180
+ def format_modal_secrets(names: list[str]) -> str:
181
+ return "\n".join(f"{idx}) {name}" for idx, name in enumerate(names, start=1))
182
+
183
+
184
+ __all__ = [
185
+ "ModalApp",
186
+ "ModalSecret",
187
+ "check_task_app_health",
188
+ "format_modal_apps",
189
+ "format_modal_secrets",
190
+ "get_modal_secret_value",
191
+ "list_modal_apps",
192
+ "list_modal_secrets",
193
+ ]
@@ -0,0 +1,232 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import re
6
+ import subprocess
7
+ import sys
8
+ import tempfile
9
+ import time
10
+ from dataclasses import dataclass
11
+ from pathlib import Path
12
+ from typing import Any, Iterable, Mapping
13
+
14
+ import requests
15
+ import tomllib
16
+
17
+ REPO_ROOT = Path(__file__).resolve().parents[3]
18
+
19
+
20
+ class TrainError(RuntimeError):
21
+ """Raised for interactive CLI failures."""
22
+
23
+
24
+ def load_toml(path: Path) -> dict[str, Any]:
25
+ try:
26
+ with path.open("rb") as fh:
27
+ return tomllib.load(fh)
28
+ except FileNotFoundError as exc: # pragma: no cover - guarded by CLI
29
+ raise TrainError(f"Config not found: {path}") from exc
30
+ except tomllib.TOMLDecodeError as exc: # pragma: no cover - malformed input
31
+ raise TrainError(f"Failed to parse TOML: {path}\n{exc}") from exc
32
+
33
+
34
+ def mask_value(value: str | None) -> str:
35
+ if not value:
36
+ return "<unset>"
37
+ value = str(value)
38
+ if len(value) <= 6:
39
+ return "****"
40
+ return f"{value[:4]}…{value[-2:]}"
41
+
42
+
43
+ _ENV_LINE = re.compile(r"^\s*(?:export\s+)?(?P<key>[A-Za-z0-9_]+)\s*=\s*(?P<value>.*)$")
44
+
45
+
46
+ def read_env_file(path: Path) -> dict[str, str]:
47
+ if not path.exists():
48
+ return {}
49
+ data: dict[str, str] = {}
50
+ for line in path.read_text(encoding="utf-8", errors="ignore").splitlines():
51
+ m = _ENV_LINE.match(line)
52
+ if not m:
53
+ continue
54
+ raw = m.group("value").strip()
55
+ if raw and raw[0] == raw[-1] and raw[0] in {'"', "'"} and len(raw) >= 2:
56
+ raw = raw[1:-1]
57
+ data[m.group("key")] = raw
58
+ return data
59
+
60
+
61
+ def write_env_value(path: Path, key: str, value: str) -> None:
62
+ existing = []
63
+ if path.exists():
64
+ existing = path.read_text(encoding="utf-8", errors="ignore").splitlines()
65
+ updated = False
66
+ new_lines: list[str] = []
67
+ for line in existing:
68
+ m = _ENV_LINE.match(line)
69
+ if m and m.group("key") == key:
70
+ new_lines.append(f"{key}={value}")
71
+ updated = True
72
+ else:
73
+ new_lines.append(line)
74
+ if not updated:
75
+ new_lines.append(f"{key}={value}")
76
+ path.write_text("\n".join(new_lines) + "\n", encoding="utf-8")
77
+
78
+
79
+ @dataclass(slots=True)
80
+ class CLIResult:
81
+ code: int
82
+ stdout: str
83
+ stderr: str
84
+
85
+
86
+ def run_cli(args: Iterable[str], *, cwd: Path | None = None, env: Mapping[str, str] | None = None, timeout: float | None = None) -> CLIResult:
87
+ proc = subprocess.run(
88
+ list(args),
89
+ cwd=cwd,
90
+ env=dict(os.environ, **(env or {})),
91
+ capture_output=True,
92
+ text=True,
93
+ timeout=timeout,
94
+ )
95
+ return CLIResult(code=proc.returncode, stdout=proc.stdout.strip(), stderr=proc.stderr.strip())
96
+
97
+
98
+ def http_post(url: str, *, headers: Mapping[str, str] | None = None, json_body: Any | None = None, timeout: float = 60.0) -> requests.Response:
99
+ resp = requests.post(url, headers=dict(headers or {}), json=json_body, timeout=timeout)
100
+ return resp
101
+
102
+
103
+ def http_get(url: str, *, headers: Mapping[str, str] | None = None, timeout: float = 30.0) -> requests.Response:
104
+ resp = requests.get(url, headers=dict(headers or {}), timeout=timeout)
105
+ return resp
106
+
107
+
108
+ def post_multipart(url: str, *, api_key: str, file_field: str, file_path: Path, purpose: str = "fine-tune") -> requests.Response:
109
+ headers = {"Authorization": f"Bearer {api_key}"}
110
+ files = {file_field: (file_path.name, file_path.read_bytes(), "application/jsonl")}
111
+ data = {"purpose": purpose}
112
+ return requests.post(url, headers=headers, files=files, data=data, timeout=300)
113
+
114
+
115
+ def fmt_duration(seconds: float) -> str:
116
+ if seconds < 60:
117
+ return f"{seconds:.1f}s"
118
+ minutes, secs = divmod(seconds, 60)
119
+ if minutes < 60:
120
+ return f"{int(minutes)}m{int(secs):02d}s"
121
+ hours, mins = divmod(minutes, 60)
122
+ return f"{int(hours)}h{int(mins):02d}m"
123
+
124
+
125
+ def validate_sft_jsonl(path: Path, *, max_errors: int = 20) -> None:
126
+ errors: list[str] = []
127
+ try:
128
+ fh = path.open("r", encoding="utf-8")
129
+ except FileNotFoundError as exc: # pragma: no cover - upstream ensures existence
130
+ raise TrainError(f"Dataset not found: {path}") from exc
131
+
132
+ with fh:
133
+ for idx, line in enumerate(fh, start=1):
134
+ stripped = line.strip()
135
+ if not stripped:
136
+ continue
137
+ try:
138
+ record = json.loads(stripped)
139
+ except json.JSONDecodeError as exc:
140
+ errors.append(f"Line {idx}: invalid JSON ({exc.msg})")
141
+ if len(errors) >= max_errors:
142
+ break
143
+ continue
144
+
145
+ messages = record.get("messages")
146
+ if not isinstance(messages, list) or not messages:
147
+ errors.append(f"Line {idx}: missing or empty 'messages' list")
148
+ if len(errors) >= max_errors:
149
+ break
150
+ continue
151
+
152
+ for msg_idx, msg in enumerate(messages):
153
+ if not isinstance(msg, dict):
154
+ errors.append(f"Line {idx}: message {msg_idx} is not an object")
155
+ break
156
+ if "role" not in msg or "content" not in msg:
157
+ errors.append(f"Line {idx}: message {msg_idx} missing 'role' or 'content'")
158
+ break
159
+ if not isinstance(msg["role"], str) or not isinstance(msg["content"], str):
160
+ errors.append(f"Line {idx}: message {msg_idx} has non-string role/content")
161
+ break
162
+ if len(errors) >= max_errors:
163
+ break
164
+
165
+ if errors:
166
+ suffix = "" if len(errors) < max_errors else f" (showing first {max_errors} issues)"
167
+ details = "\n - ".join(errors)
168
+ raise TrainError(f"Dataset validation failed{suffix}:\n - {details}")
169
+
170
+
171
+ def limit_jsonl_examples(src: Path, limit: int) -> Path:
172
+ if limit <= 0:
173
+ raise TrainError("Example limit must be positive")
174
+ if not src.exists():
175
+ raise TrainError(f"Dataset not found: {src}")
176
+
177
+ tmp_dir = Path(tempfile.mkdtemp(prefix="sft_subset_"))
178
+ dest = tmp_dir / f"{src.stem}.head{limit}{src.suffix}"
179
+
180
+ written = 0
181
+ with src.open("r", encoding="utf-8") as fin, dest.open("w", encoding="utf-8") as fout:
182
+ for line in fin:
183
+ if not line.strip():
184
+ continue
185
+ fout.write(line)
186
+ written += 1
187
+ if written >= limit:
188
+ break
189
+
190
+ if written == 0:
191
+ raise TrainError("Subset dataset is empty; check limit value")
192
+
193
+ return dest
194
+
195
+
196
+ def ensure_api_base(base: str) -> str:
197
+ base = base.rstrip("/")
198
+ if not base.endswith("/api"):
199
+ base = f"{base}/api"
200
+ return base
201
+
202
+
203
+ def preview_json(data: Any, limit: int = 600) -> str:
204
+ try:
205
+ return json.dumps(data, indent=2)[:limit]
206
+ except Exception:
207
+ return str(data)[:limit]
208
+
209
+
210
+ def sleep(seconds: float) -> None:
211
+ time.sleep(seconds)
212
+
213
+
214
+ __all__ = [
215
+ "CLIResult",
216
+ "REPO_ROOT",
217
+ "TrainError",
218
+ "ensure_api_base",
219
+ "fmt_duration",
220
+ "http_get",
221
+ "http_post",
222
+ "load_toml",
223
+ "mask_value",
224
+ "post_multipart",
225
+ "preview_json",
226
+ "read_env_file",
227
+ "run_cli",
228
+ "sleep",
229
+ "limit_jsonl_examples",
230
+ "validate_sft_jsonl",
231
+ "write_env_value",
232
+ ]
synth_ai/cli/__init__.py CHANGED
@@ -75,3 +75,26 @@ try:
75
75
  _rl_demo.register(cli)
76
76
  except Exception:
77
77
  pass
78
+ try:
79
+ from synth_ai.api.train import register as _train_register
80
+
81
+ _train_register(cli)
82
+ except Exception:
83
+ pass
84
+
85
+
86
+
87
+ from .task_apps import task_app_group
88
+ cli.add_command(task_app_group, name="task-app")
89
+
90
+
91
+ try:
92
+ from . import task_apps as _task_apps
93
+ _task_apps.register(cli)
94
+ except Exception:
95
+ pass
96
+
97
+ cli.add_command(task_app_group.commands['serve'], name='serve')
98
+ cli.add_command(task_app_group.commands['deploy'], name='deploy')
99
+
100
+ cli.add_command(task_app_group.commands['modal-serve'], name='modal-serve')
synth_ai/cli/rl_demo.py CHANGED
@@ -67,9 +67,15 @@ def register(cli):
67
67
  _forward(["rl_demo.configure"])
68
68
 
69
69
  @_rlg.command("init")
70
- @click.option("--force", is_flag=True, help="Overwrite existing files in CWD")
71
- def rl_init(force: bool):
70
+ @click.option("--template", type=str, default=None, help="Template id to instantiate")
71
+ @click.option("--dest", type=click.Path(), default=None, help="Destination directory for files")
72
+ @click.option("--force", is_flag=True, help="Overwrite existing files in destination")
73
+ def rl_init(template: str | None, dest: str | None, force: bool):
72
74
  args = ["rl_demo.init"]
75
+ if template:
76
+ args.extend(["--template", template])
77
+ if dest:
78
+ args.extend(["--dest", dest])
73
79
  if force:
74
80
  args.append("--force")
75
81
  _forward(args)
@@ -130,9 +136,15 @@ def register(cli):
130
136
  _forward(["rl_demo.configure"])
131
137
 
132
138
  @cli.command("rl_demo.init")
133
- @click.option("--force", is_flag=True, help="Overwrite existing files in CWD")
134
- def rl_init_alias(force: bool):
139
+ @click.option("--template", type=str, default=None, help="Template id to instantiate")
140
+ @click.option("--dest", type=click.Path(), default=None, help="Destination directory for files")
141
+ @click.option("--force", is_flag=True, help="Overwrite existing files in destination")
142
+ def rl_init_alias(template: str | None, dest: str | None, force: bool):
135
143
  args = ["rl_demo.init"]
144
+ if template:
145
+ args.extend(["--template", template])
146
+ if dest:
147
+ args.extend(["--dest", dest])
136
148
  if force:
137
149
  args.append("--force")
138
150
  _forward(args)
@@ -161,12 +173,12 @@ def register(cli):
161
173
  _forward(args)
162
174
 
163
175
  # Top-level convenience alias: `synth-ai deploy`
164
- @cli.command("deploy")
176
+ @cli.command("demo-deploy")
165
177
  @click.option("--local", is_flag=True, help="Run local FastAPI instead of Modal deploy")
166
178
  @click.option("--app", type=click.Path(), default=None, help="Path to Modal app.py for uv run modal deploy")
167
179
  @click.option("--name", type=str, default="synth-math-demo", help="Modal app name")
168
180
  @click.option("--script", type=click.Path(), default=None, help="Path to deploy_task_app.sh (optional legacy)")
169
- def deploy_top(local: bool, app: str | None, name: str, script: str | None):
181
+ def deploy_demo(local: bool, app: str | None, name: str, script: str | None):
170
182
  args: list[str] = ["rl_demo.deploy"]
171
183
  if local:
172
184
  args.append("--local")
synth_ai/cli/root.py CHANGED
@@ -14,6 +14,20 @@ import sys
14
14
  import time
15
15
 
16
16
  import click
17
+ try:
18
+ from importlib.metadata import PackageNotFoundError, version as _pkg_version
19
+ try:
20
+ __pkg_version__ = _pkg_version("synth-ai")
21
+ except PackageNotFoundError:
22
+ try:
23
+ from synth_ai import __version__ as __pkg_version__ # type: ignore
24
+ except Exception:
25
+ __pkg_version__ = "unknown"
26
+ except Exception:
27
+ try:
28
+ from synth_ai import __version__ as __pkg_version__ # type: ignore
29
+ except Exception:
30
+ __pkg_version__ = "unknown"
17
31
 
18
32
 
19
33
  def find_sqld_binary() -> str | None:
@@ -66,9 +80,10 @@ rm -rf "$TMP_DIR"
66
80
  return os.path.expanduser("~/.local/bin/sqld")
67
81
 
68
82
 
69
- @click.group()
83
+ @click.group(help=f"Synth AI v{__pkg_version__} - Software for aiding the best and multiplying the will.")
84
+ @click.version_option(version=__pkg_version__, prog_name="synth-ai")
70
85
  def cli():
71
- """Synth AI - Software for aiding the best and multiplying the will."""
86
+ """Top-level command group for Synth AI."""
72
87
 
73
88
 
74
89
  # === Legacy demo command group (aliases new rl_demo implementation) ===
@@ -84,7 +99,7 @@ def _forward_to_demo(args: list[str]) -> None:
84
99
  except Exception as e: # pragma: no cover
85
100
  click.echo(f"Failed to import demo CLI: {e}")
86
101
  sys.exit(1)
87
- rc = int(demo_cli.main(args) or 0)
102
+ rc = int(getattr(demo_cli, "main")(args) or 0) # type: ignore[attr-defined]
88
103
  if rc != 0:
89
104
  sys.exit(rc)
90
105
 
@@ -123,6 +138,22 @@ def setup():
123
138
  _forward_to_demo(["rl_demo.setup"])
124
139
 
125
140
 
141
+ @demo.command()
142
+ @click.option("--template", type=str, default=None, help="Template id to instantiate")
143
+ @click.option("--dest", type=str, default=None, help="Destination directory for files")
144
+ @click.option("--force", is_flag=True, help="Overwrite existing files in destination")
145
+ def init(template: str | None, dest: str | None, force: bool):
146
+ """Copy demo task app template into the current directory."""
147
+ args: list[str] = ["demo.init"]
148
+ if template:
149
+ args.extend(["--template", template])
150
+ if dest:
151
+ args.extend(["--dest", dest])
152
+ if force:
153
+ args.append("--force")
154
+ _forward_to_demo(args)
155
+
156
+
126
157
  @demo.command()
127
158
  @click.option("--batch-size", type=int, default=None)
128
159
  @click.option("--group-size", type=int, default=None)
@@ -142,8 +173,8 @@ def run(batch_size: int | None, group_size: int | None, model: str | None, timeo
142
173
  _forward_to_demo(args)
143
174
 
144
175
 
145
- @cli.command()
146
- def setup():
176
+ @cli.command(name="setup")
177
+ def setup_command():
147
178
  """Perform SDK handshake and write keys to .env."""
148
179
  _forward_to_demo(["rl_demo.setup"])
149
180
 
@@ -164,7 +195,7 @@ def setup():
164
195
  default=True,
165
196
  help="Kill any process already bound to --env-port without prompting",
166
197
  )
167
- def serve(
198
+ def serve_deprecated(
168
199
  db_file: str,
169
200
  sqld_port: int,
170
201
  env_port: int,
@@ -174,6 +205,7 @@ def serve(
174
205
  force: bool,
175
206
  ):
176
207
  logging.basicConfig(level=logging.INFO, format="%(message)s")
208
+ click.echo("⚠️ 'synth-ai serve' now targets task apps; use 'synth-ai serve' for task apps or 'synth-ai serve-deprecated' for this legacy service.", err=True)
177
209
  processes = []
178
210
 
179
211
  def signal_handler(sig, frame):