dtflow 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtflow/__init__.py +36 -2
- dtflow/__main__.py +292 -239
- dtflow/cli/__init__.py +8 -2
- dtflow/cli/commands.py +1030 -92
- dtflow/converters.py +456 -0
- dtflow/core.py +96 -31
- dtflow/lineage.py +407 -0
- dtflow/mcp/cli.py +14 -14
- dtflow/pipeline.py +450 -0
- dtflow/storage/io.py +376 -370
- dtflow/streaming.py +661 -0
- dtflow/tokenizers.py +387 -31
- dtflow/utils/display.py +5 -4
- {dtflow-0.2.0.dist-info → dtflow-0.3.1.dist-info}/METADATA +234 -15
- dtflow-0.3.1.dist-info/RECORD +24 -0
- dtflow-0.2.0.dist-info/RECORD +0 -21
- {dtflow-0.2.0.dist-info → dtflow-0.3.1.dist-info}/WHEEL +0 -0
- {dtflow-0.2.0.dist-info → dtflow-0.3.1.dist-info}/entry_points.txt +0 -0
dtflow/streaming.py
ADDED
|
@@ -0,0 +1,661 @@
|
|
|
1
|
+
"""
|
|
2
|
+
流式处理模块
|
|
3
|
+
|
|
4
|
+
支持大文件的惰性处理,避免全量加载内存。
|
|
5
|
+
支持格式:JSONL, CSV, Parquet, Arrow
|
|
6
|
+
"""
|
|
7
|
+
import glob
|
|
8
|
+
import os
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Union
|
|
11
|
+
|
|
12
|
+
import orjson
|
|
13
|
+
import polars as pl
|
|
14
|
+
from rich.progress import (
|
|
15
|
+
Progress,
|
|
16
|
+
SpinnerColumn,
|
|
17
|
+
TextColumn,
|
|
18
|
+
BarColumn,
|
|
19
|
+
TaskProgressColumn,
|
|
20
|
+
TimeElapsedColumn,
|
|
21
|
+
TimeRemainingColumn,
|
|
22
|
+
MofNCompleteColumn,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
# 支持的流式格式
|
|
26
|
+
STREAMING_FORMATS = {".jsonl", ".csv", ".parquet", ".arrow", ".feather"}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _count_rows_fast(filepath: str) -> Optional[int]:
|
|
30
|
+
"""快速统计文件行数(不加载数据)"""
|
|
31
|
+
path = Path(filepath)
|
|
32
|
+
ext = path.suffix.lower()
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
if ext == ".jsonl":
|
|
36
|
+
# JSONL: 直接数换行符
|
|
37
|
+
with open(filepath, "rb") as f:
|
|
38
|
+
return sum(1 for line in f if line.strip())
|
|
39
|
+
elif ext == ".csv":
|
|
40
|
+
# CSV: Polars LazyFrame
|
|
41
|
+
return pl.scan_csv(filepath).select(pl.len()).collect().item()
|
|
42
|
+
elif ext == ".parquet":
|
|
43
|
+
# Parquet: Polars LazyFrame
|
|
44
|
+
return pl.scan_parquet(filepath).select(pl.len()).collect().item()
|
|
45
|
+
elif ext in (".arrow", ".feather"):
|
|
46
|
+
# Arrow: Polars LazyFrame
|
|
47
|
+
return pl.scan_ipc(filepath).select(pl.len()).collect().item()
|
|
48
|
+
except Exception:
|
|
49
|
+
pass
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class StreamingTransformer:
|
|
54
|
+
"""
|
|
55
|
+
流式数据转换器。
|
|
56
|
+
|
|
57
|
+
使用 generator 实现惰性处理,适合处理超大文件。
|
|
58
|
+
内存占用 O(1),不会随文件大小增长。
|
|
59
|
+
|
|
60
|
+
Examples:
|
|
61
|
+
>>> st = StreamingTransformer.load_stream("huge_100gb.jsonl")
|
|
62
|
+
>>> (st
|
|
63
|
+
... .filter(lambda x: x["score"] > 0.5)
|
|
64
|
+
... .transform(lambda x: {"text": x["content"]})
|
|
65
|
+
... .save("output.jsonl"))
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
iterator: Iterator[Dict[str, Any]],
|
|
71
|
+
source_path: Optional[str] = None,
|
|
72
|
+
total: Optional[int] = None,
|
|
73
|
+
):
|
|
74
|
+
"""
|
|
75
|
+
初始化流式转换器。
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
iterator: 数据迭代器
|
|
79
|
+
source_path: 源文件路径(用于元数据)
|
|
80
|
+
total: 总行数(用于进度条,可选)
|
|
81
|
+
"""
|
|
82
|
+
self._iterator = iterator
|
|
83
|
+
self._source_path = source_path
|
|
84
|
+
self._total = total
|
|
85
|
+
self._operations: List[Dict[str, Any]] = []
|
|
86
|
+
|
|
87
|
+
@classmethod
|
|
88
|
+
def load_stream(
|
|
89
|
+
cls, filepath: str, batch_size: int = 10000
|
|
90
|
+
) -> "StreamingTransformer":
|
|
91
|
+
"""
|
|
92
|
+
流式加载文件。
|
|
93
|
+
|
|
94
|
+
支持 JSONL、CSV、Parquet、Arrow 格式。
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
filepath: 文件路径
|
|
98
|
+
batch_size: 批量读取大小(CSV/Parquet/Arrow)
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
StreamingTransformer 实例
|
|
102
|
+
"""
|
|
103
|
+
path = Path(filepath)
|
|
104
|
+
if not path.exists():
|
|
105
|
+
raise FileNotFoundError(f"文件不存在: {filepath}")
|
|
106
|
+
|
|
107
|
+
ext = path.suffix.lower()
|
|
108
|
+
if ext not in STREAMING_FORMATS:
|
|
109
|
+
raise ValueError(f"不支持的流式格式: {ext},支持: {STREAMING_FORMATS}")
|
|
110
|
+
|
|
111
|
+
# 快速统计总行数(用于进度条)
|
|
112
|
+
total = _count_rows_fast(filepath)
|
|
113
|
+
|
|
114
|
+
if ext == ".jsonl":
|
|
115
|
+
return cls(_stream_jsonl(filepath), source_path=filepath, total=total)
|
|
116
|
+
elif ext == ".csv":
|
|
117
|
+
return cls(_stream_csv(filepath, batch_size), source_path=filepath, total=total)
|
|
118
|
+
elif ext == ".parquet":
|
|
119
|
+
return cls(_stream_parquet(filepath, batch_size), source_path=filepath, total=total)
|
|
120
|
+
elif ext in (".arrow", ".feather"):
|
|
121
|
+
return cls(_stream_arrow(filepath), source_path=filepath, total=total)
|
|
122
|
+
else:
|
|
123
|
+
raise ValueError(f"未知格式: {ext}")
|
|
124
|
+
|
|
125
|
+
@classmethod
|
|
126
|
+
def load_sharded(
|
|
127
|
+
cls, pattern: str, batch_size: int = 10000
|
|
128
|
+
) -> "StreamingTransformer":
|
|
129
|
+
"""
|
|
130
|
+
加载分片文件(支持 glob 模式)。
|
|
131
|
+
|
|
132
|
+
支持 JSONL、CSV、Parquet、Arrow 格式(根据扩展名自动检测)。
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
pattern: glob 模式,如 "data_*.jsonl" 或 "shards/part-*.parquet"
|
|
136
|
+
batch_size: 批量读取大小(CSV/Parquet/Arrow)
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
StreamingTransformer 实例
|
|
140
|
+
|
|
141
|
+
Examples:
|
|
142
|
+
>>> st = StreamingTransformer.load_sharded("data/train_*.jsonl")
|
|
143
|
+
>>> st = StreamingTransformer.load_sharded("shards/part-*.parquet")
|
|
144
|
+
"""
|
|
145
|
+
files = sorted(glob.glob(pattern))
|
|
146
|
+
if not files:
|
|
147
|
+
raise FileNotFoundError(f"没有匹配的文件: {pattern}")
|
|
148
|
+
|
|
149
|
+
def generator():
|
|
150
|
+
for filepath in files:
|
|
151
|
+
ext = Path(filepath).suffix.lower()
|
|
152
|
+
if ext == ".jsonl":
|
|
153
|
+
yield from _stream_jsonl(filepath)
|
|
154
|
+
elif ext == ".csv":
|
|
155
|
+
yield from _stream_csv(filepath, batch_size)
|
|
156
|
+
elif ext == ".parquet":
|
|
157
|
+
yield from _stream_parquet(filepath, batch_size)
|
|
158
|
+
elif ext in (".arrow", ".feather"):
|
|
159
|
+
yield from _stream_arrow(filepath)
|
|
160
|
+
else:
|
|
161
|
+
# 默认当作 JSONL
|
|
162
|
+
yield from _stream_jsonl(filepath)
|
|
163
|
+
|
|
164
|
+
return cls(generator(), source_path=pattern)
|
|
165
|
+
|
|
166
|
+
def filter(self, func: Callable[[Dict], bool]) -> "StreamingTransformer":
|
|
167
|
+
"""
|
|
168
|
+
惰性过滤。
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
func: 过滤函数,返回 True 保留
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
新的 StreamingTransformer(惰性,不立即执行)
|
|
175
|
+
"""
|
|
176
|
+
def filtered_iterator():
|
|
177
|
+
for item in self._iterator:
|
|
178
|
+
try:
|
|
179
|
+
if func(item):
|
|
180
|
+
yield item
|
|
181
|
+
except Exception:
|
|
182
|
+
pass # 跳过错误
|
|
183
|
+
|
|
184
|
+
# 过滤后数量未知,不传递 total
|
|
185
|
+
new_st = StreamingTransformer(filtered_iterator(), self._source_path, total=None)
|
|
186
|
+
new_st._operations = self._operations + [{"type": "filter", "func": func}]
|
|
187
|
+
return new_st
|
|
188
|
+
|
|
189
|
+
def transform(self, func: Callable[[Dict], Dict]) -> "StreamingTransformer":
|
|
190
|
+
"""
|
|
191
|
+
惰性转换。
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
func: 转换函数
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
新的 StreamingTransformer(惰性,不立即执行)
|
|
198
|
+
"""
|
|
199
|
+
def transformed_iterator():
|
|
200
|
+
for item in self._iterator:
|
|
201
|
+
try:
|
|
202
|
+
yield func(item)
|
|
203
|
+
except Exception:
|
|
204
|
+
pass # 跳过错误
|
|
205
|
+
|
|
206
|
+
# transform 是 1:1 转换,保留 total
|
|
207
|
+
new_st = StreamingTransformer(transformed_iterator(), self._source_path, total=self._total)
|
|
208
|
+
new_st._operations = self._operations + [{"type": "transform", "func": func}]
|
|
209
|
+
return new_st
|
|
210
|
+
|
|
211
|
+
def head(self, n: int) -> "StreamingTransformer":
|
|
212
|
+
"""
|
|
213
|
+
惰性取前 N 条。
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
n: 数量
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
新的 StreamingTransformer
|
|
220
|
+
"""
|
|
221
|
+
def head_iterator():
|
|
222
|
+
count = 0
|
|
223
|
+
for item in self._iterator:
|
|
224
|
+
if count >= n:
|
|
225
|
+
break
|
|
226
|
+
yield item
|
|
227
|
+
count += 1
|
|
228
|
+
|
|
229
|
+
# head(n) 的 total 是 min(n, original_total)
|
|
230
|
+
new_total = min(n, self._total) if self._total is not None else n
|
|
231
|
+
new_st = StreamingTransformer(head_iterator(), self._source_path, total=new_total)
|
|
232
|
+
new_st._operations = self._operations + [{"type": "head", "n": n}]
|
|
233
|
+
return new_st
|
|
234
|
+
|
|
235
|
+
def skip(self, n: int) -> "StreamingTransformer":
|
|
236
|
+
"""
|
|
237
|
+
惰性跳过前 N 条。
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
n: 跳过数量
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
新的 StreamingTransformer
|
|
244
|
+
"""
|
|
245
|
+
def skip_iterator():
|
|
246
|
+
count = 0
|
|
247
|
+
for item in self._iterator:
|
|
248
|
+
if count < n:
|
|
249
|
+
count += 1
|
|
250
|
+
continue
|
|
251
|
+
yield item
|
|
252
|
+
|
|
253
|
+
# skip(n) 的 total 是 max(0, original_total - n)
|
|
254
|
+
new_total = max(0, self._total - n) if self._total is not None else None
|
|
255
|
+
new_st = StreamingTransformer(skip_iterator(), self._source_path, total=new_total)
|
|
256
|
+
new_st._operations = self._operations + [{"type": "skip", "n": n}]
|
|
257
|
+
return new_st
|
|
258
|
+
|
|
259
|
+
def batch(self, size: int) -> Generator[List[Dict], None, None]:
|
|
260
|
+
"""
|
|
261
|
+
分批迭代(用于批量处理场景)。
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
size: 批次大小
|
|
265
|
+
|
|
266
|
+
Yields:
|
|
267
|
+
数据批次列表
|
|
268
|
+
|
|
269
|
+
Examples:
|
|
270
|
+
>>> for batch in st.batch(1000):
|
|
271
|
+
... process_batch(batch)
|
|
272
|
+
"""
|
|
273
|
+
batch = []
|
|
274
|
+
for item in self._iterator:
|
|
275
|
+
batch.append(item)
|
|
276
|
+
if len(batch) >= size:
|
|
277
|
+
yield batch
|
|
278
|
+
batch = []
|
|
279
|
+
if batch:
|
|
280
|
+
yield batch
|
|
281
|
+
|
|
282
|
+
def save(
|
|
283
|
+
self, filepath: str, show_progress: bool = True, batch_size: int = 10000
|
|
284
|
+
) -> int:
|
|
285
|
+
"""
|
|
286
|
+
流式保存到文件。
|
|
287
|
+
|
|
288
|
+
支持 JSONL、CSV、Parquet、Arrow 格式(根据扩展名自动检测)。
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
filepath: 输出文件路径
|
|
292
|
+
show_progress: 是否显示进度
|
|
293
|
+
batch_size: 批量写入大小(CSV/Parquet/Arrow)
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
写入的记录数
|
|
297
|
+
"""
|
|
298
|
+
path = Path(filepath)
|
|
299
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
300
|
+
ext = path.suffix.lower()
|
|
301
|
+
|
|
302
|
+
if ext == ".jsonl":
|
|
303
|
+
return self._save_jsonl(filepath, show_progress)
|
|
304
|
+
elif ext == ".csv":
|
|
305
|
+
return self._save_batched(filepath, "csv", batch_size, show_progress)
|
|
306
|
+
elif ext == ".parquet":
|
|
307
|
+
return self._save_batched(filepath, "parquet", batch_size, show_progress)
|
|
308
|
+
elif ext in (".arrow", ".feather"):
|
|
309
|
+
return self._save_batched(filepath, "arrow", batch_size, show_progress)
|
|
310
|
+
else:
|
|
311
|
+
# 默认 JSONL
|
|
312
|
+
return self._save_jsonl(filepath, show_progress)
|
|
313
|
+
|
|
314
|
+
def _save_jsonl(self, filepath: str, show_progress: bool) -> int:
|
|
315
|
+
"""JSONL 逐行流式保存(使用 orjson)"""
|
|
316
|
+
count = 0
|
|
317
|
+
|
|
318
|
+
if show_progress:
|
|
319
|
+
# 根据是否有总数选择进度条样式
|
|
320
|
+
if self._total is not None:
|
|
321
|
+
# 有总数:显示进度条、百分比、剩余时间
|
|
322
|
+
columns = [
|
|
323
|
+
SpinnerColumn(),
|
|
324
|
+
TextColumn("[progress.description]{task.description}"),
|
|
325
|
+
BarColumn(),
|
|
326
|
+
TaskProgressColumn(),
|
|
327
|
+
MofNCompleteColumn(),
|
|
328
|
+
TimeElapsedColumn(),
|
|
329
|
+
TimeRemainingColumn(),
|
|
330
|
+
]
|
|
331
|
+
else:
|
|
332
|
+
# 无总数:只显示已处理数量
|
|
333
|
+
columns = [
|
|
334
|
+
SpinnerColumn(),
|
|
335
|
+
TextColumn("[progress.description]{task.description}"),
|
|
336
|
+
MofNCompleteColumn(),
|
|
337
|
+
TimeElapsedColumn(),
|
|
338
|
+
]
|
|
339
|
+
|
|
340
|
+
with Progress(*columns) as progress:
|
|
341
|
+
task = progress.add_task("处理中", total=self._total)
|
|
342
|
+
with open(filepath, "wb") as f:
|
|
343
|
+
for item in self._iterator:
|
|
344
|
+
f.write(orjson.dumps(item) + b"\n")
|
|
345
|
+
count += 1
|
|
346
|
+
progress.update(task, advance=1)
|
|
347
|
+
else:
|
|
348
|
+
with open(filepath, "wb") as f:
|
|
349
|
+
for item in self._iterator:
|
|
350
|
+
f.write(orjson.dumps(item) + b"\n")
|
|
351
|
+
count += 1
|
|
352
|
+
|
|
353
|
+
return count
|
|
354
|
+
|
|
355
|
+
def _save_batched(
|
|
356
|
+
self, filepath: str, fmt: str, batch_size: int, show_progress: bool
|
|
357
|
+
) -> int:
|
|
358
|
+
"""
|
|
359
|
+
批量流式保存(CSV/Parquet/Arrow)。
|
|
360
|
+
|
|
361
|
+
读取和处理是流式的,写入时收集后一次性写入。
|
|
362
|
+
"""
|
|
363
|
+
path = Path(filepath)
|
|
364
|
+
all_items = []
|
|
365
|
+
|
|
366
|
+
if show_progress:
|
|
367
|
+
# 根据是否有总数选择进度条样式
|
|
368
|
+
if self._total is not None:
|
|
369
|
+
columns = [
|
|
370
|
+
SpinnerColumn(),
|
|
371
|
+
TextColumn("[progress.description]{task.description}"),
|
|
372
|
+
BarColumn(),
|
|
373
|
+
TaskProgressColumn(),
|
|
374
|
+
MofNCompleteColumn(),
|
|
375
|
+
TimeElapsedColumn(),
|
|
376
|
+
TimeRemainingColumn(),
|
|
377
|
+
]
|
|
378
|
+
else:
|
|
379
|
+
columns = [
|
|
380
|
+
SpinnerColumn(),
|
|
381
|
+
TextColumn("[progress.description]{task.description}"),
|
|
382
|
+
MofNCompleteColumn(),
|
|
383
|
+
TimeElapsedColumn(),
|
|
384
|
+
]
|
|
385
|
+
|
|
386
|
+
with Progress(*columns) as progress:
|
|
387
|
+
task = progress.add_task("处理中", total=self._total)
|
|
388
|
+
for item in self._iterator:
|
|
389
|
+
all_items.append(item)
|
|
390
|
+
progress.update(task, advance=1)
|
|
391
|
+
else:
|
|
392
|
+
for item in self._iterator:
|
|
393
|
+
all_items.append(item)
|
|
394
|
+
|
|
395
|
+
if all_items:
|
|
396
|
+
df = pl.DataFrame(all_items)
|
|
397
|
+
if fmt == "csv":
|
|
398
|
+
df.write_csv(path)
|
|
399
|
+
elif fmt == "parquet":
|
|
400
|
+
df.write_parquet(path)
|
|
401
|
+
elif fmt == "arrow":
|
|
402
|
+
df.write_ipc(path)
|
|
403
|
+
|
|
404
|
+
return len(all_items)
|
|
405
|
+
|
|
406
|
+
def save_sharded(
|
|
407
|
+
self,
|
|
408
|
+
output_dir: str,
|
|
409
|
+
shard_size: int = 100000,
|
|
410
|
+
prefix: str = "part",
|
|
411
|
+
show_progress: bool = True,
|
|
412
|
+
) -> List[str]:
|
|
413
|
+
"""
|
|
414
|
+
分片保存。
|
|
415
|
+
|
|
416
|
+
Args:
|
|
417
|
+
output_dir: 输出目录
|
|
418
|
+
shard_size: 每个分片的记录数
|
|
419
|
+
prefix: 分片文件前缀
|
|
420
|
+
show_progress: 是否显示进度
|
|
421
|
+
|
|
422
|
+
Returns:
|
|
423
|
+
生成的分片文件路径列表
|
|
424
|
+
|
|
425
|
+
Examples:
|
|
426
|
+
>>> files = st.save_sharded("output/", shard_size=100000)
|
|
427
|
+
>>> # 生成: output/part-00000.jsonl, output/part-00001.jsonl, ...
|
|
428
|
+
"""
|
|
429
|
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
430
|
+
|
|
431
|
+
shard_files = []
|
|
432
|
+
shard_idx = 0
|
|
433
|
+
count_in_shard = 0
|
|
434
|
+
current_file = None
|
|
435
|
+
|
|
436
|
+
def process_items(progress=None, task=None):
|
|
437
|
+
nonlocal shard_idx, count_in_shard, current_file
|
|
438
|
+
|
|
439
|
+
for item in self._iterator:
|
|
440
|
+
# 需要新分片
|
|
441
|
+
if current_file is None or count_in_shard >= shard_size:
|
|
442
|
+
if current_file:
|
|
443
|
+
current_file.close()
|
|
444
|
+
|
|
445
|
+
shard_path = os.path.join(output_dir, f"{prefix}-{shard_idx:05d}.jsonl")
|
|
446
|
+
shard_files.append(shard_path)
|
|
447
|
+
current_file = open(shard_path, "wb")
|
|
448
|
+
shard_idx += 1
|
|
449
|
+
count_in_shard = 0
|
|
450
|
+
if progress is not None:
|
|
451
|
+
progress.update(task, description=f"分片 {shard_idx}")
|
|
452
|
+
|
|
453
|
+
current_file.write(orjson.dumps(item) + b"\n")
|
|
454
|
+
count_in_shard += 1
|
|
455
|
+
if progress is not None:
|
|
456
|
+
progress.update(task, advance=1)
|
|
457
|
+
|
|
458
|
+
try:
|
|
459
|
+
if show_progress:
|
|
460
|
+
if self._total is not None:
|
|
461
|
+
columns = [
|
|
462
|
+
SpinnerColumn(),
|
|
463
|
+
TextColumn("[progress.description]{task.description}"),
|
|
464
|
+
BarColumn(),
|
|
465
|
+
TaskProgressColumn(),
|
|
466
|
+
MofNCompleteColumn(),
|
|
467
|
+
TimeElapsedColumn(),
|
|
468
|
+
TimeRemainingColumn(),
|
|
469
|
+
]
|
|
470
|
+
else:
|
|
471
|
+
columns = [
|
|
472
|
+
SpinnerColumn(),
|
|
473
|
+
TextColumn("[progress.description]{task.description}"),
|
|
474
|
+
MofNCompleteColumn(),
|
|
475
|
+
TimeElapsedColumn(),
|
|
476
|
+
]
|
|
477
|
+
|
|
478
|
+
with Progress(*columns) as progress:
|
|
479
|
+
task = progress.add_task("分片 1", total=self._total)
|
|
480
|
+
process_items(progress, task)
|
|
481
|
+
else:
|
|
482
|
+
process_items()
|
|
483
|
+
finally:
|
|
484
|
+
if current_file:
|
|
485
|
+
current_file.close()
|
|
486
|
+
|
|
487
|
+
return shard_files
|
|
488
|
+
|
|
489
|
+
def collect(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
|
490
|
+
"""
|
|
491
|
+
收集所有数据到内存(注意内存占用)。
|
|
492
|
+
|
|
493
|
+
Args:
|
|
494
|
+
limit: 最大收集数量,None 表示全部
|
|
495
|
+
|
|
496
|
+
Returns:
|
|
497
|
+
数据列表
|
|
498
|
+
"""
|
|
499
|
+
result = []
|
|
500
|
+
for item in self._iterator:
|
|
501
|
+
result.append(item)
|
|
502
|
+
if limit and len(result) >= limit:
|
|
503
|
+
break
|
|
504
|
+
return result
|
|
505
|
+
|
|
506
|
+
def count(self) -> int:
|
|
507
|
+
"""
|
|
508
|
+
计数(会消耗迭代器)。
|
|
509
|
+
|
|
510
|
+
Returns:
|
|
511
|
+
记录数
|
|
512
|
+
"""
|
|
513
|
+
count = 0
|
|
514
|
+
for _ in self._iterator:
|
|
515
|
+
count += 1
|
|
516
|
+
return count
|
|
517
|
+
|
|
518
|
+
def __iter__(self) -> Iterator[Dict[str, Any]]:
|
|
519
|
+
"""支持直接迭代"""
|
|
520
|
+
return self._iterator
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
# ============ 便捷函数 ============
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
def load_stream(filepath: str, batch_size: int = 10000) -> StreamingTransformer:
|
|
527
|
+
"""
|
|
528
|
+
流式加载文件。
|
|
529
|
+
|
|
530
|
+
支持 JSONL、CSV、Parquet、Arrow 格式。
|
|
531
|
+
|
|
532
|
+
Args:
|
|
533
|
+
filepath: 文件路径
|
|
534
|
+
batch_size: 批量读取大小(CSV/Parquet/Arrow)
|
|
535
|
+
|
|
536
|
+
Returns:
|
|
537
|
+
StreamingTransformer 实例
|
|
538
|
+
|
|
539
|
+
Examples:
|
|
540
|
+
>>> from dtflow import load_stream
|
|
541
|
+
>>> (load_stream("huge.jsonl")
|
|
542
|
+
... .filter(lambda x: x["score"] > 0.5)
|
|
543
|
+
... .save("filtered.jsonl"))
|
|
544
|
+
>>> (load_stream("data.csv")
|
|
545
|
+
... .filter(lambda x: x["score"] > 0.5)
|
|
546
|
+
... .save("output.parquet"))
|
|
547
|
+
"""
|
|
548
|
+
return StreamingTransformer.load_stream(filepath, batch_size)
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
def load_sharded(pattern: str, batch_size: int = 10000) -> StreamingTransformer:
|
|
552
|
+
"""
|
|
553
|
+
加载分片文件。
|
|
554
|
+
|
|
555
|
+
支持 JSONL、CSV、Parquet、Arrow 格式。
|
|
556
|
+
|
|
557
|
+
Args:
|
|
558
|
+
pattern: glob 模式
|
|
559
|
+
batch_size: 批量读取大小(CSV/Parquet/Arrow)
|
|
560
|
+
|
|
561
|
+
Returns:
|
|
562
|
+
StreamingTransformer 实例
|
|
563
|
+
|
|
564
|
+
Examples:
|
|
565
|
+
>>> from dtflow import load_sharded
|
|
566
|
+
>>> load_sharded("data/*.jsonl").save("merged.jsonl")
|
|
567
|
+
>>> load_sharded("data/*.parquet").save("merged.parquet")
|
|
568
|
+
"""
|
|
569
|
+
return StreamingTransformer.load_sharded(pattern, batch_size)
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
def process_shards(
|
|
573
|
+
input_pattern: str,
|
|
574
|
+
output_dir: str,
|
|
575
|
+
func: Callable[[Dict], Optional[Dict]],
|
|
576
|
+
workers: int = 1,
|
|
577
|
+
shard_size: int = 100000,
|
|
578
|
+
) -> List[str]:
|
|
579
|
+
"""
|
|
580
|
+
并行处理分片文件。
|
|
581
|
+
|
|
582
|
+
Args:
|
|
583
|
+
input_pattern: 输入文件 glob 模式
|
|
584
|
+
output_dir: 输出目录
|
|
585
|
+
func: 处理函数,返回 None 表示过滤掉
|
|
586
|
+
workers: 并行工作进程数(目前仅支持 1)
|
|
587
|
+
shard_size: 输出分片大小
|
|
588
|
+
|
|
589
|
+
Returns:
|
|
590
|
+
生成的输出文件列表
|
|
591
|
+
|
|
592
|
+
Examples:
|
|
593
|
+
>>> def process(item):
|
|
594
|
+
... if item["score"] > 0.5:
|
|
595
|
+
... return {"text": item["content"]}
|
|
596
|
+
... return None
|
|
597
|
+
>>> process_shards("input/*.jsonl", "output/", process)
|
|
598
|
+
"""
|
|
599
|
+
# 简单实现:串行处理
|
|
600
|
+
# TODO: 未来可以添加多进程支持
|
|
601
|
+
|
|
602
|
+
def transform_func(item):
|
|
603
|
+
result = func(item)
|
|
604
|
+
return result
|
|
605
|
+
|
|
606
|
+
return (
|
|
607
|
+
load_sharded(input_pattern)
|
|
608
|
+
.transform(transform_func)
|
|
609
|
+
.filter(lambda x: x is not None)
|
|
610
|
+
.save_sharded(output_dir, shard_size=shard_size)
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
# ============ 流式读取函数 ============
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
def _stream_jsonl(filepath: str) -> Generator[Dict[str, Any], None, None]:
|
|
618
|
+
"""JSONL 流式读取(使用 orjson)"""
|
|
619
|
+
with open(filepath, "rb") as f:
|
|
620
|
+
for line in f:
|
|
621
|
+
line = line.strip()
|
|
622
|
+
if line:
|
|
623
|
+
yield orjson.loads(line)
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
def _stream_csv(
|
|
627
|
+
filepath: str, batch_size: int = 10000
|
|
628
|
+
) -> Generator[Dict[str, Any], None, None]:
|
|
629
|
+
"""CSV 流式读取(使用 Polars BatchedCsvReader)"""
|
|
630
|
+
reader = pl.read_csv_batched(filepath, batch_size=batch_size)
|
|
631
|
+
while True:
|
|
632
|
+
batches = reader.next_batches(1)
|
|
633
|
+
if not batches:
|
|
634
|
+
break
|
|
635
|
+
for row in batches[0].to_dicts():
|
|
636
|
+
yield row
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
def _stream_parquet(
|
|
640
|
+
filepath: str, batch_size: int = 10000
|
|
641
|
+
) -> Generator[Dict[str, Any], None, None]:
|
|
642
|
+
"""Parquet 流式读取(使用 PyArrow iter_batches)"""
|
|
643
|
+
import pyarrow.parquet as pq
|
|
644
|
+
|
|
645
|
+
pf = pq.ParquetFile(filepath)
|
|
646
|
+
for batch in pf.iter_batches(batch_size=batch_size):
|
|
647
|
+
df = pl.from_arrow(batch)
|
|
648
|
+
for row in df.to_dicts():
|
|
649
|
+
yield row
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
def _stream_arrow(filepath: str) -> Generator[Dict[str, Any], None, None]:
|
|
653
|
+
"""Arrow/Feather 流式读取(使用 PyArrow IPC)"""
|
|
654
|
+
import pyarrow as pa
|
|
655
|
+
|
|
656
|
+
with pa.ipc.open_file(filepath) as reader:
|
|
657
|
+
for i in range(reader.num_record_batches):
|
|
658
|
+
batch = reader.get_batch(i)
|
|
659
|
+
df = pl.from_arrow(batch)
|
|
660
|
+
for row in df.to_dicts():
|
|
661
|
+
yield row
|