synth-ai 0.2.4.dev6__py3-none-any.whl → 0.2.4.dev8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +18 -9
- synth_ai/cli/__init__.py +10 -5
- synth_ai/cli/balance.py +25 -32
- synth_ai/cli/calc.py +2 -3
- synth_ai/cli/demo.py +3 -5
- synth_ai/cli/legacy_root_backup.py +58 -32
- synth_ai/cli/man.py +22 -19
- synth_ai/cli/recent.py +9 -8
- synth_ai/cli/root.py +58 -13
- synth_ai/cli/status.py +13 -6
- synth_ai/cli/traces.py +45 -21
- synth_ai/cli/watch.py +40 -37
- synth_ai/config/base_url.py +47 -2
- synth_ai/core/experiment.py +1 -2
- synth_ai/environments/__init__.py +2 -6
- synth_ai/environments/environment/artifacts/base.py +3 -1
- synth_ai/environments/environment/db/sqlite.py +1 -1
- synth_ai/environments/environment/registry.py +19 -20
- synth_ai/environments/environment/resources/sqlite.py +2 -3
- synth_ai/environments/environment/rewards/core.py +3 -2
- synth_ai/environments/environment/tools/__init__.py +6 -4
- synth_ai/environments/examples/crafter_classic/__init__.py +1 -1
- synth_ai/environments/examples/crafter_classic/engine.py +13 -13
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +1 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +2 -1
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +2 -1
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +3 -2
- synth_ai/environments/examples/crafter_classic/environment.py +16 -15
- synth_ai/environments/examples/crafter_classic/taskset.py +2 -2
- synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +2 -3
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +2 -1
- synth_ai/environments/examples/crafter_custom/crafter/__init__.py +2 -2
- synth_ai/environments/examples/crafter_custom/crafter/config.py +2 -2
- synth_ai/environments/examples/crafter_custom/crafter/env.py +1 -5
- synth_ai/environments/examples/crafter_custom/crafter/objects.py +1 -2
- synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +1 -2
- synth_ai/environments/examples/crafter_custom/dataset_builder.py +5 -5
- synth_ai/environments/examples/crafter_custom/environment.py +13 -13
- synth_ai/environments/examples/crafter_custom/run_dataset.py +5 -5
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +2 -2
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +5 -4
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +2 -1
- synth_ai/environments/examples/enron/engine.py +18 -14
- synth_ai/environments/examples/enron/environment.py +12 -11
- synth_ai/environments/examples/enron/taskset.py +7 -7
- synth_ai/environments/examples/minigrid/__init__.py +6 -6
- synth_ai/environments/examples/minigrid/engine.py +6 -6
- synth_ai/environments/examples/minigrid/environment.py +6 -6
- synth_ai/environments/examples/minigrid/puzzle_loader.py +3 -2
- synth_ai/environments/examples/minigrid/taskset.py +13 -13
- synth_ai/environments/examples/nethack/achievements.py +1 -1
- synth_ai/environments/examples/nethack/engine.py +8 -7
- synth_ai/environments/examples/nethack/environment.py +10 -9
- synth_ai/environments/examples/nethack/helpers/__init__.py +8 -9
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +1 -1
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +2 -1
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +1 -1
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +3 -4
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +6 -5
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +5 -5
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +7 -6
- synth_ai/environments/examples/nethack/taskset.py +5 -5
- synth_ai/environments/examples/red/engine.py +9 -8
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +7 -7
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +3 -2
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +2 -1
- synth_ai/environments/examples/red/environment.py +18 -15
- synth_ai/environments/examples/red/taskset.py +5 -3
- synth_ai/environments/examples/sokoban/engine.py +16 -13
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +3 -2
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +2 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +1 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +7 -5
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +1 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +2 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +5 -4
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +3 -2
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +2 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +5 -4
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +1 -1
- synth_ai/environments/examples/sokoban/environment.py +15 -14
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +5 -3
- synth_ai/environments/examples/sokoban/puzzle_loader.py +3 -2
- synth_ai/environments/examples/sokoban/taskset.py +13 -10
- synth_ai/environments/examples/tictactoe/engine.py +6 -6
- synth_ai/environments/examples/tictactoe/environment.py +8 -7
- synth_ai/environments/examples/tictactoe/taskset.py +6 -5
- synth_ai/environments/examples/verilog/engine.py +4 -3
- synth_ai/environments/examples/verilog/environment.py +11 -10
- synth_ai/environments/examples/verilog/taskset.py +14 -12
- synth_ai/environments/examples/wordle/__init__.py +5 -5
- synth_ai/environments/examples/wordle/engine.py +32 -25
- synth_ai/environments/examples/wordle/environment.py +21 -16
- synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +6 -6
- synth_ai/environments/examples/wordle/taskset.py +20 -12
- synth_ai/environments/reproducibility/core.py +1 -1
- synth_ai/environments/reproducibility/tree.py +21 -21
- synth_ai/environments/service/app.py +3 -2
- synth_ai/environments/service/core_routes.py +104 -110
- synth_ai/environments/service/external_registry.py +1 -2
- synth_ai/environments/service/registry.py +1 -1
- synth_ai/environments/stateful/core.py +1 -2
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/api.py +4 -4
- synth_ai/environments/tasks/core.py +14 -12
- synth_ai/environments/tasks/filters.py +6 -4
- synth_ai/environments/tasks/utils.py +13 -11
- synth_ai/evals/base.py +2 -3
- synth_ai/experimental/synth_oss.py +4 -4
- synth_ai/http.py +102 -0
- synth_ai/inference/__init__.py +7 -0
- synth_ai/inference/client.py +20 -0
- synth_ai/jobs/client.py +246 -0
- synth_ai/learning/__init__.py +24 -0
- synth_ai/learning/client.py +149 -0
- synth_ai/learning/config.py +43 -0
- synth_ai/learning/constants.py +29 -0
- synth_ai/learning/ft_client.py +59 -0
- synth_ai/learning/gateway.py +1 -3
- synth_ai/learning/health.py +43 -0
- synth_ai/learning/jobs.py +205 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +15 -10
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +26 -14
- synth_ai/learning/prompts/mipro.py +61 -52
- synth_ai/learning/prompts/random_search.py +42 -43
- synth_ai/learning/prompts/run_mipro_banking77.py +32 -20
- synth_ai/learning/prompts/run_random_search_banking77.py +71 -52
- synth_ai/learning/rl_client.py +256 -0
- synth_ai/learning/sse.py +58 -0
- synth_ai/learning/validators.py +48 -0
- synth_ai/lm/__init__.py +5 -5
- synth_ai/lm/caching/ephemeral.py +9 -9
- synth_ai/lm/caching/handler.py +20 -20
- synth_ai/lm/caching/persistent.py +10 -10
- synth_ai/lm/config.py +3 -3
- synth_ai/lm/constants.py +7 -7
- synth_ai/lm/core/all.py +17 -3
- synth_ai/lm/core/exceptions.py +0 -2
- synth_ai/lm/core/main.py +26 -41
- synth_ai/lm/core/main_v3.py +33 -10
- synth_ai/lm/core/synth_models.py +48 -0
- synth_ai/lm/core/vendor_clients.py +26 -22
- synth_ai/lm/injection.py +7 -8
- synth_ai/lm/overrides.py +21 -19
- synth_ai/lm/provider_support/__init__.py +1 -1
- synth_ai/lm/provider_support/anthropic.py +15 -15
- synth_ai/lm/provider_support/openai.py +23 -21
- synth_ai/lm/structured_outputs/handler.py +34 -32
- synth_ai/lm/structured_outputs/inject.py +24 -27
- synth_ai/lm/structured_outputs/rehabilitate.py +19 -15
- synth_ai/lm/tools/base.py +17 -16
- synth_ai/lm/unified_interface.py +17 -18
- synth_ai/lm/vendors/base.py +20 -18
- synth_ai/lm/vendors/core/anthropic_api.py +36 -27
- synth_ai/lm/vendors/core/gemini_api.py +31 -36
- synth_ai/lm/vendors/core/mistral_api.py +19 -19
- synth_ai/lm/vendors/core/openai_api.py +42 -13
- synth_ai/lm/vendors/openai_standard.py +158 -101
- synth_ai/lm/vendors/openai_standard_responses.py +74 -61
- synth_ai/lm/vendors/retries.py +9 -1
- synth_ai/lm/vendors/supported/custom_endpoint.py +38 -28
- synth_ai/lm/vendors/supported/deepseek.py +10 -10
- synth_ai/lm/vendors/supported/grok.py +8 -8
- synth_ai/lm/vendors/supported/ollama.py +2 -1
- synth_ai/lm/vendors/supported/openrouter.py +11 -9
- synth_ai/lm/vendors/synth_client.py +425 -75
- synth_ai/lm/warmup.py +8 -7
- synth_ai/rl/__init__.py +30 -0
- synth_ai/rl/contracts.py +32 -0
- synth_ai/rl/env_keys.py +137 -0
- synth_ai/rl/secrets.py +19 -0
- synth_ai/scripts/verify_rewards.py +100 -0
- synth_ai/task/__init__.py +10 -0
- synth_ai/task/contracts.py +120 -0
- synth_ai/task/health.py +28 -0
- synth_ai/task/validators.py +12 -0
- synth_ai/tracing/__init__.py +22 -10
- synth_ai/tracing_v1/__init__.py +22 -20
- synth_ai/tracing_v3/__init__.py +7 -7
- synth_ai/tracing_v3/abstractions.py +56 -52
- synth_ai/tracing_v3/config.py +4 -2
- synth_ai/tracing_v3/db_config.py +6 -8
- synth_ai/tracing_v3/decorators.py +29 -30
- synth_ai/tracing_v3/examples/basic_usage.py +12 -12
- synth_ai/tracing_v3/hooks.py +24 -22
- synth_ai/tracing_v3/llm_call_record_helpers.py +85 -98
- synth_ai/tracing_v3/lm_call_record_abstractions.py +2 -4
- synth_ai/tracing_v3/migration_helper.py +3 -5
- synth_ai/tracing_v3/replica_sync.py +30 -32
- synth_ai/tracing_v3/session_tracer.py +158 -31
- synth_ai/tracing_v3/storage/__init__.py +1 -1
- synth_ai/tracing_v3/storage/base.py +8 -7
- synth_ai/tracing_v3/storage/config.py +4 -4
- synth_ai/tracing_v3/storage/factory.py +4 -4
- synth_ai/tracing_v3/storage/utils.py +9 -9
- synth_ai/tracing_v3/turso/__init__.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +9 -9
- synth_ai/tracing_v3/turso/manager.py +278 -48
- synth_ai/tracing_v3/turso/models.py +77 -19
- synth_ai/tracing_v3/utils.py +5 -5
- synth_ai/v0/tracing/abstractions.py +28 -28
- synth_ai/v0/tracing/base_client.py +9 -9
- synth_ai/v0/tracing/client_manager.py +7 -7
- synth_ai/v0/tracing/config.py +7 -7
- synth_ai/v0/tracing/context.py +6 -6
- synth_ai/v0/tracing/decorators.py +6 -5
- synth_ai/v0/tracing/events/manage.py +1 -1
- synth_ai/v0/tracing/events/store.py +5 -4
- synth_ai/v0/tracing/immediate_client.py +4 -5
- synth_ai/v0/tracing/local.py +3 -3
- synth_ai/v0/tracing/log_client_base.py +4 -5
- synth_ai/v0/tracing/retry_queue.py +5 -6
- synth_ai/v0/tracing/trackers.py +25 -25
- synth_ai/v0/tracing/upload.py +6 -0
- synth_ai/v0/tracing_v1/__init__.py +1 -1
- synth_ai/v0/tracing_v1/abstractions.py +28 -28
- synth_ai/v0/tracing_v1/base_client.py +9 -9
- synth_ai/v0/tracing_v1/client_manager.py +7 -7
- synth_ai/v0/tracing_v1/config.py +7 -7
- synth_ai/v0/tracing_v1/context.py +6 -6
- synth_ai/v0/tracing_v1/decorators.py +7 -6
- synth_ai/v0/tracing_v1/events/manage.py +1 -1
- synth_ai/v0/tracing_v1/events/store.py +5 -4
- synth_ai/v0/tracing_v1/immediate_client.py +4 -5
- synth_ai/v0/tracing_v1/local.py +3 -3
- synth_ai/v0/tracing_v1/log_client_base.py +4 -5
- synth_ai/v0/tracing_v1/retry_queue.py +5 -6
- synth_ai/v0/tracing_v1/trackers.py +25 -25
- synth_ai/v0/tracing_v1/upload.py +25 -24
- synth_ai/zyk/__init__.py +1 -0
- synth_ai-0.2.4.dev8.dist-info/METADATA +635 -0
- synth_ai-0.2.4.dev8.dist-info/RECORD +317 -0
- synth_ai/tui/__init__.py +0 -1
- synth_ai/tui/__main__.py +0 -13
- synth_ai/tui/cli/__init__.py +0 -1
- synth_ai/tui/cli/query_experiments.py +0 -165
- synth_ai/tui/cli/query_experiments_v3.py +0 -165
- synth_ai/tui/dashboard.py +0 -329
- synth_ai-0.2.4.dev6.dist-info/METADATA +0 -203
- synth_ai-0.2.4.dev6.dist-info/RECORD +0 -299
- {synth_ai-0.2.4.dev6.dist-info → synth_ai-0.2.4.dev8.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev6.dist-info → synth_ai-0.2.4.dev8.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev6.dist-info → synth_ai-0.2.4.dev8.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev6.dist-info → synth_ai-0.2.4.dev8.dist-info}/top_level.txt +0 -0
@@ -8,10 +8,12 @@ in the requested structured format (Pydantic models).
|
|
8
8
|
import logging
|
9
9
|
import time
|
10
10
|
from abc import ABC, abstractmethod
|
11
|
-
from
|
11
|
+
from collections.abc import Callable
|
12
|
+
from typing import Any, Literal
|
12
13
|
|
13
14
|
from pydantic import BaseModel
|
14
15
|
|
16
|
+
from synth_ai.lm.constants import SPECIAL_BASE_TEMPS
|
15
17
|
from synth_ai.lm.core.exceptions import StructuredOutputCoercionFailureException
|
16
18
|
from synth_ai.lm.structured_outputs.inject import (
|
17
19
|
inject_structured_output_instructions,
|
@@ -22,7 +24,6 @@ from synth_ai.lm.structured_outputs.rehabilitate import (
|
|
22
24
|
pull_out_structured_output,
|
23
25
|
)
|
24
26
|
from synth_ai.lm.vendors.base import BaseLMResponse, VendorBase
|
25
|
-
from synth_ai.lm.constants import SPECIAL_BASE_TEMPS
|
26
27
|
|
27
28
|
logger = logging.getLogger(__name__)
|
28
29
|
|
@@ -30,26 +31,27 @@ logger = logging.getLogger(__name__)
|
|
30
31
|
class StructuredHandlerBase(ABC):
|
31
32
|
"""
|
32
33
|
Abstract base class for structured output handlers.
|
33
|
-
|
34
|
+
|
34
35
|
Handles the logic for ensuring language models return properly formatted
|
35
36
|
structured outputs, with retry logic and error handling.
|
36
|
-
|
37
|
+
|
37
38
|
Attributes:
|
38
39
|
core_client: Primary vendor client for API calls
|
39
40
|
retry_client: Client used for retry attempts (may use different model)
|
40
41
|
handler_params: Configuration parameters including retry count
|
41
42
|
structured_output_mode: Either "stringified_json" or "forced_json"
|
42
43
|
"""
|
44
|
+
|
43
45
|
core_client: VendorBase
|
44
46
|
retry_client: VendorBase
|
45
|
-
handler_params:
|
47
|
+
handler_params: dict[str, Any]
|
46
48
|
structured_output_mode: Literal["stringified_json", "forced_json"]
|
47
49
|
|
48
50
|
def __init__(
|
49
51
|
self,
|
50
52
|
core_client: VendorBase,
|
51
53
|
retry_client: VendorBase,
|
52
|
-
handler_params:
|
54
|
+
handler_params: dict[str, Any] | None = None,
|
53
55
|
structured_output_mode: Literal["stringified_json", "forced_json"] = "stringified_json",
|
54
56
|
):
|
55
57
|
self.core_client = core_client
|
@@ -59,7 +61,7 @@ class StructuredHandlerBase(ABC):
|
|
59
61
|
|
60
62
|
async def call_async(
|
61
63
|
self,
|
62
|
-
messages:
|
64
|
+
messages: list[dict[str, Any]],
|
63
65
|
model: str,
|
64
66
|
response_model: BaseModel,
|
65
67
|
temperature: float = 0.0,
|
@@ -74,7 +76,7 @@ class StructuredHandlerBase(ABC):
|
|
74
76
|
model=model,
|
75
77
|
response_model=response_model,
|
76
78
|
api_call_method=self.core_client._hit_api_async_structured_output
|
77
|
-
if (
|
79
|
+
if (response_model and self.structured_output_mode == "forced_json")
|
78
80
|
else self.core_client._hit_api_async,
|
79
81
|
temperature=temperature,
|
80
82
|
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
@@ -83,7 +85,7 @@ class StructuredHandlerBase(ABC):
|
|
83
85
|
|
84
86
|
def call_sync(
|
85
87
|
self,
|
86
|
-
messages:
|
88
|
+
messages: list[dict[str, Any]],
|
87
89
|
response_model: BaseModel,
|
88
90
|
model: str,
|
89
91
|
temperature: float = 0.0,
|
@@ -97,7 +99,7 @@ class StructuredHandlerBase(ABC):
|
|
97
99
|
model=model,
|
98
100
|
response_model=response_model,
|
99
101
|
api_call_method=self.core_client._hit_api_sync_structured_output
|
100
|
-
if (
|
102
|
+
if (response_model and self.structured_output_mode == "forced_json")
|
101
103
|
else self.core_client._hit_api_sync,
|
102
104
|
temperature=temperature,
|
103
105
|
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
@@ -107,7 +109,7 @@ class StructuredHandlerBase(ABC):
|
|
107
109
|
@abstractmethod
|
108
110
|
async def _process_call_async(
|
109
111
|
self,
|
110
|
-
messages:
|
112
|
+
messages: list[dict[str, Any]],
|
111
113
|
model: str,
|
112
114
|
response_model: BaseModel,
|
113
115
|
api_call_method,
|
@@ -119,7 +121,7 @@ class StructuredHandlerBase(ABC):
|
|
119
121
|
@abstractmethod
|
120
122
|
def _process_call_sync(
|
121
123
|
self,
|
122
|
-
messages:
|
124
|
+
messages: list[dict[str, Any]],
|
123
125
|
model: str,
|
124
126
|
response_model: BaseModel,
|
125
127
|
api_call_method,
|
@@ -132,24 +134,24 @@ class StructuredHandlerBase(ABC):
|
|
132
134
|
class StringifiedJSONHandler(StructuredHandlerBase):
|
133
135
|
core_client: VendorBase
|
134
136
|
retry_client: VendorBase
|
135
|
-
handler_params:
|
137
|
+
handler_params: dict[str, Any]
|
136
138
|
|
137
139
|
def __init__(
|
138
140
|
self,
|
139
141
|
core_client: VendorBase,
|
140
142
|
retry_client: VendorBase,
|
141
|
-
handler_params:
|
143
|
+
handler_params: dict[str, Any] | None = None,
|
142
144
|
):
|
143
145
|
super().__init__(
|
144
146
|
core_client,
|
145
147
|
retry_client,
|
146
|
-
handler_params,
|
148
|
+
handler_params or {"retries": 3},
|
147
149
|
structured_output_mode="stringified_json",
|
148
150
|
)
|
149
151
|
|
150
152
|
async def _process_call_async(
|
151
153
|
self,
|
152
|
-
messages:
|
154
|
+
messages: list[dict[str, Any]],
|
153
155
|
model: str,
|
154
156
|
response_model: BaseModel,
|
155
157
|
temperature: float,
|
@@ -170,7 +172,7 @@ class StringifiedJSONHandler(StructuredHandlerBase):
|
|
170
172
|
response_model=response_model,
|
171
173
|
previously_failed_error_messages=previously_failed_error_messages,
|
172
174
|
)
|
173
|
-
t0 = time.time()
|
175
|
+
# t0 = time.time() # unused
|
174
176
|
raw_text_response_or_cached_hit = await api_call_method(
|
175
177
|
messages=messages_with_json_formatting_instructions,
|
176
178
|
model=model,
|
@@ -184,7 +186,7 @@ class StringifiedJSONHandler(StructuredHandlerBase):
|
|
184
186
|
assert type(raw_text_response_or_cached_hit) in [str, BaseLMResponse], (
|
185
187
|
f"Expected str or BaseLMResponse, got {type(raw_text_response_or_cached_hit)}"
|
186
188
|
)
|
187
|
-
if
|
189
|
+
if isinstance(raw_text_response_or_cached_hit, BaseLMResponse):
|
188
190
|
# print("Got cached hit, returning directly")
|
189
191
|
raw_text_response = raw_text_response_or_cached_hit.raw_response
|
190
192
|
else:
|
@@ -242,7 +244,7 @@ class StringifiedJSONHandler(StructuredHandlerBase):
|
|
242
244
|
|
243
245
|
def _process_call_sync(
|
244
246
|
self,
|
245
|
-
messages:
|
247
|
+
messages: list[dict[str, Any]],
|
246
248
|
model: str,
|
247
249
|
response_model: BaseModel,
|
248
250
|
temperature: float,
|
@@ -277,7 +279,7 @@ class StringifiedJSONHandler(StructuredHandlerBase):
|
|
277
279
|
assert type(raw_text_response_or_cached_hit) in [str, BaseLMResponse], (
|
278
280
|
f"Expected str or BaseLMResponse, got {type(raw_text_response_or_cached_hit)}"
|
279
281
|
)
|
280
|
-
if
|
282
|
+
if isinstance(raw_text_response_or_cached_hit, BaseLMResponse):
|
281
283
|
logger.info("Got cached hit, returning directly")
|
282
284
|
raw_text_response = raw_text_response_or_cached_hit.raw_response
|
283
285
|
else:
|
@@ -320,26 +322,26 @@ class StringifiedJSONHandler(StructuredHandlerBase):
|
|
320
322
|
class ForcedJSONHandler(StructuredHandlerBase):
|
321
323
|
core_client: VendorBase
|
322
324
|
retry_client: VendorBase
|
323
|
-
handler_params:
|
325
|
+
handler_params: dict[str, Any]
|
324
326
|
|
325
327
|
def __init__(
|
326
328
|
self,
|
327
329
|
core_client: VendorBase,
|
328
330
|
retry_client: VendorBase,
|
329
|
-
handler_params:
|
331
|
+
handler_params: dict[str, Any] | None = None,
|
330
332
|
reasoning_effort: str = "high",
|
331
333
|
):
|
332
334
|
super().__init__(
|
333
335
|
core_client,
|
334
336
|
retry_client,
|
335
|
-
handler_params,
|
337
|
+
handler_params or {"retries": 3},
|
336
338
|
structured_output_mode="forced_json",
|
337
339
|
)
|
338
340
|
self.reasoning_effort = reasoning_effort
|
339
341
|
|
340
342
|
async def _process_call_async(
|
341
343
|
self,
|
342
|
-
messages:
|
344
|
+
messages: list[dict[str, Any]],
|
343
345
|
model: str,
|
344
346
|
response_model: BaseModel,
|
345
347
|
api_call_method: Callable,
|
@@ -360,7 +362,7 @@ class ForcedJSONHandler(StructuredHandlerBase):
|
|
360
362
|
|
361
363
|
def _process_call_sync(
|
362
364
|
self,
|
363
|
-
messages:
|
365
|
+
messages: list[dict[str, Any]],
|
364
366
|
model: str,
|
365
367
|
response_model: BaseModel,
|
366
368
|
api_call_method: Callable,
|
@@ -380,16 +382,16 @@ class ForcedJSONHandler(StructuredHandlerBase):
|
|
380
382
|
|
381
383
|
|
382
384
|
class StructuredOutputHandler:
|
383
|
-
handler:
|
385
|
+
handler: StringifiedJSONHandler | ForcedJSONHandler
|
384
386
|
mode: Literal["stringified_json", "forced_json"]
|
385
|
-
handler_params:
|
387
|
+
handler_params: dict[str, Any]
|
386
388
|
|
387
389
|
def __init__(
|
388
390
|
self,
|
389
391
|
core_client: VendorBase,
|
390
392
|
retry_client: VendorBase,
|
391
393
|
mode: Literal["stringified_json", "forced_json"],
|
392
|
-
handler_params:
|
394
|
+
handler_params: dict[str, Any] = {},
|
393
395
|
):
|
394
396
|
self.mode = mode
|
395
397
|
if self.mode == "stringified_json":
|
@@ -402,11 +404,11 @@ class StructuredOutputHandler:
|
|
402
404
|
|
403
405
|
async def call_async(
|
404
406
|
self,
|
405
|
-
messages:
|
407
|
+
messages: list[dict[str, Any]],
|
406
408
|
model: str,
|
407
409
|
response_model: BaseModel,
|
408
410
|
use_ephemeral_cache_only: bool = False,
|
409
|
-
lm_config:
|
411
|
+
lm_config: dict[str, Any] = {},
|
410
412
|
reasoning_effort: str = "high",
|
411
413
|
) -> BaseLMResponse:
|
412
414
|
# print("Output handler call async")
|
@@ -421,11 +423,11 @@ class StructuredOutputHandler:
|
|
421
423
|
|
422
424
|
def call_sync(
|
423
425
|
self,
|
424
|
-
messages:
|
426
|
+
messages: list[dict[str, Any]],
|
425
427
|
model: str,
|
426
428
|
response_model: BaseModel,
|
427
429
|
use_ephemeral_cache_only: bool = False,
|
428
|
-
lm_config:
|
430
|
+
lm_config: dict[str, Any] = {},
|
429
431
|
reasoning_effort: str = "high",
|
430
432
|
) -> BaseLMResponse:
|
431
433
|
return self.handler.call_sync(
|
@@ -1,22 +1,19 @@
|
|
1
1
|
import json
|
2
|
+
import warnings
|
2
3
|
from typing import (
|
3
4
|
Any,
|
4
|
-
|
5
|
-
List,
|
5
|
+
Literal,
|
6
6
|
Optional,
|
7
|
-
|
8
|
-
Type,
|
9
|
-
get_type_hints,
|
7
|
+
Union,
|
10
8
|
get_args,
|
11
9
|
get_origin,
|
12
|
-
|
13
|
-
Literal,
|
10
|
+
get_type_hints,
|
14
11
|
)
|
12
|
+
|
15
13
|
from pydantic import BaseModel
|
16
|
-
import warnings
|
17
14
|
|
18
15
|
|
19
|
-
def generate_type_map() ->
|
16
|
+
def generate_type_map() -> dict[Any, str]:
|
20
17
|
base_types = {
|
21
18
|
int: "int",
|
22
19
|
float: "float",
|
@@ -26,8 +23,8 @@ def generate_type_map() -> Dict[Any, str]:
|
|
26
23
|
}
|
27
24
|
|
28
25
|
collection_types = {
|
29
|
-
|
30
|
-
|
26
|
+
list: "List",
|
27
|
+
dict: "Dict",
|
31
28
|
Optional: "Optional",
|
32
29
|
}
|
33
30
|
|
@@ -37,19 +34,19 @@ def generate_type_map() -> Dict[Any, str]:
|
|
37
34
|
for collection, collection_name in collection_types.items():
|
38
35
|
if collection is Optional:
|
39
36
|
type_map[Optional[base_type]] = name
|
40
|
-
elif collection is
|
37
|
+
elif collection is dict:
|
41
38
|
# Handle generic Dict type
|
42
|
-
type_map[
|
39
|
+
type_map[dict] = "Dict[Any,Any]"
|
43
40
|
# Provide both key and value types for Dict
|
44
|
-
type_map[
|
41
|
+
type_map[dict[base_type, base_type]] = f"{collection_name}[{name},{name}]"
|
45
42
|
# Handle Dict[Any, Any] explicitly
|
46
|
-
type_map[
|
43
|
+
type_map[dict[Any, Any]] = "Dict[Any,Any]"
|
47
44
|
else:
|
48
45
|
type_map[collection[base_type]] = f"{collection_name}[{name}]"
|
49
46
|
return type_map
|
50
47
|
|
51
48
|
|
52
|
-
def generate_example_dict() ->
|
49
|
+
def generate_example_dict() -> dict[str, Any]:
|
53
50
|
example_values = {
|
54
51
|
"str": "<Your type-str response here>",
|
55
52
|
"int": "<Your type-int response here>",
|
@@ -101,10 +98,10 @@ def get_type_string(type_hint):
|
|
101
98
|
return f"{type_hint.__name__}({', '.join(f'{k}: {v}' for k, v in field_types.items())})"
|
102
99
|
else:
|
103
100
|
return base_type_examples.get(type_hint, ("Unknown", "unknown"))[0]
|
104
|
-
elif origin in (list,
|
101
|
+
elif origin in (list, list):
|
105
102
|
elem_type = get_type_string(args[0])
|
106
103
|
return f"List[{elem_type}]"
|
107
|
-
elif origin in (dict,
|
104
|
+
elif origin in (dict, dict):
|
108
105
|
key_type = get_type_string(args[0])
|
109
106
|
value_type = get_type_string(args[1])
|
110
107
|
return f"Dict[{key_type}, {value_type}]"
|
@@ -167,10 +164,10 @@ def get_example_value(type_hint):
|
|
167
164
|
return example, union_docs
|
168
165
|
else:
|
169
166
|
return base_type_examples.get(type_hint, ("Unknown", "unknown"))[1], []
|
170
|
-
elif origin in (list,
|
167
|
+
elif origin in (list, list):
|
171
168
|
value, docs = get_example_value(args[0])
|
172
169
|
return [value], docs
|
173
|
-
elif origin in (dict,
|
170
|
+
elif origin in (dict, dict):
|
174
171
|
if not args or len(args) < 2:
|
175
172
|
warnings.warn(
|
176
173
|
f"Dictionary type hint {type_hint} missing type arguments. "
|
@@ -224,9 +221,9 @@ def get_example_value(type_hint):
|
|
224
221
|
def add_json_instructions_to_messages(
|
225
222
|
system_message,
|
226
223
|
user_message,
|
227
|
-
response_model:
|
228
|
-
previously_failed_error_messages:
|
229
|
-
) ->
|
224
|
+
response_model: type[BaseModel] | None = None,
|
225
|
+
previously_failed_error_messages: list[str] = [],
|
226
|
+
) -> tuple[str, str]:
|
230
227
|
if response_model:
|
231
228
|
type_hints = get_type_hints(response_model)
|
232
229
|
# print("Type hints", type_hints)
|
@@ -283,10 +280,10 @@ Here are some error traces from previous attempts:
|
|
283
280
|
|
284
281
|
|
285
282
|
def inject_structured_output_instructions(
|
286
|
-
messages:
|
287
|
-
response_model:
|
288
|
-
previously_failed_error_messages:
|
289
|
-
) ->
|
283
|
+
messages: list[dict[str, str]],
|
284
|
+
response_model: type[BaseModel] | None = None,
|
285
|
+
previously_failed_error_messages: list[str] = [],
|
286
|
+
) -> list[dict[str, str]]:
|
290
287
|
prev_system_message_content = messages[0]["content"]
|
291
288
|
prev_user_message_content = messages[1]["content"]
|
292
289
|
system_message, user_message = add_json_instructions_to_messages(
|
@@ -2,15 +2,13 @@ import ast
|
|
2
2
|
import json
|
3
3
|
import logging
|
4
4
|
import re
|
5
|
-
from typing import Dict, List, Type, Union
|
6
5
|
|
7
6
|
from pydantic import BaseModel
|
8
7
|
|
9
|
-
from synth_ai.lm.vendors.base import VendorBase
|
10
8
|
from synth_ai.lm.vendors.core.openai_api import OpenAIStructuredOutputClient
|
11
9
|
|
12
10
|
|
13
|
-
def pull_out_structured_output(response_raw: str, response_model:
|
11
|
+
def pull_out_structured_output(response_raw: str, response_model: type[BaseModel]) -> BaseModel:
|
14
12
|
logger = logging.getLogger(__name__)
|
15
13
|
# logger.debug(f"Raw response received: {response_raw}")
|
16
14
|
|
@@ -36,7 +34,7 @@ def pull_out_structured_output(response_raw: str, response_model: Type[BaseModel
|
|
36
34
|
try:
|
37
35
|
response = json.loads(response_prepared)
|
38
36
|
final = response_model(**response)
|
39
|
-
except json.JSONDecodeError
|
37
|
+
except json.JSONDecodeError:
|
40
38
|
# Attempt to parse using ast.literal_eval as a fallback
|
41
39
|
response_prepared = response_prepared.replace("\n", "").replace("\\n", "")
|
42
40
|
response_prepared = response_prepared.replace('\\"', '"')
|
@@ -46,18 +44,22 @@ def pull_out_structured_output(response_raw: str, response_model: Type[BaseModel
|
|
46
44
|
except Exception as inner_e:
|
47
45
|
raise ValueError(
|
48
46
|
f"Failed to parse response as {response_model}: {inner_e} - {response_prepared}"
|
49
|
-
)
|
47
|
+
) from inner_e
|
50
48
|
except Exception as e:
|
51
|
-
raise ValueError(
|
49
|
+
raise ValueError(
|
50
|
+
f"Failed to parse response as {response_model}: {e} - {response_prepared}"
|
51
|
+
) from e
|
52
52
|
assert isinstance(final, BaseModel), "Structured output must be a Pydantic model"
|
53
53
|
return final
|
54
54
|
|
55
55
|
|
56
56
|
def fix_errant_stringified_json_sync(
|
57
57
|
response_raw: str,
|
58
|
-
response_model:
|
59
|
-
models:
|
58
|
+
response_model: type[BaseModel],
|
59
|
+
models: list[str] | None = None,
|
60
60
|
) -> BaseModel:
|
61
|
+
if models is None:
|
62
|
+
models = ["gpt-4o-mini", "gpt-4o"]
|
61
63
|
try:
|
62
64
|
return pull_out_structured_output(response_raw, response_model)
|
63
65
|
except ValueError as e:
|
@@ -85,14 +87,16 @@ def fix_errant_stringified_json_sync(
|
|
85
87
|
return pull_out_structured_output(fixed_response, response_model)
|
86
88
|
except Exception as e:
|
87
89
|
pass
|
88
|
-
raise ValueError("Failed to fix response using any model")
|
90
|
+
raise ValueError("Failed to fix response using any model") from None
|
89
91
|
|
90
92
|
|
91
93
|
async def fix_errant_stringified_json_async(
|
92
94
|
response_raw: str,
|
93
|
-
response_model:
|
94
|
-
models:
|
95
|
+
response_model: type[BaseModel],
|
96
|
+
models: list[str] | None = None,
|
95
97
|
) -> BaseModel:
|
98
|
+
if models is None:
|
99
|
+
models = ["gpt-4o-mini", "gpt-4o"]
|
96
100
|
try:
|
97
101
|
return pull_out_structured_output(response_raw, response_model)
|
98
102
|
except ValueError as e:
|
@@ -119,13 +123,13 @@ async def fix_errant_stringified_json_async(
|
|
119
123
|
return pull_out_structured_output(fixed_response, response_model)
|
120
124
|
except Exception as e:
|
121
125
|
pass
|
122
|
-
raise ValueError("Failed to fix response using any model")
|
126
|
+
raise ValueError("Failed to fix response using any model") from None
|
123
127
|
|
124
128
|
|
125
129
|
async def fix_errant_forced_async(
|
126
|
-
messages:
|
130
|
+
messages: list[dict],
|
127
131
|
response_raw: str,
|
128
|
-
response_model:
|
132
|
+
response_model: type[BaseModel],
|
129
133
|
model: str,
|
130
134
|
) -> BaseModel:
|
131
135
|
try:
|
@@ -157,7 +161,7 @@ async def fix_errant_forced_async(
|
|
157
161
|
|
158
162
|
def fix_errant_forced_sync(
|
159
163
|
response_raw: str,
|
160
|
-
response_model:
|
164
|
+
response_model: type[BaseModel],
|
161
165
|
model: str,
|
162
166
|
) -> BaseModel:
|
163
167
|
client = OpenAIStructuredOutputClient()
|
synth_ai/lm/tools/base.py
CHANGED
@@ -4,7 +4,7 @@ Base class for LM tools.
|
|
4
4
|
This module provides the base class for defining tools that can be used with language models.
|
5
5
|
"""
|
6
6
|
|
7
|
-
from typing import
|
7
|
+
from typing import Any
|
8
8
|
|
9
9
|
from pydantic import BaseModel
|
10
10
|
|
@@ -12,25 +12,26 @@ from pydantic import BaseModel
|
|
12
12
|
class BaseTool(BaseModel):
|
13
13
|
"""
|
14
14
|
Base class for defining tools that can be used with language models.
|
15
|
-
|
15
|
+
|
16
16
|
Attributes:
|
17
17
|
name: The name of the tool
|
18
18
|
arguments: Pydantic model defining the tool's arguments
|
19
19
|
description: Human-readable description of what the tool does
|
20
20
|
strict: Whether to enforce strict schema validation (default True)
|
21
21
|
"""
|
22
|
+
|
22
23
|
name: str
|
23
|
-
arguments:
|
24
|
+
arguments: type[BaseModel]
|
24
25
|
description: str = ""
|
25
26
|
strict: bool = True
|
26
27
|
|
27
|
-
def to_openai_tool(self) ->
|
28
|
+
def to_openai_tool(self) -> dict[str, Any]:
|
28
29
|
"""
|
29
30
|
Convert the tool to OpenAI's tool format.
|
30
|
-
|
31
|
+
|
31
32
|
Returns:
|
32
33
|
dict: Tool definition in OpenAI's expected format
|
33
|
-
|
34
|
+
|
34
35
|
Note:
|
35
36
|
- Ensures additionalProperties is False for strict validation
|
36
37
|
- Fixes array items that lack explicit types
|
@@ -40,7 +41,7 @@ class BaseTool(BaseModel):
|
|
40
41
|
schema["additionalProperties"] = False
|
41
42
|
|
42
43
|
if "properties" in schema:
|
43
|
-
for
|
44
|
+
for _prop_name, prop_schema in schema["properties"].items():
|
44
45
|
if prop_schema.get("type") == "array":
|
45
46
|
items_schema = prop_schema.get("items", {})
|
46
47
|
if not isinstance(items_schema, dict) or not items_schema.get("type"):
|
@@ -63,13 +64,13 @@ class BaseTool(BaseModel):
|
|
63
64
|
},
|
64
65
|
}
|
65
66
|
|
66
|
-
def to_anthropic_tool(self) ->
|
67
|
+
def to_anthropic_tool(self) -> dict[str, Any]:
|
67
68
|
"""
|
68
69
|
Convert the tool to Anthropic's tool format.
|
69
|
-
|
70
|
+
|
70
71
|
Returns:
|
71
72
|
dict: Tool definition in Anthropic's expected format
|
72
|
-
|
73
|
+
|
73
74
|
Note:
|
74
75
|
Anthropic uses a different format with input_schema instead of parameters.
|
75
76
|
"""
|
@@ -86,13 +87,13 @@ class BaseTool(BaseModel):
|
|
86
87
|
},
|
87
88
|
}
|
88
89
|
|
89
|
-
def to_mistral_tool(self) ->
|
90
|
+
def to_mistral_tool(self) -> dict[str, Any]:
|
90
91
|
"""
|
91
92
|
Convert the tool to Mistral's tool format.
|
92
|
-
|
93
|
+
|
93
94
|
Returns:
|
94
95
|
dict: Tool definition in Mistral's expected format
|
95
|
-
|
96
|
+
|
96
97
|
Note:
|
97
98
|
Mistral requires explicit handling of array types and enum values.
|
98
99
|
"""
|
@@ -130,13 +131,13 @@ class BaseTool(BaseModel):
|
|
130
131
|
},
|
131
132
|
}
|
132
133
|
|
133
|
-
def to_gemini_tool(self) ->
|
134
|
+
def to_gemini_tool(self) -> dict[str, Any]:
|
134
135
|
"""
|
135
136
|
Convert the tool to Gemini's tool format.
|
136
|
-
|
137
|
+
|
137
138
|
Returns:
|
138
139
|
dict: Tool definition in Gemini's expected format
|
139
|
-
|
140
|
+
|
140
141
|
Note:
|
141
142
|
Gemini uses a simpler format without the nested "function" key.
|
142
143
|
"""
|
synth_ai/lm/unified_interface.py
CHANGED
@@ -3,12 +3,11 @@ Unified interface for LM providers.
|
|
3
3
|
Provides a consistent API for OpenAI and Synth backends.
|
4
4
|
"""
|
5
5
|
|
6
|
-
import os
|
7
6
|
import logging
|
8
7
|
from abc import ABC, abstractmethod
|
9
|
-
from typing import
|
8
|
+
from typing import Any
|
10
9
|
|
11
|
-
from .config import
|
10
|
+
from .config import OpenAIConfig, SynthConfig
|
12
11
|
|
13
12
|
logger = logging.getLogger(__name__)
|
14
13
|
|
@@ -18,8 +17,8 @@ class UnifiedLMProvider(ABC):
|
|
18
17
|
|
19
18
|
@abstractmethod
|
20
19
|
async def create_chat_completion(
|
21
|
-
self, model: str, messages:
|
22
|
-
) ->
|
20
|
+
self, model: str, messages: list[dict[str, Any]], **kwargs
|
21
|
+
) -> dict[str, Any]:
|
23
22
|
"""Create a chat completion."""
|
24
23
|
pass
|
25
24
|
|
@@ -37,7 +36,7 @@ class UnifiedLMProvider(ABC):
|
|
37
36
|
class OpenAIProvider(UnifiedLMProvider):
|
38
37
|
"""OpenAI provider implementation."""
|
39
38
|
|
40
|
-
def __init__(self, api_key:
|
39
|
+
def __init__(self, api_key: str | None = None, **kwargs):
|
41
40
|
"""
|
42
41
|
Initialize OpenAI provider.
|
43
42
|
|
@@ -47,8 +46,8 @@ class OpenAIProvider(UnifiedLMProvider):
|
|
47
46
|
"""
|
48
47
|
try:
|
49
48
|
from openai import AsyncOpenAI
|
50
|
-
except ImportError:
|
51
|
-
raise ImportError("OpenAI package not installed. Run: pip install openai")
|
49
|
+
except ImportError as err:
|
50
|
+
raise ImportError("OpenAI package not installed. Run: pip install openai") from err
|
52
51
|
|
53
52
|
# Use provided key or load from environment
|
54
53
|
if api_key is None:
|
@@ -59,8 +58,8 @@ class OpenAIProvider(UnifiedLMProvider):
|
|
59
58
|
logger.info("Initialized OpenAI provider")
|
60
59
|
|
61
60
|
async def create_chat_completion(
|
62
|
-
self, model: str, messages:
|
63
|
-
) ->
|
61
|
+
self, model: str, messages: list[dict[str, Any]], **kwargs
|
62
|
+
) -> dict[str, Any]:
|
64
63
|
"""Create a chat completion using OpenAI."""
|
65
64
|
response = await self.client.chat.completions.create(
|
66
65
|
model=model, messages=messages, **kwargs
|
@@ -82,7 +81,7 @@ class OpenAIProvider(UnifiedLMProvider):
|
|
82
81
|
class SynthProvider(UnifiedLMProvider):
|
83
82
|
"""Synth provider implementation."""
|
84
83
|
|
85
|
-
def __init__(self, config:
|
84
|
+
def __init__(self, config: SynthConfig | None = None, **kwargs):
|
86
85
|
"""
|
87
86
|
Initialize Synth provider.
|
88
87
|
|
@@ -96,8 +95,8 @@ class SynthProvider(UnifiedLMProvider):
|
|
96
95
|
self.client = AsyncSynthClient(self.config)
|
97
96
|
|
98
97
|
async def create_chat_completion(
|
99
|
-
self, model: str, messages:
|
100
|
-
) ->
|
98
|
+
self, model: str, messages: list[dict[str, Any]], **kwargs
|
99
|
+
) -> dict[str, Any]:
|
101
100
|
"""Create a chat completion using Synth."""
|
102
101
|
return await self.client.chat_completions_create(model=model, messages=messages, **kwargs)
|
103
102
|
|
@@ -156,9 +155,9 @@ class UnifiedLMClient:
|
|
156
155
|
default_provider: Default provider to use ("openai" or "synth")
|
157
156
|
"""
|
158
157
|
self.default_provider = default_provider
|
159
|
-
self._providers:
|
158
|
+
self._providers: dict[str, UnifiedLMProvider] = {}
|
160
159
|
|
161
|
-
async def _get_provider(self, provider:
|
160
|
+
async def _get_provider(self, provider: str | None = None) -> UnifiedLMProvider:
|
162
161
|
"""Get or create a provider instance."""
|
163
162
|
provider_name = provider or self.default_provider
|
164
163
|
|
@@ -168,8 +167,8 @@ class UnifiedLMClient:
|
|
168
167
|
return self._providers[provider_name]
|
169
168
|
|
170
169
|
async def create_chat_completion(
|
171
|
-
self, model: str, messages:
|
172
|
-
) ->
|
170
|
+
self, model: str, messages: list[dict[str, Any]], provider: str | None = None, **kwargs
|
171
|
+
) -> dict[str, Any]:
|
173
172
|
"""
|
174
173
|
Create a chat completion using specified or default provider.
|
175
174
|
|
@@ -185,7 +184,7 @@ class UnifiedLMClient:
|
|
185
184
|
provider_instance = await self._get_provider(provider)
|
186
185
|
return await provider_instance.create_chat_completion(model, messages, **kwargs)
|
187
186
|
|
188
|
-
async def warmup(self, model: str, provider:
|
187
|
+
async def warmup(self, model: str, provider: str | None = None, **kwargs) -> bool:
|
189
188
|
"""Warm up a model on specified provider."""
|
190
189
|
provider_instance = await self._get_provider(provider)
|
191
190
|
return await provider_instance.warmup(model, **kwargs)
|