synth-ai 0.0.0.dev3__py3-none-any.whl → 0.1.0.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (48) hide show
  1. public_tests/synth_sdk.py +389 -0
  2. public_tests/test_agent.py +538 -0
  3. public_tests/test_recursive_structured_outputs.py +180 -0
  4. public_tests/test_structured_outputs.py +100 -0
  5. synth_ai/__init__.py +1 -0
  6. synth_ai/zyk/__init__.py +3 -0
  7. synth_ai/zyk/lms/__init__.py +0 -0
  8. synth_ai/zyk/lms/caching/__init__.py +0 -0
  9. synth_ai/zyk/lms/caching/constants.py +1 -0
  10. synth_ai/zyk/lms/caching/dbs.py +0 -0
  11. synth_ai/zyk/lms/caching/ephemeral.py +50 -0
  12. synth_ai/zyk/lms/caching/handler.py +92 -0
  13. synth_ai/zyk/lms/caching/initialize.py +13 -0
  14. synth_ai/zyk/lms/caching/persistent.py +55 -0
  15. synth_ai/zyk/lms/config.py +8 -0
  16. synth_ai/zyk/lms/core/__init__.py +0 -0
  17. synth_ai/zyk/lms/core/all.py +35 -0
  18. synth_ai/zyk/lms/core/exceptions.py +9 -0
  19. synth_ai/zyk/lms/core/main.py +245 -0
  20. synth_ai/zyk/lms/core/vendor_clients.py +60 -0
  21. synth_ai/zyk/lms/cost/__init__.py +0 -0
  22. synth_ai/zyk/lms/cost/monitor.py +1 -0
  23. synth_ai/zyk/lms/cost/statefulness.py +1 -0
  24. synth_ai/zyk/lms/structured_outputs/__init__.py +0 -0
  25. synth_ai/zyk/lms/structured_outputs/handler.py +388 -0
  26. synth_ai/zyk/lms/structured_outputs/inject.py +185 -0
  27. synth_ai/zyk/lms/structured_outputs/rehabilitate.py +186 -0
  28. synth_ai/zyk/lms/vendors/__init__.py +0 -0
  29. synth_ai/zyk/lms/vendors/base.py +15 -0
  30. synth_ai/zyk/lms/vendors/constants.py +5 -0
  31. synth_ai/zyk/lms/vendors/core/__init__.py +0 -0
  32. synth_ai/zyk/lms/vendors/core/anthropic_api.py +191 -0
  33. synth_ai/zyk/lms/vendors/core/gemini_api.py +146 -0
  34. synth_ai/zyk/lms/vendors/core/openai_api.py +145 -0
  35. synth_ai/zyk/lms/vendors/local/__init__.py +0 -0
  36. synth_ai/zyk/lms/vendors/local/ollama.py +0 -0
  37. synth_ai/zyk/lms/vendors/openai_standard.py +141 -0
  38. synth_ai/zyk/lms/vendors/retries.py +3 -0
  39. synth_ai/zyk/lms/vendors/supported/__init__.py +0 -0
  40. synth_ai/zyk/lms/vendors/supported/deepseek.py +18 -0
  41. synth_ai/zyk/lms/vendors/supported/together.py +11 -0
  42. {synth_ai-0.0.0.dev3.dist-info → synth_ai-0.1.0.dev6.dist-info}/METADATA +1 -1
  43. synth_ai-0.1.0.dev6.dist-info/RECORD +46 -0
  44. synth_ai-0.1.0.dev6.dist-info/top_level.txt +2 -0
  45. synth_ai-0.0.0.dev3.dist-info/RECORD +0 -6
  46. synth_ai-0.0.0.dev3.dist-info/top_level.txt +0 -1
  47. {synth_ai-0.0.0.dev3.dist-info → synth_ai-0.1.0.dev6.dist-info}/LICENSE +0 -0
  48. {synth_ai-0.0.0.dev3.dist-info → synth_ai-0.1.0.dev6.dist-info}/WHEEL +0 -0
@@ -0,0 +1,245 @@
1
+ from typing import Any, Dict, List, Literal, Optional
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+ from synth_ai.zyk.lms.core.exceptions import StructuredOutputCoercionFailureException
6
+ from synth_ai.zyk.lms.core.vendor_clients import (
7
+ anthropic_naming_regexes,
8
+ get_client,
9
+ openai_naming_regexes,
10
+ )
11
+ from synth_ai.zyk.lms.structured_outputs.handler import StructuredOutputHandler
12
+ from synth_ai.zyk.lms.vendors.base import VendorBase
13
+
14
+
15
+ def build_messages(
16
+ sys_msg: str,
17
+ user_msg: str,
18
+ images_bytes: List = [],
19
+ model_name: Optional[str] = None,
20
+ ) -> List[Dict]:
21
+ if len(images_bytes) > 0 and any(
22
+ regex.match(model_name) for regex in openai_naming_regexes
23
+ ):
24
+ return [
25
+ {"role": "system", "content": sys_msg},
26
+ {
27
+ "role": "user",
28
+ "content": [{"type": "text", "text": user_msg}]
29
+ + [
30
+ {
31
+ "type": "image_url",
32
+ "image_url": {"url": f"data:image/jpeg;base64,{image_bytes}"},
33
+ }
34
+ for image_bytes in images_bytes
35
+ ],
36
+ },
37
+ ]
38
+ elif len(images_bytes) > 0 and any(
39
+ regex.match(model_name) for regex in anthropic_naming_regexes
40
+ ):
41
+ system_info = {"role": "system", "content": sys_msg}
42
+ user_info = {
43
+ "role": "user",
44
+ "content": [{"type": "text", "text": user_msg}]
45
+ + [
46
+ {
47
+ "type": "image",
48
+ "source": {
49
+ "type": "base64",
50
+ "media_type": "image/png",
51
+ "data": image_bytes,
52
+ },
53
+ }
54
+ for image_bytes in images_bytes
55
+ ],
56
+ }
57
+ return [system_info, user_info]
58
+ elif len(images_bytes) > 0:
59
+ raise ValueError("Images are not yet supported for this model")
60
+ else:
61
+ return [
62
+ {"role": "system", "content": sys_msg},
63
+ {"role": "user", "content": user_msg},
64
+ ]
65
+
66
+
67
+ class LM:
68
+ # if str
69
+ model_name: str
70
+ client: VendorBase
71
+ lm_config: Dict[str, Any]
72
+ structured_output_handler: StructuredOutputHandler
73
+
74
+ def __init__(
75
+ self,
76
+ model_name: str,
77
+ formatting_model_name: str,
78
+ temperature: float,
79
+ max_retries: Literal["None", "Few", "Many"] = "Few",
80
+ structured_output_mode: Literal[
81
+ "stringified_json", "forced_json"
82
+ ] = "stringified_json",
83
+ synth_logging: bool = True,
84
+ ):
85
+ # print("Structured output mode", structured_output_mode)
86
+ self.client = get_client(
87
+ model_name,
88
+ with_formatting=structured_output_mode == "forced_json",
89
+ synth_logging=synth_logging,
90
+ )
91
+ # print(self.client.__class__)
92
+
93
+ formatting_client = get_client(formatting_model_name, with_formatting=True)
94
+
95
+ max_retries_dict = {"None": 0, "Few": 2, "Many": 5}
96
+ self.structured_output_handler = StructuredOutputHandler(
97
+ self.client,
98
+ formatting_client,
99
+ structured_output_mode,
100
+ {"max_retries": max_retries_dict.get(max_retries, 2)},
101
+ )
102
+ self.backup_structured_output_handler = StructuredOutputHandler(
103
+ self.client,
104
+ formatting_client,
105
+ "forced_json",
106
+ {"max_retries": max_retries_dict.get(max_retries, 2)},
107
+ )
108
+ self.lm_config = {"temperature": temperature}
109
+ self.model_name = model_name
110
+
111
+ def respond_sync(
112
+ self,
113
+ system_message: Optional[str] = None,
114
+ user_message: Optional[str] = None,
115
+ messages: Optional[List[Dict]] = None,
116
+ images_as_bytes: List[Any] = [],
117
+ response_model: Optional[BaseModel] = None,
118
+ use_ephemeral_cache_only: bool = False,
119
+ ):
120
+ assert (system_message is None) == (
121
+ user_message is None
122
+ ), "Must provide both system_message and user_message or neither"
123
+ assert (
124
+ (messages is None) != (system_message is None)
125
+ ), "Must provide either messages or system_message/user_message pair, but not both"
126
+
127
+ if messages is None:
128
+ messages = build_messages(
129
+ system_message, user_message, images_as_bytes, self.model_name
130
+ )
131
+
132
+ if response_model:
133
+ try:
134
+ return self.structured_output_handler.call_sync(
135
+ messages,
136
+ model=self.model_name,
137
+ lm_config=self.lm_config,
138
+ response_model=response_model,
139
+ use_ephemeral_cache_only=use_ephemeral_cache_only,
140
+ )
141
+ except StructuredOutputCoercionFailureException:
142
+ # print("Falling back to backup handler")
143
+ return self.backup_structured_output_handler.call_sync(
144
+ messages,
145
+ model=self.model_name,
146
+ lm_config=self.lm_config,
147
+ response_model=response_model,
148
+ use_ephemeral_cache_only=use_ephemeral_cache_only,
149
+ )
150
+ else:
151
+ return self.client._hit_api_sync(
152
+ messages=messages,
153
+ model=self.model_name,
154
+ lm_config=self.lm_config,
155
+ use_ephemeral_cache_only=use_ephemeral_cache_only,
156
+ )
157
+
158
+ async def respond_async(
159
+ self,
160
+ system_message: Optional[str] = None,
161
+ user_message: Optional[str] = None,
162
+ messages: Optional[List[Dict]] = None,
163
+ images_as_bytes: List[Any] = [],
164
+ response_model: Optional[BaseModel] = None,
165
+ use_ephemeral_cache_only: bool = False,
166
+ ):
167
+ # "In respond_async")
168
+ assert (system_message is None) == (
169
+ user_message is None
170
+ ), "Must provide both system_message and user_message or neither"
171
+ assert (
172
+ (messages is None) != (system_message is None)
173
+ ), "Must provide either messages or system_message/user_message pair, but not both"
174
+
175
+ if messages is None:
176
+ messages = build_messages(
177
+ system_message, user_message, images_as_bytes, self.model_name
178
+ )
179
+
180
+ if response_model:
181
+ try:
182
+ # "Trying structured output handler")
183
+ return await self.structured_output_handler.call_async(
184
+ messages,
185
+ model=self.model_name,
186
+ lm_config=self.lm_config,
187
+ response_model=response_model,
188
+ use_ephemeral_cache_only=use_ephemeral_cache_only,
189
+ )
190
+ except StructuredOutputCoercionFailureException:
191
+ # print("Falling back to backup handler")
192
+ return await self.backup_structured_output_handler.call_async(
193
+ messages,
194
+ model=self.model_name,
195
+ lm_config=self.lm_config,
196
+ response_model=response_model,
197
+ use_ephemeral_cache_only=use_ephemeral_cache_only,
198
+ )
199
+ else:
200
+ # print("Calling API no response model")
201
+ return await self.client._hit_api_async(
202
+ messages=messages,
203
+ model=self.model_name,
204
+ lm_config=self.lm_config,
205
+ use_ephemeral_cache_only=use_ephemeral_cache_only,
206
+ )
207
+
208
+
209
+ if __name__ == "__main__":
210
+ import asyncio
211
+
212
+ # Update json instructions to handle nested pydantic?
213
+ class Thought(BaseModel):
214
+ argument_keys: List[str] = Field(description="The keys of the arguments")
215
+ argument_values: List[str] = Field(
216
+ description="Stringified JSON for the values of the arguments"
217
+ )
218
+
219
+ class TestModel(BaseModel):
220
+ emotion: str = Field(description="The emotion expressed")
221
+ concern: str = Field(description="The concern expressed")
222
+ action: str = Field(description="The action to be taken")
223
+ thought: Thought = Field(description="The thought process")
224
+
225
+ class Config:
226
+ schema_extra = {"required": ["thought", "emotion", "concern", "action"]}
227
+
228
+ lm = LM(
229
+ model_name="gpt-4o-mini",
230
+ formatting_model_name="gpt-4o-mini",
231
+ temperature=1,
232
+ max_retries="Few",
233
+ structured_output_mode="forced_json",
234
+ )
235
+ print(
236
+ asyncio.run(
237
+ lm.respond_async(
238
+ system_message="You are a helpful assistant ",
239
+ user_message="Hello, how are you?",
240
+ images_as_bytes=[],
241
+ response_model=TestModel,
242
+ use_ephemeral_cache_only=False,
243
+ )
244
+ )
245
+ )
@@ -0,0 +1,60 @@
1
+ import re
2
+ from typing import Any, List, Pattern
3
+
4
+ from synth_ai.zyk.lms.core.all import (
5
+ AnthropicClient,
6
+ DeepSeekClient,
7
+ GeminiClient,
8
+ # OpenAIClient,
9
+ OpenAIStructuredOutputClient,
10
+ TogetherClient,
11
+ )
12
+
13
+ openai_naming_regexes: List[Pattern] = [
14
+ re.compile(r"^(ft:)?(o[1,3](-.*)?|gpt-.*)$"),
15
+ ]
16
+ openai_formatting_model_regexes: List[Pattern] = [
17
+ re.compile(r"^(ft:)?gpt-4o(-.*)?$"),
18
+ ]
19
+ anthropic_naming_regexes: List[Pattern] = [
20
+ re.compile(r"^claude-.*$"),
21
+ ]
22
+ gemini_naming_regexes: List[Pattern] = [
23
+ re.compile(r"^gemini-.*$"),
24
+ ]
25
+ deepseek_naming_regexes: List[Pattern] = [
26
+ re.compile(r"^deepseek-.*$"),
27
+ ]
28
+ together_naming_regexes: List[Pattern] = [
29
+ re.compile(r"^.*\/.*$"),
30
+ ]
31
+
32
+
33
+ def get_client(
34
+ model_name: str,
35
+ with_formatting: bool = False,
36
+ synth_logging: bool = True,
37
+ ) -> Any:
38
+ # print("With formatting", with_formatting)
39
+ if any(regex.match(model_name) for regex in openai_naming_regexes):
40
+ # print("Returning OpenAIStructuredOutputClient")
41
+ return OpenAIStructuredOutputClient(
42
+ synth_logging=synth_logging,
43
+ )
44
+ elif any(regex.match(model_name) for regex in anthropic_naming_regexes):
45
+ if with_formatting:
46
+ client = AnthropicClient()
47
+ client._hit_api_async_structured_output = OpenAIStructuredOutputClient(
48
+ synth_logging=synth_logging
49
+ )._hit_api_async
50
+ return client
51
+ else:
52
+ return AnthropicClient()
53
+ elif any(regex.match(model_name) for regex in gemini_naming_regexes):
54
+ return GeminiClient()
55
+ elif any(regex.match(model_name) for regex in deepseek_naming_regexes):
56
+ return DeepSeekClient()
57
+ elif any(regex.match(model_name) for regex in together_naming_regexes):
58
+ return TogetherClient()
59
+ else:
60
+ raise ValueError(f"Invalid model name: {model_name}")
File without changes
@@ -0,0 +1 @@
1
+ #TODO
@@ -0,0 +1 @@
1
+ # Maybe some kind of ephemeral cache
File without changes
@@ -0,0 +1,388 @@
1
+ import time
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Callable, Dict, List, Literal, Optional, Union
4
+
5
+ from pydantic import BaseModel
6
+
7
+ from synth_ai.zyk.lms.core.exceptions import StructuredOutputCoercionFailureException
8
+ from synth_ai.zyk.lms.structured_outputs.inject import (
9
+ inject_structured_output_instructions,
10
+ )
11
+ from synth_ai.zyk.lms.structured_outputs.rehabilitate import (
12
+ fix_errant_forced_async,
13
+ fix_errant_forced_sync,
14
+ fix_errant_stringified_json_async,
15
+ fix_errant_stringified_json_sync,
16
+ pull_out_structured_output,
17
+ )
18
+ from synth_ai.zyk.lms.vendors.base import VendorBase
19
+ from synth_ai.zyk.lms.vendors.constants import SPECIAL_BASE_TEMPS
20
+
21
+
22
+ class StructuredHandlerBase(ABC):
23
+ core_client: VendorBase
24
+ retry_client: VendorBase
25
+ handler_params: Dict[str, Any]
26
+ structured_output_mode: Literal["stringified_json", "forced_json"]
27
+
28
+ def __init__(
29
+ self,
30
+ core_client: VendorBase,
31
+ retry_client: VendorBase,
32
+ handler_params: Optional[Dict[str, Any]] = None,
33
+ structured_output_mode: Literal[
34
+ "stringified_json", "forced_json"
35
+ ] = "stringified_json",
36
+ ):
37
+ self.core_client = core_client
38
+ self.retry_client = retry_client
39
+ self.handler_params = (
40
+ handler_params if handler_params is not None else {"retries": 3}
41
+ )
42
+ self.structured_output_mode = structured_output_mode
43
+
44
+ async def call_async(
45
+ self,
46
+ messages: List[Dict[str, Any]],
47
+ model: str,
48
+ response_model: BaseModel,
49
+ temperature: float = 0.0,
50
+ use_ephemeral_cache_only: bool = False,
51
+ ) -> BaseModel:
52
+ if temperature == 0.0:
53
+ temperature = SPECIAL_BASE_TEMPS.get(model, 0.0)
54
+ # print("Calling from base")
55
+ return await self._process_call_async(
56
+ messages=messages,
57
+ model=model,
58
+ response_model=response_model,
59
+ api_call_method=self.core_client._hit_api_async_structured_output
60
+ if (not not response_model and self.structured_output_mode == "forced_json")
61
+ else self.core_client._hit_api_async,
62
+ temperature=temperature,
63
+ use_ephemeral_cache_only=use_ephemeral_cache_only,
64
+ )
65
+
66
+ def call_sync(
67
+ self,
68
+ messages: List[Dict[str, Any]],
69
+ response_model: BaseModel,
70
+ model: str,
71
+ temperature: float = 0.0,
72
+ use_ephemeral_cache_only: bool = False,
73
+ ) -> BaseModel:
74
+ if temperature == 0.0:
75
+ temperature = SPECIAL_BASE_TEMPS.get(model, 0.0)
76
+ return self._process_call_sync(
77
+ messages=messages,
78
+ model=model,
79
+ response_model=response_model,
80
+ api_call_method=self.core_client._hit_api_sync_structured_output
81
+ if (not not response_model and self.structured_output_mode == "forced_json")
82
+ else self.core_client._hit_api_sync,
83
+ temperature=temperature,
84
+ use_ephemeral_cache_only=use_ephemeral_cache_only,
85
+ )
86
+
87
+ @abstractmethod
88
+ async def _process_call_async(
89
+ self,
90
+ messages: List[Dict[str, Any]],
91
+ model: str,
92
+ response_model: BaseModel,
93
+ api_call_method,
94
+ use_ephemeral_cache_only: bool = False,
95
+ ) -> BaseModel:
96
+ pass
97
+
98
+ @abstractmethod
99
+ def _process_call_sync(
100
+ self,
101
+ messages: List[Dict[str, Any]],
102
+ model: str,
103
+ response_model: BaseModel,
104
+ api_call_method,
105
+ use_ephemeral_cache_only: bool = False,
106
+ ) -> BaseModel:
107
+ pass
108
+
109
+
110
+ class StringifiedJSONHandler(StructuredHandlerBase):
111
+ core_client: VendorBase
112
+ retry_client: VendorBase
113
+ handler_params: Dict[str, Any]
114
+
115
+ def __init__(
116
+ self,
117
+ core_client: VendorBase,
118
+ retry_client: VendorBase,
119
+ handler_params: Dict[str, Any] = {"retries": 3},
120
+ ):
121
+ super().__init__(
122
+ core_client,
123
+ retry_client,
124
+ handler_params,
125
+ structured_output_mode="stringified_json",
126
+ )
127
+
128
+ async def _process_call_async(
129
+ self,
130
+ messages: List[Dict[str, Any]],
131
+ model: str,
132
+ response_model: BaseModel,
133
+ temperature: float,
134
+ api_call_method: Callable,
135
+ use_ephemeral_cache_only: bool = False,
136
+ ) -> BaseModel:
137
+ # print("In _process_call_async")
138
+ assert isinstance(
139
+ api_call_method, Callable
140
+ ), "api_call_method must be a callable"
141
+ assert (
142
+ response_model is not None
143
+ ), "Don't use this handler for unstructured outputs"
144
+ remaining_retries = self.handler_params.get("retries", 2)
145
+ previously_failed_error_messages = []
146
+ structured_output = None
147
+
148
+ while remaining_retries > 0:
149
+ messages_with_json_formatting_instructions = (
150
+ inject_structured_output_instructions(
151
+ messages=messages,
152
+ response_model=response_model,
153
+ previously_failed_error_messages=previously_failed_error_messages,
154
+ )
155
+ )
156
+ t0 = time.time()
157
+ raw_text_response_or_cached_hit = await api_call_method(
158
+ messages=messages_with_json_formatting_instructions,
159
+ model=model,
160
+ lm_config={"response_model": None, "temperature": temperature},
161
+ use_ephemeral_cache_only=use_ephemeral_cache_only,
162
+ )
163
+ # print(f"Time to get response: {time.time() - t0}")
164
+ if not isinstance(raw_text_response_or_cached_hit, str):
165
+ return raw_text_response_or_cached_hit
166
+ else:
167
+ raw_text_response = raw_text_response_or_cached_hit
168
+ try:
169
+ structured_output = pull_out_structured_output(
170
+ raw_text_response, response_model
171
+ )
172
+ break
173
+ # except Exception as e:
174
+ # try:
175
+ # structured_output = await fix_errant_stringified_json_async(raw_text_response, response_model)
176
+ # break
177
+ except Exception as e:
178
+ try:
179
+ # t0 = time.time()
180
+ # print(f"Got error {e}, attempting to fix")
181
+ structured_output = await fix_errant_forced_async(
182
+ messages_with_json_formatting_instructions,
183
+ raw_text_response,
184
+ response_model,
185
+ "gpt-4o-mini",
186
+ )
187
+
188
+ # print(f"Time to fix: {time.time() - t0}")
189
+ break
190
+ except Exception as e:
191
+ previously_failed_error_messages.append(
192
+ f"Generated attempt and got error. Attempt:\n\n{raw_text_response}\n\nError:\n\n{e}"
193
+ )
194
+ remaining_retries -= 1
195
+ if structured_output is None:
196
+ raise StructuredOutputCoercionFailureException(
197
+ "Failed to get structured output"
198
+ )
199
+ return structured_output
200
+
201
+ def _process_call_sync(
202
+ self,
203
+ messages: List[Dict[str, Any]],
204
+ model: str,
205
+ response_model: BaseModel,
206
+ temperature: float,
207
+ api_call_method: Callable,
208
+ use_ephemeral_cache_only: bool = False,
209
+ ) -> BaseModel:
210
+ assert isinstance(
211
+ api_call_method, Callable
212
+ ), "api_call_method must be a callable"
213
+ assert (
214
+ response_model is not None
215
+ ), "Don't use this handler for unstructured outputs"
216
+ remaining_retries = self.handler_params.get("retries", 2)
217
+ previously_failed_error_messages = []
218
+ structured_output = None
219
+
220
+ while remaining_retries > 0:
221
+ messages_with_json_formatting_instructions = (
222
+ inject_structured_output_instructions(
223
+ messages=messages,
224
+ response_model=response_model,
225
+ previously_failed_error_messages=previously_failed_error_messages,
226
+ )
227
+ )
228
+ # t0 = time.time()
229
+ raw_text_response_or_cached_hit = api_call_method(
230
+ messages=messages_with_json_formatting_instructions,
231
+ model=model,
232
+ lm_config={"response_model": None, "temperature": temperature},
233
+ use_ephemeral_cache_only=use_ephemeral_cache_only,
234
+ )
235
+ # print(f"Time to get response: {time.time() - t0}")
236
+ if not isinstance(raw_text_response_or_cached_hit, str):
237
+ return raw_text_response_or_cached_hit
238
+ else:
239
+ raw_text_response = raw_text_response_or_cached_hit
240
+ try:
241
+ structured_output = pull_out_structured_output(
242
+ raw_text_response, response_model
243
+ )
244
+ break
245
+ # except Exception:
246
+ # try:
247
+ # structured_output = fix_errant_stringified_json_sync(raw_text_response, response_model)
248
+ # break
249
+ except Exception as e:
250
+ try:
251
+ # t0 = time.time()
252
+ # print(f"Got error {e}, attempting to fix")
253
+ structured_output = fix_errant_forced_sync(
254
+ raw_text_response, response_model, "gpt-4o-mini"
255
+ )
256
+ break
257
+ # print(f"Time to fix: {time.time() - t0}")
258
+ except Exception as e:
259
+ previously_failed_error_messages.append(
260
+ f"Generated attempt and got error. Attempt:\n\n{raw_text_response}\n\nError:\n\n{e}"
261
+ )
262
+ remaining_retries -= 1
263
+ if structured_output is None:
264
+ raise StructuredOutputCoercionFailureException(
265
+ "Failed to get structured output"
266
+ )
267
+ return structured_output
268
+
269
+
270
+ class ForcedJSONHandler(StructuredHandlerBase):
271
+ core_client: VendorBase
272
+ retry_client: VendorBase
273
+ handler_params: Dict[str, Any]
274
+
275
+ def __init__(
276
+ self,
277
+ core_client: VendorBase,
278
+ retry_client: VendorBase,
279
+ handler_params: Dict[str, Any] = {},
280
+ ):
281
+ super().__init__(
282
+ core_client,
283
+ retry_client,
284
+ handler_params,
285
+ structured_output_mode="forced_json",
286
+ )
287
+
288
+ async def _process_call_async(
289
+ self,
290
+ messages: List[Dict[str, Any]],
291
+ model: str,
292
+ response_model: BaseModel,
293
+ api_call_method: Callable,
294
+ temperature: float = 0.0,
295
+ use_ephemeral_cache_only: bool = False,
296
+ ) -> BaseModel:
297
+ # print("Forced JSON")
298
+ assert (
299
+ response_model is not None
300
+ ), "Don't use this handler for unstructured outputs"
301
+ return await api_call_method(
302
+ messages=messages,
303
+ model=model,
304
+ response_model=response_model,
305
+ temperature=temperature,
306
+ use_ephemeral_cache_only=use_ephemeral_cache_only,
307
+ )
308
+
309
+ def _process_call_sync(
310
+ self,
311
+ messages: List[Dict[str, Any]],
312
+ model: str,
313
+ response_model: BaseModel,
314
+ api_call_method: Callable,
315
+ temperature: float = 0.0,
316
+ use_ephemeral_cache_only: bool = False,
317
+ ) -> BaseModel:
318
+ assert (
319
+ response_model is not None
320
+ ), "Don't use this handler for unstructured outputs"
321
+ return api_call_method(
322
+ messages=messages,
323
+ model=model,
324
+ response_model=response_model,
325
+ temperature=temperature,
326
+ use_ephemeral_cache_only=use_ephemeral_cache_only,
327
+ )
328
+
329
+
330
+ class StructuredOutputHandler:
331
+ handler: Union[StringifiedJSONHandler, ForcedJSONHandler]
332
+ mode: Literal["stringified_json", "forced_json"]
333
+ handler_params: Dict[str, Any]
334
+
335
+ def __init__(
336
+ self,
337
+ core_client: VendorBase,
338
+ retry_client: VendorBase,
339
+ mode: Literal["stringified_json", "forced_json"],
340
+ handler_params: Dict[str, Any] = {},
341
+ ):
342
+ self.mode = mode
343
+ if self.mode == "stringified_json":
344
+ self.handler = StringifiedJSONHandler(
345
+ core_client, retry_client, handler_params
346
+ )
347
+ elif self.mode == "forced_json":
348
+ # print("Forced JSON")
349
+ self.handler = ForcedJSONHandler(core_client, retry_client, handler_params)
350
+ else:
351
+ raise ValueError(f"Invalid mode: {mode}")
352
+
353
+ async def call_async(
354
+ self,
355
+ messages: List[Dict[str, Any]],
356
+ model: str,
357
+ response_model: BaseModel,
358
+ use_ephemeral_cache_only: bool = False,
359
+ lm_config: Dict[str, Any] = {},
360
+ ) -> BaseModel:
361
+ # print("Output handler call async")
362
+ return await self.handler.call_async(
363
+ messages=messages,
364
+ model=model,
365
+ response_model=response_model,
366
+ temperature=lm_config.get(
367
+ "temperature", SPECIAL_BASE_TEMPS.get(model, 0.0)
368
+ ),
369
+ use_ephemeral_cache_only=use_ephemeral_cache_only,
370
+ )
371
+
372
+ def call_sync(
373
+ self,
374
+ messages: List[Dict[str, Any]],
375
+ model: str,
376
+ response_model: BaseModel,
377
+ use_ephemeral_cache_only: bool = False,
378
+ lm_config: Dict[str, Any] = {},
379
+ ) -> BaseModel:
380
+ return self.handler.call_sync(
381
+ messages=messages,
382
+ model=model,
383
+ response_model=response_model,
384
+ temperature=lm_config.get(
385
+ "temperature", SPECIAL_BASE_TEMPS.get(model, 0.0)
386
+ ),
387
+ use_ephemeral_cache_only=use_ephemeral_cache_only,
388
+ )