speedy-utils 1.1.5__py3-none-any.whl → 1.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +1 -5
- llm_utils/chat_format/display.py +17 -4
- llm_utils/chat_format/transform.py +9 -9
- llm_utils/group_messages.py +1 -1
- llm_utils/lm/async_lm/__init__.py +7 -0
- llm_utils/lm/async_lm/_utils.py +201 -0
- llm_utils/lm/async_lm/async_llm_task.py +509 -0
- llm_utils/lm/async_lm/async_lm.py +387 -0
- llm_utils/lm/async_lm/async_lm_base.py +405 -0
- llm_utils/lm/async_lm/lm_specific.py +136 -0
- llm_utils/lm/utils.py +1 -3
- llm_utils/scripts/vllm_load_balancer.py +244 -147
- speedy_utils/__init__.py +3 -1
- speedy_utils/common/notebook_utils.py +4 -4
- speedy_utils/common/report_manager.py +2 -3
- speedy_utils/common/utils_cache.py +233 -3
- speedy_utils/common/utils_io.py +2 -0
- speedy_utils/scripts/mpython.py +1 -3
- {speedy_utils-1.1.5.dist-info → speedy_utils-1.1.7.dist-info}/METADATA +1 -1
- speedy_utils-1.1.7.dist-info/RECORD +39 -0
- llm_utils/lm/async_lm.py +0 -942
- llm_utils/lm/chat_html.py +0 -246
- llm_utils/lm/lm_json.py +0 -68
- llm_utils/lm/sync_lm.py +0 -943
- speedy_utils-1.1.5.dist-info/RECORD +0 -37
- {speedy_utils-1.1.5.dist-info → speedy_utils-1.1.7.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.5.dist-info → speedy_utils-1.1.7.dist-info}/entry_points.txt +0 -0
llm_utils/lm/async_lm.py
DELETED
|
@@ -1,942 +0,0 @@
|
|
|
1
|
-
|
|
2
|
-
import base64
|
|
3
|
-
import hashlib
|
|
4
|
-
import json
|
|
5
|
-
import os
|
|
6
|
-
from abc import ABC
|
|
7
|
-
from functools import cache, lru_cache
|
|
8
|
-
from typing import (
|
|
9
|
-
Any,
|
|
10
|
-
Dict,
|
|
11
|
-
Generic,
|
|
12
|
-
List,
|
|
13
|
-
Literal,
|
|
14
|
-
Optional,
|
|
15
|
-
Sequence,
|
|
16
|
-
Type,
|
|
17
|
-
TypeVar,
|
|
18
|
-
Union,
|
|
19
|
-
cast,
|
|
20
|
-
overload,
|
|
21
|
-
)
|
|
22
|
-
from typing_extensions import TypedDict
|
|
23
|
-
from httpx import URL
|
|
24
|
-
from loguru import logger
|
|
25
|
-
from numpy import isin
|
|
26
|
-
from openai import AsyncOpenAI, AuthenticationError, BadRequestError, RateLimitError
|
|
27
|
-
from openai.pagination import AsyncPage as AsyncSyncPage
|
|
28
|
-
|
|
29
|
-
# from openai.pagination import AsyncSyncPage
|
|
30
|
-
from openai.types.chat import (
|
|
31
|
-
ChatCompletionAssistantMessageParam,
|
|
32
|
-
ChatCompletionMessageParam,
|
|
33
|
-
ChatCompletionSystemMessageParam,
|
|
34
|
-
ChatCompletionToolMessageParam,
|
|
35
|
-
ChatCompletionUserMessageParam,
|
|
36
|
-
)
|
|
37
|
-
from openai.types.model import Model
|
|
38
|
-
from pydantic import BaseModel
|
|
39
|
-
from pydantic import ValidationError
|
|
40
|
-
from llm_utils.chat_format.display import get_conversation_one_turn
|
|
41
|
-
|
|
42
|
-
# --------------------------------------------------------------------------- #
|
|
43
|
-
# type helpers
|
|
44
|
-
# --------------------------------------------------------------------------- #
|
|
45
|
-
TModel = TypeVar("TModel", bound=BaseModel)
|
|
46
|
-
Messages = List[ChatCompletionMessageParam]
|
|
47
|
-
LegacyMsgs = List[Dict[str, str]]
|
|
48
|
-
RawMsgs = Union[Messages, LegacyMsgs]
|
|
49
|
-
|
|
50
|
-
# --------------------------------------------------------------------------- #
|
|
51
|
-
# color helpers (unchanged)
|
|
52
|
-
# --------------------------------------------------------------------------- #
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def _color(code: int, text: str) -> str:
|
|
56
|
-
return f"\x1b[{code}m{text}\x1b[0m"
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def _red(t):
|
|
60
|
-
return _color(31, t)
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def _green(t):
|
|
64
|
-
return _color(32, t)
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
def _blue(t):
|
|
68
|
-
return _color(34, t)
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def _yellow(t):
|
|
72
|
-
return _color(33, t)
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
TParsed = TypeVar("TParsed", bound=BaseModel)
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
class ParsedOutput(TypedDict, Generic[TParsed]):
|
|
79
|
-
messages: List
|
|
80
|
-
completion: Any
|
|
81
|
-
parsed: TParsed
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
class AsyncLM:
|
|
85
|
-
"""Unified **async** language‑model wrapper with optional JSON parsing."""
|
|
86
|
-
|
|
87
|
-
def __init__(
|
|
88
|
-
self,
|
|
89
|
-
model: str | None = None,
|
|
90
|
-
*,
|
|
91
|
-
temperature: float = 0.0,
|
|
92
|
-
max_tokens: int = 2_000,
|
|
93
|
-
host: str = "localhost",
|
|
94
|
-
port: Optional[int | str] = None,
|
|
95
|
-
base_url: Optional[str] = None,
|
|
96
|
-
api_key: Optional[str] = None,
|
|
97
|
-
cache: bool = True,
|
|
98
|
-
ports: Optional[List[int]] = None,
|
|
99
|
-
**openai_kwargs: Any,
|
|
100
|
-
) -> None:
|
|
101
|
-
self.model = model
|
|
102
|
-
self.temperature = temperature
|
|
103
|
-
self.max_tokens = max_tokens
|
|
104
|
-
self.port = port
|
|
105
|
-
self.host = host
|
|
106
|
-
self.base_url = base_url or (f"http://{host}:{port}/v1" if port else None)
|
|
107
|
-
self.api_key = api_key or os.getenv("OPENAI_API_KEY", "abc")
|
|
108
|
-
self.openai_kwargs = openai_kwargs
|
|
109
|
-
self.do_cache = cache
|
|
110
|
-
self.ports = ports
|
|
111
|
-
self._init_port = port # <-- store the port provided at init
|
|
112
|
-
|
|
113
|
-
# Async client
|
|
114
|
-
|
|
115
|
-
@property
|
|
116
|
-
def client(self) -> AsyncOpenAI:
|
|
117
|
-
# if have multiple ports
|
|
118
|
-
if self.ports:
|
|
119
|
-
import random
|
|
120
|
-
|
|
121
|
-
port = random.choice(self.ports)
|
|
122
|
-
api_base = f"http://{self.host}:{port}/v1"
|
|
123
|
-
logger.debug(f"Using port: {port}")
|
|
124
|
-
else:
|
|
125
|
-
api_base = self.base_url or f"http://{self.host}:{self.port}/v1"
|
|
126
|
-
client = AsyncOpenAI(
|
|
127
|
-
api_key=self.api_key, base_url=api_base, **self.openai_kwargs
|
|
128
|
-
)
|
|
129
|
-
return client
|
|
130
|
-
|
|
131
|
-
# ------------------------------------------------------------------ #
|
|
132
|
-
# Public API – typed overloads
|
|
133
|
-
# ------------------------------------------------------------------ #
|
|
134
|
-
@overload
|
|
135
|
-
async def __call__(
|
|
136
|
-
self,
|
|
137
|
-
*,
|
|
138
|
-
prompt: str | None = ...,
|
|
139
|
-
messages: RawMsgs | None = ...,
|
|
140
|
-
response_format: type[str] = str,
|
|
141
|
-
return_openai_response: bool = ...,
|
|
142
|
-
**kwargs: Any,
|
|
143
|
-
) -> str: ...
|
|
144
|
-
|
|
145
|
-
@overload
|
|
146
|
-
async def __call__(
|
|
147
|
-
self,
|
|
148
|
-
*,
|
|
149
|
-
prompt: str | None = ...,
|
|
150
|
-
messages: RawMsgs | None = ...,
|
|
151
|
-
response_format: Type[TModel],
|
|
152
|
-
return_openai_response: bool = ...,
|
|
153
|
-
**kwargs: Any,
|
|
154
|
-
) -> TModel: ...
|
|
155
|
-
|
|
156
|
-
async def __call__(
|
|
157
|
-
self,
|
|
158
|
-
prompt: Optional[str] = None,
|
|
159
|
-
messages: Optional[RawMsgs] = None,
|
|
160
|
-
response_format: Union[type[str], Type[BaseModel]] = str,
|
|
161
|
-
cache: Optional[bool] = None,
|
|
162
|
-
max_tokens: Optional[int] = None,
|
|
163
|
-
return_openai_response: bool = False,
|
|
164
|
-
**kwargs: Any,
|
|
165
|
-
):
|
|
166
|
-
if (prompt is None) == (messages is None):
|
|
167
|
-
raise ValueError("Provide *either* `prompt` or `messages` (but not both).")
|
|
168
|
-
|
|
169
|
-
if prompt is not None:
|
|
170
|
-
messages = [{"role": "user", "content": prompt}]
|
|
171
|
-
|
|
172
|
-
assert messages is not None
|
|
173
|
-
# assert self.model is not None, "Model must be set before calling."
|
|
174
|
-
if not self.model:
|
|
175
|
-
models = await self.list_models(port=self.port, host=self.host)
|
|
176
|
-
self.model = models[0] if models else None
|
|
177
|
-
logger.info(
|
|
178
|
-
f"No model specified. Using the first available model. {self.model}"
|
|
179
|
-
)
|
|
180
|
-
openai_msgs: Messages = (
|
|
181
|
-
self._convert_messages(cast(LegacyMsgs, messages))
|
|
182
|
-
if isinstance(messages[0], dict)
|
|
183
|
-
else cast(Messages, messages)
|
|
184
|
-
)
|
|
185
|
-
|
|
186
|
-
kw = dict(
|
|
187
|
-
self.openai_kwargs,
|
|
188
|
-
temperature=self.temperature,
|
|
189
|
-
max_tokens=max_tokens or self.max_tokens,
|
|
190
|
-
)
|
|
191
|
-
kw.update(kwargs)
|
|
192
|
-
use_cache = self.do_cache if cache is None else cache
|
|
193
|
-
|
|
194
|
-
raw_response = await self._call_raw(
|
|
195
|
-
openai_msgs,
|
|
196
|
-
response_format=response_format,
|
|
197
|
-
use_cache=use_cache,
|
|
198
|
-
**kw,
|
|
199
|
-
)
|
|
200
|
-
|
|
201
|
-
if return_openai_response:
|
|
202
|
-
response = raw_response
|
|
203
|
-
else:
|
|
204
|
-
response = self._parse_output(raw_response, response_format)
|
|
205
|
-
|
|
206
|
-
self.last_log = [prompt, messages, raw_response]
|
|
207
|
-
return response
|
|
208
|
-
|
|
209
|
-
# ------------------------------------------------------------------ #
|
|
210
|
-
# Model invocation (async)
|
|
211
|
-
# ------------------------------------------------------------------ #
|
|
212
|
-
async def _call_raw(
|
|
213
|
-
self,
|
|
214
|
-
messages: Sequence[ChatCompletionMessageParam],
|
|
215
|
-
response_format: Union[type[str], Type[BaseModel]],
|
|
216
|
-
use_cache: bool,
|
|
217
|
-
**kw: Any,
|
|
218
|
-
):
|
|
219
|
-
assert self.model is not None, "Model must be set before making a call."
|
|
220
|
-
model: str = self.model
|
|
221
|
-
|
|
222
|
-
cache_key = (
|
|
223
|
-
self._cache_key(messages, kw, response_format) if use_cache else None
|
|
224
|
-
)
|
|
225
|
-
if cache_key and (hit := self._load_cache(cache_key)) is not None:
|
|
226
|
-
# Check if cached value is an error
|
|
227
|
-
if isinstance(hit, dict) and hit.get("error"):
|
|
228
|
-
error_type = hit.get("error_type", "Unknown")
|
|
229
|
-
error_msg = hit.get("error_message", "Cached error")
|
|
230
|
-
logger.warning(f"Found cached error ({error_type}): {error_msg}")
|
|
231
|
-
# Re-raise as a ValueError with meaningful message
|
|
232
|
-
raise ValueError(f"Cached {error_type}: {error_msg}")
|
|
233
|
-
return hit
|
|
234
|
-
|
|
235
|
-
try:
|
|
236
|
-
if response_format is not str and issubclass(response_format, BaseModel):
|
|
237
|
-
openai_response = await self.client.beta.chat.completions.parse(
|
|
238
|
-
model=model,
|
|
239
|
-
messages=list(messages),
|
|
240
|
-
response_format=response_format, # type: ignore[arg-type]
|
|
241
|
-
**kw,
|
|
242
|
-
)
|
|
243
|
-
else:
|
|
244
|
-
openai_response = await self.client.chat.completions.create(
|
|
245
|
-
model=model,
|
|
246
|
-
messages=list(messages),
|
|
247
|
-
**kw,
|
|
248
|
-
)
|
|
249
|
-
|
|
250
|
-
except (AuthenticationError, RateLimitError, BadRequestError) as exc:
|
|
251
|
-
error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
|
|
252
|
-
logger.error(error_msg)
|
|
253
|
-
|
|
254
|
-
# Cache the error if it's a BadRequestError to avoid repeated calls
|
|
255
|
-
if isinstance(exc, BadRequestError) and cache_key:
|
|
256
|
-
error_response = {
|
|
257
|
-
"error": True,
|
|
258
|
-
"error_type": "BadRequestError",
|
|
259
|
-
"error_message": str(exc),
|
|
260
|
-
"choices": [],
|
|
261
|
-
}
|
|
262
|
-
self._dump_cache(cache_key, error_response)
|
|
263
|
-
logger.debug(f"Cached BadRequestError for key: {cache_key}")
|
|
264
|
-
|
|
265
|
-
raise
|
|
266
|
-
|
|
267
|
-
if cache_key:
|
|
268
|
-
self._dump_cache(cache_key, openai_response)
|
|
269
|
-
|
|
270
|
-
return openai_response
|
|
271
|
-
|
|
272
|
-
# ------------------------------------------------------------------ #
|
|
273
|
-
# Utilities below are unchanged (sync I/O is acceptable)
|
|
274
|
-
# ------------------------------------------------------------------ #
|
|
275
|
-
@staticmethod
|
|
276
|
-
def _convert_messages(msgs: LegacyMsgs) -> Messages:
|
|
277
|
-
converted: Messages = []
|
|
278
|
-
for msg in msgs:
|
|
279
|
-
role = msg["role"]
|
|
280
|
-
content = msg["content"]
|
|
281
|
-
if role == "user":
|
|
282
|
-
converted.append(
|
|
283
|
-
ChatCompletionUserMessageParam(role="user", content=content)
|
|
284
|
-
)
|
|
285
|
-
elif role == "assistant":
|
|
286
|
-
converted.append(
|
|
287
|
-
ChatCompletionAssistantMessageParam(
|
|
288
|
-
role="assistant", content=content
|
|
289
|
-
)
|
|
290
|
-
)
|
|
291
|
-
elif role == "system":
|
|
292
|
-
converted.append(
|
|
293
|
-
ChatCompletionSystemMessageParam(role="system", content=content)
|
|
294
|
-
)
|
|
295
|
-
elif role == "tool":
|
|
296
|
-
converted.append(
|
|
297
|
-
ChatCompletionToolMessageParam(
|
|
298
|
-
role="tool",
|
|
299
|
-
content=content,
|
|
300
|
-
tool_call_id=msg.get("tool_call_id") or "",
|
|
301
|
-
)
|
|
302
|
-
)
|
|
303
|
-
else:
|
|
304
|
-
converted.append({"role": role, "content": content}) # type: ignore[arg-type]
|
|
305
|
-
return converted
|
|
306
|
-
|
|
307
|
-
@staticmethod
|
|
308
|
-
def _parse_output(
|
|
309
|
-
raw_response: Any, response_format: Union[type[str], Type[BaseModel]]
|
|
310
|
-
) -> str | BaseModel:
|
|
311
|
-
if hasattr(raw_response, "model_dump"):
|
|
312
|
-
raw_response = raw_response.model_dump()
|
|
313
|
-
|
|
314
|
-
if response_format is str:
|
|
315
|
-
if isinstance(raw_response, dict) and "choices" in raw_response:
|
|
316
|
-
message = raw_response["choices"][0]["message"]
|
|
317
|
-
return message.get("content", "") or ""
|
|
318
|
-
return cast(str, raw_response)
|
|
319
|
-
|
|
320
|
-
model_cls = cast(Type[BaseModel], response_format)
|
|
321
|
-
|
|
322
|
-
if isinstance(raw_response, dict) and "choices" in raw_response:
|
|
323
|
-
message = raw_response["choices"][0]["message"]
|
|
324
|
-
if "parsed" in message:
|
|
325
|
-
return model_cls.model_validate(message["parsed"])
|
|
326
|
-
content = message.get("content")
|
|
327
|
-
if content is None:
|
|
328
|
-
raise ValueError("Model returned empty content")
|
|
329
|
-
try:
|
|
330
|
-
data = json.loads(content)
|
|
331
|
-
return model_cls.model_validate(data)
|
|
332
|
-
except Exception as exc:
|
|
333
|
-
raise ValueError(
|
|
334
|
-
f"Failed to parse model output as JSON:\n{content}"
|
|
335
|
-
) from exc
|
|
336
|
-
|
|
337
|
-
if isinstance(raw_response, model_cls):
|
|
338
|
-
return raw_response
|
|
339
|
-
if isinstance(raw_response, dict):
|
|
340
|
-
return model_cls.model_validate(raw_response)
|
|
341
|
-
|
|
342
|
-
try:
|
|
343
|
-
data = json.loads(raw_response)
|
|
344
|
-
return model_cls.model_validate(data)
|
|
345
|
-
except Exception as exc:
|
|
346
|
-
raise ValueError(
|
|
347
|
-
f"Model did not return valid JSON:\n---\n{raw_response}"
|
|
348
|
-
) from exc
|
|
349
|
-
|
|
350
|
-
# ------------------------------------------------------------------ #
|
|
351
|
-
# Simple disk cache (sync)
|
|
352
|
-
# ------------------------------------------------------------------ #
|
|
353
|
-
@staticmethod
|
|
354
|
-
def _cache_key(
|
|
355
|
-
messages: Any, kw: Any, response_format: Union[type[str], Type[BaseModel]]
|
|
356
|
-
) -> str:
|
|
357
|
-
tag = response_format.__name__ if response_format is not str else "text"
|
|
358
|
-
blob = json.dumps([messages, kw, tag], sort_keys=True).encode()
|
|
359
|
-
return base64.urlsafe_b64encode(hashlib.sha256(blob).digest()).decode()[:22]
|
|
360
|
-
|
|
361
|
-
@staticmethod
|
|
362
|
-
def _cache_path(key: str) -> str:
|
|
363
|
-
return os.path.expanduser(f"~/.cache/lm/{key}.json")
|
|
364
|
-
|
|
365
|
-
def _dump_cache(self, key: str, val: Any) -> None:
|
|
366
|
-
try:
|
|
367
|
-
path = self._cache_path(key)
|
|
368
|
-
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
369
|
-
with open(path, "w") as fh:
|
|
370
|
-
if isinstance(val, BaseModel):
|
|
371
|
-
json.dump(val.model_dump(mode="json"), fh)
|
|
372
|
-
else:
|
|
373
|
-
json.dump(val, fh)
|
|
374
|
-
except Exception as exc:
|
|
375
|
-
logger.debug(f"cache write skipped: {exc}")
|
|
376
|
-
|
|
377
|
-
def _load_cache(self, key: str) -> Any | None:
|
|
378
|
-
path = self._cache_path(key)
|
|
379
|
-
if not os.path.exists(path):
|
|
380
|
-
return None
|
|
381
|
-
try:
|
|
382
|
-
with open(path) as fh:
|
|
383
|
-
return json.load(fh)
|
|
384
|
-
except Exception:
|
|
385
|
-
return None
|
|
386
|
-
|
|
387
|
-
# ------------------------------------------------------------------ #
|
|
388
|
-
# Missing methods from LM class
|
|
389
|
-
# ------------------------------------------------------------------ #
|
|
390
|
-
async def parse(
|
|
391
|
-
self,
|
|
392
|
-
response_model: Type[TParsed],
|
|
393
|
-
instruction: Optional[str] = None,
|
|
394
|
-
prompt: Optional[str] = None,
|
|
395
|
-
messages: Optional[RawMsgs] = None,
|
|
396
|
-
think: Literal[True, False, None] = None,
|
|
397
|
-
add_json_schema_to_instruction: bool = False,
|
|
398
|
-
temperature: Optional[float] = None,
|
|
399
|
-
max_tokens: Optional[int] = None,
|
|
400
|
-
cache: Optional[bool] = True,
|
|
401
|
-
**kwargs,
|
|
402
|
-
) -> ParsedOutput[TParsed]:
|
|
403
|
-
"""Parse response using guided JSON generation."""
|
|
404
|
-
if messages is None:
|
|
405
|
-
assert instruction is not None, "Instruction must be provided."
|
|
406
|
-
assert prompt is not None, "Prompt must be provided."
|
|
407
|
-
messages = [
|
|
408
|
-
{
|
|
409
|
-
"role": "system",
|
|
410
|
-
"content": instruction,
|
|
411
|
-
},
|
|
412
|
-
{
|
|
413
|
-
"role": "user",
|
|
414
|
-
"content": prompt,
|
|
415
|
-
},
|
|
416
|
-
] # type: ignore
|
|
417
|
-
|
|
418
|
-
post_fix = ""
|
|
419
|
-
json_schema = response_model.model_json_schema()
|
|
420
|
-
if add_json_schema_to_instruction and response_model:
|
|
421
|
-
_schema = f"\n\n<output_json_schema>\n{json.dumps(json_schema, indent=2)}\n</output_json_schema>"
|
|
422
|
-
post_fix += _schema
|
|
423
|
-
|
|
424
|
-
if think:
|
|
425
|
-
post_fix += "\n\n/think"
|
|
426
|
-
elif not think:
|
|
427
|
-
post_fix += "\n\n/no_think"
|
|
428
|
-
|
|
429
|
-
assert isinstance(messages, list), "Messages must be a list."
|
|
430
|
-
assert len(messages) > 0, "Messages cannot be empty."
|
|
431
|
-
assert messages[0]["role"] == "system", (
|
|
432
|
-
"First message must be a system message with instruction."
|
|
433
|
-
)
|
|
434
|
-
messages[0]["content"] += post_fix # type: ignore
|
|
435
|
-
|
|
436
|
-
model_kwargs = {}
|
|
437
|
-
if temperature is not None:
|
|
438
|
-
model_kwargs["temperature"] = temperature
|
|
439
|
-
if max_tokens is not None:
|
|
440
|
-
model_kwargs["max_tokens"] = max_tokens
|
|
441
|
-
model_kwargs.update(kwargs)
|
|
442
|
-
|
|
443
|
-
use_cache = self.do_cache if cache is None else cache
|
|
444
|
-
cache_key = None
|
|
445
|
-
completion = None
|
|
446
|
-
if use_cache:
|
|
447
|
-
cache_data = {
|
|
448
|
-
"messages": messages,
|
|
449
|
-
"model_kwargs": model_kwargs,
|
|
450
|
-
"guided_json": json_schema,
|
|
451
|
-
"response_format": response_model.__name__,
|
|
452
|
-
}
|
|
453
|
-
cache_key = self._cache_key(cache_data, {}, response_model)
|
|
454
|
-
completion = self._load_cache(cache_key) # dict
|
|
455
|
-
if not completion:
|
|
456
|
-
completion = await self.client.chat.completions.create(
|
|
457
|
-
model=self.model, # type: ignore
|
|
458
|
-
messages=messages, # type: ignore
|
|
459
|
-
extra_body={"guided_json": json_schema},
|
|
460
|
-
**model_kwargs,
|
|
461
|
-
)
|
|
462
|
-
completion = completion.model_dump()
|
|
463
|
-
if cache_key:
|
|
464
|
-
self._dump_cache(cache_key, completion)
|
|
465
|
-
assert isinstance(completion, dict), (
|
|
466
|
-
"Completion must be a dictionary with OpenAI response format."
|
|
467
|
-
)
|
|
468
|
-
self.last_log = [prompt, messages, completion]
|
|
469
|
-
|
|
470
|
-
output = cast(TParsed, self._parse_complete_output(completion, response_model))
|
|
471
|
-
full_messages = messages + [completion]
|
|
472
|
-
return ParsedOutput(
|
|
473
|
-
messages=full_messages,
|
|
474
|
-
completion=completion,
|
|
475
|
-
parsed=output,
|
|
476
|
-
)
|
|
477
|
-
|
|
478
|
-
def _parse_complete_output(
|
|
479
|
-
self, completion: Any, response_model: Type[BaseModel]
|
|
480
|
-
) -> BaseModel:
|
|
481
|
-
"""Parse completion output to response model."""
|
|
482
|
-
if hasattr(completion, "model_dump"):
|
|
483
|
-
completion = completion.model_dump()
|
|
484
|
-
|
|
485
|
-
if "choices" not in completion or not completion["choices"]:
|
|
486
|
-
raise ValueError("No choices in OpenAI response")
|
|
487
|
-
|
|
488
|
-
content = completion["choices"][0]["message"]["content"]
|
|
489
|
-
if not content:
|
|
490
|
-
# Enhanced error for debugging: show input tokens and their count
|
|
491
|
-
|
|
492
|
-
# Try to extract tokens from the completion for debugging
|
|
493
|
-
input_tokens = None
|
|
494
|
-
try:
|
|
495
|
-
input_tokens = completion.get('usage', {}).get('prompt_tokens')
|
|
496
|
-
except Exception:
|
|
497
|
-
input_tokens = None
|
|
498
|
-
|
|
499
|
-
# Try to get the prompt/messages for tokenization
|
|
500
|
-
prompt = None
|
|
501
|
-
try:
|
|
502
|
-
prompt = completion.get('messages') or completion.get('prompt')
|
|
503
|
-
except Exception:
|
|
504
|
-
prompt = None
|
|
505
|
-
|
|
506
|
-
tokens_preview = ''
|
|
507
|
-
if prompt is not None:
|
|
508
|
-
try:
|
|
509
|
-
tokenizer = get_tokenizer(self.model)
|
|
510
|
-
if isinstance(prompt, list):
|
|
511
|
-
prompt_text = '\n'.join(
|
|
512
|
-
m.get('content', '') for m in prompt if isinstance(m, dict)
|
|
513
|
-
)
|
|
514
|
-
else:
|
|
515
|
-
prompt_text = str(prompt)
|
|
516
|
-
tokens = tokenizer.encode(prompt_text)
|
|
517
|
-
n_tokens = len(tokens)
|
|
518
|
-
first_100 = tokens[:100]
|
|
519
|
-
last_100 = tokens[-100:] if n_tokens > 100 else []
|
|
520
|
-
tokens_preview = (
|
|
521
|
-
f'\nInput tokens: {n_tokens}'
|
|
522
|
-
f'\nFirst 100 tokens: {first_100}'
|
|
523
|
-
f'\nLast 100 tokens: {last_100}'
|
|
524
|
-
)
|
|
525
|
-
except Exception as exc:
|
|
526
|
-
tokens_preview = f'\n[Tokenization failed: {exc}]'
|
|
527
|
-
|
|
528
|
-
raise ValueError(
|
|
529
|
-
f'Empty content in response.'
|
|
530
|
-
f'\nInput tokens (if available): {input_tokens}'
|
|
531
|
-
f'{tokens_preview}'
|
|
532
|
-
)
|
|
533
|
-
|
|
534
|
-
try:
|
|
535
|
-
data = json.loads(content)
|
|
536
|
-
return response_model.model_validate(data)
|
|
537
|
-
except Exception as exc:
|
|
538
|
-
raise ValueError(
|
|
539
|
-
f"Failed to parse response as {response_model.__name__}: {content}"
|
|
540
|
-
) from exc
|
|
541
|
-
|
|
542
|
-
async def inspect_word_probs(
|
|
543
|
-
self,
|
|
544
|
-
messages: Optional[List[Dict[str, Any]]] = None,
|
|
545
|
-
tokenizer: Optional[Any] = None,
|
|
546
|
-
do_print=True,
|
|
547
|
-
add_think: bool = True,
|
|
548
|
-
) -> tuple[List[Dict[str, Any]], Any, str]:
|
|
549
|
-
"""
|
|
550
|
-
Inspect word probabilities in a language model response.
|
|
551
|
-
|
|
552
|
-
Args:
|
|
553
|
-
tokenizer: Tokenizer instance to encode words.
|
|
554
|
-
messages: List of messages to analyze.
|
|
555
|
-
|
|
556
|
-
Returns:
|
|
557
|
-
A tuple containing:
|
|
558
|
-
- List of word probabilities with their log probabilities.
|
|
559
|
-
- Token log probability dictionaries.
|
|
560
|
-
- Rendered string with colored word probabilities.
|
|
561
|
-
"""
|
|
562
|
-
if messages is None:
|
|
563
|
-
messages = await self.last_messages(add_think=add_think)
|
|
564
|
-
if messages is None:
|
|
565
|
-
raise ValueError("No messages provided and no last messages available.")
|
|
566
|
-
|
|
567
|
-
if tokenizer is None:
|
|
568
|
-
tokenizer = get_tokenizer(self.model)
|
|
569
|
-
|
|
570
|
-
ret = await inspect_word_probs_async(self, tokenizer, messages)
|
|
571
|
-
if do_print:
|
|
572
|
-
print(ret[-1])
|
|
573
|
-
return ret
|
|
574
|
-
|
|
575
|
-
async def last_messages(
|
|
576
|
-
self, add_think: bool = True
|
|
577
|
-
) -> Optional[List[Dict[str, str]]]:
|
|
578
|
-
"""Get the last conversation messages including assistant response."""
|
|
579
|
-
if not hasattr(self, "last_log"):
|
|
580
|
-
return None
|
|
581
|
-
|
|
582
|
-
last_conv = self.last_log
|
|
583
|
-
messages = last_conv[1] if len(last_conv) > 1 else None
|
|
584
|
-
last_msg = last_conv[2]
|
|
585
|
-
if not isinstance(last_msg, dict):
|
|
586
|
-
last_conv[2] = last_conv[2].model_dump() # type: ignore
|
|
587
|
-
msg = last_conv[2]
|
|
588
|
-
# Ensure msg is a dict
|
|
589
|
-
if hasattr(msg, "model_dump"):
|
|
590
|
-
msg = msg.model_dump()
|
|
591
|
-
message = msg["choices"][0]["message"]
|
|
592
|
-
reasoning = message.get("reasoning_content")
|
|
593
|
-
answer = message.get("content")
|
|
594
|
-
if reasoning and add_think:
|
|
595
|
-
final_answer = f"<think>{reasoning}</think>\n{answer}"
|
|
596
|
-
else:
|
|
597
|
-
final_answer = f"<think>\n\n</think>\n{answer}"
|
|
598
|
-
assistant = {"role": "assistant", "content": final_answer}
|
|
599
|
-
messages = messages + [assistant] # type: ignore
|
|
600
|
-
return messages if messages else None
|
|
601
|
-
|
|
602
|
-
# ------------------------------------------------------------------ #
|
|
603
|
-
# Utility helpers
|
|
604
|
-
# ------------------------------------------------------------------ #
|
|
605
|
-
async def inspect_history(self) -> None:
|
|
606
|
-
"""Inspect the conversation history with proper formatting."""
|
|
607
|
-
if not hasattr(self, "last_log"):
|
|
608
|
-
raise ValueError("No history available. Please call the model first.")
|
|
609
|
-
|
|
610
|
-
prompt, messages, response = self.last_log
|
|
611
|
-
if hasattr(response, "model_dump"):
|
|
612
|
-
response = response.model_dump()
|
|
613
|
-
if not messages:
|
|
614
|
-
messages = [{"role": "user", "content": prompt}]
|
|
615
|
-
|
|
616
|
-
print("\n\n")
|
|
617
|
-
print(_blue("[Conversation History]") + "\n")
|
|
618
|
-
|
|
619
|
-
for msg in messages:
|
|
620
|
-
role = msg["role"]
|
|
621
|
-
content = msg["content"]
|
|
622
|
-
print(_red(f"{role.capitalize()}:"))
|
|
623
|
-
if isinstance(content, str):
|
|
624
|
-
print(content.strip())
|
|
625
|
-
elif isinstance(content, list):
|
|
626
|
-
for item in content:
|
|
627
|
-
if item.get("type") == "text":
|
|
628
|
-
print(item["text"].strip())
|
|
629
|
-
elif item.get("type") == "image_url":
|
|
630
|
-
image_url = item["image_url"]["url"]
|
|
631
|
-
if "base64" in image_url:
|
|
632
|
-
len_base64 = len(image_url.split("base64,")[1])
|
|
633
|
-
print(_blue(f"<IMAGE BASE64 ENCODED({len_base64})>"))
|
|
634
|
-
else:
|
|
635
|
-
print(_blue(f"<image_url: {image_url}>"))
|
|
636
|
-
print("\n")
|
|
637
|
-
|
|
638
|
-
print(_red("Response:"))
|
|
639
|
-
if isinstance(response, dict) and response.get("choices"):
|
|
640
|
-
message = response["choices"][0].get("message", {})
|
|
641
|
-
reasoning = message.get("reasoning_content")
|
|
642
|
-
parsed = message.get("parsed")
|
|
643
|
-
content = message.get("content")
|
|
644
|
-
if reasoning:
|
|
645
|
-
print(_yellow("<think>"))
|
|
646
|
-
print(reasoning.strip())
|
|
647
|
-
print(_yellow("</think>\n"))
|
|
648
|
-
if parsed:
|
|
649
|
-
print(
|
|
650
|
-
json.dumps(
|
|
651
|
-
(
|
|
652
|
-
parsed.model_dump()
|
|
653
|
-
if hasattr(parsed, "model_dump")
|
|
654
|
-
else parsed
|
|
655
|
-
),
|
|
656
|
-
indent=2,
|
|
657
|
-
)
|
|
658
|
-
+ "\n"
|
|
659
|
-
)
|
|
660
|
-
elif content:
|
|
661
|
-
print(content.strip())
|
|
662
|
-
else:
|
|
663
|
-
print(_green("[No content]"))
|
|
664
|
-
if len(response["choices"]) > 1:
|
|
665
|
-
print(
|
|
666
|
-
_blue(f"\n(Plus {len(response['choices']) - 1} other completions)")
|
|
667
|
-
)
|
|
668
|
-
else:
|
|
669
|
-
print(_yellow("Warning: Not a standard OpenAI response object"))
|
|
670
|
-
if isinstance(response, str):
|
|
671
|
-
print(_green(response.strip()))
|
|
672
|
-
elif isinstance(response, dict):
|
|
673
|
-
print(_green(json.dumps(response, indent=2)))
|
|
674
|
-
else:
|
|
675
|
-
print(_green(str(response)))
|
|
676
|
-
|
|
677
|
-
# ------------------------------------------------------------------ #
|
|
678
|
-
# Misc helpers
|
|
679
|
-
# ------------------------------------------------------------------ #
|
|
680
|
-
def set_model(self, model: str) -> None:
|
|
681
|
-
self.model = model
|
|
682
|
-
|
|
683
|
-
@staticmethod
|
|
684
|
-
async def list_models(port=None, host="localhost") -> List[str]:
|
|
685
|
-
try:
|
|
686
|
-
client: AsyncOpenAI = AsyncLM(port=port, host=host).client # type: ignore[arg-type]
|
|
687
|
-
base_url: URL = client.base_url
|
|
688
|
-
logger.debug(f"Base URL: {base_url}")
|
|
689
|
-
models: AsyncSyncPage[Model] = await client.models.list() # type: ignore[assignment]
|
|
690
|
-
return [model.id for model in models.data]
|
|
691
|
-
except Exception as exc:
|
|
692
|
-
logger.error(f"Failed to list models: {exc}")
|
|
693
|
-
return []
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
# --------------------------------------------------------------------------- #
|
|
697
|
-
# Module-level utility functions (async versions)
|
|
698
|
-
# --------------------------------------------------------------------------- #
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
@lru_cache(maxsize=10)
|
|
702
|
-
def get_tokenizer(model_name: str) -> Any:
|
|
703
|
-
"""Get tokenizer for the given model."""
|
|
704
|
-
from transformers import AutoTokenizer # type: ignore
|
|
705
|
-
|
|
706
|
-
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
|
707
|
-
return tokenizer
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
async def inspect_word_probs_async(lm, tokenizer, messages):
|
|
711
|
-
"""Async version of inspect_word_probs."""
|
|
712
|
-
|
|
713
|
-
import numpy as np
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
async def compute_word_log_probs(
|
|
717
|
-
tokenizer: Any,
|
|
718
|
-
lm_client: Any,
|
|
719
|
-
) -> tuple[List[Dict[str, Any]], Any]:
|
|
720
|
-
# Build a prompt that preserves literal newlines
|
|
721
|
-
prompt = tokenizer.apply_chat_template(
|
|
722
|
-
messages,
|
|
723
|
-
tokenize=False, # Don't tokenize yet, we need raw text
|
|
724
|
-
add_generation_prompt=False, # No generation prompt needed
|
|
725
|
-
)
|
|
726
|
-
|
|
727
|
-
# Request token logprobs
|
|
728
|
-
response = await lm_client.client.completions.create(
|
|
729
|
-
model=lm_client.model, # type: ignore
|
|
730
|
-
prompt=prompt,
|
|
731
|
-
max_tokens=1,
|
|
732
|
-
logprobs=1,
|
|
733
|
-
extra_body={"prompt_logprobs": 0},
|
|
734
|
-
)
|
|
735
|
-
token_logprob_dicts = response.choices[0].prompt_logprobs # type: ignore
|
|
736
|
-
|
|
737
|
-
# Override first token to known start marker
|
|
738
|
-
start_id = tokenizer.encode("<|im_start|>")[0]
|
|
739
|
-
token_logprob_dicts[0] = {
|
|
740
|
-
str(start_id): {
|
|
741
|
-
"logprob": -1,
|
|
742
|
-
"rank": 1,
|
|
743
|
-
"decoded_token": "<|im_start|>",
|
|
744
|
-
}
|
|
745
|
-
}
|
|
746
|
-
|
|
747
|
-
# Flatten tokens
|
|
748
|
-
tokens: List[Dict[str, Any]] = [
|
|
749
|
-
{"id": int(tid), **tdata}
|
|
750
|
-
for td in token_logprob_dicts
|
|
751
|
-
for tid, tdata in td.items()
|
|
752
|
-
]
|
|
753
|
-
|
|
754
|
-
# Validate tokenization
|
|
755
|
-
tokenized = tokenizer.tokenize(prompt)
|
|
756
|
-
if len(tokenized) != len(tokens):
|
|
757
|
-
raise ValueError(f"Token count mismatch: {len(tokenized)} vs {len(tokens)}")
|
|
758
|
-
for idx, tok in enumerate(tokens):
|
|
759
|
-
if tokenized[idx] != tok["decoded_token"]:
|
|
760
|
-
raise AssertionError(
|
|
761
|
-
f"Token mismatch at {idx}: "
|
|
762
|
-
f"{tokenized[idx]} != {tok['decoded_token']}"
|
|
763
|
-
)
|
|
764
|
-
|
|
765
|
-
# Split on newline sentinel
|
|
766
|
-
split_prompt = prompt.replace("\n", " <NL> ")
|
|
767
|
-
words = split_prompt.split()
|
|
768
|
-
|
|
769
|
-
word_log_probs: List[Dict[str, Any]] = []
|
|
770
|
-
token_idx = 0
|
|
771
|
-
|
|
772
|
-
for word in words:
|
|
773
|
-
# Map sentinel back to actual newline for encoding
|
|
774
|
-
target = "\n" if word == "<NL>" else word
|
|
775
|
-
sub_ids = tokenizer.encode(target, add_special_tokens=False)
|
|
776
|
-
count = len(sub_ids)
|
|
777
|
-
if count == 0:
|
|
778
|
-
continue
|
|
779
|
-
|
|
780
|
-
subs = tokens[token_idx : token_idx + count]
|
|
781
|
-
avg_logprob = sum(s["logprob"] for s in subs) / count
|
|
782
|
-
prob = float(np.exp(avg_logprob))
|
|
783
|
-
word_log_probs.append({"word": target, "probability": prob})
|
|
784
|
-
token_idx += count
|
|
785
|
-
|
|
786
|
-
return word_log_probs, token_logprob_dicts # type: ignore
|
|
787
|
-
|
|
788
|
-
def render_by_logprob(word_log_probs: List[Dict[str, Any]]) -> str:
|
|
789
|
-
"""
|
|
790
|
-
Return an ANSI-colored string for word probabilities (red → green).
|
|
791
|
-
"""
|
|
792
|
-
if not word_log_probs:
|
|
793
|
-
return ""
|
|
794
|
-
|
|
795
|
-
probs = [entry["probability"] for entry in word_log_probs]
|
|
796
|
-
min_p, max_p = min(probs), max(probs)
|
|
797
|
-
parts: List[str] = []
|
|
798
|
-
|
|
799
|
-
for entry in word_log_probs:
|
|
800
|
-
word = entry["word"]
|
|
801
|
-
# Preserve actual line breaks
|
|
802
|
-
if word == "\n":
|
|
803
|
-
parts.append("\n")
|
|
804
|
-
continue
|
|
805
|
-
|
|
806
|
-
p = entry["probability"]
|
|
807
|
-
norm = (p - min_p) / (max_p - min_p or 1.0)
|
|
808
|
-
r = int(255 * (1 - norm)) # red component (high when prob is low)
|
|
809
|
-
g = int(255 * norm) # green component (high when prob is high)
|
|
810
|
-
b = 0 # no blue for red-green gradient
|
|
811
|
-
colored = f"\x1b[38;2;{r};{g};{b}m{word}\x1b[0m"
|
|
812
|
-
parts.append(colored + " ")
|
|
813
|
-
|
|
814
|
-
return "".join(parts).rstrip()
|
|
815
|
-
|
|
816
|
-
word_probs, token_logprob_dicts = await compute_word_log_probs(tokenizer, lm)
|
|
817
|
-
return word_probs, token_logprob_dicts, render_by_logprob(word_probs)
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
# --------------------------------------------------------------------------- #
|
|
821
|
-
# Async LLMTask class
|
|
822
|
-
# --------------------------------------------------------------------------- #
|
|
823
|
-
|
|
824
|
-
InputModelType = TypeVar("InputModelType", bound=BaseModel)
|
|
825
|
-
OutputModelType = TypeVar("OutputModelType", bound=BaseModel)
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
829
|
-
"""
|
|
830
|
-
Async callable wrapper around an AsyncLM endpoint.
|
|
831
|
-
|
|
832
|
-
Sub-classes must set:
|
|
833
|
-
• lm – the async language-model instance
|
|
834
|
-
• InputModel – a Pydantic input class
|
|
835
|
-
• OutputModel – a Pydantic output class
|
|
836
|
-
|
|
837
|
-
Optional flags:
|
|
838
|
-
• temperature – float (default 0.6)
|
|
839
|
-
• think – bool (if the backend supports "chain-of-thought")
|
|
840
|
-
• add_json_schema – bool (include schema in the instruction)
|
|
841
|
-
|
|
842
|
-
The **docstring** of each sub-class is sent as the LM instruction.
|
|
843
|
-
Example
|
|
844
|
-
```python
|
|
845
|
-
class DemoTask(AsyncLLMTask):
|
|
846
|
-
"TODO: SYSTEM_PROMPT_INSTURCTION HERE"
|
|
847
|
-
|
|
848
|
-
lm = AsyncLM(port=8130, cache=False, model="gpt-3.5-turbo")
|
|
849
|
-
|
|
850
|
-
class InputModel(BaseModel):
|
|
851
|
-
text_to_translate:str
|
|
852
|
-
|
|
853
|
-
class OutputModel(BaseModel):
|
|
854
|
-
translation:str
|
|
855
|
-
glossary_use:str
|
|
856
|
-
|
|
857
|
-
temperature = 0.6
|
|
858
|
-
think=False
|
|
859
|
-
|
|
860
|
-
demo_task = DemoTask()
|
|
861
|
-
result = await demo_task({'text_to_translate': 'Translate from english to vietnamese: Hello how are you'})
|
|
862
|
-
```
|
|
863
|
-
"""
|
|
864
|
-
|
|
865
|
-
lm: "AsyncLM"
|
|
866
|
-
InputModel: InputModelType
|
|
867
|
-
OutputModel: OutputModelType
|
|
868
|
-
|
|
869
|
-
temperature: float = 0.6
|
|
870
|
-
think: bool = False
|
|
871
|
-
add_json_schema: bool = False
|
|
872
|
-
cache: bool = False
|
|
873
|
-
|
|
874
|
-
async def __call__(
|
|
875
|
-
self,
|
|
876
|
-
data: BaseModel | dict,
|
|
877
|
-
temperature: float = 0.1,
|
|
878
|
-
cache: bool = False,
|
|
879
|
-
think: Optional[bool] = None, # if not None, overrides self.think
|
|
880
|
-
) -> tuple[OutputModelType, List[Dict[str, Any]]]:
|
|
881
|
-
# Get the input and output model types from the generic parameters
|
|
882
|
-
type_args = getattr(self.__class__, "__orig_bases__", None)
|
|
883
|
-
if (
|
|
884
|
-
type_args
|
|
885
|
-
and hasattr(type_args[0], "__args__")
|
|
886
|
-
and len(type_args[0].__args__) >= 2
|
|
887
|
-
):
|
|
888
|
-
input_model = type_args[0].__args__[0]
|
|
889
|
-
output_model = type_args[0].__args__[1]
|
|
890
|
-
else:
|
|
891
|
-
# Fallback to the old way if type introspection fails
|
|
892
|
-
if (
|
|
893
|
-
not hasattr(self, "InputModel")
|
|
894
|
-
or not hasattr(self, "OutputModel")
|
|
895
|
-
or not hasattr(self, "lm")
|
|
896
|
-
):
|
|
897
|
-
raise NotImplementedError(
|
|
898
|
-
f"{self.__class__.__name__} must define lm, InputModel, and OutputModel as class attributes or use proper generic typing."
|
|
899
|
-
)
|
|
900
|
-
input_model = self.InputModel
|
|
901
|
-
output_model = self.OutputModel
|
|
902
|
-
|
|
903
|
-
# Ensure input_model is a class before calling
|
|
904
|
-
if isinstance(data, BaseModel):
|
|
905
|
-
item = data
|
|
906
|
-
elif isinstance(input_model, type) and issubclass(input_model, BaseModel):
|
|
907
|
-
item = input_model(**data)
|
|
908
|
-
else:
|
|
909
|
-
raise TypeError("InputModel must be a subclass of BaseModel")
|
|
910
|
-
|
|
911
|
-
assert isinstance(output_model, type) and issubclass(output_model, BaseModel), (
|
|
912
|
-
"OutputModel must be a subclass of BaseModel"
|
|
913
|
-
)
|
|
914
|
-
|
|
915
|
-
result = await self.lm.parse(
|
|
916
|
-
prompt=item.model_dump_json(),
|
|
917
|
-
instruction=self.__doc__ or "",
|
|
918
|
-
response_model=output_model,
|
|
919
|
-
temperature=temperature or self.temperature,
|
|
920
|
-
think=think if think is not None else self.think,
|
|
921
|
-
add_json_schema_to_instruction=self.add_json_schema,
|
|
922
|
-
cache=self.cache or cache,
|
|
923
|
-
)
|
|
924
|
-
|
|
925
|
-
return (
|
|
926
|
-
cast(OutputModelType, result["parsed"]), # type: ignore
|
|
927
|
-
cast(List[dict], result["messages"]), # type: ignore
|
|
928
|
-
)
|
|
929
|
-
|
|
930
|
-
def generate_training_data(
|
|
931
|
-
self, input_dict: Dict[str, Any], output: Dict[str, Any]
|
|
932
|
-
) -> Dict[str, Any]:
|
|
933
|
-
"""Return share gpt like format"""
|
|
934
|
-
system_prompt = self.__doc__ or ""
|
|
935
|
-
user_msg = self.InputModel(**input_dict).model_dump_json() # type: ignore[attr-defined]
|
|
936
|
-
assistant_msg = self.OutputModel(**output).model_dump_json() # type: ignore[attr-defined]
|
|
937
|
-
messages = get_conversation_one_turn(
|
|
938
|
-
system_msg=system_prompt, user_msg=user_msg, assistant_msg=assistant_msg
|
|
939
|
-
)
|
|
940
|
-
return {"messages": messages}
|
|
941
|
-
|
|
942
|
-
arun = __call__ # alias for compatibility with other LLMTask implementations
|