speedy-utils 1.1.23__py3-none-any.whl → 1.1.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +12 -8
- llm_utils/chat_format/__init__.py +2 -0
- llm_utils/chat_format/display.py +115 -44
- llm_utils/lm/__init__.py +14 -6
- llm_utils/lm/llm.py +413 -0
- llm_utils/lm/llm_signature.py +35 -0
- llm_utils/lm/mixins.py +379 -0
- llm_utils/lm/openai_memoize.py +18 -7
- llm_utils/lm/signature.py +26 -37
- llm_utils/lm/utils.py +61 -76
- speedy_utils/__init__.py +31 -2
- speedy_utils/all.py +30 -1
- speedy_utils/common/utils_cache.py +142 -1
- speedy_utils/common/utils_io.py +36 -26
- speedy_utils/common/utils_misc.py +25 -1
- speedy_utils/multi_worker/thread.py +145 -58
- {speedy_utils-1.1.23.dist-info → speedy_utils-1.1.25.dist-info}/METADATA +1 -1
- {speedy_utils-1.1.23.dist-info → speedy_utils-1.1.25.dist-info}/RECORD +20 -19
- llm_utils/lm/llm_as_a_judge.py +0 -390
- llm_utils/lm/llm_task.py +0 -614
- {speedy_utils-1.1.23.dist-info → speedy_utils-1.1.25.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.23.dist-info → speedy_utils-1.1.25.dist-info}/entry_points.txt +0 -0
llm_utils/lm/utils.py
CHANGED
|
@@ -14,6 +14,7 @@ from openai import OpenAI
|
|
|
14
14
|
|
|
15
15
|
try:
|
|
16
16
|
import psutil
|
|
17
|
+
|
|
17
18
|
HAS_PSUTIL = True
|
|
18
19
|
except ImportError:
|
|
19
20
|
HAS_PSUTIL = False
|
|
@@ -26,7 +27,7 @@ _VLLM_PROCESSES: List[subprocess.Popen] = []
|
|
|
26
27
|
|
|
27
28
|
def _extract_port_from_vllm_cmd(vllm_cmd: str) -> int:
|
|
28
29
|
"""Extract port from VLLM command string."""
|
|
29
|
-
port_match = re.search(r
|
|
30
|
+
port_match = re.search(r"--port\s+(\d+)", vllm_cmd)
|
|
30
31
|
if port_match:
|
|
31
32
|
return int(port_match.group(1))
|
|
32
33
|
return 8000
|
|
@@ -34,39 +35,39 @@ def _extract_port_from_vllm_cmd(vllm_cmd: str) -> int:
|
|
|
34
35
|
|
|
35
36
|
def _parse_env_vars_from_cmd(cmd: str) -> tuple[dict[str, str], str]:
|
|
36
37
|
"""Parse environment variables from command string.
|
|
37
|
-
|
|
38
|
+
|
|
38
39
|
Args:
|
|
39
40
|
cmd: Command string that may contain environment variables like 'VAR=value command...'
|
|
40
|
-
|
|
41
|
+
|
|
41
42
|
Returns:
|
|
42
43
|
Tuple of (env_dict, cleaned_cmd) where env_dict contains parsed env vars
|
|
43
44
|
and cleaned_cmd is the command without the env vars.
|
|
44
45
|
"""
|
|
45
46
|
import shlex
|
|
46
|
-
|
|
47
|
+
|
|
47
48
|
# Split the command while preserving quoted strings
|
|
48
49
|
parts = shlex.split(cmd)
|
|
49
|
-
|
|
50
|
+
|
|
50
51
|
env_vars = {}
|
|
51
52
|
cmd_parts = []
|
|
52
|
-
|
|
53
|
+
|
|
53
54
|
for part in parts:
|
|
54
|
-
if
|
|
55
|
+
if "=" in part and not part.startswith("-"):
|
|
55
56
|
# Check if this looks like an environment variable
|
|
56
57
|
# Should be KEY=VALUE format, not contain spaces (unless quoted), and KEY should be uppercase
|
|
57
|
-
key_value = part.split(
|
|
58
|
+
key_value = part.split("=", 1)
|
|
58
59
|
if len(key_value) == 2:
|
|
59
60
|
key, value = key_value
|
|
60
|
-
if key.isupper() and key.replace(
|
|
61
|
+
if key.isupper() and key.replace("_", "").isalnum():
|
|
61
62
|
env_vars[key] = value
|
|
62
63
|
continue
|
|
63
|
-
|
|
64
|
+
|
|
64
65
|
# Not an env var, add to command parts
|
|
65
66
|
cmd_parts.append(part)
|
|
66
|
-
|
|
67
|
+
|
|
67
68
|
# Reconstruct the cleaned command
|
|
68
|
-
cleaned_cmd =
|
|
69
|
-
|
|
69
|
+
cleaned_cmd = " ".join(cmd_parts)
|
|
70
|
+
|
|
70
71
|
return env_vars, cleaned_cmd
|
|
71
72
|
|
|
72
73
|
|
|
@@ -74,38 +75,33 @@ def _start_vllm_server(vllm_cmd: str, timeout: int = 120) -> subprocess.Popen:
|
|
|
74
75
|
"""Start VLLM server and wait for ready."""
|
|
75
76
|
# Parse environment variables from command
|
|
76
77
|
env_vars, cleaned_cmd = _parse_env_vars_from_cmd(vllm_cmd)
|
|
77
|
-
|
|
78
|
+
|
|
78
79
|
port = _extract_port_from_vllm_cmd(cleaned_cmd)
|
|
79
|
-
|
|
80
|
+
|
|
80
81
|
logger.info(f"Starting VLLM server: {cleaned_cmd}")
|
|
81
82
|
if env_vars:
|
|
82
83
|
logger.info(f"Environment variables: {env_vars}")
|
|
83
84
|
logger.info(f"VLLM output logged to: /tmp/vllm_{port}.txt")
|
|
84
|
-
|
|
85
|
-
with open(f
|
|
85
|
+
|
|
86
|
+
with open(f"/tmp/vllm_{port}.txt", "w") as log_file:
|
|
86
87
|
log_file.write(f"VLLM Server started at {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
|
|
87
88
|
log_file.write(f"Command: {cleaned_cmd}\n")
|
|
88
89
|
if env_vars:
|
|
89
90
|
log_file.write(f"Environment: {env_vars}\n")
|
|
90
91
|
log_file.write(f"Port: {port}\n")
|
|
91
92
|
log_file.write("-" * 50 + "\n")
|
|
92
|
-
|
|
93
|
+
|
|
93
94
|
# Prepare environment for subprocess
|
|
94
95
|
env = os.environ.copy()
|
|
95
96
|
env.update(env_vars)
|
|
96
97
|
|
|
97
|
-
with open(f
|
|
98
|
+
with open(f"/tmp/vllm_{port}.txt", "a") as log_file:
|
|
98
99
|
process = subprocess.Popen(
|
|
99
|
-
cleaned_cmd.split(),
|
|
100
|
-
stdout=log_file,
|
|
101
|
-
stderr=subprocess.STDOUT,
|
|
102
|
-
text=True,
|
|
103
|
-
preexec_fn=os.setsid,
|
|
104
|
-
env=env
|
|
100
|
+
cleaned_cmd.split(), stdout=log_file, stderr=subprocess.STDOUT, text=True, preexec_fn=os.setsid, env=env
|
|
105
101
|
)
|
|
106
|
-
|
|
102
|
+
|
|
107
103
|
_VLLM_PROCESSES.append(process)
|
|
108
|
-
|
|
104
|
+
|
|
109
105
|
start_time = time.time()
|
|
110
106
|
while time.time() - start_time < timeout:
|
|
111
107
|
try:
|
|
@@ -115,26 +111,24 @@ def _start_vllm_server(vllm_cmd: str, timeout: int = 120) -> subprocess.Popen:
|
|
|
115
111
|
return process
|
|
116
112
|
except requests.RequestException:
|
|
117
113
|
pass
|
|
118
|
-
|
|
114
|
+
|
|
119
115
|
if process.poll() is not None:
|
|
120
116
|
stdout, stderr = process.communicate()
|
|
121
117
|
raise RuntimeError(
|
|
122
|
-
f"VLLM server terminated unexpectedly. "
|
|
123
|
-
f"Return code: {process.returncode}, "
|
|
124
|
-
f"stderr: {stderr[:200]}..."
|
|
118
|
+
f"VLLM server terminated unexpectedly. Return code: {process.returncode}, stderr: {stderr[:200]}..."
|
|
125
119
|
)
|
|
126
|
-
|
|
120
|
+
|
|
127
121
|
time.sleep(2)
|
|
128
|
-
|
|
122
|
+
|
|
129
123
|
process.terminate()
|
|
130
124
|
try:
|
|
131
125
|
process.wait(timeout=5)
|
|
132
126
|
except subprocess.TimeoutExpired:
|
|
133
127
|
process.kill()
|
|
134
|
-
|
|
128
|
+
|
|
135
129
|
if process in _VLLM_PROCESSES:
|
|
136
130
|
_VLLM_PROCESSES.remove(process)
|
|
137
|
-
|
|
131
|
+
|
|
138
132
|
raise RuntimeError(f"VLLM server failed to start within {timeout}s on port {port}")
|
|
139
133
|
|
|
140
134
|
|
|
@@ -142,7 +136,7 @@ def _kill_vllm_on_port(port: int) -> bool:
|
|
|
142
136
|
"""Kill VLLM server on port."""
|
|
143
137
|
killed = False
|
|
144
138
|
logger.info(f"Checking VLLM server on port {port}")
|
|
145
|
-
|
|
139
|
+
|
|
146
140
|
processes_to_remove = []
|
|
147
141
|
for process in _VLLM_PROCESSES:
|
|
148
142
|
try:
|
|
@@ -151,8 +145,8 @@ def _kill_vllm_on_port(port: int) -> bool:
|
|
|
151
145
|
if HAS_PSUTIL:
|
|
152
146
|
try:
|
|
153
147
|
proc = psutil.Process(process.pid)
|
|
154
|
-
cmdline =
|
|
155
|
-
if f
|
|
148
|
+
cmdline = " ".join(proc.cmdline())
|
|
149
|
+
if f"--port {port}" in cmdline or f"--port={port}" in cmdline:
|
|
156
150
|
logger.info(f"Killing tracked VLLM process {process.pid} on port {port}")
|
|
157
151
|
os.killpg(os.getpgid(process.pid), signal.SIGTERM)
|
|
158
152
|
try:
|
|
@@ -164,7 +158,7 @@ def _kill_vllm_on_port(port: int) -> bool:
|
|
|
164
158
|
killed_process = True
|
|
165
159
|
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
|
166
160
|
pass
|
|
167
|
-
|
|
161
|
+
|
|
168
162
|
if not HAS_PSUTIL or not killed_process:
|
|
169
163
|
logger.info(f"Killing tracked VLLM process {process.pid}")
|
|
170
164
|
try:
|
|
@@ -177,24 +171,23 @@ def _kill_vllm_on_port(port: int) -> bool:
|
|
|
177
171
|
killed = True
|
|
178
172
|
except (ProcessLookupError, OSError):
|
|
179
173
|
pass
|
|
180
|
-
|
|
174
|
+
|
|
181
175
|
processes_to_remove.append(process)
|
|
182
176
|
else:
|
|
183
177
|
processes_to_remove.append(process)
|
|
184
178
|
except (ProcessLookupError, OSError):
|
|
185
179
|
processes_to_remove.append(process)
|
|
186
|
-
|
|
180
|
+
|
|
187
181
|
for process in processes_to_remove:
|
|
188
182
|
if process in _VLLM_PROCESSES:
|
|
189
183
|
_VLLM_PROCESSES.remove(process)
|
|
190
|
-
|
|
184
|
+
|
|
191
185
|
if not killed and HAS_PSUTIL:
|
|
192
186
|
try:
|
|
193
|
-
for proc in psutil.process_iter([
|
|
187
|
+
for proc in psutil.process_iter(["pid", "cmdline"]):
|
|
194
188
|
try:
|
|
195
|
-
cmdline =
|
|
196
|
-
if
|
|
197
|
-
(f'--port {port}' in cmdline or f'--port={port}' in cmdline)):
|
|
189
|
+
cmdline = " ".join(proc.info["cmdline"] or [])
|
|
190
|
+
if "vllm" in cmdline.lower() and (f"--port {port}" in cmdline or f"--port={port}" in cmdline):
|
|
198
191
|
logger.info(f"Killing untracked VLLM process {proc.info['pid']} on port {port}")
|
|
199
192
|
proc.terminate()
|
|
200
193
|
try:
|
|
@@ -207,13 +200,13 @@ def _kill_vllm_on_port(port: int) -> bool:
|
|
|
207
200
|
continue
|
|
208
201
|
except Exception as e:
|
|
209
202
|
logger.warning(f"Error searching processes on port {port}: {e}")
|
|
210
|
-
|
|
203
|
+
|
|
211
204
|
if killed:
|
|
212
205
|
logger.info(f"Killed VLLM server on port {port}")
|
|
213
206
|
time.sleep(2)
|
|
214
207
|
else:
|
|
215
208
|
logger.info(f"No VLLM server on port {port}")
|
|
216
|
-
|
|
209
|
+
|
|
217
210
|
return killed
|
|
218
211
|
|
|
219
212
|
|
|
@@ -262,32 +255,24 @@ def _is_server_running(port: int) -> bool:
|
|
|
262
255
|
return False
|
|
263
256
|
|
|
264
257
|
|
|
265
|
-
def get_base_client(
|
|
266
|
-
client=None,
|
|
267
|
-
cache: bool = True,
|
|
268
|
-
api_key="abc",
|
|
269
|
-
vllm_cmd=None,
|
|
270
|
-
vllm_process=None
|
|
271
|
-
) -> OpenAI:
|
|
258
|
+
def get_base_client(client=None, cache: bool = True, api_key="abc", vllm_cmd=None, vllm_process=None) -> OpenAI:
|
|
272
259
|
"""Get OpenAI client from various inputs."""
|
|
273
260
|
from llm_utils import MOpenAI
|
|
274
261
|
|
|
275
|
-
open_ai_class = OpenAI if not cache else MOpenAI
|
|
276
|
-
|
|
277
262
|
if client is None:
|
|
278
263
|
if vllm_cmd is not None:
|
|
279
264
|
# Parse environment variables from command to get clean command for port extraction
|
|
280
265
|
_, cleaned_cmd = _parse_env_vars_from_cmd(vllm_cmd)
|
|
281
266
|
port = _extract_port_from_vllm_cmd(cleaned_cmd)
|
|
282
|
-
return
|
|
267
|
+
return MOpenAI(base_url=f"http://localhost:{port}/v1", api_key=api_key, cache=cache)
|
|
283
268
|
else:
|
|
284
|
-
|
|
269
|
+
raise ValueError("Either client or vllm_cmd must be provided.")
|
|
285
270
|
elif isinstance(client, int):
|
|
286
|
-
return
|
|
271
|
+
return MOpenAI(base_url=f"http://localhost:{client}/v1", api_key=api_key, cache=cache)
|
|
287
272
|
elif isinstance(client, str):
|
|
288
|
-
return
|
|
273
|
+
return MOpenAI(base_url=client, api_key=api_key, cache=cache)
|
|
289
274
|
elif isinstance(client, OpenAI):
|
|
290
|
-
return client
|
|
275
|
+
return MOpenAI(base_url=client.base_url, api_key=api_key, cache=cache)
|
|
291
276
|
else:
|
|
292
277
|
raise ValueError("Invalid client type. Must be OpenAI, port (int), base_url (str), or None.")
|
|
293
278
|
|
|
@@ -296,17 +281,17 @@ def _is_lora_path(path: str) -> bool:
|
|
|
296
281
|
"""Check if path is LoRA adapter directory."""
|
|
297
282
|
if not os.path.isdir(path):
|
|
298
283
|
return False
|
|
299
|
-
adapter_config_path = os.path.join(path,
|
|
284
|
+
adapter_config_path = os.path.join(path, "adapter_config.json")
|
|
300
285
|
return os.path.isfile(adapter_config_path)
|
|
301
286
|
|
|
302
287
|
|
|
303
288
|
def _get_port_from_client(client: OpenAI) -> Optional[int]:
|
|
304
289
|
"""Extract port from OpenAI client base_url."""
|
|
305
|
-
if hasattr(client,
|
|
290
|
+
if hasattr(client, "base_url") and client.base_url:
|
|
306
291
|
base_url = str(client.base_url)
|
|
307
|
-
if
|
|
292
|
+
if "localhost:" in base_url:
|
|
308
293
|
try:
|
|
309
|
-
port_part = base_url.split(
|
|
294
|
+
port_part = base_url.split("localhost:")[1].split("/")[0]
|
|
310
295
|
return int(port_part)
|
|
311
296
|
except (IndexError, ValueError):
|
|
312
297
|
pass
|
|
@@ -315,14 +300,14 @@ def _get_port_from_client(client: OpenAI) -> Optional[int]:
|
|
|
315
300
|
|
|
316
301
|
def _load_lora_adapter(lora_path: str, port: int) -> str:
|
|
317
302
|
"""Load LoRA adapter from path."""
|
|
318
|
-
lora_name = os.path.basename(lora_path.rstrip(
|
|
303
|
+
lora_name = os.path.basename(lora_path.rstrip("/\\"))
|
|
319
304
|
if not lora_name:
|
|
320
305
|
lora_name = os.path.basename(os.path.dirname(lora_path))
|
|
321
|
-
|
|
306
|
+
|
|
322
307
|
response = requests.post(
|
|
323
|
-
f
|
|
324
|
-
headers={
|
|
325
|
-
json={"lora_name": lora_name, "lora_path": os.path.abspath(lora_path)}
|
|
308
|
+
f"http://localhost:{port}/v1/load_lora_adapter",
|
|
309
|
+
headers={"accept": "application/json", "Content-Type": "application/json"},
|
|
310
|
+
json={"lora_name": lora_name, "lora_path": os.path.abspath(lora_path)},
|
|
326
311
|
)
|
|
327
312
|
response.raise_for_status()
|
|
328
313
|
return lora_name
|
|
@@ -331,14 +316,14 @@ def _load_lora_adapter(lora_path: str, port: int) -> str:
|
|
|
331
316
|
def _unload_lora_adapter(lora_path: str, port: int) -> None:
|
|
332
317
|
"""Unload LoRA adapter."""
|
|
333
318
|
try:
|
|
334
|
-
lora_name = os.path.basename(lora_path.rstrip(
|
|
319
|
+
lora_name = os.path.basename(lora_path.rstrip("/\\"))
|
|
335
320
|
if not lora_name:
|
|
336
321
|
lora_name = os.path.basename(os.path.dirname(lora_path))
|
|
337
|
-
|
|
322
|
+
|
|
338
323
|
response = requests.post(
|
|
339
|
-
f
|
|
340
|
-
headers={
|
|
341
|
-
json={"lora_name": lora_name, "lora_int_id": 0}
|
|
324
|
+
f"http://localhost:{port}/v1/unload_lora_adapter",
|
|
325
|
+
headers={"accept": "application/json", "Content-Type": "application/json"},
|
|
326
|
+
json={"lora_name": lora_name, "lora_int_id": 0},
|
|
342
327
|
)
|
|
343
328
|
response.raise_for_status()
|
|
344
329
|
except requests.RequestException as e:
|
speedy_utils/__init__.py
CHANGED
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
# • timef(func) -> Callable - Function execution time decorator
|
|
17
17
|
# • retry_runtime(sleep_seconds: int, max_retry: int, exceptions) -> Callable
|
|
18
18
|
# • memoize(func) -> Callable - Function result caching decorator
|
|
19
|
+
# • imemoize(func) -> Callable - In-memory caching decorator (global persistent)
|
|
19
20
|
# • identify(obj: Any) -> str - Generate unique object identifier
|
|
20
21
|
# • identify_uuid(obj: Any) -> str - Generate UUID-based object identifier
|
|
21
22
|
# • load_by_ext(fname: Union[str, list[str]]) -> Any - Auto-detect file format loader
|
|
@@ -79,7 +80,24 @@ from glob import glob
|
|
|
79
80
|
from multiprocessing import Pool
|
|
80
81
|
from pathlib import Path
|
|
81
82
|
from threading import Lock
|
|
82
|
-
from typing import
|
|
83
|
+
from typing import (
|
|
84
|
+
Any,
|
|
85
|
+
Awaitable,
|
|
86
|
+
Callable as TypingCallable,
|
|
87
|
+
Dict,
|
|
88
|
+
Generic,
|
|
89
|
+
Iterable,
|
|
90
|
+
List,
|
|
91
|
+
Literal,
|
|
92
|
+
Mapping,
|
|
93
|
+
Optional,
|
|
94
|
+
Sequence,
|
|
95
|
+
Set,
|
|
96
|
+
Tuple,
|
|
97
|
+
Type,
|
|
98
|
+
TypeVar,
|
|
99
|
+
Union,
|
|
100
|
+
)
|
|
83
101
|
|
|
84
102
|
# Third-party imports
|
|
85
103
|
import numpy as np
|
|
@@ -108,7 +126,7 @@ from .common.notebook_utils import (
|
|
|
108
126
|
)
|
|
109
127
|
|
|
110
128
|
# Cache utilities
|
|
111
|
-
from .common.utils_cache import identify, identify_uuid, memoize
|
|
129
|
+
from .common.utils_cache import identify, identify_uuid, imemoize, memoize
|
|
112
130
|
|
|
113
131
|
# IO utilities
|
|
114
132
|
from .common.utils_io import (
|
|
@@ -124,6 +142,7 @@ from .common.utils_io import (
|
|
|
124
142
|
# Misc utilities
|
|
125
143
|
from .common.utils_misc import (
|
|
126
144
|
convert_to_builtin_python,
|
|
145
|
+
dedup,
|
|
127
146
|
flatten_list,
|
|
128
147
|
get_arg_names,
|
|
129
148
|
is_notebook,
|
|
@@ -171,12 +190,20 @@ __all__ = [
|
|
|
171
190
|
"defaultdict",
|
|
172
191
|
# Typing
|
|
173
192
|
"Any",
|
|
193
|
+
"Awaitable",
|
|
174
194
|
"Callable",
|
|
195
|
+
"TypingCallable",
|
|
175
196
|
"Dict",
|
|
176
197
|
"Generic",
|
|
198
|
+
"Iterable",
|
|
177
199
|
"List",
|
|
178
200
|
"Literal",
|
|
201
|
+
"Mapping",
|
|
179
202
|
"Optional",
|
|
203
|
+
"Sequence",
|
|
204
|
+
"Set",
|
|
205
|
+
"Tuple",
|
|
206
|
+
"Type",
|
|
180
207
|
"TypeVar",
|
|
181
208
|
"Union",
|
|
182
209
|
# Third-party
|
|
@@ -198,6 +225,7 @@ __all__ = [
|
|
|
198
225
|
"retry_runtime",
|
|
199
226
|
# Cache utilities
|
|
200
227
|
"memoize",
|
|
228
|
+
"imemoize",
|
|
201
229
|
"identify",
|
|
202
230
|
"identify_uuid",
|
|
203
231
|
# IO utilities
|
|
@@ -214,6 +242,7 @@ __all__ = [
|
|
|
214
242
|
"get_arg_names",
|
|
215
243
|
"is_notebook",
|
|
216
244
|
"convert_to_builtin_python",
|
|
245
|
+
"dedup",
|
|
217
246
|
# Print utilities
|
|
218
247
|
"display_pretty_table_html",
|
|
219
248
|
"flatten_dict",
|
speedy_utils/all.py
CHANGED
|
@@ -71,7 +71,24 @@ from glob import glob
|
|
|
71
71
|
from multiprocessing import Pool
|
|
72
72
|
from pathlib import Path
|
|
73
73
|
from threading import Lock
|
|
74
|
-
from typing import
|
|
74
|
+
from typing import (
|
|
75
|
+
Any,
|
|
76
|
+
Awaitable,
|
|
77
|
+
Callable as TypingCallable,
|
|
78
|
+
Dict,
|
|
79
|
+
Generic,
|
|
80
|
+
Iterable,
|
|
81
|
+
List,
|
|
82
|
+
Literal,
|
|
83
|
+
Mapping,
|
|
84
|
+
Optional,
|
|
85
|
+
Sequence,
|
|
86
|
+
Set,
|
|
87
|
+
Tuple,
|
|
88
|
+
Type,
|
|
89
|
+
TypeVar,
|
|
90
|
+
Union,
|
|
91
|
+
)
|
|
75
92
|
|
|
76
93
|
# Third-party imports
|
|
77
94
|
import numpy as np
|
|
@@ -115,6 +132,9 @@ from speedy_utils import ( # Clock module; Function decorators; Cache utilities
|
|
|
115
132
|
timef,
|
|
116
133
|
)
|
|
117
134
|
|
|
135
|
+
|
|
136
|
+
choice = random.choice
|
|
137
|
+
|
|
118
138
|
# Define __all__ explicitly with all exports
|
|
119
139
|
__all__ = [
|
|
120
140
|
# Standard library
|
|
@@ -146,12 +166,20 @@ __all__ = [
|
|
|
146
166
|
"defaultdict",
|
|
147
167
|
# Typing
|
|
148
168
|
"Any",
|
|
169
|
+
"Awaitable",
|
|
149
170
|
"Callable",
|
|
171
|
+
"TypingCallable",
|
|
150
172
|
"Dict",
|
|
151
173
|
"Generic",
|
|
174
|
+
"Iterable",
|
|
152
175
|
"List",
|
|
153
176
|
"Literal",
|
|
177
|
+
"Mapping",
|
|
154
178
|
"Optional",
|
|
179
|
+
"Sequence",
|
|
180
|
+
"Set",
|
|
181
|
+
"Tuple",
|
|
182
|
+
"Type",
|
|
155
183
|
"TypeVar",
|
|
156
184
|
"Union",
|
|
157
185
|
# Third-party
|
|
@@ -199,4 +227,5 @@ __all__ = [
|
|
|
199
227
|
# Multi-worker processing
|
|
200
228
|
"multi_process",
|
|
201
229
|
"multi_thread",
|
|
230
|
+
"choice",
|
|
202
231
|
]
|
|
@@ -44,6 +44,9 @@ _MEM_CACHES: "weakref.WeakKeyDictionary[Callable[..., Any], cachetools.LRUCache]
|
|
|
44
44
|
weakref.WeakKeyDictionary()
|
|
45
45
|
)
|
|
46
46
|
|
|
47
|
+
# Global memory cache for imemoize (persists across IPython reloads)
|
|
48
|
+
_GLOBAL_MEMORY_CACHE: dict[str, Any] = {}
|
|
49
|
+
|
|
47
50
|
# Backward-compat global symbol (internal only; not exported)
|
|
48
51
|
LRU_MEM_CACHE = cachetools.LRUCache(maxsize=256)
|
|
49
52
|
|
|
@@ -680,4 +683,142 @@ def memoize(
|
|
|
680
683
|
return decorator(_func)
|
|
681
684
|
|
|
682
685
|
|
|
683
|
-
|
|
686
|
+
# --------------------------------------------------------------------------------------
|
|
687
|
+
# In-memory memoize with global persistent cache
|
|
688
|
+
# --------------------------------------------------------------------------------------
|
|
689
|
+
|
|
690
|
+
|
|
691
|
+
@overload
|
|
692
|
+
def imemoize(
|
|
693
|
+
_func: Callable[P, R],
|
|
694
|
+
*,
|
|
695
|
+
keys: Optional[list[str]] = ...,
|
|
696
|
+
key: Optional[Callable[..., Any]] = ...,
|
|
697
|
+
ignore_self: bool = ...,
|
|
698
|
+
) -> Callable[P, R]: ...
|
|
699
|
+
|
|
700
|
+
|
|
701
|
+
@overload
|
|
702
|
+
def imemoize(
|
|
703
|
+
_func: Callable[P, Awaitable[R]],
|
|
704
|
+
*,
|
|
705
|
+
keys: Optional[list[str]] = ...,
|
|
706
|
+
key: Optional[Callable[..., Any]] = ...,
|
|
707
|
+
ignore_self: bool = ...,
|
|
708
|
+
) -> Callable[P, Awaitable[R]]: ...
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
@overload
|
|
712
|
+
def imemoize(
|
|
713
|
+
_func: None = ...,
|
|
714
|
+
*,
|
|
715
|
+
keys: Optional[list[str]] = ...,
|
|
716
|
+
key: Optional[Callable[..., Any]] = ...,
|
|
717
|
+
ignore_self: bool = ...,
|
|
718
|
+
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
@overload
|
|
722
|
+
def imemoize( # type: ignore
|
|
723
|
+
_func: None = ...,
|
|
724
|
+
*,
|
|
725
|
+
keys: Optional[list[str]] = ...,
|
|
726
|
+
key: Optional[Callable[..., Any]] = ...,
|
|
727
|
+
ignore_self: bool = ...,
|
|
728
|
+
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: ...
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
def imemoize(
|
|
732
|
+
_func: Optional[Callable[P, Any]] = None,
|
|
733
|
+
*,
|
|
734
|
+
keys: Optional[list[str]] = None,
|
|
735
|
+
key: Optional[Callable[..., Any]] = None,
|
|
736
|
+
ignore_self: bool = True,
|
|
737
|
+
):
|
|
738
|
+
"""
|
|
739
|
+
In-memory memoization decorator with global persistent cache.
|
|
740
|
+
|
|
741
|
+
Unlike regular memoize, this uses a global memory cache that persists
|
|
742
|
+
across IPython %load executions. The cache key is based on the function's
|
|
743
|
+
source code combined with runtime arguments, making it suitable for
|
|
744
|
+
notebook environments where functions may be reloaded.
|
|
745
|
+
|
|
746
|
+
Args:
|
|
747
|
+
keys: list of argument names to include in key (optional).
|
|
748
|
+
key: custom callable (*args, **kwargs) -> hashable for keying (optional).
|
|
749
|
+
ignore_self: ignore 'self' when building cache key for bound methods.
|
|
750
|
+
|
|
751
|
+
Example:
|
|
752
|
+
@imemoize
|
|
753
|
+
def expensive_computation(x):
|
|
754
|
+
import time
|
|
755
|
+
time.sleep(2)
|
|
756
|
+
return x * x
|
|
757
|
+
|
|
758
|
+
# First call computes and caches
|
|
759
|
+
result1 = expensive_computation(5)
|
|
760
|
+
|
|
761
|
+
# Second call retrieves from memory cache
|
|
762
|
+
result2 = expensive_computation(5)
|
|
763
|
+
|
|
764
|
+
# Even after %load file.py in IPython, the cache persists
|
|
765
|
+
"""
|
|
766
|
+
|
|
767
|
+
def decorator(func: Callable[P, Any]) -> Callable[P, Any]:
|
|
768
|
+
is_async = inspect.iscoroutinefunction(func)
|
|
769
|
+
|
|
770
|
+
if is_async:
|
|
771
|
+
@functools.wraps(func)
|
|
772
|
+
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
|
|
773
|
+
# Compute cache key based on function source + args
|
|
774
|
+
func_source, sub_dir, key_id = _compute_cache_components(
|
|
775
|
+
func, args, kwargs, ignore_self, keys, key
|
|
776
|
+
)
|
|
777
|
+
cache_key = identify((func_source, sub_dir, key_id))
|
|
778
|
+
|
|
779
|
+
# Check global memory cache
|
|
780
|
+
with mem_lock:
|
|
781
|
+
if cache_key in _GLOBAL_MEMORY_CACHE:
|
|
782
|
+
return _GLOBAL_MEMORY_CACHE[cache_key]
|
|
783
|
+
|
|
784
|
+
# Compute result and store in cache
|
|
785
|
+
result = await func(*args, **kwargs)
|
|
786
|
+
|
|
787
|
+
with mem_lock:
|
|
788
|
+
_GLOBAL_MEMORY_CACHE[cache_key] = result
|
|
789
|
+
|
|
790
|
+
return result
|
|
791
|
+
|
|
792
|
+
return async_wrapper
|
|
793
|
+
else:
|
|
794
|
+
@functools.wraps(func)
|
|
795
|
+
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
|
|
796
|
+
# Compute cache key based on function source + args
|
|
797
|
+
func_source, sub_dir, key_id = _compute_cache_components(
|
|
798
|
+
func, args, kwargs, ignore_self, keys, key
|
|
799
|
+
)
|
|
800
|
+
cache_key = identify((func_source, sub_dir, key_id))
|
|
801
|
+
|
|
802
|
+
# Check global memory cache
|
|
803
|
+
with mem_lock:
|
|
804
|
+
if cache_key in _GLOBAL_MEMORY_CACHE:
|
|
805
|
+
return _GLOBAL_MEMORY_CACHE[cache_key]
|
|
806
|
+
|
|
807
|
+
# Compute result and store in cache
|
|
808
|
+
result = func(*args, **kwargs)
|
|
809
|
+
|
|
810
|
+
with mem_lock:
|
|
811
|
+
_GLOBAL_MEMORY_CACHE[cache_key] = result
|
|
812
|
+
|
|
813
|
+
return result
|
|
814
|
+
|
|
815
|
+
return sync_wrapper
|
|
816
|
+
|
|
817
|
+
# Support both @imemoize and @imemoize(...)
|
|
818
|
+
if _func is None:
|
|
819
|
+
return decorator
|
|
820
|
+
else:
|
|
821
|
+
return decorator(_func)
|
|
822
|
+
|
|
823
|
+
|
|
824
|
+
__all__ = ["memoize", "imemoize", "identify"]
|