dtflow 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dtflow/streaming.py ADDED
@@ -0,0 +1,661 @@
1
+ """
2
+ 流式处理模块
3
+
4
+ 支持大文件的惰性处理,避免全量加载内存。
5
+ 支持格式:JSONL, CSV, Parquet, Arrow
6
+ """
7
+ import glob
8
+ import os
9
+ from pathlib import Path
10
+ from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Union
11
+
12
+ import orjson
13
+ import polars as pl
14
+ from rich.progress import (
15
+ Progress,
16
+ SpinnerColumn,
17
+ TextColumn,
18
+ BarColumn,
19
+ TaskProgressColumn,
20
+ TimeElapsedColumn,
21
+ TimeRemainingColumn,
22
+ MofNCompleteColumn,
23
+ )
24
+
25
+ # 支持的流式格式
26
+ STREAMING_FORMATS = {".jsonl", ".csv", ".parquet", ".arrow", ".feather"}
27
+
28
+
29
+ def _count_rows_fast(filepath: str) -> Optional[int]:
30
+ """快速统计文件行数(不加载数据)"""
31
+ path = Path(filepath)
32
+ ext = path.suffix.lower()
33
+
34
+ try:
35
+ if ext == ".jsonl":
36
+ # JSONL: 直接数换行符
37
+ with open(filepath, "rb") as f:
38
+ return sum(1 for line in f if line.strip())
39
+ elif ext == ".csv":
40
+ # CSV: Polars LazyFrame
41
+ return pl.scan_csv(filepath).select(pl.len()).collect().item()
42
+ elif ext == ".parquet":
43
+ # Parquet: Polars LazyFrame
44
+ return pl.scan_parquet(filepath).select(pl.len()).collect().item()
45
+ elif ext in (".arrow", ".feather"):
46
+ # Arrow: Polars LazyFrame
47
+ return pl.scan_ipc(filepath).select(pl.len()).collect().item()
48
+ except Exception:
49
+ pass
50
+ return None
51
+
52
+
53
+ class StreamingTransformer:
54
+ """
55
+ 流式数据转换器。
56
+
57
+ 使用 generator 实现惰性处理,适合处理超大文件。
58
+ 内存占用 O(1),不会随文件大小增长。
59
+
60
+ Examples:
61
+ >>> st = StreamingTransformer.load_stream("huge_100gb.jsonl")
62
+ >>> (st
63
+ ... .filter(lambda x: x["score"] > 0.5)
64
+ ... .transform(lambda x: {"text": x["content"]})
65
+ ... .save("output.jsonl"))
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ iterator: Iterator[Dict[str, Any]],
71
+ source_path: Optional[str] = None,
72
+ total: Optional[int] = None,
73
+ ):
74
+ """
75
+ 初始化流式转换器。
76
+
77
+ Args:
78
+ iterator: 数据迭代器
79
+ source_path: 源文件路径(用于元数据)
80
+ total: 总行数(用于进度条,可选)
81
+ """
82
+ self._iterator = iterator
83
+ self._source_path = source_path
84
+ self._total = total
85
+ self._operations: List[Dict[str, Any]] = []
86
+
87
+ @classmethod
88
+ def load_stream(
89
+ cls, filepath: str, batch_size: int = 10000
90
+ ) -> "StreamingTransformer":
91
+ """
92
+ 流式加载文件。
93
+
94
+ 支持 JSONL、CSV、Parquet、Arrow 格式。
95
+
96
+ Args:
97
+ filepath: 文件路径
98
+ batch_size: 批量读取大小(CSV/Parquet/Arrow)
99
+
100
+ Returns:
101
+ StreamingTransformer 实例
102
+ """
103
+ path = Path(filepath)
104
+ if not path.exists():
105
+ raise FileNotFoundError(f"文件不存在: {filepath}")
106
+
107
+ ext = path.suffix.lower()
108
+ if ext not in STREAMING_FORMATS:
109
+ raise ValueError(f"不支持的流式格式: {ext},支持: {STREAMING_FORMATS}")
110
+
111
+ # 快速统计总行数(用于进度条)
112
+ total = _count_rows_fast(filepath)
113
+
114
+ if ext == ".jsonl":
115
+ return cls(_stream_jsonl(filepath), source_path=filepath, total=total)
116
+ elif ext == ".csv":
117
+ return cls(_stream_csv(filepath, batch_size), source_path=filepath, total=total)
118
+ elif ext == ".parquet":
119
+ return cls(_stream_parquet(filepath, batch_size), source_path=filepath, total=total)
120
+ elif ext in (".arrow", ".feather"):
121
+ return cls(_stream_arrow(filepath), source_path=filepath, total=total)
122
+ else:
123
+ raise ValueError(f"未知格式: {ext}")
124
+
125
+ @classmethod
126
+ def load_sharded(
127
+ cls, pattern: str, batch_size: int = 10000
128
+ ) -> "StreamingTransformer":
129
+ """
130
+ 加载分片文件(支持 glob 模式)。
131
+
132
+ 支持 JSONL、CSV、Parquet、Arrow 格式(根据扩展名自动检测)。
133
+
134
+ Args:
135
+ pattern: glob 模式,如 "data_*.jsonl" 或 "shards/part-*.parquet"
136
+ batch_size: 批量读取大小(CSV/Parquet/Arrow)
137
+
138
+ Returns:
139
+ StreamingTransformer 实例
140
+
141
+ Examples:
142
+ >>> st = StreamingTransformer.load_sharded("data/train_*.jsonl")
143
+ >>> st = StreamingTransformer.load_sharded("shards/part-*.parquet")
144
+ """
145
+ files = sorted(glob.glob(pattern))
146
+ if not files:
147
+ raise FileNotFoundError(f"没有匹配的文件: {pattern}")
148
+
149
+ def generator():
150
+ for filepath in files:
151
+ ext = Path(filepath).suffix.lower()
152
+ if ext == ".jsonl":
153
+ yield from _stream_jsonl(filepath)
154
+ elif ext == ".csv":
155
+ yield from _stream_csv(filepath, batch_size)
156
+ elif ext == ".parquet":
157
+ yield from _stream_parquet(filepath, batch_size)
158
+ elif ext in (".arrow", ".feather"):
159
+ yield from _stream_arrow(filepath)
160
+ else:
161
+ # 默认当作 JSONL
162
+ yield from _stream_jsonl(filepath)
163
+
164
+ return cls(generator(), source_path=pattern)
165
+
166
+ def filter(self, func: Callable[[Dict], bool]) -> "StreamingTransformer":
167
+ """
168
+ 惰性过滤。
169
+
170
+ Args:
171
+ func: 过滤函数,返回 True 保留
172
+
173
+ Returns:
174
+ 新的 StreamingTransformer(惰性,不立即执行)
175
+ """
176
+ def filtered_iterator():
177
+ for item in self._iterator:
178
+ try:
179
+ if func(item):
180
+ yield item
181
+ except Exception:
182
+ pass # 跳过错误
183
+
184
+ # 过滤后数量未知,不传递 total
185
+ new_st = StreamingTransformer(filtered_iterator(), self._source_path, total=None)
186
+ new_st._operations = self._operations + [{"type": "filter", "func": func}]
187
+ return new_st
188
+
189
+ def transform(self, func: Callable[[Dict], Dict]) -> "StreamingTransformer":
190
+ """
191
+ 惰性转换。
192
+
193
+ Args:
194
+ func: 转换函数
195
+
196
+ Returns:
197
+ 新的 StreamingTransformer(惰性,不立即执行)
198
+ """
199
+ def transformed_iterator():
200
+ for item in self._iterator:
201
+ try:
202
+ yield func(item)
203
+ except Exception:
204
+ pass # 跳过错误
205
+
206
+ # transform 是 1:1 转换,保留 total
207
+ new_st = StreamingTransformer(transformed_iterator(), self._source_path, total=self._total)
208
+ new_st._operations = self._operations + [{"type": "transform", "func": func}]
209
+ return new_st
210
+
211
+ def head(self, n: int) -> "StreamingTransformer":
212
+ """
213
+ 惰性取前 N 条。
214
+
215
+ Args:
216
+ n: 数量
217
+
218
+ Returns:
219
+ 新的 StreamingTransformer
220
+ """
221
+ def head_iterator():
222
+ count = 0
223
+ for item in self._iterator:
224
+ if count >= n:
225
+ break
226
+ yield item
227
+ count += 1
228
+
229
+ # head(n) 的 total 是 min(n, original_total)
230
+ new_total = min(n, self._total) if self._total is not None else n
231
+ new_st = StreamingTransformer(head_iterator(), self._source_path, total=new_total)
232
+ new_st._operations = self._operations + [{"type": "head", "n": n}]
233
+ return new_st
234
+
235
+ def skip(self, n: int) -> "StreamingTransformer":
236
+ """
237
+ 惰性跳过前 N 条。
238
+
239
+ Args:
240
+ n: 跳过数量
241
+
242
+ Returns:
243
+ 新的 StreamingTransformer
244
+ """
245
+ def skip_iterator():
246
+ count = 0
247
+ for item in self._iterator:
248
+ if count < n:
249
+ count += 1
250
+ continue
251
+ yield item
252
+
253
+ # skip(n) 的 total 是 max(0, original_total - n)
254
+ new_total = max(0, self._total - n) if self._total is not None else None
255
+ new_st = StreamingTransformer(skip_iterator(), self._source_path, total=new_total)
256
+ new_st._operations = self._operations + [{"type": "skip", "n": n}]
257
+ return new_st
258
+
259
+ def batch(self, size: int) -> Generator[List[Dict], None, None]:
260
+ """
261
+ 分批迭代(用于批量处理场景)。
262
+
263
+ Args:
264
+ size: 批次大小
265
+
266
+ Yields:
267
+ 数据批次列表
268
+
269
+ Examples:
270
+ >>> for batch in st.batch(1000):
271
+ ... process_batch(batch)
272
+ """
273
+ batch = []
274
+ for item in self._iterator:
275
+ batch.append(item)
276
+ if len(batch) >= size:
277
+ yield batch
278
+ batch = []
279
+ if batch:
280
+ yield batch
281
+
282
+ def save(
283
+ self, filepath: str, show_progress: bool = True, batch_size: int = 10000
284
+ ) -> int:
285
+ """
286
+ 流式保存到文件。
287
+
288
+ 支持 JSONL、CSV、Parquet、Arrow 格式(根据扩展名自动检测)。
289
+
290
+ Args:
291
+ filepath: 输出文件路径
292
+ show_progress: 是否显示进度
293
+ batch_size: 批量写入大小(CSV/Parquet/Arrow)
294
+
295
+ Returns:
296
+ 写入的记录数
297
+ """
298
+ path = Path(filepath)
299
+ path.parent.mkdir(parents=True, exist_ok=True)
300
+ ext = path.suffix.lower()
301
+
302
+ if ext == ".jsonl":
303
+ return self._save_jsonl(filepath, show_progress)
304
+ elif ext == ".csv":
305
+ return self._save_batched(filepath, "csv", batch_size, show_progress)
306
+ elif ext == ".parquet":
307
+ return self._save_batched(filepath, "parquet", batch_size, show_progress)
308
+ elif ext in (".arrow", ".feather"):
309
+ return self._save_batched(filepath, "arrow", batch_size, show_progress)
310
+ else:
311
+ # 默认 JSONL
312
+ return self._save_jsonl(filepath, show_progress)
313
+
314
+ def _save_jsonl(self, filepath: str, show_progress: bool) -> int:
315
+ """JSONL 逐行流式保存(使用 orjson)"""
316
+ count = 0
317
+
318
+ if show_progress:
319
+ # 根据是否有总数选择进度条样式
320
+ if self._total is not None:
321
+ # 有总数:显示进度条、百分比、剩余时间
322
+ columns = [
323
+ SpinnerColumn(),
324
+ TextColumn("[progress.description]{task.description}"),
325
+ BarColumn(),
326
+ TaskProgressColumn(),
327
+ MofNCompleteColumn(),
328
+ TimeElapsedColumn(),
329
+ TimeRemainingColumn(),
330
+ ]
331
+ else:
332
+ # 无总数:只显示已处理数量
333
+ columns = [
334
+ SpinnerColumn(),
335
+ TextColumn("[progress.description]{task.description}"),
336
+ MofNCompleteColumn(),
337
+ TimeElapsedColumn(),
338
+ ]
339
+
340
+ with Progress(*columns) as progress:
341
+ task = progress.add_task("处理中", total=self._total)
342
+ with open(filepath, "wb") as f:
343
+ for item in self._iterator:
344
+ f.write(orjson.dumps(item) + b"\n")
345
+ count += 1
346
+ progress.update(task, advance=1)
347
+ else:
348
+ with open(filepath, "wb") as f:
349
+ for item in self._iterator:
350
+ f.write(orjson.dumps(item) + b"\n")
351
+ count += 1
352
+
353
+ return count
354
+
355
+ def _save_batched(
356
+ self, filepath: str, fmt: str, batch_size: int, show_progress: bool
357
+ ) -> int:
358
+ """
359
+ 批量流式保存(CSV/Parquet/Arrow)。
360
+
361
+ 读取和处理是流式的,写入时收集后一次性写入。
362
+ """
363
+ path = Path(filepath)
364
+ all_items = []
365
+
366
+ if show_progress:
367
+ # 根据是否有总数选择进度条样式
368
+ if self._total is not None:
369
+ columns = [
370
+ SpinnerColumn(),
371
+ TextColumn("[progress.description]{task.description}"),
372
+ BarColumn(),
373
+ TaskProgressColumn(),
374
+ MofNCompleteColumn(),
375
+ TimeElapsedColumn(),
376
+ TimeRemainingColumn(),
377
+ ]
378
+ else:
379
+ columns = [
380
+ SpinnerColumn(),
381
+ TextColumn("[progress.description]{task.description}"),
382
+ MofNCompleteColumn(),
383
+ TimeElapsedColumn(),
384
+ ]
385
+
386
+ with Progress(*columns) as progress:
387
+ task = progress.add_task("处理中", total=self._total)
388
+ for item in self._iterator:
389
+ all_items.append(item)
390
+ progress.update(task, advance=1)
391
+ else:
392
+ for item in self._iterator:
393
+ all_items.append(item)
394
+
395
+ if all_items:
396
+ df = pl.DataFrame(all_items)
397
+ if fmt == "csv":
398
+ df.write_csv(path)
399
+ elif fmt == "parquet":
400
+ df.write_parquet(path)
401
+ elif fmt == "arrow":
402
+ df.write_ipc(path)
403
+
404
+ return len(all_items)
405
+
406
+ def save_sharded(
407
+ self,
408
+ output_dir: str,
409
+ shard_size: int = 100000,
410
+ prefix: str = "part",
411
+ show_progress: bool = True,
412
+ ) -> List[str]:
413
+ """
414
+ 分片保存。
415
+
416
+ Args:
417
+ output_dir: 输出目录
418
+ shard_size: 每个分片的记录数
419
+ prefix: 分片文件前缀
420
+ show_progress: 是否显示进度
421
+
422
+ Returns:
423
+ 生成的分片文件路径列表
424
+
425
+ Examples:
426
+ >>> files = st.save_sharded("output/", shard_size=100000)
427
+ >>> # 生成: output/part-00000.jsonl, output/part-00001.jsonl, ...
428
+ """
429
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
430
+
431
+ shard_files = []
432
+ shard_idx = 0
433
+ count_in_shard = 0
434
+ current_file = None
435
+
436
+ def process_items(progress=None, task=None):
437
+ nonlocal shard_idx, count_in_shard, current_file
438
+
439
+ for item in self._iterator:
440
+ # 需要新分片
441
+ if current_file is None or count_in_shard >= shard_size:
442
+ if current_file:
443
+ current_file.close()
444
+
445
+ shard_path = os.path.join(output_dir, f"{prefix}-{shard_idx:05d}.jsonl")
446
+ shard_files.append(shard_path)
447
+ current_file = open(shard_path, "wb")
448
+ shard_idx += 1
449
+ count_in_shard = 0
450
+ if progress is not None:
451
+ progress.update(task, description=f"分片 {shard_idx}")
452
+
453
+ current_file.write(orjson.dumps(item) + b"\n")
454
+ count_in_shard += 1
455
+ if progress is not None:
456
+ progress.update(task, advance=1)
457
+
458
+ try:
459
+ if show_progress:
460
+ if self._total is not None:
461
+ columns = [
462
+ SpinnerColumn(),
463
+ TextColumn("[progress.description]{task.description}"),
464
+ BarColumn(),
465
+ TaskProgressColumn(),
466
+ MofNCompleteColumn(),
467
+ TimeElapsedColumn(),
468
+ TimeRemainingColumn(),
469
+ ]
470
+ else:
471
+ columns = [
472
+ SpinnerColumn(),
473
+ TextColumn("[progress.description]{task.description}"),
474
+ MofNCompleteColumn(),
475
+ TimeElapsedColumn(),
476
+ ]
477
+
478
+ with Progress(*columns) as progress:
479
+ task = progress.add_task("分片 1", total=self._total)
480
+ process_items(progress, task)
481
+ else:
482
+ process_items()
483
+ finally:
484
+ if current_file:
485
+ current_file.close()
486
+
487
+ return shard_files
488
+
489
+ def collect(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
490
+ """
491
+ 收集所有数据到内存(注意内存占用)。
492
+
493
+ Args:
494
+ limit: 最大收集数量,None 表示全部
495
+
496
+ Returns:
497
+ 数据列表
498
+ """
499
+ result = []
500
+ for item in self._iterator:
501
+ result.append(item)
502
+ if limit and len(result) >= limit:
503
+ break
504
+ return result
505
+
506
+ def count(self) -> int:
507
+ """
508
+ 计数(会消耗迭代器)。
509
+
510
+ Returns:
511
+ 记录数
512
+ """
513
+ count = 0
514
+ for _ in self._iterator:
515
+ count += 1
516
+ return count
517
+
518
+ def __iter__(self) -> Iterator[Dict[str, Any]]:
519
+ """支持直接迭代"""
520
+ return self._iterator
521
+
522
+
523
+ # ============ 便捷函数 ============
524
+
525
+
526
+ def load_stream(filepath: str, batch_size: int = 10000) -> StreamingTransformer:
527
+ """
528
+ 流式加载文件。
529
+
530
+ 支持 JSONL、CSV、Parquet、Arrow 格式。
531
+
532
+ Args:
533
+ filepath: 文件路径
534
+ batch_size: 批量读取大小(CSV/Parquet/Arrow)
535
+
536
+ Returns:
537
+ StreamingTransformer 实例
538
+
539
+ Examples:
540
+ >>> from dtflow import load_stream
541
+ >>> (load_stream("huge.jsonl")
542
+ ... .filter(lambda x: x["score"] > 0.5)
543
+ ... .save("filtered.jsonl"))
544
+ >>> (load_stream("data.csv")
545
+ ... .filter(lambda x: x["score"] > 0.5)
546
+ ... .save("output.parquet"))
547
+ """
548
+ return StreamingTransformer.load_stream(filepath, batch_size)
549
+
550
+
551
+ def load_sharded(pattern: str, batch_size: int = 10000) -> StreamingTransformer:
552
+ """
553
+ 加载分片文件。
554
+
555
+ 支持 JSONL、CSV、Parquet、Arrow 格式。
556
+
557
+ Args:
558
+ pattern: glob 模式
559
+ batch_size: 批量读取大小(CSV/Parquet/Arrow)
560
+
561
+ Returns:
562
+ StreamingTransformer 实例
563
+
564
+ Examples:
565
+ >>> from dtflow import load_sharded
566
+ >>> load_sharded("data/*.jsonl").save("merged.jsonl")
567
+ >>> load_sharded("data/*.parquet").save("merged.parquet")
568
+ """
569
+ return StreamingTransformer.load_sharded(pattern, batch_size)
570
+
571
+
572
+ def process_shards(
573
+ input_pattern: str,
574
+ output_dir: str,
575
+ func: Callable[[Dict], Optional[Dict]],
576
+ workers: int = 1,
577
+ shard_size: int = 100000,
578
+ ) -> List[str]:
579
+ """
580
+ 并行处理分片文件。
581
+
582
+ Args:
583
+ input_pattern: 输入文件 glob 模式
584
+ output_dir: 输出目录
585
+ func: 处理函数,返回 None 表示过滤掉
586
+ workers: 并行工作进程数(目前仅支持 1)
587
+ shard_size: 输出分片大小
588
+
589
+ Returns:
590
+ 生成的输出文件列表
591
+
592
+ Examples:
593
+ >>> def process(item):
594
+ ... if item["score"] > 0.5:
595
+ ... return {"text": item["content"]}
596
+ ... return None
597
+ >>> process_shards("input/*.jsonl", "output/", process)
598
+ """
599
+ # 简单实现:串行处理
600
+ # TODO: 未来可以添加多进程支持
601
+
602
+ def transform_func(item):
603
+ result = func(item)
604
+ return result
605
+
606
+ return (
607
+ load_sharded(input_pattern)
608
+ .transform(transform_func)
609
+ .filter(lambda x: x is not None)
610
+ .save_sharded(output_dir, shard_size=shard_size)
611
+ )
612
+
613
+
614
+ # ============ 流式读取函数 ============
615
+
616
+
617
+ def _stream_jsonl(filepath: str) -> Generator[Dict[str, Any], None, None]:
618
+ """JSONL 流式读取(使用 orjson)"""
619
+ with open(filepath, "rb") as f:
620
+ for line in f:
621
+ line = line.strip()
622
+ if line:
623
+ yield orjson.loads(line)
624
+
625
+
626
+ def _stream_csv(
627
+ filepath: str, batch_size: int = 10000
628
+ ) -> Generator[Dict[str, Any], None, None]:
629
+ """CSV 流式读取(使用 Polars BatchedCsvReader)"""
630
+ reader = pl.read_csv_batched(filepath, batch_size=batch_size)
631
+ while True:
632
+ batches = reader.next_batches(1)
633
+ if not batches:
634
+ break
635
+ for row in batches[0].to_dicts():
636
+ yield row
637
+
638
+
639
+ def _stream_parquet(
640
+ filepath: str, batch_size: int = 10000
641
+ ) -> Generator[Dict[str, Any], None, None]:
642
+ """Parquet 流式读取(使用 PyArrow iter_batches)"""
643
+ import pyarrow.parquet as pq
644
+
645
+ pf = pq.ParquetFile(filepath)
646
+ for batch in pf.iter_batches(batch_size=batch_size):
647
+ df = pl.from_arrow(batch)
648
+ for row in df.to_dicts():
649
+ yield row
650
+
651
+
652
+ def _stream_arrow(filepath: str) -> Generator[Dict[str, Any], None, None]:
653
+ """Arrow/Feather 流式读取(使用 PyArrow IPC)"""
654
+ import pyarrow as pa
655
+
656
+ with pa.ipc.open_file(filepath) as reader:
657
+ for i in range(reader.num_record_batches):
658
+ batch = reader.get_batch(i)
659
+ df = pl.from_arrow(batch)
660
+ for row in df.to_dicts():
661
+ yield row