speedy-utils 1.1.33__py3-none-any.whl → 1.1.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/lm/__init__.py CHANGED
@@ -7,6 +7,7 @@ from .lm_base import LMBase, get_model_name
7
7
  from .mixins import (
8
8
  ModelUtilsMixin,
9
9
  TemperatureRangeMixin,
10
+ TokenizationMixin,
10
11
  TwoStepPydanticMixin,
11
12
  VLLMMixin,
12
13
  )
@@ -14,19 +15,20 @@ from .signature import Input, InputField, Output, OutputField, Signature
14
15
 
15
16
 
16
17
  __all__ = [
17
- "LMBase",
18
- "LLM",
19
- "AsyncLM",
20
- "AsyncLLMTask",
21
- "BasePromptBuilder",
22
- "LLMSignature",
23
- "Signature",
24
- "InputField",
25
- "OutputField",
26
- "Input",
27
- "Output",
28
- "TemperatureRangeMixin",
29
- "TwoStepPydanticMixin",
30
- "VLLMMixin",
31
- "ModelUtilsMixin",
18
+ 'LMBase',
19
+ 'LLM',
20
+ 'AsyncLM',
21
+ 'AsyncLLMTask',
22
+ 'BasePromptBuilder',
23
+ 'LLMSignature',
24
+ 'Signature',
25
+ 'InputField',
26
+ 'OutputField',
27
+ 'Input',
28
+ 'Output',
29
+ 'TemperatureRangeMixin',
30
+ 'TwoStepPydanticMixin',
31
+ 'VLLMMixin',
32
+ 'ModelUtilsMixin',
33
+ 'TokenizationMixin',
32
34
  ]
llm_utils/lm/llm.py CHANGED
@@ -20,6 +20,7 @@ from .base_prompt_builder import BasePromptBuilder
20
20
  from .mixins import (
21
21
  ModelUtilsMixin,
22
22
  TemperatureRangeMixin,
23
+ TokenizationMixin,
23
24
  TwoStepPydanticMixin,
24
25
  VLLMMixin,
25
26
  )
@@ -47,6 +48,7 @@ class LLM(
47
48
  TwoStepPydanticMixin,
48
49
  VLLMMixin,
49
50
  ModelUtilsMixin,
51
+ TokenizationMixin,
50
52
  ):
51
53
  """LLM task with structured input/output handling."""
52
54
 
llm_utils/lm/mixins.py CHANGED
@@ -396,6 +396,80 @@ class VLLMMixin:
396
396
  return _kill_vllm_on_port(port)
397
397
 
398
398
 
399
+ class TokenizationMixin:
400
+ """Mixin for tokenization operations (encode/decode)."""
401
+
402
+ def encode(
403
+ self,
404
+ text: str,
405
+ *,
406
+ add_special_tokens: bool = True,
407
+ return_token_strs: bool = False,
408
+ ) -> list[int] | tuple[list[int], list[str]]:
409
+ """
410
+ Encode text to token IDs using the model's tokenizer.
411
+
412
+ Args:
413
+ text: Text to tokenize
414
+ add_special_tokens: Whether to add special tokens (e.g., BOS)
415
+ return_token_strs: If True, also return token strings
416
+
417
+ Returns:
418
+ List of token IDs, or tuple of (token IDs, token strings)
419
+ """
420
+ import requests
421
+
422
+ # Get base_url from client and remove /v1 suffix if present
423
+ # (tokenize endpoint is at root level, not under /v1)
424
+ base_url = str(self.client.base_url).rstrip('/')
425
+ if base_url.endswith('/v1'):
426
+ base_url = base_url[:-3] # Remove '/v1'
427
+
428
+ response = requests.post(
429
+ f'{base_url}/tokenize',
430
+ json={
431
+ 'prompt': text,
432
+ 'add_special_tokens': add_special_tokens,
433
+ 'return_token_strs': return_token_strs,
434
+ },
435
+ )
436
+ response.raise_for_status()
437
+ data = response.json()
438
+
439
+ if return_token_strs:
440
+ return data['tokens'], data.get('token_strs', [])
441
+ return data['tokens']
442
+
443
+ def decode(
444
+ self,
445
+ token_ids: list[int],
446
+ ) -> str:
447
+ """
448
+ Decode token IDs to text using the model's tokenizer.
449
+
450
+ Args:
451
+ token_ids: List of token IDs to decode
452
+
453
+ Returns:
454
+ Decoded text string
455
+ """
456
+ import requests
457
+
458
+ # Get base_url from client and remove /v1 suffix if present
459
+ # (detokenize endpoint is at root level, not under /v1)
460
+ base_url = str(self.client.base_url).rstrip('/')
461
+ if base_url.endswith('/v1'):
462
+ base_url = base_url[:-3] # Remove '/v1'
463
+
464
+ response = requests.post(
465
+ f'{base_url}/detokenize',
466
+ json={'tokens': token_ids},
467
+ )
468
+ response.raise_for_status()
469
+ data = response.json()
470
+ return data['prompt']
471
+
472
+
399
473
  class ModelUtilsMixin:
400
474
  """Mixin for model utility methods."""
401
475
 
llm_utils/lm/utils.py CHANGED
@@ -282,7 +282,11 @@ def get_base_client(
282
282
  return MOpenAI(
283
283
  base_url=f"http://localhost:{port}/v1", api_key=api_key, cache=cache
284
284
  )
285
- raise ValueError("Either client or vllm_cmd must be provided.")
285
+ # Use default port 8000 when client is None
286
+ logger.info("No client specified, using default port 8000 at http://localhost:8000/v1")
287
+ return MOpenAI(
288
+ base_url="http://localhost:8000/v1", api_key=api_key, cache=cache
289
+ )
286
290
  if isinstance(client, int):
287
291
  return MOpenAI(
288
292
  base_url=f"http://localhost:{client}/v1", api_key=api_key, cache=cache
@@ -1,6 +1,3 @@
1
- # ray_multi_process.py
2
-
3
-
4
1
  from ..__imports import *
5
2
 
6
3
 
@@ -8,6 +5,12 @@ SPEEDY_RUNNING_PROCESSES: list[psutil.Process] = []
8
5
  _SPEEDY_PROCESSES_LOCK = threading.Lock()
9
6
 
10
7
 
8
+ # /mnt/data/anhvth8/venvs/Megatron-Bridge-Host/lib/python3.12/site-packages/ray/_private/worker.py:2046: FutureWarning: Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0
9
+ # turn off future warning and verbose task logs
10
+ os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
11
+ os.environ["RAY_DEDUP_LOGS"] = "0"
12
+ os.environ["RAY_LOG_TO_STDERR"] = "0"
13
+
11
14
  def _prune_dead_processes() -> None:
12
15
  """Remove dead processes from tracking list."""
13
16
  with _SPEEDY_PROCESSES_LOCK:
@@ -78,6 +81,7 @@ def _track_multiprocessing_processes() -> None:
78
81
 
79
82
  def _build_cache_dir(func: Callable, items: list[Any]) -> Path:
80
83
  """Build cache dir with function name + timestamp."""
84
+ import datetime
81
85
  func_name = getattr(func, '__name__', 'func')
82
86
  now = datetime.datetime.now()
83
87
  stamp = now.strftime('%m%d_%Hh%Mm%Ss')
@@ -85,9 +89,8 @@ def _build_cache_dir(func: Callable, items: list[Any]) -> Path:
85
89
  path = Path('.cache') / run_id
86
90
  path.mkdir(parents=True, exist_ok=True)
87
91
  return path
88
-
89
-
90
- def wrap_dump(func: Callable, cache_dir: Path | None):
92
+ _DUMP_THREADS = []
93
+ def wrap_dump(func: Callable, cache_dir: Path | None, dump_in_thread: bool = True):
91
94
  """Wrap a function so results are dumped to .pkl when cache_dir is set."""
92
95
  if cache_dir is None:
93
96
  return func
@@ -95,8 +98,24 @@ def wrap_dump(func: Callable, cache_dir: Path | None):
95
98
  def wrapped(x, *args, **kwargs):
96
99
  res = func(x, *args, **kwargs)
97
100
  p = cache_dir / f'{uuid.uuid4().hex}.pkl'
98
- with open(p, 'wb') as fh:
99
- pickle.dump(res, fh)
101
+
102
+ def save():
103
+ with open(p, 'wb') as fh:
104
+ pickle.dump(res, fh)
105
+ # Clean trash to avoid bloating memory
106
+ # print(f'Thread count: {threading.active_count()}')
107
+ # print(f'Saved result to {p}')
108
+
109
+ if dump_in_thread:
110
+ thread = threading.Thread(target=save)
111
+ _DUMP_THREADS.append(thread)
112
+ # count thread
113
+ # print(f'Thread count: {threading.active_count()}')
114
+ while threading.active_count() > 16:
115
+ time.sleep(0.1)
116
+ thread.start()
117
+ else:
118
+ save()
100
119
  return str(p)
101
120
 
102
121
  return wrapped
@@ -109,20 +128,28 @@ RAY_WORKER = None
109
128
 
110
129
  def ensure_ray(workers: int, pbar: tqdm | None = None):
111
130
  """Initialize or reinitialize Ray with a given worker count, log to bar postfix."""
131
+ import ray as _ray_module
132
+ import logging
133
+
112
134
  global RAY_WORKER
113
- if not ray.is_initialized() or workers != RAY_WORKER:
114
- if ray.is_initialized() and pbar:
135
+ # shutdown when worker count changes or if Ray not initialized
136
+ if not _ray_module.is_initialized() or workers != RAY_WORKER:
137
+ if _ray_module.is_initialized() and pbar:
115
138
  pbar.set_postfix_str(f'Restarting Ray {workers} workers')
116
- ray.shutdown()
139
+ _ray_module.shutdown()
117
140
  t0 = time.time()
118
- ray.init(num_cpus=workers, ignore_reinit_error=True)
141
+ _ray_module.init(
142
+ num_cpus=workers,
143
+ ignore_reinit_error=True,
144
+ logging_level=logging.ERROR,
145
+ log_to_driver=False,
146
+ )
119
147
  took = time.time() - t0
120
148
  _track_ray_processes() # Track Ray worker processes
121
149
  if pbar:
122
150
  pbar.set_postfix_str(f'ray.init {workers} took {took:.2f}s')
123
151
  RAY_WORKER = workers
124
152
 
125
-
126
153
  def multi_process(
127
154
  func: Callable[[Any], Any],
128
155
  items: Iterable[Any] | None = None,
@@ -134,6 +161,8 @@ def multi_process(
134
161
  # backend: str = "ray", # "seq", "ray", or "fastcore"
135
162
  backend: Literal['seq', 'ray', 'mp', 'threadpool', 'safe'] = 'mp',
136
163
  desc: str | None = None,
164
+ shared_kwargs: list[str] | None = None,
165
+ dump_in_thread: bool = True,
137
166
  **func_kwargs: Any,
138
167
  ) -> list[Any]:
139
168
  """
@@ -146,13 +175,55 @@ def multi_process(
146
175
  - "threadpool": run in parallel with thread pool
147
176
  - "safe": run in parallel with thread pool (explicitly safe for tests)
148
177
 
178
+ shared_kwargs:
179
+ - Optional list of kwarg names that should be shared via Ray's zero-copy object store
180
+ - Only works with Ray backend
181
+ - Useful for large objects (e.g., models, datasets) that should be shared across workers
182
+ - Example: shared_kwargs=['model', 'tokenizer'] for sharing large ML models
183
+
184
+ dump_in_thread:
185
+ - Whether to dump results to disk in a separate thread (default: True)
186
+ - If False, dumping is done synchronously, which may block but ensures data is saved before returning
187
+
149
188
  If lazy_output=True, every result is saved to .pkl and
150
189
  the returned list contains file paths.
151
190
  """
152
191
 
153
192
  # default backend selection
154
193
  if backend is None:
155
- backend = 'ray' if _HAS_RAY else 'mp'
194
+ try:
195
+ import ray as _ray_module
196
+ backend = 'ray'
197
+ except ImportError:
198
+ backend = 'mp'
199
+
200
+ # Validate shared_kwargs
201
+ if shared_kwargs:
202
+ # Validate that all shared_kwargs are valid kwargs for the function
203
+ sig = inspect.signature(func)
204
+ valid_params = set(sig.parameters.keys())
205
+
206
+ for kw in shared_kwargs:
207
+ if kw not in func_kwargs:
208
+ raise ValueError(
209
+ f"shared_kwargs key '{kw}' not found in provided func_kwargs"
210
+ )
211
+ # Check if parameter exists in function signature or if function accepts **kwargs
212
+ has_var_keyword = any(
213
+ p.kind == inspect.Parameter.VAR_KEYWORD
214
+ for p in sig.parameters.values()
215
+ )
216
+ if kw not in valid_params and not has_var_keyword:
217
+ raise ValueError(
218
+ f"shared_kwargs key '{kw}' is not a valid parameter for function '{func.__name__}'. "
219
+ f"Valid parameters: {valid_params}"
220
+ )
221
+
222
+ # Only allow shared_kwargs with Ray backend
223
+ if backend != 'ray':
224
+ raise ValueError(
225
+ f"shared_kwargs only supported with 'ray' backend, got '{backend}'"
226
+ )
156
227
 
157
228
  # unify items
158
229
  # unify items and coerce to concrete list so we can use len() and
@@ -169,7 +240,7 @@ def multi_process(
169
240
 
170
241
  # build cache dir + wrap func
171
242
  cache_dir = _build_cache_dir(func, items) if lazy_output else None
172
- f_wrapped = wrap_dump(func, cache_dir)
243
+ f_wrapped = wrap_dump(func, cache_dir, dump_in_thread)
173
244
 
174
245
  total = len(items)
175
246
  if desc:
@@ -181,8 +252,6 @@ def multi_process(
181
252
  ) as pbar:
182
253
  # ---- sequential backend ----
183
254
  if backend == 'seq':
184
- pbar.set_postfix_str('backend=seq')
185
- results = []
186
255
  for x in items:
187
256
  results.append(f_wrapped(x, **func_kwargs))
188
257
  pbar.update(1)
@@ -190,19 +259,46 @@ def multi_process(
190
259
 
191
260
  # ---- ray backend ----
192
261
  if backend == 'ray':
193
- pbar.set_postfix_str('backend=ray')
194
- ensure_ray(workers, pbar)
262
+ import ray as _ray_module
195
263
 
196
- @ray.remote
197
- def _task(x):
198
- return f_wrapped(x, **func_kwargs)
199
-
200
- refs = [_task.remote(x) for x in items]
264
+ ensure_ray(workers, pbar)
265
+ shared_refs = {}
266
+ regular_kwargs = {}
267
+
268
+ if shared_kwargs:
269
+ for kw in shared_kwargs:
270
+ # Put large objects in Ray's object store (zero-copy)
271
+ shared_refs[kw] = _ray_module.put(func_kwargs[kw])
272
+ pbar.set_postfix_str(f'ray: shared `{kw}` via object store')
273
+
274
+ # Remaining kwargs are regular
275
+ regular_kwargs = {
276
+ k: v for k, v in func_kwargs.items()
277
+ if k not in shared_kwargs
278
+ }
279
+ else:
280
+ regular_kwargs = func_kwargs
281
+
282
+ @_ray_module.remote
283
+ def _task(x, shared_refs_dict, regular_kwargs_dict):
284
+ # Dereference shared objects (zero-copy for numpy arrays)
285
+ import ray as _ray_in_task
286
+ dereferenced = {k: _ray_in_task.get(v) for k, v in shared_refs_dict.items()}
287
+ # Merge with regular kwargs
288
+ all_kwargs = {**dereferenced, **regular_kwargs_dict}
289
+ return f_wrapped(x, **all_kwargs)
290
+
291
+ refs = [
292
+ _task.remote(x, shared_refs, regular_kwargs) for x in items
293
+ ]
201
294
 
202
295
  results = []
296
+ t_start = time.time()
203
297
  for r in refs:
204
- results.append(ray.get(r))
298
+ results.append(_ray_module.get(r))
205
299
  pbar.update(1)
300
+ t_end = time.time()
301
+ print(f"Ray processing took {t_end - t_start:.2f}s for {total} items")
206
302
  return results
207
303
 
208
304
  # ---- fastcore backend ----
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: speedy-utils
3
- Version: 1.1.33
3
+ Version: 1.1.35
4
4
  Summary: Fast and easy-to-use package for data science
5
5
  Project-URL: Homepage, https://github.com/anhvth/speedy
6
6
  Project-URL: Repository, https://github.com/anhvth/speedy
@@ -4,15 +4,15 @@ llm_utils/chat_format/__init__.py,sha256=a7BKtBVktgLMq2Do4iNu3YfdDdTG1v9M_BkmaEo
4
4
  llm_utils/chat_format/display.py,sha256=Lffjzna9_vV3QgfiXZM2_tuVb3wqA-WxwrmoAjsJigw,17356
5
5
  llm_utils/chat_format/transform.py,sha256=PJ2g9KT1GSbWuAs7giEbTpTAffpU9QsIXyRlbfpTZUQ,5351
6
6
  llm_utils/chat_format/utils.py,sha256=M2EctZ6NeHXqFYufh26Y3CpSphN0bdZm5xoNaEJj5vg,1251
7
- llm_utils/lm/__init__.py,sha256=lFE2DZRpj6eRMo11kx7oRLyYOP2FuDmz08mAcq-cYew,730
7
+ llm_utils/lm/__init__.py,sha256=4jYMy3wPH3tg-tHFyWEWOqrnmX4Tu32VZCdzRGMGQsI,778
8
8
  llm_utils/lm/base_prompt_builder.py,sha256=_TzYMsWr-SsbA_JNXptUVN56lV5RfgWWTrFi-E8LMy4,12337
9
- llm_utils/lm/llm.py,sha256=C8Z8l6Ljs7uVX-zabLcDCdTf3fpGxfljaYRM0patHUQ,16469
9
+ llm_utils/lm/llm.py,sha256=yas7Khd0Djc8-GD8jL--B2oPteV9FC3PpfPbr9XCLOQ,16515
10
10
  llm_utils/lm/llm_signature.py,sha256=vV8uZgLLd6ZKqWbq0OPywWvXAfl7hrJQnbtBF-VnZRU,1244
11
11
  llm_utils/lm/lm_base.py,sha256=Bk3q34KrcCK_bC4Ryxbc3KqkiPL39zuVZaBQ1i6wJqs,9437
12
- llm_utils/lm/mixins.py,sha256=on83g-JO2SpZ0digOpU8mooqFBX6w7Bc-DeGzVoVCX8,14536
12
+ llm_utils/lm/mixins.py,sha256=o0tZiaKW4u1BxBVlT_0yTwnO8h7KnY02HX5TuWipvr0,16735
13
13
  llm_utils/lm/openai_memoize.py,sha256=rYrSFPpgO7adsjK1lVdkJlhqqIw_13TCW7zU8eNwm3o,5185
14
14
  llm_utils/lm/signature.py,sha256=K1hvCAqoC5CmsQ0Y_ywnYy2fRb5JzmIK8OS-hjH-5To,9971
15
- llm_utils/lm/utils.py,sha256=t-RSR-ffOs2HT67JfgqRhbZX9wbxlJNtQMro4ZyuVVM,12461
15
+ llm_utils/lm/utils.py,sha256=dEKFta8S6Mm4LjIctcpFlEGL9RnmLm5DHd2TA70UWuA,12649
16
16
  llm_utils/lm/async_lm/__init__.py,sha256=j0xK49ooZ0Dm5GstGGHbmPMrPjd3mOXoJ1H7eAL_Z4g,122
17
17
  llm_utils/lm/async_lm/_utils.py,sha256=mB-AueWJJatTx0PXqd_oWc6Kz36cfgDmDTKgiXafCJI,6106
18
18
  llm_utils/lm/async_lm/async_llm_task.py,sha256=2PWW4vPW2jYUiGmYFo4-DHrmX5Jm8Iw_1qo6EPL-ytE,18611
@@ -41,16 +41,16 @@ speedy_utils/common/utils_io.py,sha256=w9AxMD_8V3Wyo_0o9OtXjVQS8Z3KhxQiOkrl2p8Np
41
41
  speedy_utils/common/utils_misc.py,sha256=ZRJCS7OJxybpVm1sasoeCYRW2TaaGCXj4DySYlQeVR8,2227
42
42
  speedy_utils/common/utils_print.py,sha256=AGDB7mgJnO00QkJBH6kJb46738q3GzMUZPwtQ248vQw,4763
43
43
  speedy_utils/multi_worker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
- speedy_utils/multi_worker/process.py,sha256=O3BpGH7iL_2vh_ezwyHb28lvkVADABpTUnhKHbiEe8I,10542
44
+ speedy_utils/multi_worker/process.py,sha256=jk2K3oNnul1jop4g2U7-6GAekJ4fCyXCbj39WWAwXWQ,14925
45
45
  speedy_utils/multi_worker/thread.py,sha256=k4Ff4R2W0Ehet1zJ5nHQOfcsvOjnJzU6A2I18qw7_6M,21320
46
46
  speedy_utils/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
47
47
  speedy_utils/scripts/mpython.py,sha256=aZvusJLKa3APVhabbFUAEo873VBm8Fym7HKGmVW4LyE,3843
48
48
  speedy_utils/scripts/openapi_client_codegen.py,sha256=GModmmhkvGnxljK4KczyixKDrk-VEcLaW5I0XT6tzWo,9657
49
49
  vision_utils/README.md,sha256=AIDZZj8jo_QNrEjFyHwd00iOO431s-js-M2dLtVTn3I,5740
50
- vision_utils/__init__.py,sha256=XsLxy1Fn33Zxu6hTFl3NEWfxGjuQQ-0Wmoh6lU9NZ_o,257
51
- vision_utils/io_utils.py,sha256=q41pffN632HbMmzcBzfg2Z7DvZZgoAQCdD9jHLqDgjc,26603
52
- vision_utils/plot.py,sha256=v73onfH8KbGHigw5KStUPqbLyJqIEOvvJaqtaoGKrls,12032
53
- speedy_utils-1.1.33.dist-info/METADATA,sha256=QaZU14x_OlpExaRMZp3RnhkxEdRfBlhf0mqhaPTr6x4,8048
54
- speedy_utils-1.1.33.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
55
- speedy_utils-1.1.33.dist-info/entry_points.txt,sha256=1rrFMfqvaMUE9hvwGiD6vnVh98kmgy0TARBj-v0Lfhs,244
56
- speedy_utils-1.1.33.dist-info/RECORD,,
50
+ vision_utils/__init__.py,sha256=hF54sT6FAxby8kDVhOvruy4yot8O-Ateey5n96O1pQM,284
51
+ vision_utils/io_utils.py,sha256=pI0Va6miesBysJcllK6NXCay8HpGZsaMWwlsKB2DMgA,26510
52
+ vision_utils/plot.py,sha256=HkNj3osA3moPuupP1VguXfPPOW614dZO5tvC-EFKpKM,12028
53
+ speedy_utils-1.1.35.dist-info/METADATA,sha256=wsz89syaYNXEeGjJXV8zb0W2ZrTjpN2Lj47tE7LQeEI,8048
54
+ speedy_utils-1.1.35.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
55
+ speedy_utils-1.1.35.dist-info/entry_points.txt,sha256=1rrFMfqvaMUE9hvwGiD6vnVh98kmgy0TARBj-v0Lfhs,244
56
+ speedy_utils-1.1.35.dist-info/RECORD,,
vision_utils/__init__.py CHANGED
@@ -1,4 +1,11 @@
1
- from .io_utils import read_images, read_images_cpu, read_images_gpu, ImageMmap, ImageMmapDynamic
1
+ from .io_utils import (
2
+ ImageMmap,
3
+ ImageMmapDynamic,
4
+ read_images,
5
+ read_images_cpu,
6
+ read_images_gpu,
7
+ )
2
8
  from .plot import plot_images_notebook
3
9
 
4
- __all__ = ['plot_images_notebook', 'read_images_cpu', 'read_images_gpu', 'read_images', 'ImageMmap', 'ImageMmapDynamic']
10
+
11
+ __all__ = ['plot_images_notebook', 'read_images_cpu', 'read_images_gpu', 'read_images', 'ImageMmap', 'ImageMmapDynamic']
vision_utils/io_utils.py CHANGED
@@ -3,14 +3,16 @@ from __future__ import annotations
3
3
  # type: ignore
4
4
  import os
5
5
  import time
6
- from pathlib import Path
7
- from typing import Sequence, Tuple, TYPE_CHECKING
8
6
  from multiprocessing import cpu_count
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING, Sequence, Tuple
9
9
 
10
10
  import numpy as np
11
11
  from PIL import Image
12
+
12
13
  from speedy_utils import identify
13
14
 
15
+
14
16
  try:
15
17
  from torch.utils.data import Dataset
16
18
  except ImportError:
@@ -438,12 +440,11 @@ class ImageMmap(Dataset):
438
440
  if img is None:
439
441
  if self.safe:
440
442
  raise ValueError(f"Failed to load image: {path}")
441
- else:
442
- # Failed to load, write zeros
443
- print(f"Warning: Failed to load {path}, using zeros")
444
- mm[global_idx] = np.zeros(
445
- (self.H, self.W, self.C), dtype=self.dtype
446
- )
443
+ # Failed to load, write zeros
444
+ print(f"Warning: Failed to load {path}, using zeros")
445
+ mm[global_idx] = np.zeros(
446
+ (self.H, self.W, self.C), dtype=self.dtype
447
+ )
447
448
  else:
448
449
  # Clip to valid range and ensure correct dtype
449
450
  if self.dtype == np.uint8:
@@ -625,9 +626,10 @@ class ImageMmapDynamic(Dataset):
625
626
  - data file: concatenated flattened images in path order
626
627
  - meta: JSON with offsets, shapes, dtype, total_elems, paths, n
627
628
  """
628
- from tqdm import tqdm
629
629
  import json
630
630
 
631
+ from tqdm import tqdm
632
+
631
633
  print(f"Building dynamic mmap cache for {self.n} images...")
632
634
  # We don't know total size up front -> write sequentially
633
635
  offsets = np.zeros(self.n, dtype=np.int64)
@@ -660,11 +662,10 @@ class ImageMmapDynamic(Dataset):
660
662
  if img is None:
661
663
  if self.safe:
662
664
  raise ValueError(f"Failed to load image: {path}")
663
- else:
664
- print(
665
- f"Warning: Failed to load {path}, storing 1x1x3 zeros"
666
- )
667
- img = np.zeros((1, 1, 3), dtype=self.dtype)
665
+ print(
666
+ f"Warning: Failed to load {path}, storing 1x1x3 zeros"
667
+ )
668
+ img = np.zeros((1, 1, 3), dtype=self.dtype)
668
669
 
669
670
  # Clip to valid range for uint8
670
671
  if self.dtype == np.uint8:
vision_utils/plot.py CHANGED
@@ -1,8 +1,8 @@
1
1
  from pathlib import Path
2
2
  from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
3
3
 
4
- import numpy as np
5
4
  import matplotlib.pyplot as plt
5
+ import numpy as np
6
6
 
7
7
 
8
8
  if TYPE_CHECKING:
@@ -311,7 +311,7 @@ def visualize_tensor(img_tensor, mode='hwc', normalize=True, max_cols=8):
311
311
  mpl_available, plt = _check_matplotlib_available()
312
312
  if not mpl_available:
313
313
  raise ImportError("matplotlib is required for plotting. Install it with: pip install matplotlib")
314
-
314
+
315
315
  if mode == 'chw':
316
316
  img_tensor = img_tensor.permute(1, 2, 0)
317
317
  imgs = [img_tensor]