speedy-utils 0.1.30__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +3 -3
- llm_utils/lm/__init__.py +12 -0
- llm_utils/lm/base_lm.py +337 -0
- llm_utils/lm/chat_session.py +115 -0
- llm_utils/lm/pydantic_lm.py +195 -0
- llm_utils/lm/text_lm.py +130 -0
- llm_utils/lm/utils.py +130 -0
- {speedy_utils-0.1.30.dist-info → speedy_utils-1.0.1.dist-info}/METADATA +1 -1
- {speedy_utils-0.1.30.dist-info → speedy_utils-1.0.1.dist-info}/RECORD +11 -6
- llm_utils/lm.py +0 -742
- {speedy_utils-0.1.30.dist-info → speedy_utils-1.0.1.dist-info}/WHEEL +0 -0
- {speedy_utils-0.1.30.dist-info → speedy_utils-1.0.1.dist-info}/entry_points.txt +0 -0
llm_utils/__init__.py
CHANGED
|
@@ -8,7 +8,7 @@ from .chat_format import (
|
|
|
8
8
|
build_chatml_input,
|
|
9
9
|
format_msgs,
|
|
10
10
|
)
|
|
11
|
-
from .lm import
|
|
11
|
+
from .lm import PydanticLM, TextLM
|
|
12
12
|
from .group_messages import (
|
|
13
13
|
split_indices_by_length,
|
|
14
14
|
group_messages_by_len,
|
|
@@ -23,8 +23,8 @@ __all__ = [
|
|
|
23
23
|
"display_conversations",
|
|
24
24
|
"build_chatml_input",
|
|
25
25
|
"format_msgs",
|
|
26
|
-
"OAI_LM",
|
|
27
|
-
"LM",
|
|
28
26
|
"split_indices_by_length",
|
|
29
27
|
"group_messages_by_len",
|
|
28
|
+
"PydanticLM",
|
|
29
|
+
"TextLM",
|
|
30
30
|
]
|
llm_utils/lm/__init__.py
ADDED
llm_utils/lm/base_lm.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import random
|
|
3
|
+
import time
|
|
4
|
+
from typing import (
|
|
5
|
+
Any,
|
|
6
|
+
List,
|
|
7
|
+
Literal,
|
|
8
|
+
Optional,
|
|
9
|
+
Type,
|
|
10
|
+
TypeVar,
|
|
11
|
+
Union,
|
|
12
|
+
Dict,
|
|
13
|
+
overload,
|
|
14
|
+
Tuple,
|
|
15
|
+
)
|
|
16
|
+
from pydantic import BaseModel
|
|
17
|
+
from speedy_utils import dump_json_or_pickle, identify_uuid, load_json_or_pickle
|
|
18
|
+
from loguru import logger
|
|
19
|
+
from copy import deepcopy
|
|
20
|
+
import numpy as np
|
|
21
|
+
import tempfile
|
|
22
|
+
import fcntl
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class LM:
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
model: Optional[str] = None,
|
|
30
|
+
model_type: Literal["chat", "text"] = "chat",
|
|
31
|
+
temperature: float = 0.0,
|
|
32
|
+
max_tokens: int = 2000,
|
|
33
|
+
cache: bool = True,
|
|
34
|
+
callbacks: Optional[Any] = None,
|
|
35
|
+
num_retries: int = 3,
|
|
36
|
+
host: str = "localhost",
|
|
37
|
+
port: Optional[int] = None,
|
|
38
|
+
ports: Optional[List[int]] = None,
|
|
39
|
+
api_key: Optional[str] = None,
|
|
40
|
+
system_prompt: Optional[str] = None,
|
|
41
|
+
**kwargs,
|
|
42
|
+
):
|
|
43
|
+
from openai import OpenAI
|
|
44
|
+
|
|
45
|
+
self.ports = ports
|
|
46
|
+
self.host = host
|
|
47
|
+
self.api_key = api_key or os.getenv("OPENAI_API_KEY", "abc")
|
|
48
|
+
resolved_base_url_from_kwarg = kwargs.get("base_url")
|
|
49
|
+
if resolved_base_url_from_kwarg is not None and not isinstance(
|
|
50
|
+
resolved_base_url_from_kwarg, str
|
|
51
|
+
):
|
|
52
|
+
logger.warning(
|
|
53
|
+
f"base_url in kwargs was not a string ({type(resolved_base_url_from_kwarg)}), ignoring."
|
|
54
|
+
)
|
|
55
|
+
resolved_base_url_from_kwarg = None
|
|
56
|
+
resolved_base_url: Optional[str] = resolved_base_url_from_kwarg
|
|
57
|
+
if resolved_base_url is None:
|
|
58
|
+
selected_port = port
|
|
59
|
+
if selected_port is None and ports is not None and len(ports) > 0:
|
|
60
|
+
selected_port = ports[0]
|
|
61
|
+
if selected_port is not None:
|
|
62
|
+
resolved_base_url = f"http://{host}:{selected_port}/v1"
|
|
63
|
+
self.base_url = resolved_base_url
|
|
64
|
+
|
|
65
|
+
if model is None:
|
|
66
|
+
if self.base_url:
|
|
67
|
+
try:
|
|
68
|
+
model_list = self.list_models()
|
|
69
|
+
if model_list:
|
|
70
|
+
model_name_from_list = model_list[0]
|
|
71
|
+
model = model_name_from_list
|
|
72
|
+
logger.info(f"Using default model: {model}")
|
|
73
|
+
else:
|
|
74
|
+
logger.warning(
|
|
75
|
+
f"No models found at {self.base_url}. Please specify a model."
|
|
76
|
+
)
|
|
77
|
+
except Exception as e:
|
|
78
|
+
example_cmd = (
|
|
79
|
+
"LM.start_server('unsloth/gemma-3-1b-it')\n"
|
|
80
|
+
"# Or manually run: svllm serve --model unsloth/gemma-3-1b-it --gpus 0 -hp localhost:9150"
|
|
81
|
+
)
|
|
82
|
+
logger.error(
|
|
83
|
+
f"Failed to list models from {self.base_url}: {e}\n"
|
|
84
|
+
f"Make sure your model server is running and accessible.\n"
|
|
85
|
+
f"Example to start a server:\n{example_cmd}"
|
|
86
|
+
)
|
|
87
|
+
else:
|
|
88
|
+
logger.warning(
|
|
89
|
+
"base_url not configured, cannot fetch default model. Please specify a model."
|
|
90
|
+
)
|
|
91
|
+
assert (
|
|
92
|
+
model is not None
|
|
93
|
+
), "Model name must be provided or discoverable via list_models"
|
|
94
|
+
|
|
95
|
+
# Remove 'openai/' prefix if present
|
|
96
|
+
if model.startswith("openai/"):
|
|
97
|
+
model = model[7:]
|
|
98
|
+
|
|
99
|
+
self.kwargs = {"temperature": temperature, "max_tokens": max_tokens, **kwargs}
|
|
100
|
+
self.model = model
|
|
101
|
+
self.model_type = model_type
|
|
102
|
+
self.num_retries = num_retries
|
|
103
|
+
self.do_cache = cache
|
|
104
|
+
self.callbacks = callbacks
|
|
105
|
+
|
|
106
|
+
# Initialize OpenAI client
|
|
107
|
+
self.openai_client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
|
108
|
+
self.system_prompt = system_prompt
|
|
109
|
+
|
|
110
|
+
def dump_cache(
|
|
111
|
+
self, id: str, result: Union[str, BaseModel, List[Union[str, BaseModel]]]
|
|
112
|
+
):
|
|
113
|
+
try:
|
|
114
|
+
cache_file = f"~/.cache/oai_lm/{self.model}/{id}.pkl"
|
|
115
|
+
cache_file = os.path.expanduser(cache_file)
|
|
116
|
+
dump_json_or_pickle(result, cache_file)
|
|
117
|
+
except Exception as e:
|
|
118
|
+
logger.warning(f"Cache dump failed: {e}")
|
|
119
|
+
|
|
120
|
+
def load_cache(
|
|
121
|
+
self, id: str
|
|
122
|
+
) -> Optional[Union[str, BaseModel, List[Union[str, BaseModel]]]]:
|
|
123
|
+
try:
|
|
124
|
+
cache_file = f"~/.cache/oai_lm/{self.model}/{id}.pkl"
|
|
125
|
+
cache_file = os.path.expanduser(cache_file)
|
|
126
|
+
if not os.path.exists(cache_file):
|
|
127
|
+
return None
|
|
128
|
+
return load_json_or_pickle(cache_file)
|
|
129
|
+
except Exception as e:
|
|
130
|
+
logger.warning(f"Cache load failed for {id}: {e}")
|
|
131
|
+
return None
|
|
132
|
+
|
|
133
|
+
def list_models(self) -> List[str]:
|
|
134
|
+
from openai import OpenAI
|
|
135
|
+
|
|
136
|
+
if not self.base_url:
|
|
137
|
+
raise ValueError("Cannot list models: base_url is not configured.")
|
|
138
|
+
if not self.api_key:
|
|
139
|
+
logger.warning(
|
|
140
|
+
"API key not available for listing models. Using default 'abc'."
|
|
141
|
+
)
|
|
142
|
+
api_key_str = str(self.api_key) if self.api_key is not None else "abc"
|
|
143
|
+
base_url_str = str(self.base_url) if self.base_url is not None else None
|
|
144
|
+
if isinstance(self.base_url, float):
|
|
145
|
+
raise TypeError(
|
|
146
|
+
f"base_url must be a string or None, got float: {self.base_url}"
|
|
147
|
+
)
|
|
148
|
+
client = OpenAI(base_url=base_url_str, api_key=api_key_str)
|
|
149
|
+
try:
|
|
150
|
+
page = client.models.list()
|
|
151
|
+
return [d.id for d in page.data]
|
|
152
|
+
except Exception as e:
|
|
153
|
+
logger.error(f"Error listing models: {e}")
|
|
154
|
+
return []
|
|
155
|
+
|
|
156
|
+
def get_least_used_port(self) -> int:
|
|
157
|
+
if self.ports is None:
|
|
158
|
+
raise ValueError("Ports must be configured to pick the least used port.")
|
|
159
|
+
if not self.ports:
|
|
160
|
+
raise ValueError("Ports list is empty, cannot pick a port.")
|
|
161
|
+
return self._pick_least_used_port(self.ports)
|
|
162
|
+
|
|
163
|
+
def _pick_least_used_port(self, ports: List[int]) -> int:
|
|
164
|
+
global_lock_file = "/tmp/ports.lock"
|
|
165
|
+
with open(global_lock_file, "w") as lock_file:
|
|
166
|
+
fcntl.flock(lock_file, fcntl.LOCK_EX)
|
|
167
|
+
try:
|
|
168
|
+
port_use: Dict[int, int] = {}
|
|
169
|
+
for port in ports:
|
|
170
|
+
file_counter = f"/tmp/port_use_counter_{port}.npy"
|
|
171
|
+
if os.path.exists(file_counter):
|
|
172
|
+
try:
|
|
173
|
+
counter = np.load(file_counter)
|
|
174
|
+
except Exception as e:
|
|
175
|
+
logger.warning(f"Corrupted usage file {file_counter}: {e}")
|
|
176
|
+
counter = np.array([0])
|
|
177
|
+
else:
|
|
178
|
+
counter = np.array([0])
|
|
179
|
+
port_use[port] = counter[0]
|
|
180
|
+
if not port_use:
|
|
181
|
+
if ports:
|
|
182
|
+
raise ValueError(
|
|
183
|
+
"Port usage data is empty, cannot pick a port."
|
|
184
|
+
)
|
|
185
|
+
else:
|
|
186
|
+
raise ValueError("No ports provided to pick from.")
|
|
187
|
+
lsp = min(port_use, key=lambda k: port_use[k])
|
|
188
|
+
self._update_port_use(lsp, 1)
|
|
189
|
+
finally:
|
|
190
|
+
fcntl.flock(lock_file, fcntl.LOCK_UN)
|
|
191
|
+
return lsp
|
|
192
|
+
|
|
193
|
+
def _update_port_use(self, port: int, increment: int):
|
|
194
|
+
file_counter = f"/tmp/port_use_counter_{port}.npy"
|
|
195
|
+
file_counter_lock = f"/tmp/port_use_counter_{port}.lock"
|
|
196
|
+
with open(file_counter_lock, "w") as lock_file:
|
|
197
|
+
fcntl.flock(lock_file, fcntl.LOCK_EX)
|
|
198
|
+
try:
|
|
199
|
+
if os.path.exists(file_counter):
|
|
200
|
+
try:
|
|
201
|
+
counter = np.load(file_counter)
|
|
202
|
+
except Exception as e:
|
|
203
|
+
logger.warning(f"Corrupted usage file {file_counter}: {e}")
|
|
204
|
+
counter = np.array([0])
|
|
205
|
+
else:
|
|
206
|
+
counter = np.array([0])
|
|
207
|
+
counter[0] += increment
|
|
208
|
+
self._atomic_save(counter, file_counter)
|
|
209
|
+
finally:
|
|
210
|
+
fcntl.flock(lock_file, fcntl.LOCK_UN)
|
|
211
|
+
|
|
212
|
+
def _atomic_save(self, array: np.ndarray, filename: str):
|
|
213
|
+
tmp_dir: str = os.path.dirname(filename) or "."
|
|
214
|
+
with tempfile.NamedTemporaryFile(dir=tmp_dir, delete=False) as tmp:
|
|
215
|
+
np.save(tmp, array)
|
|
216
|
+
temp_name: str = tmp.name
|
|
217
|
+
os.replace(temp_name, filename)
|
|
218
|
+
|
|
219
|
+
def _prepare_call_inputs(
|
|
220
|
+
self,
|
|
221
|
+
messages: List[Any],
|
|
222
|
+
max_tokens: Optional[int],
|
|
223
|
+
port: Optional[int],
|
|
224
|
+
use_loadbalance: Optional[bool],
|
|
225
|
+
cache: Optional[bool],
|
|
226
|
+
**kwargs,
|
|
227
|
+
) -> Tuple[dict, bool, Optional[int], List[Any]]:
|
|
228
|
+
"""Prepare inputs for the LLM call."""
|
|
229
|
+
# Prepare kwargs
|
|
230
|
+
effective_kwargs = {**self.kwargs, **kwargs}
|
|
231
|
+
if max_tokens is not None:
|
|
232
|
+
effective_kwargs["max_tokens"] = max_tokens
|
|
233
|
+
|
|
234
|
+
# Set effective cache
|
|
235
|
+
effective_cache = cache if cache is not None else self.do_cache
|
|
236
|
+
|
|
237
|
+
# Setup port
|
|
238
|
+
current_port = port
|
|
239
|
+
if self.ports and not current_port:
|
|
240
|
+
current_port = (
|
|
241
|
+
self.get_least_used_port()
|
|
242
|
+
if use_loadbalance
|
|
243
|
+
else random.choice(self.ports)
|
|
244
|
+
)
|
|
245
|
+
if current_port:
|
|
246
|
+
base_url = f"http://{self.host}:{current_port}/v1"
|
|
247
|
+
effective_kwargs["base_url"] = base_url
|
|
248
|
+
# Update client with new base_url
|
|
249
|
+
from openai import OpenAI
|
|
250
|
+
|
|
251
|
+
self.openai_client = OpenAI(api_key=self.api_key, base_url=base_url)
|
|
252
|
+
|
|
253
|
+
return effective_kwargs, effective_cache, current_port, messages
|
|
254
|
+
|
|
255
|
+
def _call_llm(
|
|
256
|
+
self,
|
|
257
|
+
dspy_main_input: List[Any],
|
|
258
|
+
current_port: Optional[int],
|
|
259
|
+
use_loadbalance: Optional[bool],
|
|
260
|
+
**kwargs,
|
|
261
|
+
) -> Any:
|
|
262
|
+
raise NotImplementedError("This method should be implemented in subclasses.")
|
|
263
|
+
# """Call the OpenAI API directly and get raw output (no retries)."""
|
|
264
|
+
# try:
|
|
265
|
+
# # Handle message list
|
|
266
|
+
# response = self.openai_client.chat.completions.create(
|
|
267
|
+
# model=self.model, messages=dspy_main_input, **kwargs
|
|
268
|
+
# )
|
|
269
|
+
|
|
270
|
+
# # Update port usage stats if needed
|
|
271
|
+
# if current_port and use_loadbalance is True:
|
|
272
|
+
# self._update_port_use(current_port, -1)
|
|
273
|
+
|
|
274
|
+
# return response.choices[0].message.content
|
|
275
|
+
|
|
276
|
+
# except Exception as e:
|
|
277
|
+
# logger.error(f"API call failed: {e}")
|
|
278
|
+
# raise
|
|
279
|
+
|
|
280
|
+
def _generate_cache_key_base(
|
|
281
|
+
self,
|
|
282
|
+
messages: List[Any],
|
|
283
|
+
effective_kwargs: dict,
|
|
284
|
+
) -> List[Any]:
|
|
285
|
+
"""Base method to generate cache key components."""
|
|
286
|
+
return [
|
|
287
|
+
messages,
|
|
288
|
+
effective_kwargs.get("temperature"),
|
|
289
|
+
effective_kwargs.get("max_tokens"),
|
|
290
|
+
self.model,
|
|
291
|
+
]
|
|
292
|
+
|
|
293
|
+
def _store_in_cache_base(
|
|
294
|
+
self, effective_cache: bool, id_for_cache: Optional[str], result: Any
|
|
295
|
+
):
|
|
296
|
+
"""Base method to store result in cache if caching is enabled."""
|
|
297
|
+
if effective_cache and id_for_cache:
|
|
298
|
+
self.dump_cache(id_for_cache, result)
|
|
299
|
+
|
|
300
|
+
def __call__(
|
|
301
|
+
self,
|
|
302
|
+
prompt: Optional[str] = None,
|
|
303
|
+
system_prompt: Optional[str] = None,
|
|
304
|
+
messages: Optional[List[Any]] = None,
|
|
305
|
+
**kwargs,
|
|
306
|
+
) -> Any:
|
|
307
|
+
"""
|
|
308
|
+
If have prompt but not messages, convert prompt to messages.
|
|
309
|
+
If both raise
|
|
310
|
+
If neither, raise
|
|
311
|
+
"""
|
|
312
|
+
if prompt is not None and messages is not None:
|
|
313
|
+
raise ValueError("Cannot provide both prompt and messages.")
|
|
314
|
+
if prompt is None and messages is None:
|
|
315
|
+
raise ValueError("Either prompt or messages must be provided.")
|
|
316
|
+
|
|
317
|
+
# Convert prompt to messages if needed
|
|
318
|
+
if prompt is not None:
|
|
319
|
+
effective_system_prompt = system_prompt or self.system_prompt
|
|
320
|
+
if effective_system_prompt is not None:
|
|
321
|
+
messages = [
|
|
322
|
+
{"role": "system", "content": effective_system_prompt},
|
|
323
|
+
{"role": "user", "content": prompt},
|
|
324
|
+
]
|
|
325
|
+
else:
|
|
326
|
+
messages = [{"role": "user", "content": prompt}]
|
|
327
|
+
|
|
328
|
+
# Call the LLM with the prepared inputs
|
|
329
|
+
assert messages is not None, "messages must not be None"
|
|
330
|
+
return self.forward_messages(messages=messages, **kwargs)
|
|
331
|
+
|
|
332
|
+
def forward_messages(
|
|
333
|
+
self,
|
|
334
|
+
messages: List[Any],
|
|
335
|
+
**kwargs,
|
|
336
|
+
) -> str:
|
|
337
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from copy import deepcopy
|
|
2
|
+
import time # Add time for possible retries
|
|
3
|
+
from typing import (
|
|
4
|
+
Any,
|
|
5
|
+
List,
|
|
6
|
+
Literal,
|
|
7
|
+
Optional,
|
|
8
|
+
TypedDict,
|
|
9
|
+
Type,
|
|
10
|
+
Union,
|
|
11
|
+
TypeVar,
|
|
12
|
+
overload,
|
|
13
|
+
)
|
|
14
|
+
from pydantic import BaseModel
|
|
15
|
+
|
|
16
|
+
T = TypeVar("T", bound=BaseModel)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Message(TypedDict):
|
|
20
|
+
role: Literal["user", "assistant", "system"]
|
|
21
|
+
content: str | BaseModel
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ChatSession:
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
lm: Any,
|
|
28
|
+
system_prompt: Optional[str] = None,
|
|
29
|
+
history: List[Message] = [],
|
|
30
|
+
callback=None,
|
|
31
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
32
|
+
):
|
|
33
|
+
self.lm = deepcopy(lm)
|
|
34
|
+
self.history = deepcopy(history)
|
|
35
|
+
self.callback = callback
|
|
36
|
+
self.response_format = response_format
|
|
37
|
+
if system_prompt:
|
|
38
|
+
system_message: Message = {
|
|
39
|
+
"role": "system",
|
|
40
|
+
"content": system_prompt,
|
|
41
|
+
}
|
|
42
|
+
self.history.insert(0, system_message)
|
|
43
|
+
|
|
44
|
+
def __len__(self):
|
|
45
|
+
return len(self.history)
|
|
46
|
+
|
|
47
|
+
@overload
|
|
48
|
+
def __call__(
|
|
49
|
+
self, text, response_format: Type[T], display=False, max_prev_turns=3, **kwargs
|
|
50
|
+
) -> T: ...
|
|
51
|
+
@overload
|
|
52
|
+
def __call__(
|
|
53
|
+
self,
|
|
54
|
+
text,
|
|
55
|
+
response_format: None = None,
|
|
56
|
+
display=False,
|
|
57
|
+
max_prev_turns=3,
|
|
58
|
+
**kwargs,
|
|
59
|
+
) -> str: ...
|
|
60
|
+
def __call__(
|
|
61
|
+
self,
|
|
62
|
+
text,
|
|
63
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
64
|
+
display=False,
|
|
65
|
+
max_prev_turns=3,
|
|
66
|
+
**kwargs,
|
|
67
|
+
) -> Union[str, BaseModel]:
|
|
68
|
+
current_response_format = response_format or self.response_format
|
|
69
|
+
self.history.append({"role": "user", "content": text})
|
|
70
|
+
output = self.lm(
|
|
71
|
+
messages=self.parse_history(),
|
|
72
|
+
response_format=current_response_format,
|
|
73
|
+
**kwargs,
|
|
74
|
+
)
|
|
75
|
+
if isinstance(output, BaseModel):
|
|
76
|
+
self.history.append({"role": "assistant", "content": output})
|
|
77
|
+
else:
|
|
78
|
+
assert response_format is None
|
|
79
|
+
self.history.append({"role": "assistant", "content": output})
|
|
80
|
+
if display:
|
|
81
|
+
self.inspect_history(max_prev_turns=max_prev_turns)
|
|
82
|
+
if self.callback:
|
|
83
|
+
self.callback(self, output)
|
|
84
|
+
return output
|
|
85
|
+
|
|
86
|
+
def send_message(self, text, **kwargs):
|
|
87
|
+
return self.__call__(text, **kwargs)
|
|
88
|
+
|
|
89
|
+
def parse_history(self, indent=None):
|
|
90
|
+
parsed_history = []
|
|
91
|
+
for m in self.history:
|
|
92
|
+
if isinstance(m["content"], str):
|
|
93
|
+
parsed_history.append(m)
|
|
94
|
+
elif isinstance(m["content"], BaseModel):
|
|
95
|
+
parsed_history.append(
|
|
96
|
+
{
|
|
97
|
+
"role": m["role"],
|
|
98
|
+
"content": m["content"].model_dump_json(indent=indent),
|
|
99
|
+
}
|
|
100
|
+
)
|
|
101
|
+
else:
|
|
102
|
+
raise ValueError(f"Unexpected content type: {type(m['content'])}")
|
|
103
|
+
return parsed_history
|
|
104
|
+
|
|
105
|
+
def inspect_history(self, max_prev_turns=3):
|
|
106
|
+
from llm_utils import display_chat_messages_as_html
|
|
107
|
+
|
|
108
|
+
h = self.parse_history(indent=2)
|
|
109
|
+
try:
|
|
110
|
+
from IPython.display import clear_output
|
|
111
|
+
|
|
112
|
+
clear_output()
|
|
113
|
+
display_chat_messages_as_html(h[-max_prev_turns * 2 :])
|
|
114
|
+
except:
|
|
115
|
+
pass
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import time
|
|
3
|
+
from typing import Any, List, Optional, Type, Union, cast, TypeVar, Generic
|
|
4
|
+
|
|
5
|
+
from openai import AuthenticationError, RateLimitError
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from speedy_utils.common.logger import logger
|
|
9
|
+
from speedy_utils.common.utils_cache import identify_uuid
|
|
10
|
+
|
|
11
|
+
from .base_lm import LM
|
|
12
|
+
|
|
13
|
+
T = TypeVar("T", bound=BaseModel)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PydanticLM(LM):
|
|
17
|
+
"""
|
|
18
|
+
Language model that returns outputs as Pydantic models.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def _generate_cache_key(
|
|
22
|
+
self,
|
|
23
|
+
messages: List[Any],
|
|
24
|
+
response_format: Optional[Type[BaseModel]],
|
|
25
|
+
kwargs: dict,
|
|
26
|
+
) -> str:
|
|
27
|
+
"""
|
|
28
|
+
Generate a cache key based on input parameters.
|
|
29
|
+
"""
|
|
30
|
+
cache_key_base = self._generate_cache_key_base(messages, kwargs)
|
|
31
|
+
cache_key_base.insert(
|
|
32
|
+
1, (response_format.model_json_schema() if response_format else None)
|
|
33
|
+
)
|
|
34
|
+
return identify_uuid(str(cache_key_base))
|
|
35
|
+
|
|
36
|
+
def _parse_cached_result(
|
|
37
|
+
self, cached_result: Any, response_format: Optional[Type[BaseModel]]
|
|
38
|
+
) -> Optional[BaseModel]:
|
|
39
|
+
"""
|
|
40
|
+
Parse cached result into a BaseModel instance.
|
|
41
|
+
"""
|
|
42
|
+
if isinstance(cached_result, BaseModel):
|
|
43
|
+
return cached_result
|
|
44
|
+
elif isinstance(cached_result, str):
|
|
45
|
+
if response_format is None:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
"response_format must be provided to parse cached string result."
|
|
48
|
+
)
|
|
49
|
+
import json
|
|
50
|
+
|
|
51
|
+
return response_format.model_validate_json(cached_result)
|
|
52
|
+
elif (
|
|
53
|
+
isinstance(cached_result, list)
|
|
54
|
+
and cached_result
|
|
55
|
+
and isinstance(cached_result[0], (str, BaseModel))
|
|
56
|
+
):
|
|
57
|
+
first = cached_result[0]
|
|
58
|
+
if isinstance(first, BaseModel):
|
|
59
|
+
return first
|
|
60
|
+
elif isinstance(first, str):
|
|
61
|
+
if response_format is None:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
"response_format must be provided to parse cached string result."
|
|
64
|
+
)
|
|
65
|
+
import json
|
|
66
|
+
|
|
67
|
+
return response_format.model_validate_json(first)
|
|
68
|
+
else:
|
|
69
|
+
logger.warning(
|
|
70
|
+
f"Cached result has unexpected type {type(cached_result)}. Ignoring cache."
|
|
71
|
+
)
|
|
72
|
+
return None
|
|
73
|
+
|
|
74
|
+
def _check_cache(
|
|
75
|
+
self,
|
|
76
|
+
effective_cache: bool,
|
|
77
|
+
messages: List[Any],
|
|
78
|
+
response_format: Optional[Type[BaseModel]],
|
|
79
|
+
effective_kwargs: dict,
|
|
80
|
+
):
|
|
81
|
+
"""Check if result is in cache and return it if available."""
|
|
82
|
+
if not effective_cache:
|
|
83
|
+
return None, None
|
|
84
|
+
|
|
85
|
+
cache_id = self._generate_cache_key(messages, response_format, effective_kwargs)
|
|
86
|
+
cached_result = self.load_cache(cache_id)
|
|
87
|
+
parsed_cache = self._parse_cached_result(cached_result, response_format)
|
|
88
|
+
|
|
89
|
+
return cache_id, parsed_cache
|
|
90
|
+
|
|
91
|
+
def _call_llm(
|
|
92
|
+
self,
|
|
93
|
+
dspy_main_input: List[Any],
|
|
94
|
+
response_format: Type[BaseModel],
|
|
95
|
+
current_port: Optional[int],
|
|
96
|
+
use_loadbalance: Optional[bool],
|
|
97
|
+
**kwargs,
|
|
98
|
+
):
|
|
99
|
+
"""Call the LLM with response format support using OpenAI's parse method."""
|
|
100
|
+
# Use messages directly
|
|
101
|
+
messages = dspy_main_input
|
|
102
|
+
|
|
103
|
+
# Use OpenAI's parse method for structured output
|
|
104
|
+
try:
|
|
105
|
+
response = self.openai_client.beta.chat.completions.parse(
|
|
106
|
+
model=self.model,
|
|
107
|
+
messages=messages,
|
|
108
|
+
response_format=response_format,
|
|
109
|
+
**kwargs,
|
|
110
|
+
)
|
|
111
|
+
except AuthenticationError as e:
|
|
112
|
+
logger.error(f"Authentication error: {e}")
|
|
113
|
+
raise
|
|
114
|
+
except TimeoutError as e:
|
|
115
|
+
logger.error(f"Timeout error: {e}")
|
|
116
|
+
raise
|
|
117
|
+
except RateLimitError as e:
|
|
118
|
+
logger.error(f"Rate limit exceeded: {e}")
|
|
119
|
+
raise
|
|
120
|
+
# Update port usage stats if needed
|
|
121
|
+
if current_port and use_loadbalance is True:
|
|
122
|
+
self._update_port_use(current_port, -1)
|
|
123
|
+
|
|
124
|
+
return response.choices[0].message.parsed
|
|
125
|
+
|
|
126
|
+
def _parse_llm_output(
|
|
127
|
+
self, llm_output: Any, response_format: Optional[Type[BaseModel]]
|
|
128
|
+
) -> BaseModel:
|
|
129
|
+
"""Parse the LLM output into the correct format."""
|
|
130
|
+
if isinstance(llm_output, BaseModel):
|
|
131
|
+
return llm_output
|
|
132
|
+
elif isinstance(llm_output, dict):
|
|
133
|
+
if not response_format:
|
|
134
|
+
raise ValueError("response_format required to parse dict output.")
|
|
135
|
+
return response_format.model_validate(llm_output)
|
|
136
|
+
elif isinstance(llm_output, str):
|
|
137
|
+
if not response_format:
|
|
138
|
+
raise ValueError("response_format required to parse string output.")
|
|
139
|
+
import json
|
|
140
|
+
|
|
141
|
+
return response_format.model_validate_json(llm_output)
|
|
142
|
+
else:
|
|
143
|
+
if not response_format:
|
|
144
|
+
raise ValueError("response_format required to parse output.")
|
|
145
|
+
return response_format.model_validate_json(str(llm_output))
|
|
146
|
+
|
|
147
|
+
def _store_in_cache(
|
|
148
|
+
self, effective_cache: bool, cache_id: Optional[str], result: BaseModel
|
|
149
|
+
):
|
|
150
|
+
"""Store the result in cache if caching is enabled."""
|
|
151
|
+
if result and isinstance(result, BaseModel):
|
|
152
|
+
self._store_in_cache_base(
|
|
153
|
+
effective_cache, cache_id, result.model_dump_json()
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
def forward_messages(
|
|
157
|
+
self,
|
|
158
|
+
response_format: Type[T],
|
|
159
|
+
messages: List[Any],
|
|
160
|
+
cache: Optional[bool] = None,
|
|
161
|
+
port: Optional[int] = None,
|
|
162
|
+
use_loadbalance: Optional[bool] = None,
|
|
163
|
+
max_tokens: Optional[int] = None,
|
|
164
|
+
**kwargs,
|
|
165
|
+
) -> T:
|
|
166
|
+
# 1. Prepare inputs
|
|
167
|
+
effective_kwargs, effective_cache, current_port, dspy_main_input = (
|
|
168
|
+
self._prepare_call_inputs(
|
|
169
|
+
messages, max_tokens, port, use_loadbalance, cache, **kwargs
|
|
170
|
+
)
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# 2. Check cache
|
|
174
|
+
cache_id, cached_result = self._check_cache(
|
|
175
|
+
effective_cache, messages, response_format, effective_kwargs
|
|
176
|
+
)
|
|
177
|
+
if cached_result:
|
|
178
|
+
return cast(T, cached_result)
|
|
179
|
+
|
|
180
|
+
# 3. Call LLM using OpenAI's parse method
|
|
181
|
+
llm_output = self._call_llm(
|
|
182
|
+
dspy_main_input,
|
|
183
|
+
response_format,
|
|
184
|
+
current_port,
|
|
185
|
+
use_loadbalance,
|
|
186
|
+
**effective_kwargs,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# 4. Parse output
|
|
190
|
+
result = self._parse_llm_output(llm_output, response_format)
|
|
191
|
+
|
|
192
|
+
# 5. Store in cache
|
|
193
|
+
self._store_in_cache(effective_cache, cache_id, result)
|
|
194
|
+
|
|
195
|
+
return cast(T, result)
|