synth-ai 0.2.4.dev7__py3-none-any.whl → 0.2.4.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/__init__.py +1 -1
- synth_ai/cli/__init__.py +6 -0
- synth_ai/cli/balance.py +3 -15
- synth_ai/cli/demo.py +68 -9
- synth_ai/cli/rl_demo.py +137 -0
- synth_ai/cli/root.py +65 -0
- synth_ai/config/base_url.py +47 -0
- synth_ai/demos/core/__init__.py +1 -0
- synth_ai/demos/core/cli.py +621 -0
- synth_ai/demos/demo_task_apps/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/core.py +374 -0
- synth_ai/demos/demo_task_apps/math/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/math/app.py +37 -0
- synth_ai/demos/demo_task_apps/math/config.toml +44 -0
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +60 -0
- synth_ai/demos/demo_task_apps/math/deploy_task_app.sh +22 -0
- synth_ai/environments/examples/bandit/__init__.py +33 -0
- synth_ai/environments/examples/bandit/engine.py +294 -0
- synth_ai/environments/examples/bandit/environment.py +194 -0
- synth_ai/environments/examples/bandit/taskset.py +200 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +250 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +59 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_config.toml +24 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/crafter_synth_config.toml +56 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_config_modal.toml +32 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +724 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_modal.py +384 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_action_results.py +53 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_agent_actions.py +178 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_latest_run.py +222 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_lm_traces.py +183 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_no_rewards.py +210 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_trace_issue.py +206 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_db_schema.py +49 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_latest_results.py +64 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/debug_agent_responses.py +88 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/quick_trace_check.py +77 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/compare_experiments.py +324 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/kick_off_ft_oai.py +362 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/multi_model_config.toml +49 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_enhanced_hooks.py +332 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_events.py +97 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_results.py +217 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_hook_storage.py +87 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_seeds.py +88 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/compare_seed_performance.py +195 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/custom_eval_pipelines.py +400 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/plot_hook_frequency.py +195 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/seed_analysis_summary.py +56 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +858 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +52 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +874 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +216 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/compare_traces.py +296 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_env_serialization.py +464 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/debug_player_loss.py +112 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_service.py +203 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_slowness.py +305 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_by_difficulty.py +126 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_example.py +94 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/explore_saved_states.py +142 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft.py +26 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft_OLD.py +984 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_gemini.py +724 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_modal.py +386 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_metadata.py +205 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_gemini.py +150 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_modal.py +283 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/prepare_vertex_ft.py +280 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/profile_env_slowness.py +456 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/replicate_issue.py +166 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_and_eval.py +102 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_comparison.py +128 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_qwen_rollouts.py +655 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/trace_eval_OLD.py +202 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/validate_openai_format.py +166 -0
- synth_ai/environments/examples/crafter_classic/environment.py +41 -2
- synth_ai/environments/examples/crafter_custom/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/crafter_custom/agent_demos/trace_eval.py +202 -0
- synth_ai/environments/examples/crafter_custom/old/analyze_diamond_issue.py +159 -0
- synth_ai/environments/examples/crafter_custom/old/analyze_diamond_spawning.py +158 -0
- synth_ai/environments/examples/crafter_custom/old/compare_worlds.py +71 -0
- synth_ai/environments/examples/crafter_custom/old/dataset_stats.py +105 -0
- synth_ai/environments/examples/crafter_custom/old/diamond_spawning_summary.py +119 -0
- synth_ai/environments/examples/crafter_custom/old/example_dataset_usage.py +52 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +112 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +221 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +831 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +899 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +95 -0
- synth_ai/environments/service/app.py +8 -0
- synth_ai/http.py +102 -0
- synth_ai/inference/__init__.py +7 -0
- synth_ai/inference/client.py +20 -0
- synth_ai/install_sqld.sh +40 -0
- synth_ai/jobs/client.py +246 -0
- synth_ai/learning/__init__.py +24 -0
- synth_ai/learning/client.py +149 -0
- synth_ai/learning/config.py +43 -0
- synth_ai/learning/constants.py +29 -0
- synth_ai/learning/ft_client.py +59 -0
- synth_ai/learning/health.py +43 -0
- synth_ai/learning/jobs.py +205 -0
- synth_ai/learning/rl_client.py +256 -0
- synth_ai/learning/sse.py +58 -0
- synth_ai/learning/validators.py +48 -0
- synth_ai/lm/core/main_v3.py +13 -0
- synth_ai/lm/core/synth_models.py +48 -0
- synth_ai/lm/core/vendor_clients.py +9 -6
- synth_ai/lm/vendors/core/openai_api.py +31 -3
- synth_ai/lm/vendors/openai_standard.py +45 -14
- synth_ai/lm/vendors/supported/custom_endpoint.py +12 -2
- synth_ai/lm/vendors/synth_client.py +372 -28
- synth_ai/rl/__init__.py +30 -0
- synth_ai/rl/contracts.py +32 -0
- synth_ai/rl/env_keys.py +137 -0
- synth_ai/rl/secrets.py +19 -0
- synth_ai/scripts/verify_rewards.py +100 -0
- synth_ai/task/__init__.py +10 -0
- synth_ai/task/contracts.py +120 -0
- synth_ai/task/health.py +28 -0
- synth_ai/task/validators.py +12 -0
- synth_ai/tracing_v3/hooks.py +3 -1
- synth_ai/tracing_v3/session_tracer.py +123 -2
- synth_ai/tracing_v3/turso/manager.py +218 -0
- synth_ai/tracing_v3/turso/models.py +53 -0
- synth_ai-0.2.4.dev9.dist-info/METADATA +91 -0
- {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev9.dist-info}/RECORD +147 -30
- {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev9.dist-info}/entry_points.txt +1 -0
- synth_ai/tui/__init__.py +0 -1
- synth_ai/tui/__main__.py +0 -13
- synth_ai/tui/cli/__init__.py +0 -1
- synth_ai/tui/cli/query_experiments.py +0 -164
- synth_ai/tui/cli/query_experiments_v3.py +0 -164
- synth_ai/tui/dashboard.py +0 -340
- synth_ai-0.2.4.dev7.dist-info/METADATA +0 -193
- {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev9.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev9.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev9.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
from ..http import AsyncHttpClient, HTTPError, sleep
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LearningClient:
|
|
10
|
+
def __init__(self, base_url: str, api_key: str, *, timeout: float = 30.0) -> None:
|
|
11
|
+
self._base_url = base_url.rstrip("/")
|
|
12
|
+
self._api_key = api_key
|
|
13
|
+
self._timeout = timeout
|
|
14
|
+
|
|
15
|
+
async def upload_training_file(self, path: str | Path, *, purpose: str = "fine-tune") -> str:
|
|
16
|
+
p = Path(path)
|
|
17
|
+
content = p.read_bytes()
|
|
18
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
19
|
+
data = {"purpose": purpose}
|
|
20
|
+
files = {"file": (p.name, content, _infer_content_type(p.name))}
|
|
21
|
+
js = await http.post_multipart("/api/learning/files", data=data, files=files)
|
|
22
|
+
if not isinstance(js, dict) or "id" not in js:
|
|
23
|
+
raise HTTPError(status=500, url="/api/learning/files", message="invalid_upload_response", body_snippet=str(js)[:200])
|
|
24
|
+
return str(js["id"])
|
|
25
|
+
|
|
26
|
+
async def create_job(
|
|
27
|
+
self,
|
|
28
|
+
*,
|
|
29
|
+
training_type: str,
|
|
30
|
+
model: str,
|
|
31
|
+
training_file_id: str,
|
|
32
|
+
hyperparameters: Optional[Dict[str, Any]] = None,
|
|
33
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
34
|
+
) -> Dict[str, Any]:
|
|
35
|
+
body = {
|
|
36
|
+
"training_type": training_type,
|
|
37
|
+
"model": model,
|
|
38
|
+
"training_file_id": training_file_id,
|
|
39
|
+
"hyperparameters": hyperparameters or {},
|
|
40
|
+
"metadata": metadata or {},
|
|
41
|
+
}
|
|
42
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
43
|
+
return await http.post_json("/api/learning/jobs", json=body)
|
|
44
|
+
|
|
45
|
+
async def start_job(self, job_id: str) -> Dict[str, Any]:
|
|
46
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
47
|
+
return await http.post_json(f"/api/learning/jobs/{job_id}/start", json={})
|
|
48
|
+
|
|
49
|
+
async def get_job(self, job_id: str) -> Dict[str, Any]:
|
|
50
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
51
|
+
return await http.get(f"/api/learning/jobs/{job_id}")
|
|
52
|
+
|
|
53
|
+
async def get_events(self, job_id: str, *, since_seq: int = 0, limit: int = 200) -> List[Dict[str, Any]]:
|
|
54
|
+
params = {"since_seq": since_seq, "limit": limit}
|
|
55
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
56
|
+
js = await http.get(f"/api/learning/jobs/{job_id}/events", params=params)
|
|
57
|
+
if isinstance(js, dict) and isinstance(js.get("events"), list):
|
|
58
|
+
return js["events"]
|
|
59
|
+
return []
|
|
60
|
+
|
|
61
|
+
async def get_metrics(self, job_id: str, *, name: str | None = None, after_step: int | None = None, limit: int = 500, run_id: str | None = None) -> List[Dict[str, Any]]:
|
|
62
|
+
params: Dict[str, Any] = {"limit": limit}
|
|
63
|
+
if name is not None:
|
|
64
|
+
params["name"] = name
|
|
65
|
+
if after_step is not None:
|
|
66
|
+
params["after_step"] = after_step
|
|
67
|
+
if run_id is not None:
|
|
68
|
+
params["run_id"] = run_id
|
|
69
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
70
|
+
js = await http.get(f"/api/learning/jobs/{job_id}/metrics", params=params)
|
|
71
|
+
if isinstance(js, dict) and isinstance(js.get("points"), list):
|
|
72
|
+
return js["points"]
|
|
73
|
+
return []
|
|
74
|
+
|
|
75
|
+
async def get_timeline(self, job_id: str, *, limit: int = 200) -> List[Dict[str, Any]]:
|
|
76
|
+
params = {"limit": limit}
|
|
77
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
78
|
+
js = await http.get(f"/api/learning/jobs/{job_id}/timeline", params=params)
|
|
79
|
+
if isinstance(js, dict) and isinstance(js.get("events"), list):
|
|
80
|
+
return js["events"]
|
|
81
|
+
return []
|
|
82
|
+
|
|
83
|
+
async def poll_until_terminal(
|
|
84
|
+
self,
|
|
85
|
+
job_id: str,
|
|
86
|
+
*,
|
|
87
|
+
interval_seconds: float = 2.0,
|
|
88
|
+
max_seconds: float | None = 3600,
|
|
89
|
+
on_event: Callable[[Dict[str, Any]], None] | None = None,
|
|
90
|
+
) -> Dict[str, Any]:
|
|
91
|
+
last_seq = 0
|
|
92
|
+
elapsed = 0.0
|
|
93
|
+
while True:
|
|
94
|
+
# Events
|
|
95
|
+
events = await self.get_events(job_id, since_seq=last_seq, limit=200)
|
|
96
|
+
for e in events:
|
|
97
|
+
if isinstance(e, dict) and isinstance(e.get("seq"), int):
|
|
98
|
+
last_seq = max(last_seq, int(e["seq"]))
|
|
99
|
+
if on_event:
|
|
100
|
+
try:
|
|
101
|
+
on_event(e)
|
|
102
|
+
except Exception:
|
|
103
|
+
pass
|
|
104
|
+
|
|
105
|
+
# Status
|
|
106
|
+
job = await self.get_job(job_id)
|
|
107
|
+
status = str(job.get("status") or "").lower()
|
|
108
|
+
if status in {"succeeded", "failed", "canceled", "cancelled"}:
|
|
109
|
+
return job
|
|
110
|
+
|
|
111
|
+
# Sleep and time budget
|
|
112
|
+
await sleep(interval_seconds)
|
|
113
|
+
elapsed += interval_seconds
|
|
114
|
+
if max_seconds is not None and elapsed >= max_seconds:
|
|
115
|
+
raise TimeoutError(f"Polling timed out after {elapsed} seconds for job {job_id}")
|
|
116
|
+
|
|
117
|
+
# --- Optional diagnostics ---
|
|
118
|
+
async def pricing_preflight(self, *, job_type: str, gpu_type: str, estimated_seconds: float, container_count: int) -> Dict[str, Any]:
|
|
119
|
+
body = {
|
|
120
|
+
"job_type": job_type,
|
|
121
|
+
"gpu_type": gpu_type,
|
|
122
|
+
"estimated_seconds": float(estimated_seconds or 0.0),
|
|
123
|
+
"container_count": int(container_count or 1),
|
|
124
|
+
}
|
|
125
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
126
|
+
js = await http.post_json("/api/v1/pricing/preflight", json=body)
|
|
127
|
+
if not isinstance(js, dict):
|
|
128
|
+
raise HTTPError(status=500, url="/api/v1/pricing/preflight", message="invalid_preflight_response", body_snippet=str(js)[:200])
|
|
129
|
+
return js
|
|
130
|
+
|
|
131
|
+
async def balance_autumn_normalized(self) -> Dict[str, Any]:
|
|
132
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
133
|
+
js = await http.get("/api/v1/balance/autumn-normalized")
|
|
134
|
+
if not isinstance(js, dict):
|
|
135
|
+
raise HTTPError(status=500, url="/api/v1/balance/autumn-normalized", message="invalid_balance_response", body_snippet=str(js)[:200])
|
|
136
|
+
return js
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _infer_content_type(filename: str) -> str:
|
|
140
|
+
name = filename.lower()
|
|
141
|
+
if name.endswith(".jsonl"):
|
|
142
|
+
return "application/jsonl"
|
|
143
|
+
if name.endswith(".json"):
|
|
144
|
+
return "application/json"
|
|
145
|
+
if name.endswith(".txt"):
|
|
146
|
+
return "text/plain"
|
|
147
|
+
return "application/octet-stream"
|
|
148
|
+
|
|
149
|
+
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Dict, Optional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class FTJobConfig:
|
|
9
|
+
model: str
|
|
10
|
+
training_file_id: str
|
|
11
|
+
n_epochs: int = 1
|
|
12
|
+
batch_size: int = 1
|
|
13
|
+
upload_to_wasabi: bool = True
|
|
14
|
+
|
|
15
|
+
def hyperparameters(self) -> Dict[str, Any]:
|
|
16
|
+
if self.n_epochs < 1:
|
|
17
|
+
raise ValueError("n_epochs must be >= 1")
|
|
18
|
+
if self.batch_size < 1:
|
|
19
|
+
raise ValueError("batch_size must be >= 1")
|
|
20
|
+
return {"n_epochs": int(self.n_epochs), "batch_size": int(self.batch_size)}
|
|
21
|
+
|
|
22
|
+
def metadata(self) -> Dict[str, Any]: # type: ignore[override]
|
|
23
|
+
return {"upload_to_wasabi": bool(self.upload_to_wasabi)}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class RLJobConfig:
|
|
28
|
+
model: str
|
|
29
|
+
task_app_url: str
|
|
30
|
+
trainer_id: str
|
|
31
|
+
batch_size: int = 1
|
|
32
|
+
group_size: int = 2
|
|
33
|
+
job_config_id: Optional[str] = None
|
|
34
|
+
inline_config: Optional[Dict[str, Any]] = None
|
|
35
|
+
|
|
36
|
+
def trainer_dict(self) -> Dict[str, Any]:
|
|
37
|
+
if self.batch_size < 1:
|
|
38
|
+
raise ValueError("batch_size must be >= 1")
|
|
39
|
+
if self.group_size < 2:
|
|
40
|
+
raise ValueError("group_size must be >= 2")
|
|
41
|
+
return {"batch_size": int(self.batch_size), "group_size": int(self.group_size)}
|
|
42
|
+
|
|
43
|
+
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
# Terminal statuses normalized across FT and RL
|
|
4
|
+
TERMINAL_STATUSES = {
|
|
5
|
+
"succeeded",
|
|
6
|
+
"failed",
|
|
7
|
+
"cancelled",
|
|
8
|
+
"canceled",
|
|
9
|
+
"error",
|
|
10
|
+
"completed",
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
# Terminal event types (success/failure) across FT and RL
|
|
14
|
+
TERMINAL_EVENT_SUCCESS = {
|
|
15
|
+
"sft.completed",
|
|
16
|
+
"sft.workflow.completed",
|
|
17
|
+
"rl.job.completed",
|
|
18
|
+
"rl.train.completed",
|
|
19
|
+
"workflow.completed",
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
TERMINAL_EVENT_FAILURE = {
|
|
23
|
+
"sft.failed",
|
|
24
|
+
"sft.workflow.failed",
|
|
25
|
+
"rl.job.failed",
|
|
26
|
+
"workflow.failed",
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Dict, Optional
|
|
5
|
+
|
|
6
|
+
from ..http import AsyncHttpClient, HTTPError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FtClient:
|
|
10
|
+
def __init__(self, base_url: str, api_key: str, *, timeout: float = 30.0) -> None:
|
|
11
|
+
self._base_url = base_url.rstrip("/")
|
|
12
|
+
self._api_key = api_key
|
|
13
|
+
self._timeout = timeout
|
|
14
|
+
|
|
15
|
+
async def upload_training_file(self, path: str | Path, *, purpose: str = "fine-tune") -> str:
|
|
16
|
+
p = Path(path)
|
|
17
|
+
content = p.read_bytes()
|
|
18
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
19
|
+
data = {"purpose": purpose}
|
|
20
|
+
files = {"file": (p.name, content, _infer_content_type(p.name))}
|
|
21
|
+
js = await http.post_multipart("/api/learning/files", data=data, files=files)
|
|
22
|
+
if not isinstance(js, dict) or "id" not in js:
|
|
23
|
+
raise HTTPError(status=500, url="/api/learning/files", message="invalid_upload_response", body_snippet=str(js)[:200])
|
|
24
|
+
return str(js["id"])
|
|
25
|
+
|
|
26
|
+
async def create_sft_job(
|
|
27
|
+
self,
|
|
28
|
+
*,
|
|
29
|
+
model: str,
|
|
30
|
+
training_file_id: str,
|
|
31
|
+
hyperparameters: Dict[str, Any],
|
|
32
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
33
|
+
) -> Dict[str, Any]:
|
|
34
|
+
body = {
|
|
35
|
+
"training_type": "sft_offline",
|
|
36
|
+
"model": model,
|
|
37
|
+
"training_file_id": training_file_id,
|
|
38
|
+
"hyperparameters": dict(hyperparameters or {}),
|
|
39
|
+
"metadata": dict(metadata or {}),
|
|
40
|
+
}
|
|
41
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
42
|
+
return await http.post_json("/api/learning/jobs", json=body)
|
|
43
|
+
|
|
44
|
+
async def start_job(self, job_id: str) -> Dict[str, Any]:
|
|
45
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
46
|
+
return await http.post_json(f"/api/learning/jobs/{job_id}/start", json={})
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _infer_content_type(filename: str) -> str:
|
|
50
|
+
name = filename.lower()
|
|
51
|
+
if name.endswith(".jsonl"):
|
|
52
|
+
return "application/jsonl"
|
|
53
|
+
if name.endswith(".json"):
|
|
54
|
+
return "application/json"
|
|
55
|
+
if name.endswith(".txt"):
|
|
56
|
+
return "text/plain"
|
|
57
|
+
return "application/octet-stream"
|
|
58
|
+
|
|
59
|
+
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Optional
|
|
4
|
+
import aiohttp
|
|
5
|
+
|
|
6
|
+
from ..http import AsyncHttpClient
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _api_base(b: str) -> str:
|
|
10
|
+
b = (b or "").rstrip("/")
|
|
11
|
+
return b if b.endswith("/api") else f"{b}/api"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
async def backend_health(base_url: str, api_key: str) -> Dict[str, Any]:
|
|
15
|
+
async with AsyncHttpClient(base_url, api_key, timeout=15.0) as http:
|
|
16
|
+
js = await http.get(f"{_api_base(base_url)}/health")
|
|
17
|
+
return {"ok": True, "raw": js}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
async def task_app_health(task_app_url: str) -> Dict[str, Any]:
|
|
21
|
+
# Delegate to central task module for consistency
|
|
22
|
+
from synth_ai.task.health import task_app_health as _th
|
|
23
|
+
|
|
24
|
+
return await _th(task_app_url)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
async def pricing_preflight(base_url: str, api_key: str, *, job_type: str, gpu_type: str, estimated_seconds: float, container_count: int) -> Dict[str, Any]:
|
|
28
|
+
body = {
|
|
29
|
+
"job_type": job_type,
|
|
30
|
+
"gpu_type": gpu_type,
|
|
31
|
+
"estimated_seconds": float(estimated_seconds or 0.0),
|
|
32
|
+
"container_count": int(container_count or 1),
|
|
33
|
+
}
|
|
34
|
+
async with AsyncHttpClient(base_url, api_key, timeout=30.0) as http:
|
|
35
|
+
js = await http.post_json(f"{_api_base(base_url)}/v1/pricing/preflight", json=body)
|
|
36
|
+
return js if isinstance(js, dict) else {"raw": js}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
async def balance_autumn_normalized(base_url: str, api_key: str) -> Dict[str, Any]:
|
|
40
|
+
async with AsyncHttpClient(base_url, api_key, timeout=30.0) as http:
|
|
41
|
+
js = await http.get(f"{_api_base(base_url)}/v1/balance/autumn-normalized")
|
|
42
|
+
return js if isinstance(js, dict) else {"raw": js}
|
|
43
|
+
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
4
|
+
import time
|
|
5
|
+
|
|
6
|
+
from .constants import TERMINAL_EVENT_FAILURE, TERMINAL_EVENT_SUCCESS, TERMINAL_STATUSES
|
|
7
|
+
from ..http import AsyncHttpClient, sleep
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _api_base(b: str) -> str:
|
|
11
|
+
b = (b or "").rstrip("/")
|
|
12
|
+
return b if b.endswith("/api") else f"{b}/api"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class JobsApiResolver:
|
|
16
|
+
def __init__(self, base_url: str, *, strict: bool) -> None:
|
|
17
|
+
self._base = _api_base(base_url)
|
|
18
|
+
self._strict = strict
|
|
19
|
+
|
|
20
|
+
def status_urls(self, job_id: str) -> List[str]:
|
|
21
|
+
if self._strict:
|
|
22
|
+
return [f"{self._base}/learning/jobs/{job_id}"]
|
|
23
|
+
return [
|
|
24
|
+
f"{self._base}/learning/jobs/{job_id}",
|
|
25
|
+
f"{self._base}/rl/jobs/{job_id}",
|
|
26
|
+
f"{self._base}/orchestration/jobs/{job_id}",
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
def events_urls(self, job_id: str, since: int) -> List[str]:
|
|
30
|
+
if self._strict:
|
|
31
|
+
return [f"{self._base}/learning/jobs/{job_id}/events?since_seq={since}&limit=200"]
|
|
32
|
+
return [
|
|
33
|
+
f"{self._base}/learning/jobs/{job_id}/events?since_seq={since}&limit=200",
|
|
34
|
+
f"{self._base}/orchestration/jobs/{job_id}/events?since_seq={since}&limit=200",
|
|
35
|
+
# RL /jobs/{id}/events is SSE in backend; avoid in JSON poller
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
def metrics_url(self, job_id: str, after_step: int) -> str:
|
|
39
|
+
return f"{self._base}/learning/jobs/{job_id}/metrics?after_step={after_step}&limit=200"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class JobHandle:
|
|
43
|
+
def __init__(self, base_url: str, api_key: str, job_id: str, *, strict: bool = True, timeout: float = 600.0) -> None:
|
|
44
|
+
self.base_url = base_url.rstrip("/")
|
|
45
|
+
self.api_key = api_key
|
|
46
|
+
self.job_id = job_id
|
|
47
|
+
self.strict = strict
|
|
48
|
+
self.timeout = timeout
|
|
49
|
+
|
|
50
|
+
async def poll_until_terminal(
|
|
51
|
+
self,
|
|
52
|
+
*,
|
|
53
|
+
interval_seconds: float = 2.0,
|
|
54
|
+
max_seconds: float | None = None,
|
|
55
|
+
empty_polls_threshold: int = 5,
|
|
56
|
+
startup_deadline_s: int = 45,
|
|
57
|
+
on_event: Optional[Callable[[Dict[str, Any]], None]] = None,
|
|
58
|
+
on_metric: Optional[Callable[[Dict[str, Any]], None]] = None,
|
|
59
|
+
) -> Dict[str, Any]:
|
|
60
|
+
last_seq_by_stream: Dict[str, int] = {}
|
|
61
|
+
events_job_id: Optional[str] = None
|
|
62
|
+
last_status: Optional[str] = None
|
|
63
|
+
last_step_by_name: Dict[str, int] = {}
|
|
64
|
+
empty_polls = 0
|
|
65
|
+
saw_any_event = False
|
|
66
|
+
start_t = time.time()
|
|
67
|
+
resolver = JobsApiResolver(self.base_url, strict=self.strict)
|
|
68
|
+
detected_fine_tuned_model: Optional[str] = None
|
|
69
|
+
|
|
70
|
+
async with AsyncHttpClient(self.base_url, self.api_key, timeout=self.timeout) as http:
|
|
71
|
+
while True:
|
|
72
|
+
# Status
|
|
73
|
+
status_data: Optional[Dict[str, Any]] = None
|
|
74
|
+
for su in resolver.status_urls(self.job_id):
|
|
75
|
+
try:
|
|
76
|
+
status_data = await http.get(su)
|
|
77
|
+
if isinstance(status_data, dict):
|
|
78
|
+
break
|
|
79
|
+
except Exception:
|
|
80
|
+
continue
|
|
81
|
+
status = str((status_data or {}).get("status") or "").lower()
|
|
82
|
+
if status_data:
|
|
83
|
+
linked = status_data.get("linked_job_id")
|
|
84
|
+
if isinstance(linked, str) and linked and linked != events_job_id:
|
|
85
|
+
events_job_id = linked
|
|
86
|
+
# Capture fine_tuned_model if already present on status
|
|
87
|
+
if not detected_fine_tuned_model:
|
|
88
|
+
ftm = status_data.get("fine_tuned_model")
|
|
89
|
+
if isinstance(ftm, str) and ftm:
|
|
90
|
+
detected_fine_tuned_model = ftm
|
|
91
|
+
if status and status != last_status:
|
|
92
|
+
last_status = status
|
|
93
|
+
if on_event:
|
|
94
|
+
try:
|
|
95
|
+
on_event({"type": "job.status", "message": status})
|
|
96
|
+
except Exception:
|
|
97
|
+
pass
|
|
98
|
+
|
|
99
|
+
# Events
|
|
100
|
+
stream_ids = [self.job_id]
|
|
101
|
+
if events_job_id and events_job_id not in stream_ids:
|
|
102
|
+
stream_ids.append(events_job_id)
|
|
103
|
+
total_events_this_cycle = 0
|
|
104
|
+
terminal_event_seen = False
|
|
105
|
+
terminal_event_status: Optional[str] = None
|
|
106
|
+
for ev_id in stream_ids:
|
|
107
|
+
since = last_seq_by_stream.get(ev_id, 0)
|
|
108
|
+
for eu in resolver.events_urls(ev_id, since):
|
|
109
|
+
try:
|
|
110
|
+
ev_js = await http.get(eu)
|
|
111
|
+
except Exception:
|
|
112
|
+
continue
|
|
113
|
+
try:
|
|
114
|
+
events = (ev_js or {}).get("events") or (ev_js or {}).get("data") or []
|
|
115
|
+
if not isinstance(events, list):
|
|
116
|
+
events = []
|
|
117
|
+
except Exception:
|
|
118
|
+
events = []
|
|
119
|
+
total_events_this_cycle += len(events)
|
|
120
|
+
if events:
|
|
121
|
+
saw_any_event = True
|
|
122
|
+
for e in events:
|
|
123
|
+
seq_val = int(e.get("seq") or 0)
|
|
124
|
+
if seq_val <= last_seq_by_stream.get(ev_id, 0):
|
|
125
|
+
continue
|
|
126
|
+
last_seq_by_stream[ev_id] = seq_val
|
|
127
|
+
if on_event:
|
|
128
|
+
try:
|
|
129
|
+
on_event(e)
|
|
130
|
+
except Exception:
|
|
131
|
+
pass
|
|
132
|
+
et = str(e.get("type") or e.get("event_type") or "").lower()
|
|
133
|
+
# Capture fine_tuned_model from event data when available
|
|
134
|
+
if not detected_fine_tuned_model:
|
|
135
|
+
try:
|
|
136
|
+
data_obj = e.get("data") or {}
|
|
137
|
+
ftm = data_obj.get("fine_tuned_model") if isinstance(data_obj, dict) else None
|
|
138
|
+
if isinstance(ftm, str) and ftm:
|
|
139
|
+
detected_fine_tuned_model = ftm
|
|
140
|
+
except Exception:
|
|
141
|
+
pass
|
|
142
|
+
if et in TERMINAL_EVENT_SUCCESS:
|
|
143
|
+
terminal_event_seen = True
|
|
144
|
+
terminal_event_status = "succeeded"
|
|
145
|
+
elif et in TERMINAL_EVENT_FAILURE:
|
|
146
|
+
terminal_event_seen = True
|
|
147
|
+
terminal_event_status = "failed"
|
|
148
|
+
|
|
149
|
+
# Metrics
|
|
150
|
+
try:
|
|
151
|
+
after = max(last_step_by_name.values()) if last_step_by_name else -1
|
|
152
|
+
mu = resolver.metrics_url(self.job_id, after)
|
|
153
|
+
md = await http.get(mu)
|
|
154
|
+
for p in (md or {}).get("points", []):
|
|
155
|
+
name = str(p.get("name") or "")
|
|
156
|
+
step = int(p.get("step") or -1)
|
|
157
|
+
if step <= last_step_by_name.get(name, -1):
|
|
158
|
+
continue
|
|
159
|
+
last_step_by_name[name] = step
|
|
160
|
+
if on_metric:
|
|
161
|
+
try:
|
|
162
|
+
on_metric(p)
|
|
163
|
+
except Exception:
|
|
164
|
+
pass
|
|
165
|
+
except Exception:
|
|
166
|
+
pass
|
|
167
|
+
|
|
168
|
+
# Terminal decisions
|
|
169
|
+
if terminal_event_seen or (status and status in TERMINAL_STATUSES):
|
|
170
|
+
# Best-effort enrichment of final result with fine_tuned_model
|
|
171
|
+
result_status = terminal_event_status or status or "completed"
|
|
172
|
+
final_res: Dict[str, Any] = {"status": result_status, "job_id": self.job_id}
|
|
173
|
+
if not detected_fine_tuned_model:
|
|
174
|
+
# Briefly try to re-fetch status to see if fine_tuned_model is persisted
|
|
175
|
+
try:
|
|
176
|
+
for su in resolver.status_urls(self.job_id):
|
|
177
|
+
try:
|
|
178
|
+
final_status = await http.get(su)
|
|
179
|
+
if isinstance(final_status, dict):
|
|
180
|
+
ftm2 = final_status.get("fine_tuned_model")
|
|
181
|
+
if isinstance(ftm2, str) and ftm2:
|
|
182
|
+
detected_fine_tuned_model = ftm2
|
|
183
|
+
break
|
|
184
|
+
except Exception:
|
|
185
|
+
continue
|
|
186
|
+
except Exception:
|
|
187
|
+
pass
|
|
188
|
+
if detected_fine_tuned_model:
|
|
189
|
+
final_res["fine_tuned_model"] = detected_fine_tuned_model
|
|
190
|
+
return final_res
|
|
191
|
+
|
|
192
|
+
# Guards (relaxed): do not abort on consecutive empty polls
|
|
193
|
+
if total_events_this_cycle == 0:
|
|
194
|
+
empty_polls += 1
|
|
195
|
+
else:
|
|
196
|
+
empty_polls = 0
|
|
197
|
+
if not saw_any_event and (time.time() - start_t) > int(startup_deadline_s):
|
|
198
|
+
raise AssertionError(
|
|
199
|
+
f"No events observed within startup window ({startup_deadline_s}s). Investigate event streaming."
|
|
200
|
+
)
|
|
201
|
+
await sleep(interval_seconds)
|
|
202
|
+
if max_seconds is not None and (time.time() - start_t) >= max_seconds:
|
|
203
|
+
raise TimeoutError(f"Polling timed out after {max_seconds}s for job {self.job_id}")
|
|
204
|
+
|
|
205
|
+
|