speedy-utils 1.0.13__py3-none-any.whl → 1.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +10 -15
- llm_utils/lm/alm.py +24 -3
- llm_utils/lm/chat_html.py +244 -0
- llm_utils/lm/lm.py +390 -74
- llm_utils/lm/lm_json.py +72 -0
- llm_utils/scripts/README.md +48 -0
- llm_utils/scripts/example_vllm_client.py +269 -0
- llm_utils/scripts/requirements_example.txt +3 -0
- llm_utils/scripts/serve_script.sh +2 -0
- speedy_utils/__init__.py +96 -5
- speedy_utils/common/notebook_utils.py +63 -0
- speedy_utils/common/utils_cache.py +3 -3
- speedy_utils/common/utils_print.py +2 -65
- speedy_utils/multi_worker/process.py +7 -0
- speedy_utils/scripts/__init__.py +0 -0
- speedy_utils/scripts/mpython.py +3 -2
- speedy_utils/scripts/openapi_client_codegen.py +258 -0
- {speedy_utils-1.0.13.dist-info → speedy_utils-1.0.15.dist-info}/METADATA +1 -1
- {speedy_utils-1.0.13.dist-info → speedy_utils-1.0.15.dist-info}/RECORD +21 -12
- {speedy_utils-1.0.13.dist-info → speedy_utils-1.0.15.dist-info}/entry_points.txt +1 -0
- {speedy_utils-1.0.13.dist-info → speedy_utils-1.0.15.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Beautiful example script for interacting with VLLM server.
|
|
3
|
+
|
|
4
|
+
This script demonstrates various ways to use the VLLM API server
|
|
5
|
+
for text generation tasks.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import json
|
|
10
|
+
from typing import Dict, List, Optional, Any
|
|
11
|
+
|
|
12
|
+
import aiohttp
|
|
13
|
+
from loguru import logger
|
|
14
|
+
from pydantic import BaseModel, Field
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class VLLMRequest(BaseModel):
|
|
18
|
+
"""Request model for VLLM API."""
|
|
19
|
+
prompt: str
|
|
20
|
+
max_tokens: int = Field(default=512, ge=1, le=8192)
|
|
21
|
+
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
|
|
22
|
+
top_p: float = Field(default=0.9, ge=0.0, le=1.0)
|
|
23
|
+
stream: bool = False
|
|
24
|
+
stop: Optional[List[str]] = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class VLLMResponse(BaseModel):
|
|
28
|
+
"""Response model from VLLM API."""
|
|
29
|
+
text: str
|
|
30
|
+
finish_reason: str
|
|
31
|
+
prompt_tokens: int
|
|
32
|
+
completion_tokens: int
|
|
33
|
+
total_tokens: int
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class VLLMClient:
|
|
37
|
+
"""Client for interacting with VLLM server."""
|
|
38
|
+
|
|
39
|
+
def __init__(self, base_url: str = 'http://localhost:8140'):
|
|
40
|
+
self.base_url = base_url
|
|
41
|
+
self.model_name = 'selfeval_8b'
|
|
42
|
+
|
|
43
|
+
async def generate_text(
|
|
44
|
+
self,
|
|
45
|
+
request: VLLMRequest
|
|
46
|
+
) -> VLLMResponse:
|
|
47
|
+
"""Generate text using VLLM API."""
|
|
48
|
+
url = f'{self.base_url}/v1/completions'
|
|
49
|
+
|
|
50
|
+
payload = {
|
|
51
|
+
'model': self.model_name,
|
|
52
|
+
'prompt': request.prompt,
|
|
53
|
+
'max_tokens': request.max_tokens,
|
|
54
|
+
'temperature': request.temperature,
|
|
55
|
+
'top_p': request.top_p,
|
|
56
|
+
'stream': request.stream,
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
if request.stop:
|
|
60
|
+
payload['stop'] = request.stop
|
|
61
|
+
|
|
62
|
+
async with aiohttp.ClientSession() as session:
|
|
63
|
+
try:
|
|
64
|
+
async with session.post(
|
|
65
|
+
url,
|
|
66
|
+
json=payload,
|
|
67
|
+
timeout=aiohttp.ClientTimeout(total=60)
|
|
68
|
+
) as response:
|
|
69
|
+
response.raise_for_status()
|
|
70
|
+
data = await response.json()
|
|
71
|
+
|
|
72
|
+
choice = data['choices'][0]
|
|
73
|
+
usage = data['usage']
|
|
74
|
+
|
|
75
|
+
return VLLMResponse(
|
|
76
|
+
text=choice['text'],
|
|
77
|
+
finish_reason=choice['finish_reason'],
|
|
78
|
+
prompt_tokens=usage['prompt_tokens'],
|
|
79
|
+
completion_tokens=usage['completion_tokens'],
|
|
80
|
+
total_tokens=usage['total_tokens']
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
except aiohttp.ClientError as e:
|
|
84
|
+
logger.error(f'HTTP error: {e}')
|
|
85
|
+
raise
|
|
86
|
+
except Exception as e:
|
|
87
|
+
logger.error(f'Unexpected error: {e}')
|
|
88
|
+
raise
|
|
89
|
+
|
|
90
|
+
async def generate_batch(
|
|
91
|
+
self,
|
|
92
|
+
requests: List[VLLMRequest]
|
|
93
|
+
) -> List[VLLMResponse]:
|
|
94
|
+
"""Generate text for multiple requests concurrently."""
|
|
95
|
+
tasks = [self.generate_text(req) for req in requests]
|
|
96
|
+
return await asyncio.gather(*tasks, return_exceptions=True)
|
|
97
|
+
|
|
98
|
+
async def health_check(self) -> bool:
|
|
99
|
+
"""Check if the VLLM server is healthy."""
|
|
100
|
+
url = f'{self.base_url}/health'
|
|
101
|
+
|
|
102
|
+
try:
|
|
103
|
+
async with aiohttp.ClientSession() as session:
|
|
104
|
+
async with session.get(
|
|
105
|
+
url,
|
|
106
|
+
timeout=aiohttp.ClientTimeout(total=10)
|
|
107
|
+
) as response:
|
|
108
|
+
return response.status == 200
|
|
109
|
+
except Exception as e:
|
|
110
|
+
logger.warning(f'Health check failed: {e}')
|
|
111
|
+
return False
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
async def example_basic_generation():
|
|
115
|
+
"""Example: Basic text generation."""
|
|
116
|
+
logger.info('🚀 Running basic generation example')
|
|
117
|
+
|
|
118
|
+
client = VLLMClient()
|
|
119
|
+
|
|
120
|
+
# Check server health
|
|
121
|
+
if not await client.health_check():
|
|
122
|
+
logger.error('❌ Server is not healthy')
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
request = VLLMRequest(
|
|
126
|
+
prompt='Explain the concept of machine learning in simple terms:',
|
|
127
|
+
max_tokens=256,
|
|
128
|
+
temperature=0.7,
|
|
129
|
+
stop=['\n\n']
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
try:
|
|
133
|
+
response = await client.generate_text(request)
|
|
134
|
+
|
|
135
|
+
logger.success('✅ Generation completed')
|
|
136
|
+
logger.info(f'📝 Generated text:\n{response.text}')
|
|
137
|
+
logger.info(f'📊 Tokens: {response.total_tokens} total '
|
|
138
|
+
f'({response.prompt_tokens} prompt + '
|
|
139
|
+
f'{response.completion_tokens} completion)')
|
|
140
|
+
|
|
141
|
+
except Exception as e:
|
|
142
|
+
logger.error(f'❌ Generation failed: {e}')
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
async def example_batch_generation():
|
|
146
|
+
"""Example: Batch text generation."""
|
|
147
|
+
logger.info('🚀 Running batch generation example')
|
|
148
|
+
|
|
149
|
+
client = VLLMClient()
|
|
150
|
+
|
|
151
|
+
prompts = [
|
|
152
|
+
'What is artificial intelligence?',
|
|
153
|
+
'Explain quantum computing briefly:',
|
|
154
|
+
'What are the benefits of renewable energy?'
|
|
155
|
+
]
|
|
156
|
+
|
|
157
|
+
requests = [
|
|
158
|
+
VLLMRequest(
|
|
159
|
+
prompt=prompt,
|
|
160
|
+
max_tokens=128,
|
|
161
|
+
temperature=0.8
|
|
162
|
+
) for prompt in prompts
|
|
163
|
+
]
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
responses = await client.generate_batch(requests)
|
|
167
|
+
|
|
168
|
+
for i, response in enumerate(responses):
|
|
169
|
+
if isinstance(response, Exception):
|
|
170
|
+
logger.error(f'❌ Request {i+1} failed: {response}')
|
|
171
|
+
else:
|
|
172
|
+
logger.success(f'✅ Request {i+1} completed')
|
|
173
|
+
logger.info(f'📝 Response {i+1}:\n{response.text}\n')
|
|
174
|
+
|
|
175
|
+
except Exception as e:
|
|
176
|
+
logger.error(f'❌ Batch generation failed: {e}')
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
async def example_creative_writing():
|
|
180
|
+
"""Example: Creative writing with specific parameters."""
|
|
181
|
+
logger.info('🚀 Running creative writing example')
|
|
182
|
+
|
|
183
|
+
client = VLLMClient()
|
|
184
|
+
|
|
185
|
+
request = VLLMRequest(
|
|
186
|
+
prompt=(
|
|
187
|
+
'Write a short story about a robot discovering emotions. '
|
|
188
|
+
'The story should be exactly 3 paragraphs:\n\n'
|
|
189
|
+
),
|
|
190
|
+
max_tokens=400,
|
|
191
|
+
temperature=1.2, # Higher temperature for creativity
|
|
192
|
+
top_p=0.95,
|
|
193
|
+
stop=['THE END', '\n\n\n']
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
try:
|
|
197
|
+
response = await client.generate_text(request)
|
|
198
|
+
|
|
199
|
+
logger.success('✅ Creative writing completed')
|
|
200
|
+
logger.info(f'📚 Story:\n{response.text}')
|
|
201
|
+
logger.info(f'🎯 Finish reason: {response.finish_reason}')
|
|
202
|
+
|
|
203
|
+
except Exception as e:
|
|
204
|
+
logger.error(f'❌ Creative writing failed: {e}')
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
async def example_code_generation():
|
|
208
|
+
"""Example: Code generation."""
|
|
209
|
+
logger.info('🚀 Running code generation example')
|
|
210
|
+
|
|
211
|
+
client = VLLMClient()
|
|
212
|
+
|
|
213
|
+
request = VLLMRequest(
|
|
214
|
+
prompt=(
|
|
215
|
+
'Write a Python function that calculates the fibonacci '
|
|
216
|
+
'sequence up to n terms:\n\n```python\n'
|
|
217
|
+
),
|
|
218
|
+
max_tokens=300,
|
|
219
|
+
temperature=0.2, # Lower temperature for code
|
|
220
|
+
stop=['```', '\n\n\n']
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
try:
|
|
224
|
+
response = await client.generate_text(request)
|
|
225
|
+
|
|
226
|
+
logger.success('✅ Code generation completed')
|
|
227
|
+
logger.info(f'💻 Generated code:\n```python\n{response.text}\n```')
|
|
228
|
+
|
|
229
|
+
except Exception as e:
|
|
230
|
+
logger.error(f'❌ Code generation failed: {e}')
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
async def main():
|
|
234
|
+
"""Run all examples."""
|
|
235
|
+
logger.info('🎯 Starting VLLM Client Examples')
|
|
236
|
+
logger.info('=' * 50)
|
|
237
|
+
|
|
238
|
+
examples = [
|
|
239
|
+
example_basic_generation,
|
|
240
|
+
example_batch_generation,
|
|
241
|
+
example_creative_writing,
|
|
242
|
+
example_code_generation
|
|
243
|
+
]
|
|
244
|
+
|
|
245
|
+
for example in examples:
|
|
246
|
+
await example()
|
|
247
|
+
logger.info('-' * 50)
|
|
248
|
+
await asyncio.sleep(1) # Brief pause between examples
|
|
249
|
+
|
|
250
|
+
logger.info('🎉 All examples completed!')
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
if __name__ == '__main__':
|
|
254
|
+
# Configure logger
|
|
255
|
+
logger.remove()
|
|
256
|
+
logger.add(
|
|
257
|
+
lambda msg: print(msg, end=''),
|
|
258
|
+
format='<green>{time:HH:mm:ss}</green> | '
|
|
259
|
+
'<level>{level: <8}</level> | '
|
|
260
|
+
'<cyan>{message}</cyan>',
|
|
261
|
+
level='INFO'
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
try:
|
|
265
|
+
asyncio.run(main())
|
|
266
|
+
except KeyboardInterrupt:
|
|
267
|
+
logger.info('\n👋 Goodbye!')
|
|
268
|
+
except Exception as e:
|
|
269
|
+
logger.error(f'❌ Script failed: {e}')
|
|
@@ -0,0 +1,2 @@
|
|
|
1
|
+
HF_HOME=/home/anhvth5/.cache/huggingface CUDA_VISIBLE_DEVICES=0 /home/anhvth5/miniconda3/envs/unsloth_env/bin/vllm serve ./outputs/8B_selfeval_retranslate/Qwen3-8B_2025_05_30/ls_response_only_r8_a8_sq8192_lr5e_06_bz64_ep1_4/ --port 8140 --tensor-parallel 1 --gpu-memory-utilization 0.9 --dtype auto --max-model-len 8192 --enable-prefix-caching --disable-log-requests --served-model-name selfeval_8b
|
|
2
|
+
Logging to /tmp/vllm_8140.txt
|
speedy_utils/__init__.py
CHANGED
|
@@ -33,9 +33,11 @@ from .common.utils_misc import (
|
|
|
33
33
|
|
|
34
34
|
# Print utilities
|
|
35
35
|
from .common.utils_print import (
|
|
36
|
-
display_pretty_table_html,
|
|
37
36
|
flatten_dict,
|
|
38
37
|
fprint,
|
|
38
|
+
)
|
|
39
|
+
from .common.notebook_utils import (
|
|
40
|
+
display_pretty_table_html,
|
|
39
41
|
print_table,
|
|
40
42
|
)
|
|
41
43
|
|
|
@@ -43,8 +45,98 @@ from .common.utils_print import (
|
|
|
43
45
|
from .multi_worker.process import multi_process
|
|
44
46
|
from .multi_worker.thread import multi_thread
|
|
45
47
|
|
|
48
|
+
# notebook
|
|
49
|
+
from .common.notebook_utils import change_dir
|
|
50
|
+
|
|
51
|
+
# Standard library imports
|
|
52
|
+
import copy
|
|
53
|
+
import functools
|
|
54
|
+
import gc
|
|
55
|
+
import inspect
|
|
56
|
+
import json
|
|
57
|
+
import multiprocessing
|
|
58
|
+
import os
|
|
59
|
+
import os.path as osp
|
|
60
|
+
import pickle
|
|
61
|
+
import pprint
|
|
62
|
+
import random
|
|
63
|
+
import re
|
|
64
|
+
import sys
|
|
65
|
+
import textwrap
|
|
66
|
+
import threading
|
|
67
|
+
import time
|
|
68
|
+
import traceback
|
|
69
|
+
import uuid
|
|
70
|
+
from collections import Counter, defaultdict
|
|
71
|
+
from collections.abc import Callable
|
|
72
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
73
|
+
from glob import glob
|
|
74
|
+
from multiprocessing import Pool
|
|
75
|
+
from pathlib import Path
|
|
76
|
+
from threading import Lock
|
|
77
|
+
from typing import Any, Dict, Generic, List, Literal, Optional, TypeVar, Union
|
|
78
|
+
|
|
79
|
+
# Third-party imports
|
|
80
|
+
import numpy as np
|
|
81
|
+
import pandas as pd
|
|
82
|
+
import xxhash
|
|
83
|
+
from IPython.core.getipython import get_ipython
|
|
84
|
+
from IPython.display import HTML, display
|
|
85
|
+
from loguru import logger
|
|
86
|
+
from pydantic import BaseModel
|
|
87
|
+
from tabulate import tabulate
|
|
88
|
+
from tqdm import tqdm
|
|
89
|
+
|
|
46
90
|
# Define __all__ explicitly
|
|
47
91
|
__all__ = [
|
|
92
|
+
# Standard library
|
|
93
|
+
"random",
|
|
94
|
+
"copy",
|
|
95
|
+
"functools",
|
|
96
|
+
"gc",
|
|
97
|
+
"inspect",
|
|
98
|
+
"json",
|
|
99
|
+
"multiprocessing",
|
|
100
|
+
"os",
|
|
101
|
+
"osp",
|
|
102
|
+
"pickle",
|
|
103
|
+
"pprint",
|
|
104
|
+
"re",
|
|
105
|
+
"sys",
|
|
106
|
+
"textwrap",
|
|
107
|
+
"threading",
|
|
108
|
+
"time",
|
|
109
|
+
"traceback",
|
|
110
|
+
"uuid",
|
|
111
|
+
"Counter",
|
|
112
|
+
"ThreadPoolExecutor",
|
|
113
|
+
"as_completed",
|
|
114
|
+
"glob",
|
|
115
|
+
"Pool",
|
|
116
|
+
"Path",
|
|
117
|
+
"Lock",
|
|
118
|
+
"defaultdict",
|
|
119
|
+
# Typing
|
|
120
|
+
"Any",
|
|
121
|
+
"Callable",
|
|
122
|
+
"Dict",
|
|
123
|
+
"Generic",
|
|
124
|
+
"List",
|
|
125
|
+
"Literal",
|
|
126
|
+
"Optional",
|
|
127
|
+
"TypeVar",
|
|
128
|
+
"Union",
|
|
129
|
+
# Third-party
|
|
130
|
+
"pd",
|
|
131
|
+
"xxhash",
|
|
132
|
+
"get_ipython",
|
|
133
|
+
"HTML",
|
|
134
|
+
"display",
|
|
135
|
+
"logger",
|
|
136
|
+
"BaseModel",
|
|
137
|
+
"tabulate",
|
|
138
|
+
"tqdm",
|
|
139
|
+
"np",
|
|
48
140
|
# Clock module
|
|
49
141
|
"Clock",
|
|
50
142
|
"speedy_timer",
|
|
@@ -79,7 +171,6 @@ __all__ = [
|
|
|
79
171
|
# Multi-worker processing
|
|
80
172
|
"multi_process",
|
|
81
173
|
"multi_thread",
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
# setup_logger('D')
|
|
174
|
+
# Notebook utilities
|
|
175
|
+
"change_dir",
|
|
176
|
+
]
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
# jupyter notebook utilities
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import pathlib
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from IPython.display import HTML, display
|
|
8
|
+
from tabulate import tabulate
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def change_dir(target_directory: str = 'POLY') -> None:
|
|
12
|
+
"""Change directory to the first occurrence of x in the current path."""
|
|
13
|
+
cur_dir = pathlib.Path('./')
|
|
14
|
+
target_dir = str(cur_dir.absolute()).split(target_directory)[0] + target_directory
|
|
15
|
+
os.chdir(target_dir)
|
|
16
|
+
print(f'Current dir: {target_dir}')
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def display_pretty_table_html(data: dict) -> None:
|
|
20
|
+
"""Display a pretty HTML table in Jupyter notebooks."""
|
|
21
|
+
table = "<table>"
|
|
22
|
+
for key, value in data.items():
|
|
23
|
+
table += f"<tr><td>{key}</td><td>{value}</td></tr>"
|
|
24
|
+
table += "</table>"
|
|
25
|
+
display(HTML(table))
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def print_table(data: Any, use_html: bool = True) -> None:
|
|
29
|
+
"""Print data as a table. If use_html is True, display using IPython HTML."""
|
|
30
|
+
|
|
31
|
+
def __get_table(data: Any) -> str:
|
|
32
|
+
if isinstance(data, str):
|
|
33
|
+
try:
|
|
34
|
+
data = json.loads(data)
|
|
35
|
+
except json.JSONDecodeError as exc:
|
|
36
|
+
raise ValueError("String input could not be decoded as JSON") from exc
|
|
37
|
+
|
|
38
|
+
if isinstance(data, list):
|
|
39
|
+
if all(isinstance(item, dict) for item in data):
|
|
40
|
+
headers = list(data[0].keys())
|
|
41
|
+
rows = [list(item.values()) for item in data]
|
|
42
|
+
return tabulate(
|
|
43
|
+
rows, headers=headers, tablefmt="html" if use_html else "grid"
|
|
44
|
+
)
|
|
45
|
+
else:
|
|
46
|
+
raise ValueError("List must contain dictionaries")
|
|
47
|
+
|
|
48
|
+
if isinstance(data, dict):
|
|
49
|
+
headers = ["Key", "Value"]
|
|
50
|
+
rows = list(data.items())
|
|
51
|
+
return tabulate(
|
|
52
|
+
rows, headers=headers, tablefmt="html" if use_html else "grid"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
raise TypeError(
|
|
56
|
+
"Input data must be a list of dictionaries, a dictionary, or a JSON string"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
table = __get_table(data)
|
|
60
|
+
if use_html:
|
|
61
|
+
display(HTML(table))
|
|
62
|
+
else:
|
|
63
|
+
print(table)
|
|
@@ -80,9 +80,9 @@ def identify(obj: Any, depth=0, max_depth=2) -> str:
|
|
|
80
80
|
elif obj is None:
|
|
81
81
|
return identify("None", depth + 1, max_depth)
|
|
82
82
|
else:
|
|
83
|
-
primitive_types = [int, float, str, bool]
|
|
84
|
-
if not type(obj) in primitive_types:
|
|
85
|
-
|
|
83
|
+
# primitive_types = [int, float, str, bool]
|
|
84
|
+
# if not type(obj) in primitive_types:
|
|
85
|
+
# logger.warning(f"Unknown type: {type(obj)}")
|
|
86
86
|
return xxhash.xxh64_hexdigest(fast_serialize(obj), seed=0)
|
|
87
87
|
|
|
88
88
|
|
|
@@ -1,32 +1,13 @@
|
|
|
1
1
|
# utils/utils_print.py
|
|
2
2
|
|
|
3
3
|
import copy
|
|
4
|
-
import inspect
|
|
5
|
-
import json
|
|
6
4
|
import pprint
|
|
7
|
-
import re
|
|
8
|
-
import sys
|
|
9
5
|
import textwrap
|
|
10
|
-
import
|
|
11
|
-
from collections import OrderedDict
|
|
12
|
-
from typing import Annotated, Any, Dict, List, Literal, Optional
|
|
6
|
+
from typing import Any
|
|
13
7
|
|
|
14
|
-
from IPython.display import HTML, display
|
|
15
|
-
from loguru import logger
|
|
16
8
|
from tabulate import tabulate
|
|
17
9
|
|
|
18
|
-
from .
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def display_pretty_table_html(data: dict) -> None:
|
|
22
|
-
"""
|
|
23
|
-
Display a pretty HTML table in Jupyter notebooks.
|
|
24
|
-
"""
|
|
25
|
-
table = "<table>"
|
|
26
|
-
for key, value in data.items():
|
|
27
|
-
table += f"<tr><td>{key}</td><td>{value}</td></tr>"
|
|
28
|
-
table += "</table>"
|
|
29
|
-
display(HTML(table))
|
|
10
|
+
from .notebook_utils import display_pretty_table_html
|
|
30
11
|
|
|
31
12
|
|
|
32
13
|
# Flattening the dictionary using "." notation for keys
|
|
@@ -166,51 +147,7 @@ def fprint(
|
|
|
166
147
|
printer.pprint(processed_data)
|
|
167
148
|
|
|
168
149
|
|
|
169
|
-
def print_table(data: Any, use_html: bool = True) -> None:
|
|
170
|
-
"""
|
|
171
|
-
Print data as a table. If use_html is True, display using IPython HTML.
|
|
172
|
-
"""
|
|
173
|
-
|
|
174
|
-
def __get_table(data: Any) -> str:
|
|
175
|
-
if isinstance(data, str):
|
|
176
|
-
try:
|
|
177
|
-
data = json.loads(data)
|
|
178
|
-
except json.JSONDecodeError as exc:
|
|
179
|
-
raise ValueError("String input could not be decoded as JSON") from exc
|
|
180
|
-
|
|
181
|
-
if isinstance(data, list):
|
|
182
|
-
if all(isinstance(item, dict) for item in data):
|
|
183
|
-
headers = list(data[0].keys())
|
|
184
|
-
rows = [list(item.values()) for item in data]
|
|
185
|
-
return tabulate(
|
|
186
|
-
rows, headers=headers, tablefmt="html" if use_html else "grid"
|
|
187
|
-
)
|
|
188
|
-
else:
|
|
189
|
-
raise ValueError("List must contain dictionaries")
|
|
190
|
-
|
|
191
|
-
if isinstance(data, dict):
|
|
192
|
-
headers = ["Key", "Value"]
|
|
193
|
-
rows = list(data.items())
|
|
194
|
-
return tabulate(
|
|
195
|
-
rows, headers=headers, tablefmt="html" if use_html else "grid"
|
|
196
|
-
)
|
|
197
|
-
|
|
198
|
-
raise TypeError(
|
|
199
|
-
"Input data must be a list of dictionaries, a dictionary, or a JSON string"
|
|
200
|
-
)
|
|
201
|
-
|
|
202
|
-
table = __get_table(data)
|
|
203
|
-
if use_html:
|
|
204
|
-
display(HTML(table))
|
|
205
|
-
else:
|
|
206
|
-
print(table)
|
|
207
|
-
|
|
208
|
-
|
|
209
150
|
__all__ = [
|
|
210
|
-
"display_pretty_table_html",
|
|
211
151
|
"flatten_dict",
|
|
212
152
|
"fprint",
|
|
213
|
-
"print_table",
|
|
214
|
-
# "setup_logger",
|
|
215
|
-
# "log",
|
|
216
153
|
]
|
|
@@ -75,6 +75,7 @@ def multi_process(
|
|
|
75
75
|
timeout: float | None = None,
|
|
76
76
|
stop_on_error: bool = True,
|
|
77
77
|
process_update_interval=10,
|
|
78
|
+
for_loop: bool = False,
|
|
78
79
|
**fixed_kwargs,
|
|
79
80
|
) -> list[Any]:
|
|
80
81
|
"""
|
|
@@ -95,6 +96,12 @@ def multi_process(
|
|
|
95
96
|
substitute failing result with ``None``.
|
|
96
97
|
**fixed_kwargs – static keyword args forwarded to every ``func()`` call.
|
|
97
98
|
"""
|
|
99
|
+
if for_loop:
|
|
100
|
+
ret = []
|
|
101
|
+
for arg in inputs:
|
|
102
|
+
ret.append(func(arg, **fixed_kwargs))
|
|
103
|
+
return ret
|
|
104
|
+
|
|
98
105
|
if workers is None:
|
|
99
106
|
workers = os.cpu_count() or 1
|
|
100
107
|
if inflight is None:
|
|
File without changes
|
speedy_utils/scripts/mpython.py
CHANGED
|
@@ -85,6 +85,7 @@ def main():
|
|
|
85
85
|
|
|
86
86
|
cpu_per_process = max(args.total_cpu // args.total_fold, 1)
|
|
87
87
|
cmds = []
|
|
88
|
+
path_python = shutil.which("python")
|
|
88
89
|
for i in range(args.total_fold):
|
|
89
90
|
gpu = gpus[i % num_gpus]
|
|
90
91
|
cpu_start = (i * cpu_per_process) % args.total_cpu
|
|
@@ -92,10 +93,10 @@ def main():
|
|
|
92
93
|
ENV = f"CUDA_VISIBLE_DEVICES={gpu} MP_ID={i} MP_TOTAL={args.total_fold}"
|
|
93
94
|
if taskset_path:
|
|
94
95
|
fold_cmd = (
|
|
95
|
-
f"{ENV} {taskset_path} -c {cpu_start}-{cpu_end}
|
|
96
|
+
f"{ENV} {taskset_path} -c {cpu_start}-{cpu_end} {path_python} {cmd_str}"
|
|
96
97
|
)
|
|
97
98
|
else:
|
|
98
|
-
fold_cmd = f"{ENV}
|
|
99
|
+
fold_cmd = f"{ENV} {path_python} {cmd_str}"
|
|
99
100
|
|
|
100
101
|
cmds.append(fold_cmd)
|
|
101
102
|
|