speedy-utils 1.1.6__py3-none-any.whl → 1.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +1 -5
- llm_utils/chat_format/transform.py +9 -9
- llm_utils/group_messages.py +1 -1
- llm_utils/lm/async_lm/__init__.py +6 -1
- llm_utils/lm/async_lm/_utils.py +7 -4
- llm_utils/lm/async_lm/async_llm_task.py +472 -110
- llm_utils/lm/async_lm/async_lm.py +273 -665
- llm_utils/lm/async_lm/async_lm_base.py +407 -0
- llm_utils/lm/async_lm/lm_specific.py +136 -0
- llm_utils/lm/utils.py +1 -3
- llm_utils/scripts/vllm_load_balancer.py +49 -37
- speedy_utils/__init__.py +3 -1
- speedy_utils/common/notebook_utils.py +4 -4
- speedy_utils/common/report_manager.py +2 -3
- speedy_utils/common/utils_cache.py +233 -3
- speedy_utils/common/utils_io.py +2 -0
- speedy_utils/scripts/mpython.py +1 -3
- {speedy_utils-1.1.6.dist-info → speedy_utils-1.1.8.dist-info}/METADATA +1 -1
- speedy_utils-1.1.8.dist-info/RECORD +39 -0
- llm_utils/lm/chat_html.py +0 -246
- llm_utils/lm/lm_json.py +0 -68
- llm_utils/lm/sync_lm.py +0 -943
- speedy_utils-1.1.6.dist-info/RECORD +0 -40
- {speedy_utils-1.1.6.dist-info → speedy_utils-1.1.8.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.6.dist-info → speedy_utils-1.1.8.dist-info}/entry_points.txt +0 -0
llm_utils/lm/sync_lm.py
DELETED
|
@@ -1,943 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
# ============================================================================= #
|
|
3
|
-
# SYNCHRONOUS LANGUAGE MODEL WRAPPER WITH OPENAI COMPATIBILITY
|
|
4
|
-
# ============================================================================= #
|
|
5
|
-
#
|
|
6
|
-
# Title & Intent:
|
|
7
|
-
# Unified synchronous language model interface with caching, type safety, and OpenAI API compatibility
|
|
8
|
-
#
|
|
9
|
-
# High-level Summary:
|
|
10
|
-
# This module provides a comprehensive synchronous wrapper for language models that supports both
|
|
11
|
-
# string prompts and structured Pydantic model responses. It includes intelligent caching with
|
|
12
|
-
# content-based hashing, automatic retry logic for rate limits, and seamless integration with
|
|
13
|
-
# OpenAI-compatible APIs. The LM class handles message formatting, response parsing, token counting,
|
|
14
|
-
# and provides detailed logging and debugging capabilities for production use.
|
|
15
|
-
#
|
|
16
|
-
# Public API / Data Contracts:
|
|
17
|
-
# • LM(model, temperature=0.0, max_tokens=2000, host="localhost", port=None, **kwargs) - Main wrapper class
|
|
18
|
-
# • LM.__call__(prompt=None, messages=None, response_format=str, cache=None, **kwargs) -> str | BaseModel
|
|
19
|
-
# • LM.list_models(port=None) -> List[str] - Enumerate available models
|
|
20
|
-
# • LM.count_tokens(messages, model=None) -> int - Token counting utility
|
|
21
|
-
# • LM.price(messages, model=None, response_tokens=0) -> float - Cost estimation
|
|
22
|
-
# • LM.set_model(model_name) -> None - Runtime model switching
|
|
23
|
-
# • TModel = TypeVar("TModel", bound=BaseModel) - Generic type for structured responses
|
|
24
|
-
# • Messages = List[ChatCompletionMessageParam] - Typed message format
|
|
25
|
-
# • RawMsgs = Union[Messages, LegacyMsgs] - Flexible input format
|
|
26
|
-
#
|
|
27
|
-
# Invariants / Constraints:
|
|
28
|
-
# • MUST provide either 'prompt' or 'messages' parameter, but not both
|
|
29
|
-
# • MUST set model name before making API calls (auto-detection available)
|
|
30
|
-
# • response_format=str MUST return string; response_format=PydanticModel MUST return model instance
|
|
31
|
-
# • Caching MUST use content-based hashing for reproducible results
|
|
32
|
-
# • MUST handle OpenAI rate limits with exponential backoff (up to 3 retries)
|
|
33
|
-
# • MUST preserve message order and format during transformations
|
|
34
|
-
# • Token counting SHOULD use tiktoken when available, fall back to character estimation
|
|
35
|
-
# • MUST validate Pydantic responses and retry on parsing failures
|
|
36
|
-
#
|
|
37
|
-
# Usage Example:
|
|
38
|
-
# ```python
|
|
39
|
-
# from llm_utils.lm.sync_lm import LM
|
|
40
|
-
# from pydantic import BaseModel
|
|
41
|
-
#
|
|
42
|
-
# class CodeResponse(BaseModel):
|
|
43
|
-
# language: str
|
|
44
|
-
# code: str
|
|
45
|
-
# explanation: str
|
|
46
|
-
#
|
|
47
|
-
# # String response
|
|
48
|
-
# lm = LM(model="gpt-4o-mini", temperature=0.1)
|
|
49
|
-
# response = lm(prompt="Write a Python hello world")
|
|
50
|
-
# print(response) # Returns string
|
|
51
|
-
#
|
|
52
|
-
# # Structured response
|
|
53
|
-
# code_response = lm(
|
|
54
|
-
# prompt="Write a Python function to calculate fibonacci",
|
|
55
|
-
# response_format=CodeResponse
|
|
56
|
-
# )
|
|
57
|
-
# print(f"Language: {code_response.language}") # Returns CodeResponse instance
|
|
58
|
-
#
|
|
59
|
-
# # Message-based conversation
|
|
60
|
-
# messages = [
|
|
61
|
-
# {"role": "system", "content": "You are a helpful coding assistant"},
|
|
62
|
-
# {"role": "user", "content": "Explain async/await in Python"}
|
|
63
|
-
# ]
|
|
64
|
-
# response = lm(messages=messages, max_tokens=1000)
|
|
65
|
-
# ```
|
|
66
|
-
#
|
|
67
|
-
# TODO & Future Work:
|
|
68
|
-
# • Add streaming response support for long-form generation
|
|
69
|
-
# • Implement fine-grained token usage tracking per conversation
|
|
70
|
-
# • Add support for function calling and tool use
|
|
71
|
-
# • Optimize caching strategy for conversation contexts
|
|
72
|
-
# • Add async context manager support for resource cleanup
|
|
73
|
-
#
|
|
74
|
-
# ============================================================================= #
|
|
75
|
-
"""
|
|
76
|
-
|
|
77
|
-
from __future__ import annotations
|
|
78
|
-
|
|
79
|
-
import base64
|
|
80
|
-
import hashlib
|
|
81
|
-
import json
|
|
82
|
-
import os
|
|
83
|
-
from abc import ABC
|
|
84
|
-
from functools import lru_cache
|
|
85
|
-
from typing import (
|
|
86
|
-
Any,
|
|
87
|
-
Dict,
|
|
88
|
-
List,
|
|
89
|
-
Literal,
|
|
90
|
-
Optional,
|
|
91
|
-
Sequence,
|
|
92
|
-
Type,
|
|
93
|
-
TypeVar,
|
|
94
|
-
Union,
|
|
95
|
-
cast,
|
|
96
|
-
overload,
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
from loguru import logger
|
|
100
|
-
from openai import AuthenticationError, OpenAI, RateLimitError
|
|
101
|
-
from openai.pagination import SyncPage
|
|
102
|
-
from openai.types.chat import (
|
|
103
|
-
ChatCompletionAssistantMessageParam,
|
|
104
|
-
ChatCompletionMessageParam,
|
|
105
|
-
ChatCompletionSystemMessageParam,
|
|
106
|
-
ChatCompletionToolMessageParam,
|
|
107
|
-
ChatCompletionUserMessageParam,
|
|
108
|
-
)
|
|
109
|
-
from openai.types.model import Model
|
|
110
|
-
from pydantic import BaseModel
|
|
111
|
-
|
|
112
|
-
from llm_utils.chat_format.display import get_conversation_one_turn
|
|
113
|
-
from speedy_utils.common.utils_io import jdumps
|
|
114
|
-
|
|
115
|
-
# --------------------------------------------------------------------------- #
|
|
116
|
-
# type helpers
|
|
117
|
-
# --------------------------------------------------------------------------- #
|
|
118
|
-
TModel = TypeVar("TModel", bound=BaseModel)
|
|
119
|
-
Messages = List[ChatCompletionMessageParam] # final, already-typed messages
|
|
120
|
-
LegacyMsgs = List[Dict[str, str]] # old “…role/content…” dicts
|
|
121
|
-
RawMsgs = Union[Messages, LegacyMsgs] # what __call__ accepts
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
# --------------------------------------------------------------------------- #
|
|
125
|
-
# color formatting helpers
|
|
126
|
-
# --------------------------------------------------------------------------- #
|
|
127
|
-
def _red(text: str) -> str:
|
|
128
|
-
"""Format text with red color."""
|
|
129
|
-
return f"\x1b[31m{text}\x1b[0m"
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
def _green(text: str) -> str:
|
|
133
|
-
"""Format text with green color."""
|
|
134
|
-
return f"\x1b[32m{text}\x1b[0m"
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
def _blue(text: str) -> str:
|
|
138
|
-
"""Format text with blue color."""
|
|
139
|
-
return f"\x1b[34m{text}\x1b[0m"
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
def _yellow(text: str) -> str:
|
|
143
|
-
"""Format text with yellow color."""
|
|
144
|
-
return f"\x1b[33m{text}\x1b[0m"
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
# from functools import lru_cache
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
# @lru_cache(maxsize=10)
|
|
151
|
-
# def get_tok(tokenizer_name):
|
|
152
|
-
# from transformers import AutoTokenizer
|
|
153
|
-
|
|
154
|
-
# tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
|
|
155
|
-
# return tokenizer
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
class LM:
|
|
159
|
-
"""
|
|
160
|
-
Unified language-model wrapper.
|
|
161
|
-
|
|
162
|
-
• `response_format=str` → returns `str`
|
|
163
|
-
• `response_format=YourPydanticModel` → returns that model instance
|
|
164
|
-
"""
|
|
165
|
-
|
|
166
|
-
# --------------------------------------------------------------------- #
|
|
167
|
-
# ctor / plumbing
|
|
168
|
-
# --------------------------------------------------------------------- #
|
|
169
|
-
def __init__(
|
|
170
|
-
self,
|
|
171
|
-
model: str | None = None,
|
|
172
|
-
*,
|
|
173
|
-
temperature: float = 0.0,
|
|
174
|
-
max_tokens: int = 2_000,
|
|
175
|
-
host: str = "localhost",
|
|
176
|
-
port: Optional[int | str] = None,
|
|
177
|
-
base_url: Optional[str] = None,
|
|
178
|
-
api_key: Optional[str] = None,
|
|
179
|
-
cache: bool = True,
|
|
180
|
-
**openai_kwargs: Any,
|
|
181
|
-
) -> None:
|
|
182
|
-
self.model = model
|
|
183
|
-
self.temperature = temperature
|
|
184
|
-
self.max_tokens = max_tokens
|
|
185
|
-
self.base_url = base_url or (f"http://{host}:{port}/v1" if port else None)
|
|
186
|
-
self.api_key = api_key or os.getenv("OPENAI_API_KEY", "abc")
|
|
187
|
-
self.openai_kwargs = openai_kwargs
|
|
188
|
-
self.do_cache = cache
|
|
189
|
-
self._init_port = port # <-- store the port provided at init
|
|
190
|
-
|
|
191
|
-
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
|
192
|
-
|
|
193
|
-
def set_model(self, model: str) -> None:
|
|
194
|
-
"""Set the model name after initialization."""
|
|
195
|
-
self.model = model
|
|
196
|
-
|
|
197
|
-
# --------------------------------------------------------------------- #
|
|
198
|
-
# public API – typed overloads
|
|
199
|
-
# --------------------------------------------------------------------- #
|
|
200
|
-
@overload
|
|
201
|
-
def __call__(
|
|
202
|
-
self,
|
|
203
|
-
*,
|
|
204
|
-
prompt: str | None = ...,
|
|
205
|
-
messages: RawMsgs | None = ...,
|
|
206
|
-
response_format: type[str] = str,
|
|
207
|
-
return_openai_response: bool = ...,
|
|
208
|
-
**kwargs: Any,
|
|
209
|
-
) -> str: ...
|
|
210
|
-
|
|
211
|
-
@overload
|
|
212
|
-
def __call__(
|
|
213
|
-
self,
|
|
214
|
-
*,
|
|
215
|
-
prompt: str | None = ...,
|
|
216
|
-
messages: RawMsgs | None = ...,
|
|
217
|
-
response_format: Type[TModel],
|
|
218
|
-
return_openai_response: bool = ...,
|
|
219
|
-
**kwargs: Any,
|
|
220
|
-
) -> TModel: ...
|
|
221
|
-
|
|
222
|
-
# single implementation
|
|
223
|
-
def __call__(
|
|
224
|
-
self,
|
|
225
|
-
prompt: Optional[str] = None,
|
|
226
|
-
messages: Optional[RawMsgs] = None,
|
|
227
|
-
response_format: Union[type[str], Type[BaseModel]] = str,
|
|
228
|
-
cache: Optional[bool] = None,
|
|
229
|
-
max_tokens: Optional[int] = None,
|
|
230
|
-
return_openai_response: bool = False,
|
|
231
|
-
**kwargs: Any,
|
|
232
|
-
):
|
|
233
|
-
# argument validation ------------------------------------------------
|
|
234
|
-
if (prompt is None) == (messages is None):
|
|
235
|
-
raise ValueError("Provide *either* `prompt` or `messages` (but not both).")
|
|
236
|
-
|
|
237
|
-
if prompt is not None:
|
|
238
|
-
messages = [{"role": "user", "content": prompt}]
|
|
239
|
-
|
|
240
|
-
assert messages is not None # for type-checker
|
|
241
|
-
|
|
242
|
-
# If model is not specified, but port is provided, use the first available model
|
|
243
|
-
if self.model is None:
|
|
244
|
-
port = self._init_port
|
|
245
|
-
if port:
|
|
246
|
-
available_models = self.list_models(port=port)
|
|
247
|
-
if available_models:
|
|
248
|
-
self.model = available_models[0]
|
|
249
|
-
logger.debug(f"Auto-selected model: {self.model}")
|
|
250
|
-
else:
|
|
251
|
-
raise ValueError("No models available to select from.")
|
|
252
|
-
else:
|
|
253
|
-
raise AssertionError("Model must be set before calling.")
|
|
254
|
-
|
|
255
|
-
openai_msgs: Messages = (
|
|
256
|
-
self._convert_messages(cast(LegacyMsgs, messages))
|
|
257
|
-
if isinstance(messages[0], dict) # legacy style
|
|
258
|
-
else cast(Messages, messages) # already typed
|
|
259
|
-
)
|
|
260
|
-
|
|
261
|
-
kw = dict(
|
|
262
|
-
self.openai_kwargs,
|
|
263
|
-
temperature=self.temperature,
|
|
264
|
-
max_tokens=max_tokens or self.max_tokens,
|
|
265
|
-
)
|
|
266
|
-
kw.update(kwargs)
|
|
267
|
-
use_cache = self.do_cache if cache is None else cache
|
|
268
|
-
|
|
269
|
-
raw_response = self._call_raw(
|
|
270
|
-
openai_msgs,
|
|
271
|
-
response_format=response_format,
|
|
272
|
-
use_cache=use_cache,
|
|
273
|
-
**kw,
|
|
274
|
-
)
|
|
275
|
-
|
|
276
|
-
if return_openai_response:
|
|
277
|
-
response = raw_response
|
|
278
|
-
else:
|
|
279
|
-
response = self._parse_output(raw_response, response_format)
|
|
280
|
-
|
|
281
|
-
self.last_log = [prompt, messages, raw_response]
|
|
282
|
-
return response
|
|
283
|
-
|
|
284
|
-
def inspect_history(self) -> None:
|
|
285
|
-
if not hasattr(self, "last_log"):
|
|
286
|
-
raise ValueError("No history available. Please call the model first.")
|
|
287
|
-
|
|
288
|
-
prompt, messages, response = self.last_log
|
|
289
|
-
# Ensure response is a dictionary
|
|
290
|
-
if hasattr(response, "model_dump"):
|
|
291
|
-
response = response.model_dump()
|
|
292
|
-
|
|
293
|
-
if not messages:
|
|
294
|
-
messages = [{"role": "user", "content": prompt}]
|
|
295
|
-
|
|
296
|
-
print("\n\n")
|
|
297
|
-
print(_blue("[Conversation History]") + "\n")
|
|
298
|
-
|
|
299
|
-
# Print all messages in the conversation
|
|
300
|
-
for msg in messages:
|
|
301
|
-
role = msg["role"]
|
|
302
|
-
content = msg["content"]
|
|
303
|
-
print(_red(f"{role.capitalize()}:"))
|
|
304
|
-
|
|
305
|
-
if isinstance(content, str):
|
|
306
|
-
print(content.strip())
|
|
307
|
-
elif isinstance(content, list):
|
|
308
|
-
# Handle multimodal content
|
|
309
|
-
for item in content:
|
|
310
|
-
if item.get("type") == "text":
|
|
311
|
-
print(item["text"].strip())
|
|
312
|
-
elif item.get("type") == "image_url":
|
|
313
|
-
image_url = item["image_url"]["url"]
|
|
314
|
-
if "base64" in image_url:
|
|
315
|
-
len_base64 = len(image_url.split("base64,")[1])
|
|
316
|
-
print(_blue(f"<IMAGE BASE64 ENCODED({len_base64})>"))
|
|
317
|
-
else:
|
|
318
|
-
print(_blue(f"<image_url: {image_url}>"))
|
|
319
|
-
print("\n")
|
|
320
|
-
|
|
321
|
-
# Print the response - now always an OpenAI completion
|
|
322
|
-
print(_red("Response:"))
|
|
323
|
-
|
|
324
|
-
# Handle OpenAI response object
|
|
325
|
-
if isinstance(response, dict) and "choices" in response and response["choices"]:
|
|
326
|
-
message = response["choices"][0].get("message", {})
|
|
327
|
-
|
|
328
|
-
# Check for reasoning content (if available)
|
|
329
|
-
reasoning = message.get("reasoning_content")
|
|
330
|
-
|
|
331
|
-
# Check for parsed content (structured mode)
|
|
332
|
-
parsed = message.get("parsed")
|
|
333
|
-
|
|
334
|
-
# Get regular content
|
|
335
|
-
content = message.get("content")
|
|
336
|
-
|
|
337
|
-
# Display reasoning if available
|
|
338
|
-
if reasoning:
|
|
339
|
-
print(_yellow("<think>"))
|
|
340
|
-
print(reasoning.strip())
|
|
341
|
-
print(_yellow("</think>"))
|
|
342
|
-
print()
|
|
343
|
-
|
|
344
|
-
# Display parsed content for structured responses
|
|
345
|
-
if parsed:
|
|
346
|
-
# print(_green('<Parsed Structure>'))
|
|
347
|
-
if hasattr(parsed, "model_dump"):
|
|
348
|
-
print(jdumps(parsed.model_dump(), indent=2))
|
|
349
|
-
else:
|
|
350
|
-
print(jdumps(parsed, indent=2))
|
|
351
|
-
# print(_green('</Parsed Structure>'))
|
|
352
|
-
print()
|
|
353
|
-
|
|
354
|
-
else:
|
|
355
|
-
if content:
|
|
356
|
-
# print(_green("<Content>"))
|
|
357
|
-
print(content.strip())
|
|
358
|
-
# print(_green("</Content>"))
|
|
359
|
-
else:
|
|
360
|
-
print(_green("[No content]"))
|
|
361
|
-
|
|
362
|
-
# Show if there were multiple completions
|
|
363
|
-
if len(response["choices"]) > 1:
|
|
364
|
-
print(
|
|
365
|
-
_blue(f"\n(Plus {len(response['choices']) - 1} other completions)")
|
|
366
|
-
)
|
|
367
|
-
else:
|
|
368
|
-
# Fallback for non-standard response objects or cached responses
|
|
369
|
-
print(_yellow("Warning: Not a standard OpenAI response object"))
|
|
370
|
-
if isinstance(response, str):
|
|
371
|
-
print(_green(response.strip()))
|
|
372
|
-
elif isinstance(response, dict):
|
|
373
|
-
print(_green(json.dumps(response, indent=2)))
|
|
374
|
-
else:
|
|
375
|
-
print(_green(str(response)))
|
|
376
|
-
|
|
377
|
-
# print("\n\n")
|
|
378
|
-
|
|
379
|
-
# --------------------------------------------------------------------- #
|
|
380
|
-
# low-level OpenAI call
|
|
381
|
-
# --------------------------------------------------------------------- #
|
|
382
|
-
def _call_raw(
|
|
383
|
-
self,
|
|
384
|
-
messages: Sequence[ChatCompletionMessageParam],
|
|
385
|
-
response_format: Union[type[str], Type[BaseModel]],
|
|
386
|
-
use_cache: bool,
|
|
387
|
-
**kw: Any,
|
|
388
|
-
):
|
|
389
|
-
assert self.model is not None, "Model must be set before making a call."
|
|
390
|
-
model: str = self.model
|
|
391
|
-
|
|
392
|
-
cache_key = (
|
|
393
|
-
self._cache_key(messages, kw, response_format) if use_cache else None
|
|
394
|
-
)
|
|
395
|
-
if cache_key and (hit := self._load_cache(cache_key)) is not None:
|
|
396
|
-
return hit
|
|
397
|
-
|
|
398
|
-
try:
|
|
399
|
-
# structured mode
|
|
400
|
-
if response_format is not str and issubclass(response_format, BaseModel):
|
|
401
|
-
openai_response = self.client.beta.chat.completions.parse(
|
|
402
|
-
model=model,
|
|
403
|
-
messages=list(messages),
|
|
404
|
-
response_format=response_format, # type: ignore[arg-type]
|
|
405
|
-
**kw,
|
|
406
|
-
)
|
|
407
|
-
# plain-text mode
|
|
408
|
-
else:
|
|
409
|
-
openai_response = self.client.chat.completions.create(
|
|
410
|
-
model=model,
|
|
411
|
-
messages=list(messages),
|
|
412
|
-
**kw,
|
|
413
|
-
)
|
|
414
|
-
|
|
415
|
-
except (AuthenticationError, RateLimitError) as exc: # pragma: no cover
|
|
416
|
-
logger.error(exc)
|
|
417
|
-
raise
|
|
418
|
-
|
|
419
|
-
if cache_key:
|
|
420
|
-
self._dump_cache(cache_key, openai_response)
|
|
421
|
-
|
|
422
|
-
return openai_response
|
|
423
|
-
|
|
424
|
-
# --------------------------------------------------------------------- #
|
|
425
|
-
# legacy → typed messages
|
|
426
|
-
# --------------------------------------------------------------------- #
|
|
427
|
-
@staticmethod
|
|
428
|
-
def _convert_messages(msgs: LegacyMsgs) -> Messages:
|
|
429
|
-
converted: Messages = []
|
|
430
|
-
for msg in msgs:
|
|
431
|
-
role = msg["role"]
|
|
432
|
-
content = msg["content"]
|
|
433
|
-
if role == "user":
|
|
434
|
-
converted.append(
|
|
435
|
-
ChatCompletionUserMessageParam(role="user", content=content)
|
|
436
|
-
)
|
|
437
|
-
elif role == "assistant":
|
|
438
|
-
converted.append(
|
|
439
|
-
ChatCompletionAssistantMessageParam(
|
|
440
|
-
role="assistant", content=content
|
|
441
|
-
)
|
|
442
|
-
)
|
|
443
|
-
elif role == "system":
|
|
444
|
-
converted.append(
|
|
445
|
-
ChatCompletionSystemMessageParam(role="system", content=content)
|
|
446
|
-
)
|
|
447
|
-
elif role == "tool":
|
|
448
|
-
converted.append(
|
|
449
|
-
ChatCompletionToolMessageParam(
|
|
450
|
-
role="tool",
|
|
451
|
-
content=content,
|
|
452
|
-
tool_call_id=msg.get("tool_call_id") or "", # str, never None
|
|
453
|
-
)
|
|
454
|
-
)
|
|
455
|
-
else:
|
|
456
|
-
# fall back to raw dict for unknown roles
|
|
457
|
-
converted.append({"role": role, "content": content}) # type: ignore[arg-type]
|
|
458
|
-
return converted
|
|
459
|
-
|
|
460
|
-
# --------------------------------------------------------------------- #
|
|
461
|
-
# final parse (needed for plain-text or cache hits only)
|
|
462
|
-
# --------------------------------------------------------------------- #
|
|
463
|
-
@staticmethod
|
|
464
|
-
def _parse_output(
|
|
465
|
-
raw_response: Any,
|
|
466
|
-
response_format: Union[type[str], Type[BaseModel]],
|
|
467
|
-
) -> str | BaseModel:
|
|
468
|
-
# Convert any object to dict if needed
|
|
469
|
-
if hasattr(raw_response, "model_dump"):
|
|
470
|
-
raw_response = raw_response.model_dump()
|
|
471
|
-
|
|
472
|
-
if response_format is str:
|
|
473
|
-
# Extract the content from OpenAI response dict
|
|
474
|
-
if isinstance(raw_response, dict) and "choices" in raw_response:
|
|
475
|
-
message = raw_response["choices"][0]["message"]
|
|
476
|
-
return message.get("content", "") or ""
|
|
477
|
-
return cast(str, raw_response)
|
|
478
|
-
|
|
479
|
-
# For the type-checker: we *know* it's a BaseModel subclass here.
|
|
480
|
-
model_cls = cast(Type[BaseModel], response_format)
|
|
481
|
-
|
|
482
|
-
# Handle structured response
|
|
483
|
-
if isinstance(raw_response, dict) and "choices" in raw_response:
|
|
484
|
-
message = raw_response["choices"][0]["message"]
|
|
485
|
-
|
|
486
|
-
# Check if already parsed by OpenAI client
|
|
487
|
-
if "parsed" in message:
|
|
488
|
-
return model_cls.model_validate(message["parsed"])
|
|
489
|
-
|
|
490
|
-
# Need to parse the content
|
|
491
|
-
content = message.get("content")
|
|
492
|
-
if content is None:
|
|
493
|
-
raise ValueError("Model returned empty content")
|
|
494
|
-
|
|
495
|
-
try:
|
|
496
|
-
data = json.loads(content)
|
|
497
|
-
return model_cls.model_validate(data)
|
|
498
|
-
except Exception as exc:
|
|
499
|
-
raise ValueError(
|
|
500
|
-
f"Failed to parse model output as JSON:\n{content}"
|
|
501
|
-
) from exc
|
|
502
|
-
|
|
503
|
-
# Handle cached response or other formats
|
|
504
|
-
if isinstance(raw_response, model_cls):
|
|
505
|
-
return raw_response
|
|
506
|
-
if isinstance(raw_response, dict):
|
|
507
|
-
return model_cls.model_validate(raw_response)
|
|
508
|
-
|
|
509
|
-
# Try parsing as JSON string
|
|
510
|
-
try:
|
|
511
|
-
data = json.loads(raw_response)
|
|
512
|
-
return model_cls.model_validate(data)
|
|
513
|
-
except Exception as exc:
|
|
514
|
-
raise ValueError(
|
|
515
|
-
f"Model did not return valid JSON:\n---\n{raw_response}"
|
|
516
|
-
) from exc
|
|
517
|
-
|
|
518
|
-
# --------------------------------------------------------------------- #
|
|
519
|
-
# tiny disk cache
|
|
520
|
-
# --------------------------------------------------------------------- #
|
|
521
|
-
@staticmethod
|
|
522
|
-
def _cache_key(
|
|
523
|
-
messages: Any,
|
|
524
|
-
kw: Any,
|
|
525
|
-
response_format: Union[type[str], Type[BaseModel]],
|
|
526
|
-
) -> str:
|
|
527
|
-
tag = response_format.__name__ if response_format is not str else "text"
|
|
528
|
-
blob = json.dumps([messages, kw, tag], sort_keys=True).encode()
|
|
529
|
-
return base64.urlsafe_b64encode(hashlib.sha256(blob).digest()).decode()[:22]
|
|
530
|
-
|
|
531
|
-
@staticmethod
|
|
532
|
-
def _cache_path(key: str) -> str:
|
|
533
|
-
return os.path.expanduser(f"~/.cache/lm/{key}.json")
|
|
534
|
-
|
|
535
|
-
def _dump_cache(self, key: str, val: Any) -> None:
|
|
536
|
-
try:
|
|
537
|
-
path = self._cache_path(key)
|
|
538
|
-
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
539
|
-
with open(path, "w") as fh:
|
|
540
|
-
if isinstance(val, BaseModel):
|
|
541
|
-
json.dump(val.model_dump(mode="json"), fh)
|
|
542
|
-
else:
|
|
543
|
-
json.dump(val, fh)
|
|
544
|
-
except Exception as exc: # pragma: no cover
|
|
545
|
-
logger.debug(f"cache write skipped: {exc}")
|
|
546
|
-
|
|
547
|
-
def _load_cache(self, key: str) -> Any | None:
|
|
548
|
-
path = self._cache_path(key)
|
|
549
|
-
if not os.path.exists(path):
|
|
550
|
-
return None
|
|
551
|
-
try:
|
|
552
|
-
with open(path) as fh:
|
|
553
|
-
return json.load(fh)
|
|
554
|
-
except Exception: # pragma: no cover
|
|
555
|
-
return None
|
|
556
|
-
|
|
557
|
-
@staticmethod
|
|
558
|
-
def list_models(
|
|
559
|
-
port=None, host="localhost", base_url: Optional[str] = None
|
|
560
|
-
) -> List[str]:
|
|
561
|
-
"""List available models from OpenAI-compatible API server."""
|
|
562
|
-
try:
|
|
563
|
-
client: OpenAI = OpenAI(
|
|
564
|
-
api_key=os.getenv("OPENAI_API_KEY", "abc"),
|
|
565
|
-
base_url=f"http://{host}:{port}/v1" if port else base_url or None,
|
|
566
|
-
)
|
|
567
|
-
models: SyncPage[Model] = client.models.list()
|
|
568
|
-
return [model.id for model in models.data]
|
|
569
|
-
except Exception as exc:
|
|
570
|
-
endpoint = f"http://{host}:{port}/v1" if port else base_url
|
|
571
|
-
error_msg = str(exc)
|
|
572
|
-
|
|
573
|
-
if "404" in error_msg or "Not Found" in error_msg:
|
|
574
|
-
raise ValueError(
|
|
575
|
-
f"No OpenAI-compatible API found at {endpoint}. "
|
|
576
|
-
f"The endpoint appears to be running a different service "
|
|
577
|
-
f"(possibly Jupyter Server). Please check the port number."
|
|
578
|
-
) from exc
|
|
579
|
-
elif "Connection" in error_msg:
|
|
580
|
-
raise ValueError(
|
|
581
|
-
f"Cannot connect to {endpoint}. "
|
|
582
|
-
f"Please verify the service is running and accessible."
|
|
583
|
-
) from exc
|
|
584
|
-
else:
|
|
585
|
-
raise ValueError(
|
|
586
|
-
f"Failed to list models from {endpoint}: {error_msg}"
|
|
587
|
-
) from exc
|
|
588
|
-
|
|
589
|
-
def parse(
|
|
590
|
-
self,
|
|
591
|
-
response_model: Type[BaseModel],
|
|
592
|
-
instruction: Optional[str] = None,
|
|
593
|
-
prompt: Optional[str] = None,
|
|
594
|
-
messages: Optional[RawMsgs] = None,
|
|
595
|
-
think: Literal[True, False, None] = None,
|
|
596
|
-
add_json_schema_to_instruction: bool = False,
|
|
597
|
-
temperature: Optional[float] = None,
|
|
598
|
-
max_tokens: Optional[int] = None,
|
|
599
|
-
return_openai_response: bool = False,
|
|
600
|
-
cache: Optional[bool] = True,
|
|
601
|
-
**kwargs,
|
|
602
|
-
):
|
|
603
|
-
if messages is None:
|
|
604
|
-
assert instruction is not None, "Instruction must be provided."
|
|
605
|
-
assert prompt is not None, "Prompt must be provided."
|
|
606
|
-
messages = [
|
|
607
|
-
{
|
|
608
|
-
"role": "system",
|
|
609
|
-
"content": instruction,
|
|
610
|
-
},
|
|
611
|
-
{
|
|
612
|
-
"role": "user",
|
|
613
|
-
"content": prompt,
|
|
614
|
-
},
|
|
615
|
-
] # type: ignore
|
|
616
|
-
|
|
617
|
-
post_fix = ""
|
|
618
|
-
json_schema = response_model.model_json_schema()
|
|
619
|
-
if add_json_schema_to_instruction and response_model:
|
|
620
|
-
_schema = f"\n\n<output_json_schema>\n{json.dumps(json_schema, indent=2)}\n</output_json_schema>"
|
|
621
|
-
post_fix += _schema
|
|
622
|
-
|
|
623
|
-
if think:
|
|
624
|
-
post_fix += "\n\n/think"
|
|
625
|
-
elif not think:
|
|
626
|
-
post_fix += "\n\n/no_think"
|
|
627
|
-
|
|
628
|
-
assert isinstance(messages, list), "Messages must be a list."
|
|
629
|
-
assert len(messages) > 0, "Messages cannot be empty."
|
|
630
|
-
assert messages[0]["role"] == "system", (
|
|
631
|
-
"First message must be a system message with instruction."
|
|
632
|
-
)
|
|
633
|
-
messages[0]["content"] += post_fix # type: ignore
|
|
634
|
-
|
|
635
|
-
model_kwargs = {}
|
|
636
|
-
if temperature is not None:
|
|
637
|
-
model_kwargs["temperature"] = temperature
|
|
638
|
-
if max_tokens is not None:
|
|
639
|
-
model_kwargs["max_tokens"] = max_tokens
|
|
640
|
-
model_kwargs.update(kwargs)
|
|
641
|
-
|
|
642
|
-
use_cache = self.do_cache if cache is None else cache
|
|
643
|
-
cache_key = None
|
|
644
|
-
if use_cache:
|
|
645
|
-
cache_data = {
|
|
646
|
-
"messages": messages,
|
|
647
|
-
"model_kwargs": model_kwargs,
|
|
648
|
-
"guided_json": json_schema,
|
|
649
|
-
"response_format": response_model.__name__,
|
|
650
|
-
}
|
|
651
|
-
cache_key = self._cache_key(cache_data, {}, response_model)
|
|
652
|
-
cached_response = self._load_cache(cache_key)
|
|
653
|
-
self.last_log = [prompt, messages, cached_response]
|
|
654
|
-
if cached_response is not None:
|
|
655
|
-
if return_openai_response:
|
|
656
|
-
return cached_response
|
|
657
|
-
return self._parse_complete_output(cached_response, response_model)
|
|
658
|
-
|
|
659
|
-
completion = self.client.chat.completions.create(
|
|
660
|
-
model=self.model, # type: ignore
|
|
661
|
-
messages=messages, # type: ignore
|
|
662
|
-
extra_body={"guided_json": json_schema},
|
|
663
|
-
**model_kwargs,
|
|
664
|
-
)
|
|
665
|
-
|
|
666
|
-
if cache_key:
|
|
667
|
-
self._dump_cache(cache_key, completion)
|
|
668
|
-
|
|
669
|
-
self.last_log = [prompt, messages, completion]
|
|
670
|
-
if return_openai_response:
|
|
671
|
-
return completion
|
|
672
|
-
return self._parse_complete_output(completion, response_model)
|
|
673
|
-
|
|
674
|
-
def _parse_complete_output(
|
|
675
|
-
self, completion: Any, response_model: Type[BaseModel]
|
|
676
|
-
) -> BaseModel:
|
|
677
|
-
"""Parse completion output to response model."""
|
|
678
|
-
if hasattr(completion, "model_dump"):
|
|
679
|
-
completion = completion.model_dump()
|
|
680
|
-
|
|
681
|
-
if "choices" not in completion or not completion["choices"]:
|
|
682
|
-
raise ValueError("No choices in OpenAI response")
|
|
683
|
-
|
|
684
|
-
content = completion["choices"][0]["message"]["content"]
|
|
685
|
-
if not content:
|
|
686
|
-
raise ValueError("Empty content in response")
|
|
687
|
-
|
|
688
|
-
try:
|
|
689
|
-
data = json.loads(content)
|
|
690
|
-
return response_model.model_validate(data)
|
|
691
|
-
except Exception as exc:
|
|
692
|
-
raise ValueError(
|
|
693
|
-
f"Failed to parse response as {response_model.__name__}: {content}"
|
|
694
|
-
) from exc
|
|
695
|
-
|
|
696
|
-
def inspect_word_probs(
|
|
697
|
-
self,
|
|
698
|
-
messages: Optional[List[Dict[str, Any]]] = None,
|
|
699
|
-
tokenizer: Optional[Any] = None,
|
|
700
|
-
do_print=True,
|
|
701
|
-
add_think: bool = True,
|
|
702
|
-
) -> tuple[List[Dict[str, Any]], Any, str]:
|
|
703
|
-
"""
|
|
704
|
-
Inspect word probabilities in a language model response.
|
|
705
|
-
|
|
706
|
-
Args:
|
|
707
|
-
tokenizer: Tokenizer instance to encode words.
|
|
708
|
-
messages: List of messages to analyze.
|
|
709
|
-
|
|
710
|
-
Returns:
|
|
711
|
-
A tuple containing:
|
|
712
|
-
- List of word probabilities with their log probabilities.
|
|
713
|
-
- Token log probability dictionaries.
|
|
714
|
-
- Rendered string with colored word probabilities.
|
|
715
|
-
"""
|
|
716
|
-
if messages is None:
|
|
717
|
-
messages = self.last_messages(add_think=add_think)
|
|
718
|
-
if messages is None:
|
|
719
|
-
raise ValueError("No messages provided and no last messages available.")
|
|
720
|
-
|
|
721
|
-
if tokenizer is None:
|
|
722
|
-
tokenizer = get_tokenizer(self.model)
|
|
723
|
-
|
|
724
|
-
ret = inspect_word_probs(self, tokenizer, messages)
|
|
725
|
-
if do_print:
|
|
726
|
-
print(ret[-1])
|
|
727
|
-
return ret
|
|
728
|
-
|
|
729
|
-
def last_messages(self, add_think: bool = True) -> Optional[List[Dict[str, str]]]:
|
|
730
|
-
last_conv = self.last_log
|
|
731
|
-
messages = last_conv[1] if len(last_conv) > 1 else None
|
|
732
|
-
last_msg = last_conv[2]
|
|
733
|
-
if not isinstance(last_msg, dict):
|
|
734
|
-
last_conv[2] = last_conv[2].model_dump() # type: ignore
|
|
735
|
-
msg = last_conv[2]
|
|
736
|
-
# Ensure msg is a dict
|
|
737
|
-
if hasattr(msg, "model_dump"):
|
|
738
|
-
msg = msg.model_dump()
|
|
739
|
-
message = msg["choices"][0]["message"]
|
|
740
|
-
reasoning = message.get("reasoning_content")
|
|
741
|
-
answer = message.get("content")
|
|
742
|
-
if reasoning and add_think:
|
|
743
|
-
final_answer = f"<think>{reasoning}</think>\n{answer}"
|
|
744
|
-
else:
|
|
745
|
-
final_answer = f"<think>\n\n</think>\n{answer}"
|
|
746
|
-
assistant = {"role": "assistant", "content": final_answer}
|
|
747
|
-
messages = messages + [assistant] # type: ignore
|
|
748
|
-
return messages if messages else None
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
@lru_cache(maxsize=10)
|
|
752
|
-
def get_tokenizer(model_name: str) -> Any:
|
|
753
|
-
from transformers import AutoTokenizer # type: ignore
|
|
754
|
-
|
|
755
|
-
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
|
756
|
-
return tokenizer
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
def inspect_word_probs(lm, tokenizer, messages):
|
|
760
|
-
import numpy as np
|
|
761
|
-
|
|
762
|
-
def compute_word_log_probs(
|
|
763
|
-
tokenizer: Any,
|
|
764
|
-
lm_client: Any,
|
|
765
|
-
) -> tuple[List[Dict[str, Any]], Any]:
|
|
766
|
-
# Build a prompt that preserves literal newlines
|
|
767
|
-
prompt = tokenizer.apply_chat_template(
|
|
768
|
-
messages,
|
|
769
|
-
tokenize=False, # Don't tokenize yet, we need raw text
|
|
770
|
-
add_generation_prompt=False, # No generation prompt needed
|
|
771
|
-
)
|
|
772
|
-
|
|
773
|
-
# Request token logprobs
|
|
774
|
-
response = lm_client.client.completions.create(
|
|
775
|
-
model=lm_client.model, # type: ignore
|
|
776
|
-
prompt=prompt,
|
|
777
|
-
max_tokens=1,
|
|
778
|
-
logprobs=1,
|
|
779
|
-
extra_body={"prompt_logprobs": 0},
|
|
780
|
-
)
|
|
781
|
-
token_logprob_dicts = response.choices[0].prompt_logprobs # type: ignore
|
|
782
|
-
|
|
783
|
-
# Override first token to known start marker
|
|
784
|
-
start_id = tokenizer.encode("<|im_start|>")[0]
|
|
785
|
-
token_logprob_dicts[0] = {
|
|
786
|
-
str(start_id): {
|
|
787
|
-
"logprob": -1,
|
|
788
|
-
"rank": 1,
|
|
789
|
-
"decoded_token": "<|im_start|>",
|
|
790
|
-
}
|
|
791
|
-
}
|
|
792
|
-
|
|
793
|
-
# Flatten tokens
|
|
794
|
-
tokens: List[Dict[str, Any]] = [
|
|
795
|
-
{"id": int(tid), **tdata}
|
|
796
|
-
for td in token_logprob_dicts
|
|
797
|
-
for tid, tdata in td.items()
|
|
798
|
-
]
|
|
799
|
-
|
|
800
|
-
# Validate tokenization
|
|
801
|
-
tokenized = tokenizer.tokenize(prompt)
|
|
802
|
-
if len(tokenized) != len(tokens):
|
|
803
|
-
raise ValueError(f"Token count mismatch: {len(tokenized)} vs {len(tokens)}")
|
|
804
|
-
for idx, tok in enumerate(tokens):
|
|
805
|
-
if tokenized[idx] != tok["decoded_token"]:
|
|
806
|
-
raise AssertionError(
|
|
807
|
-
f"Token mismatch at {idx}: "
|
|
808
|
-
f"{tokenized[idx]} != {tok['decoded_token']}"
|
|
809
|
-
)
|
|
810
|
-
|
|
811
|
-
# Split on newline sentinel
|
|
812
|
-
split_prompt = prompt.replace("\n", " <NL> ")
|
|
813
|
-
words = split_prompt.split()
|
|
814
|
-
|
|
815
|
-
word_log_probs: List[Dict[str, Any]] = []
|
|
816
|
-
token_idx = 0
|
|
817
|
-
|
|
818
|
-
for word in words:
|
|
819
|
-
# Map sentinel back to actual newline for encoding
|
|
820
|
-
target = "\n" if word == "<NL>" else word
|
|
821
|
-
sub_ids = tokenizer.encode(target, add_special_tokens=False)
|
|
822
|
-
count = len(sub_ids)
|
|
823
|
-
if count == 0:
|
|
824
|
-
continue
|
|
825
|
-
|
|
826
|
-
subs = tokens[token_idx : token_idx + count]
|
|
827
|
-
avg_logprob = sum(s["logprob"] for s in subs) / count
|
|
828
|
-
prob = float(np.exp(avg_logprob))
|
|
829
|
-
word_log_probs.append({"word": target, "probability": prob})
|
|
830
|
-
token_idx += count
|
|
831
|
-
|
|
832
|
-
return word_log_probs, token_logprob_dicts # type: ignore
|
|
833
|
-
|
|
834
|
-
def render_by_logprob(word_log_probs: List[Dict[str, Any]]) -> str:
|
|
835
|
-
"""
|
|
836
|
-
Return an ANSI-colored string for word probabilities (red → green).
|
|
837
|
-
"""
|
|
838
|
-
if not word_log_probs:
|
|
839
|
-
return ""
|
|
840
|
-
|
|
841
|
-
probs = [entry["probability"] for entry in word_log_probs]
|
|
842
|
-
min_p, max_p = min(probs), max(probs)
|
|
843
|
-
parts: List[str] = []
|
|
844
|
-
|
|
845
|
-
for entry in word_log_probs:
|
|
846
|
-
word = entry["word"]
|
|
847
|
-
# Preserve actual line breaks
|
|
848
|
-
if word == "\n":
|
|
849
|
-
parts.append("\n")
|
|
850
|
-
continue
|
|
851
|
-
|
|
852
|
-
p = entry["probability"]
|
|
853
|
-
norm = (p - min_p) / (max_p - min_p or 1.0)
|
|
854
|
-
r = int(255 * (1 - norm)) # red component (high when prob is low)
|
|
855
|
-
g = int(255 * norm) # green component (high when prob is high)
|
|
856
|
-
b = 0 # no blue for red-green gradient
|
|
857
|
-
colored = f"\x1b[38;2;{r};{g};{b}m{word}\x1b[0m"
|
|
858
|
-
parts.append(colored + " ")
|
|
859
|
-
|
|
860
|
-
return "".join(parts).rstrip()
|
|
861
|
-
|
|
862
|
-
word_probs, token_logprob_dicts = compute_word_log_probs(tokenizer, lm)
|
|
863
|
-
return word_probs, token_logprob_dicts, render_by_logprob(word_probs)
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
class LLMTask(ABC):
|
|
867
|
-
"""
|
|
868
|
-
Callable wrapper around an LM endpoint.
|
|
869
|
-
|
|
870
|
-
Sub-classes must set:
|
|
871
|
-
• lm – the language-model instance
|
|
872
|
-
• InputModel – a Pydantic input class
|
|
873
|
-
• OutputModel – a Pydantic output class
|
|
874
|
-
|
|
875
|
-
Optional flags:
|
|
876
|
-
• temperature – float (default 0.6)
|
|
877
|
-
• think – bool (if the backend supports “chain-of-thought”)
|
|
878
|
-
• add_json_schema – bool (include schema in the instruction)
|
|
879
|
-
|
|
880
|
-
The **docstring** of each sub-class is sent as the LM instruction.
|
|
881
|
-
Example
|
|
882
|
-
```python
|
|
883
|
-
class DemoTask(LLMTask):
|
|
884
|
-
"TODO: SYSTEM_PROMPT_INSTURCTION HERE"
|
|
885
|
-
|
|
886
|
-
lm = LM(port=8130, cache=False, model="gpt-3.5-turbo")
|
|
887
|
-
|
|
888
|
-
class InputModel(BaseModel):
|
|
889
|
-
text_to_translate:str
|
|
890
|
-
|
|
891
|
-
class OutputModel(BaseModel):
|
|
892
|
-
translation:str
|
|
893
|
-
glossary_use:str
|
|
894
|
-
|
|
895
|
-
temperature = 0.6
|
|
896
|
-
think=False
|
|
897
|
-
|
|
898
|
-
demo_task = DemoTask()
|
|
899
|
-
demo_task({'text_to_translate': 'Translate from english to vietnamese: Hello how are you'})
|
|
900
|
-
```
|
|
901
|
-
"""
|
|
902
|
-
|
|
903
|
-
lm: "LM"
|
|
904
|
-
InputModel: Type[BaseModel]
|
|
905
|
-
OutputModel: Type[BaseModel]
|
|
906
|
-
|
|
907
|
-
temperature: float = 0.6
|
|
908
|
-
think: bool = False
|
|
909
|
-
add_json_schema: bool = False
|
|
910
|
-
|
|
911
|
-
def __call__(self, data: BaseModel | dict) -> BaseModel:
|
|
912
|
-
if (
|
|
913
|
-
not hasattr(self, "InputModel")
|
|
914
|
-
or not hasattr(self, "OutputModel")
|
|
915
|
-
or not hasattr(self, "lm")
|
|
916
|
-
):
|
|
917
|
-
raise NotImplementedError(
|
|
918
|
-
f"{self.__class__.__name__} must define lm, InputModel, and OutputModel as class attributes."
|
|
919
|
-
)
|
|
920
|
-
|
|
921
|
-
item = data if isinstance(data, BaseModel) else self.InputModel(**data)
|
|
922
|
-
|
|
923
|
-
return self.lm.parse(
|
|
924
|
-
prompt=item.model_dump_json(),
|
|
925
|
-
instruction=self.__doc__ or "",
|
|
926
|
-
response_model=self.OutputModel,
|
|
927
|
-
temperature=self.temperature,
|
|
928
|
-
think=self.think,
|
|
929
|
-
add_json_schema_to_instruction=self.add_json_schema,
|
|
930
|
-
)
|
|
931
|
-
|
|
932
|
-
def generate_training_data(
|
|
933
|
-
self, input_dict: Dict[str, Any], output: Dict[str, Any]
|
|
934
|
-
):
|
|
935
|
-
"Return share gpt like format"
|
|
936
|
-
system_prompt = self.__doc__ or ""
|
|
937
|
-
user_msg = self.InputModel(**input_dict).model_dump_json() # type: ignore[attr-defined]
|
|
938
|
-
assistant_msg = self.OutputModel(**output).model_dump_json() # type: ignore[attr-defined]
|
|
939
|
-
return get_conversation_one_turn(
|
|
940
|
-
system_msg=system_prompt, user_msg=user_msg, assistant_msg=assistant_msg
|
|
941
|
-
)
|
|
942
|
-
|
|
943
|
-
run = __call__ # alias for compatibility with other LLMTask implementations
|