speedy-utils 1.1.27__py3-none-any.whl → 1.1.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. llm_utils/__init__.py +16 -4
  2. llm_utils/chat_format/__init__.py +10 -10
  3. llm_utils/chat_format/display.py +33 -21
  4. llm_utils/chat_format/transform.py +17 -19
  5. llm_utils/chat_format/utils.py +6 -4
  6. llm_utils/group_messages.py +17 -14
  7. llm_utils/lm/__init__.py +6 -5
  8. llm_utils/lm/async_lm/__init__.py +1 -0
  9. llm_utils/lm/async_lm/_utils.py +10 -9
  10. llm_utils/lm/async_lm/async_llm_task.py +141 -137
  11. llm_utils/lm/async_lm/async_lm.py +48 -42
  12. llm_utils/lm/async_lm/async_lm_base.py +59 -60
  13. llm_utils/lm/async_lm/lm_specific.py +4 -3
  14. llm_utils/lm/base_prompt_builder.py +93 -70
  15. llm_utils/lm/llm.py +126 -108
  16. llm_utils/lm/llm_signature.py +4 -2
  17. llm_utils/lm/lm_base.py +72 -73
  18. llm_utils/lm/mixins.py +102 -62
  19. llm_utils/lm/openai_memoize.py +124 -87
  20. llm_utils/lm/signature.py +105 -92
  21. llm_utils/lm/utils.py +42 -23
  22. llm_utils/scripts/vllm_load_balancer.py +23 -30
  23. llm_utils/scripts/vllm_serve.py +8 -7
  24. llm_utils/vector_cache/__init__.py +9 -3
  25. llm_utils/vector_cache/cli.py +1 -1
  26. llm_utils/vector_cache/core.py +59 -63
  27. llm_utils/vector_cache/types.py +7 -5
  28. llm_utils/vector_cache/utils.py +12 -8
  29. speedy_utils/__imports.py +244 -0
  30. speedy_utils/__init__.py +90 -194
  31. speedy_utils/all.py +125 -227
  32. speedy_utils/common/clock.py +37 -42
  33. speedy_utils/common/function_decorator.py +6 -12
  34. speedy_utils/common/logger.py +43 -52
  35. speedy_utils/common/notebook_utils.py +13 -21
  36. speedy_utils/common/patcher.py +21 -17
  37. speedy_utils/common/report_manager.py +42 -44
  38. speedy_utils/common/utils_cache.py +152 -169
  39. speedy_utils/common/utils_io.py +137 -103
  40. speedy_utils/common/utils_misc.py +15 -21
  41. speedy_utils/common/utils_print.py +22 -28
  42. speedy_utils/multi_worker/process.py +66 -79
  43. speedy_utils/multi_worker/thread.py +78 -155
  44. speedy_utils/scripts/mpython.py +38 -36
  45. speedy_utils/scripts/openapi_client_codegen.py +10 -10
  46. {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.28.dist-info}/METADATA +1 -1
  47. speedy_utils-1.1.28.dist-info/RECORD +57 -0
  48. vision_utils/README.md +202 -0
  49. vision_utils/__init__.py +5 -0
  50. vision_utils/io_utils.py +470 -0
  51. vision_utils/plot.py +345 -0
  52. speedy_utils-1.1.27.dist-info/RECORD +0 -52
  53. {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.28.dist-info}/WHEEL +0 -0
  54. {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.28.dist-info}/entry_points.txt +0 -0
@@ -1,97 +1,5 @@
1
- """
2
- # ============================================================================= #
3
- # THREAD-BASED PARALLEL EXECUTION WITH PROGRESS TRACKING AND ERROR HANDLING
4
- # ============================================================================= #
5
- #
6
- # Title & Intent:
7
- # High-performance thread pool utilities for parallel processing with comprehensive error handling
8
- #
9
- # High-level Summary:
10
- # This module provides robust thread-based parallel execution utilities designed for CPU-bound
11
- # and I/O-bound tasks requiring concurrent processing. It features intelligent worker management,
12
- # comprehensive error handling with detailed tracebacks, progress tracking with tqdm integration,
13
- # and flexible batching strategies. The module optimizes for both throughput and reliability,
14
- # making it suitable for data processing pipelines, batch operations, and concurrent API calls.
15
- #
16
- # Public API / Data Contracts:
17
- # • multi_thread(func, inputs, *, workers=None, **kwargs) -> list[Any] - Main executor
18
- # • multi_thread_standard(func, inputs, workers=4) -> list[Any] - Simple ordered helper
19
- # • kill_all_thread(exc_type=SystemExit, join_timeout=0.1) -> int - Emergency stop
20
- # • DEFAULT_WORKERS = (cpu_count * 2) - Default worker thread count
21
- # • T = TypeVar('T'), R = TypeVar('R') - Generic type variables for input/output typing
22
- # • _group_iter(src, size) -> Iterable[list[T]] - Utility for chunking iterables
23
- # • _worker(item, func, fixed_kwargs) -> R - Individual worker function wrapper
24
- # • _ResultCollector - Maintains ordered/unordered result aggregation
25
- #
26
- # Invariants / Constraints:
27
- # • Worker count MUST be positive integer, defaults to (CPU cores * 2)
28
- # • Input iterables MUST be finite and non-empty for meaningful processing
29
- # • Functions MUST be thread-safe when used with multiple workers
30
- # • Error handling MUST capture and log detailed tracebacks for debugging
31
- # • Progress tracking MUST be optional and gracefully handle tqdm unavailability
32
- # • Batch processing MUST maintain input order in results
33
- # • MUST handle keyboard interruption gracefully with resource cleanup
34
- # • Thread pool MUST be properly closed and joined after completion
35
- #
36
- # Usage Example:
37
- # ```python
38
- # from speedy_utils.multi_worker.thread import multi_thread, multi_thread_batch
39
- # import requests
40
- #
41
- # # Simple parallel processing
42
- # def square(x):
43
- # return x ** 2
44
- #
45
- # numbers = list(range(100))
46
- # results = multi_thread(square, numbers, num_workers=8)
47
- # print(f"Processed {len(results)} items")
48
- #
49
- # # Parallel API calls with error handling
50
- # def fetch_url(url):
51
- # response = requests.get(url, timeout=10)
52
- # return response.status_code, len(response.content)
53
- #
54
- # urls = ["http://example.com", "http://google.com", "http://github.com"]
55
- # results = multi_thread(fetch_url, urls, num_workers=3, progress=True)
56
- #
57
- # # Batched processing for memory efficiency
58
- # def process_batch(items):
59
- # return [item.upper() for item in items]
60
- #
61
- # large_dataset = ["item" + str(i) for i in range(10000)]
62
- # batched_results = multi_thread_batch(
63
- # process_batch,
64
- # large_dataset,
65
- # batch_size=100,
66
- # num_workers=4
67
- # )
68
- # ```
69
- #
70
- # TODO & Future Work:
71
- # • Add adaptive worker count based on task characteristics
72
- # • Implement priority queuing for time-sensitive tasks
73
- # • Add memory usage monitoring and automatic batch size adjustment
74
- # • Support for async function execution within thread pool
75
- # • Add detailed performance metrics and timing analysis
76
- # • Implement graceful degradation for resource-constrained environments
77
- #
78
- # ============================================================================= #
79
- """
80
-
81
- import ctypes
82
- import os
83
- import sys
84
- import threading
85
- import time
86
- import traceback
87
- from collections.abc import Callable, Iterable, Mapping, Sequence
88
- from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
89
- from heapq import heappop, heappush
90
- from itertools import islice
91
- from types import MappingProxyType
92
- from typing import Any, Generic, TypeVar, cast
93
-
94
- from loguru import logger
1
+ from ..__imports import *
2
+
95
3
 
96
4
  try:
97
5
  from tqdm import tqdm
@@ -101,8 +9,8 @@ except ImportError: # pragma: no cover
101
9
  # Sensible defaults
102
10
  DEFAULT_WORKERS = (os.cpu_count() or 4) * 2
103
11
 
104
- T = TypeVar("T")
105
- R = TypeVar("R")
12
+ T = TypeVar('T')
13
+ R = TypeVar('R')
106
14
 
107
15
  SPEEDY_RUNNING_THREADS: list[threading.Thread] = [] # cooperative shutdown tracking
108
16
  _SPEEDY_THREADS_LOCK = threading.Lock()
@@ -124,11 +32,11 @@ class UserFunctionError(Exception):
124
32
  self.user_traceback = user_traceback
125
33
 
126
34
  # Create a focused error message
127
- tb_str = "".join(traceback.format_list(user_traceback))
35
+ tb_str = ''.join(traceback.format_list(user_traceback))
128
36
  msg = (
129
37
  f'\nError in function "{func_name}" with input: {input_value!r}\n'
130
- f"\nUser code traceback:\n{tb_str}"
131
- f"{type(original_exception).__name__}: {original_exception}"
38
+ f'\nUser code traceback:\n{tb_str}'
39
+ f'{type(original_exception).__name__}: {original_exception}'
132
40
  )
133
41
  super().__init__(msg)
134
42
 
@@ -165,7 +73,7 @@ def _track_threads(threads: Iterable[threading.Thread]) -> None:
165
73
 
166
74
 
167
75
  def _track_executor_threads(pool: ThreadPoolExecutor) -> None:
168
- thread_set = getattr(pool, "_threads", None)
76
+ thread_set = getattr(pool, '_threads', None)
169
77
  if not thread_set:
170
78
  return
171
79
  _track_threads(tuple(thread_set))
@@ -188,9 +96,9 @@ def _worker(
188
96
  if not callable(func):
189
97
  func_type = type(func).__name__
190
98
  raise TypeError(
191
- f"\nmulti_thread: func parameter must be callable, "
192
- f"got {func_type}: {func!r}\n"
193
- f"Hint: Did you accidentally pass a {func_type} instead of a function?"
99
+ f'\nmulti_thread: func parameter must be callable, '
100
+ f'got {func_type}: {func!r}\n'
101
+ f'Hint: Did you accidentally pass a {func_type} instead of a function?'
194
102
  )
195
103
 
196
104
  try:
@@ -205,9 +113,9 @@ def _worker(
205
113
  # Filter to keep only user code frames
206
114
  user_frames = []
207
115
  skip_patterns = [
208
- "multi_worker/thread.py",
209
- "concurrent/futures/",
210
- "threading.py",
116
+ 'multi_worker/thread.py',
117
+ 'concurrent/futures/',
118
+ 'threading.py',
211
119
  ]
212
120
 
213
121
  for frame in tb_list:
@@ -216,7 +124,7 @@ def _worker(
216
124
 
217
125
  # If we have user frames, wrap in our custom exception
218
126
  if user_frames:
219
- func_name = getattr(func, "__name__", repr(func))
127
+ func_name = getattr(func, '__name__', repr(func))
220
128
  raise UserFunctionError(
221
129
  exc,
222
130
  func_name,
@@ -237,14 +145,14 @@ def _run_batch(
237
145
 
238
146
 
239
147
  def _attach_metadata(fut: Future[Any], idx: int, logical_size: int) -> None:
240
- setattr(fut, "_speedy_idx", idx)
241
- setattr(fut, "_speedy_size", logical_size)
148
+ fut._speedy_idx = idx
149
+ fut._speedy_size = logical_size
242
150
 
243
151
 
244
152
  def _future_meta(fut: Future[Any]) -> tuple[int, int]:
245
153
  return (
246
- getattr(fut, "_speedy_idx"),
247
- getattr(fut, "_speedy_size"),
154
+ fut._speedy_idx,
155
+ fut._speedy_size,
248
156
  )
249
157
 
250
158
 
@@ -292,7 +200,7 @@ def _resolve_worker_count(workers: int | None) -> int:
292
200
  if workers is None:
293
201
  return DEFAULT_WORKERS
294
202
  if workers <= 0:
295
- raise ValueError("workers must be a positive integer")
203
+ raise ValueError('workers must be a positive integer')
296
204
  return workers
297
205
 
298
206
 
@@ -300,18 +208,16 @@ def _normalize_batch_result(result: Any, logical_size: int) -> list[Any]:
300
208
  if logical_size == 1:
301
209
  return [result]
302
210
  if result is None:
303
- raise ValueError("batched callable returned None for a batch result")
211
+ raise ValueError('batched callable returned None for a batch result')
304
212
  if isinstance(result, (str, bytes, bytearray)):
305
- raise TypeError("batched callable must not return str/bytes when batching")
306
- if isinstance(result, Sequence):
307
- out = list(result)
308
- elif isinstance(result, Iterable):
213
+ raise TypeError('batched callable must not return str/bytes when batching')
214
+ if isinstance(result, (Sequence, Iterable)):
309
215
  out = list(result)
310
216
  else:
311
- raise TypeError("batched callable must return an iterable of results")
217
+ raise TypeError('batched callable must return an iterable of results')
312
218
  if len(out) != logical_size:
313
219
  raise ValueError(
314
- f"batched callable returned {len(out)} items, expected {logical_size}",
220
+ f'batched callable returned {len(out)} items, expected {logical_size}',
315
221
  )
316
222
  return out
317
223
 
@@ -398,7 +304,9 @@ def multi_thread(
398
304
  results: list[R | None] = []
399
305
 
400
306
  for proc_idx, chunk in enumerate(chunks):
401
- with tempfile.NamedTemporaryFile(delete=False, suffix="multi_thread.pkl") as fh:
307
+ with tempfile.NamedTemporaryFile(
308
+ delete=False, suffix='multi_thread.pkl'
309
+ ) as fh:
402
310
  file_pkl = fh.name
403
311
  assert isinstance(in_process_multi_thread, Callable)
404
312
  proc = in_process_multi_thread(
@@ -420,28 +328,28 @@ def multi_thread(
420
328
 
421
329
  for proc, file_pkl in procs:
422
330
  proc.join()
423
- logger.info("process finished: %s", proc)
331
+ logger.info('process finished: %s', proc)
424
332
  try:
425
333
  results.extend(load_by_ext(file_pkl))
426
334
  finally:
427
335
  try:
428
336
  os.unlink(file_pkl)
429
337
  except OSError as exc: # pragma: no cover - best effort cleanup
430
- logger.warning("failed to remove temp file %s: %s", file_pkl, exc)
338
+ logger.warning('failed to remove temp file %s: %s', file_pkl, exc)
431
339
  return results
432
340
 
433
341
  try:
434
342
  import pandas as pd
435
343
 
436
344
  if isinstance(inputs, pd.DataFrame):
437
- inputs = cast(Iterable[T], inputs.to_dict(orient="records"))
345
+ inputs = cast(Iterable[T], inputs.to_dict(orient='records'))
438
346
  except ImportError: # pragma: no cover - optional dependency
439
347
  pass
440
348
 
441
349
  if batch <= 0:
442
- raise ValueError("batch must be a positive integer")
350
+ raise ValueError('batch must be a positive integer')
443
351
  if prefetch_factor <= 0:
444
- raise ValueError("prefetch_factor must be a positive integer")
352
+ raise ValueError('prefetch_factor must be a positive integer')
445
353
 
446
354
  workers_val = _resolve_worker_count(workers)
447
355
  progress_update = max(progress_update, 1)
@@ -463,12 +371,19 @@ def multi_thread(
463
371
 
464
372
  bar = None
465
373
  last_bar_update = 0
466
- if progress and tqdm is not None and logical_total is not None and logical_total > 0:
374
+ if (
375
+ progress
376
+ and tqdm is not None
377
+ and logical_total is not None
378
+ and logical_total > 0
379
+ ):
467
380
  bar = tqdm(
468
381
  total=logical_total,
469
382
  ncols=128,
470
- colour="green",
471
- bar_format=("{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]"),
383
+ colour='green',
384
+ bar_format=(
385
+ '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
386
+ ),
472
387
  )
473
388
 
474
389
  deadline = time.monotonic() + timeout if timeout is not None else None
@@ -482,9 +397,9 @@ def multi_thread(
482
397
  inflight: set[Future[Any]] = set()
483
398
  pool = ThreadPoolExecutor(
484
399
  max_workers=workers_val,
485
- thread_name_prefix="speedy-thread",
400
+ thread_name_prefix='speedy-thread',
486
401
  )
487
- shutdown_kwargs: dict[str, Any] = {"wait": True}
402
+ shutdown_kwargs: dict[str, Any] = {'wait': True}
488
403
 
489
404
  try:
490
405
 
@@ -517,7 +432,7 @@ def multi_thread(
517
432
  if remaining <= 0:
518
433
  _cancel_futures(inflight)
519
434
  raise TimeoutError(
520
- f"multi_thread timed out after {timeout} seconds",
435
+ f'multi_thread timed out after {timeout} seconds',
521
436
  )
522
437
  wait_timeout = max(remaining, 0.0)
523
438
 
@@ -530,7 +445,7 @@ def multi_thread(
530
445
  if not done:
531
446
  _cancel_futures(inflight)
532
447
  raise TimeoutError(
533
- f"multi_thread timed out after {timeout} seconds",
448
+ f'multi_thread timed out after {timeout} seconds',
534
449
  )
535
450
 
536
451
  for fut in done:
@@ -549,11 +464,11 @@ def multi_thread(
549
464
  orig_exc = exc.original_exception
550
465
 
551
466
  # Build new traceback from user frames only
552
- tb_str = "".join(traceback.format_list(exc.user_traceback))
467
+ tb_str = ''.join(traceback.format_list(exc.user_traceback))
553
468
  clean_msg = (
554
469
  f'\nError in "{exc.func_name}" '
555
- f"with input: {exc.input_value!r}\n\n{tb_str}"
556
- f"{type(orig_exc).__name__}: {orig_exc}"
470
+ f'with input: {exc.input_value!r}\n\n{tb_str}'
471
+ f'{type(orig_exc).__name__}: {orig_exc}'
557
472
  )
558
473
 
559
474
  # Raise a new instance of the original exception type
@@ -568,7 +483,7 @@ def multi_thread(
568
483
  if stop_on_error:
569
484
  _cancel_futures(inflight)
570
485
  raise
571
- logger.exception("multi_thread task failed", exc_info=exc)
486
+ logger.exception('multi_thread task failed', exc_info=exc)
572
487
  out_items = [None] * logical_size
573
488
  else:
574
489
  try:
@@ -576,7 +491,7 @@ def multi_thread(
576
491
  except Exception as exc:
577
492
  _cancel_futures(inflight)
578
493
  raise RuntimeError(
579
- "batched callable returned an unexpected shape",
494
+ 'batched callable returned an unexpected shape',
580
495
  ) from exc
581
496
 
582
497
  collector.add(idx, out_items)
@@ -588,10 +503,14 @@ def multi_thread(
588
503
  bar.update(delta)
589
504
  last_bar_update = completed_items
590
505
  submitted = next_logical_idx
591
- pending = max(logical_total - submitted, 0) if logical_total is not None else "-"
506
+ pending = (
507
+ max(logical_total - submitted, 0)
508
+ if logical_total is not None
509
+ else '-'
510
+ )
592
511
  postfix = {
593
- "processing": min(len(inflight), workers_val),
594
- "pending": pending,
512
+ 'processing': min(len(inflight), workers_val),
513
+ 'pending': pending,
595
514
  }
596
515
  bar.set_postfix(postfix)
597
516
 
@@ -604,7 +523,7 @@ def multi_thread(
604
523
  results = collector.finalize()
605
524
 
606
525
  except KeyboardInterrupt:
607
- shutdown_kwargs = {"wait": False, "cancel_futures": True}
526
+ shutdown_kwargs = {'wait': False, 'cancel_futures': True}
608
527
  _cancel_futures(inflight)
609
528
  kill_all_thread(SystemExit)
610
529
  raise KeyboardInterrupt() from None
@@ -612,27 +531,29 @@ def multi_thread(
612
531
  try:
613
532
  pool.shutdown(**shutdown_kwargs)
614
533
  except TypeError: # pragma: no cover - Python <3.9 fallback
615
- pool.shutdown(shutdown_kwargs.get("wait", True))
534
+ pool.shutdown(shutdown_kwargs.get('wait', True))
616
535
  if bar:
617
536
  delta = completed_items - last_bar_update
618
537
  if delta > 0:
619
538
  bar.update(delta)
620
539
  bar.close()
621
540
 
622
- results = collector.finalize() if "results" not in locals() else results
541
+ results = collector.finalize() if 'results' not in locals() else results
623
542
  if store_output_pkl_file:
624
543
  dump_json_or_pickle(results, store_output_pkl_file)
625
544
  _prune_dead_threads()
626
545
  return results
627
546
 
628
547
 
629
- def multi_thread_standard(fn: Callable[[T], R], items: Iterable[T], workers: int = 4) -> list[R]:
548
+ def multi_thread_standard(
549
+ fn: Callable[[T], R], items: Iterable[T], workers: int = 4
550
+ ) -> list[R]:
630
551
  """Execute ``fn`` across ``items`` while preserving submission order."""
631
552
 
632
553
  workers_val = _resolve_worker_count(workers)
633
554
  with ThreadPoolExecutor(
634
555
  max_workers=workers_val,
635
- thread_name_prefix="speedy-thread",
556
+ thread_name_prefix='speedy-thread',
636
557
  ) as executor:
637
558
  futures: list[Future[R]] = []
638
559
  for item in items:
@@ -647,17 +568,19 @@ def _async_raise(thread_id: int, exc_type: type[BaseException]) -> bool:
647
568
  if thread_id <= 0:
648
569
  return False
649
570
  if not issubclass(exc_type, BaseException):
650
- raise TypeError("exc_type must derive from BaseException")
571
+ raise TypeError('exc_type must derive from BaseException')
651
572
  res = _PY_SET_ASYNC_EXC(ctypes.c_ulong(thread_id), ctypes.py_object(exc_type))
652
573
  if res == 0:
653
574
  return False
654
575
  if res > 1: # pragma: no cover - defensive branch
655
576
  _PY_SET_ASYNC_EXC(ctypes.c_ulong(thread_id), None)
656
- raise SystemError("PyThreadState_SetAsyncExc failed")
577
+ raise SystemError('PyThreadState_SetAsyncExc failed')
657
578
  return True
658
579
 
659
580
 
660
- def kill_all_thread(exc_type: type[BaseException] = SystemExit, join_timeout: float = 0.1) -> int:
581
+ def kill_all_thread(
582
+ exc_type: type[BaseException] = SystemExit, join_timeout: float = 0.1
583
+ ) -> int:
661
584
  """Forcefully stop tracked worker threads (dangerous; use sparingly).
662
585
 
663
586
  Returns
@@ -682,17 +605,17 @@ def kill_all_thread(exc_type: type[BaseException] = SystemExit, join_timeout: fl
682
605
  terminated += 1
683
606
  thread.join(timeout=join_timeout)
684
607
  else:
685
- logger.warning("Unable to signal thread %s", thread.name)
608
+ logger.warning('Unable to signal thread %s', thread.name)
686
609
  except Exception as exc: # pragma: no cover - defensive
687
- logger.error("Failed to stop thread %s: %s", thread.name, exc)
610
+ logger.error('Failed to stop thread %s: %s', thread.name, exc)
688
611
  _prune_dead_threads()
689
612
  return terminated
690
613
 
691
614
 
692
615
  __all__ = [
693
- "SPEEDY_RUNNING_THREADS",
694
- "UserFunctionError",
695
- "multi_thread",
696
- "multi_thread_standard",
697
- "kill_all_thread",
616
+ 'SPEEDY_RUNNING_THREADS',
617
+ 'UserFunctionError',
618
+ 'multi_thread',
619
+ 'multi_thread_standard',
620
+ 'kill_all_thread',
698
621
  ]
@@ -6,101 +6,103 @@ import os
6
6
  import shlex # To properly escape command line arguments
7
7
  import shutil
8
8
 
9
- taskset_path = shutil.which("taskset")
9
+
10
+ taskset_path = shutil.which('taskset')
10
11
 
11
12
 
12
13
  def assert_script(python_path):
13
- code_str = open(python_path).read()
14
- if "MP_ID" not in code_str or "MP_TOTAL" not in code_str:
14
+ with open(python_path) as f:
15
+ code_str = f.read()
16
+ if 'MP_ID' not in code_str or 'MP_TOTAL' not in code_str:
15
17
  example_code = (
16
18
  'import os; MP_TOTAL = int(os.environ.get("MP_TOTAL"));MP_ID = int(os.environ.get("MP_ID"))\n'
17
- "inputs = list(inputs[MP_ID::MP_TOTAL])"
19
+ 'inputs = list(inputs[MP_ID::MP_TOTAL])'
18
20
  )
19
21
  # ANSI escape codes for coloring
20
- YELLOW = "\033[93m"
21
- RESET = "\033[0m"
22
+ YELLOW = '\033[93m'
23
+ RESET = '\033[0m'
22
24
  raise_msg = (
23
- f"MP_ID and MP_TOTAL not found in {python_path}, please add them.\n\n"
24
- f"Example:\n{YELLOW}{example_code}{RESET}"
25
+ f'MP_ID and MP_TOTAL not found in {python_path}, please add them.\n\n'
26
+ f'Example:\n{YELLOW}{example_code}{RESET}'
25
27
  )
26
28
  raise Exception(raise_msg)
27
29
 
28
30
 
29
31
  def run_in_tmux(commands_to_run, tmux_name, num_windows):
30
- with open("/tmp/start_multirun_tmux.sh", "w") as script_file:
32
+ with open('/tmp/start_multirun_tmux.sh', 'w') as script_file:
31
33
  # first cmd is to kill the session if it exists
32
34
 
33
- script_file.write("#!/bin/bash\n\n")
34
- script_file.write(f"tmux kill-session -t {tmux_name}\nsleep .1\n")
35
- script_file.write(f"tmux new-session -d -s {tmux_name}\n")
35
+ script_file.write('#!/bin/bash\n\n')
36
+ script_file.write(f'tmux kill-session -t {tmux_name}\nsleep .1\n')
37
+ script_file.write(f'tmux new-session -d -s {tmux_name}\n')
36
38
  for i, cmd in enumerate(itertools.cycle(commands_to_run)):
37
39
  if i >= num_windows:
38
40
  break
39
- window_name = f"{tmux_name}:{i}"
41
+ window_name = f'{tmux_name}:{i}'
40
42
  if i == 0:
41
43
  script_file.write(f"tmux send-keys -t {window_name} '{cmd}' C-m\n")
42
44
  else:
43
- script_file.write(f"tmux new-window -t {tmux_name}\n")
45
+ script_file.write(f'tmux new-window -t {tmux_name}\n')
44
46
  script_file.write(f"tmux send-keys -t {window_name} '{cmd}' C-m\n")
45
47
 
46
48
  # Make the script executable
47
- script_file.write("chmod +x /tmp/start_multirun_tmux.sh\n")
48
- print("Run /tmp/start_multirun_tmux.sh")
49
+ script_file.write('chmod +x /tmp/start_multirun_tmux.sh\n')
50
+ print('Run /tmp/start_multirun_tmux.sh')
49
51
 
50
52
 
51
53
  def main():
52
54
  # Assert that MP_ID and MP_TOTAL are not already set
53
55
 
54
- parser = argparse.ArgumentParser(description="Process fold arguments")
56
+ parser = argparse.ArgumentParser(description='Process fold arguments')
55
57
  parser.add_argument(
56
- "--total_fold", "-t", default=16, type=int, help="total number of folds"
58
+ '--total_fold', '-t', default=16, type=int, help='total number of folds'
57
59
  )
58
- parser.add_argument("--gpus", type=str, default="0,1,2,3,4,5,6,7")
59
- parser.add_argument("--ignore_gpus", "-ig", type=str, default="")
60
+ parser.add_argument('--gpus', type=str, default='0,1,2,3,4,5,6,7')
61
+ parser.add_argument('--ignore_gpus', '-ig', type=str, default='')
60
62
  parser.add_argument(
61
- "--total_cpu",
63
+ '--total_cpu',
62
64
  type=int,
63
65
  default=multiprocessing.cpu_count(),
64
- help="total number of cpu cores available",
66
+ help='total number of cpu cores available',
65
67
  )
66
68
  parser.add_argument(
67
- "cmd", nargs=argparse.REMAINDER
69
+ 'cmd', nargs=argparse.REMAINDER
68
70
  ) # This will gather the remaining unparsed arguments
69
71
 
70
72
  args = parser.parse_args()
71
- if not args.cmd or (args.cmd[0] == "--" and len(args.cmd) == 1):
72
- parser.error("Invalid command provided")
73
+ if not args.cmd or (args.cmd[0] == '--' and len(args.cmd) == 1):
74
+ parser.error('Invalid command provided')
73
75
  assert_script(args.cmd[0])
74
76
 
75
77
  cmd_str = None
76
- if args.cmd[0] == "--":
78
+ if args.cmd[0] == '--':
77
79
  cmd_str = shlex.join(args.cmd[1:])
78
80
  else:
79
81
  cmd_str = shlex.join(args.cmd)
80
82
 
81
- gpus = args.gpus.split(",")
82
- gpus = [gpu for gpu in gpus if gpu not in args.ignore_gpus.split(",")]
83
+ gpus = args.gpus.split(',')
84
+ gpus = [gpu for gpu in gpus if gpu not in args.ignore_gpus.split(',')]
83
85
  num_gpus = len(gpus)
84
86
 
85
87
  cpu_per_process = max(args.total_cpu // args.total_fold, 1)
86
88
  cmds = []
87
- path_python = shutil.which("python")
89
+ path_python = shutil.which('python')
88
90
  for i in range(args.total_fold):
89
91
  gpu = gpus[i % num_gpus]
90
92
  cpu_start = (i * cpu_per_process) % args.total_cpu
91
93
  cpu_end = ((i + 1) * cpu_per_process - 1) % args.total_cpu
92
- ENV = f"CUDA_VISIBLE_DEVICES={gpu} MP_ID={i} MP_TOTAL={args.total_fold}"
94
+ ENV = f'CUDA_VISIBLE_DEVICES={gpu} MP_ID={i} MP_TOTAL={args.total_fold}'
93
95
  if taskset_path:
94
- fold_cmd = f"{ENV} {taskset_path} -c {cpu_start}-{cpu_end} {path_python} {cmd_str}"
96
+ fold_cmd = f'{ENV} {taskset_path} -c {cpu_start}-{cpu_end} {path_python} {cmd_str}'
95
97
  else:
96
- fold_cmd = f"{ENV} {path_python} {cmd_str}"
98
+ fold_cmd = f'{ENV} {path_python} {cmd_str}'
97
99
 
98
100
  cmds.append(fold_cmd)
99
101
 
100
- run_in_tmux(cmds, "mpython", args.total_fold)
101
- os.chmod("/tmp/start_multirun_tmux.sh", 0o755) # Make the script executable
102
- os.system("/tmp/start_multirun_tmux.sh")
102
+ run_in_tmux(cmds, 'mpython', args.total_fold)
103
+ os.chmod('/tmp/start_multirun_tmux.sh', 0o755) # Make the script executable
104
+ os.system('/tmp/start_multirun_tmux.sh')
103
105
 
104
106
 
105
- if __name__ == "__main__":
107
+ if __name__ == '__main__':
106
108
  main()
@@ -33,7 +33,7 @@ def snake_case(s: str) -> str:
33
33
  return "".join(out)
34
34
 
35
35
 
36
- def map_openapi_type(prop: Dict[str, Any]) -> str:
36
+ def map_openapi_type(prop: dict[str, Any]) -> str:
37
37
  t = prop.get("type")
38
38
  if t == "string":
39
39
  fmt = prop.get("format")
@@ -50,8 +50,8 @@ def map_openapi_type(prop: Dict[str, Any]) -> str:
50
50
  return "Any"
51
51
 
52
52
 
53
- def generate_models(components: Dict[str, Any]) -> List[str]:
54
- lines: List[str] = []
53
+ def generate_models(components: dict[str, Any]) -> list[str]:
54
+ lines: list[str] = []
55
55
  schemas = components.get("schemas", {})
56
56
  for name, schema in schemas.items():
57
57
  if "enum" in schema:
@@ -77,10 +77,10 @@ def generate_models(components: Dict[str, Any]) -> List[str]:
77
77
  return lines
78
78
 
79
79
 
80
- def generate_client(spec: Dict[str, Any]) -> List[str]:
80
+ def generate_client(spec: dict[str, Any]) -> list[str]:
81
81
  paths = spec.get("paths", {})
82
82
  models = spec.get("components", {}).get("schemas", {})
83
- lines: List[str] = []
83
+ lines: list[str] = []
84
84
  lines.append("class GeneratedClient:")
85
85
  lines.append(' """Client generated from OpenAPI spec."""')
86
86
  lines.append("")
@@ -115,8 +115,8 @@ def generate_client(spec: Dict[str, Any]) -> List[str]:
115
115
  func_name = snake_case(op_id)
116
116
  summary = op.get("summary", "").strip()
117
117
  # collect parameters
118
- req_params: List[str] = ["self"]
119
- opt_params: List[str] = []
118
+ req_params: list[str] = ["self"]
119
+ opt_params: list[str] = []
120
120
  # path params (required)
121
121
  path_params = [p for p in op.get("parameters", []) if p.get("in") == "path"]
122
122
  for p in path_params:
@@ -219,15 +219,15 @@ def main() -> None:
219
219
 
220
220
  try:
221
221
  spec_src = args.spec
222
- if spec_src.startswith("http://") or spec_src.startswith("https://"):
222
+ if spec_src.startswith(("http://", "https://")):
223
223
  import httpx
224
224
 
225
225
  response = httpx.get(spec_src)
226
226
  spec = response.json()
227
227
  else:
228
- with open(spec_src, "r", encoding="utf-8") as f:
228
+ with open(spec_src, encoding="utf-8") as f:
229
229
  spec = json.load(f)
230
- out: List[str] = []
230
+ out: list[str] = []
231
231
  # imports
232
232
  out.append("from typing import Any, Dict, List, Optional")
233
233
  out.append("from datetime import datetime")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: speedy-utils
3
- Version: 1.1.27
3
+ Version: 1.1.28
4
4
  Summary: Fast and easy-to-use package for data science
5
5
  Project-URL: Homepage, https://github.com/anhvth/speedy
6
6
  Project-URL: Repository, https://github.com/anhvth/speedy