synth-ai 0.2.4.dev6__py3-none-any.whl → 0.2.4.dev8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +18 -9
- synth_ai/cli/__init__.py +10 -5
- synth_ai/cli/balance.py +25 -32
- synth_ai/cli/calc.py +2 -3
- synth_ai/cli/demo.py +3 -5
- synth_ai/cli/legacy_root_backup.py +58 -32
- synth_ai/cli/man.py +22 -19
- synth_ai/cli/recent.py +9 -8
- synth_ai/cli/root.py +58 -13
- synth_ai/cli/status.py +13 -6
- synth_ai/cli/traces.py +45 -21
- synth_ai/cli/watch.py +40 -37
- synth_ai/config/base_url.py +47 -2
- synth_ai/core/experiment.py +1 -2
- synth_ai/environments/__init__.py +2 -6
- synth_ai/environments/environment/artifacts/base.py +3 -1
- synth_ai/environments/environment/db/sqlite.py +1 -1
- synth_ai/environments/environment/registry.py +19 -20
- synth_ai/environments/environment/resources/sqlite.py +2 -3
- synth_ai/environments/environment/rewards/core.py +3 -2
- synth_ai/environments/environment/tools/__init__.py +6 -4
- synth_ai/environments/examples/crafter_classic/__init__.py +1 -1
- synth_ai/environments/examples/crafter_classic/engine.py +13 -13
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +1 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +2 -1
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +2 -1
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +3 -2
- synth_ai/environments/examples/crafter_classic/environment.py +16 -15
- synth_ai/environments/examples/crafter_classic/taskset.py +2 -2
- synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +2 -3
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +2 -1
- synth_ai/environments/examples/crafter_custom/crafter/__init__.py +2 -2
- synth_ai/environments/examples/crafter_custom/crafter/config.py +2 -2
- synth_ai/environments/examples/crafter_custom/crafter/env.py +1 -5
- synth_ai/environments/examples/crafter_custom/crafter/objects.py +1 -2
- synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +1 -2
- synth_ai/environments/examples/crafter_custom/dataset_builder.py +5 -5
- synth_ai/environments/examples/crafter_custom/environment.py +13 -13
- synth_ai/environments/examples/crafter_custom/run_dataset.py +5 -5
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +2 -2
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +5 -4
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +2 -1
- synth_ai/environments/examples/enron/engine.py +18 -14
- synth_ai/environments/examples/enron/environment.py +12 -11
- synth_ai/environments/examples/enron/taskset.py +7 -7
- synth_ai/environments/examples/minigrid/__init__.py +6 -6
- synth_ai/environments/examples/minigrid/engine.py +6 -6
- synth_ai/environments/examples/minigrid/environment.py +6 -6
- synth_ai/environments/examples/minigrid/puzzle_loader.py +3 -2
- synth_ai/environments/examples/minigrid/taskset.py +13 -13
- synth_ai/environments/examples/nethack/achievements.py +1 -1
- synth_ai/environments/examples/nethack/engine.py +8 -7
- synth_ai/environments/examples/nethack/environment.py +10 -9
- synth_ai/environments/examples/nethack/helpers/__init__.py +8 -9
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +1 -1
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +2 -1
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +1 -1
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +3 -4
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +6 -5
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +5 -5
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +7 -6
- synth_ai/environments/examples/nethack/taskset.py +5 -5
- synth_ai/environments/examples/red/engine.py +9 -8
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +7 -7
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +3 -2
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +2 -1
- synth_ai/environments/examples/red/environment.py +18 -15
- synth_ai/environments/examples/red/taskset.py +5 -3
- synth_ai/environments/examples/sokoban/engine.py +16 -13
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +3 -2
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +2 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +1 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +7 -5
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +1 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +2 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +5 -4
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +3 -2
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +2 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +5 -4
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +1 -1
- synth_ai/environments/examples/sokoban/environment.py +15 -14
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +5 -3
- synth_ai/environments/examples/sokoban/puzzle_loader.py +3 -2
- synth_ai/environments/examples/sokoban/taskset.py +13 -10
- synth_ai/environments/examples/tictactoe/engine.py +6 -6
- synth_ai/environments/examples/tictactoe/environment.py +8 -7
- synth_ai/environments/examples/tictactoe/taskset.py +6 -5
- synth_ai/environments/examples/verilog/engine.py +4 -3
- synth_ai/environments/examples/verilog/environment.py +11 -10
- synth_ai/environments/examples/verilog/taskset.py +14 -12
- synth_ai/environments/examples/wordle/__init__.py +5 -5
- synth_ai/environments/examples/wordle/engine.py +32 -25
- synth_ai/environments/examples/wordle/environment.py +21 -16
- synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +6 -6
- synth_ai/environments/examples/wordle/taskset.py +20 -12
- synth_ai/environments/reproducibility/core.py +1 -1
- synth_ai/environments/reproducibility/tree.py +21 -21
- synth_ai/environments/service/app.py +3 -2
- synth_ai/environments/service/core_routes.py +104 -110
- synth_ai/environments/service/external_registry.py +1 -2
- synth_ai/environments/service/registry.py +1 -1
- synth_ai/environments/stateful/core.py +1 -2
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/api.py +4 -4
- synth_ai/environments/tasks/core.py +14 -12
- synth_ai/environments/tasks/filters.py +6 -4
- synth_ai/environments/tasks/utils.py +13 -11
- synth_ai/evals/base.py +2 -3
- synth_ai/experimental/synth_oss.py +4 -4
- synth_ai/http.py +102 -0
- synth_ai/inference/__init__.py +7 -0
- synth_ai/inference/client.py +20 -0
- synth_ai/jobs/client.py +246 -0
- synth_ai/learning/__init__.py +24 -0
- synth_ai/learning/client.py +149 -0
- synth_ai/learning/config.py +43 -0
- synth_ai/learning/constants.py +29 -0
- synth_ai/learning/ft_client.py +59 -0
- synth_ai/learning/gateway.py +1 -3
- synth_ai/learning/health.py +43 -0
- synth_ai/learning/jobs.py +205 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +15 -10
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +26 -14
- synth_ai/learning/prompts/mipro.py +61 -52
- synth_ai/learning/prompts/random_search.py +42 -43
- synth_ai/learning/prompts/run_mipro_banking77.py +32 -20
- synth_ai/learning/prompts/run_random_search_banking77.py +71 -52
- synth_ai/learning/rl_client.py +256 -0
- synth_ai/learning/sse.py +58 -0
- synth_ai/learning/validators.py +48 -0
- synth_ai/lm/__init__.py +5 -5
- synth_ai/lm/caching/ephemeral.py +9 -9
- synth_ai/lm/caching/handler.py +20 -20
- synth_ai/lm/caching/persistent.py +10 -10
- synth_ai/lm/config.py +3 -3
- synth_ai/lm/constants.py +7 -7
- synth_ai/lm/core/all.py +17 -3
- synth_ai/lm/core/exceptions.py +0 -2
- synth_ai/lm/core/main.py +26 -41
- synth_ai/lm/core/main_v3.py +33 -10
- synth_ai/lm/core/synth_models.py +48 -0
- synth_ai/lm/core/vendor_clients.py +26 -22
- synth_ai/lm/injection.py +7 -8
- synth_ai/lm/overrides.py +21 -19
- synth_ai/lm/provider_support/__init__.py +1 -1
- synth_ai/lm/provider_support/anthropic.py +15 -15
- synth_ai/lm/provider_support/openai.py +23 -21
- synth_ai/lm/structured_outputs/handler.py +34 -32
- synth_ai/lm/structured_outputs/inject.py +24 -27
- synth_ai/lm/structured_outputs/rehabilitate.py +19 -15
- synth_ai/lm/tools/base.py +17 -16
- synth_ai/lm/unified_interface.py +17 -18
- synth_ai/lm/vendors/base.py +20 -18
- synth_ai/lm/vendors/core/anthropic_api.py +36 -27
- synth_ai/lm/vendors/core/gemini_api.py +31 -36
- synth_ai/lm/vendors/core/mistral_api.py +19 -19
- synth_ai/lm/vendors/core/openai_api.py +42 -13
- synth_ai/lm/vendors/openai_standard.py +158 -101
- synth_ai/lm/vendors/openai_standard_responses.py +74 -61
- synth_ai/lm/vendors/retries.py +9 -1
- synth_ai/lm/vendors/supported/custom_endpoint.py +38 -28
- synth_ai/lm/vendors/supported/deepseek.py +10 -10
- synth_ai/lm/vendors/supported/grok.py +8 -8
- synth_ai/lm/vendors/supported/ollama.py +2 -1
- synth_ai/lm/vendors/supported/openrouter.py +11 -9
- synth_ai/lm/vendors/synth_client.py +425 -75
- synth_ai/lm/warmup.py +8 -7
- synth_ai/rl/__init__.py +30 -0
- synth_ai/rl/contracts.py +32 -0
- synth_ai/rl/env_keys.py +137 -0
- synth_ai/rl/secrets.py +19 -0
- synth_ai/scripts/verify_rewards.py +100 -0
- synth_ai/task/__init__.py +10 -0
- synth_ai/task/contracts.py +120 -0
- synth_ai/task/health.py +28 -0
- synth_ai/task/validators.py +12 -0
- synth_ai/tracing/__init__.py +22 -10
- synth_ai/tracing_v1/__init__.py +22 -20
- synth_ai/tracing_v3/__init__.py +7 -7
- synth_ai/tracing_v3/abstractions.py +56 -52
- synth_ai/tracing_v3/config.py +4 -2
- synth_ai/tracing_v3/db_config.py +6 -8
- synth_ai/tracing_v3/decorators.py +29 -30
- synth_ai/tracing_v3/examples/basic_usage.py +12 -12
- synth_ai/tracing_v3/hooks.py +24 -22
- synth_ai/tracing_v3/llm_call_record_helpers.py +85 -98
- synth_ai/tracing_v3/lm_call_record_abstractions.py +2 -4
- synth_ai/tracing_v3/migration_helper.py +3 -5
- synth_ai/tracing_v3/replica_sync.py +30 -32
- synth_ai/tracing_v3/session_tracer.py +158 -31
- synth_ai/tracing_v3/storage/__init__.py +1 -1
- synth_ai/tracing_v3/storage/base.py +8 -7
- synth_ai/tracing_v3/storage/config.py +4 -4
- synth_ai/tracing_v3/storage/factory.py +4 -4
- synth_ai/tracing_v3/storage/utils.py +9 -9
- synth_ai/tracing_v3/turso/__init__.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +9 -9
- synth_ai/tracing_v3/turso/manager.py +278 -48
- synth_ai/tracing_v3/turso/models.py +77 -19
- synth_ai/tracing_v3/utils.py +5 -5
- synth_ai/v0/tracing/abstractions.py +28 -28
- synth_ai/v0/tracing/base_client.py +9 -9
- synth_ai/v0/tracing/client_manager.py +7 -7
- synth_ai/v0/tracing/config.py +7 -7
- synth_ai/v0/tracing/context.py +6 -6
- synth_ai/v0/tracing/decorators.py +6 -5
- synth_ai/v0/tracing/events/manage.py +1 -1
- synth_ai/v0/tracing/events/store.py +5 -4
- synth_ai/v0/tracing/immediate_client.py +4 -5
- synth_ai/v0/tracing/local.py +3 -3
- synth_ai/v0/tracing/log_client_base.py +4 -5
- synth_ai/v0/tracing/retry_queue.py +5 -6
- synth_ai/v0/tracing/trackers.py +25 -25
- synth_ai/v0/tracing/upload.py +6 -0
- synth_ai/v0/tracing_v1/__init__.py +1 -1
- synth_ai/v0/tracing_v1/abstractions.py +28 -28
- synth_ai/v0/tracing_v1/base_client.py +9 -9
- synth_ai/v0/tracing_v1/client_manager.py +7 -7
- synth_ai/v0/tracing_v1/config.py +7 -7
- synth_ai/v0/tracing_v1/context.py +6 -6
- synth_ai/v0/tracing_v1/decorators.py +7 -6
- synth_ai/v0/tracing_v1/events/manage.py +1 -1
- synth_ai/v0/tracing_v1/events/store.py +5 -4
- synth_ai/v0/tracing_v1/immediate_client.py +4 -5
- synth_ai/v0/tracing_v1/local.py +3 -3
- synth_ai/v0/tracing_v1/log_client_base.py +4 -5
- synth_ai/v0/tracing_v1/retry_queue.py +5 -6
- synth_ai/v0/tracing_v1/trackers.py +25 -25
- synth_ai/v0/tracing_v1/upload.py +25 -24
- synth_ai/zyk/__init__.py +1 -0
- synth_ai-0.2.4.dev8.dist-info/METADATA +635 -0
- synth_ai-0.2.4.dev8.dist-info/RECORD +317 -0
- synth_ai/tui/__init__.py +0 -1
- synth_ai/tui/__main__.py +0 -13
- synth_ai/tui/cli/__init__.py +0 -1
- synth_ai/tui/cli/query_experiments.py +0 -165
- synth_ai/tui/cli/query_experiments_v3.py +0 -165
- synth_ai/tui/dashboard.py +0 -329
- synth_ai-0.2.4.dev6.dist-info/METADATA +0 -203
- synth_ai-0.2.4.dev6.dist-info/RECORD +0 -299
- {synth_ai-0.2.4.dev6.dist-info → synth_ai-0.2.4.dev8.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev6.dist-info → synth_ai-0.2.4.dev8.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev6.dist-info → synth_ai-0.2.4.dev8.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev6.dist-info → synth_ai-0.2.4.dev8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,43 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Any, Dict, Optional
|
5
|
+
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class FTJobConfig:
|
9
|
+
model: str
|
10
|
+
training_file_id: str
|
11
|
+
n_epochs: int = 1
|
12
|
+
batch_size: int = 1
|
13
|
+
upload_to_wasabi: bool = True
|
14
|
+
|
15
|
+
def hyperparameters(self) -> Dict[str, Any]:
|
16
|
+
if self.n_epochs < 1:
|
17
|
+
raise ValueError("n_epochs must be >= 1")
|
18
|
+
if self.batch_size < 1:
|
19
|
+
raise ValueError("batch_size must be >= 1")
|
20
|
+
return {"n_epochs": int(self.n_epochs), "batch_size": int(self.batch_size)}
|
21
|
+
|
22
|
+
def metadata(self) -> Dict[str, Any]: # type: ignore[override]
|
23
|
+
return {"upload_to_wasabi": bool(self.upload_to_wasabi)}
|
24
|
+
|
25
|
+
|
26
|
+
@dataclass
|
27
|
+
class RLJobConfig:
|
28
|
+
model: str
|
29
|
+
task_app_url: str
|
30
|
+
trainer_id: str
|
31
|
+
batch_size: int = 1
|
32
|
+
group_size: int = 2
|
33
|
+
job_config_id: Optional[str] = None
|
34
|
+
inline_config: Optional[Dict[str, Any]] = None
|
35
|
+
|
36
|
+
def trainer_dict(self) -> Dict[str, Any]:
|
37
|
+
if self.batch_size < 1:
|
38
|
+
raise ValueError("batch_size must be >= 1")
|
39
|
+
if self.group_size < 2:
|
40
|
+
raise ValueError("group_size must be >= 2")
|
41
|
+
return {"batch_size": int(self.batch_size), "group_size": int(self.group_size)}
|
42
|
+
|
43
|
+
|
@@ -0,0 +1,29 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
# Terminal statuses normalized across FT and RL
|
4
|
+
TERMINAL_STATUSES = {
|
5
|
+
"succeeded",
|
6
|
+
"failed",
|
7
|
+
"cancelled",
|
8
|
+
"canceled",
|
9
|
+
"error",
|
10
|
+
"completed",
|
11
|
+
}
|
12
|
+
|
13
|
+
# Terminal event types (success/failure) across FT and RL
|
14
|
+
TERMINAL_EVENT_SUCCESS = {
|
15
|
+
"sft.completed",
|
16
|
+
"sft.workflow.completed",
|
17
|
+
"rl.job.completed",
|
18
|
+
"rl.train.completed",
|
19
|
+
"workflow.completed",
|
20
|
+
}
|
21
|
+
|
22
|
+
TERMINAL_EVENT_FAILURE = {
|
23
|
+
"sft.failed",
|
24
|
+
"sft.workflow.failed",
|
25
|
+
"rl.job.failed",
|
26
|
+
"workflow.failed",
|
27
|
+
}
|
28
|
+
|
29
|
+
|
@@ -0,0 +1,59 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import Any, Dict, Optional
|
5
|
+
|
6
|
+
from ..http import AsyncHttpClient, HTTPError
|
7
|
+
|
8
|
+
|
9
|
+
class FtClient:
|
10
|
+
def __init__(self, base_url: str, api_key: str, *, timeout: float = 30.0) -> None:
|
11
|
+
self._base_url = base_url.rstrip("/")
|
12
|
+
self._api_key = api_key
|
13
|
+
self._timeout = timeout
|
14
|
+
|
15
|
+
async def upload_training_file(self, path: str | Path, *, purpose: str = "fine-tune") -> str:
|
16
|
+
p = Path(path)
|
17
|
+
content = p.read_bytes()
|
18
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
19
|
+
data = {"purpose": purpose}
|
20
|
+
files = {"file": (p.name, content, _infer_content_type(p.name))}
|
21
|
+
js = await http.post_multipart("/api/learning/files", data=data, files=files)
|
22
|
+
if not isinstance(js, dict) or "id" not in js:
|
23
|
+
raise HTTPError(status=500, url="/api/learning/files", message="invalid_upload_response", body_snippet=str(js)[:200])
|
24
|
+
return str(js["id"])
|
25
|
+
|
26
|
+
async def create_sft_job(
|
27
|
+
self,
|
28
|
+
*,
|
29
|
+
model: str,
|
30
|
+
training_file_id: str,
|
31
|
+
hyperparameters: Dict[str, Any],
|
32
|
+
metadata: Optional[Dict[str, Any]] = None,
|
33
|
+
) -> Dict[str, Any]:
|
34
|
+
body = {
|
35
|
+
"training_type": "sft_offline",
|
36
|
+
"model": model,
|
37
|
+
"training_file_id": training_file_id,
|
38
|
+
"hyperparameters": dict(hyperparameters or {}),
|
39
|
+
"metadata": dict(metadata or {}),
|
40
|
+
}
|
41
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
42
|
+
return await http.post_json("/api/learning/jobs", json=body)
|
43
|
+
|
44
|
+
async def start_job(self, job_id: str) -> Dict[str, Any]:
|
45
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
46
|
+
return await http.post_json(f"/api/learning/jobs/{job_id}/start", json={})
|
47
|
+
|
48
|
+
|
49
|
+
def _infer_content_type(filename: str) -> str:
|
50
|
+
name = filename.lower()
|
51
|
+
if name.endswith(".jsonl"):
|
52
|
+
return "application/jsonl"
|
53
|
+
if name.endswith(".json"):
|
54
|
+
return "application/json"
|
55
|
+
if name.endswith(".txt"):
|
56
|
+
return "text/plain"
|
57
|
+
return "application/octet-stream"
|
58
|
+
|
59
|
+
|
synth_ai/learning/gateway.py
CHANGED
@@ -0,0 +1,43 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Any, Dict, Optional
|
4
|
+
import aiohttp
|
5
|
+
|
6
|
+
from ..http import AsyncHttpClient
|
7
|
+
|
8
|
+
|
9
|
+
def _api_base(b: str) -> str:
|
10
|
+
b = (b or "").rstrip("/")
|
11
|
+
return b if b.endswith("/api") else f"{b}/api"
|
12
|
+
|
13
|
+
|
14
|
+
async def backend_health(base_url: str, api_key: str) -> Dict[str, Any]:
|
15
|
+
async with AsyncHttpClient(base_url, api_key, timeout=15.0) as http:
|
16
|
+
js = await http.get(f"{_api_base(base_url)}/health")
|
17
|
+
return {"ok": True, "raw": js}
|
18
|
+
|
19
|
+
|
20
|
+
async def task_app_health(task_app_url: str) -> Dict[str, Any]:
|
21
|
+
# Delegate to central task module for consistency
|
22
|
+
from synth_ai.task.health import task_app_health as _th
|
23
|
+
|
24
|
+
return await _th(task_app_url)
|
25
|
+
|
26
|
+
|
27
|
+
async def pricing_preflight(base_url: str, api_key: str, *, job_type: str, gpu_type: str, estimated_seconds: float, container_count: int) -> Dict[str, Any]:
|
28
|
+
body = {
|
29
|
+
"job_type": job_type,
|
30
|
+
"gpu_type": gpu_type,
|
31
|
+
"estimated_seconds": float(estimated_seconds or 0.0),
|
32
|
+
"container_count": int(container_count or 1),
|
33
|
+
}
|
34
|
+
async with AsyncHttpClient(base_url, api_key, timeout=30.0) as http:
|
35
|
+
js = await http.post_json(f"{_api_base(base_url)}/v1/pricing/preflight", json=body)
|
36
|
+
return js if isinstance(js, dict) else {"raw": js}
|
37
|
+
|
38
|
+
|
39
|
+
async def balance_autumn_normalized(base_url: str, api_key: str) -> Dict[str, Any]:
|
40
|
+
async with AsyncHttpClient(base_url, api_key, timeout=30.0) as http:
|
41
|
+
js = await http.get(f"{_api_base(base_url)}/v1/balance/autumn-normalized")
|
42
|
+
return js if isinstance(js, dict) else {"raw": js}
|
43
|
+
|
@@ -0,0 +1,205 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Any, Callable, Dict, List, Optional
|
4
|
+
import time
|
5
|
+
|
6
|
+
from .constants import TERMINAL_EVENT_FAILURE, TERMINAL_EVENT_SUCCESS, TERMINAL_STATUSES
|
7
|
+
from ..http import AsyncHttpClient, sleep
|
8
|
+
|
9
|
+
|
10
|
+
def _api_base(b: str) -> str:
|
11
|
+
b = (b or "").rstrip("/")
|
12
|
+
return b if b.endswith("/api") else f"{b}/api"
|
13
|
+
|
14
|
+
|
15
|
+
class JobsApiResolver:
|
16
|
+
def __init__(self, base_url: str, *, strict: bool) -> None:
|
17
|
+
self._base = _api_base(base_url)
|
18
|
+
self._strict = strict
|
19
|
+
|
20
|
+
def status_urls(self, job_id: str) -> List[str]:
|
21
|
+
if self._strict:
|
22
|
+
return [f"{self._base}/learning/jobs/{job_id}"]
|
23
|
+
return [
|
24
|
+
f"{self._base}/learning/jobs/{job_id}",
|
25
|
+
f"{self._base}/rl/jobs/{job_id}",
|
26
|
+
f"{self._base}/orchestration/jobs/{job_id}",
|
27
|
+
]
|
28
|
+
|
29
|
+
def events_urls(self, job_id: str, since: int) -> List[str]:
|
30
|
+
if self._strict:
|
31
|
+
return [f"{self._base}/learning/jobs/{job_id}/events?since_seq={since}&limit=200"]
|
32
|
+
return [
|
33
|
+
f"{self._base}/learning/jobs/{job_id}/events?since_seq={since}&limit=200",
|
34
|
+
f"{self._base}/orchestration/jobs/{job_id}/events?since_seq={since}&limit=200",
|
35
|
+
# RL /jobs/{id}/events is SSE in backend; avoid in JSON poller
|
36
|
+
]
|
37
|
+
|
38
|
+
def metrics_url(self, job_id: str, after_step: int) -> str:
|
39
|
+
return f"{self._base}/learning/jobs/{job_id}/metrics?after_step={after_step}&limit=200"
|
40
|
+
|
41
|
+
|
42
|
+
class JobHandle:
|
43
|
+
def __init__(self, base_url: str, api_key: str, job_id: str, *, strict: bool = True, timeout: float = 600.0) -> None:
|
44
|
+
self.base_url = base_url.rstrip("/")
|
45
|
+
self.api_key = api_key
|
46
|
+
self.job_id = job_id
|
47
|
+
self.strict = strict
|
48
|
+
self.timeout = timeout
|
49
|
+
|
50
|
+
async def poll_until_terminal(
|
51
|
+
self,
|
52
|
+
*,
|
53
|
+
interval_seconds: float = 2.0,
|
54
|
+
max_seconds: float | None = None,
|
55
|
+
empty_polls_threshold: int = 5,
|
56
|
+
startup_deadline_s: int = 45,
|
57
|
+
on_event: Optional[Callable[[Dict[str, Any]], None]] = None,
|
58
|
+
on_metric: Optional[Callable[[Dict[str, Any]], None]] = None,
|
59
|
+
) -> Dict[str, Any]:
|
60
|
+
last_seq_by_stream: Dict[str, int] = {}
|
61
|
+
events_job_id: Optional[str] = None
|
62
|
+
last_status: Optional[str] = None
|
63
|
+
last_step_by_name: Dict[str, int] = {}
|
64
|
+
empty_polls = 0
|
65
|
+
saw_any_event = False
|
66
|
+
start_t = time.time()
|
67
|
+
resolver = JobsApiResolver(self.base_url, strict=self.strict)
|
68
|
+
detected_fine_tuned_model: Optional[str] = None
|
69
|
+
|
70
|
+
async with AsyncHttpClient(self.base_url, self.api_key, timeout=self.timeout) as http:
|
71
|
+
while True:
|
72
|
+
# Status
|
73
|
+
status_data: Optional[Dict[str, Any]] = None
|
74
|
+
for su in resolver.status_urls(self.job_id):
|
75
|
+
try:
|
76
|
+
status_data = await http.get(su)
|
77
|
+
if isinstance(status_data, dict):
|
78
|
+
break
|
79
|
+
except Exception:
|
80
|
+
continue
|
81
|
+
status = str((status_data or {}).get("status") or "").lower()
|
82
|
+
if status_data:
|
83
|
+
linked = status_data.get("linked_job_id")
|
84
|
+
if isinstance(linked, str) and linked and linked != events_job_id:
|
85
|
+
events_job_id = linked
|
86
|
+
# Capture fine_tuned_model if already present on status
|
87
|
+
if not detected_fine_tuned_model:
|
88
|
+
ftm = status_data.get("fine_tuned_model")
|
89
|
+
if isinstance(ftm, str) and ftm:
|
90
|
+
detected_fine_tuned_model = ftm
|
91
|
+
if status and status != last_status:
|
92
|
+
last_status = status
|
93
|
+
if on_event:
|
94
|
+
try:
|
95
|
+
on_event({"type": "job.status", "message": status})
|
96
|
+
except Exception:
|
97
|
+
pass
|
98
|
+
|
99
|
+
# Events
|
100
|
+
stream_ids = [self.job_id]
|
101
|
+
if events_job_id and events_job_id not in stream_ids:
|
102
|
+
stream_ids.append(events_job_id)
|
103
|
+
total_events_this_cycle = 0
|
104
|
+
terminal_event_seen = False
|
105
|
+
terminal_event_status: Optional[str] = None
|
106
|
+
for ev_id in stream_ids:
|
107
|
+
since = last_seq_by_stream.get(ev_id, 0)
|
108
|
+
for eu in resolver.events_urls(ev_id, since):
|
109
|
+
try:
|
110
|
+
ev_js = await http.get(eu)
|
111
|
+
except Exception:
|
112
|
+
continue
|
113
|
+
try:
|
114
|
+
events = (ev_js or {}).get("events") or (ev_js or {}).get("data") or []
|
115
|
+
if not isinstance(events, list):
|
116
|
+
events = []
|
117
|
+
except Exception:
|
118
|
+
events = []
|
119
|
+
total_events_this_cycle += len(events)
|
120
|
+
if events:
|
121
|
+
saw_any_event = True
|
122
|
+
for e in events:
|
123
|
+
seq_val = int(e.get("seq") or 0)
|
124
|
+
if seq_val <= last_seq_by_stream.get(ev_id, 0):
|
125
|
+
continue
|
126
|
+
last_seq_by_stream[ev_id] = seq_val
|
127
|
+
if on_event:
|
128
|
+
try:
|
129
|
+
on_event(e)
|
130
|
+
except Exception:
|
131
|
+
pass
|
132
|
+
et = str(e.get("type") or e.get("event_type") or "").lower()
|
133
|
+
# Capture fine_tuned_model from event data when available
|
134
|
+
if not detected_fine_tuned_model:
|
135
|
+
try:
|
136
|
+
data_obj = e.get("data") or {}
|
137
|
+
ftm = data_obj.get("fine_tuned_model") if isinstance(data_obj, dict) else None
|
138
|
+
if isinstance(ftm, str) and ftm:
|
139
|
+
detected_fine_tuned_model = ftm
|
140
|
+
except Exception:
|
141
|
+
pass
|
142
|
+
if et in TERMINAL_EVENT_SUCCESS:
|
143
|
+
terminal_event_seen = True
|
144
|
+
terminal_event_status = "succeeded"
|
145
|
+
elif et in TERMINAL_EVENT_FAILURE:
|
146
|
+
terminal_event_seen = True
|
147
|
+
terminal_event_status = "failed"
|
148
|
+
|
149
|
+
# Metrics
|
150
|
+
try:
|
151
|
+
after = max(last_step_by_name.values()) if last_step_by_name else -1
|
152
|
+
mu = resolver.metrics_url(self.job_id, after)
|
153
|
+
md = await http.get(mu)
|
154
|
+
for p in (md or {}).get("points", []):
|
155
|
+
name = str(p.get("name") or "")
|
156
|
+
step = int(p.get("step") or -1)
|
157
|
+
if step <= last_step_by_name.get(name, -1):
|
158
|
+
continue
|
159
|
+
last_step_by_name[name] = step
|
160
|
+
if on_metric:
|
161
|
+
try:
|
162
|
+
on_metric(p)
|
163
|
+
except Exception:
|
164
|
+
pass
|
165
|
+
except Exception:
|
166
|
+
pass
|
167
|
+
|
168
|
+
# Terminal decisions
|
169
|
+
if terminal_event_seen or (status and status in TERMINAL_STATUSES):
|
170
|
+
# Best-effort enrichment of final result with fine_tuned_model
|
171
|
+
result_status = terminal_event_status or status or "completed"
|
172
|
+
final_res: Dict[str, Any] = {"status": result_status, "job_id": self.job_id}
|
173
|
+
if not detected_fine_tuned_model:
|
174
|
+
# Briefly try to re-fetch status to see if fine_tuned_model is persisted
|
175
|
+
try:
|
176
|
+
for su in resolver.status_urls(self.job_id):
|
177
|
+
try:
|
178
|
+
final_status = await http.get(su)
|
179
|
+
if isinstance(final_status, dict):
|
180
|
+
ftm2 = final_status.get("fine_tuned_model")
|
181
|
+
if isinstance(ftm2, str) and ftm2:
|
182
|
+
detected_fine_tuned_model = ftm2
|
183
|
+
break
|
184
|
+
except Exception:
|
185
|
+
continue
|
186
|
+
except Exception:
|
187
|
+
pass
|
188
|
+
if detected_fine_tuned_model:
|
189
|
+
final_res["fine_tuned_model"] = detected_fine_tuned_model
|
190
|
+
return final_res
|
191
|
+
|
192
|
+
# Guards (relaxed): do not abort on consecutive empty polls
|
193
|
+
if total_events_this_cycle == 0:
|
194
|
+
empty_polls += 1
|
195
|
+
else:
|
196
|
+
empty_polls = 0
|
197
|
+
if not saw_any_event and (time.time() - start_t) > int(startup_deadline_s):
|
198
|
+
raise AssertionError(
|
199
|
+
f"No events observed within startup window ({startup_deadline_s}s). Investigate event streaming."
|
200
|
+
)
|
201
|
+
await sleep(interval_seconds)
|
202
|
+
if max_seconds is not None and (time.time() - start_t) >= max_seconds:
|
203
|
+
raise TimeoutError(f"Polling timed out after {max_seconds}s for job {self.job_id}")
|
204
|
+
|
205
|
+
|
@@ -18,16 +18,15 @@ from __future__ import annotations
|
|
18
18
|
import asyncio
|
19
19
|
import os
|
20
20
|
import random
|
21
|
-
from typing import
|
21
|
+
from typing import Any
|
22
22
|
|
23
|
-
from dotenv import load_dotenv
|
24
23
|
from datasets import load_dataset
|
25
|
-
|
24
|
+
from dotenv import load_dotenv
|
26
25
|
from synth_ai.lm.core.main_v3 import LM, build_messages
|
27
26
|
from synth_ai.lm.overrides import LMOverridesContext
|
28
27
|
|
29
28
|
|
30
|
-
async def classify_one(lm: LM, text: str, label_names:
|
29
|
+
async def classify_one(lm: LM, text: str, label_names: list[str]) -> str:
|
31
30
|
labels_joined = ", ".join(label_names)
|
32
31
|
system_message = (
|
33
32
|
"You are an intent classifier for the Banking77 dataset. "
|
@@ -41,7 +40,7 @@ async def classify_one(lm: LM, text: str, label_names: List[str]) -> str:
|
|
41
40
|
return (resp.raw_response or "").strip()
|
42
41
|
|
43
42
|
|
44
|
-
def choose_label(pred: str, label_names:
|
43
|
+
def choose_label(pred: str, label_names: list[str]) -> str:
|
45
44
|
norm_pred = pred.strip().lower()
|
46
45
|
label_lookup = {ln.lower(): ln for ln in label_names}
|
47
46
|
mapped = label_lookup.get(norm_pred)
|
@@ -56,12 +55,18 @@ def choose_label(pred: str, label_names: List[str]) -> str:
|
|
56
55
|
return max(label_names, key=score)
|
57
56
|
|
58
57
|
|
59
|
-
async def eval_context(
|
58
|
+
async def eval_context(
|
59
|
+
lm: LM,
|
60
|
+
items: list[tuple[str, str]],
|
61
|
+
label_names: list[str],
|
62
|
+
ctx_name: str,
|
63
|
+
specs: list[dict[str, Any]],
|
64
|
+
) -> tuple[str, int, int]:
|
60
65
|
correct = 0
|
61
66
|
with LMOverridesContext(specs):
|
62
67
|
tasks = [classify_one(lm, text, label_names) for text, _ in items]
|
63
68
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
64
|
-
for (text, gold), pred in zip(items, results):
|
69
|
+
for (text, gold), pred in zip(items, results, strict=False):
|
65
70
|
if isinstance(pred, Exception):
|
66
71
|
# Treat exceptions as incorrect
|
67
72
|
continue
|
@@ -81,7 +86,7 @@ async def main() -> None:
|
|
81
86
|
|
82
87
|
print("Loading Banking77 dataset (split='test')...")
|
83
88
|
ds = load_dataset("banking77", split="test")
|
84
|
-
label_names:
|
89
|
+
label_names: list[str] = ds.features["label"].names # type: ignore
|
85
90
|
|
86
91
|
idxs = random.sample(range(len(ds)), k=min(n, len(ds)))
|
87
92
|
items = [
|
@@ -90,7 +95,7 @@ async def main() -> None:
|
|
90
95
|
]
|
91
96
|
|
92
97
|
# Define a few override contexts to compare
|
93
|
-
contexts:
|
98
|
+
contexts: list[dict[str, Any]] = [
|
94
99
|
{
|
95
100
|
"name": "baseline (no overrides)",
|
96
101
|
"overrides": [],
|
@@ -145,7 +150,7 @@ async def main() -> None:
|
|
145
150
|
print(f"\nEvaluating {len(contexts)} contexts on {len(items)} Banking77 samples (async)...")
|
146
151
|
|
147
152
|
# Evaluate each context sequentially but batched (each context classifies in parallel)
|
148
|
-
results:
|
153
|
+
results: list[tuple[str, int, int]] = []
|
149
154
|
for ctx in contexts:
|
150
155
|
name = ctx["name"]
|
151
156
|
specs = ctx["overrides"]
|
@@ -27,18 +27,17 @@ from __future__ import annotations
|
|
27
27
|
import asyncio
|
28
28
|
import os
|
29
29
|
import random
|
30
|
-
from typing import Any, Dict, List, Optional
|
31
30
|
|
32
31
|
from datasets import load_dataset
|
33
32
|
|
34
33
|
# Use the v3 LM class present in this repo
|
35
34
|
from synth_ai.lm.core.main_v3 import LM, build_messages
|
36
|
-
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
37
|
-
from synth_ai.tracing_v3.abstractions import LMCAISEvent
|
38
|
-
|
39
35
|
|
40
36
|
# Use Overrides context to demonstrate matching by content
|
41
37
|
from synth_ai.lm.overrides import LMOverridesContext
|
38
|
+
from synth_ai.tracing_v3.abstractions import LMCAISEvent
|
39
|
+
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
40
|
+
|
42
41
|
INJECTION_RULES = [
|
43
42
|
{"find": "accnt", "replace": "account"},
|
44
43
|
{"find": "atm", "replace": "ATM"},
|
@@ -46,7 +45,7 @@ INJECTION_RULES = [
|
|
46
45
|
]
|
47
46
|
|
48
47
|
|
49
|
-
async def classify_sample(lm: LM, text: str, label_names:
|
48
|
+
async def classify_sample(lm: LM, text: str, label_names: list[str]) -> str:
|
50
49
|
"""Classify one Banking77 utterance and return the predicted label name."""
|
51
50
|
labels_joined = ", ".join(label_names)
|
52
51
|
system_message = (
|
@@ -77,7 +76,7 @@ async def main() -> None:
|
|
77
76
|
# Columns: {"text": str, "label": int}; label names at ds.features["label"].names
|
78
77
|
print("Loading Banking77 dataset (split='test')...")
|
79
78
|
ds = load_dataset("banking77", split="test")
|
80
|
-
label_names:
|
79
|
+
label_names: list[str] = ds.features["label"].names # type: ignore
|
81
80
|
|
82
81
|
# Sample a few items for a quick demo
|
83
82
|
n = int(os.getenv("N_SAMPLES", "8"))
|
@@ -116,7 +115,9 @@ async def main() -> None:
|
|
116
115
|
|
117
116
|
is_correct = pred_label == gold_label
|
118
117
|
correct += int(is_correct)
|
119
|
-
print(
|
118
|
+
print(
|
119
|
+
f"[{i}] text={text!r}\n gold={gold_label}\n pred={pred} -> mapped={pred_label} {'✅' if is_correct else '❌'}"
|
120
|
+
)
|
120
121
|
|
121
122
|
if idxs:
|
122
123
|
acc = correct / len(idxs)
|
@@ -137,7 +138,11 @@ async def main() -> None:
|
|
137
138
|
with LMOverridesContext([{"match": {"contains": "atm"}, "injection_rules": INJECTION_RULES}]):
|
138
139
|
_ = await classify_sample(lm_traced, test_text, label_names)
|
139
140
|
# inspect trace
|
140
|
-
events = [
|
141
|
+
events = [
|
142
|
+
e
|
143
|
+
for e in (tracer.current_session.event_history if tracer.current_session else [])
|
144
|
+
if isinstance(e, LMCAISEvent)
|
145
|
+
]
|
141
146
|
assert events, "No LMCAISEvent recorded by SessionTracer"
|
142
147
|
cr = events[-1].call_records[0]
|
143
148
|
traced_user = ""
|
@@ -145,7 +150,7 @@ async def main() -> None:
|
|
145
150
|
if m.role == "user":
|
146
151
|
for part in m.parts:
|
147
152
|
if getattr(part, "type", None) == "text":
|
148
|
-
traced_user +=
|
153
|
+
traced_user += part.text or ""
|
149
154
|
assert "ATM" in traced_user, f"Expected substitution in traced prompt; got: {traced_user!r}"
|
150
155
|
print("LM path trace verified: substitution present in traced prompt.")
|
151
156
|
await tracer.end_timestep()
|
@@ -155,7 +160,7 @@ async def main() -> None:
|
|
155
160
|
try:
|
156
161
|
import synth_ai.lm.provider_support.openai as _synth_openai_patch # noqa: F401
|
157
162
|
from openai import AsyncOpenAI
|
158
|
-
|
163
|
+
|
159
164
|
base_url = os.getenv("OPENAI_BASE_URL", "https://api.groq.com/openai/v1")
|
160
165
|
api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GROQ_API_KEY") or ""
|
161
166
|
client = AsyncOpenAI(base_url=base_url, api_key=api_key)
|
@@ -163,8 +168,12 @@ async def main() -> None:
|
|
163
168
|
{"role": "system", "content": "Echo user label."},
|
164
169
|
{"role": "user", "content": f"Please classify: {test_text}"},
|
165
170
|
]
|
166
|
-
with LMOverridesContext(
|
167
|
-
|
171
|
+
with LMOverridesContext(
|
172
|
+
[{"match": {"contains": "atm"}, "injection_rules": INJECTION_RULES}]
|
173
|
+
):
|
174
|
+
_ = await client.chat.completions.create(
|
175
|
+
model=model, messages=messages, temperature=0
|
176
|
+
)
|
168
177
|
# Not all models echo input; instead, verify that our injected expectation matches
|
169
178
|
expected_user = messages[1]["content"].replace("atm", "ATM")
|
170
179
|
if messages[1]["content"] == expected_user:
|
@@ -176,13 +185,16 @@ async def main() -> None:
|
|
176
185
|
|
177
186
|
# 3) Anthropic wrapper path (AsyncClient): ensure apply_injection is active
|
178
187
|
try:
|
179
|
-
import synth_ai.lm.provider_support.anthropic as _synth_anthropic_patch # noqa: F401
|
180
188
|
import anthropic
|
189
|
+
import synth_ai.lm.provider_support.anthropic as _synth_anthropic_patch # noqa: F401
|
190
|
+
|
181
191
|
a_model = os.getenv("ANTHROPIC_MODEL", "claude-3-5-haiku-20241022")
|
182
192
|
a_key = os.getenv("ANTHROPIC_API_KEY")
|
183
193
|
if a_key:
|
184
194
|
a_client = anthropic.AsyncClient(api_key=a_key)
|
185
|
-
with LMOverridesContext(
|
195
|
+
with LMOverridesContext(
|
196
|
+
[{"match": {"contains": "atm"}, "injection_rules": INJECTION_RULES}]
|
197
|
+
):
|
186
198
|
_ = await a_client.messages.create(
|
187
199
|
model=a_model,
|
188
200
|
system="Echo user label.",
|