speedy-utils 1.1.6__py3-none-any.whl → 1.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +1 -5
- llm_utils/chat_format/transform.py +9 -9
- llm_utils/group_messages.py +1 -1
- llm_utils/lm/async_lm/__init__.py +6 -1
- llm_utils/lm/async_lm/_utils.py +7 -4
- llm_utils/lm/async_lm/async_llm_task.py +465 -110
- llm_utils/lm/async_lm/async_lm.py +273 -665
- llm_utils/lm/async_lm/async_lm_base.py +405 -0
- llm_utils/lm/async_lm/lm_specific.py +136 -0
- llm_utils/lm/utils.py +1 -3
- llm_utils/scripts/vllm_load_balancer.py +49 -37
- speedy_utils/__init__.py +3 -1
- speedy_utils/common/notebook_utils.py +4 -4
- speedy_utils/common/report_manager.py +2 -3
- speedy_utils/common/utils_cache.py +233 -3
- speedy_utils/common/utils_io.py +2 -0
- speedy_utils/scripts/mpython.py +1 -3
- {speedy_utils-1.1.6.dist-info → speedy_utils-1.1.7.dist-info}/METADATA +1 -1
- speedy_utils-1.1.7.dist-info/RECORD +39 -0
- llm_utils/lm/chat_html.py +0 -246
- llm_utils/lm/lm_json.py +0 -68
- llm_utils/lm/sync_lm.py +0 -943
- speedy_utils-1.1.6.dist-info/RECORD +0 -40
- {speedy_utils-1.1.6.dist-info → speedy_utils-1.1.7.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.6.dist-info → speedy_utils-1.1.7.dist-info}/entry_points.txt +0 -0
|
@@ -1,62 +1,51 @@
|
|
|
1
1
|
# from ._utils import *
|
|
2
|
-
import base64
|
|
3
|
-
import hashlib
|
|
4
|
-
import json
|
|
5
|
-
import os
|
|
6
2
|
from typing import (
|
|
7
3
|
Any,
|
|
8
|
-
Dict,
|
|
9
4
|
List,
|
|
10
5
|
Literal,
|
|
11
6
|
Optional,
|
|
12
|
-
Sequence,
|
|
13
7
|
Type,
|
|
14
|
-
Union,
|
|
15
8
|
cast,
|
|
16
|
-
overload,
|
|
17
9
|
)
|
|
18
10
|
|
|
19
|
-
from httpx import URL
|
|
20
11
|
from loguru import logger
|
|
21
|
-
from openai import
|
|
22
|
-
from openai.pagination import AsyncPage as AsyncSyncPage
|
|
23
|
-
|
|
24
|
-
# from openai.pagination import AsyncSyncPage
|
|
25
|
-
from openai.types.chat import (
|
|
26
|
-
ChatCompletionAssistantMessageParam,
|
|
27
|
-
ChatCompletionMessageParam,
|
|
28
|
-
ChatCompletionSystemMessageParam,
|
|
29
|
-
ChatCompletionToolMessageParam,
|
|
30
|
-
ChatCompletionUserMessageParam,
|
|
31
|
-
)
|
|
32
|
-
from openai.types.model import Model
|
|
12
|
+
from openai import AuthenticationError, BadRequestError, RateLimitError
|
|
33
13
|
from pydantic import BaseModel
|
|
34
|
-
|
|
35
14
|
from speedy_utils import jloads
|
|
36
15
|
|
|
16
|
+
# from llm_utils.lm.async_lm.async_llm_task import OutputModelType
|
|
17
|
+
from llm_utils.lm.async_lm.async_lm_base import AsyncLMBase
|
|
18
|
+
|
|
37
19
|
from ._utils import (
|
|
38
20
|
LegacyMsgs,
|
|
39
21
|
Messages,
|
|
22
|
+
OutputModelType,
|
|
40
23
|
ParsedOutput,
|
|
41
24
|
RawMsgs,
|
|
42
|
-
TModel,
|
|
43
|
-
TParsed,
|
|
44
|
-
_blue,
|
|
45
|
-
_green,
|
|
46
|
-
_red,
|
|
47
|
-
_yellow,
|
|
48
|
-
get_tokenizer,
|
|
49
|
-
inspect_word_probs_async,
|
|
50
25
|
)
|
|
51
26
|
|
|
52
27
|
|
|
53
|
-
|
|
28
|
+
def jloads_safe(content: str) -> Any:
|
|
29
|
+
# if contain ```json, remove it
|
|
30
|
+
if "```json" in content:
|
|
31
|
+
content = content.split("```json")[1].strip().split("```")[0].strip()
|
|
32
|
+
try:
|
|
33
|
+
return jloads(content)
|
|
34
|
+
except Exception as e:
|
|
35
|
+
logger.error(
|
|
36
|
+
f"Failed to parse JSON content: {content[:100]}... with error: {e}"
|
|
37
|
+
)
|
|
38
|
+
raise ValueError(f"Invalid JSON content: {content}") from e
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class AsyncLM(AsyncLMBase):
|
|
54
42
|
"""Unified **async** language‑model wrapper with optional JSON parsing."""
|
|
55
43
|
|
|
56
44
|
def __init__(
|
|
57
45
|
self,
|
|
58
|
-
model: str
|
|
46
|
+
model: str,
|
|
59
47
|
*,
|
|
48
|
+
response_model: Optional[type[BaseModel]] = None,
|
|
60
49
|
temperature: float = 0.0,
|
|
61
50
|
max_tokens: int = 2_000,
|
|
62
51
|
host: str = "localhost",
|
|
@@ -64,167 +53,101 @@ class AsyncLM:
|
|
|
64
53
|
base_url: Optional[str] = None,
|
|
65
54
|
api_key: Optional[str] = None,
|
|
66
55
|
cache: bool = True,
|
|
56
|
+
think: Literal[True, False, None] = None,
|
|
57
|
+
add_json_schema_to_instruction: Optional[bool] = None,
|
|
58
|
+
use_beta: bool = False,
|
|
67
59
|
ports: Optional[List[int]] = None,
|
|
68
|
-
|
|
60
|
+
top_p: float = 1.0,
|
|
61
|
+
presence_penalty: float = 0.0,
|
|
62
|
+
top_k: int = 1,
|
|
63
|
+
repetition_penalty: float = 1.0,
|
|
64
|
+
frequency_penalty: Optional[float] = None,
|
|
69
65
|
) -> None:
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
self.openai_kwargs = openai_kwargs
|
|
78
|
-
self.do_cache = cache
|
|
79
|
-
self.ports = ports
|
|
80
|
-
self._init_port = port # <-- store the port provided at init
|
|
81
|
-
|
|
82
|
-
# Async client
|
|
83
|
-
|
|
84
|
-
@property
|
|
85
|
-
def client(self) -> AsyncOpenAI:
|
|
86
|
-
# if have multiple ports
|
|
87
|
-
if self.ports:
|
|
88
|
-
import random
|
|
89
|
-
|
|
90
|
-
port = random.choice(self.ports)
|
|
91
|
-
api_base = f"http://{self.host}:{port}/v1"
|
|
92
|
-
logger.debug(f"Using port: {port}")
|
|
93
|
-
else:
|
|
94
|
-
api_base = self.base_url or f"http://{self.host}:{self.port}/v1"
|
|
95
|
-
client = AsyncOpenAI(
|
|
96
|
-
api_key=self.api_key, base_url=api_base, **self.openai_kwargs
|
|
66
|
+
super().__init__(
|
|
67
|
+
host=host,
|
|
68
|
+
port=port,
|
|
69
|
+
ports=ports,
|
|
70
|
+
base_url=base_url,
|
|
71
|
+
cache=cache,
|
|
72
|
+
api_key=api_key,
|
|
97
73
|
)
|
|
98
|
-
return client
|
|
99
74
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
async def __call__(
|
|
116
|
-
self,
|
|
117
|
-
*,
|
|
118
|
-
prompt: str | None = ...,
|
|
119
|
-
messages: RawMsgs | None = ...,
|
|
120
|
-
response_format: Type[TModel],
|
|
121
|
-
return_openai_response: bool = ...,
|
|
122
|
-
**kwargs: Any,
|
|
123
|
-
) -> TModel: ...
|
|
124
|
-
|
|
125
|
-
async def _set_model(self) -> None:
|
|
126
|
-
if not self.model:
|
|
127
|
-
models = await self.list_models(port=self.port, host=self.host)
|
|
128
|
-
self.model = models[0] if models else None
|
|
129
|
-
logger.info(
|
|
130
|
-
f"No model specified. Using the first available model. {self.model}"
|
|
131
|
-
)
|
|
132
|
-
|
|
133
|
-
async def __call__(
|
|
134
|
-
self,
|
|
135
|
-
prompt: Optional[str] = None,
|
|
136
|
-
messages: Optional[RawMsgs] = None,
|
|
137
|
-
response_format: Union[type[str], Type[BaseModel]] = str,
|
|
138
|
-
cache: Optional[bool] = None,
|
|
139
|
-
max_tokens: Optional[int] = None,
|
|
140
|
-
return_openai_response: bool = False,
|
|
141
|
-
**kwargs: Any,
|
|
142
|
-
):
|
|
143
|
-
if (prompt is None) == (messages is None):
|
|
144
|
-
raise ValueError("Provide *either* `prompt` or `messages` (but not both).")
|
|
145
|
-
|
|
146
|
-
if prompt is not None:
|
|
147
|
-
messages = [{"role": "user", "content": prompt}]
|
|
148
|
-
|
|
149
|
-
assert messages is not None
|
|
150
|
-
# assert self.model is not None, "Model must be set before calling."
|
|
151
|
-
await self._set_model()
|
|
152
|
-
|
|
153
|
-
openai_msgs: Messages = (
|
|
154
|
-
self._convert_messages(cast(LegacyMsgs, messages))
|
|
155
|
-
if isinstance(messages[0], dict)
|
|
156
|
-
else cast(Messages, messages)
|
|
157
|
-
)
|
|
158
|
-
|
|
159
|
-
kw = dict(
|
|
160
|
-
self.openai_kwargs,
|
|
161
|
-
temperature=self.temperature,
|
|
162
|
-
max_tokens=max_tokens or self.max_tokens,
|
|
75
|
+
# Model behavior options
|
|
76
|
+
self.response_model = response_model
|
|
77
|
+
self.think = think
|
|
78
|
+
self._use_beta = use_beta
|
|
79
|
+
self.add_json_schema_to_instruction = add_json_schema_to_instruction
|
|
80
|
+
if not use_beta:
|
|
81
|
+
self.add_json_schema_to_instruction = True
|
|
82
|
+
|
|
83
|
+
# Store all model-related parameters in model_kwargs
|
|
84
|
+
self.model_kwargs = dict(
|
|
85
|
+
model=model,
|
|
86
|
+
temperature=temperature,
|
|
87
|
+
max_tokens=max_tokens,
|
|
88
|
+
top_p=top_p,
|
|
89
|
+
presence_penalty=presence_penalty,
|
|
163
90
|
)
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
openai_msgs,
|
|
169
|
-
response_format=response_format,
|
|
170
|
-
use_cache=use_cache,
|
|
171
|
-
**kw,
|
|
91
|
+
self.extra_body = dict(
|
|
92
|
+
top_k=top_k,
|
|
93
|
+
repetition_penalty=repetition_penalty,
|
|
94
|
+
frequency_penalty=frequency_penalty,
|
|
172
95
|
)
|
|
173
96
|
|
|
174
|
-
|
|
175
|
-
response = raw_response
|
|
176
|
-
else:
|
|
177
|
-
response = self._parse_output(raw_response, response_format)
|
|
178
|
-
|
|
179
|
-
self._last_log = [prompt, messages, raw_response]
|
|
180
|
-
return response
|
|
181
|
-
|
|
182
|
-
# ------------------------------------------------------------------ #
|
|
183
|
-
# Model invocation (async)
|
|
184
|
-
# ------------------------------------------------------------------ #
|
|
185
|
-
async def _call_raw(
|
|
97
|
+
async def _unified_client_call(
|
|
186
98
|
self,
|
|
187
|
-
messages:
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
99
|
+
messages: list[dict],
|
|
100
|
+
extra_body: Optional[dict] = None,
|
|
101
|
+
cache_suffix: str = "",
|
|
102
|
+
) -> dict:
|
|
103
|
+
"""Unified method for all client interactions with caching and error handling."""
|
|
104
|
+
converted_messages = self._convert_messages(messages)
|
|
105
|
+
cache_key = None
|
|
106
|
+
completion = None
|
|
107
|
+
|
|
108
|
+
# Handle caching
|
|
109
|
+
if self._cache:
|
|
110
|
+
cache_data = {
|
|
111
|
+
"messages": converted_messages,
|
|
112
|
+
"model_kwargs": self.model_kwargs,
|
|
113
|
+
"extra_body": extra_body or {},
|
|
114
|
+
"cache_suffix": cache_suffix,
|
|
115
|
+
}
|
|
116
|
+
cache_key = self._cache_key(cache_data, {}, str)
|
|
117
|
+
completion = self._load_cache(cache_key)
|
|
118
|
+
|
|
119
|
+
# Check for cached error responses
|
|
120
|
+
if (
|
|
121
|
+
completion
|
|
122
|
+
and isinstance(completion, dict)
|
|
123
|
+
and "error" in completion
|
|
124
|
+
and completion["error"]
|
|
125
|
+
):
|
|
126
|
+
error_type = completion.get("error_type", "Unknown")
|
|
127
|
+
error_message = completion.get("error_message", "Cached error")
|
|
128
|
+
logger.warning(f"Found cached error ({error_type}): {error_message}")
|
|
129
|
+
raise ValueError(f"Cached {error_type}: {error_message}")
|
|
207
130
|
|
|
208
131
|
try:
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
messages
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
132
|
+
# Get completion from API if not cached
|
|
133
|
+
if not completion:
|
|
134
|
+
call_kwargs = {
|
|
135
|
+
"messages": converted_messages,
|
|
136
|
+
**self.model_kwargs,
|
|
137
|
+
}
|
|
138
|
+
if extra_body:
|
|
139
|
+
call_kwargs["extra_body"] = extra_body
|
|
140
|
+
|
|
141
|
+
completion = await self.client.chat.completions.create(**call_kwargs)
|
|
142
|
+
|
|
143
|
+
if hasattr(completion, "model_dump"):
|
|
144
|
+
completion = completion.model_dump()
|
|
145
|
+
if cache_key:
|
|
146
|
+
self._dump_cache(cache_key, completion)
|
|
222
147
|
|
|
223
148
|
except (AuthenticationError, RateLimitError, BadRequestError) as exc:
|
|
224
149
|
error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
|
|
225
150
|
logger.error(error_msg)
|
|
226
|
-
|
|
227
|
-
# Cache the error if it's a BadRequestError to avoid repeated calls
|
|
228
151
|
if isinstance(exc, BadRequestError) and cache_key:
|
|
229
152
|
error_response = {
|
|
230
153
|
"error": True,
|
|
@@ -234,153 +157,180 @@ class AsyncLM:
|
|
|
234
157
|
}
|
|
235
158
|
self._dump_cache(cache_key, error_response)
|
|
236
159
|
logger.debug(f"Cached BadRequestError for key: {cache_key}")
|
|
237
|
-
|
|
238
160
|
raise
|
|
239
161
|
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
162
|
+
return completion
|
|
163
|
+
|
|
164
|
+
async def _call_and_parse(
|
|
165
|
+
self,
|
|
166
|
+
messages: list[dict],
|
|
167
|
+
response_model: Type[OutputModelType],
|
|
168
|
+
json_schema: dict,
|
|
169
|
+
) -> tuple[dict, list[dict], OutputModelType]:
|
|
170
|
+
"""Unified call and parse with cache and error handling."""
|
|
171
|
+
if self._use_beta:
|
|
172
|
+
return await self._call_and_parse_with_beta(
|
|
173
|
+
messages, response_model, json_schema
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
choice = None
|
|
177
|
+
try:
|
|
178
|
+
# Use unified client call
|
|
179
|
+
completion = await self._unified_client_call(
|
|
180
|
+
messages,
|
|
181
|
+
extra_body={**self.extra_body},
|
|
182
|
+
cache_suffix=f"_parse_{response_model.__name__}",
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# Parse the response
|
|
186
|
+
choice = completion["choices"][0]["message"]
|
|
187
|
+
if "content" not in choice:
|
|
188
|
+
raise ValueError("Response choice must contain 'content' field.")
|
|
189
|
+
|
|
190
|
+
content = choice["content"]
|
|
191
|
+
if not content:
|
|
192
|
+
raise ValueError("Response content is empty")
|
|
193
|
+
|
|
194
|
+
parsed = response_model.model_validate(jloads_safe(content))
|
|
195
|
+
|
|
196
|
+
except Exception as e:
|
|
197
|
+
# Try fallback to beta mode if regular parsing fails
|
|
198
|
+
if not isinstance(
|
|
199
|
+
e, (AuthenticationError, RateLimitError, BadRequestError)
|
|
200
|
+
):
|
|
201
|
+
content = choice.get("content", "N/A") if choice else "N/A"
|
|
202
|
+
logger.info(
|
|
203
|
+
f"Regular parsing failed due to wrong format or content, now falling back to beta mode: {content=}, {e=}"
|
|
267
204
|
)
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
role="tool",
|
|
272
|
-
content=content,
|
|
273
|
-
tool_call_id=msg.get("tool_call_id") or "",
|
|
205
|
+
try:
|
|
206
|
+
return await self._call_and_parse_with_beta(
|
|
207
|
+
messages, response_model, json_schema
|
|
274
208
|
)
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
if response_format is str:
|
|
288
|
-
if isinstance(raw_response, dict) and "choices" in raw_response:
|
|
289
|
-
message = raw_response["choices"][0]["message"]
|
|
290
|
-
return message.get("content", "") or ""
|
|
291
|
-
return cast(str, raw_response)
|
|
292
|
-
|
|
293
|
-
model_cls = cast(Type[BaseModel], response_format)
|
|
294
|
-
|
|
295
|
-
if isinstance(raw_response, dict) and "choices" in raw_response:
|
|
296
|
-
message = raw_response["choices"][0]["message"]
|
|
297
|
-
if "parsed" in message:
|
|
298
|
-
return model_cls.model_validate(message["parsed"])
|
|
299
|
-
content = message.get("content")
|
|
300
|
-
if content is None:
|
|
301
|
-
raise ValueError("Model returned empty content")
|
|
302
|
-
try:
|
|
303
|
-
data = json.loads(content)
|
|
304
|
-
return model_cls.model_validate(data)
|
|
305
|
-
except Exception as exc:
|
|
306
|
-
raise ValueError(
|
|
307
|
-
f"Failed to parse model output as JSON:\n{content}"
|
|
308
|
-
) from exc
|
|
309
|
-
|
|
310
|
-
if isinstance(raw_response, model_cls):
|
|
311
|
-
return raw_response
|
|
312
|
-
if isinstance(raw_response, dict):
|
|
313
|
-
return model_cls.model_validate(raw_response)
|
|
209
|
+
except Exception as beta_e:
|
|
210
|
+
logger.warning(f"Beta mode fallback also failed: {beta_e}")
|
|
211
|
+
choice_info = choice if choice is not None else "N/A"
|
|
212
|
+
raise ValueError(
|
|
213
|
+
f"Failed to parse model response with both regular and beta modes. "
|
|
214
|
+
f"Regular error: {e}. Beta error: {beta_e}. "
|
|
215
|
+
f"Model response message: {choice_info}"
|
|
216
|
+
) from e
|
|
217
|
+
raise
|
|
218
|
+
|
|
219
|
+
assistant_msg = self._extract_assistant_message(choice)
|
|
220
|
+
full_messages = messages + [assistant_msg]
|
|
314
221
|
|
|
222
|
+
return completion, full_messages, cast(OutputModelType, parsed)
|
|
223
|
+
|
|
224
|
+
async def _call_and_parse_with_beta(
|
|
225
|
+
self,
|
|
226
|
+
messages: list[dict],
|
|
227
|
+
response_model: Type[OutputModelType],
|
|
228
|
+
json_schema: dict,
|
|
229
|
+
) -> tuple[dict, list[dict], OutputModelType]:
|
|
230
|
+
"""Call and parse for beta mode with guided JSON."""
|
|
231
|
+
choice = None
|
|
315
232
|
try:
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
233
|
+
# Use unified client call with guided JSON
|
|
234
|
+
completion = await self._unified_client_call(
|
|
235
|
+
messages,
|
|
236
|
+
extra_body={"guided_json": json_schema, **self.extra_body},
|
|
237
|
+
cache_suffix=f"_beta_parse_{response_model.__name__}",
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Parse the response
|
|
241
|
+
choice = completion["choices"][0]["message"]
|
|
242
|
+
parsed = self._parse_complete_output(completion, response_model)
|
|
243
|
+
|
|
244
|
+
except Exception as e:
|
|
245
|
+
choice_info = choice if choice is not None else "N/A"
|
|
319
246
|
raise ValueError(
|
|
320
|
-
f"
|
|
321
|
-
) from
|
|
247
|
+
f"Failed to parse model response: {e}\nModel response message: {choice_info}"
|
|
248
|
+
) from e
|
|
322
249
|
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
250
|
+
assistant_msg = self._extract_assistant_message(choice)
|
|
251
|
+
full_messages = messages + [assistant_msg]
|
|
252
|
+
|
|
253
|
+
return completion, full_messages, cast(OutputModelType, parsed)
|
|
254
|
+
|
|
255
|
+
def _extract_assistant_message(self, choice): # -> dict[str, str] | dict[str, Any]:
|
|
256
|
+
# TODO this current assume choice is a dict with "reasoning_content" and "content"
|
|
257
|
+
has_reasoning = False
|
|
258
|
+
if "reasoning_content" in choice and isinstance(
|
|
259
|
+
choice["reasoning_content"], str
|
|
260
|
+
):
|
|
261
|
+
reasoning_content = choice["reasoning_content"].strip()
|
|
262
|
+
has_reasoning = True
|
|
263
|
+
|
|
264
|
+
content = choice["content"]
|
|
265
|
+
_content = content.lstrip("\n")
|
|
266
|
+
if has_reasoning:
|
|
267
|
+
assistant_msg = {
|
|
268
|
+
"role": "assistant",
|
|
269
|
+
"content": f"<think>\n{reasoning_content}\n</think>\n\n{_content}",
|
|
270
|
+
}
|
|
271
|
+
else:
|
|
272
|
+
assistant_msg = {"role": "assistant", "content": _content}
|
|
273
|
+
|
|
274
|
+
return assistant_msg
|
|
275
|
+
|
|
276
|
+
async def __call__(
|
|
277
|
+
self,
|
|
278
|
+
prompt: Optional[str] = None,
|
|
279
|
+
messages: Optional[RawMsgs] = None,
|
|
280
|
+
): # -> tuple[Any | dict[Any, Any], list[ChatCompletionMessagePar...:# -> tuple[Any | dict[Any, Any], list[ChatCompletionMessagePar...:
|
|
281
|
+
"""Unified async call for language model, returns (assistant_message.model_dump(), messages)."""
|
|
282
|
+
if (prompt is None) == (messages is None):
|
|
283
|
+
raise ValueError("Provide *either* `prompt` or `messages` (but not both).")
|
|
284
|
+
|
|
285
|
+
if prompt is not None:
|
|
286
|
+
messages = [{"role": "user", "content": prompt}]
|
|
287
|
+
|
|
288
|
+
assert messages is not None
|
|
289
|
+
|
|
290
|
+
openai_msgs: Messages = (
|
|
291
|
+
self._convert_messages(cast(LegacyMsgs, messages))
|
|
292
|
+
if isinstance(messages[0], dict)
|
|
293
|
+
else cast(Messages, messages)
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
assert self.model_kwargs["model"] is not None, (
|
|
297
|
+
"Model must be set before making a call."
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
# Use unified client call
|
|
301
|
+
raw_response = await self._unified_client_call(
|
|
302
|
+
list(openai_msgs), cache_suffix="_call"
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
if hasattr(raw_response, "model_dump"):
|
|
306
|
+
raw_response = raw_response.model_dump() # type: ignore
|
|
307
|
+
|
|
308
|
+
# Extract the assistant's message
|
|
309
|
+
assistant_msg = raw_response["choices"][0]["message"]
|
|
310
|
+
# Build the full messages list (input + assistant reply)
|
|
311
|
+
full_messages = list(messages) + [
|
|
312
|
+
{"role": assistant_msg["role"], "content": assistant_msg["content"]}
|
|
313
|
+
]
|
|
314
|
+
# Return the OpenAI message as model_dump (if available) and the messages list
|
|
315
|
+
if hasattr(assistant_msg, "model_dump"):
|
|
316
|
+
msg_dump = assistant_msg.model_dump()
|
|
317
|
+
else:
|
|
318
|
+
msg_dump = dict(assistant_msg)
|
|
319
|
+
return msg_dump, full_messages
|
|
349
320
|
|
|
350
|
-
def _load_cache(self, key: str) -> Any | None:
|
|
351
|
-
path = self._cache_path(key)
|
|
352
|
-
if not os.path.exists(path):
|
|
353
|
-
return None
|
|
354
|
-
try:
|
|
355
|
-
with open(path) as fh:
|
|
356
|
-
return json.load(fh)
|
|
357
|
-
except Exception:
|
|
358
|
-
return None
|
|
359
|
-
|
|
360
|
-
# ------------------------------------------------------------------ #
|
|
361
|
-
# Missing methods from LM class
|
|
362
|
-
# ------------------------------------------------------------------ #
|
|
363
321
|
async def parse(
|
|
364
322
|
self,
|
|
365
|
-
response_model: Type[TParsed],
|
|
366
323
|
instruction,
|
|
367
324
|
prompt,
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
cache: Optional[bool] = None,
|
|
373
|
-
use_beta: bool = False,
|
|
374
|
-
**kwargs,
|
|
375
|
-
) -> ParsedOutput[TParsed]:
|
|
376
|
-
"""Parse response using guided JSON generation."""
|
|
377
|
-
|
|
378
|
-
if not use_beta:
|
|
379
|
-
assert add_json_schema_to_instruction, (
|
|
325
|
+
) -> ParsedOutput[BaseModel]:
|
|
326
|
+
"""Parse response using guided JSON generation. Returns (parsed.model_dump(), messages)."""
|
|
327
|
+
if not self._use_beta:
|
|
328
|
+
assert self.add_json_schema_to_instruction, (
|
|
380
329
|
"add_json_schema_to_instruction must be True when use_beta is False. otherwise model will not be able to parse the response."
|
|
381
330
|
)
|
|
382
331
|
|
|
383
|
-
|
|
332
|
+
assert self.response_model is not None, "response_model must be set at init."
|
|
333
|
+
json_schema = self.response_model.model_json_schema()
|
|
384
334
|
|
|
385
335
|
# Build system message content in a single, clear block
|
|
386
336
|
assert instruction is not None, "Instruction must be provided."
|
|
@@ -388,122 +338,32 @@ class AsyncLM:
|
|
|
388
338
|
system_content = instruction
|
|
389
339
|
|
|
390
340
|
# Add schema if needed
|
|
391
|
-
system_content = self.
|
|
392
|
-
response_model,
|
|
393
|
-
add_json_schema_to_instruction,
|
|
341
|
+
system_content = self.build_system_prompt(
|
|
342
|
+
self.response_model,
|
|
343
|
+
self.add_json_schema_to_instruction,
|
|
394
344
|
json_schema,
|
|
395
345
|
system_content,
|
|
396
|
-
think=think,
|
|
346
|
+
think=self.think,
|
|
397
347
|
)
|
|
398
348
|
|
|
399
|
-
# Rebuild messages with updated system message if needed
|
|
400
349
|
messages = [
|
|
401
350
|
{"role": "system", "content": system_content},
|
|
402
351
|
{"role": "user", "content": prompt},
|
|
403
352
|
] # type: ignore
|
|
404
353
|
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
model_kwargs["max_tokens"] = max_tokens
|
|
410
|
-
model_kwargs.update(kwargs)
|
|
411
|
-
|
|
412
|
-
use_cache = self.do_cache if cache is None else cache
|
|
413
|
-
cache_key = None
|
|
414
|
-
completion = None
|
|
415
|
-
choice = None
|
|
416
|
-
parsed = None
|
|
417
|
-
|
|
418
|
-
if use_cache:
|
|
419
|
-
cache_data = {
|
|
420
|
-
"messages": messages,
|
|
421
|
-
"model_kwargs": model_kwargs,
|
|
422
|
-
"guided_json": json_schema,
|
|
423
|
-
"response_format": response_model.__name__,
|
|
424
|
-
"use_beta": use_beta,
|
|
425
|
-
}
|
|
426
|
-
cache_key = self._cache_key(cache_data, {}, response_model)
|
|
427
|
-
completion = self._load_cache(cache_key) # dict
|
|
428
|
-
|
|
429
|
-
if not completion:
|
|
430
|
-
completion, choice, parsed = await self._call_and_parse_completion(
|
|
431
|
-
messages,
|
|
432
|
-
response_model,
|
|
433
|
-
json_schema,
|
|
434
|
-
use_beta=use_beta,
|
|
435
|
-
model_kwargs=model_kwargs,
|
|
436
|
-
)
|
|
437
|
-
|
|
438
|
-
if cache_key:
|
|
439
|
-
self._dump_cache(cache_key, completion)
|
|
440
|
-
else:
|
|
441
|
-
# Extract choice and parsed from cached completion
|
|
442
|
-
choice = completion["choices"][0]["message"]
|
|
443
|
-
try:
|
|
444
|
-
parsed = self._parse_complete_output(completion, response_model)
|
|
445
|
-
except Exception as e:
|
|
446
|
-
raise ValueError(
|
|
447
|
-
f"Failed to parse cached completion: {e}\nRaw: {choice.get('content')}"
|
|
448
|
-
) from e
|
|
449
|
-
|
|
450
|
-
assert isinstance(completion, dict), (
|
|
451
|
-
"Completion must be a dictionary with OpenAI response format."
|
|
354
|
+
completion, full_messages, parsed = await self._call_and_parse(
|
|
355
|
+
messages,
|
|
356
|
+
self.response_model,
|
|
357
|
+
json_schema,
|
|
452
358
|
)
|
|
453
|
-
self._last_log = [prompt, messages, completion]
|
|
454
|
-
|
|
455
|
-
reasoning_content = choice.get("reasoning_content", "").strip()
|
|
456
|
-
_content = choice.get("content", "").lstrip("\n")
|
|
457
|
-
content = f"<think>\n{reasoning_content}\n</think>\n\n{_content}"
|
|
458
|
-
|
|
459
|
-
full_messages = messages + [{"role": "assistant", "content": content}]
|
|
460
359
|
|
|
461
360
|
return ParsedOutput(
|
|
462
361
|
messages=full_messages,
|
|
362
|
+
parsed=cast(BaseModel, parsed),
|
|
463
363
|
completion=completion,
|
|
464
|
-
|
|
364
|
+
model_kwargs=self.model_kwargs,
|
|
465
365
|
)
|
|
466
366
|
|
|
467
|
-
def _build_system_prompt(
|
|
468
|
-
self,
|
|
469
|
-
response_model,
|
|
470
|
-
add_json_schema_to_instruction,
|
|
471
|
-
json_schema,
|
|
472
|
-
system_content,
|
|
473
|
-
think,
|
|
474
|
-
):
|
|
475
|
-
if add_json_schema_to_instruction and response_model:
|
|
476
|
-
schema_block = f"\n\n<output_json_schema>\n{json.dumps(json_schema, indent=2)}\n</output_json_schema>"
|
|
477
|
-
# if schema_block not in system_content:
|
|
478
|
-
if "<output_json_schema>" in system_content:
|
|
479
|
-
# remove exsting schema block
|
|
480
|
-
import re # replace
|
|
481
|
-
|
|
482
|
-
system_content = re.sub(
|
|
483
|
-
r"<output_json_schema>.*?</output_json_schema>",
|
|
484
|
-
"",
|
|
485
|
-
system_content,
|
|
486
|
-
flags=re.DOTALL,
|
|
487
|
-
)
|
|
488
|
-
system_content = system_content.strip()
|
|
489
|
-
system_content += schema_block
|
|
490
|
-
|
|
491
|
-
if think is True:
|
|
492
|
-
if "/think" in system_content:
|
|
493
|
-
pass
|
|
494
|
-
elif "/no_think" in system_content:
|
|
495
|
-
system_content = system_content.replace("/no_think", "/think")
|
|
496
|
-
else:
|
|
497
|
-
system_content += "\n\n/think"
|
|
498
|
-
elif think is False:
|
|
499
|
-
if "/no_think" in system_content:
|
|
500
|
-
pass
|
|
501
|
-
elif "/think" in system_content:
|
|
502
|
-
system_content = system_content.replace("/think", "/no_think")
|
|
503
|
-
else:
|
|
504
|
-
system_content += "\n\n/no_think"
|
|
505
|
-
return system_content
|
|
506
|
-
|
|
507
367
|
def _parse_complete_output(
|
|
508
368
|
self, completion: Any, response_model: Type[BaseModel]
|
|
509
369
|
) -> BaseModel:
|
|
@@ -516,264 +376,12 @@ class AsyncLM:
|
|
|
516
376
|
|
|
517
377
|
content = completion["choices"][0]["message"]["content"]
|
|
518
378
|
if not content:
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
# Try to extract tokens from the completion for debugging
|
|
522
|
-
input_tokens = None
|
|
523
|
-
try:
|
|
524
|
-
input_tokens = completion.get("usage", {}).get("prompt_tokens")
|
|
525
|
-
except Exception:
|
|
526
|
-
input_tokens = None
|
|
527
|
-
|
|
528
|
-
# Try to get the prompt/messages for tokenization
|
|
529
|
-
prompt = None
|
|
530
|
-
try:
|
|
531
|
-
prompt = completion.get("messages") or completion.get("prompt")
|
|
532
|
-
except Exception:
|
|
533
|
-
prompt = None
|
|
534
|
-
|
|
535
|
-
tokens_preview = ""
|
|
536
|
-
if prompt is not None:
|
|
537
|
-
try:
|
|
538
|
-
tokenizer = get_tokenizer(self.model)
|
|
539
|
-
if isinstance(prompt, list):
|
|
540
|
-
prompt_text = "\n".join(
|
|
541
|
-
m.get("content", "") for m in prompt if isinstance(m, dict)
|
|
542
|
-
)
|
|
543
|
-
else:
|
|
544
|
-
prompt_text = str(prompt)
|
|
545
|
-
tokens = tokenizer.encode(prompt_text)
|
|
546
|
-
n_tokens = len(tokens)
|
|
547
|
-
first_100 = tokens[:100]
|
|
548
|
-
last_100 = tokens[-100:] if n_tokens > 100 else []
|
|
549
|
-
tokens_preview = (
|
|
550
|
-
f"\nInput tokens: {n_tokens}"
|
|
551
|
-
f"\nFirst 100 tokens: {first_100}"
|
|
552
|
-
f"\nLast 100 tokens: {last_100}"
|
|
553
|
-
)
|
|
554
|
-
except Exception as exc:
|
|
555
|
-
tokens_preview = f"\n[Tokenization failed: {exc}]"
|
|
556
|
-
|
|
557
|
-
raise ValueError(
|
|
558
|
-
f"Empty content in response."
|
|
559
|
-
f"\nInput tokens (if available): {input_tokens}"
|
|
560
|
-
f"{tokens_preview}"
|
|
561
|
-
)
|
|
379
|
+
raise ValueError("Response content is empty")
|
|
562
380
|
|
|
563
381
|
try:
|
|
564
|
-
data =
|
|
382
|
+
data = jloads(content)
|
|
565
383
|
return response_model.model_validate(data)
|
|
566
384
|
except Exception as exc:
|
|
567
385
|
raise ValueError(
|
|
568
|
-
f"Failed to
|
|
386
|
+
f"Failed to validate against response model {response_model.__name__}: {exc}\nRaw content: {content}"
|
|
569
387
|
) from exc
|
|
570
|
-
|
|
571
|
-
async def inspect_word_probs(
|
|
572
|
-
self,
|
|
573
|
-
messages: Optional[List[Dict[str, Any]]] = None,
|
|
574
|
-
tokenizer: Optional[Any] = None,
|
|
575
|
-
do_print=True,
|
|
576
|
-
add_think: bool = True,
|
|
577
|
-
) -> tuple[List[Dict[str, Any]], Any, str]:
|
|
578
|
-
"""
|
|
579
|
-
Inspect word probabilities in a language model response.
|
|
580
|
-
|
|
581
|
-
Args:
|
|
582
|
-
tokenizer: Tokenizer instance to encode words.
|
|
583
|
-
messages: List of messages to analyze.
|
|
584
|
-
|
|
585
|
-
Returns:
|
|
586
|
-
A tuple containing:
|
|
587
|
-
- List of word probabilities with their log probabilities.
|
|
588
|
-
- Token log probability dictionaries.
|
|
589
|
-
- Rendered string with colored word probabilities.
|
|
590
|
-
"""
|
|
591
|
-
if messages is None:
|
|
592
|
-
messages = await self.last_messages(add_think=add_think)
|
|
593
|
-
if messages is None:
|
|
594
|
-
raise ValueError("No messages provided and no last messages available.")
|
|
595
|
-
|
|
596
|
-
if tokenizer is None:
|
|
597
|
-
tokenizer = get_tokenizer(self.model)
|
|
598
|
-
|
|
599
|
-
ret = await inspect_word_probs_async(self, tokenizer, messages)
|
|
600
|
-
if do_print:
|
|
601
|
-
print(ret[-1])
|
|
602
|
-
return ret
|
|
603
|
-
|
|
604
|
-
async def last_messages(
|
|
605
|
-
self, add_think: bool = True
|
|
606
|
-
) -> Optional[List[Dict[str, str]]]:
|
|
607
|
-
"""Get the last conversation messages including assistant response."""
|
|
608
|
-
if not hasattr(self, "last_log"):
|
|
609
|
-
return None
|
|
610
|
-
|
|
611
|
-
last_conv = self._last_log
|
|
612
|
-
messages = last_conv[1] if len(last_conv) > 1 else None
|
|
613
|
-
last_msg = last_conv[2]
|
|
614
|
-
if not isinstance(last_msg, dict):
|
|
615
|
-
last_conv[2] = last_conv[2].model_dump() # type: ignore
|
|
616
|
-
msg = last_conv[2]
|
|
617
|
-
# Ensure msg is a dict
|
|
618
|
-
if hasattr(msg, "model_dump"):
|
|
619
|
-
msg = msg.model_dump()
|
|
620
|
-
message = msg["choices"][0]["message"]
|
|
621
|
-
reasoning = message.get("reasoning_content")
|
|
622
|
-
answer = message.get("content")
|
|
623
|
-
if reasoning and add_think:
|
|
624
|
-
final_answer = f"<think>{reasoning}</think>\n{answer}"
|
|
625
|
-
else:
|
|
626
|
-
final_answer = f"<think>\n\n</think>\n{answer}"
|
|
627
|
-
assistant = {"role": "assistant", "content": final_answer}
|
|
628
|
-
messages = messages + [assistant] # type: ignore
|
|
629
|
-
return messages if messages else None
|
|
630
|
-
|
|
631
|
-
# ------------------------------------------------------------------ #
|
|
632
|
-
# Utility helpers
|
|
633
|
-
# ------------------------------------------------------------------ #
|
|
634
|
-
async def inspect_history(self) -> None:
|
|
635
|
-
"""Inspect the conversation history with proper formatting."""
|
|
636
|
-
if not hasattr(self, "last_log"):
|
|
637
|
-
raise ValueError("No history available. Please call the model first.")
|
|
638
|
-
|
|
639
|
-
prompt, messages, response = self._last_log
|
|
640
|
-
if hasattr(response, "model_dump"):
|
|
641
|
-
response = response.model_dump()
|
|
642
|
-
if not messages:
|
|
643
|
-
messages = [{"role": "user", "content": prompt}]
|
|
644
|
-
|
|
645
|
-
print("\n\n")
|
|
646
|
-
print(_blue("[Conversation History]") + "\n")
|
|
647
|
-
|
|
648
|
-
for msg in messages:
|
|
649
|
-
role = msg["role"]
|
|
650
|
-
content = msg["content"]
|
|
651
|
-
print(_red(f"{role.capitalize()}:"))
|
|
652
|
-
if isinstance(content, str):
|
|
653
|
-
print(content.strip())
|
|
654
|
-
elif isinstance(content, list):
|
|
655
|
-
for item in content:
|
|
656
|
-
if item.get("type") == "text":
|
|
657
|
-
print(item["text"].strip())
|
|
658
|
-
elif item.get("type") == "image_url":
|
|
659
|
-
image_url = item["image_url"]["url"]
|
|
660
|
-
if "base64" in image_url:
|
|
661
|
-
len_base64 = len(image_url.split("base64,")[1])
|
|
662
|
-
print(_blue(f"<IMAGE BASE64 ENCODED({len_base64})>"))
|
|
663
|
-
else:
|
|
664
|
-
print(_blue(f"<image_url: {image_url}>"))
|
|
665
|
-
print("\n")
|
|
666
|
-
|
|
667
|
-
print(_red("Response:"))
|
|
668
|
-
if isinstance(response, dict) and response.get("choices"):
|
|
669
|
-
message = response["choices"][0].get("message", {})
|
|
670
|
-
reasoning = message.get("reasoning_content")
|
|
671
|
-
parsed = message.get("parsed")
|
|
672
|
-
content = message.get("content")
|
|
673
|
-
if reasoning:
|
|
674
|
-
print(_yellow("<think>"))
|
|
675
|
-
print(reasoning.strip())
|
|
676
|
-
print(_yellow("</think>\n"))
|
|
677
|
-
if parsed:
|
|
678
|
-
print(
|
|
679
|
-
json.dumps(
|
|
680
|
-
(
|
|
681
|
-
parsed.model_dump()
|
|
682
|
-
if hasattr(parsed, "model_dump")
|
|
683
|
-
else parsed
|
|
684
|
-
),
|
|
685
|
-
indent=2,
|
|
686
|
-
)
|
|
687
|
-
+ "\n"
|
|
688
|
-
)
|
|
689
|
-
elif content:
|
|
690
|
-
print(content.strip())
|
|
691
|
-
else:
|
|
692
|
-
print(_green("[No content]"))
|
|
693
|
-
if len(response["choices"]) > 1:
|
|
694
|
-
print(
|
|
695
|
-
_blue(f"\n(Plus {len(response['choices']) - 1} other completions)")
|
|
696
|
-
)
|
|
697
|
-
else:
|
|
698
|
-
print(_yellow("Warning: Not a standard OpenAI response object"))
|
|
699
|
-
if isinstance(response, str):
|
|
700
|
-
print(_green(response.strip()))
|
|
701
|
-
elif isinstance(response, dict):
|
|
702
|
-
print(_green(json.dumps(response, indent=2)))
|
|
703
|
-
else:
|
|
704
|
-
print(_green(str(response)))
|
|
705
|
-
|
|
706
|
-
# ------------------------------------------------------------------ #
|
|
707
|
-
# Misc helpers
|
|
708
|
-
# ------------------------------------------------------------------ #
|
|
709
|
-
def set_model(self, model: str) -> None:
|
|
710
|
-
self.model = model
|
|
711
|
-
|
|
712
|
-
@staticmethod
|
|
713
|
-
async def list_models(port=None, host="localhost") -> List[str]:
|
|
714
|
-
try:
|
|
715
|
-
client: AsyncOpenAI = AsyncLM(port=port, host=host).client # type: ignore[arg-type]
|
|
716
|
-
base_url: URL = client.base_url
|
|
717
|
-
logger.debug(f"Base URL: {base_url}")
|
|
718
|
-
models: AsyncSyncPage[Model] = await client.models.list() # type: ignore[assignment]
|
|
719
|
-
return [model.id for model in models.data]
|
|
720
|
-
except Exception as exc:
|
|
721
|
-
logger.error(f"Failed to list models: {exc}")
|
|
722
|
-
return []
|
|
723
|
-
|
|
724
|
-
async def _call_and_parse_completion(
|
|
725
|
-
self,
|
|
726
|
-
messages: list[dict],
|
|
727
|
-
response_model: Type[TParsed],
|
|
728
|
-
json_schema: dict,
|
|
729
|
-
use_beta: bool,
|
|
730
|
-
model_kwargs: dict,
|
|
731
|
-
) -> tuple[dict, dict, TParsed]:
|
|
732
|
-
"""Call vLLM or OpenAI-compatible endpoint and parse JSON response consistently."""
|
|
733
|
-
await self._set_model() # Ensure model is set before making the call
|
|
734
|
-
# Convert messages to proper type
|
|
735
|
-
converted_messages = self._convert_messages(messages) # type: ignore
|
|
736
|
-
|
|
737
|
-
if use_beta:
|
|
738
|
-
# Use guided JSON for structure enforcement
|
|
739
|
-
try:
|
|
740
|
-
completion = await self.client.chat.completions.create(
|
|
741
|
-
model=str(self.model), # type: ignore
|
|
742
|
-
messages=converted_messages,
|
|
743
|
-
extra_body={"guided_json": json_schema}, # type: ignore
|
|
744
|
-
**model_kwargs,
|
|
745
|
-
) # type: ignore
|
|
746
|
-
except Exception:
|
|
747
|
-
# Fallback if extra_body is not supported
|
|
748
|
-
completion = await self.client.chat.completions.create(
|
|
749
|
-
model=str(self.model), # type: ignore
|
|
750
|
-
messages=converted_messages,
|
|
751
|
-
response_format={"type": "json_object"},
|
|
752
|
-
**model_kwargs,
|
|
753
|
-
)
|
|
754
|
-
else:
|
|
755
|
-
# Use OpenAI-style structured output
|
|
756
|
-
completion = await self.client.chat.completions.create(
|
|
757
|
-
model=str(self.model), # type: ignore
|
|
758
|
-
messages=converted_messages,
|
|
759
|
-
response_format={"type": "json_object"},
|
|
760
|
-
**model_kwargs,
|
|
761
|
-
)
|
|
762
|
-
|
|
763
|
-
if hasattr(completion, "model_dump"):
|
|
764
|
-
completion = completion.model_dump()
|
|
765
|
-
|
|
766
|
-
choice = completion["choices"][0]["message"]
|
|
767
|
-
|
|
768
|
-
try:
|
|
769
|
-
parsed = (
|
|
770
|
-
self._parse_complete_output(completion, response_model)
|
|
771
|
-
if use_beta
|
|
772
|
-
else response_model.model_validate(jloads(choice.get("content")))
|
|
773
|
-
)
|
|
774
|
-
except Exception as e:
|
|
775
|
-
raise ValueError(
|
|
776
|
-
f"Failed to parse model response: {e}\nRaw: {choice.get('content')}"
|
|
777
|
-
) from e
|
|
778
|
-
|
|
779
|
-
return completion, choice, parsed # type: ignore
|