synth-ai 0.2.4.dev6__py3-none-any.whl → 0.2.4.dev8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +18 -9
- synth_ai/cli/__init__.py +10 -5
- synth_ai/cli/balance.py +25 -32
- synth_ai/cli/calc.py +2 -3
- synth_ai/cli/demo.py +3 -5
- synth_ai/cli/legacy_root_backup.py +58 -32
- synth_ai/cli/man.py +22 -19
- synth_ai/cli/recent.py +9 -8
- synth_ai/cli/root.py +58 -13
- synth_ai/cli/status.py +13 -6
- synth_ai/cli/traces.py +45 -21
- synth_ai/cli/watch.py +40 -37
- synth_ai/config/base_url.py +47 -2
- synth_ai/core/experiment.py +1 -2
- synth_ai/environments/__init__.py +2 -6
- synth_ai/environments/environment/artifacts/base.py +3 -1
- synth_ai/environments/environment/db/sqlite.py +1 -1
- synth_ai/environments/environment/registry.py +19 -20
- synth_ai/environments/environment/resources/sqlite.py +2 -3
- synth_ai/environments/environment/rewards/core.py +3 -2
- synth_ai/environments/environment/tools/__init__.py +6 -4
- synth_ai/environments/examples/crafter_classic/__init__.py +1 -1
- synth_ai/environments/examples/crafter_classic/engine.py +13 -13
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +1 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +2 -1
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +2 -1
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +3 -2
- synth_ai/environments/examples/crafter_classic/environment.py +16 -15
- synth_ai/environments/examples/crafter_classic/taskset.py +2 -2
- synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +2 -3
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +2 -1
- synth_ai/environments/examples/crafter_custom/crafter/__init__.py +2 -2
- synth_ai/environments/examples/crafter_custom/crafter/config.py +2 -2
- synth_ai/environments/examples/crafter_custom/crafter/env.py +1 -5
- synth_ai/environments/examples/crafter_custom/crafter/objects.py +1 -2
- synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +1 -2
- synth_ai/environments/examples/crafter_custom/dataset_builder.py +5 -5
- synth_ai/environments/examples/crafter_custom/environment.py +13 -13
- synth_ai/environments/examples/crafter_custom/run_dataset.py +5 -5
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +2 -2
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +5 -4
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +2 -1
- synth_ai/environments/examples/enron/engine.py +18 -14
- synth_ai/environments/examples/enron/environment.py +12 -11
- synth_ai/environments/examples/enron/taskset.py +7 -7
- synth_ai/environments/examples/minigrid/__init__.py +6 -6
- synth_ai/environments/examples/minigrid/engine.py +6 -6
- synth_ai/environments/examples/minigrid/environment.py +6 -6
- synth_ai/environments/examples/minigrid/puzzle_loader.py +3 -2
- synth_ai/environments/examples/minigrid/taskset.py +13 -13
- synth_ai/environments/examples/nethack/achievements.py +1 -1
- synth_ai/environments/examples/nethack/engine.py +8 -7
- synth_ai/environments/examples/nethack/environment.py +10 -9
- synth_ai/environments/examples/nethack/helpers/__init__.py +8 -9
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +1 -1
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +2 -1
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +1 -1
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +3 -4
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +6 -5
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +5 -5
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +7 -6
- synth_ai/environments/examples/nethack/taskset.py +5 -5
- synth_ai/environments/examples/red/engine.py +9 -8
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +7 -7
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +3 -2
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +2 -1
- synth_ai/environments/examples/red/environment.py +18 -15
- synth_ai/environments/examples/red/taskset.py +5 -3
- synth_ai/environments/examples/sokoban/engine.py +16 -13
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +3 -2
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +2 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +1 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +7 -5
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +1 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +2 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +5 -4
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +3 -2
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +2 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +5 -4
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +1 -1
- synth_ai/environments/examples/sokoban/environment.py +15 -14
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +5 -3
- synth_ai/environments/examples/sokoban/puzzle_loader.py +3 -2
- synth_ai/environments/examples/sokoban/taskset.py +13 -10
- synth_ai/environments/examples/tictactoe/engine.py +6 -6
- synth_ai/environments/examples/tictactoe/environment.py +8 -7
- synth_ai/environments/examples/tictactoe/taskset.py +6 -5
- synth_ai/environments/examples/verilog/engine.py +4 -3
- synth_ai/environments/examples/verilog/environment.py +11 -10
- synth_ai/environments/examples/verilog/taskset.py +14 -12
- synth_ai/environments/examples/wordle/__init__.py +5 -5
- synth_ai/environments/examples/wordle/engine.py +32 -25
- synth_ai/environments/examples/wordle/environment.py +21 -16
- synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +6 -6
- synth_ai/environments/examples/wordle/taskset.py +20 -12
- synth_ai/environments/reproducibility/core.py +1 -1
- synth_ai/environments/reproducibility/tree.py +21 -21
- synth_ai/environments/service/app.py +3 -2
- synth_ai/environments/service/core_routes.py +104 -110
- synth_ai/environments/service/external_registry.py +1 -2
- synth_ai/environments/service/registry.py +1 -1
- synth_ai/environments/stateful/core.py +1 -2
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/api.py +4 -4
- synth_ai/environments/tasks/core.py +14 -12
- synth_ai/environments/tasks/filters.py +6 -4
- synth_ai/environments/tasks/utils.py +13 -11
- synth_ai/evals/base.py +2 -3
- synth_ai/experimental/synth_oss.py +4 -4
- synth_ai/http.py +102 -0
- synth_ai/inference/__init__.py +7 -0
- synth_ai/inference/client.py +20 -0
- synth_ai/jobs/client.py +246 -0
- synth_ai/learning/__init__.py +24 -0
- synth_ai/learning/client.py +149 -0
- synth_ai/learning/config.py +43 -0
- synth_ai/learning/constants.py +29 -0
- synth_ai/learning/ft_client.py +59 -0
- synth_ai/learning/gateway.py +1 -3
- synth_ai/learning/health.py +43 -0
- synth_ai/learning/jobs.py +205 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +15 -10
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +26 -14
- synth_ai/learning/prompts/mipro.py +61 -52
- synth_ai/learning/prompts/random_search.py +42 -43
- synth_ai/learning/prompts/run_mipro_banking77.py +32 -20
- synth_ai/learning/prompts/run_random_search_banking77.py +71 -52
- synth_ai/learning/rl_client.py +256 -0
- synth_ai/learning/sse.py +58 -0
- synth_ai/learning/validators.py +48 -0
- synth_ai/lm/__init__.py +5 -5
- synth_ai/lm/caching/ephemeral.py +9 -9
- synth_ai/lm/caching/handler.py +20 -20
- synth_ai/lm/caching/persistent.py +10 -10
- synth_ai/lm/config.py +3 -3
- synth_ai/lm/constants.py +7 -7
- synth_ai/lm/core/all.py +17 -3
- synth_ai/lm/core/exceptions.py +0 -2
- synth_ai/lm/core/main.py +26 -41
- synth_ai/lm/core/main_v3.py +33 -10
- synth_ai/lm/core/synth_models.py +48 -0
- synth_ai/lm/core/vendor_clients.py +26 -22
- synth_ai/lm/injection.py +7 -8
- synth_ai/lm/overrides.py +21 -19
- synth_ai/lm/provider_support/__init__.py +1 -1
- synth_ai/lm/provider_support/anthropic.py +15 -15
- synth_ai/lm/provider_support/openai.py +23 -21
- synth_ai/lm/structured_outputs/handler.py +34 -32
- synth_ai/lm/structured_outputs/inject.py +24 -27
- synth_ai/lm/structured_outputs/rehabilitate.py +19 -15
- synth_ai/lm/tools/base.py +17 -16
- synth_ai/lm/unified_interface.py +17 -18
- synth_ai/lm/vendors/base.py +20 -18
- synth_ai/lm/vendors/core/anthropic_api.py +36 -27
- synth_ai/lm/vendors/core/gemini_api.py +31 -36
- synth_ai/lm/vendors/core/mistral_api.py +19 -19
- synth_ai/lm/vendors/core/openai_api.py +42 -13
- synth_ai/lm/vendors/openai_standard.py +158 -101
- synth_ai/lm/vendors/openai_standard_responses.py +74 -61
- synth_ai/lm/vendors/retries.py +9 -1
- synth_ai/lm/vendors/supported/custom_endpoint.py +38 -28
- synth_ai/lm/vendors/supported/deepseek.py +10 -10
- synth_ai/lm/vendors/supported/grok.py +8 -8
- synth_ai/lm/vendors/supported/ollama.py +2 -1
- synth_ai/lm/vendors/supported/openrouter.py +11 -9
- synth_ai/lm/vendors/synth_client.py +425 -75
- synth_ai/lm/warmup.py +8 -7
- synth_ai/rl/__init__.py +30 -0
- synth_ai/rl/contracts.py +32 -0
- synth_ai/rl/env_keys.py +137 -0
- synth_ai/rl/secrets.py +19 -0
- synth_ai/scripts/verify_rewards.py +100 -0
- synth_ai/task/__init__.py +10 -0
- synth_ai/task/contracts.py +120 -0
- synth_ai/task/health.py +28 -0
- synth_ai/task/validators.py +12 -0
- synth_ai/tracing/__init__.py +22 -10
- synth_ai/tracing_v1/__init__.py +22 -20
- synth_ai/tracing_v3/__init__.py +7 -7
- synth_ai/tracing_v3/abstractions.py +56 -52
- synth_ai/tracing_v3/config.py +4 -2
- synth_ai/tracing_v3/db_config.py +6 -8
- synth_ai/tracing_v3/decorators.py +29 -30
- synth_ai/tracing_v3/examples/basic_usage.py +12 -12
- synth_ai/tracing_v3/hooks.py +24 -22
- synth_ai/tracing_v3/llm_call_record_helpers.py +85 -98
- synth_ai/tracing_v3/lm_call_record_abstractions.py +2 -4
- synth_ai/tracing_v3/migration_helper.py +3 -5
- synth_ai/tracing_v3/replica_sync.py +30 -32
- synth_ai/tracing_v3/session_tracer.py +158 -31
- synth_ai/tracing_v3/storage/__init__.py +1 -1
- synth_ai/tracing_v3/storage/base.py +8 -7
- synth_ai/tracing_v3/storage/config.py +4 -4
- synth_ai/tracing_v3/storage/factory.py +4 -4
- synth_ai/tracing_v3/storage/utils.py +9 -9
- synth_ai/tracing_v3/turso/__init__.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +9 -9
- synth_ai/tracing_v3/turso/manager.py +278 -48
- synth_ai/tracing_v3/turso/models.py +77 -19
- synth_ai/tracing_v3/utils.py +5 -5
- synth_ai/v0/tracing/abstractions.py +28 -28
- synth_ai/v0/tracing/base_client.py +9 -9
- synth_ai/v0/tracing/client_manager.py +7 -7
- synth_ai/v0/tracing/config.py +7 -7
- synth_ai/v0/tracing/context.py +6 -6
- synth_ai/v0/tracing/decorators.py +6 -5
- synth_ai/v0/tracing/events/manage.py +1 -1
- synth_ai/v0/tracing/events/store.py +5 -4
- synth_ai/v0/tracing/immediate_client.py +4 -5
- synth_ai/v0/tracing/local.py +3 -3
- synth_ai/v0/tracing/log_client_base.py +4 -5
- synth_ai/v0/tracing/retry_queue.py +5 -6
- synth_ai/v0/tracing/trackers.py +25 -25
- synth_ai/v0/tracing/upload.py +6 -0
- synth_ai/v0/tracing_v1/__init__.py +1 -1
- synth_ai/v0/tracing_v1/abstractions.py +28 -28
- synth_ai/v0/tracing_v1/base_client.py +9 -9
- synth_ai/v0/tracing_v1/client_manager.py +7 -7
- synth_ai/v0/tracing_v1/config.py +7 -7
- synth_ai/v0/tracing_v1/context.py +6 -6
- synth_ai/v0/tracing_v1/decorators.py +7 -6
- synth_ai/v0/tracing_v1/events/manage.py +1 -1
- synth_ai/v0/tracing_v1/events/store.py +5 -4
- synth_ai/v0/tracing_v1/immediate_client.py +4 -5
- synth_ai/v0/tracing_v1/local.py +3 -3
- synth_ai/v0/tracing_v1/log_client_base.py +4 -5
- synth_ai/v0/tracing_v1/retry_queue.py +5 -6
- synth_ai/v0/tracing_v1/trackers.py +25 -25
- synth_ai/v0/tracing_v1/upload.py +25 -24
- synth_ai/zyk/__init__.py +1 -0
- synth_ai-0.2.4.dev8.dist-info/METADATA +635 -0
- synth_ai-0.2.4.dev8.dist-info/RECORD +317 -0
- synth_ai/tui/__init__.py +0 -1
- synth_ai/tui/__main__.py +0 -13
- synth_ai/tui/cli/__init__.py +0 -1
- synth_ai/tui/cli/query_experiments.py +0 -165
- synth_ai/tui/cli/query_experiments_v3.py +0 -165
- synth_ai/tui/dashboard.py +0 -329
- synth_ai-0.2.4.dev6.dist-info/METADATA +0 -203
- synth_ai-0.2.4.dev6.dist-info/RECORD +0 -299
- {synth_ai-0.2.4.dev6.dist-info → synth_ai-0.2.4.dev8.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev6.dist-info → synth_ai-0.2.4.dev8.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev6.dist-info → synth_ai-0.2.4.dev8.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev6.dist-info → synth_ai-0.2.4.dev8.dist-info}/top_level.txt +0 -0
@@ -12,44 +12,46 @@ Run:
|
|
12
12
|
from __future__ import annotations
|
13
13
|
|
14
14
|
import asyncio
|
15
|
+
import json
|
15
16
|
import os
|
16
17
|
import random
|
18
|
+
import time
|
19
|
+
from collections.abc import Sequence
|
17
20
|
from dataclasses import dataclass, replace
|
21
|
+
from pathlib import Path
|
18
22
|
from types import SimpleNamespace
|
19
|
-
from
|
20
|
-
from typing import Any, Dict, List, Sequence, Tuple
|
23
|
+
from typing import Any
|
21
24
|
|
22
|
-
from dotenv import load_dotenv
|
23
25
|
from datasets import load_dataset
|
24
|
-
|
25
|
-
from synth_ai.lm.core.main_v3 import LM, build_messages
|
26
|
-
import json
|
27
|
-
import time
|
28
|
-
from pathlib import Path
|
26
|
+
from dotenv import load_dotenv
|
29
27
|
from synth_ai.learning.prompts.random_search import random_search_compile
|
28
|
+
from synth_ai.lm.core.main_v3 import LM, build_messages
|
29
|
+
from tqdm import tqdm
|
30
30
|
|
31
31
|
|
32
|
-
def choose_label(pred: str, label_names:
|
32
|
+
def choose_label(pred: str, label_names: list[str]) -> str:
|
33
33
|
norm = (pred or "").strip().lower()
|
34
34
|
d = {ln.lower(): ln for ln in label_names}
|
35
35
|
if norm in d:
|
36
36
|
return d[norm]
|
37
|
+
|
37
38
|
def score(cand: str) -> int:
|
38
39
|
c = cand.lower()
|
39
40
|
return sum(1 for w in c.split() if w in norm)
|
41
|
+
|
40
42
|
return max(label_names, key=score)
|
41
43
|
|
42
44
|
|
43
|
-
def accuracy(pred: str, gold: str, labels:
|
45
|
+
def accuracy(pred: str, gold: str, labels: list[str]) -> float:
|
44
46
|
return 1.0 if choose_label(pred, labels) == gold else 0.0
|
45
47
|
|
46
48
|
|
47
49
|
@dataclass
|
48
50
|
class StudentProgram:
|
49
51
|
lm: LM
|
50
|
-
label_names:
|
52
|
+
label_names: list[str]
|
51
53
|
instruction: str
|
52
|
-
demos:
|
54
|
+
demos: list[tuple[str, str]]
|
53
55
|
|
54
56
|
def reset_copy(self):
|
55
57
|
return replace(self, instruction=self.instruction, demos=list(self.demos))
|
@@ -57,7 +59,7 @@ class StudentProgram:
|
|
57
59
|
def deepcopy(self):
|
58
60
|
return replace(self, instruction=str(self.instruction), demos=list(self.demos))
|
59
61
|
|
60
|
-
def with_demos(self, demos:
|
62
|
+
def with_demos(self, demos: list[tuple[str, str]]):
|
61
63
|
return replace(self, demos=list(demos))
|
62
64
|
|
63
65
|
def run(self, x: str) -> str:
|
@@ -66,10 +68,12 @@ class StudentProgram:
|
|
66
68
|
sys = self.instruction or "You are an intent classifier for Banking77."
|
67
69
|
user = (f"Examples:\n{examples}\n\n" if examples else "") + f"Message: {x}\nLabel:"
|
68
70
|
messages = build_messages(sys, user, images_bytes=None, model_name=self.lm.model)
|
71
|
+
|
69
72
|
# Call LM synchronously via asyncio
|
70
73
|
async def _call():
|
71
74
|
resp = await self.lm.respond_async(messages=messages)
|
72
75
|
return (resp.raw_response or "").strip()
|
76
|
+
|
73
77
|
return asyncio.run(_call())
|
74
78
|
|
75
79
|
async def _apredict(self, x: str):
|
@@ -91,13 +95,13 @@ def main():
|
|
91
95
|
|
92
96
|
print("Loading Banking77 dataset (train/dev split of test for demo)...")
|
93
97
|
ds = load_dataset("banking77")
|
94
|
-
label_names:
|
98
|
+
label_names: list[str] = ds["test"].features["label"].names # type: ignore
|
95
99
|
|
96
100
|
# Create small train/val from the test split for speed
|
97
101
|
all_items = [(r["text"], label_names[int(r["label"])]) for r in ds["test"]]
|
98
102
|
random.shuffle(all_items)
|
99
|
-
trainset: Sequence[
|
100
|
-
valset: Sequence[
|
103
|
+
trainset: Sequence[tuple[str, str]] = all_items[:40]
|
104
|
+
valset: Sequence[tuple[str, str]] = all_items[40:60] # 20 examples
|
101
105
|
|
102
106
|
student = StudentProgram(
|
103
107
|
lm=lm,
|
@@ -110,17 +114,20 @@ def main():
|
|
110
114
|
return accuracy(yhat, y, label_names)
|
111
115
|
|
112
116
|
total_candidates = 3 + 3 # zero-shot, labeled few-shot, bootstrapped + 3 random seeds
|
113
|
-
print(
|
117
|
+
print(
|
118
|
+
f"Running Random Search optimizer ({total_candidates} candidates, parallel eval of 20 questions)..."
|
119
|
+
)
|
114
120
|
|
115
|
-
def eval_parallel(program: StudentProgram, dataset: Sequence[
|
121
|
+
def eval_parallel(program: StudentProgram, dataset: Sequence[tuple[str, str]], metric_fn):
|
116
122
|
async def _run():
|
117
123
|
xs = [x for x, _ in dataset]
|
118
124
|
ys = [y for _, y in dataset]
|
119
|
-
preds:
|
125
|
+
preds: list[Optional[str]] = [None] * len(xs)
|
120
126
|
sem = asyncio.Semaphore(int(os.getenv("CONCURRENCY", "5")))
|
121
127
|
|
122
128
|
async def worker(i: int, x: str, y: str):
|
123
129
|
import time
|
130
|
+
|
124
131
|
t_start = time.monotonic()
|
125
132
|
try:
|
126
133
|
async with sem:
|
@@ -138,16 +145,18 @@ def main():
|
|
138
145
|
t_end = time.monotonic()
|
139
146
|
return i, y, "", t_start, t_end, {}
|
140
147
|
|
141
|
-
tasks = [asyncio.create_task(worker(i, x, y)) for i, (x, y) in enumerate(zip(xs, ys))]
|
148
|
+
tasks = [asyncio.create_task(worker(i, x, y)) for i, (x, y) in enumerate(zip(xs, ys, strict=False))]
|
142
149
|
correct_sum = 0.0
|
143
150
|
processed = 0
|
144
|
-
import
|
145
|
-
|
151
|
+
import statistics
|
152
|
+
import time
|
153
|
+
|
154
|
+
durations: list[float] = []
|
146
155
|
in_tok_sum = 0
|
147
156
|
out_tok_sum = 0
|
148
157
|
in_tok_count = 0
|
149
158
|
out_tok_count = 0
|
150
|
-
details:
|
159
|
+
details: list[dict[str, Any]] = []
|
151
160
|
t_batch_start = time.monotonic()
|
152
161
|
deadline = float(os.getenv("BATCH_DEADLINE_S", "20"))
|
153
162
|
with tqdm(total=len(tasks), desc="Rollouts", leave=False) as pbar:
|
@@ -172,7 +181,10 @@ def main():
|
|
172
181
|
break
|
173
182
|
# Wait for at least one completion within remaining time (polling granularity <= 1s)
|
174
183
|
timeout = min(1.0, remaining)
|
175
|
-
done, pending = await asyncio.wait(
|
184
|
+
done, pending = await asyncio.wait(
|
185
|
+
pending, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
|
186
|
+
)
|
187
|
+
import contextlib
|
176
188
|
for task in done:
|
177
189
|
try:
|
178
190
|
i, y_true, pred, t_start, t_end, usage = task.result()
|
@@ -182,11 +194,9 @@ def main():
|
|
182
194
|
durations.append(max(0.0, t_end - t_start))
|
183
195
|
preds[i] = pred
|
184
196
|
processed += 1
|
185
|
-
|
197
|
+
with contextlib.suppress(Exception):
|
186
198
|
correct_sum += float(metric_fn(pred, y_true))
|
187
|
-
|
188
|
-
pass
|
189
|
-
try:
|
199
|
+
with contextlib.suppress(Exception):
|
190
200
|
pt = usage.get("prompt_tokens") or usage.get("input_tokens")
|
191
201
|
ct = usage.get("completion_tokens") or usage.get("output_tokens")
|
192
202
|
if isinstance(pt, (int, float)):
|
@@ -195,30 +205,34 @@ def main():
|
|
195
205
|
if isinstance(ct, (int, float)):
|
196
206
|
out_tok_sum += int(ct)
|
197
207
|
out_tok_count += 1
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
208
|
+
details.append(
|
209
|
+
{
|
210
|
+
"index": i,
|
211
|
+
"seconds": max(0.0, t_end - t_start),
|
212
|
+
"score": float(metric_fn(pred, y_true)),
|
213
|
+
"usage": {
|
214
|
+
"prompt_tokens": usage.get("prompt_tokens")
|
215
|
+
or usage.get("input_tokens"),
|
216
|
+
"completion_tokens": usage.get("completion_tokens")
|
217
|
+
or usage.get("output_tokens"),
|
218
|
+
},
|
219
|
+
}
|
220
|
+
)
|
209
221
|
pbar.update(1)
|
210
222
|
med = statistics.median(durations) if durations else 0.0
|
211
223
|
mx = max(durations) if durations else 0.0
|
212
224
|
avg_in = (in_tok_sum / in_tok_count) if in_tok_count else 0.0
|
213
225
|
avg_out = (out_tok_sum / out_tok_count) if out_tok_count else 0.0
|
214
|
-
pbar.set_postfix(
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
226
|
+
pbar.set_postfix(
|
227
|
+
{
|
228
|
+
"acc": f"{(correct_sum / processed):.2f}",
|
229
|
+
"done": f"{processed}/{len(tasks)}",
|
230
|
+
"med_s": f"{med:.1f}",
|
231
|
+
"max_s": f"{mx:.1f}",
|
232
|
+
"tin": f"{avg_in:.1f}",
|
233
|
+
"tout": f"{avg_out:.1f}",
|
234
|
+
}
|
235
|
+
)
|
222
236
|
# Compute score only from completed/successful rollouts (drop timeouts/cancelled)
|
223
237
|
subs = [float(d.get("score", 0.0)) for d in details]
|
224
238
|
result = SimpleNamespace(score=(sum(subs) / max(1, len(subs))), subscores=subs)
|
@@ -226,28 +240,33 @@ def main():
|
|
226
240
|
result.mean_in = (in_tok_sum / in_tok_count) if in_tok_count else 0.0
|
227
241
|
result.mean_out = (out_tok_sum / out_tok_count) if out_tok_count else 0.0
|
228
242
|
return result
|
243
|
+
|
229
244
|
return asyncio.run(_run())
|
245
|
+
|
230
246
|
pbar = tqdm(total=total_candidates, desc="Candidates")
|
231
|
-
candidate_eval_details:
|
247
|
+
candidate_eval_details: dict[int, Any] = {}
|
248
|
+
|
232
249
|
def on_cand(idx: int, score: float, res, intervention):
|
233
250
|
pbar.update(1)
|
234
251
|
pbar.set_postfix({"score": f"{score:.2f}"})
|
235
252
|
# store per-instance details (for apples-to-apples)
|
236
|
-
|
253
|
+
import contextlib
|
254
|
+
with contextlib.suppress(Exception):
|
237
255
|
candidate_eval_details[idx] = {
|
238
256
|
"score": score,
|
239
257
|
"mean_in": getattr(res, "mean_in", None),
|
240
258
|
"mean_out": getattr(res, "mean_out", None),
|
241
259
|
"instances": getattr(res, "details", None),
|
242
260
|
}
|
243
|
-
except Exception:
|
244
|
-
pass
|
245
261
|
# visible summary line per candidate
|
246
|
-
kind =
|
262
|
+
kind = (
|
263
|
+
intervention.get("kind", "candidate") if isinstance(intervention, dict) else "candidate"
|
264
|
+
)
|
247
265
|
label = intervention.get("label") if isinstance(intervention, dict) else None
|
248
266
|
seed = intervention.get("seed") if isinstance(intervention, dict) else None
|
249
267
|
processed = len(getattr(res, "details", []) or [])
|
250
268
|
from tqdm import tqdm as _tqdm
|
269
|
+
|
251
270
|
_tqdm.write(
|
252
271
|
f"Candidate {idx}/{total_candidates} [{kind}{'' if label is None else f', label={label}'}{'' if seed is None else f', seed={seed}'}]: "
|
253
272
|
f"score={score:.2f} | mean tin/tout={getattr(res, 'mean_in', 0):.1f}/{getattr(res, 'mean_out', 0):.1f} | N={processed}"
|
@@ -0,0 +1,256 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Any, Dict, List, Optional, Callable
|
4
|
+
import os
|
5
|
+
import time
|
6
|
+
|
7
|
+
from ..http import AsyncHttpClient, HTTPError, sleep
|
8
|
+
|
9
|
+
|
10
|
+
def _api_base(b: str) -> str:
|
11
|
+
b = (b or "").rstrip("/")
|
12
|
+
return b if b.endswith("/api") else f"{b}/api"
|
13
|
+
|
14
|
+
|
15
|
+
class RlClient:
|
16
|
+
"""Lightweight RL client for provider-agnostic job control.
|
17
|
+
|
18
|
+
Notes:
|
19
|
+
- Uses learning/* for status/events/metrics and rl/* for creation/start.
|
20
|
+
- Trainer endpoints are resolved server-side via trainer_id.
|
21
|
+
"""
|
22
|
+
|
23
|
+
def __init__(self, base_url: str, api_key: str, *, timeout: float = 600.0) -> None:
|
24
|
+
self._base_url = base_url.rstrip("/")
|
25
|
+
self._api_key = api_key
|
26
|
+
self._timeout = timeout
|
27
|
+
|
28
|
+
async def resolve_trainer_start_url(self, trainer_id: str) -> str:
|
29
|
+
"""GET /api/rl/services/{id} → { training_start_url }"""
|
30
|
+
path = f"/api/rl/services/{trainer_id}"
|
31
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=30.0) as http:
|
32
|
+
js = await http.get(path)
|
33
|
+
if not isinstance(js, dict):
|
34
|
+
raise HTTPError(status=500, url=path, message="invalid_service_response", body_snippet=str(js)[:200])
|
35
|
+
start_url = js.get("training_start_url")
|
36
|
+
if not isinstance(start_url, str) or not start_url:
|
37
|
+
raise HTTPError(status=500, url=path, message="missing_training_start_url", body_snippet=str(js)[:200])
|
38
|
+
return start_url
|
39
|
+
|
40
|
+
async def create_job(
|
41
|
+
self,
|
42
|
+
*,
|
43
|
+
model: str,
|
44
|
+
task_app_url: str,
|
45
|
+
trainer: Dict[str, Any],
|
46
|
+
trainer_id: Optional[str] = None,
|
47
|
+
job_config_id: Optional[str] = None,
|
48
|
+
inline_config: Optional[Dict[str, Any]] = None,
|
49
|
+
) -> Dict[str, Any]:
|
50
|
+
body = {
|
51
|
+
"job_type": "rl",
|
52
|
+
"data": {
|
53
|
+
"model": model,
|
54
|
+
"endpoint_base_url": task_app_url,
|
55
|
+
**({"job_config_id": job_config_id} if job_config_id else {}),
|
56
|
+
**({"config": inline_config} if inline_config else {}),
|
57
|
+
"trainer": {
|
58
|
+
"batch_size": int(trainer.get("batch_size", 1)),
|
59
|
+
"group_size": max(2, int(trainer.get("group_size", 2))),
|
60
|
+
},
|
61
|
+
},
|
62
|
+
}
|
63
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
64
|
+
js = await http.post_json(f"{_api_base(self._base_url)}/rl/jobs", json=body)
|
65
|
+
if not isinstance(js, dict):
|
66
|
+
raise HTTPError(status=500, url="/api/rl/jobs", message="invalid_create_response", body_snippet=str(js)[:200])
|
67
|
+
return js
|
68
|
+
|
69
|
+
async def start_job_if_supported(self, job_id: str) -> Optional[Dict[str, Any]]:
|
70
|
+
path = f"{_api_base(self._base_url)}/rl/jobs/{job_id}/start"
|
71
|
+
try:
|
72
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=30.0) as http:
|
73
|
+
return await http.post_json(path, json={})
|
74
|
+
except HTTPError as he: # noqa: PERF203
|
75
|
+
if he.status == 404:
|
76
|
+
return None
|
77
|
+
raise
|
78
|
+
|
79
|
+
async def get_job(self, job_id: str) -> Dict[str, Any]:
|
80
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=30.0) as http:
|
81
|
+
return await http.get(f"{_api_base(self._base_url)}/learning/jobs/{job_id}")
|
82
|
+
|
83
|
+
async def get_events(self, job_id: str, *, since_seq: int = 0, limit: int = 200) -> List[Dict[str, Any]]:
|
84
|
+
params = {"since_seq": since_seq, "limit": limit}
|
85
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=30.0) as http:
|
86
|
+
try:
|
87
|
+
js = await http.get(f"{_api_base(self._base_url)}/learning/jobs/{job_id}/events", params=params)
|
88
|
+
except HTTPError as he:
|
89
|
+
try:
|
90
|
+
print(
|
91
|
+
f"[poll] events HTTPError status={he.status} url={he.url} since_seq={since_seq} body={(he.body_snippet or '')[:200]}"
|
92
|
+
)
|
93
|
+
except Exception:
|
94
|
+
pass
|
95
|
+
raise
|
96
|
+
if isinstance(js, dict):
|
97
|
+
evs = js.get("events") or js.get("data")
|
98
|
+
if isinstance(evs, list):
|
99
|
+
return evs
|
100
|
+
return []
|
101
|
+
|
102
|
+
async def get_metrics(self, job_id: str, *, after_step: int = -1, limit: int = 200) -> List[Dict[str, Any]]:
|
103
|
+
params = {"after_step": after_step, "limit": limit}
|
104
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=30.0) as http:
|
105
|
+
js = await http.get(f"{_api_base(self._base_url)}/learning/jobs/{job_id}/metrics", params=params)
|
106
|
+
if isinstance(js, dict) and isinstance(js.get("points"), list):
|
107
|
+
return js["points"]
|
108
|
+
return []
|
109
|
+
|
110
|
+
async def poll_until_terminal(
|
111
|
+
self,
|
112
|
+
job_id: str,
|
113
|
+
*,
|
114
|
+
interval_seconds: float = 2.0,
|
115
|
+
max_seconds: float | None = None,
|
116
|
+
empty_polls_threshold: int = 5,
|
117
|
+
startup_deadline_s: int = 45,
|
118
|
+
on_event: Optional[Callable[[Dict[str, Any]], None]] = None,
|
119
|
+
on_metric: Optional[Callable[[Dict[str, Any]], None]] = None,
|
120
|
+
) -> Dict[str, Any]:
|
121
|
+
last_seq_by_stream: Dict[str, int] = {}
|
122
|
+
events_job_id: Optional[str] = None
|
123
|
+
last_status: Optional[str] = None
|
124
|
+
last_step_by_name: Dict[str, int] = {}
|
125
|
+
empty_polls = 0
|
126
|
+
saw_any_event = False
|
127
|
+
start_t = time.time()
|
128
|
+
terminal = {"succeeded", "failed", "cancelled", "canceled", "error", "completed"}
|
129
|
+
|
130
|
+
while True:
|
131
|
+
status_data: Optional[Dict[str, Any]] = None
|
132
|
+
try:
|
133
|
+
status_data = await self.get_job(job_id)
|
134
|
+
except Exception:
|
135
|
+
status_data = None
|
136
|
+
if status_data is None:
|
137
|
+
try:
|
138
|
+
print(f"[poll] get_job returned None base={self._base_url} job_id={job_id}")
|
139
|
+
except Exception:
|
140
|
+
pass
|
141
|
+
status = str((status_data or {}).get("status") or "").lower()
|
142
|
+
if status_data:
|
143
|
+
linked = status_data.get("linked_job_id")
|
144
|
+
if isinstance(linked, str) and linked and linked != events_job_id:
|
145
|
+
events_job_id = linked
|
146
|
+
try:
|
147
|
+
print(f"[poll] discovered linked_job_id stream={events_job_id}")
|
148
|
+
except Exception:
|
149
|
+
pass
|
150
|
+
if status and status != last_status:
|
151
|
+
last_status = status
|
152
|
+
# Status transitions only to avoid log spam
|
153
|
+
if on_event:
|
154
|
+
try:
|
155
|
+
on_event({"type": "rl.status", "message": status})
|
156
|
+
except Exception:
|
157
|
+
pass
|
158
|
+
|
159
|
+
# Events
|
160
|
+
stream_ids = [job_id]
|
161
|
+
if events_job_id and events_job_id not in stream_ids:
|
162
|
+
stream_ids.append(events_job_id)
|
163
|
+
try:
|
164
|
+
print(f"[poll] streams={stream_ids} intervals={interval_seconds}s since_map={last_seq_by_stream} empty_polls={empty_polls}")
|
165
|
+
except Exception:
|
166
|
+
pass
|
167
|
+
total_events_this_cycle = 0
|
168
|
+
terminal_event_seen = False
|
169
|
+
terminal_event_status: Optional[str] = None
|
170
|
+
for ev_id in stream_ids:
|
171
|
+
since = last_seq_by_stream.get(ev_id, 0)
|
172
|
+
try:
|
173
|
+
events = await self.get_events(ev_id, since_seq=since, limit=200)
|
174
|
+
except HTTPError as he:
|
175
|
+
try:
|
176
|
+
print(f"[poll] get_events error status={he.status} url={he.url} since={since} body={(he.body_snippet or '')[:200]}")
|
177
|
+
except Exception:
|
178
|
+
pass
|
179
|
+
events = []
|
180
|
+
except Exception as e:
|
181
|
+
try:
|
182
|
+
print(f"[poll] get_events unexpected error ev_id={ev_id} since={since} err={type(e).__name__}: {e}")
|
183
|
+
except Exception:
|
184
|
+
pass
|
185
|
+
events = []
|
186
|
+
total_events_this_cycle += len(events)
|
187
|
+
if events:
|
188
|
+
saw_any_event = True
|
189
|
+
for e in events:
|
190
|
+
seq_val = int(e.get("seq") or 0)
|
191
|
+
if seq_val <= last_seq_by_stream.get(ev_id, 0):
|
192
|
+
continue
|
193
|
+
last_seq_by_stream[ev_id] = seq_val
|
194
|
+
if on_event:
|
195
|
+
try:
|
196
|
+
on_event(e)
|
197
|
+
except Exception:
|
198
|
+
pass
|
199
|
+
et = str(e.get("type") or e.get("event_type") or "").lower()
|
200
|
+
if et in ("rl.job.completed", "workflow.completed", "rl.train.completed"):
|
201
|
+
terminal_event_seen = True
|
202
|
+
terminal_event_status = "succeeded"
|
203
|
+
elif et in ("rl.job.failed", "workflow.failed"):
|
204
|
+
terminal_event_seen = True
|
205
|
+
terminal_event_status = "failed"
|
206
|
+
|
207
|
+
# Metrics
|
208
|
+
try:
|
209
|
+
after = max(last_step_by_name.values()) if last_step_by_name else -1
|
210
|
+
points = await self.get_metrics(job_id, after_step=after, limit=200)
|
211
|
+
for p in points:
|
212
|
+
name = str(p.get("name") or "")
|
213
|
+
step = int(p.get("step") or -1)
|
214
|
+
if step <= last_step_by_name.get(name, -1):
|
215
|
+
continue
|
216
|
+
last_step_by_name[name] = step
|
217
|
+
if on_metric:
|
218
|
+
try:
|
219
|
+
on_metric(p)
|
220
|
+
except Exception:
|
221
|
+
pass
|
222
|
+
except Exception:
|
223
|
+
pass
|
224
|
+
|
225
|
+
if terminal_event_seen:
|
226
|
+
return {"status": terminal_event_status or status or "completed", "job_id": job_id}
|
227
|
+
if status and status in terminal:
|
228
|
+
return {"status": status, "job_id": job_id}
|
229
|
+
|
230
|
+
if total_events_this_cycle == 0:
|
231
|
+
empty_polls += 1
|
232
|
+
else:
|
233
|
+
empty_polls = 0
|
234
|
+
if empty_polls >= max(1, int(empty_polls_threshold)):
|
235
|
+
try:
|
236
|
+
print(
|
237
|
+
f"[poll] threshold hit: empty_polls={empty_polls} >= {empty_polls_threshold} streams={stream_ids} last_seq_map={last_seq_by_stream}"
|
238
|
+
)
|
239
|
+
except Exception:
|
240
|
+
pass
|
241
|
+
raise AssertionError(f"No new events detected for {empty_polls_threshold} consecutive polls. Check event ingestion.")
|
242
|
+
|
243
|
+
if not saw_any_event and (time.time() - start_t) > int(startup_deadline_s):
|
244
|
+
try:
|
245
|
+
print(
|
246
|
+
f"[poll] startup window exceeded: {startup_deadline_s}s base={self._base_url} job={job_id} streams={stream_ids} last_seq_map={last_seq_by_stream}"
|
247
|
+
)
|
248
|
+
except Exception:
|
249
|
+
pass
|
250
|
+
raise AssertionError(f"No events observed within startup window ({startup_deadline_s}s). Investigate event streaming.")
|
251
|
+
|
252
|
+
await sleep(interval_seconds)
|
253
|
+
if max_seconds is not None and (time.time() - start_t) >= max_seconds:
|
254
|
+
raise TimeoutError(f"Polling timed out after {max_seconds}s for job {job_id}")
|
255
|
+
|
256
|
+
|
synth_ai/learning/sse.py
ADDED
@@ -0,0 +1,58 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import json
|
4
|
+
import time
|
5
|
+
from typing import Any, Callable, Optional
|
6
|
+
|
7
|
+
import aiohttp
|
8
|
+
|
9
|
+
|
10
|
+
def _api_base(b: str) -> str:
|
11
|
+
b = (b or "").rstrip("/")
|
12
|
+
return b if b.endswith("/api") else f"{b}/api"
|
13
|
+
|
14
|
+
|
15
|
+
async def stream_events(
|
16
|
+
base_url: str,
|
17
|
+
api_key: str,
|
18
|
+
job_id: str,
|
19
|
+
*,
|
20
|
+
seconds: int = 60,
|
21
|
+
on_event: Optional[Callable[[dict], None]] = None,
|
22
|
+
) -> None:
|
23
|
+
if seconds <= 0:
|
24
|
+
return
|
25
|
+
headers = {"Accept": "text/event-stream", "Authorization": f"Bearer {api_key}"}
|
26
|
+
candidates = [
|
27
|
+
f"{_api_base(base_url)}/rl/jobs/{job_id}/events?since_seq=0",
|
28
|
+
f"{_api_base(base_url)}/learning/jobs/{job_id}/events?since_seq=0",
|
29
|
+
]
|
30
|
+
for url in candidates:
|
31
|
+
try:
|
32
|
+
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=None)) as session:
|
33
|
+
async with session.get(url, headers=headers) as resp:
|
34
|
+
if resp.status != 200:
|
35
|
+
continue
|
36
|
+
start_t = time.time()
|
37
|
+
async for raw in resp.content:
|
38
|
+
line = raw.decode(errors="ignore").strip()
|
39
|
+
if not line or line.startswith(":"):
|
40
|
+
continue
|
41
|
+
if not line.startswith("data:"):
|
42
|
+
continue
|
43
|
+
data = line[5:].strip()
|
44
|
+
try:
|
45
|
+
obj = json.loads(data)
|
46
|
+
except Exception:
|
47
|
+
continue
|
48
|
+
if on_event:
|
49
|
+
try:
|
50
|
+
on_event(obj)
|
51
|
+
except Exception:
|
52
|
+
pass
|
53
|
+
if (time.time() - start_t) >= seconds:
|
54
|
+
return
|
55
|
+
except Exception:
|
56
|
+
continue
|
57
|
+
|
58
|
+
|
@@ -0,0 +1,48 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from pathlib import Path
|
4
|
+
import json
|
5
|
+
from typing import Any, Dict
|
6
|
+
from urllib.parse import urlparse
|
7
|
+
|
8
|
+
|
9
|
+
def validate_training_jsonl(path: str | Path, *, sample_lines: int = 50) -> None:
|
10
|
+
p = Path(path)
|
11
|
+
if not p.exists():
|
12
|
+
raise FileNotFoundError(str(p))
|
13
|
+
lines = p.read_text().splitlines()
|
14
|
+
if not lines:
|
15
|
+
raise ValueError("empty JSONL")
|
16
|
+
for i, line in enumerate(lines[: max(1, sample_lines) ], start=1):
|
17
|
+
if not line.strip():
|
18
|
+
continue
|
19
|
+
try:
|
20
|
+
obj = json.loads(line)
|
21
|
+
except Exception as e:
|
22
|
+
raise ValueError(f"invalid json on line {i}: {e}") from e
|
23
|
+
msgs = obj.get("messages")
|
24
|
+
if not isinstance(msgs, list) or len(msgs) < 2:
|
25
|
+
raise ValueError(f"line {i}: missing messages[] with at least 2 turns")
|
26
|
+
roles = [m.get("role") for m in msgs if isinstance(m, dict)]
|
27
|
+
if not roles or not isinstance(roles[0], str):
|
28
|
+
raise ValueError(f"line {i}: missing first role")
|
29
|
+
for m in msgs:
|
30
|
+
if not isinstance(m, dict):
|
31
|
+
raise ValueError(f"line {i}: non-dict message")
|
32
|
+
if not isinstance(m.get("role"), str) or not isinstance(m.get("content"), str) or not m["content"].strip():
|
33
|
+
raise ValueError(f"line {i}: invalid role/content")
|
34
|
+
|
35
|
+
|
36
|
+
def validate_task_app_url(url: str, *, name: str = "TASK_APP_BASE_URL") -> None:
|
37
|
+
from synth_ai.task.validators import validate_task_app_url as _vt
|
38
|
+
|
39
|
+
_vt(url, name=name)
|
40
|
+
|
41
|
+
|
42
|
+
def validate_trainer_cfg_rl(trainer: Dict[str, Any]) -> None:
|
43
|
+
bs = int(trainer.get("batch_size", 1))
|
44
|
+
gs = int(trainer.get("group_size", 2))
|
45
|
+
if bs < 1:
|
46
|
+
raise ValueError("trainer.batch_size must be >= 1")
|
47
|
+
if gs < 2:
|
48
|
+
raise ValueError("trainer.group_size must be >= 2")
|
synth_ai/lm/__init__.py
CHANGED
@@ -4,24 +4,24 @@ Synth AI Language Model Interface.
|
|
4
4
|
Provides a unified interface for multiple LLM providers including OpenAI and Synth.
|
5
5
|
"""
|
6
6
|
|
7
|
-
from .config import
|
8
|
-
from .
|
7
|
+
from .config import OpenAIConfig, SynthConfig
|
8
|
+
from .core.main_v3 import LM
|
9
9
|
from .unified_interface import (
|
10
|
-
UnifiedLMProvider,
|
11
10
|
OpenAIProvider,
|
12
11
|
SynthProvider,
|
13
12
|
UnifiedLMClient,
|
13
|
+
UnifiedLMProvider,
|
14
14
|
create_provider,
|
15
15
|
)
|
16
16
|
from .vendors.synth_client import (
|
17
17
|
AsyncSynthClient,
|
18
18
|
SyncSynthClient,
|
19
19
|
create_async_client,
|
20
|
-
create_sync_client,
|
21
20
|
create_chat_completion_async,
|
22
21
|
create_chat_completion_sync,
|
22
|
+
create_sync_client,
|
23
23
|
)
|
24
|
-
from .
|
24
|
+
from .warmup import get_warmup_status, warmup_synth_model
|
25
25
|
|
26
26
|
__all__ = [
|
27
27
|
# Configuration
|