speedy-utils 1.1.23__py3-none-any.whl → 1.1.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/lm/utils.py CHANGED
@@ -14,6 +14,7 @@ from openai import OpenAI
14
14
 
15
15
  try:
16
16
  import psutil
17
+
17
18
  HAS_PSUTIL = True
18
19
  except ImportError:
19
20
  HAS_PSUTIL = False
@@ -26,7 +27,7 @@ _VLLM_PROCESSES: List[subprocess.Popen] = []
26
27
 
27
28
  def _extract_port_from_vllm_cmd(vllm_cmd: str) -> int:
28
29
  """Extract port from VLLM command string."""
29
- port_match = re.search(r'--port\s+(\d+)', vllm_cmd)
30
+ port_match = re.search(r"--port\s+(\d+)", vllm_cmd)
30
31
  if port_match:
31
32
  return int(port_match.group(1))
32
33
  return 8000
@@ -34,39 +35,39 @@ def _extract_port_from_vllm_cmd(vllm_cmd: str) -> int:
34
35
 
35
36
  def _parse_env_vars_from_cmd(cmd: str) -> tuple[dict[str, str], str]:
36
37
  """Parse environment variables from command string.
37
-
38
+
38
39
  Args:
39
40
  cmd: Command string that may contain environment variables like 'VAR=value command...'
40
-
41
+
41
42
  Returns:
42
43
  Tuple of (env_dict, cleaned_cmd) where env_dict contains parsed env vars
43
44
  and cleaned_cmd is the command without the env vars.
44
45
  """
45
46
  import shlex
46
-
47
+
47
48
  # Split the command while preserving quoted strings
48
49
  parts = shlex.split(cmd)
49
-
50
+
50
51
  env_vars = {}
51
52
  cmd_parts = []
52
-
53
+
53
54
  for part in parts:
54
- if '=' in part and not part.startswith('-'):
55
+ if "=" in part and not part.startswith("-"):
55
56
  # Check if this looks like an environment variable
56
57
  # Should be KEY=VALUE format, not contain spaces (unless quoted), and KEY should be uppercase
57
- key_value = part.split('=', 1)
58
+ key_value = part.split("=", 1)
58
59
  if len(key_value) == 2:
59
60
  key, value = key_value
60
- if key.isupper() and key.replace('_', '').isalnum():
61
+ if key.isupper() and key.replace("_", "").isalnum():
61
62
  env_vars[key] = value
62
63
  continue
63
-
64
+
64
65
  # Not an env var, add to command parts
65
66
  cmd_parts.append(part)
66
-
67
+
67
68
  # Reconstruct the cleaned command
68
- cleaned_cmd = ' '.join(cmd_parts)
69
-
69
+ cleaned_cmd = " ".join(cmd_parts)
70
+
70
71
  return env_vars, cleaned_cmd
71
72
 
72
73
 
@@ -74,38 +75,33 @@ def _start_vllm_server(vllm_cmd: str, timeout: int = 120) -> subprocess.Popen:
74
75
  """Start VLLM server and wait for ready."""
75
76
  # Parse environment variables from command
76
77
  env_vars, cleaned_cmd = _parse_env_vars_from_cmd(vllm_cmd)
77
-
78
+
78
79
  port = _extract_port_from_vllm_cmd(cleaned_cmd)
79
-
80
+
80
81
  logger.info(f"Starting VLLM server: {cleaned_cmd}")
81
82
  if env_vars:
82
83
  logger.info(f"Environment variables: {env_vars}")
83
84
  logger.info(f"VLLM output logged to: /tmp/vllm_{port}.txt")
84
-
85
- with open(f'/tmp/vllm_{port}.txt', 'w') as log_file:
85
+
86
+ with open(f"/tmp/vllm_{port}.txt", "w") as log_file:
86
87
  log_file.write(f"VLLM Server started at {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
87
88
  log_file.write(f"Command: {cleaned_cmd}\n")
88
89
  if env_vars:
89
90
  log_file.write(f"Environment: {env_vars}\n")
90
91
  log_file.write(f"Port: {port}\n")
91
92
  log_file.write("-" * 50 + "\n")
92
-
93
+
93
94
  # Prepare environment for subprocess
94
95
  env = os.environ.copy()
95
96
  env.update(env_vars)
96
97
 
97
- with open(f'/tmp/vllm_{port}.txt', 'a') as log_file:
98
+ with open(f"/tmp/vllm_{port}.txt", "a") as log_file:
98
99
  process = subprocess.Popen(
99
- cleaned_cmd.split(),
100
- stdout=log_file,
101
- stderr=subprocess.STDOUT,
102
- text=True,
103
- preexec_fn=os.setsid,
104
- env=env
100
+ cleaned_cmd.split(), stdout=log_file, stderr=subprocess.STDOUT, text=True, preexec_fn=os.setsid, env=env
105
101
  )
106
-
102
+
107
103
  _VLLM_PROCESSES.append(process)
108
-
104
+
109
105
  start_time = time.time()
110
106
  while time.time() - start_time < timeout:
111
107
  try:
@@ -115,26 +111,24 @@ def _start_vllm_server(vllm_cmd: str, timeout: int = 120) -> subprocess.Popen:
115
111
  return process
116
112
  except requests.RequestException:
117
113
  pass
118
-
114
+
119
115
  if process.poll() is not None:
120
116
  stdout, stderr = process.communicate()
121
117
  raise RuntimeError(
122
- f"VLLM server terminated unexpectedly. "
123
- f"Return code: {process.returncode}, "
124
- f"stderr: {stderr[:200]}..."
118
+ f"VLLM server terminated unexpectedly. Return code: {process.returncode}, stderr: {stderr[:200]}..."
125
119
  )
126
-
120
+
127
121
  time.sleep(2)
128
-
122
+
129
123
  process.terminate()
130
124
  try:
131
125
  process.wait(timeout=5)
132
126
  except subprocess.TimeoutExpired:
133
127
  process.kill()
134
-
128
+
135
129
  if process in _VLLM_PROCESSES:
136
130
  _VLLM_PROCESSES.remove(process)
137
-
131
+
138
132
  raise RuntimeError(f"VLLM server failed to start within {timeout}s on port {port}")
139
133
 
140
134
 
@@ -142,7 +136,7 @@ def _kill_vllm_on_port(port: int) -> bool:
142
136
  """Kill VLLM server on port."""
143
137
  killed = False
144
138
  logger.info(f"Checking VLLM server on port {port}")
145
-
139
+
146
140
  processes_to_remove = []
147
141
  for process in _VLLM_PROCESSES:
148
142
  try:
@@ -151,8 +145,8 @@ def _kill_vllm_on_port(port: int) -> bool:
151
145
  if HAS_PSUTIL:
152
146
  try:
153
147
  proc = psutil.Process(process.pid)
154
- cmdline = ' '.join(proc.cmdline())
155
- if f'--port {port}' in cmdline or f'--port={port}' in cmdline:
148
+ cmdline = " ".join(proc.cmdline())
149
+ if f"--port {port}" in cmdline or f"--port={port}" in cmdline:
156
150
  logger.info(f"Killing tracked VLLM process {process.pid} on port {port}")
157
151
  os.killpg(os.getpgid(process.pid), signal.SIGTERM)
158
152
  try:
@@ -164,7 +158,7 @@ def _kill_vllm_on_port(port: int) -> bool:
164
158
  killed_process = True
165
159
  except (psutil.NoSuchProcess, psutil.AccessDenied):
166
160
  pass
167
-
161
+
168
162
  if not HAS_PSUTIL or not killed_process:
169
163
  logger.info(f"Killing tracked VLLM process {process.pid}")
170
164
  try:
@@ -177,24 +171,23 @@ def _kill_vllm_on_port(port: int) -> bool:
177
171
  killed = True
178
172
  except (ProcessLookupError, OSError):
179
173
  pass
180
-
174
+
181
175
  processes_to_remove.append(process)
182
176
  else:
183
177
  processes_to_remove.append(process)
184
178
  except (ProcessLookupError, OSError):
185
179
  processes_to_remove.append(process)
186
-
180
+
187
181
  for process in processes_to_remove:
188
182
  if process in _VLLM_PROCESSES:
189
183
  _VLLM_PROCESSES.remove(process)
190
-
184
+
191
185
  if not killed and HAS_PSUTIL:
192
186
  try:
193
- for proc in psutil.process_iter(['pid', 'cmdline']):
187
+ for proc in psutil.process_iter(["pid", "cmdline"]):
194
188
  try:
195
- cmdline = ' '.join(proc.info['cmdline'] or [])
196
- if ('vllm' in cmdline.lower() and
197
- (f'--port {port}' in cmdline or f'--port={port}' in cmdline)):
189
+ cmdline = " ".join(proc.info["cmdline"] or [])
190
+ if "vllm" in cmdline.lower() and (f"--port {port}" in cmdline or f"--port={port}" in cmdline):
198
191
  logger.info(f"Killing untracked VLLM process {proc.info['pid']} on port {port}")
199
192
  proc.terminate()
200
193
  try:
@@ -207,13 +200,13 @@ def _kill_vllm_on_port(port: int) -> bool:
207
200
  continue
208
201
  except Exception as e:
209
202
  logger.warning(f"Error searching processes on port {port}: {e}")
210
-
203
+
211
204
  if killed:
212
205
  logger.info(f"Killed VLLM server on port {port}")
213
206
  time.sleep(2)
214
207
  else:
215
208
  logger.info(f"No VLLM server on port {port}")
216
-
209
+
217
210
  return killed
218
211
 
219
212
 
@@ -262,32 +255,24 @@ def _is_server_running(port: int) -> bool:
262
255
  return False
263
256
 
264
257
 
265
- def get_base_client(
266
- client=None,
267
- cache: bool = True,
268
- api_key="abc",
269
- vllm_cmd=None,
270
- vllm_process=None
271
- ) -> OpenAI:
258
+ def get_base_client(client=None, cache: bool = True, api_key="abc", vllm_cmd=None, vllm_process=None) -> OpenAI:
272
259
  """Get OpenAI client from various inputs."""
273
260
  from llm_utils import MOpenAI
274
261
 
275
- open_ai_class = OpenAI if not cache else MOpenAI
276
-
277
262
  if client is None:
278
263
  if vllm_cmd is not None:
279
264
  # Parse environment variables from command to get clean command for port extraction
280
265
  _, cleaned_cmd = _parse_env_vars_from_cmd(vllm_cmd)
281
266
  port = _extract_port_from_vllm_cmd(cleaned_cmd)
282
- return open_ai_class(base_url=f"http://localhost:{port}/v1", api_key=api_key)
267
+ return MOpenAI(base_url=f"http://localhost:{port}/v1", api_key=api_key, cache=cache)
283
268
  else:
284
- return open_ai_class()
269
+ raise ValueError("Either client or vllm_cmd must be provided.")
285
270
  elif isinstance(client, int):
286
- return open_ai_class(base_url=f"http://localhost:{client}/v1", api_key=api_key)
271
+ return MOpenAI(base_url=f"http://localhost:{client}/v1", api_key=api_key, cache=cache)
287
272
  elif isinstance(client, str):
288
- return open_ai_class(base_url=client, api_key=api_key)
273
+ return MOpenAI(base_url=client, api_key=api_key, cache=cache)
289
274
  elif isinstance(client, OpenAI):
290
- return client
275
+ return MOpenAI(base_url=client.base_url, api_key=api_key, cache=cache)
291
276
  else:
292
277
  raise ValueError("Invalid client type. Must be OpenAI, port (int), base_url (str), or None.")
293
278
 
@@ -296,17 +281,17 @@ def _is_lora_path(path: str) -> bool:
296
281
  """Check if path is LoRA adapter directory."""
297
282
  if not os.path.isdir(path):
298
283
  return False
299
- adapter_config_path = os.path.join(path, 'adapter_config.json')
284
+ adapter_config_path = os.path.join(path, "adapter_config.json")
300
285
  return os.path.isfile(adapter_config_path)
301
286
 
302
287
 
303
288
  def _get_port_from_client(client: OpenAI) -> Optional[int]:
304
289
  """Extract port from OpenAI client base_url."""
305
- if hasattr(client, 'base_url') and client.base_url:
290
+ if hasattr(client, "base_url") and client.base_url:
306
291
  base_url = str(client.base_url)
307
- if 'localhost:' in base_url:
292
+ if "localhost:" in base_url:
308
293
  try:
309
- port_part = base_url.split('localhost:')[1].split('/')[0]
294
+ port_part = base_url.split("localhost:")[1].split("/")[0]
310
295
  return int(port_part)
311
296
  except (IndexError, ValueError):
312
297
  pass
@@ -315,14 +300,14 @@ def _get_port_from_client(client: OpenAI) -> Optional[int]:
315
300
 
316
301
  def _load_lora_adapter(lora_path: str, port: int) -> str:
317
302
  """Load LoRA adapter from path."""
318
- lora_name = os.path.basename(lora_path.rstrip('/\\'))
303
+ lora_name = os.path.basename(lora_path.rstrip("/\\"))
319
304
  if not lora_name:
320
305
  lora_name = os.path.basename(os.path.dirname(lora_path))
321
-
306
+
322
307
  response = requests.post(
323
- f'http://localhost:{port}/v1/load_lora_adapter',
324
- headers={'accept': 'application/json', 'Content-Type': 'application/json'},
325
- json={"lora_name": lora_name, "lora_path": os.path.abspath(lora_path)}
308
+ f"http://localhost:{port}/v1/load_lora_adapter",
309
+ headers={"accept": "application/json", "Content-Type": "application/json"},
310
+ json={"lora_name": lora_name, "lora_path": os.path.abspath(lora_path)},
326
311
  )
327
312
  response.raise_for_status()
328
313
  return lora_name
@@ -331,14 +316,14 @@ def _load_lora_adapter(lora_path: str, port: int) -> str:
331
316
  def _unload_lora_adapter(lora_path: str, port: int) -> None:
332
317
  """Unload LoRA adapter."""
333
318
  try:
334
- lora_name = os.path.basename(lora_path.rstrip('/\\'))
319
+ lora_name = os.path.basename(lora_path.rstrip("/\\"))
335
320
  if not lora_name:
336
321
  lora_name = os.path.basename(os.path.dirname(lora_path))
337
-
322
+
338
323
  response = requests.post(
339
- f'http://localhost:{port}/v1/unload_lora_adapter',
340
- headers={'accept': 'application/json', 'Content-Type': 'application/json'},
341
- json={"lora_name": lora_name, "lora_int_id": 0}
324
+ f"http://localhost:{port}/v1/unload_lora_adapter",
325
+ headers={"accept": "application/json", "Content-Type": "application/json"},
326
+ json={"lora_name": lora_name, "lora_int_id": 0},
342
327
  )
343
328
  response.raise_for_status()
344
329
  except requests.RequestException as e:
speedy_utils/__init__.py CHANGED
@@ -79,7 +79,24 @@ from glob import glob
79
79
  from multiprocessing import Pool
80
80
  from pathlib import Path
81
81
  from threading import Lock
82
- from typing import Any, Dict, Generic, List, Literal, Optional, TypeVar, Union
82
+ from typing import (
83
+ Any,
84
+ Awaitable,
85
+ Callable as TypingCallable,
86
+ Dict,
87
+ Generic,
88
+ Iterable,
89
+ List,
90
+ Literal,
91
+ Mapping,
92
+ Optional,
93
+ Sequence,
94
+ Set,
95
+ Tuple,
96
+ Type,
97
+ TypeVar,
98
+ Union,
99
+ )
83
100
 
84
101
  # Third-party imports
85
102
  import numpy as np
@@ -124,6 +141,7 @@ from .common.utils_io import (
124
141
  # Misc utilities
125
142
  from .common.utils_misc import (
126
143
  convert_to_builtin_python,
144
+ dedup,
127
145
  flatten_list,
128
146
  get_arg_names,
129
147
  is_notebook,
@@ -171,12 +189,20 @@ __all__ = [
171
189
  "defaultdict",
172
190
  # Typing
173
191
  "Any",
192
+ "Awaitable",
174
193
  "Callable",
194
+ "TypingCallable",
175
195
  "Dict",
176
196
  "Generic",
197
+ "Iterable",
177
198
  "List",
178
199
  "Literal",
200
+ "Mapping",
179
201
  "Optional",
202
+ "Sequence",
203
+ "Set",
204
+ "Tuple",
205
+ "Type",
180
206
  "TypeVar",
181
207
  "Union",
182
208
  # Third-party
@@ -214,6 +240,7 @@ __all__ = [
214
240
  "get_arg_names",
215
241
  "is_notebook",
216
242
  "convert_to_builtin_python",
243
+ "dedup",
217
244
  # Print utilities
218
245
  "display_pretty_table_html",
219
246
  "flatten_dict",
speedy_utils/all.py CHANGED
@@ -71,7 +71,24 @@ from glob import glob
71
71
  from multiprocessing import Pool
72
72
  from pathlib import Path
73
73
  from threading import Lock
74
- from typing import Any, Dict, Generic, List, Literal, Optional, TypeVar, Union
74
+ from typing import (
75
+ Any,
76
+ Awaitable,
77
+ Callable as TypingCallable,
78
+ Dict,
79
+ Generic,
80
+ Iterable,
81
+ List,
82
+ Literal,
83
+ Mapping,
84
+ Optional,
85
+ Sequence,
86
+ Set,
87
+ Tuple,
88
+ Type,
89
+ TypeVar,
90
+ Union,
91
+ )
75
92
 
76
93
  # Third-party imports
77
94
  import numpy as np
@@ -115,6 +132,9 @@ from speedy_utils import ( # Clock module; Function decorators; Cache utilities
115
132
  timef,
116
133
  )
117
134
 
135
+
136
+ choice = random.choice
137
+
118
138
  # Define __all__ explicitly with all exports
119
139
  __all__ = [
120
140
  # Standard library
@@ -146,12 +166,20 @@ __all__ = [
146
166
  "defaultdict",
147
167
  # Typing
148
168
  "Any",
169
+ "Awaitable",
149
170
  "Callable",
171
+ "TypingCallable",
150
172
  "Dict",
151
173
  "Generic",
174
+ "Iterable",
152
175
  "List",
153
176
  "Literal",
177
+ "Mapping",
154
178
  "Optional",
179
+ "Sequence",
180
+ "Set",
181
+ "Tuple",
182
+ "Type",
155
183
  "TypeVar",
156
184
  "Union",
157
185
  # Third-party
@@ -199,4 +227,5 @@ __all__ = [
199
227
  # Multi-worker processing
200
228
  "multi_process",
201
229
  "multi_thread",
230
+ "choice",
202
231
  ]
@@ -29,9 +29,7 @@ def dump_jsonl(list_dictionaries: list[dict], file_name: str = "output.jsonl") -
29
29
  file.write(json.dumps(dictionary, ensure_ascii=False) + "\n")
30
30
 
31
31
 
32
- def dump_json_or_pickle(
33
- obj: Any, fname: str, ensure_ascii: bool = False, indent: int = 4
34
- ) -> None:
32
+ def dump_json_or_pickle(obj: Any, fname: str, ensure_ascii: bool = False, indent: int = 4) -> None:
35
33
  """
36
34
  Dump an object to a file, supporting both JSON and pickle formats.
37
35
  """
@@ -59,6 +57,7 @@ def dump_json_or_pickle(
59
57
  if isinstance(obj, BaseModel):
60
58
  data = obj.model_dump()
61
59
  from fastcore.all import dict2obj, obj2dict
60
+
62
61
  obj2 = dict2obj(data)
63
62
  with open(fname, "wb") as f:
64
63
  pickle.dump(obj2, f)
@@ -84,7 +83,8 @@ def load_json_or_pickle(fname: str, counter=0) -> Any:
84
83
  except EOFError:
85
84
  time.sleep(1)
86
85
  if counter > 5:
87
- print("Error: Ran out of input", fname)
86
+ # Keep message concise and actionable
87
+ print(f"Corrupted cache file {fname} removed; it will be regenerated on next access")
88
88
  os.remove(fname)
89
89
  raise
90
90
  return load_json_or_pickle(fname, counter + 1)
@@ -92,8 +92,6 @@ def load_json_or_pickle(fname: str, counter=0) -> Any:
92
92
  raise ValueError(f"Error {e} while loading {fname}") from e
93
93
 
94
94
 
95
-
96
-
97
95
  try:
98
96
  import orjson # type: ignore[import-not-found] # fastest JSON parser when available
99
97
  except Exception:
@@ -113,11 +111,11 @@ def fast_load_jsonl(
113
111
  use_orjson: bool = True,
114
112
  encoding: str = "utf-8",
115
113
  errors: str = "strict",
116
- on_error: str = "raise", # 'raise' | 'warn' | 'skip'
114
+ on_error: str = "raise", # 'raise' | 'warn' | 'skip'
117
115
  skip_empty: bool = True,
118
116
  max_lines: Optional[int] = None,
119
117
  use_multiworker: bool = True,
120
- multiworker_threshold: int = 50000,
118
+ multiworker_threshold: int = 1000000,
121
119
  workers: Optional[int] = None,
122
120
  ) -> Iterable[Any]:
123
121
  """
@@ -127,7 +125,7 @@ def fast_load_jsonl(
127
125
  - Optional tqdm progress over bytes (compressed size if gz/bz2/xz/zst).
128
126
  - Auto-detects compression by extension: .gz, .bz2, .xz/.lzma, .zst/.zstd.
129
127
  - Uses orjson if available (use_orjson=True), falls back to json.
130
- - Automatically uses multi-worker processing for large files (>50k lines).
128
+ - Automatically uses multi-worker processing for large files (>100k lines).
131
129
 
132
130
  Args:
133
131
  path_or_file: Path-like or file-like object. File-like can be binary or text.
@@ -140,11 +138,12 @@ def fast_load_jsonl(
140
138
  max_lines: Stop after reading this many lines (useful for sampling).
141
139
  use_multiworker: Enable multi-worker processing for large files.
142
140
  multiworker_threshold: Line count threshold to trigger multi-worker processing.
143
- workers: Number of worker threads (defaults to CPU count).
141
+ workers: Number of worker threads (defaults to 80% of CPU count, max 8).
144
142
 
145
143
  Yields:
146
144
  Parsed Python objects per line.
147
145
  """
146
+
148
147
  def _open_auto(pth_or_f) -> IO[Any]:
149
148
  if hasattr(pth_or_f, "read"):
150
149
  # ensure binary buffer for consistent byte-length progress
@@ -206,39 +205,47 @@ def fast_load_jsonl(
206
205
 
207
206
  # Check if we should use multi-worker processing
208
207
  should_use_multiworker = (
209
- use_multiworker
208
+ use_multiworker
210
209
  and not hasattr(path_or_file, "read") # Only for file paths, not file objects
211
210
  and max_lines is None # Don't use multiworker if we're limiting lines
212
211
  )
213
-
212
+
214
213
  if should_use_multiworker:
215
214
  line_count = _count_lines_fast(cast(Union[str, os.PathLike], path_or_file))
216
215
  if line_count > multiworker_threshold:
217
216
  # Use multi-worker processing
218
217
  from ..multi_worker.thread import multi_thread
219
218
 
219
+ # Calculate optimal worker count: 80% of CPU count, capped at 8
220
+ cpu_count = os.cpu_count() or 4
221
+ default_workers = min(int(cpu_count * 0.8), 8)
222
+ num_workers = workers if workers is not None else default_workers
223
+ num_workers = max(1, num_workers) # At least 1 worker
224
+
220
225
  # Read all lines into chunks
221
226
  f = _open_auto(path_or_file)
222
227
  all_lines = list(f)
223
228
  f.close()
224
-
225
- # Split into chunks for workers
226
- num_workers = workers or os.cpu_count() or 4
227
- chunk_size = max(len(all_lines) // num_workers, 1000)
229
+
230
+ # Split into chunks - aim for ~10k-20k lines per chunk minimum
231
+ min_chunk_size = 10000
232
+ chunk_size = max(len(all_lines) // num_workers, min_chunk_size)
228
233
  chunks = []
229
234
  for i in range(0, len(all_lines), chunk_size):
230
- chunks.append(all_lines[i:i + chunk_size])
231
-
235
+ chunks.append(all_lines[i : i + chunk_size])
236
+
232
237
  # Process chunks in parallel
233
238
  if progress:
234
- print(f"Processing {line_count} lines with {num_workers} workers...")
235
-
239
+ print(f"Processing {line_count} lines with {num_workers} workers ({len(chunks)} chunks)...")
240
+
236
241
  chunk_results = multi_thread(_process_chunk, chunks, workers=num_workers, progress=progress)
237
-
242
+
238
243
  # Flatten results and yield
239
- for chunk_result in chunk_results:
240
- for obj in chunk_result:
241
- yield obj
244
+ if chunk_results:
245
+ for chunk_result in chunk_results:
246
+ if chunk_result:
247
+ for obj in chunk_result:
248
+ yield obj
242
249
  return
243
250
 
244
251
  # Single-threaded processing (original logic)
@@ -266,7 +273,11 @@ def fast_load_jsonl(
266
273
  line_no += 1
267
274
  if pbar is not None:
268
275
  # raw_line is bytes here; if not, compute byte length
269
- nbytes = len(raw_line) if isinstance(raw_line, (bytes, bytearray)) else len(str(raw_line).encode(encoding, errors))
276
+ nbytes = (
277
+ len(raw_line)
278
+ if isinstance(raw_line, (bytes, bytearray))
279
+ else len(str(raw_line).encode(encoding, errors))
280
+ )
270
281
  pbar.update(nbytes)
271
282
 
272
283
  # Normalize to bytes -> str only if needed
@@ -322,7 +333,6 @@ def fast_load_jsonl(
322
333
  pass
323
334
 
324
335
 
325
-
326
336
  def load_by_ext(fname: Union[str, list[str]], do_memoize: bool = False) -> Any:
327
337
  """
328
338
  Load data based on file extension.
@@ -3,10 +3,12 @@
3
3
  import inspect
4
4
  import os
5
5
  from collections.abc import Callable
6
- from typing import Any
6
+ from typing import Any, TypeVar
7
7
 
8
8
  from pydantic import BaseModel
9
9
 
10
+ T = TypeVar("T")
11
+
10
12
 
11
13
  def mkdir_or_exist(dir_name: str) -> None:
12
14
  """Create a directory if it doesn't exist."""
@@ -50,10 +52,32 @@ def convert_to_builtin_python(input_data: Any) -> Any:
50
52
  raise ValueError(f"Unsupported type {type(input_data)}")
51
53
 
52
54
 
55
+ def dedup(items: list[T], key: Callable[[T], Any]) -> list[T]:
56
+ """
57
+ Deduplicate items in a list based on a key function.
58
+
59
+ Args:
60
+ items: The list of items.
61
+ key: A function that takes an item and returns a hashable key.
62
+
63
+ Returns:
64
+ A list with duplicates removed, preserving the first occurrence.
65
+ """
66
+ seen = set()
67
+ result = []
68
+ for item in items:
69
+ k = key(item)
70
+ if k not in seen:
71
+ seen.add(k)
72
+ result.append(item)
73
+ return result
74
+
75
+
53
76
  __all__ = [
54
77
  "mkdir_or_exist",
55
78
  "flatten_list",
56
79
  "get_arg_names",
57
80
  "is_notebook",
58
81
  "convert_to_builtin_python",
82
+ "dedup",
59
83
  ]