speedy-utils 1.0.13__py3-none-any.whl → 1.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,269 @@
1
+ """
2
+ Beautiful example script for interacting with VLLM server.
3
+
4
+ This script demonstrates various ways to use the VLLM API server
5
+ for text generation tasks.
6
+ """
7
+
8
+ import asyncio
9
+ import json
10
+ from typing import Dict, List, Optional, Any
11
+
12
+ import aiohttp
13
+ from loguru import logger
14
+ from pydantic import BaseModel, Field
15
+
16
+
17
+ class VLLMRequest(BaseModel):
18
+ """Request model for VLLM API."""
19
+ prompt: str
20
+ max_tokens: int = Field(default=512, ge=1, le=8192)
21
+ temperature: float = Field(default=0.7, ge=0.0, le=2.0)
22
+ top_p: float = Field(default=0.9, ge=0.0, le=1.0)
23
+ stream: bool = False
24
+ stop: Optional[List[str]] = None
25
+
26
+
27
+ class VLLMResponse(BaseModel):
28
+ """Response model from VLLM API."""
29
+ text: str
30
+ finish_reason: str
31
+ prompt_tokens: int
32
+ completion_tokens: int
33
+ total_tokens: int
34
+
35
+
36
+ class VLLMClient:
37
+ """Client for interacting with VLLM server."""
38
+
39
+ def __init__(self, base_url: str = 'http://localhost:8140'):
40
+ self.base_url = base_url
41
+ self.model_name = 'selfeval_8b'
42
+
43
+ async def generate_text(
44
+ self,
45
+ request: VLLMRequest
46
+ ) -> VLLMResponse:
47
+ """Generate text using VLLM API."""
48
+ url = f'{self.base_url}/v1/completions'
49
+
50
+ payload = {
51
+ 'model': self.model_name,
52
+ 'prompt': request.prompt,
53
+ 'max_tokens': request.max_tokens,
54
+ 'temperature': request.temperature,
55
+ 'top_p': request.top_p,
56
+ 'stream': request.stream,
57
+ }
58
+
59
+ if request.stop:
60
+ payload['stop'] = request.stop
61
+
62
+ async with aiohttp.ClientSession() as session:
63
+ try:
64
+ async with session.post(
65
+ url,
66
+ json=payload,
67
+ timeout=aiohttp.ClientTimeout(total=60)
68
+ ) as response:
69
+ response.raise_for_status()
70
+ data = await response.json()
71
+
72
+ choice = data['choices'][0]
73
+ usage = data['usage']
74
+
75
+ return VLLMResponse(
76
+ text=choice['text'],
77
+ finish_reason=choice['finish_reason'],
78
+ prompt_tokens=usage['prompt_tokens'],
79
+ completion_tokens=usage['completion_tokens'],
80
+ total_tokens=usage['total_tokens']
81
+ )
82
+
83
+ except aiohttp.ClientError as e:
84
+ logger.error(f'HTTP error: {e}')
85
+ raise
86
+ except Exception as e:
87
+ logger.error(f'Unexpected error: {e}')
88
+ raise
89
+
90
+ async def generate_batch(
91
+ self,
92
+ requests: List[VLLMRequest]
93
+ ) -> List[VLLMResponse]:
94
+ """Generate text for multiple requests concurrently."""
95
+ tasks = [self.generate_text(req) for req in requests]
96
+ return await asyncio.gather(*tasks, return_exceptions=True)
97
+
98
+ async def health_check(self) -> bool:
99
+ """Check if the VLLM server is healthy."""
100
+ url = f'{self.base_url}/health'
101
+
102
+ try:
103
+ async with aiohttp.ClientSession() as session:
104
+ async with session.get(
105
+ url,
106
+ timeout=aiohttp.ClientTimeout(total=10)
107
+ ) as response:
108
+ return response.status == 200
109
+ except Exception as e:
110
+ logger.warning(f'Health check failed: {e}')
111
+ return False
112
+
113
+
114
+ async def example_basic_generation():
115
+ """Example: Basic text generation."""
116
+ logger.info('🚀 Running basic generation example')
117
+
118
+ client = VLLMClient()
119
+
120
+ # Check server health
121
+ if not await client.health_check():
122
+ logger.error('❌ Server is not healthy')
123
+ return
124
+
125
+ request = VLLMRequest(
126
+ prompt='Explain the concept of machine learning in simple terms:',
127
+ max_tokens=256,
128
+ temperature=0.7,
129
+ stop=['\n\n']
130
+ )
131
+
132
+ try:
133
+ response = await client.generate_text(request)
134
+
135
+ logger.success('✅ Generation completed')
136
+ logger.info(f'📝 Generated text:\n{response.text}')
137
+ logger.info(f'📊 Tokens: {response.total_tokens} total '
138
+ f'({response.prompt_tokens} prompt + '
139
+ f'{response.completion_tokens} completion)')
140
+
141
+ except Exception as e:
142
+ logger.error(f'❌ Generation failed: {e}')
143
+
144
+
145
+ async def example_batch_generation():
146
+ """Example: Batch text generation."""
147
+ logger.info('🚀 Running batch generation example')
148
+
149
+ client = VLLMClient()
150
+
151
+ prompts = [
152
+ 'What is artificial intelligence?',
153
+ 'Explain quantum computing briefly:',
154
+ 'What are the benefits of renewable energy?'
155
+ ]
156
+
157
+ requests = [
158
+ VLLMRequest(
159
+ prompt=prompt,
160
+ max_tokens=128,
161
+ temperature=0.8
162
+ ) for prompt in prompts
163
+ ]
164
+
165
+ try:
166
+ responses = await client.generate_batch(requests)
167
+
168
+ for i, response in enumerate(responses):
169
+ if isinstance(response, Exception):
170
+ logger.error(f'❌ Request {i+1} failed: {response}')
171
+ else:
172
+ logger.success(f'✅ Request {i+1} completed')
173
+ logger.info(f'📝 Response {i+1}:\n{response.text}\n')
174
+
175
+ except Exception as e:
176
+ logger.error(f'❌ Batch generation failed: {e}')
177
+
178
+
179
+ async def example_creative_writing():
180
+ """Example: Creative writing with specific parameters."""
181
+ logger.info('🚀 Running creative writing example')
182
+
183
+ client = VLLMClient()
184
+
185
+ request = VLLMRequest(
186
+ prompt=(
187
+ 'Write a short story about a robot discovering emotions. '
188
+ 'The story should be exactly 3 paragraphs:\n\n'
189
+ ),
190
+ max_tokens=400,
191
+ temperature=1.2, # Higher temperature for creativity
192
+ top_p=0.95,
193
+ stop=['THE END', '\n\n\n']
194
+ )
195
+
196
+ try:
197
+ response = await client.generate_text(request)
198
+
199
+ logger.success('✅ Creative writing completed')
200
+ logger.info(f'📚 Story:\n{response.text}')
201
+ logger.info(f'🎯 Finish reason: {response.finish_reason}')
202
+
203
+ except Exception as e:
204
+ logger.error(f'❌ Creative writing failed: {e}')
205
+
206
+
207
+ async def example_code_generation():
208
+ """Example: Code generation."""
209
+ logger.info('🚀 Running code generation example')
210
+
211
+ client = VLLMClient()
212
+
213
+ request = VLLMRequest(
214
+ prompt=(
215
+ 'Write a Python function that calculates the fibonacci '
216
+ 'sequence up to n terms:\n\n```python\n'
217
+ ),
218
+ max_tokens=300,
219
+ temperature=0.2, # Lower temperature for code
220
+ stop=['```', '\n\n\n']
221
+ )
222
+
223
+ try:
224
+ response = await client.generate_text(request)
225
+
226
+ logger.success('✅ Code generation completed')
227
+ logger.info(f'💻 Generated code:\n```python\n{response.text}\n```')
228
+
229
+ except Exception as e:
230
+ logger.error(f'❌ Code generation failed: {e}')
231
+
232
+
233
+ async def main():
234
+ """Run all examples."""
235
+ logger.info('🎯 Starting VLLM Client Examples')
236
+ logger.info('=' * 50)
237
+
238
+ examples = [
239
+ example_basic_generation,
240
+ example_batch_generation,
241
+ example_creative_writing,
242
+ example_code_generation
243
+ ]
244
+
245
+ for example in examples:
246
+ await example()
247
+ logger.info('-' * 50)
248
+ await asyncio.sleep(1) # Brief pause between examples
249
+
250
+ logger.info('🎉 All examples completed!')
251
+
252
+
253
+ if __name__ == '__main__':
254
+ # Configure logger
255
+ logger.remove()
256
+ logger.add(
257
+ lambda msg: print(msg, end=''),
258
+ format='<green>{time:HH:mm:ss}</green> | '
259
+ '<level>{level: <8}</level> | '
260
+ '<cyan>{message}</cyan>',
261
+ level='INFO'
262
+ )
263
+
264
+ try:
265
+ asyncio.run(main())
266
+ except KeyboardInterrupt:
267
+ logger.info('\n👋 Goodbye!')
268
+ except Exception as e:
269
+ logger.error(f'❌ Script failed: {e}')
@@ -0,0 +1,3 @@
1
+ aiohttp>=3.8.0
2
+ loguru>=0.6.0
3
+ pydantic>=2.0.0
@@ -0,0 +1,2 @@
1
+ HF_HOME=/home/anhvth5/.cache/huggingface CUDA_VISIBLE_DEVICES=0 /home/anhvth5/miniconda3/envs/unsloth_env/bin/vllm serve ./outputs/8B_selfeval_retranslate/Qwen3-8B_2025_05_30/ls_response_only_r8_a8_sq8192_lr5e_06_bz64_ep1_4/ --port 8140 --tensor-parallel 1 --gpu-memory-utilization 0.9 --dtype auto --max-model-len 8192 --enable-prefix-caching --disable-log-requests --served-model-name selfeval_8b
2
+ Logging to /tmp/vllm_8140.txt
speedy_utils/__init__.py CHANGED
@@ -33,9 +33,11 @@ from .common.utils_misc import (
33
33
 
34
34
  # Print utilities
35
35
  from .common.utils_print import (
36
- display_pretty_table_html,
37
36
  flatten_dict,
38
37
  fprint,
38
+ )
39
+ from .common.notebook_utils import (
40
+ display_pretty_table_html,
39
41
  print_table,
40
42
  )
41
43
 
@@ -43,8 +45,98 @@ from .common.utils_print import (
43
45
  from .multi_worker.process import multi_process
44
46
  from .multi_worker.thread import multi_thread
45
47
 
48
+ # notebook
49
+ from .common.notebook_utils import change_dir
50
+
51
+ # Standard library imports
52
+ import copy
53
+ import functools
54
+ import gc
55
+ import inspect
56
+ import json
57
+ import multiprocessing
58
+ import os
59
+ import os.path as osp
60
+ import pickle
61
+ import pprint
62
+ import random
63
+ import re
64
+ import sys
65
+ import textwrap
66
+ import threading
67
+ import time
68
+ import traceback
69
+ import uuid
70
+ from collections import Counter, defaultdict
71
+ from collections.abc import Callable
72
+ from concurrent.futures import ThreadPoolExecutor, as_completed
73
+ from glob import glob
74
+ from multiprocessing import Pool
75
+ from pathlib import Path
76
+ from threading import Lock
77
+ from typing import Any, Dict, Generic, List, Literal, Optional, TypeVar, Union
78
+
79
+ # Third-party imports
80
+ import numpy as np
81
+ import pandas as pd
82
+ import xxhash
83
+ from IPython.core.getipython import get_ipython
84
+ from IPython.display import HTML, display
85
+ from loguru import logger
86
+ from pydantic import BaseModel
87
+ from tabulate import tabulate
88
+ from tqdm import tqdm
89
+
46
90
  # Define __all__ explicitly
47
91
  __all__ = [
92
+ # Standard library
93
+ "random",
94
+ "copy",
95
+ "functools",
96
+ "gc",
97
+ "inspect",
98
+ "json",
99
+ "multiprocessing",
100
+ "os",
101
+ "osp",
102
+ "pickle",
103
+ "pprint",
104
+ "re",
105
+ "sys",
106
+ "textwrap",
107
+ "threading",
108
+ "time",
109
+ "traceback",
110
+ "uuid",
111
+ "Counter",
112
+ "ThreadPoolExecutor",
113
+ "as_completed",
114
+ "glob",
115
+ "Pool",
116
+ "Path",
117
+ "Lock",
118
+ "defaultdict",
119
+ # Typing
120
+ "Any",
121
+ "Callable",
122
+ "Dict",
123
+ "Generic",
124
+ "List",
125
+ "Literal",
126
+ "Optional",
127
+ "TypeVar",
128
+ "Union",
129
+ # Third-party
130
+ "pd",
131
+ "xxhash",
132
+ "get_ipython",
133
+ "HTML",
134
+ "display",
135
+ "logger",
136
+ "BaseModel",
137
+ "tabulate",
138
+ "tqdm",
139
+ "np",
48
140
  # Clock module
49
141
  "Clock",
50
142
  "speedy_timer",
@@ -79,7 +171,6 @@ __all__ = [
79
171
  # Multi-worker processing
80
172
  "multi_process",
81
173
  "multi_thread",
82
- ]
83
-
84
- # Setup default logger
85
- # setup_logger('D')
174
+ # Notebook utilities
175
+ "change_dir",
176
+ ]
@@ -0,0 +1,63 @@
1
+ # jupyter notebook utilities
2
+ import json
3
+ import os
4
+ import pathlib
5
+ from typing import Any
6
+
7
+ from IPython.display import HTML, display
8
+ from tabulate import tabulate
9
+
10
+
11
+ def change_dir(target_directory: str = 'POLY') -> None:
12
+ """Change directory to the first occurrence of x in the current path."""
13
+ cur_dir = pathlib.Path('./')
14
+ target_dir = str(cur_dir.absolute()).split(target_directory)[0] + target_directory
15
+ os.chdir(target_dir)
16
+ print(f'Current dir: {target_dir}')
17
+
18
+
19
+ def display_pretty_table_html(data: dict) -> None:
20
+ """Display a pretty HTML table in Jupyter notebooks."""
21
+ table = "<table>"
22
+ for key, value in data.items():
23
+ table += f"<tr><td>{key}</td><td>{value}</td></tr>"
24
+ table += "</table>"
25
+ display(HTML(table))
26
+
27
+
28
+ def print_table(data: Any, use_html: bool = True) -> None:
29
+ """Print data as a table. If use_html is True, display using IPython HTML."""
30
+
31
+ def __get_table(data: Any) -> str:
32
+ if isinstance(data, str):
33
+ try:
34
+ data = json.loads(data)
35
+ except json.JSONDecodeError as exc:
36
+ raise ValueError("String input could not be decoded as JSON") from exc
37
+
38
+ if isinstance(data, list):
39
+ if all(isinstance(item, dict) for item in data):
40
+ headers = list(data[0].keys())
41
+ rows = [list(item.values()) for item in data]
42
+ return tabulate(
43
+ rows, headers=headers, tablefmt="html" if use_html else "grid"
44
+ )
45
+ else:
46
+ raise ValueError("List must contain dictionaries")
47
+
48
+ if isinstance(data, dict):
49
+ headers = ["Key", "Value"]
50
+ rows = list(data.items())
51
+ return tabulate(
52
+ rows, headers=headers, tablefmt="html" if use_html else "grid"
53
+ )
54
+
55
+ raise TypeError(
56
+ "Input data must be a list of dictionaries, a dictionary, or a JSON string"
57
+ )
58
+
59
+ table = __get_table(data)
60
+ if use_html:
61
+ display(HTML(table))
62
+ else:
63
+ print(table)
@@ -80,9 +80,9 @@ def identify(obj: Any, depth=0, max_depth=2) -> str:
80
80
  elif obj is None:
81
81
  return identify("None", depth + 1, max_depth)
82
82
  else:
83
- primitive_types = [int, float, str, bool]
84
- if not type(obj) in primitive_types:
85
- logger.warning(f"Unknown type: {type(obj)}")
83
+ # primitive_types = [int, float, str, bool]
84
+ # if not type(obj) in primitive_types:
85
+ # logger.warning(f"Unknown type: {type(obj)}")
86
86
  return xxhash.xxh64_hexdigest(fast_serialize(obj), seed=0)
87
87
 
88
88
 
@@ -1,32 +1,13 @@
1
1
  # utils/utils_print.py
2
2
 
3
3
  import copy
4
- import inspect
5
- import json
6
4
  import pprint
7
- import re
8
- import sys
9
5
  import textwrap
10
- import time
11
- from collections import OrderedDict
12
- from typing import Annotated, Any, Dict, List, Literal, Optional
6
+ from typing import Any
13
7
 
14
- from IPython.display import HTML, display
15
- from loguru import logger
16
8
  from tabulate import tabulate
17
9
 
18
- from .utils_misc import is_notebook
19
-
20
-
21
- def display_pretty_table_html(data: dict) -> None:
22
- """
23
- Display a pretty HTML table in Jupyter notebooks.
24
- """
25
- table = "<table>"
26
- for key, value in data.items():
27
- table += f"<tr><td>{key}</td><td>{value}</td></tr>"
28
- table += "</table>"
29
- display(HTML(table))
10
+ from .notebook_utils import display_pretty_table_html
30
11
 
31
12
 
32
13
  # Flattening the dictionary using "." notation for keys
@@ -166,51 +147,7 @@ def fprint(
166
147
  printer.pprint(processed_data)
167
148
 
168
149
 
169
- def print_table(data: Any, use_html: bool = True) -> None:
170
- """
171
- Print data as a table. If use_html is True, display using IPython HTML.
172
- """
173
-
174
- def __get_table(data: Any) -> str:
175
- if isinstance(data, str):
176
- try:
177
- data = json.loads(data)
178
- except json.JSONDecodeError as exc:
179
- raise ValueError("String input could not be decoded as JSON") from exc
180
-
181
- if isinstance(data, list):
182
- if all(isinstance(item, dict) for item in data):
183
- headers = list(data[0].keys())
184
- rows = [list(item.values()) for item in data]
185
- return tabulate(
186
- rows, headers=headers, tablefmt="html" if use_html else "grid"
187
- )
188
- else:
189
- raise ValueError("List must contain dictionaries")
190
-
191
- if isinstance(data, dict):
192
- headers = ["Key", "Value"]
193
- rows = list(data.items())
194
- return tabulate(
195
- rows, headers=headers, tablefmt="html" if use_html else "grid"
196
- )
197
-
198
- raise TypeError(
199
- "Input data must be a list of dictionaries, a dictionary, or a JSON string"
200
- )
201
-
202
- table = __get_table(data)
203
- if use_html:
204
- display(HTML(table))
205
- else:
206
- print(table)
207
-
208
-
209
150
  __all__ = [
210
- "display_pretty_table_html",
211
151
  "flatten_dict",
212
152
  "fprint",
213
- "print_table",
214
- # "setup_logger",
215
- # "log",
216
153
  ]
@@ -75,6 +75,7 @@ def multi_process(
75
75
  timeout: float | None = None,
76
76
  stop_on_error: bool = True,
77
77
  process_update_interval=10,
78
+ for_loop: bool = False,
78
79
  **fixed_kwargs,
79
80
  ) -> list[Any]:
80
81
  """
@@ -95,6 +96,12 @@ def multi_process(
95
96
  substitute failing result with ``None``.
96
97
  **fixed_kwargs – static keyword args forwarded to every ``func()`` call.
97
98
  """
99
+ if for_loop:
100
+ ret = []
101
+ for arg in inputs:
102
+ ret.append(func(arg, **fixed_kwargs))
103
+ return ret
104
+
98
105
  if workers is None:
99
106
  workers = os.cpu_count() or 1
100
107
  if inflight is None:
File without changes
@@ -85,6 +85,7 @@ def main():
85
85
 
86
86
  cpu_per_process = max(args.total_cpu // args.total_fold, 1)
87
87
  cmds = []
88
+ path_python = shutil.which("python")
88
89
  for i in range(args.total_fold):
89
90
  gpu = gpus[i % num_gpus]
90
91
  cpu_start = (i * cpu_per_process) % args.total_cpu
@@ -92,10 +93,10 @@ def main():
92
93
  ENV = f"CUDA_VISIBLE_DEVICES={gpu} MP_ID={i} MP_TOTAL={args.total_fold}"
93
94
  if taskset_path:
94
95
  fold_cmd = (
95
- f"{ENV} {taskset_path} -c {cpu_start}-{cpu_end} python {cmd_str}"
96
+ f"{ENV} {taskset_path} -c {cpu_start}-{cpu_end} {path_python} {cmd_str}"
96
97
  )
97
98
  else:
98
- fold_cmd = f"{ENV} python {cmd_str}"
99
+ fold_cmd = f"{ENV} {path_python} {cmd_str}"
99
100
 
100
101
  cmds.append(fold_cmd)
101
102