speedy-utils 1.1.27__py3-none-any.whl → 1.1.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. llm_utils/__init__.py +16 -4
  2. llm_utils/chat_format/__init__.py +10 -10
  3. llm_utils/chat_format/display.py +33 -21
  4. llm_utils/chat_format/transform.py +17 -19
  5. llm_utils/chat_format/utils.py +6 -4
  6. llm_utils/group_messages.py +17 -14
  7. llm_utils/lm/__init__.py +6 -5
  8. llm_utils/lm/async_lm/__init__.py +1 -0
  9. llm_utils/lm/async_lm/_utils.py +10 -9
  10. llm_utils/lm/async_lm/async_llm_task.py +141 -137
  11. llm_utils/lm/async_lm/async_lm.py +48 -42
  12. llm_utils/lm/async_lm/async_lm_base.py +59 -60
  13. llm_utils/lm/async_lm/lm_specific.py +4 -3
  14. llm_utils/lm/base_prompt_builder.py +93 -70
  15. llm_utils/lm/llm.py +126 -108
  16. llm_utils/lm/llm_signature.py +4 -2
  17. llm_utils/lm/lm_base.py +72 -73
  18. llm_utils/lm/mixins.py +102 -62
  19. llm_utils/lm/openai_memoize.py +124 -87
  20. llm_utils/lm/signature.py +105 -92
  21. llm_utils/lm/utils.py +42 -23
  22. llm_utils/scripts/vllm_load_balancer.py +23 -30
  23. llm_utils/scripts/vllm_serve.py +8 -7
  24. llm_utils/vector_cache/__init__.py +9 -3
  25. llm_utils/vector_cache/cli.py +1 -1
  26. llm_utils/vector_cache/core.py +59 -63
  27. llm_utils/vector_cache/types.py +7 -5
  28. llm_utils/vector_cache/utils.py +12 -8
  29. speedy_utils/__imports.py +244 -0
  30. speedy_utils/__init__.py +90 -194
  31. speedy_utils/all.py +125 -227
  32. speedy_utils/common/clock.py +37 -42
  33. speedy_utils/common/function_decorator.py +6 -12
  34. speedy_utils/common/logger.py +43 -52
  35. speedy_utils/common/notebook_utils.py +13 -21
  36. speedy_utils/common/patcher.py +21 -17
  37. speedy_utils/common/report_manager.py +42 -44
  38. speedy_utils/common/utils_cache.py +152 -169
  39. speedy_utils/common/utils_io.py +137 -103
  40. speedy_utils/common/utils_misc.py +15 -21
  41. speedy_utils/common/utils_print.py +22 -28
  42. speedy_utils/multi_worker/process.py +66 -79
  43. speedy_utils/multi_worker/thread.py +78 -155
  44. speedy_utils/scripts/mpython.py +38 -36
  45. speedy_utils/scripts/openapi_client_codegen.py +10 -10
  46. {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.29.dist-info}/METADATA +1 -1
  47. speedy_utils-1.1.29.dist-info/RECORD +57 -0
  48. vision_utils/README.md +202 -0
  49. vision_utils/__init__.py +4 -0
  50. vision_utils/io_utils.py +735 -0
  51. vision_utils/plot.py +345 -0
  52. speedy_utils-1.1.27.dist-info/RECORD +0 -52
  53. {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.29.dist-info}/WHEEL +0 -0
  54. {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.29.dist-info}/entry_points.txt +0 -0
@@ -1,95 +1,102 @@
1
1
  # utils/utils_io.py
2
2
 
3
- import bz2
4
- import gzip
5
- import io
6
- import json
7
- import lzma
8
- import os
9
- import os.path as osp
10
- import pickle
11
- import time
12
- import warnings
13
- from glob import glob
14
- from pathlib import Path
15
- from typing import IO, Any, Iterable, Optional, Union, cast
16
-
17
- from json_repair import loads as jloads
18
- from pydantic import BaseModel
19
-
3
+ # import bz2
4
+ # import contextlib
5
+ # import gzip
6
+ # import io
7
+ # import json
8
+ # import lzma
9
+ # import os
10
+ # import os.path as osp
11
+ # import pickle
12
+ # import time
13
+ # import warnings
14
+ # from collections.abc import Iterable
15
+ # from glob import glob
16
+ # from pathlib import Path
17
+ # from typing import IO, Any, Optional, Union, cast
18
+
19
+
20
+ # from pydantic import BaseModel
21
+
22
+ from ..__imports import *
20
23
  from .utils_misc import mkdir_or_exist
21
24
 
22
25
 
23
- def dump_jsonl(list_dictionaries: list[dict], file_name: str = "output.jsonl") -> None:
26
+ def dump_jsonl(list_dictionaries: list[dict], file_name: str = 'output.jsonl') -> None:
24
27
  """
25
28
  Dumps a list of dictionaries to a file in JSON Lines format.
26
29
  """
27
- with open(file_name, "w", encoding="utf-8") as file:
30
+ with open(file_name, 'w', encoding='utf-8') as file:
28
31
  for dictionary in list_dictionaries:
29
- file.write(json.dumps(dictionary, ensure_ascii=False) + "\n")
32
+ file.write(json.dumps(dictionary, ensure_ascii=False) + '\n')
30
33
 
31
34
 
32
- def dump_json_or_pickle(obj: Any, fname: str, ensure_ascii: bool = False, indent: int = 4) -> None:
35
+ def dump_json_or_pickle(
36
+ obj: Any, fname: str, ensure_ascii: bool = False, indent: int = 4
37
+ ) -> None:
33
38
  """
34
39
  Dump an object to a file, supporting both JSON and pickle formats.
35
40
  """
36
41
  if isinstance(fname, Path):
37
42
  fname = str(fname)
38
43
  mkdir_or_exist(osp.abspath(os.path.dirname(osp.abspath(fname))))
39
- if fname.endswith(".json"):
40
- with open(fname, "w", encoding="utf-8") as f:
44
+ if fname.endswith('.json'):
45
+ with open(fname, 'w', encoding='utf-8') as f:
41
46
  try:
42
47
  json.dump(obj, f, ensure_ascii=ensure_ascii, indent=indent)
43
48
  # TypeError: Object of type datetime is not JSON serializable
44
49
  except TypeError:
45
50
  print(
46
- "Error: Object of type datetime is not JSON serializable",
51
+ 'Error: Object of type datetime is not JSON serializable',
47
52
  str(obj)[:1000],
48
53
  )
49
54
  raise
50
- elif fname.endswith(".jsonl"):
55
+ elif fname.endswith('.jsonl'):
51
56
  dump_jsonl(obj, fname)
52
- elif fname.endswith(".pkl"):
57
+ elif fname.endswith('.pkl'):
53
58
  try:
54
- with open(fname, "wb") as f:
59
+ with open(fname, 'wb') as f:
55
60
  pickle.dump(obj, f)
56
61
  except Exception as e:
57
62
  if isinstance(obj, BaseModel):
58
- data = obj.model_dump()
63
+ data = obj.model_dump() # type: ignore
59
64
  from fastcore.all import dict2obj, obj2dict
60
65
 
61
66
  obj2 = dict2obj(data)
62
- with open(fname, "wb") as f:
67
+ with open(fname, 'wb') as f:
63
68
  pickle.dump(obj2, f)
64
69
  else:
65
- raise ValueError(f"Error {e} while dumping {fname}") from e
70
+ raise ValueError(f'Error {e} while dumping {fname}') from e
66
71
 
67
72
  else:
68
- raise NotImplementedError(f"File type {fname} not supported")
73
+ raise NotImplementedError(f'File type {fname} not supported')
69
74
 
70
75
 
71
76
  def load_json_or_pickle(fname: str, counter=0) -> Any:
72
77
  """
73
78
  Load an object from a file, supporting both JSON and pickle formats.
74
79
  """
75
- if fname.endswith(".json") or fname.endswith(".jsonl"):
76
- with open(fname, encoding="utf-8") as f:
80
+ if fname.endswith(('.json', '.jsonl')):
81
+ with open(fname, encoding='utf-8') as f:
77
82
  return json.load(f)
78
83
  else:
79
84
  try:
80
- with open(fname, "rb") as f:
85
+ with open(fname, 'rb') as f:
81
86
  return pickle.load(f)
82
87
  # EOFError: Ran out of input
83
88
  except EOFError:
84
89
  time.sleep(1)
85
90
  if counter > 5:
86
91
  # Keep message concise and actionable
87
- print(f"Corrupted cache file {fname} removed; it will be regenerated on next access")
92
+ print(
93
+ f'Corrupted cache file {fname} removed; it will be regenerated on next access'
94
+ )
88
95
  os.remove(fname)
89
96
  raise
90
97
  return load_json_or_pickle(fname, counter + 1)
91
98
  except Exception as e:
92
- raise ValueError(f"Error {e} while loading {fname}") from e
99
+ raise ValueError(f'Error {e} while loading {fname}') from e
93
100
 
94
101
 
95
102
  try:
@@ -104,19 +111,19 @@ except Exception:
104
111
 
105
112
 
106
113
  def fast_load_jsonl(
107
- path_or_file: Union[str, os.PathLike, IO],
114
+ path_or_file: str | os.PathLike | IO,
108
115
  *,
109
116
  progress: bool = False,
110
- desc: str = "Reading JSONL",
117
+ desc: str = 'Reading JSONL',
111
118
  use_orjson: bool = True,
112
- encoding: str = "utf-8",
113
- errors: str = "strict",
114
- on_error: str = "raise", # 'raise' | 'warn' | 'skip'
119
+ encoding: str = 'utf-8',
120
+ errors: str = 'strict',
121
+ on_error: str = 'raise', # 'raise' | 'warn' | 'skip'
115
122
  skip_empty: bool = True,
116
- max_lines: Optional[int] = None,
123
+ max_lines: int | None = None,
117
124
  use_multiworker: bool = True,
118
125
  multiworker_threshold: int = 1000000,
119
- workers: Optional[int] = None,
126
+ workers: int | None = None,
120
127
  ) -> Iterable[Any]:
121
128
  """
122
129
  Lazily iterate objects from a JSON Lines file.
@@ -144,31 +151,52 @@ def fast_load_jsonl(
144
151
  Parsed Python objects per line.
145
152
  """
146
153
 
154
+ class ZstdWrapper:
155
+ """Context manager wrapper for zstd decompression."""
156
+
157
+ def __init__(self, path):
158
+ self.path = path
159
+ self.fh = None
160
+ self.stream = None
161
+ self.buffered = None
162
+
163
+ def __enter__(self):
164
+ if zstd is None:
165
+ raise ImportError('zstandard package required for .zst files')
166
+ self.fh = open(self.path, 'rb')
167
+ dctx = zstd.ZstdDecompressor()
168
+ self.stream = dctx.stream_reader(self.fh)
169
+ self.buffered = io.BufferedReader(self.stream)
170
+ return self.buffered
171
+
172
+ def __exit__(self, exc_type, exc_val, exc_tb):
173
+ if self.fh:
174
+ self.fh.close()
175
+
147
176
  def _open_auto(pth_or_f) -> IO[Any]:
148
- if hasattr(pth_or_f, "read"):
177
+ if hasattr(pth_or_f, 'read'):
149
178
  # ensure binary buffer for consistent byte-length progress
150
179
  fobj = pth_or_f
151
180
  # If it's text, wrap it to binary via encoding; else just return
152
181
  if isinstance(fobj, io.TextIOBase):
153
182
  # TextIO -> re-encode to bytes on the fly
154
- return io.BufferedReader(io.BytesIO(fobj.read().encode(encoding, errors)))
183
+ return io.BufferedReader(
184
+ io.BytesIO(fobj.read().encode(encoding, errors))
185
+ )
155
186
  return pth_or_f # assume binary
156
187
  s = str(pth_or_f).lower()
157
- if s.endswith(".gz"):
158
- return gzip.open(pth_or_f, "rb") # type: ignore
159
- if s.endswith(".bz2"):
160
- return bz2.open(pth_or_f, "rb") # type: ignore
161
- if s.endswith((".xz", ".lzma")):
162
- return lzma.open(pth_or_f, "rb") # type: ignore
163
- if s.endswith((".zst", ".zstd")) and zstd is not None:
164
- fh = open(pth_or_f, "rb")
165
- dctx = zstd.ZstdDecompressor()
166
- stream = dctx.stream_reader(fh)
167
- return io.BufferedReader(stream) # type: ignore
188
+ if s.endswith('.gz'):
189
+ return gzip.open(pth_or_f, 'rb') # type: ignore
190
+ if s.endswith('.bz2'):
191
+ return bz2.open(pth_or_f, 'rb') # type: ignore
192
+ if s.endswith(('.xz', '.lzma')):
193
+ return lzma.open(pth_or_f, 'rb') # type: ignore
194
+ if s.endswith(('.zst', '.zstd')) and zstd is not None:
195
+ return ZstdWrapper(pth_or_f).__enter__() # type: ignore
168
196
  # plain
169
- return open(pth_or_f, "rb", buffering=1024 * 1024)
197
+ return open(pth_or_f, 'rb', buffering=1024 * 1024)
170
198
 
171
- def _count_lines_fast(file_path: Union[str, os.PathLike]) -> int:
199
+ def _count_lines_fast(file_path: str | os.PathLike) -> int:
172
200
  """Quickly count lines in a file, handling compression."""
173
201
  try:
174
202
  f = _open_auto(file_path)
@@ -187,7 +215,7 @@ def fast_load_jsonl(
187
215
  for line_bytes in chunk_lines:
188
216
  if skip_empty and not line_bytes.strip():
189
217
  continue
190
- line_bytes = line_bytes.rstrip(b"\r\n")
218
+ line_bytes = line_bytes.rstrip(b'\r\n')
191
219
  try:
192
220
  if use_orjson and orjson is not None:
193
221
  obj = orjson.loads(line_bytes)
@@ -195,10 +223,10 @@ def fast_load_jsonl(
195
223
  obj = json.loads(line_bytes.decode(encoding, errors))
196
224
  results.append(obj)
197
225
  except Exception as e:
198
- if on_error == "raise":
226
+ if on_error == 'raise':
199
227
  raise
200
- if on_error == "warn":
201
- warnings.warn(f"Skipping malformed line: {e}")
228
+ if on_error == 'warn':
229
+ warnings.warn(f'Skipping malformed line: {e}', stacklevel=2)
202
230
  # 'skip' and 'warn' both skip the line
203
231
  continue
204
232
  return results
@@ -206,12 +234,12 @@ def fast_load_jsonl(
206
234
  # Check if we should use multi-worker processing
207
235
  should_use_multiworker = (
208
236
  use_multiworker
209
- and not hasattr(path_or_file, "read") # Only for file paths, not file objects
237
+ and not hasattr(path_or_file, 'read') # Only for file paths, not file objects
210
238
  and max_lines is None # Don't use multiworker if we're limiting lines
211
239
  )
212
240
 
213
241
  if should_use_multiworker:
214
- line_count = _count_lines_fast(cast(Union[str, os.PathLike], path_or_file))
242
+ line_count = _count_lines_fast(cast(str | os.PathLike, path_or_file))
215
243
  if line_count > multiworker_threshold:
216
244
  # Use multi-worker processing
217
245
  from ..multi_worker.thread import multi_thread
@@ -236,9 +264,13 @@ def fast_load_jsonl(
236
264
 
237
265
  # Process chunks in parallel
238
266
  if progress:
239
- print(f"Processing {line_count} lines with {num_workers} workers ({len(chunks)} chunks)...")
267
+ print(
268
+ f'Processing {line_count} lines with {num_workers} workers ({len(chunks)} chunks)...'
269
+ )
240
270
 
241
- chunk_results = multi_thread(_process_chunk, chunks, workers=num_workers, progress=progress)
271
+ chunk_results = multi_thread(
272
+ _process_chunk, chunks, workers=num_workers, progress=progress
273
+ )
242
274
 
243
275
  # Flatten results and yield
244
276
  if chunk_results:
@@ -257,15 +289,15 @@ def fast_load_jsonl(
257
289
  try:
258
290
  from tqdm import tqdm # type: ignore
259
291
  except Exception as e:
260
- raise ImportError("tqdm is required when progress=True") from e
292
+ raise ImportError('tqdm is required when progress=True') from e
261
293
  total = None
262
- if not hasattr(path_or_file, "read"):
294
+ if not hasattr(path_or_file, 'read'):
263
295
  try:
264
- path_for_size = cast(Union[str, os.PathLike], path_or_file)
296
+ path_for_size = cast(str | os.PathLike, path_or_file)
265
297
  total = os.path.getsize(path_for_size) # compressed size if any
266
298
  except Exception:
267
299
  total = None
268
- pbar = tqdm(total=total, unit="B", unit_scale=True, desc=desc)
300
+ pbar = tqdm(total=total, unit='B', unit_scale=True, desc=desc)
269
301
 
270
302
  line_no = 0
271
303
  try:
@@ -286,7 +318,7 @@ def fast_load_jsonl(
286
318
  if max_lines and line_no >= max_lines:
287
319
  break
288
320
  continue
289
- line_bytes = raw_line.rstrip(b"\r\n")
321
+ line_bytes = raw_line.rstrip(b'\r\n')
290
322
  # Parse
291
323
  try:
292
324
  if use_orjson and orjson is not None:
@@ -294,10 +326,12 @@ def fast_load_jsonl(
294
326
  else:
295
327
  obj = json.loads(line_bytes.decode(encoding, errors))
296
328
  except Exception as e:
297
- if on_error == "raise":
329
+ if on_error == 'raise':
298
330
  raise
299
- if on_error == "warn":
300
- warnings.warn(f"Skipping malformed line {line_no}: {e}")
331
+ if on_error == 'warn':
332
+ warnings.warn(
333
+ f'Skipping malformed line {line_no}: {e}', stacklevel=2
334
+ )
301
335
  # 'skip' and 'warn' both skip the line
302
336
  if max_lines and line_no >= max_lines:
303
337
  break
@@ -311,10 +345,12 @@ def fast_load_jsonl(
311
345
  try:
312
346
  obj = json.loads(raw_line)
313
347
  except Exception as e:
314
- if on_error == "raise":
348
+ if on_error == 'raise':
315
349
  raise
316
- if on_error == "warn":
317
- warnings.warn(f"Skipping malformed line {line_no}: {e}")
350
+ if on_error == 'warn':
351
+ warnings.warn(
352
+ f'Skipping malformed line {line_no}: {e}', stacklevel=2
353
+ )
318
354
  if max_lines and line_no >= max_lines:
319
355
  break
320
356
  continue
@@ -326,14 +362,12 @@ def fast_load_jsonl(
326
362
  if pbar is not None:
327
363
  pbar.close()
328
364
  # Close only if we opened it (i.e., not an external stream)
329
- if not hasattr(path_or_file, "read"):
330
- try:
365
+ if not hasattr(path_or_file, 'read'):
366
+ with contextlib.suppress(Exception):
331
367
  f.close()
332
- except Exception:
333
- pass
334
368
 
335
369
 
336
- def load_by_ext(fname: Union[str, list[str]], do_memoize: bool = False) -> Any:
370
+ def load_by_ext(fname: str | list[str], do_memoize: bool = False) -> Any:
337
371
  """
338
372
  Load data based on file extension.
339
373
  """
@@ -346,54 +380,54 @@ def load_by_ext(fname: Union[str, list[str]], do_memoize: bool = False) -> Any:
346
380
  )
347
381
 
348
382
  try:
349
- if isinstance(fname, str) and "*" in fname:
383
+ if isinstance(fname, str) and '*' in fname:
350
384
  paths = glob(fname)
351
385
  paths = sorted(paths)
352
386
  return multi_process(load_by_ext, paths, workers=16)
353
- elif isinstance(fname, list):
387
+ if isinstance(fname, list):
354
388
  paths = fname
355
389
  return multi_process(load_by_ext, paths, workers=16)
356
390
 
357
391
  def load_csv(path: str, **pd_kwargs) -> Any:
358
392
  import pandas as pd
359
393
 
360
- return pd.read_csv(path, engine="pyarrow", **pd_kwargs)
394
+ return pd.read_csv(path, engine='pyarrow', **pd_kwargs)
361
395
 
362
396
  def load_txt(path: str) -> list[str]:
363
- with open(path, encoding="utf-8") as f:
397
+ with open(path, encoding='utf-8') as f:
364
398
  return f.read().splitlines()
365
399
 
366
400
  def load_default(path: str) -> Any:
367
- if path.endswith(".jsonl"):
401
+ if path.endswith('.jsonl'):
368
402
  return list(fast_load_jsonl(path, progress=True))
369
- elif path.endswith(".json"):
403
+ if path.endswith('.json'):
370
404
  try:
371
405
  return load_json_or_pickle(path)
372
406
  except json.JSONDecodeError as exc:
373
- raise ValueError("JSON decoding failed") from exc
407
+ raise ValueError('JSON decoding failed') from exc
374
408
  return load_json_or_pickle(path)
375
409
 
376
410
  handlers = {
377
- ".csv": load_csv,
378
- ".tsv": load_csv,
379
- ".txt": load_txt,
380
- ".pkl": load_default,
381
- ".json": load_default,
382
- ".jsonl": load_default,
411
+ '.csv': load_csv,
412
+ '.tsv': load_csv,
413
+ '.txt': load_txt,
414
+ '.pkl': load_default,
415
+ '.json': load_default,
416
+ '.jsonl': load_default,
383
417
  }
384
418
 
385
419
  ext = os.path.splitext(fname)[-1]
386
420
  load_fn = handlers.get(ext)
387
421
 
388
422
  if not load_fn:
389
- raise NotImplementedError(f"File type {ext} not supported")
423
+ raise NotImplementedError(f'File type {ext} not supported')
390
424
 
391
425
  if do_memoize:
392
426
  load_fn = memoize(load_fn)
393
427
 
394
428
  return load_fn(fname)
395
429
  except Exception as e:
396
- raise ValueError(f"Error {e} while loading {fname}") from e
430
+ raise ValueError(f'Error {e} while loading {fname}') from e
397
431
 
398
432
 
399
433
  def jdumps(obj, ensure_ascii=False, indent=2, **kwargs):
@@ -403,10 +437,10 @@ def jdumps(obj, ensure_ascii=False, indent=2, **kwargs):
403
437
  load_jsonl = lambda path: list(fast_load_jsonl(path))
404
438
 
405
439
  __all__ = [
406
- "dump_json_or_pickle",
407
- "dump_jsonl",
408
- "load_by_ext",
409
- "load_json_or_pickle",
410
- "jdumps",
411
- "jloads",
440
+ 'dump_json_or_pickle',
441
+ 'dump_jsonl',
442
+ 'load_by_ext',
443
+ 'load_json_or_pickle',
444
+ 'jdumps',
445
+ 'jloads',
412
446
  ]
@@ -1,13 +1,8 @@
1
1
  # utils/utils_misc.py
2
+ from ..__imports import *
2
3
 
3
- import inspect
4
- import os
5
- from collections.abc import Callable
6
- from typing import Any, TypeVar
7
4
 
8
- from pydantic import BaseModel
9
-
10
- T = TypeVar("T")
5
+ T = TypeVar('T')
11
6
 
12
7
 
13
8
  def mkdir_or_exist(dir_name: str) -> None:
@@ -27,10 +22,10 @@ def get_arg_names(func: Callable) -> list[str]:
27
22
 
28
23
  def is_notebook() -> bool:
29
24
  try:
30
- if "get_ipython" in globals().keys():
31
- get_ipython = globals()["get_ipython"]
25
+ if 'get_ipython' in globals():
26
+ get_ipython = globals()['get_ipython']
32
27
  shell = get_ipython().__class__.__name__
33
- if shell == "ZMQInteractiveShell":
28
+ if shell == 'ZMQInteractiveShell':
34
29
  return True # Jupyter notebook or qtconsole
35
30
  return False # Other type (?)
36
31
  except NameError:
@@ -41,15 +36,14 @@ def convert_to_builtin_python(input_data: Any) -> Any:
41
36
  """Convert input data to built-in Python types."""
42
37
  if isinstance(input_data, dict):
43
38
  return {k: convert_to_builtin_python(v) for k, v in input_data.items()}
44
- elif isinstance(input_data, list):
39
+ if isinstance(input_data, list):
45
40
  return [convert_to_builtin_python(v) for v in input_data]
46
- elif isinstance(input_data, (int, float, str, bool, type(None))):
41
+ if isinstance(input_data, (int, float, str, bool, type(None))):
47
42
  return input_data
48
- elif isinstance(input_data, BaseModel):
43
+ if isinstance(input_data, BaseModel):
49
44
  data = input_data.model_dump_json()
50
45
  return convert_to_builtin_python(data)
51
- else:
52
- raise ValueError(f"Unsupported type {type(input_data)}")
46
+ raise ValueError(f'Unsupported type {type(input_data)}')
53
47
 
54
48
 
55
49
  def dedup(items: list[T], key: Callable[[T], Any]) -> list[T]:
@@ -74,10 +68,10 @@ def dedup(items: list[T], key: Callable[[T], Any]) -> list[T]:
74
68
 
75
69
 
76
70
  __all__ = [
77
- "mkdir_or_exist",
78
- "flatten_list",
79
- "get_arg_names",
80
- "is_notebook",
81
- "convert_to_builtin_python",
82
- "dedup",
71
+ 'mkdir_or_exist',
72
+ 'flatten_list',
73
+ 'get_arg_names',
74
+ 'is_notebook',
75
+ 'convert_to_builtin_python',
76
+ 'dedup',
83
77
  ]
@@ -1,17 +1,11 @@
1
1
  # utils/utils_print.py
2
2
 
3
- import copy
4
- import pprint
5
- import textwrap
6
- from typing import Any, Union
7
-
8
- from tabulate import tabulate
9
-
3
+ from ..__imports import *
10
4
  from .notebook_utils import display_pretty_table_html
11
5
 
12
6
 
13
7
  # Flattening the dictionary using "." notation for keys
14
- def flatten_dict(d, parent_key="", sep="."):
8
+ def flatten_dict(d, parent_key='', sep='.'):
15
9
  items = []
16
10
  for k, v in d.items():
17
11
  new_key = parent_key + sep + k if parent_key else k
@@ -24,22 +18,22 @@ def flatten_dict(d, parent_key="", sep="."):
24
18
 
25
19
  def fprint(
26
20
  input_data: Any,
27
- key_ignore: Union[list[str], None] = None,
28
- key_keep: Union[list[str], None] = None,
21
+ key_ignore: list[str] | None = None,
22
+ key_keep: list[str] | None = None,
29
23
  max_width: int = 100,
30
24
  indent: int = 2,
31
- depth: Union[int, None] = None,
32
- table_format: str = "grid",
25
+ depth: int | None = None,
26
+ table_format: str = 'grid',
33
27
  str_wrap_width: int = 80,
34
28
  grep=None,
35
29
  is_notebook=None,
36
30
  f=print,
37
- ) -> Union[None, str]:
31
+ ) -> None | str:
38
32
  """
39
33
  Pretty print structured data.
40
34
  """
41
35
  if isinstance(input_data, list):
42
- for i, item in enumerate(input_data):
36
+ for _i, item in enumerate(input_data):
43
37
  fprint(
44
38
  item,
45
39
  key_ignore,
@@ -53,7 +47,7 @@ def fprint(
53
47
  is_notebook,
54
48
  f,
55
49
  )
56
- print("\n" + "-" * 100 + "\n")
50
+ print('\n' + '-' * 100 + '\n')
57
51
 
58
52
  from speedy_utils import is_notebook as is_interactive
59
53
 
@@ -61,24 +55,24 @@ def fprint(
61
55
  if is_notebook is None:
62
56
  is_notebook = is_interactive()
63
57
  if isinstance(input_data, list):
64
- if all(hasattr(item, "toDict") for item in input_data):
58
+ if all(hasattr(item, 'toDict') for item in input_data):
65
59
  input_data = [item.toDict() for item in input_data]
66
- elif hasattr(input_data, "toDict"):
60
+ elif hasattr(input_data, 'toDict'):
67
61
  input_data = input_data.toDict()
68
62
 
69
63
  if isinstance(input_data, list):
70
- if all(hasattr(item, "to_dict") for item in input_data):
64
+ if all(hasattr(item, 'to_dict') for item in input_data):
71
65
  input_data = [item.to_dict() for item in input_data]
72
- elif hasattr(input_data, "to_dict"):
66
+ elif hasattr(input_data, 'to_dict'):
73
67
  input_data = input_data.to_dict()
74
68
 
75
69
  if isinstance(input_data, list):
76
- if all(hasattr(item, "model_dump") for item in input_data):
70
+ if all(hasattr(item, 'model_dump') for item in input_data):
77
71
  input_data = [item.model_dump() for item in input_data]
78
- elif hasattr(input_data, "model_dump"):
72
+ elif hasattr(input_data, 'model_dump'):
79
73
  input_data = input_data.model_dump()
80
74
  if not isinstance(input_data, (dict, str)):
81
- raise ValueError("Input data must be a dictionary or string")
75
+ raise ValueError('Input data must be a dictionary or string')
82
76
 
83
77
  if isinstance(input_data, dict):
84
78
  input_data = flatten_dict(input_data)
@@ -89,7 +83,7 @@ def fprint(
89
83
  def remove_keys(d: dict, keys: list[str]) -> dict:
90
84
  """Remove specified keys from a dictionary."""
91
85
  for key in keys:
92
- parts = key.split(".")
86
+ parts = key.split('.')
93
87
  sub_dict = d
94
88
  for part in parts[:-1]:
95
89
  sub_dict = sub_dict.get(part, {})
@@ -100,7 +94,7 @@ def fprint(
100
94
  """Keep only specified keys in a dictionary."""
101
95
  result = {}
102
96
  for key in keys:
103
- parts = key.split(".")
97
+ parts = key.split('.')
104
98
  sub_source = d
105
99
  sub_result = result
106
100
  for part in parts[:-1]:
@@ -112,7 +106,7 @@ def fprint(
112
106
  sub_result[parts[-1]] = copy.deepcopy(sub_source.get(parts[-1]))
113
107
  return result
114
108
 
115
- if hasattr(input_data, "to_dict") and not isinstance(input_data, str):
109
+ if hasattr(input_data, 'to_dict') and not isinstance(input_data, str):
116
110
  input_data = input_data.to_dict()
117
111
 
118
112
  processed_data = copy.deepcopy(input_data)
@@ -132,7 +126,7 @@ def fprint(
132
126
  f(
133
127
  tabulate(
134
128
  table,
135
- headers=["Key", "Value"],
129
+ headers=['Key', 'Value'],
136
130
  tablefmt=table_format,
137
131
  maxcolwidths=[None, max_width],
138
132
  )
@@ -148,6 +142,6 @@ def fprint(
148
142
 
149
143
 
150
144
  __all__ = [
151
- "flatten_dict",
152
- "fprint",
145
+ 'flatten_dict',
146
+ 'fprint',
153
147
  ]