speedy-utils 1.1.32__py3-none-any.whl → 1.1.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/lm/llm.py CHANGED
@@ -63,9 +63,13 @@ class LLM(
63
63
  vllm_cmd: str | None = None,
64
64
  vllm_timeout: int = 1200,
65
65
  vllm_reuse: bool = True,
66
+ verbose=False,
66
67
  **model_kwargs,
67
68
  ):
68
69
  """Initialize LLMTask."""
70
+ if verbose:
71
+ available_models = LLM.list_models(client=client)
72
+ logger.info(f'Available models: {available_models}')
69
73
  self.instruction = instruction
70
74
  self.input_model = input_model
71
75
  self.output_model = output_model
llm_utils/lm/utils.py CHANGED
@@ -282,7 +282,11 @@ def get_base_client(
282
282
  return MOpenAI(
283
283
  base_url=f"http://localhost:{port}/v1", api_key=api_key, cache=cache
284
284
  )
285
- raise ValueError("Either client or vllm_cmd must be provided.")
285
+ # Use default port 8000 when client is None
286
+ logger.info("No client specified, using default port 8000 at http://localhost:8000/v1")
287
+ return MOpenAI(
288
+ base_url="http://localhost:8000/v1", api_key=api_key, cache=cache
289
+ )
286
290
  if isinstance(client, int):
287
291
  return MOpenAI(
288
292
  base_url=f"http://localhost:{client}/v1", api_key=api_key, cache=cache
speedy_utils/__imports.py CHANGED
@@ -34,7 +34,6 @@ import re
34
34
  import sys
35
35
  import textwrap
36
36
  import threading
37
- import time
38
37
  import traceback
39
38
  import types
40
39
  import uuid
@@ -77,6 +76,9 @@ from typing import (
77
76
  )
78
77
 
79
78
  import cachetools
79
+
80
+ # Direct imports (previously lazy-loaded)
81
+ import numpy as np
80
82
  import psutil
81
83
  from fastcore.parallel import parallel
82
84
  from json_repair import loads as jloads
@@ -84,52 +86,61 @@ from loguru import logger
84
86
  from tqdm import tqdm
85
87
 
86
88
 
87
- # Direct imports (previously lazy-loaded)
88
- import numpy as np
89
89
  tabulate = __import__('tabulate').tabulate
90
90
  import xxhash
91
91
 
92
+
92
93
  # Optional imports - lazy loaded for performance
93
94
  def _get_pandas():
94
95
  """Lazy import pandas."""
95
96
  try:
96
97
  import pandas as pd
98
+
97
99
  return pd
98
100
  except ImportError:
99
101
  return None
100
102
 
103
+
101
104
  def _get_ray():
102
105
  """Lazy import ray."""
103
106
  try:
104
107
  import ray
108
+
105
109
  return ray
106
110
  except ImportError:
107
111
  return None
108
112
 
113
+
109
114
  def _get_matplotlib():
110
115
  """Lazy import matplotlib."""
111
116
  try:
112
117
  import matplotlib
118
+
113
119
  return matplotlib
114
120
  except ImportError:
115
121
  return None
116
122
 
123
+
117
124
  def _get_matplotlib_pyplot():
118
125
  """Lazy import matplotlib.pyplot."""
119
126
  try:
120
127
  import matplotlib.pyplot as plt
128
+
121
129
  return plt
122
130
  except ImportError:
123
131
  return None
124
132
 
133
+
125
134
  def _get_ipython_core():
126
135
  """Lazy import IPython.core.getipython."""
127
136
  try:
128
137
  from IPython.core.getipython import get_ipython
129
- return get_ipython
138
+
139
+ return get_ipython()
130
140
  except ImportError:
131
141
  return None
132
142
 
143
+
133
144
  # Cache for lazy imports
134
145
  _pandas_cache = None
135
146
  _ray_cache = None
@@ -137,9 +148,11 @@ _matplotlib_cache = None
137
148
  _plt_cache = None
138
149
  _get_ipython_cache = None
139
150
 
151
+
140
152
  # Lazy import classes for performance-critical modules
141
153
  class _LazyModule:
142
154
  """Lazy module loader that imports only when accessed."""
155
+
143
156
  def __init__(self, import_func, cache_var_name):
144
157
  self._import_func = import_func
145
158
  self._cache_var_name = cache_var_name
@@ -168,9 +181,10 @@ class _LazyModule:
168
181
 
169
182
  def __repr__(self):
170
183
  if self._module is None:
171
- return f"<LazyModule: not loaded>"
184
+ return '<LazyModule: not loaded>'
172
185
  return repr(self._module)
173
186
 
187
+
174
188
  # Create lazy loaders for top slow imports (import only when accessed)
175
189
  pd = _LazyModule(_get_pandas, '_pandas_cache')
176
190
  ray = _LazyModule(_get_ray, '_ray_cache')
@@ -200,11 +214,12 @@ try:
200
214
  except ImportError:
201
215
  BaseModel = None
202
216
  if TYPE_CHECKING:
217
+ import matplotlib.pyplot as plt
203
218
  import numpy as np
204
219
  import pandas as pd
205
220
  import ray
206
221
  import torch
207
- import matplotlib.pyplot as plt
222
+
208
223
  # xxhash
209
224
  import xxhash # type: ignore
210
225
  from IPython.core.getipython import get_ipython # type: ignore
@@ -1,6 +1,3 @@
1
- # ray_multi_process.py
2
-
3
-
4
1
  from ..__imports import *
5
2
 
6
3
 
@@ -8,6 +5,12 @@ SPEEDY_RUNNING_PROCESSES: list[psutil.Process] = []
8
5
  _SPEEDY_PROCESSES_LOCK = threading.Lock()
9
6
 
10
7
 
8
+ # /mnt/data/anhvth8/venvs/Megatron-Bridge-Host/lib/python3.12/site-packages/ray/_private/worker.py:2046: FutureWarning: Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0
9
+ # turn off future warning and verbose task logs
10
+ os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
11
+ os.environ["RAY_DEDUP_LOGS"] = "0"
12
+ os.environ["RAY_LOG_TO_STDERR"] = "0"
13
+
11
14
  def _prune_dead_processes() -> None:
12
15
  """Remove dead processes from tracking list."""
13
16
  with _SPEEDY_PROCESSES_LOCK:
@@ -78,6 +81,7 @@ def _track_multiprocessing_processes() -> None:
78
81
 
79
82
  def _build_cache_dir(func: Callable, items: list[Any]) -> Path:
80
83
  """Build cache dir with function name + timestamp."""
84
+ import datetime
81
85
  func_name = getattr(func, '__name__', 'func')
82
86
  now = datetime.datetime.now()
83
87
  stamp = now.strftime('%m%d_%Hh%Mm%Ss')
@@ -85,9 +89,8 @@ def _build_cache_dir(func: Callable, items: list[Any]) -> Path:
85
89
  path = Path('.cache') / run_id
86
90
  path.mkdir(parents=True, exist_ok=True)
87
91
  return path
88
-
89
-
90
- def wrap_dump(func: Callable, cache_dir: Path | None):
92
+ _DUMP_THREADS = []
93
+ def wrap_dump(func: Callable, cache_dir: Path | None, dump_in_thread: bool = True):
91
94
  """Wrap a function so results are dumped to .pkl when cache_dir is set."""
92
95
  if cache_dir is None:
93
96
  return func
@@ -95,8 +98,24 @@ def wrap_dump(func: Callable, cache_dir: Path | None):
95
98
  def wrapped(x, *args, **kwargs):
96
99
  res = func(x, *args, **kwargs)
97
100
  p = cache_dir / f'{uuid.uuid4().hex}.pkl'
98
- with open(p, 'wb') as fh:
99
- pickle.dump(res, fh)
101
+
102
+ def save():
103
+ with open(p, 'wb') as fh:
104
+ pickle.dump(res, fh)
105
+ # Clean trash to avoid bloating memory
106
+ # print(f'Thread count: {threading.active_count()}')
107
+ # print(f'Saved result to {p}')
108
+
109
+ if dump_in_thread:
110
+ thread = threading.Thread(target=save)
111
+ _DUMP_THREADS.append(thread)
112
+ # count thread
113
+ # print(f'Thread count: {threading.active_count()}')
114
+ while threading.active_count() > 16:
115
+ time.sleep(0.1)
116
+ thread.start()
117
+ else:
118
+ save()
100
119
  return str(p)
101
120
 
102
121
  return wrapped
@@ -109,20 +128,28 @@ RAY_WORKER = None
109
128
 
110
129
  def ensure_ray(workers: int, pbar: tqdm | None = None):
111
130
  """Initialize or reinitialize Ray with a given worker count, log to bar postfix."""
131
+ import ray as _ray_module
132
+ import logging
133
+
112
134
  global RAY_WORKER
113
- if not ray.is_initialized() or workers != RAY_WORKER:
114
- if ray.is_initialized() and pbar:
135
+ # shutdown when worker count changes or if Ray not initialized
136
+ if not _ray_module.is_initialized() or workers != RAY_WORKER:
137
+ if _ray_module.is_initialized() and pbar:
115
138
  pbar.set_postfix_str(f'Restarting Ray {workers} workers')
116
- ray.shutdown()
139
+ _ray_module.shutdown()
117
140
  t0 = time.time()
118
- ray.init(num_cpus=workers, ignore_reinit_error=True)
141
+ _ray_module.init(
142
+ num_cpus=workers,
143
+ ignore_reinit_error=True,
144
+ logging_level=logging.ERROR,
145
+ log_to_driver=False,
146
+ )
119
147
  took = time.time() - t0
120
148
  _track_ray_processes() # Track Ray worker processes
121
149
  if pbar:
122
150
  pbar.set_postfix_str(f'ray.init {workers} took {took:.2f}s')
123
151
  RAY_WORKER = workers
124
152
 
125
-
126
153
  def multi_process(
127
154
  func: Callable[[Any], Any],
128
155
  items: Iterable[Any] | None = None,
@@ -134,6 +161,8 @@ def multi_process(
134
161
  # backend: str = "ray", # "seq", "ray", or "fastcore"
135
162
  backend: Literal['seq', 'ray', 'mp', 'threadpool', 'safe'] = 'mp',
136
163
  desc: str | None = None,
164
+ shared_kwargs: list[str] | None = None,
165
+ dump_in_thread: bool = True,
137
166
  **func_kwargs: Any,
138
167
  ) -> list[Any]:
139
168
  """
@@ -146,13 +175,55 @@ def multi_process(
146
175
  - "threadpool": run in parallel with thread pool
147
176
  - "safe": run in parallel with thread pool (explicitly safe for tests)
148
177
 
178
+ shared_kwargs:
179
+ - Optional list of kwarg names that should be shared via Ray's zero-copy object store
180
+ - Only works with Ray backend
181
+ - Useful for large objects (e.g., models, datasets) that should be shared across workers
182
+ - Example: shared_kwargs=['model', 'tokenizer'] for sharing large ML models
183
+
184
+ dump_in_thread:
185
+ - Whether to dump results to disk in a separate thread (default: True)
186
+ - If False, dumping is done synchronously, which may block but ensures data is saved before returning
187
+
149
188
  If lazy_output=True, every result is saved to .pkl and
150
189
  the returned list contains file paths.
151
190
  """
152
191
 
153
192
  # default backend selection
154
193
  if backend is None:
155
- backend = 'ray' if _HAS_RAY else 'mp'
194
+ try:
195
+ import ray as _ray_module
196
+ backend = 'ray'
197
+ except ImportError:
198
+ backend = 'mp'
199
+
200
+ # Validate shared_kwargs
201
+ if shared_kwargs:
202
+ # Validate that all shared_kwargs are valid kwargs for the function
203
+ sig = inspect.signature(func)
204
+ valid_params = set(sig.parameters.keys())
205
+
206
+ for kw in shared_kwargs:
207
+ if kw not in func_kwargs:
208
+ raise ValueError(
209
+ f"shared_kwargs key '{kw}' not found in provided func_kwargs"
210
+ )
211
+ # Check if parameter exists in function signature or if function accepts **kwargs
212
+ has_var_keyword = any(
213
+ p.kind == inspect.Parameter.VAR_KEYWORD
214
+ for p in sig.parameters.values()
215
+ )
216
+ if kw not in valid_params and not has_var_keyword:
217
+ raise ValueError(
218
+ f"shared_kwargs key '{kw}' is not a valid parameter for function '{func.__name__}'. "
219
+ f"Valid parameters: {valid_params}"
220
+ )
221
+
222
+ # Only allow shared_kwargs with Ray backend
223
+ if backend != 'ray':
224
+ raise ValueError(
225
+ f"shared_kwargs only supported with 'ray' backend, got '{backend}'"
226
+ )
156
227
 
157
228
  # unify items
158
229
  # unify items and coerce to concrete list so we can use len() and
@@ -169,7 +240,7 @@ def multi_process(
169
240
 
170
241
  # build cache dir + wrap func
171
242
  cache_dir = _build_cache_dir(func, items) if lazy_output else None
172
- f_wrapped = wrap_dump(func, cache_dir)
243
+ f_wrapped = wrap_dump(func, cache_dir, dump_in_thread)
173
244
 
174
245
  total = len(items)
175
246
  if desc:
@@ -181,8 +252,6 @@ def multi_process(
181
252
  ) as pbar:
182
253
  # ---- sequential backend ----
183
254
  if backend == 'seq':
184
- pbar.set_postfix_str('backend=seq')
185
- results = []
186
255
  for x in items:
187
256
  results.append(f_wrapped(x, **func_kwargs))
188
257
  pbar.update(1)
@@ -190,19 +259,46 @@ def multi_process(
190
259
 
191
260
  # ---- ray backend ----
192
261
  if backend == 'ray':
193
- pbar.set_postfix_str('backend=ray')
194
- ensure_ray(workers, pbar)
262
+ import ray as _ray_module
195
263
 
196
- @ray.remote
197
- def _task(x):
198
- return f_wrapped(x, **func_kwargs)
199
-
200
- refs = [_task.remote(x) for x in items]
264
+ ensure_ray(workers, pbar)
265
+ shared_refs = {}
266
+ regular_kwargs = {}
267
+
268
+ if shared_kwargs:
269
+ for kw in shared_kwargs:
270
+ # Put large objects in Ray's object store (zero-copy)
271
+ shared_refs[kw] = _ray_module.put(func_kwargs[kw])
272
+ pbar.set_postfix_str(f'ray: shared `{kw}` via object store')
273
+
274
+ # Remaining kwargs are regular
275
+ regular_kwargs = {
276
+ k: v for k, v in func_kwargs.items()
277
+ if k not in shared_kwargs
278
+ }
279
+ else:
280
+ regular_kwargs = func_kwargs
281
+
282
+ @_ray_module.remote
283
+ def _task(x, shared_refs_dict, regular_kwargs_dict):
284
+ # Dereference shared objects (zero-copy for numpy arrays)
285
+ import ray as _ray_in_task
286
+ dereferenced = {k: _ray_in_task.get(v) for k, v in shared_refs_dict.items()}
287
+ # Merge with regular kwargs
288
+ all_kwargs = {**dereferenced, **regular_kwargs_dict}
289
+ return f_wrapped(x, **all_kwargs)
290
+
291
+ refs = [
292
+ _task.remote(x, shared_refs, regular_kwargs) for x in items
293
+ ]
201
294
 
202
295
  results = []
296
+ t_start = time.time()
203
297
  for r in refs:
204
- results.append(ray.get(r))
298
+ results.append(_ray_module.get(r))
205
299
  pbar.update(1)
300
+ t_end = time.time()
301
+ print(f"Ray processing took {t_end - t_start:.2f}s for {total} items")
206
302
  return results
207
303
 
208
304
  # ---- fastcore backend ----
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: speedy-utils
3
- Version: 1.1.32
3
+ Version: 1.1.34
4
4
  Summary: Fast and easy-to-use package for data science
5
5
  Project-URL: Homepage, https://github.com/anhvth/speedy
6
6
  Project-URL: Repository, https://github.com/anhvth/speedy
@@ -6,13 +6,13 @@ llm_utils/chat_format/transform.py,sha256=PJ2g9KT1GSbWuAs7giEbTpTAffpU9QsIXyRlbf
6
6
  llm_utils/chat_format/utils.py,sha256=M2EctZ6NeHXqFYufh26Y3CpSphN0bdZm5xoNaEJj5vg,1251
7
7
  llm_utils/lm/__init__.py,sha256=lFE2DZRpj6eRMo11kx7oRLyYOP2FuDmz08mAcq-cYew,730
8
8
  llm_utils/lm/base_prompt_builder.py,sha256=_TzYMsWr-SsbA_JNXptUVN56lV5RfgWWTrFi-E8LMy4,12337
9
- llm_utils/lm/llm.py,sha256=Ryne4why4VgRCvTJW1SLJciv2pcYYROGP0AwkLVdGPg,16299
9
+ llm_utils/lm/llm.py,sha256=C8Z8l6Ljs7uVX-zabLcDCdTf3fpGxfljaYRM0patHUQ,16469
10
10
  llm_utils/lm/llm_signature.py,sha256=vV8uZgLLd6ZKqWbq0OPywWvXAfl7hrJQnbtBF-VnZRU,1244
11
11
  llm_utils/lm/lm_base.py,sha256=Bk3q34KrcCK_bC4Ryxbc3KqkiPL39zuVZaBQ1i6wJqs,9437
12
12
  llm_utils/lm/mixins.py,sha256=on83g-JO2SpZ0digOpU8mooqFBX6w7Bc-DeGzVoVCX8,14536
13
13
  llm_utils/lm/openai_memoize.py,sha256=rYrSFPpgO7adsjK1lVdkJlhqqIw_13TCW7zU8eNwm3o,5185
14
14
  llm_utils/lm/signature.py,sha256=K1hvCAqoC5CmsQ0Y_ywnYy2fRb5JzmIK8OS-hjH-5To,9971
15
- llm_utils/lm/utils.py,sha256=t-RSR-ffOs2HT67JfgqRhbZX9wbxlJNtQMro4ZyuVVM,12461
15
+ llm_utils/lm/utils.py,sha256=dEKFta8S6Mm4LjIctcpFlEGL9RnmLm5DHd2TA70UWuA,12649
16
16
  llm_utils/lm/async_lm/__init__.py,sha256=j0xK49ooZ0Dm5GstGGHbmPMrPjd3mOXoJ1H7eAL_Z4g,122
17
17
  llm_utils/lm/async_lm/_utils.py,sha256=mB-AueWJJatTx0PXqd_oWc6Kz36cfgDmDTKgiXafCJI,6106
18
18
  llm_utils/lm/async_lm/async_llm_task.py,sha256=2PWW4vPW2jYUiGmYFo4-DHrmX5Jm8Iw_1qo6EPL-ytE,18611
@@ -27,7 +27,7 @@ llm_utils/vector_cache/cli.py,sha256=MAvnmlZ7j7_0CvIcSyK4TvJlSRFWYkm4wE7zSq3KR8k
27
27
  llm_utils/vector_cache/core.py,sha256=VXuYJy1AX22NHKvIXRriETip5RrmQcNp73-g-ZT774o,30950
28
28
  llm_utils/vector_cache/types.py,sha256=CpMZanJSTeBVxQSqjBq6pBVWp7u2-JRcgY9t5jhykdQ,438
29
29
  llm_utils/vector_cache/utils.py,sha256=OsiRFydv8i8HiJtPL9hh40aUv8I5pYfg2zvmtDi4DME,1446
30
- speedy_utils/__imports.py,sha256=PhHqZWwVOKAbbXoWxZLVVyurGmZhui3boQ7Nji002cQ,7795
30
+ speedy_utils/__imports.py,sha256=V0YzkDK4-QkK_IDXY1be6C6_STuNhXAKIp4_dM0coQs,7800
31
31
  speedy_utils/__init__.py,sha256=VkKqS4eHXd8YeDu2TAQ3Osqy70RSufUL1sECDoYzqvM,2685
32
32
  speedy_utils/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
33
  speedy_utils/common/clock.py,sha256=raLtMGIgzrRej5kUt7hOUm2ZZw2THVPo-q8dMvdZOxw,7354
@@ -41,16 +41,16 @@ speedy_utils/common/utils_io.py,sha256=w9AxMD_8V3Wyo_0o9OtXjVQS8Z3KhxQiOkrl2p8Np
41
41
  speedy_utils/common/utils_misc.py,sha256=ZRJCS7OJxybpVm1sasoeCYRW2TaaGCXj4DySYlQeVR8,2227
42
42
  speedy_utils/common/utils_print.py,sha256=AGDB7mgJnO00QkJBH6kJb46738q3GzMUZPwtQ248vQw,4763
43
43
  speedy_utils/multi_worker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
- speedy_utils/multi_worker/process.py,sha256=O3BpGH7iL_2vh_ezwyHb28lvkVADABpTUnhKHbiEe8I,10542
44
+ speedy_utils/multi_worker/process.py,sha256=jk2K3oNnul1jop4g2U7-6GAekJ4fCyXCbj39WWAwXWQ,14925
45
45
  speedy_utils/multi_worker/thread.py,sha256=k4Ff4R2W0Ehet1zJ5nHQOfcsvOjnJzU6A2I18qw7_6M,21320
46
46
  speedy_utils/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
47
47
  speedy_utils/scripts/mpython.py,sha256=aZvusJLKa3APVhabbFUAEo873VBm8Fym7HKGmVW4LyE,3843
48
48
  speedy_utils/scripts/openapi_client_codegen.py,sha256=GModmmhkvGnxljK4KczyixKDrk-VEcLaW5I0XT6tzWo,9657
49
49
  vision_utils/README.md,sha256=AIDZZj8jo_QNrEjFyHwd00iOO431s-js-M2dLtVTn3I,5740
50
- vision_utils/__init__.py,sha256=XsLxy1Fn33Zxu6hTFl3NEWfxGjuQQ-0Wmoh6lU9NZ_o,257
51
- vision_utils/io_utils.py,sha256=q41pffN632HbMmzcBzfg2Z7DvZZgoAQCdD9jHLqDgjc,26603
52
- vision_utils/plot.py,sha256=v73onfH8KbGHigw5KStUPqbLyJqIEOvvJaqtaoGKrls,12032
53
- speedy_utils-1.1.32.dist-info/METADATA,sha256=ElLAOdGyTiqq33ON2WgHX4grgtlFqWRncupD0DivCBk,8048
54
- speedy_utils-1.1.32.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
55
- speedy_utils-1.1.32.dist-info/entry_points.txt,sha256=1rrFMfqvaMUE9hvwGiD6vnVh98kmgy0TARBj-v0Lfhs,244
56
- speedy_utils-1.1.32.dist-info/RECORD,,
50
+ vision_utils/__init__.py,sha256=hF54sT6FAxby8kDVhOvruy4yot8O-Ateey5n96O1pQM,284
51
+ vision_utils/io_utils.py,sha256=pI0Va6miesBysJcllK6NXCay8HpGZsaMWwlsKB2DMgA,26510
52
+ vision_utils/plot.py,sha256=HkNj3osA3moPuupP1VguXfPPOW614dZO5tvC-EFKpKM,12028
53
+ speedy_utils-1.1.34.dist-info/METADATA,sha256=diZ6MTVGRDDhsbxoK9eBydHrbW2I6rvYG8lXXzJnJEU,8048
54
+ speedy_utils-1.1.34.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
55
+ speedy_utils-1.1.34.dist-info/entry_points.txt,sha256=1rrFMfqvaMUE9hvwGiD6vnVh98kmgy0TARBj-v0Lfhs,244
56
+ speedy_utils-1.1.34.dist-info/RECORD,,
vision_utils/__init__.py CHANGED
@@ -1,4 +1,11 @@
1
- from .io_utils import read_images, read_images_cpu, read_images_gpu, ImageMmap, ImageMmapDynamic
1
+ from .io_utils import (
2
+ ImageMmap,
3
+ ImageMmapDynamic,
4
+ read_images,
5
+ read_images_cpu,
6
+ read_images_gpu,
7
+ )
2
8
  from .plot import plot_images_notebook
3
9
 
4
- __all__ = ['plot_images_notebook', 'read_images_cpu', 'read_images_gpu', 'read_images', 'ImageMmap', 'ImageMmapDynamic']
10
+
11
+ __all__ = ['plot_images_notebook', 'read_images_cpu', 'read_images_gpu', 'read_images', 'ImageMmap', 'ImageMmapDynamic']
vision_utils/io_utils.py CHANGED
@@ -3,14 +3,16 @@ from __future__ import annotations
3
3
  # type: ignore
4
4
  import os
5
5
  import time
6
- from pathlib import Path
7
- from typing import Sequence, Tuple, TYPE_CHECKING
8
6
  from multiprocessing import cpu_count
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING, Sequence, Tuple
9
9
 
10
10
  import numpy as np
11
11
  from PIL import Image
12
+
12
13
  from speedy_utils import identify
13
14
 
15
+
14
16
  try:
15
17
  from torch.utils.data import Dataset
16
18
  except ImportError:
@@ -438,12 +440,11 @@ class ImageMmap(Dataset):
438
440
  if img is None:
439
441
  if self.safe:
440
442
  raise ValueError(f"Failed to load image: {path}")
441
- else:
442
- # Failed to load, write zeros
443
- print(f"Warning: Failed to load {path}, using zeros")
444
- mm[global_idx] = np.zeros(
445
- (self.H, self.W, self.C), dtype=self.dtype
446
- )
443
+ # Failed to load, write zeros
444
+ print(f"Warning: Failed to load {path}, using zeros")
445
+ mm[global_idx] = np.zeros(
446
+ (self.H, self.W, self.C), dtype=self.dtype
447
+ )
447
448
  else:
448
449
  # Clip to valid range and ensure correct dtype
449
450
  if self.dtype == np.uint8:
@@ -625,9 +626,10 @@ class ImageMmapDynamic(Dataset):
625
626
  - data file: concatenated flattened images in path order
626
627
  - meta: JSON with offsets, shapes, dtype, total_elems, paths, n
627
628
  """
628
- from tqdm import tqdm
629
629
  import json
630
630
 
631
+ from tqdm import tqdm
632
+
631
633
  print(f"Building dynamic mmap cache for {self.n} images...")
632
634
  # We don't know total size up front -> write sequentially
633
635
  offsets = np.zeros(self.n, dtype=np.int64)
@@ -660,11 +662,10 @@ class ImageMmapDynamic(Dataset):
660
662
  if img is None:
661
663
  if self.safe:
662
664
  raise ValueError(f"Failed to load image: {path}")
663
- else:
664
- print(
665
- f"Warning: Failed to load {path}, storing 1x1x3 zeros"
666
- )
667
- img = np.zeros((1, 1, 3), dtype=self.dtype)
665
+ print(
666
+ f"Warning: Failed to load {path}, storing 1x1x3 zeros"
667
+ )
668
+ img = np.zeros((1, 1, 3), dtype=self.dtype)
668
669
 
669
670
  # Clip to valid range for uint8
670
671
  if self.dtype == np.uint8:
vision_utils/plot.py CHANGED
@@ -1,8 +1,8 @@
1
1
  from pathlib import Path
2
2
  from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
3
3
 
4
- import numpy as np
5
4
  import matplotlib.pyplot as plt
5
+ import numpy as np
6
6
 
7
7
 
8
8
  if TYPE_CHECKING:
@@ -311,7 +311,7 @@ def visualize_tensor(img_tensor, mode='hwc', normalize=True, max_cols=8):
311
311
  mpl_available, plt = _check_matplotlib_available()
312
312
  if not mpl_available:
313
313
  raise ImportError("matplotlib is required for plotting. Install it with: pip install matplotlib")
314
-
314
+
315
315
  if mode == 'chw':
316
316
  img_tensor = img_tensor.permute(1, 2, 0)
317
317
  imgs = [img_tensor]