speedy-utils 1.1.9__tar.gz → 1.1.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/PKG-INFO +1 -1
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/pyproject.toml +1 -1
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/llm_utils/__init__.py +2 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/llm_utils/lm/async_lm/async_llm_task.py +5 -1
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/llm_utils/lm/async_lm/async_lm.py +34 -55
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/llm_utils/lm/async_lm/async_lm_base.py +5 -173
- speedy_utils-1.1.11/src/llm_utils/lm/openai_memoize.py +72 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/llm_utils/scripts/vllm_serve.py +2 -1
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/speedy_utils/__init__.py +1 -3
- speedy_utils-1.1.11/src/speedy_utils/common/utils_cache.py +664 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/speedy_utils/common/utils_io.py +14 -2
- speedy_utils-1.1.9/src/speedy_utils/common/utils_cache.py +0 -494
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/README.md +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/llm_utils/chat_format/__init__.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/llm_utils/chat_format/display.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/llm_utils/chat_format/transform.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/llm_utils/chat_format/utils.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/llm_utils/group_messages.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/llm_utils/lm/__init__.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/llm_utils/lm/async_lm/__init__.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/llm_utils/lm/async_lm/_utils.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/llm_utils/lm/async_lm/lm_specific.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/llm_utils/lm/utils.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/llm_utils/scripts/README.md +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/llm_utils/scripts/vllm_load_balancer.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/speedy_utils/all.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/speedy_utils/common/__init__.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/speedy_utils/common/clock.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/speedy_utils/common/function_decorator.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/speedy_utils/common/logger.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/speedy_utils/common/notebook_utils.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/speedy_utils/common/report_manager.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/speedy_utils/common/utils_misc.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/speedy_utils/common/utils_print.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/speedy_utils/multi_worker/__init__.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/speedy_utils/multi_worker/process.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/speedy_utils/multi_worker/thread.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/speedy_utils/scripts/__init__.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/speedy_utils/scripts/mpython.py +0 -0
- {speedy_utils-1.1.9 → speedy_utils-1.1.11}/src/speedy_utils/scripts/openapi_client_codegen.py +0 -0
|
@@ -389,7 +389,7 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
|
389
389
|
input_data: InputModelType,
|
|
390
390
|
expected_response: Optional[OutputModelType] = None,
|
|
391
391
|
label: Optional[str] = None,
|
|
392
|
-
cache_dir: pathlib.Path =
|
|
392
|
+
cache_dir: Optional[pathlib.Path] = None,
|
|
393
393
|
) -> OutputModelType:
|
|
394
394
|
"""
|
|
395
395
|
Generate training data for both thinking and non-thinking modes.
|
|
@@ -415,6 +415,10 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
|
415
415
|
# Create non-thinking mode equivalent
|
|
416
416
|
no_think_messages = self._create_no_think_messages(think_messages)
|
|
417
417
|
|
|
418
|
+
# Use default cache directory if none provided
|
|
419
|
+
if cache_dir is None:
|
|
420
|
+
cache_dir = self.DEFAULT_CACHE_DIR or pathlib.Path("./cache")
|
|
421
|
+
|
|
418
422
|
# Save training data
|
|
419
423
|
self._save_training_data(
|
|
420
424
|
input_data=input_data,
|
|
@@ -96,67 +96,37 @@ class AsyncLM(AsyncLMBase):
|
|
|
96
96
|
|
|
97
97
|
async def _unified_client_call(
|
|
98
98
|
self,
|
|
99
|
-
messages:
|
|
99
|
+
messages: RawMsgs,
|
|
100
100
|
extra_body: Optional[dict] = None,
|
|
101
|
-
|
|
101
|
+
max_tokens: Optional[int] = None,
|
|
102
102
|
) -> dict:
|
|
103
|
-
"""Unified method for all client interactions
|
|
104
|
-
converted_messages =
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
if
|
|
110
|
-
|
|
103
|
+
"""Unified method for all client interactions (caching handled by MAsyncOpenAI)."""
|
|
104
|
+
converted_messages: Messages = (
|
|
105
|
+
self._convert_messages(cast(LegacyMsgs, messages))
|
|
106
|
+
if messages and isinstance(messages[0], dict)
|
|
107
|
+
else cast(Messages, messages)
|
|
108
|
+
)
|
|
109
|
+
# override max_tokens if provided
|
|
110
|
+
if max_tokens is not None:
|
|
111
|
+
self.model_kwargs["max_tokens"] = max_tokens
|
|
112
|
+
|
|
113
|
+
try:
|
|
114
|
+
# Get completion from API (caching handled by MAsyncOpenAI)
|
|
115
|
+
call_kwargs = {
|
|
111
116
|
"messages": converted_messages,
|
|
112
|
-
|
|
113
|
-
"extra_body": extra_body or {},
|
|
114
|
-
"cache_suffix": cache_suffix,
|
|
117
|
+
**self.model_kwargs,
|
|
115
118
|
}
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
# Check for cached error responses
|
|
120
|
-
if (
|
|
121
|
-
completion
|
|
122
|
-
and isinstance(completion, dict)
|
|
123
|
-
and "error" in completion
|
|
124
|
-
and completion["error"]
|
|
125
|
-
):
|
|
126
|
-
error_type = completion.get("error_type", "Unknown")
|
|
127
|
-
error_message = completion.get("error_message", "Cached error")
|
|
128
|
-
logger.warning(f"Found cached error ({error_type}): {error_message}")
|
|
129
|
-
raise ValueError(f"Cached {error_type}: {error_message}")
|
|
119
|
+
if extra_body:
|
|
120
|
+
call_kwargs["extra_body"] = extra_body
|
|
130
121
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
if
|
|
134
|
-
|
|
135
|
-
"messages": converted_messages,
|
|
136
|
-
**self.model_kwargs,
|
|
137
|
-
}
|
|
138
|
-
if extra_body:
|
|
139
|
-
call_kwargs["extra_body"] = extra_body
|
|
140
|
-
|
|
141
|
-
completion = await self.client.chat.completions.create(**call_kwargs)
|
|
142
|
-
|
|
143
|
-
if hasattr(completion, "model_dump"):
|
|
144
|
-
completion = completion.model_dump()
|
|
145
|
-
if cache_key:
|
|
146
|
-
self._dump_cache(cache_key, completion)
|
|
122
|
+
completion = await self.client.chat.completions.create(**call_kwargs)
|
|
123
|
+
|
|
124
|
+
if hasattr(completion, "model_dump"):
|
|
125
|
+
completion = completion.model_dump()
|
|
147
126
|
|
|
148
127
|
except (AuthenticationError, RateLimitError, BadRequestError) as exc:
|
|
149
128
|
error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
|
|
150
129
|
logger.error(error_msg)
|
|
151
|
-
if isinstance(exc, BadRequestError) and cache_key:
|
|
152
|
-
error_response = {
|
|
153
|
-
"error": True,
|
|
154
|
-
"error_type": "BadRequestError",
|
|
155
|
-
"error_message": str(exc),
|
|
156
|
-
"choices": [],
|
|
157
|
-
}
|
|
158
|
-
self._dump_cache(cache_key, error_response)
|
|
159
|
-
logger.debug(f"Cached BadRequestError for key: {cache_key}")
|
|
160
130
|
raise
|
|
161
131
|
|
|
162
132
|
return completion
|
|
@@ -179,7 +149,6 @@ class AsyncLM(AsyncLMBase):
|
|
|
179
149
|
completion = await self._unified_client_call(
|
|
180
150
|
messages,
|
|
181
151
|
extra_body={**self.extra_body},
|
|
182
|
-
cache_suffix=f"_parse_{response_model.__name__}",
|
|
183
152
|
)
|
|
184
153
|
|
|
185
154
|
# Parse the response
|
|
@@ -234,7 +203,6 @@ class AsyncLM(AsyncLMBase):
|
|
|
234
203
|
completion = await self._unified_client_call(
|
|
235
204
|
messages,
|
|
236
205
|
extra_body={"guided_json": json_schema, **self.extra_body},
|
|
237
|
-
cache_suffix=f"_beta_parse_{response_model.__name__}",
|
|
238
206
|
)
|
|
239
207
|
|
|
240
208
|
# Parse the response
|
|
@@ -277,6 +245,7 @@ class AsyncLM(AsyncLMBase):
|
|
|
277
245
|
self,
|
|
278
246
|
prompt: Optional[str] = None,
|
|
279
247
|
messages: Optional[RawMsgs] = None,
|
|
248
|
+
max_tokens: Optional[int] = None,
|
|
280
249
|
): # -> tuple[Any | dict[Any, Any], list[ChatCompletionMessagePar...:# -> tuple[Any | dict[Any, Any], list[ChatCompletionMessagePar...:
|
|
281
250
|
"""Unified async call for language model, returns (assistant_message.model_dump(), messages)."""
|
|
282
251
|
if (prompt is None) == (messages is None):
|
|
@@ -299,7 +268,7 @@ class AsyncLM(AsyncLMBase):
|
|
|
299
268
|
|
|
300
269
|
# Use unified client call
|
|
301
270
|
raw_response = await self._unified_client_call(
|
|
302
|
-
list(openai_msgs),
|
|
271
|
+
list(openai_msgs), max_tokens=max_tokens
|
|
303
272
|
)
|
|
304
273
|
|
|
305
274
|
if hasattr(raw_response, "model_dump"):
|
|
@@ -385,3 +354,13 @@ class AsyncLM(AsyncLMBase):
|
|
|
385
354
|
raise ValueError(
|
|
386
355
|
f"Failed to validate against response model {response_model.__name__}: {exc}\nRaw content: {content}"
|
|
387
356
|
) from exc
|
|
357
|
+
|
|
358
|
+
async def __aenter__(self):
|
|
359
|
+
return self
|
|
360
|
+
|
|
361
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
362
|
+
if hasattr(self, "_last_client"):
|
|
363
|
+
last_client = self._last_client # type: ignore
|
|
364
|
+
await last_client._client.aclose()
|
|
365
|
+
else:
|
|
366
|
+
logger.warning("No last client to close")
|
|
@@ -1,6 +1,4 @@
|
|
|
1
1
|
# from ._utils import *
|
|
2
|
-
import base64
|
|
3
|
-
import hashlib
|
|
4
2
|
import json
|
|
5
3
|
import os
|
|
6
4
|
from typing import (
|
|
@@ -26,6 +24,8 @@ from openai.types.chat import (
|
|
|
26
24
|
from openai.types.model import Model
|
|
27
25
|
from pydantic import BaseModel
|
|
28
26
|
|
|
27
|
+
from llm_utils.lm.openai_memoize import MAsyncOpenAI
|
|
28
|
+
|
|
29
29
|
from ._utils import (
|
|
30
30
|
LegacyMsgs,
|
|
31
31
|
Messages,
|
|
@@ -56,7 +56,7 @@ class AsyncLMBase:
|
|
|
56
56
|
self._init_port = port # <-- store the port provided at init
|
|
57
57
|
|
|
58
58
|
@property
|
|
59
|
-
def client(self) ->
|
|
59
|
+
def client(self) -> MAsyncOpenAI:
|
|
60
60
|
# if have multiple ports
|
|
61
61
|
if self.ports:
|
|
62
62
|
import random
|
|
@@ -66,9 +66,10 @@ class AsyncLMBase:
|
|
|
66
66
|
logger.debug(f"Using port: {port}")
|
|
67
67
|
else:
|
|
68
68
|
api_base = self.base_url or f"http://{self._host}:{self._port}/v1"
|
|
69
|
-
client =
|
|
69
|
+
client = MAsyncOpenAI(
|
|
70
70
|
api_key=self.api_key,
|
|
71
71
|
base_url=api_base,
|
|
72
|
+
cache=self._cache,
|
|
72
73
|
)
|
|
73
74
|
self._last_client = client
|
|
74
75
|
return client
|
|
@@ -176,175 +177,6 @@ class AsyncLMBase:
|
|
|
176
177
|
f"Model did not return valid JSON:\n---\n{raw_response}"
|
|
177
178
|
) from exc
|
|
178
179
|
|
|
179
|
-
# ------------------------------------------------------------------ #
|
|
180
|
-
# Simple disk cache (sync)
|
|
181
|
-
# ------------------------------------------------------------------ #
|
|
182
|
-
@staticmethod
|
|
183
|
-
def _cache_key(
|
|
184
|
-
messages: Any, kw: Any, response_format: Union[type[str], Type[BaseModel]]
|
|
185
|
-
) -> str:
|
|
186
|
-
tag = response_format.__name__ if response_format is not str else "text"
|
|
187
|
-
blob = json.dumps([messages, kw, tag], sort_keys=True).encode()
|
|
188
|
-
return base64.urlsafe_b64encode(hashlib.sha256(blob).digest()).decode()[:22]
|
|
189
|
-
|
|
190
|
-
@staticmethod
|
|
191
|
-
def _cache_path(key: str) -> str:
|
|
192
|
-
return os.path.expanduser(f"~/.cache/lm/{key}.json")
|
|
193
|
-
|
|
194
|
-
def _dump_cache(self, key: str, val: Any) -> None:
|
|
195
|
-
try:
|
|
196
|
-
path = self._cache_path(key)
|
|
197
|
-
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
198
|
-
with open(path, "w") as fh:
|
|
199
|
-
if isinstance(val, BaseModel):
|
|
200
|
-
json.dump(val.model_dump(mode="json"), fh)
|
|
201
|
-
else:
|
|
202
|
-
json.dump(val, fh)
|
|
203
|
-
except Exception as exc:
|
|
204
|
-
logger.debug(f"cache write skipped: {exc}")
|
|
205
|
-
|
|
206
|
-
def _load_cache(self, key: str) -> Any | None:
|
|
207
|
-
path = self._cache_path(key)
|
|
208
|
-
if not os.path.exists(path):
|
|
209
|
-
return None
|
|
210
|
-
try:
|
|
211
|
-
with open(path) as fh:
|
|
212
|
-
return json.load(fh)
|
|
213
|
-
except Exception:
|
|
214
|
-
return None
|
|
215
|
-
|
|
216
|
-
# async def inspect_word_probs(
|
|
217
|
-
# self,
|
|
218
|
-
# messages: Optional[List[Dict[str, Any]]] = None,
|
|
219
|
-
# tokenizer: Optional[Any] = None,
|
|
220
|
-
# do_print=True,
|
|
221
|
-
# add_think: bool = True,
|
|
222
|
-
# ) -> tuple[List[Dict[str, Any]], Any, str]:
|
|
223
|
-
# """
|
|
224
|
-
# Inspect word probabilities in a language model response.
|
|
225
|
-
|
|
226
|
-
# Args:
|
|
227
|
-
# tokenizer: Tokenizer instance to encode words.
|
|
228
|
-
# messages: List of messages to analyze.
|
|
229
|
-
|
|
230
|
-
# Returns:
|
|
231
|
-
# A tuple containing:
|
|
232
|
-
# - List of word probabilities with their log probabilities.
|
|
233
|
-
# - Token log probability dictionaries.
|
|
234
|
-
# - Rendered string with colored word probabilities.
|
|
235
|
-
# """
|
|
236
|
-
# if messages is None:
|
|
237
|
-
# messages = await self.last_messages(add_think=add_think)
|
|
238
|
-
# if messages is None:
|
|
239
|
-
# raise ValueError("No messages provided and no last messages available.")
|
|
240
|
-
|
|
241
|
-
# if tokenizer is None:
|
|
242
|
-
# tokenizer = get_tokenizer(self.model)
|
|
243
|
-
|
|
244
|
-
# ret = await inspect_word_probs_async(self, tokenizer, messages)
|
|
245
|
-
# if do_print:
|
|
246
|
-
# print(ret[-1])
|
|
247
|
-
# return ret
|
|
248
|
-
|
|
249
|
-
# async def last_messages(
|
|
250
|
-
# self, add_think: bool = True
|
|
251
|
-
# ) -> Optional[List[Dict[str, str]]]:
|
|
252
|
-
# """Get the last conversation messages including assistant response."""
|
|
253
|
-
# if not hasattr(self, "last_log"):
|
|
254
|
-
# return None
|
|
255
|
-
|
|
256
|
-
# last_conv = self._last_log
|
|
257
|
-
# messages = last_conv[1] if len(last_conv) > 1 else None
|
|
258
|
-
# last_msg = last_conv[2]
|
|
259
|
-
# if not isinstance(last_msg, dict):
|
|
260
|
-
# last_conv[2] = last_conv[2].model_dump() # type: ignore
|
|
261
|
-
# msg = last_conv[2]
|
|
262
|
-
# # Ensure msg is a dict
|
|
263
|
-
# if hasattr(msg, "model_dump"):
|
|
264
|
-
# msg = msg.model_dump()
|
|
265
|
-
# message = msg["choices"][0]["message"]
|
|
266
|
-
# reasoning = message.get("reasoning_content")
|
|
267
|
-
# answer = message.get("content")
|
|
268
|
-
# if reasoning and add_think:
|
|
269
|
-
# final_answer = f"<think>{reasoning}</think>\n{answer}"
|
|
270
|
-
# else:
|
|
271
|
-
# final_answer = f"<think>\n\n</think>\n{answer}"
|
|
272
|
-
# assistant = {"role": "assistant", "content": final_answer}
|
|
273
|
-
# messages = messages + [assistant] # type: ignore
|
|
274
|
-
# return messages if messages else None
|
|
275
|
-
|
|
276
|
-
# async def inspect_history(self) -> None:
|
|
277
|
-
# """Inspect the conversation history with proper formatting."""
|
|
278
|
-
# if not hasattr(self, "last_log"):
|
|
279
|
-
# raise ValueError("No history available. Please call the model first.")
|
|
280
|
-
|
|
281
|
-
# prompt, messages, response = self._last_log
|
|
282
|
-
# if hasattr(response, "model_dump"):
|
|
283
|
-
# response = response.model_dump()
|
|
284
|
-
# if not messages:
|
|
285
|
-
# messages = [{"role": "user", "content": prompt}]
|
|
286
|
-
|
|
287
|
-
# print("\n\n")
|
|
288
|
-
# print(_blue("[Conversation History]") + "\n")
|
|
289
|
-
|
|
290
|
-
# for msg in messages:
|
|
291
|
-
# role = msg["role"]
|
|
292
|
-
# content = msg["content"]
|
|
293
|
-
# print(_red(f"{role.capitalize()}:"))
|
|
294
|
-
# if isinstance(content, str):
|
|
295
|
-
# print(content.strip())
|
|
296
|
-
# elif isinstance(content, list):
|
|
297
|
-
# for item in content:
|
|
298
|
-
# if item.get("type") == "text":
|
|
299
|
-
# print(item["text"].strip())
|
|
300
|
-
# elif item.get("type") == "image_url":
|
|
301
|
-
# image_url = item["image_url"]["url"]
|
|
302
|
-
# if "base64" in image_url:
|
|
303
|
-
# len_base64 = len(image_url.split("base64,")[1])
|
|
304
|
-
# print(_blue(f"<IMAGE BASE64 ENCODED({len_base64})>"))
|
|
305
|
-
# else:
|
|
306
|
-
# print(_blue(f"<image_url: {image_url}>"))
|
|
307
|
-
# print("\n")
|
|
308
|
-
|
|
309
|
-
# print(_red("Response:"))
|
|
310
|
-
# if isinstance(response, dict) and response.get("choices"):
|
|
311
|
-
# message = response["choices"][0].get("message", {})
|
|
312
|
-
# reasoning = message.get("reasoning_content")
|
|
313
|
-
# parsed = message.get("parsed")
|
|
314
|
-
# content = message.get("content")
|
|
315
|
-
# if reasoning:
|
|
316
|
-
# print(_yellow("<think>"))
|
|
317
|
-
# print(reasoning.strip())
|
|
318
|
-
# print(_yellow("</think>\n"))
|
|
319
|
-
# if parsed:
|
|
320
|
-
# print(
|
|
321
|
-
# json.dumps(
|
|
322
|
-
# (
|
|
323
|
-
# parsed.model_dump()
|
|
324
|
-
# if hasattr(parsed, "model_dump")
|
|
325
|
-
# else parsed
|
|
326
|
-
# ),
|
|
327
|
-
# indent=2,
|
|
328
|
-
# )
|
|
329
|
-
# + "\n"
|
|
330
|
-
# )
|
|
331
|
-
# elif content:
|
|
332
|
-
# print(content.strip())
|
|
333
|
-
# else:
|
|
334
|
-
# print(_green("[No content]"))
|
|
335
|
-
# if len(response["choices"]) > 1:
|
|
336
|
-
# print(
|
|
337
|
-
# _blue(f"\n(Plus {len(response['choices']) - 1} other completions)")
|
|
338
|
-
# )
|
|
339
|
-
# else:
|
|
340
|
-
# print(_yellow("Warning: Not a standard OpenAI response object"))
|
|
341
|
-
# if isinstance(response, str):
|
|
342
|
-
# print(_green(response.strip()))
|
|
343
|
-
# elif isinstance(response, dict):
|
|
344
|
-
# print(_green(json.dumps(response, indent=2)))
|
|
345
|
-
# else:
|
|
346
|
-
# print(_green(str(response)))
|
|
347
|
-
|
|
348
180
|
# ------------------------------------------------------------------ #
|
|
349
181
|
# Misc helpers
|
|
350
182
|
# ------------------------------------------------------------------ #
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from openai import OpenAI, AsyncOpenAI
|
|
2
|
+
|
|
3
|
+
from speedy_utils.common.utils_cache import memoize
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class MOpenAI(OpenAI):
|
|
7
|
+
"""
|
|
8
|
+
MOpenAI(*args, **kwargs)
|
|
9
|
+
|
|
10
|
+
Subclass of OpenAI that transparently memoizes the instance's `post` method.
|
|
11
|
+
|
|
12
|
+
This class forwards all constructor arguments to the OpenAI base class and then
|
|
13
|
+
replaces the instance's `post` method with a memoized wrapper:
|
|
14
|
+
|
|
15
|
+
Behavior
|
|
16
|
+
- The memoized `post` caches responses based on the arguments with which it is
|
|
17
|
+
invoked, preventing repeated identical requests from invoking the underlying
|
|
18
|
+
OpenAI API repeatedly.
|
|
19
|
+
- Because `post` is replaced on the instance, the cache is by-default tied to
|
|
20
|
+
the MOpenAI instance (per-instance cache).
|
|
21
|
+
- Any initialization arguments are passed unchanged to OpenAI.__init__.
|
|
22
|
+
|
|
23
|
+
Notes and cautions
|
|
24
|
+
- The exact semantics of caching (cache key construction, expiry, max size,
|
|
25
|
+
persistence) depend on the implementation of `memoize`. Ensure that the
|
|
26
|
+
provided `memoize` supports the desired behavior (e.g., hashing of mutable
|
|
27
|
+
inputs, thread-safety, TTL, cache invalidation).
|
|
28
|
+
- If the original `post` method has important side effects or relies on
|
|
29
|
+
non-deterministic behavior, memoization may change program behavior.
|
|
30
|
+
- If you need a shared cache across instances, or more advanced cache controls,
|
|
31
|
+
modify `memoize` or wrap at a class/static level instead of assigning to the
|
|
32
|
+
bound method.
|
|
33
|
+
|
|
34
|
+
Example
|
|
35
|
+
m = MOpenAI(api_key="...", model="gpt-4")
|
|
36
|
+
r1 = m.post("Hello") # executes API call and caches result
|
|
37
|
+
r2 = m.post("Hello") # returns cached result (no API call)
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, *args, cache=True, **kwargs):
|
|
41
|
+
super().__init__(*args, **kwargs)
|
|
42
|
+
if cache:
|
|
43
|
+
self.post = memoize(self.post)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class MAsyncOpenAI(AsyncOpenAI):
|
|
47
|
+
"""
|
|
48
|
+
MAsyncOpenAI(*args, **kwargs)
|
|
49
|
+
|
|
50
|
+
Async subclass of AsyncOpenAI that transparently memoizes the instance's `post` method.
|
|
51
|
+
|
|
52
|
+
This class forwards all constructor arguments to the AsyncOpenAI base class and then
|
|
53
|
+
replaces the instance's `post` method with a memoized wrapper:
|
|
54
|
+
|
|
55
|
+
Behavior
|
|
56
|
+
- The memoized `post` caches responses based on the arguments with which it is
|
|
57
|
+
invoked, preventing repeated identical requests from invoking the underlying
|
|
58
|
+
OpenAI API repeatedly.
|
|
59
|
+
- Because `post` is replaced on the instance, the cache is by-default tied to
|
|
60
|
+
the MAsyncOpenAI instance (per-instance cache).
|
|
61
|
+
- Any initialization arguments are passed unchanged to AsyncOpenAI.__init__.
|
|
62
|
+
|
|
63
|
+
Example
|
|
64
|
+
m = MAsyncOpenAI(api_key="...", model="gpt-4")
|
|
65
|
+
r1 = await m.post("Hello") # executes API call and caches result
|
|
66
|
+
r2 = await m.post("Hello") # returns cached result (no API call)
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(self, *args, cache=True, **kwargs):
|
|
70
|
+
super().__init__(*args, **kwargs)
|
|
71
|
+
if cache:
|
|
72
|
+
self.post = memoize(self.post)
|
|
@@ -72,6 +72,7 @@ import openai
|
|
|
72
72
|
import requests
|
|
73
73
|
from loguru import logger
|
|
74
74
|
|
|
75
|
+
from llm_utils.lm.openai_memoize import MOpenAI
|
|
75
76
|
from speedy_utils.common.utils_io import load_by_ext
|
|
76
77
|
|
|
77
78
|
LORA_DIR: str = os.environ.get("LORA_DIR", "/loras")
|
|
@@ -82,7 +83,7 @@ logger.info(f"LORA_DIR: {LORA_DIR}")
|
|
|
82
83
|
|
|
83
84
|
def model_list(host_port: str, api_key: str = "abc") -> None:
|
|
84
85
|
"""List models from the vLLM server."""
|
|
85
|
-
client =
|
|
86
|
+
client = MOpenAI(base_url=f"http://{host_port}/v1", api_key=api_key)
|
|
86
87
|
models = client.models.list()
|
|
87
88
|
for model in models:
|
|
88
89
|
print(f"Model ID: {model.id}")
|
|
@@ -108,7 +108,7 @@ from .common.notebook_utils import (
|
|
|
108
108
|
)
|
|
109
109
|
|
|
110
110
|
# Cache utilities
|
|
111
|
-
from .common.utils_cache import
|
|
111
|
+
from .common.utils_cache import identify, identify_uuid, memoize
|
|
112
112
|
|
|
113
113
|
# IO utilities
|
|
114
114
|
from .common.utils_io import (
|
|
@@ -197,7 +197,6 @@ __all__ = [
|
|
|
197
197
|
# Function decorators
|
|
198
198
|
"retry_runtime",
|
|
199
199
|
# Cache utilities
|
|
200
|
-
"amemoize",
|
|
201
200
|
"memoize",
|
|
202
201
|
"identify",
|
|
203
202
|
"identify_uuid",
|
|
@@ -227,5 +226,4 @@ __all__ = [
|
|
|
227
226
|
"multi_thread",
|
|
228
227
|
# Notebook utilities
|
|
229
228
|
"change_dir",
|
|
230
|
-
"amemoize",
|
|
231
229
|
]
|