synth-ai 0.1.0.dev4__py3-none-any.whl → 0.1.0.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- public_tests/synth_sdk.py +389 -0
- public_tests/test_agent.py +538 -0
- public_tests/test_recursive_structured_outputs.py +180 -0
- public_tests/test_structured_outputs.py +100 -0
- synth_ai/zyk/lms/__init__.py +0 -0
- synth_ai/zyk/lms/caching/__init__.py +0 -0
- synth_ai/zyk/lms/caching/constants.py +1 -0
- synth_ai/zyk/lms/caching/dbs.py +0 -0
- synth_ai/zyk/lms/caching/ephemeral.py +50 -0
- synth_ai/zyk/lms/caching/handler.py +92 -0
- synth_ai/zyk/lms/caching/initialize.py +13 -0
- synth_ai/zyk/lms/caching/persistent.py +55 -0
- synth_ai/zyk/lms/config.py +8 -0
- synth_ai/zyk/lms/core/__init__.py +0 -0
- synth_ai/zyk/lms/core/all.py +35 -0
- synth_ai/zyk/lms/core/exceptions.py +9 -0
- synth_ai/zyk/lms/core/main.py +245 -0
- synth_ai/zyk/lms/core/vendor_clients.py +60 -0
- synth_ai/zyk/lms/cost/__init__.py +0 -0
- synth_ai/zyk/lms/cost/monitor.py +1 -0
- synth_ai/zyk/lms/cost/statefulness.py +1 -0
- synth_ai/zyk/lms/structured_outputs/__init__.py +0 -0
- synth_ai/zyk/lms/structured_outputs/handler.py +388 -0
- synth_ai/zyk/lms/structured_outputs/inject.py +185 -0
- synth_ai/zyk/lms/structured_outputs/rehabilitate.py +186 -0
- synth_ai/zyk/lms/vendors/__init__.py +0 -0
- synth_ai/zyk/lms/vendors/base.py +15 -0
- synth_ai/zyk/lms/vendors/constants.py +5 -0
- synth_ai/zyk/lms/vendors/core/__init__.py +0 -0
- synth_ai/zyk/lms/vendors/core/anthropic_api.py +191 -0
- synth_ai/zyk/lms/vendors/core/gemini_api.py +146 -0
- synth_ai/zyk/lms/vendors/core/openai_api.py +145 -0
- synth_ai/zyk/lms/vendors/local/__init__.py +0 -0
- synth_ai/zyk/lms/vendors/local/ollama.py +0 -0
- synth_ai/zyk/lms/vendors/openai_standard.py +141 -0
- synth_ai/zyk/lms/vendors/retries.py +3 -0
- synth_ai/zyk/lms/vendors/supported/__init__.py +0 -0
- synth_ai/zyk/lms/vendors/supported/deepseek.py +18 -0
- synth_ai/zyk/lms/vendors/supported/together.py +11 -0
- {synth_ai-0.1.0.dev4.dist-info → synth_ai-0.1.0.dev6.dist-info}/METADATA +1 -1
- synth_ai-0.1.0.dev6.dist-info/RECORD +46 -0
- synth_ai-0.1.0.dev6.dist-info/top_level.txt +2 -0
- synth_ai-0.1.0.dev4.dist-info/RECORD +0 -7
- synth_ai-0.1.0.dev4.dist-info/top_level.txt +0 -1
- {synth_ai-0.1.0.dev4.dist-info → synth_ai-0.1.0.dev6.dist-info}/LICENSE +0 -0
- {synth_ai-0.1.0.dev4.dist-info → synth_ai-0.1.0.dev6.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Literal, Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
from synth_ai.zyk.lms.core.exceptions import StructuredOutputCoercionFailureException
|
|
6
|
+
from synth_ai.zyk.lms.core.vendor_clients import (
|
|
7
|
+
anthropic_naming_regexes,
|
|
8
|
+
get_client,
|
|
9
|
+
openai_naming_regexes,
|
|
10
|
+
)
|
|
11
|
+
from synth_ai.zyk.lms.structured_outputs.handler import StructuredOutputHandler
|
|
12
|
+
from synth_ai.zyk.lms.vendors.base import VendorBase
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def build_messages(
|
|
16
|
+
sys_msg: str,
|
|
17
|
+
user_msg: str,
|
|
18
|
+
images_bytes: List = [],
|
|
19
|
+
model_name: Optional[str] = None,
|
|
20
|
+
) -> List[Dict]:
|
|
21
|
+
if len(images_bytes) > 0 and any(
|
|
22
|
+
regex.match(model_name) for regex in openai_naming_regexes
|
|
23
|
+
):
|
|
24
|
+
return [
|
|
25
|
+
{"role": "system", "content": sys_msg},
|
|
26
|
+
{
|
|
27
|
+
"role": "user",
|
|
28
|
+
"content": [{"type": "text", "text": user_msg}]
|
|
29
|
+
+ [
|
|
30
|
+
{
|
|
31
|
+
"type": "image_url",
|
|
32
|
+
"image_url": {"url": f"data:image/jpeg;base64,{image_bytes}"},
|
|
33
|
+
}
|
|
34
|
+
for image_bytes in images_bytes
|
|
35
|
+
],
|
|
36
|
+
},
|
|
37
|
+
]
|
|
38
|
+
elif len(images_bytes) > 0 and any(
|
|
39
|
+
regex.match(model_name) for regex in anthropic_naming_regexes
|
|
40
|
+
):
|
|
41
|
+
system_info = {"role": "system", "content": sys_msg}
|
|
42
|
+
user_info = {
|
|
43
|
+
"role": "user",
|
|
44
|
+
"content": [{"type": "text", "text": user_msg}]
|
|
45
|
+
+ [
|
|
46
|
+
{
|
|
47
|
+
"type": "image",
|
|
48
|
+
"source": {
|
|
49
|
+
"type": "base64",
|
|
50
|
+
"media_type": "image/png",
|
|
51
|
+
"data": image_bytes,
|
|
52
|
+
},
|
|
53
|
+
}
|
|
54
|
+
for image_bytes in images_bytes
|
|
55
|
+
],
|
|
56
|
+
}
|
|
57
|
+
return [system_info, user_info]
|
|
58
|
+
elif len(images_bytes) > 0:
|
|
59
|
+
raise ValueError("Images are not yet supported for this model")
|
|
60
|
+
else:
|
|
61
|
+
return [
|
|
62
|
+
{"role": "system", "content": sys_msg},
|
|
63
|
+
{"role": "user", "content": user_msg},
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class LM:
|
|
68
|
+
# if str
|
|
69
|
+
model_name: str
|
|
70
|
+
client: VendorBase
|
|
71
|
+
lm_config: Dict[str, Any]
|
|
72
|
+
structured_output_handler: StructuredOutputHandler
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
model_name: str,
|
|
77
|
+
formatting_model_name: str,
|
|
78
|
+
temperature: float,
|
|
79
|
+
max_retries: Literal["None", "Few", "Many"] = "Few",
|
|
80
|
+
structured_output_mode: Literal[
|
|
81
|
+
"stringified_json", "forced_json"
|
|
82
|
+
] = "stringified_json",
|
|
83
|
+
synth_logging: bool = True,
|
|
84
|
+
):
|
|
85
|
+
# print("Structured output mode", structured_output_mode)
|
|
86
|
+
self.client = get_client(
|
|
87
|
+
model_name,
|
|
88
|
+
with_formatting=structured_output_mode == "forced_json",
|
|
89
|
+
synth_logging=synth_logging,
|
|
90
|
+
)
|
|
91
|
+
# print(self.client.__class__)
|
|
92
|
+
|
|
93
|
+
formatting_client = get_client(formatting_model_name, with_formatting=True)
|
|
94
|
+
|
|
95
|
+
max_retries_dict = {"None": 0, "Few": 2, "Many": 5}
|
|
96
|
+
self.structured_output_handler = StructuredOutputHandler(
|
|
97
|
+
self.client,
|
|
98
|
+
formatting_client,
|
|
99
|
+
structured_output_mode,
|
|
100
|
+
{"max_retries": max_retries_dict.get(max_retries, 2)},
|
|
101
|
+
)
|
|
102
|
+
self.backup_structured_output_handler = StructuredOutputHandler(
|
|
103
|
+
self.client,
|
|
104
|
+
formatting_client,
|
|
105
|
+
"forced_json",
|
|
106
|
+
{"max_retries": max_retries_dict.get(max_retries, 2)},
|
|
107
|
+
)
|
|
108
|
+
self.lm_config = {"temperature": temperature}
|
|
109
|
+
self.model_name = model_name
|
|
110
|
+
|
|
111
|
+
def respond_sync(
|
|
112
|
+
self,
|
|
113
|
+
system_message: Optional[str] = None,
|
|
114
|
+
user_message: Optional[str] = None,
|
|
115
|
+
messages: Optional[List[Dict]] = None,
|
|
116
|
+
images_as_bytes: List[Any] = [],
|
|
117
|
+
response_model: Optional[BaseModel] = None,
|
|
118
|
+
use_ephemeral_cache_only: bool = False,
|
|
119
|
+
):
|
|
120
|
+
assert (system_message is None) == (
|
|
121
|
+
user_message is None
|
|
122
|
+
), "Must provide both system_message and user_message or neither"
|
|
123
|
+
assert (
|
|
124
|
+
(messages is None) != (system_message is None)
|
|
125
|
+
), "Must provide either messages or system_message/user_message pair, but not both"
|
|
126
|
+
|
|
127
|
+
if messages is None:
|
|
128
|
+
messages = build_messages(
|
|
129
|
+
system_message, user_message, images_as_bytes, self.model_name
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
if response_model:
|
|
133
|
+
try:
|
|
134
|
+
return self.structured_output_handler.call_sync(
|
|
135
|
+
messages,
|
|
136
|
+
model=self.model_name,
|
|
137
|
+
lm_config=self.lm_config,
|
|
138
|
+
response_model=response_model,
|
|
139
|
+
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
|
140
|
+
)
|
|
141
|
+
except StructuredOutputCoercionFailureException:
|
|
142
|
+
# print("Falling back to backup handler")
|
|
143
|
+
return self.backup_structured_output_handler.call_sync(
|
|
144
|
+
messages,
|
|
145
|
+
model=self.model_name,
|
|
146
|
+
lm_config=self.lm_config,
|
|
147
|
+
response_model=response_model,
|
|
148
|
+
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
|
149
|
+
)
|
|
150
|
+
else:
|
|
151
|
+
return self.client._hit_api_sync(
|
|
152
|
+
messages=messages,
|
|
153
|
+
model=self.model_name,
|
|
154
|
+
lm_config=self.lm_config,
|
|
155
|
+
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
async def respond_async(
|
|
159
|
+
self,
|
|
160
|
+
system_message: Optional[str] = None,
|
|
161
|
+
user_message: Optional[str] = None,
|
|
162
|
+
messages: Optional[List[Dict]] = None,
|
|
163
|
+
images_as_bytes: List[Any] = [],
|
|
164
|
+
response_model: Optional[BaseModel] = None,
|
|
165
|
+
use_ephemeral_cache_only: bool = False,
|
|
166
|
+
):
|
|
167
|
+
# "In respond_async")
|
|
168
|
+
assert (system_message is None) == (
|
|
169
|
+
user_message is None
|
|
170
|
+
), "Must provide both system_message and user_message or neither"
|
|
171
|
+
assert (
|
|
172
|
+
(messages is None) != (system_message is None)
|
|
173
|
+
), "Must provide either messages or system_message/user_message pair, but not both"
|
|
174
|
+
|
|
175
|
+
if messages is None:
|
|
176
|
+
messages = build_messages(
|
|
177
|
+
system_message, user_message, images_as_bytes, self.model_name
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
if response_model:
|
|
181
|
+
try:
|
|
182
|
+
# "Trying structured output handler")
|
|
183
|
+
return await self.structured_output_handler.call_async(
|
|
184
|
+
messages,
|
|
185
|
+
model=self.model_name,
|
|
186
|
+
lm_config=self.lm_config,
|
|
187
|
+
response_model=response_model,
|
|
188
|
+
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
|
189
|
+
)
|
|
190
|
+
except StructuredOutputCoercionFailureException:
|
|
191
|
+
# print("Falling back to backup handler")
|
|
192
|
+
return await self.backup_structured_output_handler.call_async(
|
|
193
|
+
messages,
|
|
194
|
+
model=self.model_name,
|
|
195
|
+
lm_config=self.lm_config,
|
|
196
|
+
response_model=response_model,
|
|
197
|
+
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
|
198
|
+
)
|
|
199
|
+
else:
|
|
200
|
+
# print("Calling API no response model")
|
|
201
|
+
return await self.client._hit_api_async(
|
|
202
|
+
messages=messages,
|
|
203
|
+
model=self.model_name,
|
|
204
|
+
lm_config=self.lm_config,
|
|
205
|
+
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
if __name__ == "__main__":
|
|
210
|
+
import asyncio
|
|
211
|
+
|
|
212
|
+
# Update json instructions to handle nested pydantic?
|
|
213
|
+
class Thought(BaseModel):
|
|
214
|
+
argument_keys: List[str] = Field(description="The keys of the arguments")
|
|
215
|
+
argument_values: List[str] = Field(
|
|
216
|
+
description="Stringified JSON for the values of the arguments"
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
class TestModel(BaseModel):
|
|
220
|
+
emotion: str = Field(description="The emotion expressed")
|
|
221
|
+
concern: str = Field(description="The concern expressed")
|
|
222
|
+
action: str = Field(description="The action to be taken")
|
|
223
|
+
thought: Thought = Field(description="The thought process")
|
|
224
|
+
|
|
225
|
+
class Config:
|
|
226
|
+
schema_extra = {"required": ["thought", "emotion", "concern", "action"]}
|
|
227
|
+
|
|
228
|
+
lm = LM(
|
|
229
|
+
model_name="gpt-4o-mini",
|
|
230
|
+
formatting_model_name="gpt-4o-mini",
|
|
231
|
+
temperature=1,
|
|
232
|
+
max_retries="Few",
|
|
233
|
+
structured_output_mode="forced_json",
|
|
234
|
+
)
|
|
235
|
+
print(
|
|
236
|
+
asyncio.run(
|
|
237
|
+
lm.respond_async(
|
|
238
|
+
system_message="You are a helpful assistant ",
|
|
239
|
+
user_message="Hello, how are you?",
|
|
240
|
+
images_as_bytes=[],
|
|
241
|
+
response_model=TestModel,
|
|
242
|
+
use_ephemeral_cache_only=False,
|
|
243
|
+
)
|
|
244
|
+
)
|
|
245
|
+
)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import Any, List, Pattern
|
|
3
|
+
|
|
4
|
+
from synth_ai.zyk.lms.core.all import (
|
|
5
|
+
AnthropicClient,
|
|
6
|
+
DeepSeekClient,
|
|
7
|
+
GeminiClient,
|
|
8
|
+
# OpenAIClient,
|
|
9
|
+
OpenAIStructuredOutputClient,
|
|
10
|
+
TogetherClient,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
openai_naming_regexes: List[Pattern] = [
|
|
14
|
+
re.compile(r"^(ft:)?(o[1,3](-.*)?|gpt-.*)$"),
|
|
15
|
+
]
|
|
16
|
+
openai_formatting_model_regexes: List[Pattern] = [
|
|
17
|
+
re.compile(r"^(ft:)?gpt-4o(-.*)?$"),
|
|
18
|
+
]
|
|
19
|
+
anthropic_naming_regexes: List[Pattern] = [
|
|
20
|
+
re.compile(r"^claude-.*$"),
|
|
21
|
+
]
|
|
22
|
+
gemini_naming_regexes: List[Pattern] = [
|
|
23
|
+
re.compile(r"^gemini-.*$"),
|
|
24
|
+
]
|
|
25
|
+
deepseek_naming_regexes: List[Pattern] = [
|
|
26
|
+
re.compile(r"^deepseek-.*$"),
|
|
27
|
+
]
|
|
28
|
+
together_naming_regexes: List[Pattern] = [
|
|
29
|
+
re.compile(r"^.*\/.*$"),
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_client(
|
|
34
|
+
model_name: str,
|
|
35
|
+
with_formatting: bool = False,
|
|
36
|
+
synth_logging: bool = True,
|
|
37
|
+
) -> Any:
|
|
38
|
+
# print("With formatting", with_formatting)
|
|
39
|
+
if any(regex.match(model_name) for regex in openai_naming_regexes):
|
|
40
|
+
# print("Returning OpenAIStructuredOutputClient")
|
|
41
|
+
return OpenAIStructuredOutputClient(
|
|
42
|
+
synth_logging=synth_logging,
|
|
43
|
+
)
|
|
44
|
+
elif any(regex.match(model_name) for regex in anthropic_naming_regexes):
|
|
45
|
+
if with_formatting:
|
|
46
|
+
client = AnthropicClient()
|
|
47
|
+
client._hit_api_async_structured_output = OpenAIStructuredOutputClient(
|
|
48
|
+
synth_logging=synth_logging
|
|
49
|
+
)._hit_api_async
|
|
50
|
+
return client
|
|
51
|
+
else:
|
|
52
|
+
return AnthropicClient()
|
|
53
|
+
elif any(regex.match(model_name) for regex in gemini_naming_regexes):
|
|
54
|
+
return GeminiClient()
|
|
55
|
+
elif any(regex.match(model_name) for regex in deepseek_naming_regexes):
|
|
56
|
+
return DeepSeekClient()
|
|
57
|
+
elif any(regex.match(model_name) for regex in together_naming_regexes):
|
|
58
|
+
return TogetherClient()
|
|
59
|
+
else:
|
|
60
|
+
raise ValueError(f"Invalid model name: {model_name}")
|
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
#TODO
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Maybe some kind of ephemeral cache
|
|
File without changes
|
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
from synth_ai.zyk.lms.core.exceptions import StructuredOutputCoercionFailureException
|
|
8
|
+
from synth_ai.zyk.lms.structured_outputs.inject import (
|
|
9
|
+
inject_structured_output_instructions,
|
|
10
|
+
)
|
|
11
|
+
from synth_ai.zyk.lms.structured_outputs.rehabilitate import (
|
|
12
|
+
fix_errant_forced_async,
|
|
13
|
+
fix_errant_forced_sync,
|
|
14
|
+
fix_errant_stringified_json_async,
|
|
15
|
+
fix_errant_stringified_json_sync,
|
|
16
|
+
pull_out_structured_output,
|
|
17
|
+
)
|
|
18
|
+
from synth_ai.zyk.lms.vendors.base import VendorBase
|
|
19
|
+
from synth_ai.zyk.lms.vendors.constants import SPECIAL_BASE_TEMPS
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class StructuredHandlerBase(ABC):
|
|
23
|
+
core_client: VendorBase
|
|
24
|
+
retry_client: VendorBase
|
|
25
|
+
handler_params: Dict[str, Any]
|
|
26
|
+
structured_output_mode: Literal["stringified_json", "forced_json"]
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
core_client: VendorBase,
|
|
31
|
+
retry_client: VendorBase,
|
|
32
|
+
handler_params: Optional[Dict[str, Any]] = None,
|
|
33
|
+
structured_output_mode: Literal[
|
|
34
|
+
"stringified_json", "forced_json"
|
|
35
|
+
] = "stringified_json",
|
|
36
|
+
):
|
|
37
|
+
self.core_client = core_client
|
|
38
|
+
self.retry_client = retry_client
|
|
39
|
+
self.handler_params = (
|
|
40
|
+
handler_params if handler_params is not None else {"retries": 3}
|
|
41
|
+
)
|
|
42
|
+
self.structured_output_mode = structured_output_mode
|
|
43
|
+
|
|
44
|
+
async def call_async(
|
|
45
|
+
self,
|
|
46
|
+
messages: List[Dict[str, Any]],
|
|
47
|
+
model: str,
|
|
48
|
+
response_model: BaseModel,
|
|
49
|
+
temperature: float = 0.0,
|
|
50
|
+
use_ephemeral_cache_only: bool = False,
|
|
51
|
+
) -> BaseModel:
|
|
52
|
+
if temperature == 0.0:
|
|
53
|
+
temperature = SPECIAL_BASE_TEMPS.get(model, 0.0)
|
|
54
|
+
# print("Calling from base")
|
|
55
|
+
return await self._process_call_async(
|
|
56
|
+
messages=messages,
|
|
57
|
+
model=model,
|
|
58
|
+
response_model=response_model,
|
|
59
|
+
api_call_method=self.core_client._hit_api_async_structured_output
|
|
60
|
+
if (not not response_model and self.structured_output_mode == "forced_json")
|
|
61
|
+
else self.core_client._hit_api_async,
|
|
62
|
+
temperature=temperature,
|
|
63
|
+
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def call_sync(
|
|
67
|
+
self,
|
|
68
|
+
messages: List[Dict[str, Any]],
|
|
69
|
+
response_model: BaseModel,
|
|
70
|
+
model: str,
|
|
71
|
+
temperature: float = 0.0,
|
|
72
|
+
use_ephemeral_cache_only: bool = False,
|
|
73
|
+
) -> BaseModel:
|
|
74
|
+
if temperature == 0.0:
|
|
75
|
+
temperature = SPECIAL_BASE_TEMPS.get(model, 0.0)
|
|
76
|
+
return self._process_call_sync(
|
|
77
|
+
messages=messages,
|
|
78
|
+
model=model,
|
|
79
|
+
response_model=response_model,
|
|
80
|
+
api_call_method=self.core_client._hit_api_sync_structured_output
|
|
81
|
+
if (not not response_model and self.structured_output_mode == "forced_json")
|
|
82
|
+
else self.core_client._hit_api_sync,
|
|
83
|
+
temperature=temperature,
|
|
84
|
+
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
@abstractmethod
|
|
88
|
+
async def _process_call_async(
|
|
89
|
+
self,
|
|
90
|
+
messages: List[Dict[str, Any]],
|
|
91
|
+
model: str,
|
|
92
|
+
response_model: BaseModel,
|
|
93
|
+
api_call_method,
|
|
94
|
+
use_ephemeral_cache_only: bool = False,
|
|
95
|
+
) -> BaseModel:
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
@abstractmethod
|
|
99
|
+
def _process_call_sync(
|
|
100
|
+
self,
|
|
101
|
+
messages: List[Dict[str, Any]],
|
|
102
|
+
model: str,
|
|
103
|
+
response_model: BaseModel,
|
|
104
|
+
api_call_method,
|
|
105
|
+
use_ephemeral_cache_only: bool = False,
|
|
106
|
+
) -> BaseModel:
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class StringifiedJSONHandler(StructuredHandlerBase):
|
|
111
|
+
core_client: VendorBase
|
|
112
|
+
retry_client: VendorBase
|
|
113
|
+
handler_params: Dict[str, Any]
|
|
114
|
+
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
core_client: VendorBase,
|
|
118
|
+
retry_client: VendorBase,
|
|
119
|
+
handler_params: Dict[str, Any] = {"retries": 3},
|
|
120
|
+
):
|
|
121
|
+
super().__init__(
|
|
122
|
+
core_client,
|
|
123
|
+
retry_client,
|
|
124
|
+
handler_params,
|
|
125
|
+
structured_output_mode="stringified_json",
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
async def _process_call_async(
|
|
129
|
+
self,
|
|
130
|
+
messages: List[Dict[str, Any]],
|
|
131
|
+
model: str,
|
|
132
|
+
response_model: BaseModel,
|
|
133
|
+
temperature: float,
|
|
134
|
+
api_call_method: Callable,
|
|
135
|
+
use_ephemeral_cache_only: bool = False,
|
|
136
|
+
) -> BaseModel:
|
|
137
|
+
# print("In _process_call_async")
|
|
138
|
+
assert isinstance(
|
|
139
|
+
api_call_method, Callable
|
|
140
|
+
), "api_call_method must be a callable"
|
|
141
|
+
assert (
|
|
142
|
+
response_model is not None
|
|
143
|
+
), "Don't use this handler for unstructured outputs"
|
|
144
|
+
remaining_retries = self.handler_params.get("retries", 2)
|
|
145
|
+
previously_failed_error_messages = []
|
|
146
|
+
structured_output = None
|
|
147
|
+
|
|
148
|
+
while remaining_retries > 0:
|
|
149
|
+
messages_with_json_formatting_instructions = (
|
|
150
|
+
inject_structured_output_instructions(
|
|
151
|
+
messages=messages,
|
|
152
|
+
response_model=response_model,
|
|
153
|
+
previously_failed_error_messages=previously_failed_error_messages,
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
t0 = time.time()
|
|
157
|
+
raw_text_response_or_cached_hit = await api_call_method(
|
|
158
|
+
messages=messages_with_json_formatting_instructions,
|
|
159
|
+
model=model,
|
|
160
|
+
lm_config={"response_model": None, "temperature": temperature},
|
|
161
|
+
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
|
162
|
+
)
|
|
163
|
+
# print(f"Time to get response: {time.time() - t0}")
|
|
164
|
+
if not isinstance(raw_text_response_or_cached_hit, str):
|
|
165
|
+
return raw_text_response_or_cached_hit
|
|
166
|
+
else:
|
|
167
|
+
raw_text_response = raw_text_response_or_cached_hit
|
|
168
|
+
try:
|
|
169
|
+
structured_output = pull_out_structured_output(
|
|
170
|
+
raw_text_response, response_model
|
|
171
|
+
)
|
|
172
|
+
break
|
|
173
|
+
# except Exception as e:
|
|
174
|
+
# try:
|
|
175
|
+
# structured_output = await fix_errant_stringified_json_async(raw_text_response, response_model)
|
|
176
|
+
# break
|
|
177
|
+
except Exception as e:
|
|
178
|
+
try:
|
|
179
|
+
# t0 = time.time()
|
|
180
|
+
# print(f"Got error {e}, attempting to fix")
|
|
181
|
+
structured_output = await fix_errant_forced_async(
|
|
182
|
+
messages_with_json_formatting_instructions,
|
|
183
|
+
raw_text_response,
|
|
184
|
+
response_model,
|
|
185
|
+
"gpt-4o-mini",
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# print(f"Time to fix: {time.time() - t0}")
|
|
189
|
+
break
|
|
190
|
+
except Exception as e:
|
|
191
|
+
previously_failed_error_messages.append(
|
|
192
|
+
f"Generated attempt and got error. Attempt:\n\n{raw_text_response}\n\nError:\n\n{e}"
|
|
193
|
+
)
|
|
194
|
+
remaining_retries -= 1
|
|
195
|
+
if structured_output is None:
|
|
196
|
+
raise StructuredOutputCoercionFailureException(
|
|
197
|
+
"Failed to get structured output"
|
|
198
|
+
)
|
|
199
|
+
return structured_output
|
|
200
|
+
|
|
201
|
+
def _process_call_sync(
|
|
202
|
+
self,
|
|
203
|
+
messages: List[Dict[str, Any]],
|
|
204
|
+
model: str,
|
|
205
|
+
response_model: BaseModel,
|
|
206
|
+
temperature: float,
|
|
207
|
+
api_call_method: Callable,
|
|
208
|
+
use_ephemeral_cache_only: bool = False,
|
|
209
|
+
) -> BaseModel:
|
|
210
|
+
assert isinstance(
|
|
211
|
+
api_call_method, Callable
|
|
212
|
+
), "api_call_method must be a callable"
|
|
213
|
+
assert (
|
|
214
|
+
response_model is not None
|
|
215
|
+
), "Don't use this handler for unstructured outputs"
|
|
216
|
+
remaining_retries = self.handler_params.get("retries", 2)
|
|
217
|
+
previously_failed_error_messages = []
|
|
218
|
+
structured_output = None
|
|
219
|
+
|
|
220
|
+
while remaining_retries > 0:
|
|
221
|
+
messages_with_json_formatting_instructions = (
|
|
222
|
+
inject_structured_output_instructions(
|
|
223
|
+
messages=messages,
|
|
224
|
+
response_model=response_model,
|
|
225
|
+
previously_failed_error_messages=previously_failed_error_messages,
|
|
226
|
+
)
|
|
227
|
+
)
|
|
228
|
+
# t0 = time.time()
|
|
229
|
+
raw_text_response_or_cached_hit = api_call_method(
|
|
230
|
+
messages=messages_with_json_formatting_instructions,
|
|
231
|
+
model=model,
|
|
232
|
+
lm_config={"response_model": None, "temperature": temperature},
|
|
233
|
+
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
|
234
|
+
)
|
|
235
|
+
# print(f"Time to get response: {time.time() - t0}")
|
|
236
|
+
if not isinstance(raw_text_response_or_cached_hit, str):
|
|
237
|
+
return raw_text_response_or_cached_hit
|
|
238
|
+
else:
|
|
239
|
+
raw_text_response = raw_text_response_or_cached_hit
|
|
240
|
+
try:
|
|
241
|
+
structured_output = pull_out_structured_output(
|
|
242
|
+
raw_text_response, response_model
|
|
243
|
+
)
|
|
244
|
+
break
|
|
245
|
+
# except Exception:
|
|
246
|
+
# try:
|
|
247
|
+
# structured_output = fix_errant_stringified_json_sync(raw_text_response, response_model)
|
|
248
|
+
# break
|
|
249
|
+
except Exception as e:
|
|
250
|
+
try:
|
|
251
|
+
# t0 = time.time()
|
|
252
|
+
# print(f"Got error {e}, attempting to fix")
|
|
253
|
+
structured_output = fix_errant_forced_sync(
|
|
254
|
+
raw_text_response, response_model, "gpt-4o-mini"
|
|
255
|
+
)
|
|
256
|
+
break
|
|
257
|
+
# print(f"Time to fix: {time.time() - t0}")
|
|
258
|
+
except Exception as e:
|
|
259
|
+
previously_failed_error_messages.append(
|
|
260
|
+
f"Generated attempt and got error. Attempt:\n\n{raw_text_response}\n\nError:\n\n{e}"
|
|
261
|
+
)
|
|
262
|
+
remaining_retries -= 1
|
|
263
|
+
if structured_output is None:
|
|
264
|
+
raise StructuredOutputCoercionFailureException(
|
|
265
|
+
"Failed to get structured output"
|
|
266
|
+
)
|
|
267
|
+
return structured_output
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class ForcedJSONHandler(StructuredHandlerBase):
|
|
271
|
+
core_client: VendorBase
|
|
272
|
+
retry_client: VendorBase
|
|
273
|
+
handler_params: Dict[str, Any]
|
|
274
|
+
|
|
275
|
+
def __init__(
|
|
276
|
+
self,
|
|
277
|
+
core_client: VendorBase,
|
|
278
|
+
retry_client: VendorBase,
|
|
279
|
+
handler_params: Dict[str, Any] = {},
|
|
280
|
+
):
|
|
281
|
+
super().__init__(
|
|
282
|
+
core_client,
|
|
283
|
+
retry_client,
|
|
284
|
+
handler_params,
|
|
285
|
+
structured_output_mode="forced_json",
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
async def _process_call_async(
|
|
289
|
+
self,
|
|
290
|
+
messages: List[Dict[str, Any]],
|
|
291
|
+
model: str,
|
|
292
|
+
response_model: BaseModel,
|
|
293
|
+
api_call_method: Callable,
|
|
294
|
+
temperature: float = 0.0,
|
|
295
|
+
use_ephemeral_cache_only: bool = False,
|
|
296
|
+
) -> BaseModel:
|
|
297
|
+
# print("Forced JSON")
|
|
298
|
+
assert (
|
|
299
|
+
response_model is not None
|
|
300
|
+
), "Don't use this handler for unstructured outputs"
|
|
301
|
+
return await api_call_method(
|
|
302
|
+
messages=messages,
|
|
303
|
+
model=model,
|
|
304
|
+
response_model=response_model,
|
|
305
|
+
temperature=temperature,
|
|
306
|
+
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
def _process_call_sync(
|
|
310
|
+
self,
|
|
311
|
+
messages: List[Dict[str, Any]],
|
|
312
|
+
model: str,
|
|
313
|
+
response_model: BaseModel,
|
|
314
|
+
api_call_method: Callable,
|
|
315
|
+
temperature: float = 0.0,
|
|
316
|
+
use_ephemeral_cache_only: bool = False,
|
|
317
|
+
) -> BaseModel:
|
|
318
|
+
assert (
|
|
319
|
+
response_model is not None
|
|
320
|
+
), "Don't use this handler for unstructured outputs"
|
|
321
|
+
return api_call_method(
|
|
322
|
+
messages=messages,
|
|
323
|
+
model=model,
|
|
324
|
+
response_model=response_model,
|
|
325
|
+
temperature=temperature,
|
|
326
|
+
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
class StructuredOutputHandler:
|
|
331
|
+
handler: Union[StringifiedJSONHandler, ForcedJSONHandler]
|
|
332
|
+
mode: Literal["stringified_json", "forced_json"]
|
|
333
|
+
handler_params: Dict[str, Any]
|
|
334
|
+
|
|
335
|
+
def __init__(
|
|
336
|
+
self,
|
|
337
|
+
core_client: VendorBase,
|
|
338
|
+
retry_client: VendorBase,
|
|
339
|
+
mode: Literal["stringified_json", "forced_json"],
|
|
340
|
+
handler_params: Dict[str, Any] = {},
|
|
341
|
+
):
|
|
342
|
+
self.mode = mode
|
|
343
|
+
if self.mode == "stringified_json":
|
|
344
|
+
self.handler = StringifiedJSONHandler(
|
|
345
|
+
core_client, retry_client, handler_params
|
|
346
|
+
)
|
|
347
|
+
elif self.mode == "forced_json":
|
|
348
|
+
# print("Forced JSON")
|
|
349
|
+
self.handler = ForcedJSONHandler(core_client, retry_client, handler_params)
|
|
350
|
+
else:
|
|
351
|
+
raise ValueError(f"Invalid mode: {mode}")
|
|
352
|
+
|
|
353
|
+
async def call_async(
|
|
354
|
+
self,
|
|
355
|
+
messages: List[Dict[str, Any]],
|
|
356
|
+
model: str,
|
|
357
|
+
response_model: BaseModel,
|
|
358
|
+
use_ephemeral_cache_only: bool = False,
|
|
359
|
+
lm_config: Dict[str, Any] = {},
|
|
360
|
+
) -> BaseModel:
|
|
361
|
+
# print("Output handler call async")
|
|
362
|
+
return await self.handler.call_async(
|
|
363
|
+
messages=messages,
|
|
364
|
+
model=model,
|
|
365
|
+
response_model=response_model,
|
|
366
|
+
temperature=lm_config.get(
|
|
367
|
+
"temperature", SPECIAL_BASE_TEMPS.get(model, 0.0)
|
|
368
|
+
),
|
|
369
|
+
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
def call_sync(
|
|
373
|
+
self,
|
|
374
|
+
messages: List[Dict[str, Any]],
|
|
375
|
+
model: str,
|
|
376
|
+
response_model: BaseModel,
|
|
377
|
+
use_ephemeral_cache_only: bool = False,
|
|
378
|
+
lm_config: Dict[str, Any] = {},
|
|
379
|
+
) -> BaseModel:
|
|
380
|
+
return self.handler.call_sync(
|
|
381
|
+
messages=messages,
|
|
382
|
+
model=model,
|
|
383
|
+
response_model=response_model,
|
|
384
|
+
temperature=lm_config.get(
|
|
385
|
+
"temperature", SPECIAL_BASE_TEMPS.get(model, 0.0)
|
|
386
|
+
),
|
|
387
|
+
use_ephemeral_cache_only=use_ephemeral_cache_only,
|
|
388
|
+
)
|