dtflow 0.4.2__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtflow/__init__.py +1 -1
- dtflow/__main__.py +6 -3
- dtflow/cli/clean.py +486 -0
- dtflow/cli/commands.py +53 -2637
- dtflow/cli/common.py +384 -0
- dtflow/cli/io_ops.py +385 -0
- dtflow/cli/lineage.py +49 -0
- dtflow/cli/pipeline.py +54 -0
- dtflow/cli/sample.py +294 -0
- dtflow/cli/stats.py +589 -0
- dtflow/cli/transform.py +486 -0
- dtflow/core.py +35 -0
- dtflow/storage/io.py +49 -6
- dtflow/streaming.py +25 -4
- {dtflow-0.4.2.dist-info → dtflow-0.4.3.dist-info}/METADATA +12 -1
- dtflow-0.4.3.dist-info/RECORD +33 -0
- dtflow-0.4.2.dist-info/RECORD +0 -25
- {dtflow-0.4.2.dist-info → dtflow-0.4.3.dist-info}/WHEEL +0 -0
- {dtflow-0.4.2.dist-info → dtflow-0.4.3.dist-info}/entry_points.txt +0 -0
dtflow/cli/commands.py
CHANGED
|
@@ -1,2640 +1,56 @@
|
|
|
1
1
|
"""
|
|
2
|
-
CLI
|
|
2
|
+
CLI 命令统一导出入口
|
|
3
|
+
|
|
4
|
+
各命令已按功能拆分到独立模块:
|
|
5
|
+
- sample.py 采样相关 (sample, head, tail)
|
|
6
|
+
- transform.py 转换相关 (transform)
|
|
7
|
+
- stats.py 统计相关 (stats, token_stats)
|
|
8
|
+
- clean.py 清洗相关 (clean, dedupe)
|
|
9
|
+
- io_ops.py IO 操作 (concat, diff)
|
|
10
|
+
- pipeline.py Pipeline (run)
|
|
11
|
+
- lineage.py 血缘追踪 (history)
|
|
12
|
+
- common.py 通用工具函数
|
|
3
13
|
"""
|
|
4
14
|
|
|
5
|
-
|
|
6
|
-
import
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
from
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
import
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
from
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
from
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
""
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
""
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
type: Literal["random", "head", "tail"] = "head",
|
|
48
|
-
output: Optional[str] = None,
|
|
49
|
-
seed: Optional[int] = None,
|
|
50
|
-
by: Optional[str] = None,
|
|
51
|
-
uniform: bool = False,
|
|
52
|
-
fields: Optional[str] = None,
|
|
53
|
-
) -> None:
|
|
54
|
-
"""
|
|
55
|
-
从数据文件中采样指定数量的数据。
|
|
56
|
-
|
|
57
|
-
Args:
|
|
58
|
-
filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
|
|
59
|
-
num: 采样数量,默认 10
|
|
60
|
-
- num > 0: 采样指定数量
|
|
61
|
-
- num = 0: 采样所有数据
|
|
62
|
-
- num < 0: Python 切片风格(如 -1 表示最后 1 条,-10 表示最后 10 条)
|
|
63
|
-
type: 采样方式,可选 random/head/tail,默认 head
|
|
64
|
-
output: 输出文件路径,不指定则打印到控制台
|
|
65
|
-
seed: 随机种子(仅在 type=random 时有效)
|
|
66
|
-
by: 分层采样字段名,按该字段的值分组采样
|
|
67
|
-
uniform: 均匀采样模式(需配合 --by 使用),各组采样相同数量
|
|
68
|
-
fields: 只显示指定字段(逗号分隔),仅在预览模式下有效
|
|
69
|
-
|
|
70
|
-
Examples:
|
|
71
|
-
dt sample data.jsonl 5
|
|
72
|
-
dt sample data.csv 100 --type=head
|
|
73
|
-
dt sample data.xlsx 50 --output=sampled.jsonl
|
|
74
|
-
dt sample data.jsonl 0 # 采样所有数据
|
|
75
|
-
dt sample data.jsonl -10 # 最后 10 条数据
|
|
76
|
-
dt sample data.jsonl 1000 --by=category # 按比例分层采样
|
|
77
|
-
dt sample data.jsonl 1000 --by=category --uniform # 均匀分层采样
|
|
78
|
-
dt sample data.jsonl --fields=question,answer # 只显示指定字段
|
|
79
|
-
"""
|
|
80
|
-
filepath = Path(filename)
|
|
81
|
-
|
|
82
|
-
if not filepath.exists():
|
|
83
|
-
print(f"错误: 文件不存在 - {filename}")
|
|
84
|
-
return
|
|
85
|
-
|
|
86
|
-
if not _check_file_format(filepath):
|
|
87
|
-
return
|
|
88
|
-
|
|
89
|
-
# uniform 必须配合 by 使用
|
|
90
|
-
if uniform and not by:
|
|
91
|
-
print("错误: --uniform 必须配合 --by 使用")
|
|
92
|
-
return
|
|
93
|
-
|
|
94
|
-
# 分层采样模式
|
|
95
|
-
if by:
|
|
96
|
-
try:
|
|
97
|
-
sampled = _stratified_sample(filepath, num, by, uniform, seed, type)
|
|
98
|
-
except Exception as e:
|
|
99
|
-
print(f"错误: {e}")
|
|
100
|
-
return
|
|
101
|
-
else:
|
|
102
|
-
# 普通采样
|
|
103
|
-
try:
|
|
104
|
-
sampled = sample_file(
|
|
105
|
-
str(filepath),
|
|
106
|
-
num=num,
|
|
107
|
-
sample_type=type,
|
|
108
|
-
seed=seed,
|
|
109
|
-
output=None, # 先不保存,统一在最后处理
|
|
110
|
-
)
|
|
111
|
-
except Exception as e:
|
|
112
|
-
print(f"错误: {e}")
|
|
113
|
-
return
|
|
114
|
-
|
|
115
|
-
# 输出结果
|
|
116
|
-
if output:
|
|
117
|
-
save_data(sampled, output)
|
|
118
|
-
print(f"已保存 {len(sampled)} 条数据到 {output}")
|
|
119
|
-
else:
|
|
120
|
-
# 获取文件总行数用于显示
|
|
121
|
-
total_count = _get_file_row_count(filepath)
|
|
122
|
-
# 解析 fields 参数
|
|
123
|
-
field_list = _parse_field_list(fields) if fields else None
|
|
124
|
-
_print_samples(sampled, filepath.name, total_count, field_list)
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
def _stratified_sample(
|
|
128
|
-
filepath: Path,
|
|
129
|
-
num: int,
|
|
130
|
-
stratify_field: str,
|
|
131
|
-
uniform: bool,
|
|
132
|
-
seed: Optional[int],
|
|
133
|
-
sample_type: str,
|
|
134
|
-
) -> List[Dict]:
|
|
135
|
-
"""
|
|
136
|
-
分层采样实现。
|
|
137
|
-
|
|
138
|
-
Args:
|
|
139
|
-
filepath: 文件路径
|
|
140
|
-
num: 目标采样总数
|
|
141
|
-
stratify_field: 分层字段,支持嵌套路径语法:
|
|
142
|
-
- meta.source 嵌套字段
|
|
143
|
-
- messages[0].role 数组索引
|
|
144
|
-
- messages[-1].role 负索引
|
|
145
|
-
- messages.# 数组长度
|
|
146
|
-
- messages[*].role 展开所有元素(可加 :join/:unique 模式)
|
|
147
|
-
uniform: 是否均匀采样(各组相同数量)
|
|
148
|
-
seed: 随机种子
|
|
149
|
-
sample_type: 采样方式(用于组内采样)
|
|
150
|
-
|
|
151
|
-
Returns:
|
|
152
|
-
采样后的数据列表
|
|
153
|
-
"""
|
|
154
|
-
import random
|
|
155
|
-
from collections import defaultdict
|
|
156
|
-
|
|
157
|
-
if seed is not None:
|
|
158
|
-
random.seed(seed)
|
|
159
|
-
|
|
160
|
-
# 加载数据
|
|
161
|
-
data = load_data(str(filepath))
|
|
162
|
-
total = len(data)
|
|
163
|
-
|
|
164
|
-
if num <= 0 or num > total:
|
|
165
|
-
num = total
|
|
166
|
-
|
|
167
|
-
# 按字段分组(支持嵌套路径语法)
|
|
168
|
-
groups: Dict[Any, List[Dict]] = defaultdict(list)
|
|
169
|
-
for item in data:
|
|
170
|
-
key = get_field_with_spec(item, stratify_field, default="__null__")
|
|
171
|
-
# 确保 key 可哈希
|
|
172
|
-
if isinstance(key, list):
|
|
173
|
-
key = tuple(key)
|
|
174
|
-
groups[key].append(item)
|
|
175
|
-
|
|
176
|
-
group_keys = list(groups.keys())
|
|
177
|
-
num_groups = len(group_keys)
|
|
178
|
-
|
|
179
|
-
# 打印分组信息
|
|
180
|
-
print(f"📊 分层采样: 字段={stratify_field}, 共 {num_groups} 组")
|
|
181
|
-
for key in sorted(group_keys, key=lambda x: -len(groups[x])):
|
|
182
|
-
count = len(groups[key])
|
|
183
|
-
pct = count / total * 100
|
|
184
|
-
display_key = key if key != "__null__" else "[空值]"
|
|
185
|
-
print(f" {display_key}: {count} 条 ({pct:.1f}%)")
|
|
186
|
-
|
|
187
|
-
# 计算各组采样数量
|
|
188
|
-
if uniform:
|
|
189
|
-
# 均匀采样:各组数量相等
|
|
190
|
-
per_group = num // num_groups
|
|
191
|
-
remainder = num % num_groups
|
|
192
|
-
sample_counts = {key: per_group for key in group_keys}
|
|
193
|
-
# 余数分配给数据量最多的组
|
|
194
|
-
for key in sorted(group_keys, key=lambda x: -len(groups[x]))[:remainder]:
|
|
195
|
-
sample_counts[key] += 1
|
|
196
|
-
else:
|
|
197
|
-
# 按比例采样:保持原有比例
|
|
198
|
-
sample_counts = {}
|
|
199
|
-
allocated = 0
|
|
200
|
-
# 按组大小降序处理,确保小组也能分到
|
|
201
|
-
sorted_keys = sorted(group_keys, key=lambda x: -len(groups[x]))
|
|
202
|
-
for i, key in enumerate(sorted_keys):
|
|
203
|
-
if i == len(sorted_keys) - 1:
|
|
204
|
-
# 最后一组分配剩余
|
|
205
|
-
sample_counts[key] = num - allocated
|
|
206
|
-
else:
|
|
207
|
-
# 按比例计算
|
|
208
|
-
ratio = len(groups[key]) / total
|
|
209
|
-
count = int(num * ratio)
|
|
210
|
-
# 确保至少 1 条(如果组有数据)
|
|
211
|
-
count = max(1, count) if groups[key] else 0
|
|
212
|
-
sample_counts[key] = count
|
|
213
|
-
allocated += count
|
|
214
|
-
|
|
215
|
-
# 执行各组采样
|
|
216
|
-
result = []
|
|
217
|
-
print(f"🔄 执行采样...")
|
|
218
|
-
for key in group_keys:
|
|
219
|
-
group_data = groups[key]
|
|
220
|
-
target = min(sample_counts[key], len(group_data))
|
|
221
|
-
|
|
222
|
-
if target <= 0:
|
|
223
|
-
continue
|
|
224
|
-
|
|
225
|
-
# 组内采样
|
|
226
|
-
if sample_type == "random":
|
|
227
|
-
sampled = random.sample(group_data, target)
|
|
228
|
-
elif sample_type == "head":
|
|
229
|
-
sampled = group_data[:target]
|
|
230
|
-
else: # tail
|
|
231
|
-
sampled = group_data[-target:]
|
|
232
|
-
|
|
233
|
-
result.extend(sampled)
|
|
234
|
-
|
|
235
|
-
# 打印采样结果
|
|
236
|
-
print(f"\n📋 采样结果:")
|
|
237
|
-
result_groups: Dict[Any, int] = defaultdict(int)
|
|
238
|
-
for item in result:
|
|
239
|
-
key = item.get(stratify_field, "__null__")
|
|
240
|
-
result_groups[key] += 1
|
|
241
|
-
|
|
242
|
-
for key in sorted(group_keys, key=lambda x: -len(groups[x])):
|
|
243
|
-
orig = len(groups[key])
|
|
244
|
-
sampled_count = result_groups.get(key, 0)
|
|
245
|
-
display_key = key if key != "__null__" else "[空值]"
|
|
246
|
-
print(f" {display_key}: {orig} → {sampled_count}")
|
|
247
|
-
|
|
248
|
-
print(f"\n✅ 总计: {total} → {len(result)} 条")
|
|
249
|
-
|
|
250
|
-
return result
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
def head(
|
|
254
|
-
filename: str,
|
|
255
|
-
num: int = 10,
|
|
256
|
-
output: Optional[str] = None,
|
|
257
|
-
fields: Optional[str] = None,
|
|
258
|
-
) -> None:
|
|
259
|
-
"""
|
|
260
|
-
显示文件的前 N 条数据(dt sample --type=head 的快捷方式)。
|
|
261
|
-
|
|
262
|
-
Args:
|
|
263
|
-
filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
|
|
264
|
-
num: 显示数量,默认 10
|
|
265
|
-
- num > 0: 显示指定数量
|
|
266
|
-
- num = 0: 显示所有数据
|
|
267
|
-
- num < 0: Python 切片风格(如 -10 表示最后 10 条)
|
|
268
|
-
output: 输出文件路径,不指定则打印到控制台
|
|
269
|
-
fields: 只显示指定字段(逗号分隔),仅在预览模式下有效
|
|
270
|
-
|
|
271
|
-
Examples:
|
|
272
|
-
dt head data.jsonl # 显示前 10 条
|
|
273
|
-
dt head data.jsonl 20 # 显示前 20 条
|
|
274
|
-
dt head data.csv 0 # 显示所有数据
|
|
275
|
-
dt head data.xlsx --output=head.jsonl
|
|
276
|
-
dt head data.jsonl --fields=question,answer
|
|
277
|
-
"""
|
|
278
|
-
sample(filename, num=num, type="head", output=output, fields=fields)
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
def tail(
|
|
282
|
-
filename: str,
|
|
283
|
-
num: int = 10,
|
|
284
|
-
output: Optional[str] = None,
|
|
285
|
-
fields: Optional[str] = None,
|
|
286
|
-
) -> None:
|
|
287
|
-
"""
|
|
288
|
-
显示文件的后 N 条数据(dt sample --type=tail 的快捷方式)。
|
|
289
|
-
|
|
290
|
-
Args:
|
|
291
|
-
filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
|
|
292
|
-
num: 显示数量,默认 10
|
|
293
|
-
- num > 0: 显示指定数量
|
|
294
|
-
- num = 0: 显示所有数据
|
|
295
|
-
- num < 0: Python 切片风格(如 -10 表示最后 10 条)
|
|
296
|
-
output: 输出文件路径,不指定则打印到控制台
|
|
297
|
-
fields: 只显示指定字段(逗号分隔),仅在预览模式下有效
|
|
298
|
-
|
|
299
|
-
Examples:
|
|
300
|
-
dt tail data.jsonl # 显示后 10 条
|
|
301
|
-
dt tail data.jsonl 20 # 显示后 20 条
|
|
302
|
-
dt tail data.csv 0 # 显示所有数据
|
|
303
|
-
dt tail data.xlsx --output=tail.jsonl
|
|
304
|
-
dt tail data.jsonl --fields=question,answer
|
|
305
|
-
"""
|
|
306
|
-
sample(filename, num=num, type="tail", output=output, fields=fields)
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
def _get_file_row_count(filepath: Path) -> Optional[int]:
|
|
310
|
-
"""
|
|
311
|
-
快速获取文件行数(不加载全部数据)。
|
|
312
|
-
|
|
313
|
-
对于 JSONL 文件,直接计算行数;其他格式返回 None。
|
|
314
|
-
"""
|
|
315
|
-
ext = filepath.suffix.lower()
|
|
316
|
-
if ext == ".jsonl":
|
|
317
|
-
try:
|
|
318
|
-
with open(filepath, "rb") as f:
|
|
319
|
-
return sum(1 for _ in f)
|
|
320
|
-
except Exception:
|
|
321
|
-
return None
|
|
322
|
-
# 其他格式暂不支持快速计数
|
|
323
|
-
return None
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
def _format_value(value: Any, max_len: int = 80) -> str:
|
|
327
|
-
"""格式化单个值,长文本截断。"""
|
|
328
|
-
if value is None:
|
|
329
|
-
return "[dim]null[/dim]"
|
|
330
|
-
if isinstance(value, bool):
|
|
331
|
-
return "[cyan]true[/cyan]" if value else "[cyan]false[/cyan]"
|
|
332
|
-
if isinstance(value, (int, float)):
|
|
333
|
-
return f"[cyan]{value}[/cyan]"
|
|
334
|
-
if isinstance(value, str):
|
|
335
|
-
# 处理多行文本
|
|
336
|
-
if "\n" in value:
|
|
337
|
-
lines = value.split("\n")
|
|
338
|
-
if len(lines) > 3:
|
|
339
|
-
preview = lines[0][:max_len] + f"... [dim]({len(lines)} 行)[/dim]"
|
|
340
|
-
else:
|
|
341
|
-
preview = value.replace("\n", "\\n")
|
|
342
|
-
if len(preview) > max_len:
|
|
343
|
-
preview = preview[:max_len] + "..."
|
|
344
|
-
return f'"{preview}"'
|
|
345
|
-
if len(value) > max_len:
|
|
346
|
-
return f'"{value[:max_len]}..." [dim]({len(value)} 字符)[/dim]'
|
|
347
|
-
return f'"{value}"'
|
|
348
|
-
return str(value)
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
def _format_nested(
|
|
352
|
-
value: Any,
|
|
353
|
-
indent: str = "",
|
|
354
|
-
is_last: bool = True,
|
|
355
|
-
max_len: int = 80,
|
|
356
|
-
) -> List[str]:
|
|
357
|
-
"""
|
|
358
|
-
递归格式化嵌套结构,返回行列表。
|
|
359
|
-
|
|
360
|
-
使用树形符号展示结构:
|
|
361
|
-
├─ 中间项
|
|
362
|
-
└─ 最后一项
|
|
363
|
-
"""
|
|
364
|
-
lines = []
|
|
365
|
-
branch = "└─ " if is_last else "├─ "
|
|
366
|
-
cont = " " if is_last else "│ "
|
|
367
|
-
|
|
368
|
-
if isinstance(value, dict):
|
|
369
|
-
items = list(value.items())
|
|
370
|
-
for i, (k, v) in enumerate(items):
|
|
371
|
-
is_last_item = i == len(items) - 1
|
|
372
|
-
b = "└─ " if is_last_item else "├─ "
|
|
373
|
-
c = " " if is_last_item else "│ "
|
|
374
|
-
|
|
375
|
-
if isinstance(v, (dict, list)) and v:
|
|
376
|
-
# 嵌套结构
|
|
377
|
-
if isinstance(v, list):
|
|
378
|
-
# 检测是否为 messages 格式
|
|
379
|
-
is_messages = (
|
|
380
|
-
v and isinstance(v[0], dict) and "role" in v[0] and "content" in v[0]
|
|
381
|
-
)
|
|
382
|
-
if is_messages:
|
|
383
|
-
lines.append(
|
|
384
|
-
f"{indent}{b}[green]{k}[/green]: ({len(v)} items) [dim]→ \\[role]: content[/dim]"
|
|
385
|
-
)
|
|
386
|
-
else:
|
|
387
|
-
lines.append(f"{indent}{b}[green]{k}[/green]: ({len(v)} items)")
|
|
388
|
-
else:
|
|
389
|
-
lines.append(f"{indent}{b}[green]{k}[/green]:")
|
|
390
|
-
lines.extend(_format_nested(v, indent + c, True, max_len))
|
|
391
|
-
else:
|
|
392
|
-
# 简单值
|
|
393
|
-
lines.append(f"{indent}{b}[green]{k}[/green]: {_format_value(v, max_len)}")
|
|
394
|
-
|
|
395
|
-
elif isinstance(value, list):
|
|
396
|
-
for i, item in enumerate(value):
|
|
397
|
-
is_last_item = i == len(value) - 1
|
|
398
|
-
b = "└─ " if is_last_item else "├─ "
|
|
399
|
-
c = " " if is_last_item else "│ "
|
|
400
|
-
|
|
401
|
-
if isinstance(item, dict):
|
|
402
|
-
# 列表中的字典项 - 检测是否为 messages 格式
|
|
403
|
-
if "role" in item and "content" in item:
|
|
404
|
-
role = item.get("role", "")
|
|
405
|
-
content = item.get("content", "")
|
|
406
|
-
# 截断长内容
|
|
407
|
-
if len(content) > max_len:
|
|
408
|
-
content = content[:max_len].replace("\n", "\\n") + "..."
|
|
409
|
-
else:
|
|
410
|
-
content = content.replace("\n", "\\n")
|
|
411
|
-
# 使用 \[ 转义避免被 rich 解析为样式
|
|
412
|
-
lines.append(f"{indent}{b}[yellow]\\[{role}]:[/yellow] {content}")
|
|
413
|
-
else:
|
|
414
|
-
# 普通字典
|
|
415
|
-
lines.append(f"{indent}{b}[dim]{{...}}[/dim]")
|
|
416
|
-
lines.extend(_format_nested(item, indent + c, True, max_len))
|
|
417
|
-
elif isinstance(item, list):
|
|
418
|
-
lines.append(f"{indent}{b}[dim][{len(item)} items][/dim]")
|
|
419
|
-
lines.extend(_format_nested(item, indent + c, True, max_len))
|
|
420
|
-
else:
|
|
421
|
-
lines.append(f"{indent}{b}{_format_value(item, max_len)}")
|
|
422
|
-
|
|
423
|
-
return lines
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
def _is_simple_data(samples: List[Dict]) -> bool:
|
|
427
|
-
"""判断数据是否适合表格展示(无嵌套结构)。"""
|
|
428
|
-
if not samples or not isinstance(samples[0], dict):
|
|
429
|
-
return False
|
|
430
|
-
keys = list(samples[0].keys())
|
|
431
|
-
if len(keys) > 6:
|
|
432
|
-
return False
|
|
433
|
-
for s in samples[:3]:
|
|
434
|
-
for k in keys:
|
|
435
|
-
v = s.get(k)
|
|
436
|
-
if isinstance(v, (dict, list)):
|
|
437
|
-
return False
|
|
438
|
-
if isinstance(v, str) and len(v) > 80:
|
|
439
|
-
return False
|
|
440
|
-
return True
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
def _print_samples(
|
|
444
|
-
samples: list,
|
|
445
|
-
filename: Optional[str] = None,
|
|
446
|
-
total_count: Optional[int] = None,
|
|
447
|
-
fields: Optional[List[str]] = None,
|
|
448
|
-
) -> None:
|
|
449
|
-
"""
|
|
450
|
-
打印采样结果。
|
|
451
|
-
|
|
452
|
-
Args:
|
|
453
|
-
samples: 采样数据列表
|
|
454
|
-
filename: 文件名(用于显示概览)
|
|
455
|
-
total_count: 文件总行数(用于显示概览)
|
|
456
|
-
fields: 只显示指定字段
|
|
457
|
-
"""
|
|
458
|
-
if not samples:
|
|
459
|
-
print("没有数据")
|
|
460
|
-
return
|
|
461
|
-
|
|
462
|
-
# 过滤字段
|
|
463
|
-
if fields and isinstance(samples[0], dict):
|
|
464
|
-
field_set = set(fields)
|
|
465
|
-
samples = [{k: v for k, v in item.items() if k in field_set} for item in samples]
|
|
466
|
-
|
|
467
|
-
try:
|
|
468
|
-
from rich.console import Console
|
|
469
|
-
from rich.panel import Panel
|
|
470
|
-
from rich.table import Table
|
|
471
|
-
|
|
472
|
-
console = Console()
|
|
473
|
-
|
|
474
|
-
# 显示数据概览头部
|
|
475
|
-
if filename:
|
|
476
|
-
all_fields = set()
|
|
477
|
-
for item in samples:
|
|
478
|
-
if isinstance(item, dict):
|
|
479
|
-
all_fields.update(item.keys())
|
|
480
|
-
field_names = ", ".join(sorted(all_fields))
|
|
481
|
-
|
|
482
|
-
if total_count is not None:
|
|
483
|
-
info = f"总行数: {total_count:,} | 采样: {len(samples)} 条 | 字段: {len(all_fields)} 个"
|
|
484
|
-
else:
|
|
485
|
-
info = f"采样: {len(samples)} 条 | 字段: {len(all_fields)} 个"
|
|
486
|
-
|
|
487
|
-
console.print(
|
|
488
|
-
Panel(
|
|
489
|
-
f"[dim]{info}[/dim]\n[dim]字段: {field_names}[/dim]",
|
|
490
|
-
title=f"[bold]📊 {filename}[/bold]",
|
|
491
|
-
expand=False,
|
|
492
|
-
border_style="dim",
|
|
493
|
-
)
|
|
494
|
-
)
|
|
495
|
-
console.print()
|
|
496
|
-
|
|
497
|
-
# 简单数据用表格展示
|
|
498
|
-
if _is_simple_data(samples):
|
|
499
|
-
keys = list(samples[0].keys())
|
|
500
|
-
table = Table(show_header=True, header_style="bold cyan")
|
|
501
|
-
for key in keys:
|
|
502
|
-
table.add_column(key, overflow="fold")
|
|
503
|
-
for item in samples:
|
|
504
|
-
table.add_row(*[str(item.get(k, "")) for k in keys])
|
|
505
|
-
console.print(table)
|
|
506
|
-
return
|
|
507
|
-
|
|
508
|
-
# 嵌套数据用树形结构展示
|
|
509
|
-
for i, item in enumerate(samples, 1):
|
|
510
|
-
console.print(f"[bold cyan]--- 第 {i} 条 ---[/bold cyan]")
|
|
511
|
-
if isinstance(item, dict):
|
|
512
|
-
for line in _format_nested(item):
|
|
513
|
-
console.print(line)
|
|
514
|
-
else:
|
|
515
|
-
console.print(_format_value(item))
|
|
516
|
-
console.print()
|
|
517
|
-
|
|
518
|
-
except ImportError:
|
|
519
|
-
# 没有 rich,使用普通打印
|
|
520
|
-
if filename:
|
|
521
|
-
all_fields = set()
|
|
522
|
-
for item in samples:
|
|
523
|
-
if isinstance(item, dict):
|
|
524
|
-
all_fields.update(item.keys())
|
|
525
|
-
|
|
526
|
-
print(f"\n📊 {filename}")
|
|
527
|
-
if total_count is not None:
|
|
528
|
-
print(
|
|
529
|
-
f" 总行数: {total_count:,} | 采样: {len(samples)} 条 | 字段: {len(all_fields)} 个"
|
|
530
|
-
)
|
|
531
|
-
else:
|
|
532
|
-
print(f" 采样: {len(samples)} 条 | 字段: {len(all_fields)} 个")
|
|
533
|
-
print(f" 字段: {', '.join(sorted(all_fields))}")
|
|
534
|
-
print()
|
|
535
|
-
|
|
536
|
-
for i, item in enumerate(samples, 1):
|
|
537
|
-
print(f"--- 第 {i} 条 ---")
|
|
538
|
-
print(orjson.dumps(item, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
|
539
|
-
print()
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
# ============ Transform Command ============
|
|
543
|
-
|
|
544
|
-
CONFIG_DIR = ".dt"
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
def _get_config_path(input_path: Path, config_override: Optional[str] = None) -> Path:
|
|
548
|
-
"""获取配置文件路径"""
|
|
549
|
-
if config_override:
|
|
550
|
-
return Path(config_override)
|
|
551
|
-
|
|
552
|
-
# 使用输入文件名(不含扩展名)作为配置文件名
|
|
553
|
-
config_name = input_path.stem + ".py"
|
|
554
|
-
return input_path.parent / CONFIG_DIR / config_name
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
def transform(
|
|
558
|
-
filename: str,
|
|
559
|
-
num: Optional[int] = None,
|
|
560
|
-
preset: Optional[str] = None,
|
|
561
|
-
config: Optional[str] = None,
|
|
562
|
-
output: Optional[str] = None,
|
|
563
|
-
) -> None:
|
|
564
|
-
"""
|
|
565
|
-
转换数据格式。
|
|
566
|
-
|
|
567
|
-
两种使用方式:
|
|
568
|
-
1. 配置文件模式(默认):自动生成配置文件,编辑后再次运行
|
|
569
|
-
2. 预设模式:使用 --preset 直接转换
|
|
570
|
-
|
|
571
|
-
Args:
|
|
572
|
-
filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
|
|
573
|
-
num: 只转换前 N 条数据(可选)
|
|
574
|
-
preset: 使用预设模板(openai_chat, alpaca, sharegpt, dpo_pair, simple_qa)
|
|
575
|
-
config: 配置文件路径(可选,默认 .dt/<filename>.py)
|
|
576
|
-
output: 输出文件路径
|
|
577
|
-
|
|
578
|
-
Examples:
|
|
579
|
-
dt transform data.jsonl # 首次生成配置
|
|
580
|
-
dt transform data.jsonl 10 # 只转换前 10 条
|
|
581
|
-
dt transform data.jsonl --preset=openai_chat # 使用预设
|
|
582
|
-
dt transform data.jsonl 100 --preset=alpaca # 预设 + 限制数量
|
|
583
|
-
"""
|
|
584
|
-
filepath = Path(filename)
|
|
585
|
-
if not filepath.exists():
|
|
586
|
-
print(f"错误: 文件不存在 - {filename}")
|
|
587
|
-
return
|
|
588
|
-
|
|
589
|
-
if not _check_file_format(filepath):
|
|
590
|
-
return
|
|
591
|
-
|
|
592
|
-
# 预设模式:直接使用预设转换
|
|
593
|
-
if preset:
|
|
594
|
-
_execute_preset_transform(filepath, preset, output, num)
|
|
595
|
-
return
|
|
596
|
-
|
|
597
|
-
# 配置文件模式
|
|
598
|
-
config_path = _get_config_path(filepath, config)
|
|
599
|
-
|
|
600
|
-
if not config_path.exists():
|
|
601
|
-
_generate_config(filepath, config_path)
|
|
602
|
-
else:
|
|
603
|
-
_execute_transform(filepath, config_path, output, num)
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
def _generate_config(input_path: Path, config_path: Path) -> None:
|
|
607
|
-
"""分析输入数据并生成配置文件"""
|
|
608
|
-
print(f"📊 分析输入数据: {input_path}")
|
|
609
|
-
|
|
610
|
-
# 读取数据
|
|
611
|
-
try:
|
|
612
|
-
data = load_data(str(input_path))
|
|
613
|
-
except Exception as e:
|
|
614
|
-
print(f"错误: 无法读取文件 - {e}")
|
|
615
|
-
return
|
|
616
|
-
|
|
617
|
-
if not data:
|
|
618
|
-
print("错误: 文件为空")
|
|
619
|
-
return
|
|
620
|
-
|
|
621
|
-
total_count = len(data)
|
|
622
|
-
sample_item = data[0]
|
|
623
|
-
|
|
624
|
-
print(f" 检测到 {total_count} 条数据")
|
|
625
|
-
|
|
626
|
-
# 生成配置内容
|
|
627
|
-
config_content = _build_config_content(sample_item, input_path.name, total_count)
|
|
628
|
-
|
|
629
|
-
# 确保配置目录存在
|
|
630
|
-
config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
631
|
-
|
|
632
|
-
# 写入配置文件
|
|
633
|
-
config_path.write_text(config_content, encoding="utf-8")
|
|
634
|
-
|
|
635
|
-
print(f"\n📝 已生成配置文件: {config_path}")
|
|
636
|
-
print("\n👉 下一步:")
|
|
637
|
-
print(f" 1. 编辑 {config_path},定义 transform 函数")
|
|
638
|
-
print(f" 2. 再次执行 dt transform {input_path.name} 完成转换")
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
def _build_config_content(sample: Dict[str, Any], filename: str, total: int) -> str:
|
|
642
|
-
"""构建配置文件内容"""
|
|
643
|
-
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
644
|
-
|
|
645
|
-
# 生成 Item 类的字段定义
|
|
646
|
-
fields_def = _generate_fields_definition(sample)
|
|
647
|
-
|
|
648
|
-
# 生成默认的 transform 函数(简单重命名)
|
|
649
|
-
field_names = list(sample.keys())
|
|
650
|
-
|
|
651
|
-
# 生成规范化的字段名用于示例
|
|
652
|
-
safe_field1 = _sanitize_field_name(field_names[0])[0] if field_names else "field1"
|
|
653
|
-
safe_field2 = _sanitize_field_name(field_names[1])[0] if len(field_names) > 1 else "field2"
|
|
654
|
-
|
|
655
|
-
# 生成默认输出文件名
|
|
656
|
-
base_name = Path(filename).stem
|
|
657
|
-
output_filename = f"{base_name}_output.jsonl"
|
|
658
|
-
|
|
659
|
-
config = f'''"""
|
|
660
|
-
DataTransformer 配置文件
|
|
661
|
-
生成时间: {now}
|
|
662
|
-
输入文件: {filename} ({total} 条)
|
|
663
|
-
"""
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
# ===== 输入数据结构(自动生成,IDE 可补全)=====
|
|
667
|
-
|
|
668
|
-
class Item:
|
|
669
|
-
{fields_def}
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
# ===== 定义转换逻辑 =====
|
|
673
|
-
# 提示:输入 item. 后 IDE 会自动补全可用字段
|
|
674
|
-
|
|
675
|
-
def transform(item: Item):
|
|
676
|
-
return {{
|
|
677
|
-
{_generate_default_transform(field_names)}
|
|
678
|
-
}}
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
# 输出文件路径
|
|
682
|
-
output = "{output_filename}"
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
# ===== 示例 =====
|
|
686
|
-
#
|
|
687
|
-
# 示例1: 构建 OpenAI Chat 格式
|
|
688
|
-
# def transform(item: Item):
|
|
689
|
-
# return {{
|
|
690
|
-
# "messages": [
|
|
691
|
-
# {{"role": "user", "content": item.{safe_field1}}},
|
|
692
|
-
# {{"role": "assistant", "content": item.{safe_field2}}},
|
|
693
|
-
# ]
|
|
694
|
-
# }}
|
|
695
|
-
#
|
|
696
|
-
# 示例2: Alpaca 格式
|
|
697
|
-
# def transform(item: Item):
|
|
698
|
-
# return {{
|
|
699
|
-
# "instruction": item.{safe_field1},
|
|
700
|
-
# "input": "",
|
|
701
|
-
# "output": item.{safe_field2},
|
|
702
|
-
# }}
|
|
703
|
-
'''
|
|
704
|
-
return config
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
def _generate_fields_definition(sample: Dict[str, Any], indent: int = 4) -> str:
|
|
708
|
-
"""生成 Item 类的字段定义"""
|
|
709
|
-
lines = []
|
|
710
|
-
prefix = " " * indent
|
|
711
|
-
|
|
712
|
-
for key, value in sample.items():
|
|
713
|
-
type_name = _get_type_name(value)
|
|
714
|
-
example = _format_example_value(value)
|
|
715
|
-
safe_key, changed = _sanitize_field_name(key)
|
|
716
|
-
comment = f" # 原字段名: {key}" if changed else ""
|
|
717
|
-
lines.append(f"{prefix}{safe_key}: {type_name} = {example}{comment}")
|
|
718
|
-
|
|
719
|
-
return "\n".join(lines) if lines else f"{prefix}pass"
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
def _get_type_name(value: Any) -> str:
|
|
723
|
-
"""获取值的类型名称"""
|
|
724
|
-
if value is None:
|
|
725
|
-
return "str"
|
|
726
|
-
if isinstance(value, str):
|
|
727
|
-
return "str"
|
|
728
|
-
if isinstance(value, bool):
|
|
729
|
-
return "bool"
|
|
730
|
-
if isinstance(value, int):
|
|
731
|
-
return "int"
|
|
732
|
-
if isinstance(value, float):
|
|
733
|
-
return "float"
|
|
734
|
-
if isinstance(value, list):
|
|
735
|
-
return "list"
|
|
736
|
-
if isinstance(value, dict):
|
|
737
|
-
return "dict"
|
|
738
|
-
return "str"
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
def _format_example_value(value: Any, max_len: int = 50) -> str:
|
|
742
|
-
"""格式化示例值"""
|
|
743
|
-
if value is None:
|
|
744
|
-
return '""'
|
|
745
|
-
if isinstance(value, str):
|
|
746
|
-
# 截断长字符串
|
|
747
|
-
if len(value) > max_len:
|
|
748
|
-
value = value[:max_len] + "..."
|
|
749
|
-
# 使用 repr() 自动处理所有转义字符
|
|
750
|
-
return repr(value)
|
|
751
|
-
if isinstance(value, bool):
|
|
752
|
-
return str(value)
|
|
753
|
-
if isinstance(value, (int, float)):
|
|
754
|
-
return str(value)
|
|
755
|
-
if isinstance(value, (list, dict)):
|
|
756
|
-
s = orjson.dumps(value).decode("utf-8")
|
|
757
|
-
if len(s) > max_len:
|
|
758
|
-
return repr(s[:max_len] + "...")
|
|
759
|
-
return s
|
|
760
|
-
return '""'
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
def _sanitize_field_name(name: str) -> tuple:
|
|
764
|
-
"""
|
|
765
|
-
将字段名规范化为合法的 Python 标识符。
|
|
766
|
-
|
|
767
|
-
Returns:
|
|
768
|
-
tuple: (规范化后的名称, 是否被修改)
|
|
769
|
-
"""
|
|
770
|
-
if name.isidentifier():
|
|
771
|
-
return name, False
|
|
772
|
-
|
|
773
|
-
# 替换常见的非法字符
|
|
774
|
-
sanitized = name.replace("-", "_").replace(" ", "_").replace(".", "_")
|
|
775
|
-
|
|
776
|
-
# 如果以数字开头,添加前缀
|
|
777
|
-
if sanitized and sanitized[0].isdigit():
|
|
778
|
-
sanitized = "f_" + sanitized
|
|
779
|
-
|
|
780
|
-
# 移除其他非法字符
|
|
781
|
-
sanitized = "".join(c if c.isalnum() or c == "_" else "_" for c in sanitized)
|
|
782
|
-
|
|
783
|
-
# 确保不为空
|
|
784
|
-
if not sanitized:
|
|
785
|
-
sanitized = "field"
|
|
786
|
-
|
|
787
|
-
return sanitized, True
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
def _generate_default_transform(field_names: List[str]) -> str:
|
|
791
|
-
"""生成默认的 transform 函数体"""
|
|
792
|
-
lines = []
|
|
793
|
-
for name in field_names[:5]: # 最多显示 5 个字段
|
|
794
|
-
safe_name, _ = _sanitize_field_name(name)
|
|
795
|
-
lines.append(f' "{name}": item.{safe_name},')
|
|
796
|
-
return "\n".join(lines) if lines else " # 在这里定义输出字段"
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
def _unwrap(obj: Any) -> Any:
|
|
800
|
-
"""递归将 DictWrapper 转换为普通 dict"""
|
|
801
|
-
if hasattr(obj, "to_dict"):
|
|
802
|
-
return _unwrap(obj.to_dict())
|
|
803
|
-
if isinstance(obj, dict):
|
|
804
|
-
return {k: _unwrap(v) for k, v in obj.items()}
|
|
805
|
-
if isinstance(obj, list):
|
|
806
|
-
return [_unwrap(v) for v in obj]
|
|
807
|
-
return obj
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
def _execute_transform(
|
|
811
|
-
input_path: Path,
|
|
812
|
-
config_path: Path,
|
|
813
|
-
output_override: Optional[str],
|
|
814
|
-
num: Optional[int],
|
|
815
|
-
) -> None:
|
|
816
|
-
"""执行数据转换(默认流式处理)"""
|
|
817
|
-
print(f"📂 加载配置: {config_path}")
|
|
818
|
-
|
|
819
|
-
# 动态加载配置文件
|
|
820
|
-
try:
|
|
821
|
-
config_ns = _load_config(config_path)
|
|
822
|
-
except Exception as e:
|
|
823
|
-
print(f"错误: 无法加载配置文件 - {e}")
|
|
824
|
-
return
|
|
825
|
-
|
|
826
|
-
# 获取 transform 函数
|
|
827
|
-
if "transform" not in config_ns:
|
|
828
|
-
print("错误: 配置文件中未定义 transform 函数")
|
|
829
|
-
return
|
|
830
|
-
|
|
831
|
-
transform_func = config_ns["transform"]
|
|
832
|
-
|
|
833
|
-
# 获取输出路径
|
|
834
|
-
output_path = output_override or config_ns.get("output", "output.jsonl")
|
|
835
|
-
|
|
836
|
-
# 对于 JSONL 文件使用流式处理
|
|
837
|
-
if _is_streaming_supported(input_path):
|
|
838
|
-
print(f"📊 流式加载: {input_path}")
|
|
839
|
-
print("🔄 执行转换...")
|
|
840
|
-
try:
|
|
841
|
-
# 包装转换函数以支持属性访问(配置文件中定义的 Item 类)
|
|
842
|
-
def wrapped_transform(item):
|
|
843
|
-
result = transform_func(DictWrapper(item))
|
|
844
|
-
return _unwrap(result)
|
|
845
|
-
|
|
846
|
-
st = load_stream(str(input_path))
|
|
847
|
-
if num:
|
|
848
|
-
st = st.head(num)
|
|
849
|
-
count = st.transform(wrapped_transform).save(output_path)
|
|
850
|
-
print(f"💾 保存结果: {output_path}")
|
|
851
|
-
print(f"\n✅ 完成! 已转换 {count} 条数据到 {output_path}")
|
|
852
|
-
except Exception as e:
|
|
853
|
-
print(f"错误: 转换失败 - {e}")
|
|
854
|
-
import traceback
|
|
855
|
-
|
|
856
|
-
traceback.print_exc()
|
|
857
|
-
return
|
|
858
|
-
|
|
859
|
-
# 非 JSONL 文件使用传统方式
|
|
860
|
-
print(f"📊 加载数据: {input_path}")
|
|
861
|
-
try:
|
|
862
|
-
dt = DataTransformer.load(str(input_path))
|
|
863
|
-
except Exception as e:
|
|
864
|
-
print(f"错误: 无法读取文件 - {e}")
|
|
865
|
-
return
|
|
866
|
-
|
|
867
|
-
total = len(dt)
|
|
868
|
-
if num:
|
|
869
|
-
dt = DataTransformer(dt.data[:num])
|
|
870
|
-
print(f" 处理前 {len(dt)}/{total} 条数据")
|
|
871
|
-
else:
|
|
872
|
-
print(f" 共 {total} 条数据")
|
|
873
|
-
|
|
874
|
-
# 执行转换(使用 Core 的 to 方法,自动支持属性访问)
|
|
875
|
-
print("🔄 执行转换...")
|
|
876
|
-
try:
|
|
877
|
-
results = dt.to(transform_func)
|
|
878
|
-
except Exception as e:
|
|
879
|
-
print(f"错误: 转换失败 - {e}")
|
|
880
|
-
import traceback
|
|
881
|
-
|
|
882
|
-
traceback.print_exc()
|
|
883
|
-
return
|
|
884
|
-
|
|
885
|
-
# 保存结果
|
|
886
|
-
print(f"💾 保存结果: {output_path}")
|
|
887
|
-
try:
|
|
888
|
-
save_data(results, output_path)
|
|
889
|
-
except Exception as e:
|
|
890
|
-
print(f"错误: 无法保存文件 - {e}")
|
|
891
|
-
return
|
|
892
|
-
|
|
893
|
-
print(f"\n✅ 完成! 已转换 {len(results)} 条数据到 {output_path}")
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
def _execute_preset_transform(
|
|
897
|
-
input_path: Path,
|
|
898
|
-
preset_name: str,
|
|
899
|
-
output_override: Optional[str],
|
|
900
|
-
num: Optional[int],
|
|
901
|
-
) -> None:
|
|
902
|
-
"""使用预设模板执行转换(默认流式处理)"""
|
|
903
|
-
print(f"📂 使用预设: {preset_name}")
|
|
904
|
-
|
|
905
|
-
# 获取预设函数
|
|
906
|
-
try:
|
|
907
|
-
transform_func = get_preset(preset_name)
|
|
908
|
-
except ValueError as e:
|
|
909
|
-
print(f"错误: {e}")
|
|
910
|
-
print(f"可用预设: {', '.join(list_presets())}")
|
|
911
|
-
return
|
|
912
|
-
|
|
913
|
-
output_path = output_override or f"{input_path.stem}_{preset_name}.jsonl"
|
|
914
|
-
|
|
915
|
-
# 检查输入输出是否相同
|
|
916
|
-
input_resolved = input_path.resolve()
|
|
917
|
-
output_resolved = Path(output_path).resolve()
|
|
918
|
-
use_temp_file = input_resolved == output_resolved
|
|
919
|
-
|
|
920
|
-
# 对于 JSONL 文件使用流式处理
|
|
921
|
-
if _is_streaming_supported(input_path):
|
|
922
|
-
print(f"📊 流式加载: {input_path}")
|
|
923
|
-
print("🔄 执行转换...")
|
|
924
|
-
|
|
925
|
-
# 如果输入输出相同,使用临时文件
|
|
926
|
-
if use_temp_file:
|
|
927
|
-
print("⚠ 检测到输出文件与输入文件相同,将使用临时文件")
|
|
928
|
-
temp_fd, temp_path = tempfile.mkstemp(
|
|
929
|
-
suffix=output_resolved.suffix,
|
|
930
|
-
prefix=".tmp_",
|
|
931
|
-
dir=output_resolved.parent,
|
|
932
|
-
)
|
|
933
|
-
os.close(temp_fd)
|
|
934
|
-
actual_output = temp_path
|
|
935
|
-
else:
|
|
936
|
-
actual_output = output_path
|
|
937
|
-
|
|
938
|
-
try:
|
|
939
|
-
# 包装转换函数以支持属性访问
|
|
940
|
-
def wrapped_transform(item):
|
|
941
|
-
result = transform_func(DictWrapper(item))
|
|
942
|
-
return _unwrap(result)
|
|
943
|
-
|
|
944
|
-
st = load_stream(str(input_path))
|
|
945
|
-
if num:
|
|
946
|
-
st = st.head(num)
|
|
947
|
-
count = st.transform(wrapped_transform).save(actual_output)
|
|
948
|
-
|
|
949
|
-
# 如果使用了临时文件,移动到目标位置
|
|
950
|
-
if use_temp_file:
|
|
951
|
-
shutil.move(temp_path, output_path)
|
|
952
|
-
|
|
953
|
-
print(f"💾 保存结果: {output_path}")
|
|
954
|
-
print(f"\n✅ 完成! 已转换 {count} 条数据到 {output_path}")
|
|
955
|
-
except Exception as e:
|
|
956
|
-
# 清理临时文件
|
|
957
|
-
if use_temp_file and os.path.exists(temp_path):
|
|
958
|
-
os.unlink(temp_path)
|
|
959
|
-
print(f"错误: 转换失败 - {e}")
|
|
960
|
-
import traceback
|
|
961
|
-
|
|
962
|
-
traceback.print_exc()
|
|
963
|
-
return
|
|
964
|
-
|
|
965
|
-
# 非 JSONL 文件使用传统方式
|
|
966
|
-
print(f"📊 加载数据: {input_path}")
|
|
967
|
-
try:
|
|
968
|
-
dt = DataTransformer.load(str(input_path))
|
|
969
|
-
except Exception as e:
|
|
970
|
-
print(f"错误: 无法读取文件 - {e}")
|
|
971
|
-
return
|
|
972
|
-
|
|
973
|
-
total = len(dt)
|
|
974
|
-
if num:
|
|
975
|
-
dt = DataTransformer(dt.data[:num])
|
|
976
|
-
print(f" 处理前 {len(dt)}/{total} 条数据")
|
|
977
|
-
else:
|
|
978
|
-
print(f" 共 {total} 条数据")
|
|
979
|
-
|
|
980
|
-
# 执行转换
|
|
981
|
-
print("🔄 执行转换...")
|
|
982
|
-
try:
|
|
983
|
-
results = dt.to(transform_func)
|
|
984
|
-
except Exception as e:
|
|
985
|
-
print(f"错误: 转换失败 - {e}")
|
|
986
|
-
import traceback
|
|
987
|
-
|
|
988
|
-
traceback.print_exc()
|
|
989
|
-
return
|
|
990
|
-
|
|
991
|
-
# 保存结果
|
|
992
|
-
print(f"💾 保存结果: {output_path}")
|
|
993
|
-
try:
|
|
994
|
-
save_data(results, output_path)
|
|
995
|
-
except Exception as e:
|
|
996
|
-
print(f"错误: 无法保存文件 - {e}")
|
|
997
|
-
return
|
|
998
|
-
|
|
999
|
-
print(f"\n✅ 完成! 已转换 {len(results)} 条数据到 {output_path}")
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
def _load_config(config_path: Path) -> Dict[str, Any]:
|
|
1003
|
-
"""动态加载 Python 配置文件"""
|
|
1004
|
-
import importlib.util
|
|
1005
|
-
|
|
1006
|
-
spec = importlib.util.spec_from_file_location("dt_config", config_path)
|
|
1007
|
-
module = importlib.util.module_from_spec(spec)
|
|
1008
|
-
spec.loader.exec_module(module)
|
|
1009
|
-
|
|
1010
|
-
return {name: getattr(module, name) for name in dir(module) if not name.startswith("_")}
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
# ============ Dedupe Command ============
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
def dedupe(
|
|
1017
|
-
filename: str,
|
|
1018
|
-
key: Optional[str] = None,
|
|
1019
|
-
similar: Optional[float] = None,
|
|
1020
|
-
output: Optional[str] = None,
|
|
1021
|
-
) -> None:
|
|
1022
|
-
"""
|
|
1023
|
-
数据去重。
|
|
1024
|
-
|
|
1025
|
-
支持两种模式:
|
|
1026
|
-
1. 精确去重(默认):完全相同的数据才去重
|
|
1027
|
-
2. 相似度去重:使用 MinHash+LSH 算法,相似度超过阈值则去重
|
|
1028
|
-
|
|
1029
|
-
Args:
|
|
1030
|
-
filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
|
|
1031
|
-
key: 去重依据字段,支持嵌套路径语法:
|
|
1032
|
-
- meta.source 嵌套字段
|
|
1033
|
-
- messages[0].role 数组索引
|
|
1034
|
-
- messages[-1].content 负索引
|
|
1035
|
-
- messages.# 数组长度
|
|
1036
|
-
- messages[*].role:join 展开所有元素
|
|
1037
|
-
多个字段用逗号分隔。不指定则全量去重
|
|
1038
|
-
similar: 相似度阈值(0-1),指定后启用相似度去重模式,需要指定 --key
|
|
1039
|
-
output: 输出文件路径,不指定则覆盖原文件
|
|
1040
|
-
|
|
1041
|
-
Examples:
|
|
1042
|
-
dt dedupe data.jsonl # 全量精确去重
|
|
1043
|
-
dt dedupe data.jsonl --key=text # 按 text 字段精确去重
|
|
1044
|
-
dt dedupe data.jsonl --key=user,timestamp # 按多字段组合精确去重
|
|
1045
|
-
dt dedupe data.jsonl --key=meta.id # 按嵌套字段去重
|
|
1046
|
-
dt dedupe data.jsonl --key=messages[0].content # 按第一条消息内容去重
|
|
1047
|
-
dt dedupe data.jsonl --key=text --similar=0.8 # 相似度去重
|
|
1048
|
-
"""
|
|
1049
|
-
filepath = Path(filename)
|
|
1050
|
-
|
|
1051
|
-
if not filepath.exists():
|
|
1052
|
-
print(f"错误: 文件不存在 - {filename}")
|
|
1053
|
-
return
|
|
1054
|
-
|
|
1055
|
-
if not _check_file_format(filepath):
|
|
1056
|
-
return
|
|
1057
|
-
|
|
1058
|
-
# 相似度去重模式必须指定 key
|
|
1059
|
-
if similar is not None and not key:
|
|
1060
|
-
print("错误: 相似度去重需要指定 --key 参数")
|
|
1061
|
-
return
|
|
1062
|
-
|
|
1063
|
-
if similar is not None and (similar <= 0 or similar > 1):
|
|
1064
|
-
print("错误: --similar 参数必须在 0-1 之间")
|
|
1065
|
-
return
|
|
1066
|
-
|
|
1067
|
-
# 加载数据
|
|
1068
|
-
print(f"📊 加载数据: {filepath}")
|
|
1069
|
-
try:
|
|
1070
|
-
dt = DataTransformer.load(str(filepath))
|
|
1071
|
-
except Exception as e:
|
|
1072
|
-
print(f"错误: 无法读取文件 - {e}")
|
|
1073
|
-
return
|
|
1074
|
-
|
|
1075
|
-
original_count = len(dt)
|
|
1076
|
-
print(f" 共 {original_count} 条数据")
|
|
1077
|
-
|
|
1078
|
-
# 执行去重
|
|
1079
|
-
if similar is not None:
|
|
1080
|
-
# 相似度去重模式
|
|
1081
|
-
print(f"🔑 相似度去重: 字段={key}, 阈值={similar}")
|
|
1082
|
-
print("🔄 执行去重(MinHash+LSH)...")
|
|
1083
|
-
try:
|
|
1084
|
-
result = dt.dedupe_similar(key, threshold=similar)
|
|
1085
|
-
except ImportError as e:
|
|
1086
|
-
print(f"错误: {e}")
|
|
1087
|
-
return
|
|
1088
|
-
else:
|
|
1089
|
-
# 精确去重模式
|
|
1090
|
-
dedupe_key: Any = None
|
|
1091
|
-
if key:
|
|
1092
|
-
keys = [k.strip() for k in key.split(",")]
|
|
1093
|
-
if len(keys) == 1:
|
|
1094
|
-
dedupe_key = keys[0]
|
|
1095
|
-
print(f"🔑 按字段精确去重: {dedupe_key}")
|
|
1096
|
-
else:
|
|
1097
|
-
dedupe_key = keys
|
|
1098
|
-
print(f"🔑 按多字段组合精确去重: {', '.join(dedupe_key)}")
|
|
1099
|
-
else:
|
|
1100
|
-
print("🔑 全量精确去重")
|
|
1101
|
-
|
|
1102
|
-
print("🔄 执行去重...")
|
|
1103
|
-
result = dt.dedupe(dedupe_key)
|
|
1104
|
-
|
|
1105
|
-
dedupe_count = len(result)
|
|
1106
|
-
removed_count = original_count - dedupe_count
|
|
1107
|
-
|
|
1108
|
-
# 保存结果
|
|
1109
|
-
output_path = output or str(filepath)
|
|
1110
|
-
print(f"💾 保存结果: {output_path}")
|
|
1111
|
-
try:
|
|
1112
|
-
result.save(output_path)
|
|
1113
|
-
except Exception as e:
|
|
1114
|
-
print(f"错误: 无法保存文件 - {e}")
|
|
1115
|
-
return
|
|
1116
|
-
|
|
1117
|
-
print(f"\n✅ 完成! 去除 {removed_count} 条重复数据,剩余 {dedupe_count} 条")
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
# ============ Concat Command ============
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
def concat(
|
|
1124
|
-
*files: str,
|
|
1125
|
-
output: Optional[str] = None,
|
|
1126
|
-
strict: bool = False,
|
|
1127
|
-
) -> None:
|
|
1128
|
-
"""
|
|
1129
|
-
拼接多个数据文件(流式处理,内存占用 O(1))。
|
|
1130
|
-
|
|
1131
|
-
Args:
|
|
1132
|
-
*files: 输入文件路径列表,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
|
|
1133
|
-
output: 输出文件路径,必须指定
|
|
1134
|
-
strict: 严格模式,字段必须完全一致,否则报错
|
|
1135
|
-
|
|
1136
|
-
Examples:
|
|
1137
|
-
dt concat a.jsonl b.jsonl -o merged.jsonl
|
|
1138
|
-
dt concat data1.csv data2.csv data3.csv -o all.jsonl
|
|
1139
|
-
dt concat a.jsonl b.jsonl --strict -o merged.jsonl
|
|
1140
|
-
"""
|
|
1141
|
-
if len(files) < 2:
|
|
1142
|
-
print("错误: 至少需要两个文件")
|
|
1143
|
-
return
|
|
1144
|
-
|
|
1145
|
-
if not output:
|
|
1146
|
-
print("错误: 必须指定输出文件 (-o/--output)")
|
|
1147
|
-
return
|
|
1148
|
-
|
|
1149
|
-
# 验证所有文件
|
|
1150
|
-
file_paths = []
|
|
1151
|
-
for f in files:
|
|
1152
|
-
filepath = Path(f).resolve() # 使用绝对路径进行比较
|
|
1153
|
-
if not filepath.exists():
|
|
1154
|
-
print(f"错误: 文件不存在 - {f}")
|
|
1155
|
-
return
|
|
1156
|
-
if not _check_file_format(filepath):
|
|
1157
|
-
return
|
|
1158
|
-
file_paths.append(filepath)
|
|
1159
|
-
|
|
1160
|
-
# 检查输出文件是否与输入文件冲突
|
|
1161
|
-
output_path = Path(output).resolve()
|
|
1162
|
-
use_temp_file = output_path in file_paths
|
|
1163
|
-
if use_temp_file:
|
|
1164
|
-
print("⚠ 检测到输出文件与输入文件相同,将使用临时文件")
|
|
1165
|
-
|
|
1166
|
-
# 流式分析字段(只读取每个文件的第一行)
|
|
1167
|
-
print("📊 文件字段分析:")
|
|
1168
|
-
file_fields = [] # [(filepath, fields)]
|
|
1169
|
-
|
|
1170
|
-
for filepath in file_paths:
|
|
1171
|
-
try:
|
|
1172
|
-
# 只读取第一行来获取字段(根据格式选择加载方式)
|
|
1173
|
-
if _is_streaming_supported(filepath):
|
|
1174
|
-
first_row = load_stream(str(filepath)).head(1).collect()
|
|
1175
|
-
else:
|
|
1176
|
-
# 非流式格式(如 .json, .xlsx)使用全量加载
|
|
1177
|
-
data = load_data(str(filepath))
|
|
1178
|
-
first_row = data[:1] if data else []
|
|
1179
|
-
if not first_row:
|
|
1180
|
-
print(f"警告: 文件为空 - {filepath}")
|
|
1181
|
-
fields = set()
|
|
1182
|
-
else:
|
|
1183
|
-
fields = set(first_row[0].keys())
|
|
1184
|
-
except Exception as e:
|
|
1185
|
-
print(f"错误: 无法读取文件 {filepath} - {e}")
|
|
1186
|
-
return
|
|
1187
|
-
|
|
1188
|
-
file_fields.append((filepath, fields))
|
|
1189
|
-
fields_str = ", ".join(sorted(fields)) if fields else "(空)"
|
|
1190
|
-
print(f" {filepath.name}: {fields_str}")
|
|
1191
|
-
|
|
1192
|
-
# 分析字段差异
|
|
1193
|
-
all_fields = set()
|
|
1194
|
-
common_fields = None
|
|
1195
|
-
for _, fields in file_fields:
|
|
1196
|
-
all_fields.update(fields)
|
|
1197
|
-
if common_fields is None:
|
|
1198
|
-
common_fields = fields.copy()
|
|
1199
|
-
else:
|
|
1200
|
-
common_fields &= fields
|
|
1201
|
-
|
|
1202
|
-
common_fields = common_fields or set()
|
|
1203
|
-
diff_fields = all_fields - common_fields
|
|
1204
|
-
|
|
1205
|
-
if diff_fields:
|
|
1206
|
-
if strict:
|
|
1207
|
-
print(f"\n❌ 严格模式: 字段不一致")
|
|
1208
|
-
print(f" 共同字段: {', '.join(sorted(common_fields)) or '(无)'}")
|
|
1209
|
-
print(f" 差异字段: {', '.join(sorted(diff_fields))}")
|
|
1210
|
-
return
|
|
1211
|
-
else:
|
|
1212
|
-
print(f"\n⚠ 字段差异: {', '.join(sorted(diff_fields))} 仅在部分文件中存在")
|
|
1213
|
-
|
|
1214
|
-
# 流式拼接
|
|
1215
|
-
print("\n🔄 流式拼接...")
|
|
1216
|
-
|
|
1217
|
-
# 如果输出文件与输入文件冲突,使用临时文件(在输出文件同一目录下)
|
|
1218
|
-
if use_temp_file:
|
|
1219
|
-
output_dir = output_path.parent
|
|
1220
|
-
temp_fd, temp_path = tempfile.mkstemp(
|
|
1221
|
-
suffix=output_path.suffix,
|
|
1222
|
-
prefix=".tmp_",
|
|
1223
|
-
dir=output_dir,
|
|
1224
|
-
)
|
|
1225
|
-
os.close(temp_fd)
|
|
1226
|
-
actual_output = temp_path
|
|
1227
|
-
print(f"💾 写入临时文件: {temp_path}")
|
|
1228
|
-
else:
|
|
1229
|
-
actual_output = output
|
|
1230
|
-
print(f"💾 保存结果: {output}")
|
|
1231
|
-
|
|
1232
|
-
try:
|
|
1233
|
-
total_count = _concat_streaming(file_paths, actual_output)
|
|
1234
|
-
|
|
1235
|
-
# 如果使用了临时文件,重命名为目标文件
|
|
1236
|
-
if use_temp_file:
|
|
1237
|
-
shutil.move(temp_path, output)
|
|
1238
|
-
print(f"💾 移动到目标文件: {output}")
|
|
1239
|
-
except Exception as e:
|
|
1240
|
-
# 清理临时文件
|
|
1241
|
-
if use_temp_file and os.path.exists(temp_path):
|
|
1242
|
-
os.unlink(temp_path)
|
|
1243
|
-
print(f"错误: 拼接失败 - {e}")
|
|
1244
|
-
return
|
|
1245
|
-
|
|
1246
|
-
file_count = len(files)
|
|
1247
|
-
print(f"\n✅ 完成! 已合并 {file_count} 个文件,共 {total_count} 条数据到 {output}")
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
def _concat_streaming(file_paths: List[Path], output: str) -> int:
|
|
1251
|
-
"""流式拼接多个文件"""
|
|
1252
|
-
from ..streaming import (
|
|
1253
|
-
StreamingTransformer,
|
|
1254
|
-
_stream_arrow,
|
|
1255
|
-
_stream_csv,
|
|
1256
|
-
_stream_jsonl,
|
|
1257
|
-
_stream_parquet,
|
|
1258
|
-
)
|
|
1259
|
-
|
|
1260
|
-
def generator():
|
|
1261
|
-
for filepath in file_paths:
|
|
1262
|
-
ext = filepath.suffix.lower()
|
|
1263
|
-
if ext == ".jsonl":
|
|
1264
|
-
yield from _stream_jsonl(str(filepath))
|
|
1265
|
-
elif ext == ".csv":
|
|
1266
|
-
yield from _stream_csv(str(filepath))
|
|
1267
|
-
elif ext == ".parquet":
|
|
1268
|
-
yield from _stream_parquet(str(filepath))
|
|
1269
|
-
elif ext in (".arrow", ".feather"):
|
|
1270
|
-
yield from _stream_arrow(str(filepath))
|
|
1271
|
-
elif ext in (".json",):
|
|
1272
|
-
# JSON 需要全量加载
|
|
1273
|
-
data = load_data(str(filepath))
|
|
1274
|
-
yield from data
|
|
1275
|
-
elif ext in (".xlsx", ".xls"):
|
|
1276
|
-
# Excel 需要全量加载
|
|
1277
|
-
data = load_data(str(filepath))
|
|
1278
|
-
yield from data
|
|
1279
|
-
else:
|
|
1280
|
-
yield from _stream_jsonl(str(filepath))
|
|
1281
|
-
|
|
1282
|
-
st = StreamingTransformer(generator())
|
|
1283
|
-
return st.save(output, show_progress=True)
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
# ============ Stats Command ============
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
def stats(
|
|
1290
|
-
filename: str,
|
|
1291
|
-
top: int = 10,
|
|
1292
|
-
full: bool = False,
|
|
1293
|
-
) -> None:
|
|
1294
|
-
"""
|
|
1295
|
-
显示数据文件的统计信息。
|
|
1296
|
-
|
|
1297
|
-
默认快速模式:只统计行数和字段结构。
|
|
1298
|
-
完整模式(--full):统计值分布、唯一值、长度等详细信息。
|
|
1299
|
-
|
|
1300
|
-
Args:
|
|
1301
|
-
filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
|
|
1302
|
-
top: 显示频率最高的前 N 个值,默认 10(仅完整模式)
|
|
1303
|
-
full: 完整模式,统计值分布、唯一值等详细信息
|
|
1304
|
-
|
|
1305
|
-
Examples:
|
|
1306
|
-
dt stats data.jsonl # 快速模式(默认)
|
|
1307
|
-
dt stats data.jsonl --full # 完整模式
|
|
1308
|
-
dt stats data.csv -f --top=5 # 完整模式,显示 Top 5
|
|
1309
|
-
"""
|
|
1310
|
-
filepath = Path(filename)
|
|
1311
|
-
|
|
1312
|
-
if not filepath.exists():
|
|
1313
|
-
print(f"错误: 文件不存在 - {filename}")
|
|
1314
|
-
return
|
|
1315
|
-
|
|
1316
|
-
if not _check_file_format(filepath):
|
|
1317
|
-
return
|
|
1318
|
-
|
|
1319
|
-
if not full:
|
|
1320
|
-
_quick_stats(filepath)
|
|
1321
|
-
return
|
|
1322
|
-
|
|
1323
|
-
# 加载数据
|
|
1324
|
-
try:
|
|
1325
|
-
data = load_data(str(filepath))
|
|
1326
|
-
except Exception as e:
|
|
1327
|
-
print(f"错误: 无法读取文件 - {e}")
|
|
1328
|
-
return
|
|
1329
|
-
|
|
1330
|
-
if not data:
|
|
1331
|
-
print("文件为空")
|
|
1332
|
-
return
|
|
1333
|
-
|
|
1334
|
-
# 计算统计信息
|
|
1335
|
-
total = len(data)
|
|
1336
|
-
field_stats = _compute_field_stats(data, top)
|
|
1337
|
-
|
|
1338
|
-
# 输出统计信息
|
|
1339
|
-
_print_stats(filepath.name, total, field_stats)
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
def _quick_stats(filepath: Path) -> None:
|
|
1343
|
-
"""
|
|
1344
|
-
快速统计模式:只统计行数和字段结构,不遍历全部数据。
|
|
1345
|
-
|
|
1346
|
-
特点:
|
|
1347
|
-
- 使用流式计数,不加载全部数据到内存
|
|
1348
|
-
- 只读取前几条数据来推断字段结构
|
|
1349
|
-
- 不计算值分布、唯一值等耗时统计
|
|
1350
|
-
"""
|
|
1351
|
-
import orjson
|
|
1352
|
-
|
|
1353
|
-
from ..streaming import _count_rows_fast
|
|
1354
|
-
|
|
1355
|
-
ext = filepath.suffix.lower()
|
|
1356
|
-
file_size = filepath.stat().st_size
|
|
1357
|
-
|
|
1358
|
-
# 格式化文件大小
|
|
1359
|
-
def format_size(size: int) -> str:
|
|
1360
|
-
for unit in ["B", "KB", "MB", "GB"]:
|
|
1361
|
-
if size < 1024:
|
|
1362
|
-
return f"{size:.1f} {unit}"
|
|
1363
|
-
size /= 1024
|
|
1364
|
-
return f"{size:.1f} TB"
|
|
1365
|
-
|
|
1366
|
-
# 快速统计行数
|
|
1367
|
-
total = _count_rows_fast(str(filepath))
|
|
1368
|
-
if total is None:
|
|
1369
|
-
# 回退:手动计数
|
|
1370
|
-
total = 0
|
|
1371
|
-
try:
|
|
1372
|
-
with open(filepath, "rb") as f:
|
|
1373
|
-
for line in f:
|
|
1374
|
-
if line.strip():
|
|
1375
|
-
total += 1
|
|
1376
|
-
except Exception:
|
|
1377
|
-
total = -1
|
|
1378
|
-
|
|
1379
|
-
# 读取前几条数据推断字段结构
|
|
1380
|
-
sample_data = []
|
|
1381
|
-
sample_size = 5
|
|
1382
|
-
try:
|
|
1383
|
-
if ext == ".jsonl":
|
|
1384
|
-
with open(filepath, "rb") as f:
|
|
1385
|
-
for i, line in enumerate(f):
|
|
1386
|
-
if i >= sample_size:
|
|
1387
|
-
break
|
|
1388
|
-
line = line.strip()
|
|
1389
|
-
if line:
|
|
1390
|
-
sample_data.append(orjson.loads(line))
|
|
1391
|
-
elif ext == ".csv":
|
|
1392
|
-
import polars as pl
|
|
1393
|
-
|
|
1394
|
-
df = pl.scan_csv(str(filepath)).head(sample_size).collect()
|
|
1395
|
-
sample_data = df.to_dicts()
|
|
1396
|
-
elif ext == ".parquet":
|
|
1397
|
-
import polars as pl
|
|
1398
|
-
|
|
1399
|
-
df = pl.scan_parquet(str(filepath)).head(sample_size).collect()
|
|
1400
|
-
sample_data = df.to_dicts()
|
|
1401
|
-
elif ext in (".arrow", ".feather"):
|
|
1402
|
-
import polars as pl
|
|
1403
|
-
|
|
1404
|
-
df = pl.scan_ipc(str(filepath)).head(sample_size).collect()
|
|
1405
|
-
sample_data = df.to_dicts()
|
|
1406
|
-
elif ext == ".json":
|
|
1407
|
-
with open(filepath, "rb") as f:
|
|
1408
|
-
data = orjson.loads(f.read())
|
|
1409
|
-
if isinstance(data, list):
|
|
1410
|
-
sample_data = data[:sample_size]
|
|
1411
|
-
except Exception:
|
|
1412
|
-
pass
|
|
1413
|
-
|
|
1414
|
-
# 分析字段结构
|
|
1415
|
-
fields = []
|
|
1416
|
-
if sample_data:
|
|
1417
|
-
all_keys = set()
|
|
1418
|
-
for item in sample_data:
|
|
1419
|
-
all_keys.update(item.keys())
|
|
1420
|
-
|
|
1421
|
-
for key in sorted(all_keys):
|
|
1422
|
-
# 从采样数据中推断类型
|
|
1423
|
-
sample_values = [item.get(key) for item in sample_data if key in item]
|
|
1424
|
-
non_null = [v for v in sample_values if v is not None]
|
|
1425
|
-
if non_null:
|
|
1426
|
-
field_type = _infer_type(non_null)
|
|
1427
|
-
else:
|
|
1428
|
-
field_type = "unknown"
|
|
1429
|
-
fields.append({"field": key, "type": field_type})
|
|
1430
|
-
|
|
1431
|
-
# 输出
|
|
1432
|
-
try:
|
|
1433
|
-
from rich.console import Console
|
|
1434
|
-
from rich.panel import Panel
|
|
1435
|
-
from rich.table import Table
|
|
1436
|
-
|
|
1437
|
-
console = Console()
|
|
1438
|
-
|
|
1439
|
-
# 概览
|
|
1440
|
-
console.print(
|
|
1441
|
-
Panel(
|
|
1442
|
-
f"[bold]文件:[/bold] {filepath.name}\n"
|
|
1443
|
-
f"[bold]大小:[/bold] {format_size(file_size)}\n"
|
|
1444
|
-
f"[bold]总数:[/bold] {total:,} 条\n"
|
|
1445
|
-
f"[bold]字段:[/bold] {len(fields)} 个",
|
|
1446
|
-
title="📊 快速统计",
|
|
1447
|
-
expand=False,
|
|
1448
|
-
)
|
|
1449
|
-
)
|
|
1450
|
-
|
|
1451
|
-
if fields:
|
|
1452
|
-
table = Table(title="📋 字段结构", show_header=True, header_style="bold cyan")
|
|
1453
|
-
table.add_column("#", style="dim", justify="right")
|
|
1454
|
-
table.add_column("字段", style="green")
|
|
1455
|
-
table.add_column("类型", style="yellow")
|
|
1456
|
-
|
|
1457
|
-
for i, f in enumerate(fields, 1):
|
|
1458
|
-
table.add_row(str(i), f["field"], f["type"])
|
|
1459
|
-
|
|
1460
|
-
console.print(table)
|
|
1461
|
-
|
|
1462
|
-
except ImportError:
|
|
1463
|
-
# 没有 rich,使用普通打印
|
|
1464
|
-
print(f"\n{'=' * 40}")
|
|
1465
|
-
print("📊 快速统计")
|
|
1466
|
-
print(f"{'=' * 40}")
|
|
1467
|
-
print(f"文件: {filepath.name}")
|
|
1468
|
-
print(f"大小: {format_size(file_size)}")
|
|
1469
|
-
print(f"总数: {total:,} 条")
|
|
1470
|
-
print(f"字段: {len(fields)} 个")
|
|
1471
|
-
|
|
1472
|
-
if fields:
|
|
1473
|
-
print(f"\n📋 字段结构:")
|
|
1474
|
-
for i, f in enumerate(fields, 1):
|
|
1475
|
-
print(f" {i}. {f['field']} ({f['type']})")
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
def _compute_field_stats(data: List[Dict], top: int) -> List[Dict[str, Any]]:
|
|
1479
|
-
"""
|
|
1480
|
-
单次遍历计算每个字段的统计信息。
|
|
1481
|
-
|
|
1482
|
-
优化:将多次遍历合并为单次遍历,在遍历过程中同时收集所有统计数据。
|
|
1483
|
-
"""
|
|
1484
|
-
from collections import Counter, defaultdict
|
|
1485
|
-
|
|
1486
|
-
if not data:
|
|
1487
|
-
return []
|
|
1488
|
-
|
|
1489
|
-
total = len(data)
|
|
1490
|
-
|
|
1491
|
-
# 单次遍历收集所有字段的值和统计信息
|
|
1492
|
-
field_values = defaultdict(list) # 存储每个字段的所有值
|
|
1493
|
-
field_counters = defaultdict(Counter) # 存储每个字段的值频率(用于 top N)
|
|
1494
|
-
|
|
1495
|
-
for item in data:
|
|
1496
|
-
for k, v in item.items():
|
|
1497
|
-
field_values[k].append(v)
|
|
1498
|
-
# 对值进行截断后计数(用于 top N 显示)
|
|
1499
|
-
displayable = _truncate(v if v is not None else "", 30)
|
|
1500
|
-
field_counters[k][displayable] += 1
|
|
1501
|
-
|
|
1502
|
-
# 根据收集的数据计算统计信息
|
|
1503
|
-
stats_list = []
|
|
1504
|
-
for field in sorted(field_values.keys()):
|
|
1505
|
-
values = field_values[field]
|
|
1506
|
-
non_null = [v for v in values if v is not None and v != ""]
|
|
1507
|
-
non_null_count = len(non_null)
|
|
1508
|
-
|
|
1509
|
-
# 推断类型(从第一个非空值)
|
|
1510
|
-
field_type = _infer_type(non_null)
|
|
1511
|
-
|
|
1512
|
-
# 基础统计
|
|
1513
|
-
stat = {
|
|
1514
|
-
"field": field,
|
|
1515
|
-
"non_null": non_null_count,
|
|
1516
|
-
"null_rate": f"{(total - non_null_count) / total * 100:.1f}%",
|
|
1517
|
-
"type": field_type,
|
|
1518
|
-
}
|
|
1519
|
-
|
|
1520
|
-
# 类型特定统计
|
|
1521
|
-
if non_null:
|
|
1522
|
-
# 唯一值计数(对复杂类型使用 hash 节省内存)
|
|
1523
|
-
stat["unique"] = _count_unique(non_null, field_type)
|
|
1524
|
-
|
|
1525
|
-
# 字符串类型:计算长度统计
|
|
1526
|
-
if field_type == "str":
|
|
1527
|
-
lengths = [len(str(v)) for v in non_null]
|
|
1528
|
-
stat["len_min"] = min(lengths)
|
|
1529
|
-
stat["len_max"] = max(lengths)
|
|
1530
|
-
stat["len_avg"] = sum(lengths) / len(lengths)
|
|
1531
|
-
|
|
1532
|
-
# 数值类型:计算数值统计
|
|
1533
|
-
elif field_type in ("int", "float"):
|
|
1534
|
-
nums = [float(v) for v in non_null if _is_numeric(v)]
|
|
1535
|
-
if nums:
|
|
1536
|
-
stat["min"] = min(nums)
|
|
1537
|
-
stat["max"] = max(nums)
|
|
1538
|
-
stat["avg"] = sum(nums) / len(nums)
|
|
1539
|
-
|
|
1540
|
-
# 列表类型:计算长度统计
|
|
1541
|
-
elif field_type == "list":
|
|
1542
|
-
lengths = [len(v) if isinstance(v, list) else 0 for v in non_null]
|
|
1543
|
-
stat["len_min"] = min(lengths)
|
|
1544
|
-
stat["len_max"] = max(lengths)
|
|
1545
|
-
stat["len_avg"] = sum(lengths) / len(lengths)
|
|
1546
|
-
|
|
1547
|
-
# Top N 值(已在遍历时收集)
|
|
1548
|
-
stat["top_values"] = field_counters[field].most_common(top)
|
|
1549
|
-
|
|
1550
|
-
stats_list.append(stat)
|
|
1551
|
-
|
|
1552
|
-
return stats_list
|
|
1553
|
-
|
|
1554
|
-
|
|
1555
|
-
def _count_unique(values: List[Any], field_type: str) -> int:
|
|
1556
|
-
"""
|
|
1557
|
-
计算唯一值数量。
|
|
1558
|
-
|
|
1559
|
-
对于简单类型直接比较,对于 list/dict 或混合类型使用 hash。
|
|
1560
|
-
"""
|
|
1561
|
-
if field_type in ("list", "dict"):
|
|
1562
|
-
return _count_unique_by_hash(values)
|
|
1563
|
-
else:
|
|
1564
|
-
# 简单类型:尝试直接比较,失败则回退到 hash 方式
|
|
1565
|
-
try:
|
|
1566
|
-
return len(set(values))
|
|
1567
|
-
except TypeError:
|
|
1568
|
-
# 混合类型(如字段中既有 str 又有 dict),回退到 hash
|
|
1569
|
-
return _count_unique_by_hash(values)
|
|
1570
|
-
|
|
1571
|
-
|
|
1572
|
-
def _count_unique_by_hash(values: List[Any]) -> int:
|
|
1573
|
-
"""使用 orjson 序列化后计算 hash 来统计唯一值"""
|
|
1574
|
-
import hashlib
|
|
1575
|
-
|
|
1576
|
-
import orjson
|
|
1577
|
-
|
|
1578
|
-
seen = set()
|
|
1579
|
-
for v in values:
|
|
1580
|
-
try:
|
|
1581
|
-
h = hashlib.md5(orjson.dumps(v, option=orjson.OPT_SORT_KEYS)).digest()
|
|
1582
|
-
seen.add(h)
|
|
1583
|
-
except TypeError:
|
|
1584
|
-
# 无法序列化的值,用 repr 兜底
|
|
1585
|
-
seen.add(repr(v))
|
|
1586
|
-
return len(seen)
|
|
1587
|
-
|
|
1588
|
-
|
|
1589
|
-
def _infer_type(values: List[Any]) -> str:
|
|
1590
|
-
"""推断字段类型"""
|
|
1591
|
-
if not values:
|
|
1592
|
-
return "unknown"
|
|
1593
|
-
|
|
1594
|
-
sample = values[0]
|
|
1595
|
-
if isinstance(sample, bool):
|
|
1596
|
-
return "bool"
|
|
1597
|
-
if isinstance(sample, int):
|
|
1598
|
-
return "int"
|
|
1599
|
-
if isinstance(sample, float):
|
|
1600
|
-
return "float"
|
|
1601
|
-
if isinstance(sample, list):
|
|
1602
|
-
return "list"
|
|
1603
|
-
if isinstance(sample, dict):
|
|
1604
|
-
return "dict"
|
|
1605
|
-
return "str"
|
|
1606
|
-
|
|
1607
|
-
|
|
1608
|
-
def _is_numeric(v: Any) -> bool:
|
|
1609
|
-
"""检查值是否为数值"""
|
|
1610
|
-
if isinstance(v, (int, float)) and not isinstance(v, bool):
|
|
1611
|
-
return True
|
|
1612
|
-
return False
|
|
1613
|
-
|
|
1614
|
-
|
|
1615
|
-
def _truncate(v: Any, max_width: int) -> str:
|
|
1616
|
-
"""按显示宽度截断值(中文字符算 2 宽度)"""
|
|
1617
|
-
s = str(v)
|
|
1618
|
-
width = 0
|
|
1619
|
-
result = []
|
|
1620
|
-
for char in s:
|
|
1621
|
-
# CJK 字符范围
|
|
1622
|
-
if (
|
|
1623
|
-
"\u4e00" <= char <= "\u9fff"
|
|
1624
|
-
or "\u3000" <= char <= "\u303f"
|
|
1625
|
-
or "\uff00" <= char <= "\uffef"
|
|
1626
|
-
):
|
|
1627
|
-
char_width = 2
|
|
1628
|
-
else:
|
|
1629
|
-
char_width = 1
|
|
1630
|
-
if width + char_width > max_width - 3: # 预留 ... 的宽度
|
|
1631
|
-
return "".join(result) + "..."
|
|
1632
|
-
result.append(char)
|
|
1633
|
-
width += char_width
|
|
1634
|
-
return s
|
|
1635
|
-
|
|
1636
|
-
|
|
1637
|
-
def _display_width(s: str) -> int:
|
|
1638
|
-
"""计算字符串的显示宽度(中文字符算 2,ASCII 字符算 1)"""
|
|
1639
|
-
width = 0
|
|
1640
|
-
for char in s:
|
|
1641
|
-
# CJK 字符范围
|
|
1642
|
-
if (
|
|
1643
|
-
"\u4e00" <= char <= "\u9fff"
|
|
1644
|
-
or "\u3000" <= char <= "\u303f"
|
|
1645
|
-
or "\uff00" <= char <= "\uffef"
|
|
1646
|
-
):
|
|
1647
|
-
width += 2
|
|
1648
|
-
else:
|
|
1649
|
-
width += 1
|
|
1650
|
-
return width
|
|
1651
|
-
|
|
1652
|
-
|
|
1653
|
-
def _pad_to_width(s: str, target_width: int) -> str:
|
|
1654
|
-
"""将字符串填充到指定的显示宽度"""
|
|
1655
|
-
current_width = _display_width(s)
|
|
1656
|
-
if current_width >= target_width:
|
|
1657
|
-
return s
|
|
1658
|
-
return s + " " * (target_width - current_width)
|
|
1659
|
-
|
|
1660
|
-
|
|
1661
|
-
def _print_stats(filename: str, total: int, field_stats: List[Dict[str, Any]]) -> None:
|
|
1662
|
-
"""打印统计信息"""
|
|
1663
|
-
try:
|
|
1664
|
-
from rich.console import Console
|
|
1665
|
-
from rich.panel import Panel
|
|
1666
|
-
from rich.table import Table
|
|
1667
|
-
|
|
1668
|
-
console = Console()
|
|
1669
|
-
|
|
1670
|
-
# 概览
|
|
1671
|
-
console.print(
|
|
1672
|
-
Panel(
|
|
1673
|
-
f"[bold]文件:[/bold] {filename}\n"
|
|
1674
|
-
f"[bold]总数:[/bold] {total:,} 条\n"
|
|
1675
|
-
f"[bold]字段:[/bold] {len(field_stats)} 个",
|
|
1676
|
-
title="📊 数据概览",
|
|
1677
|
-
expand=False,
|
|
1678
|
-
)
|
|
1679
|
-
)
|
|
1680
|
-
|
|
1681
|
-
# 字段统计表
|
|
1682
|
-
table = Table(title="📋 字段统计", show_header=True, header_style="bold cyan")
|
|
1683
|
-
table.add_column("字段", style="green")
|
|
1684
|
-
table.add_column("类型", style="yellow")
|
|
1685
|
-
table.add_column("非空率", justify="right")
|
|
1686
|
-
table.add_column("唯一值", justify="right")
|
|
1687
|
-
table.add_column("统计", style="dim")
|
|
1688
|
-
|
|
1689
|
-
for stat in field_stats:
|
|
1690
|
-
non_null_rate = f"{stat['non_null'] / total * 100:.0f}%"
|
|
1691
|
-
unique = str(stat.get("unique", "-"))
|
|
1692
|
-
|
|
1693
|
-
# 构建统计信息字符串
|
|
1694
|
-
extra = []
|
|
1695
|
-
if "len_avg" in stat:
|
|
1696
|
-
extra.append(
|
|
1697
|
-
f"长度: {stat['len_min']}-{stat['len_max']} (avg {stat['len_avg']:.0f})"
|
|
1698
|
-
)
|
|
1699
|
-
if "avg" in stat:
|
|
1700
|
-
if stat["type"] == "int":
|
|
1701
|
-
extra.append(
|
|
1702
|
-
f"范围: {int(stat['min'])}-{int(stat['max'])} (avg {stat['avg']:.1f})"
|
|
1703
|
-
)
|
|
1704
|
-
else:
|
|
1705
|
-
extra.append(
|
|
1706
|
-
f"范围: {stat['min']:.2f}-{stat['max']:.2f} (avg {stat['avg']:.2f})"
|
|
1707
|
-
)
|
|
1708
|
-
|
|
1709
|
-
table.add_row(
|
|
1710
|
-
stat["field"],
|
|
1711
|
-
stat["type"],
|
|
1712
|
-
non_null_rate,
|
|
1713
|
-
unique,
|
|
1714
|
-
"; ".join(extra) if extra else "-",
|
|
1715
|
-
)
|
|
1716
|
-
|
|
1717
|
-
console.print(table)
|
|
1718
|
-
|
|
1719
|
-
# Top 值统计(仅显示有意义的字段)
|
|
1720
|
-
for stat in field_stats:
|
|
1721
|
-
top_values = stat.get("top_values", [])
|
|
1722
|
-
if not top_values:
|
|
1723
|
-
continue
|
|
1724
|
-
|
|
1725
|
-
# 跳过数值类型(min/max/avg 已足够)
|
|
1726
|
-
if stat["type"] in ("int", "float"):
|
|
1727
|
-
continue
|
|
1728
|
-
|
|
1729
|
-
# 跳过唯一值过多的字段(基本都是唯一的)
|
|
1730
|
-
unique_ratio = stat.get("unique", 0) / total if total > 0 else 0
|
|
1731
|
-
if unique_ratio > 0.9 and stat.get("unique", 0) > 100:
|
|
1732
|
-
continue
|
|
1733
|
-
|
|
1734
|
-
console.print(
|
|
1735
|
-
f"\n[bold cyan]{stat['field']}[/bold cyan] 值分布 (Top {len(top_values)}):"
|
|
1736
|
-
)
|
|
1737
|
-
max_count = max(c for _, c in top_values) if top_values else 1
|
|
1738
|
-
for value, count in top_values:
|
|
1739
|
-
pct = count / total * 100
|
|
1740
|
-
bar_len = int(count / max_count * 20) # 按相对比例,最长 20 字符
|
|
1741
|
-
bar = "█" * bar_len
|
|
1742
|
-
display_value = value if value else "[空]"
|
|
1743
|
-
# 使用显示宽度对齐(处理中文字符)
|
|
1744
|
-
padded_value = _pad_to_width(display_value, 32)
|
|
1745
|
-
console.print(f" {padded_value} {count:>6} ({pct:>5.1f}%) {bar}")
|
|
1746
|
-
|
|
1747
|
-
except ImportError:
|
|
1748
|
-
# 没有 rich,使用普通打印
|
|
1749
|
-
print(f"\n{'=' * 50}")
|
|
1750
|
-
print(f"📊 数据概览")
|
|
1751
|
-
print(f"{'=' * 50}")
|
|
1752
|
-
print(f"文件: {filename}")
|
|
1753
|
-
print(f"总数: {total:,} 条")
|
|
1754
|
-
print(f"字段: {len(field_stats)} 个")
|
|
1755
|
-
|
|
1756
|
-
print(f"\n{'=' * 50}")
|
|
1757
|
-
print(f"📋 字段统计")
|
|
1758
|
-
print(f"{'=' * 50}")
|
|
1759
|
-
print(f"{'字段':<20} {'类型':<8} {'非空率':<8} {'唯一值':<8}")
|
|
1760
|
-
print("-" * 50)
|
|
1761
|
-
|
|
1762
|
-
for stat in field_stats:
|
|
1763
|
-
non_null_rate = f"{stat['non_null'] / total * 100:.0f}%"
|
|
1764
|
-
unique = str(stat.get("unique", "-"))
|
|
1765
|
-
print(f"{stat['field']:<20} {stat['type']:<8} {non_null_rate:<8} {unique:<8}")
|
|
1766
|
-
|
|
1767
|
-
|
|
1768
|
-
# ============ Clean Command ============
|
|
1769
|
-
|
|
1770
|
-
|
|
1771
|
-
def clean(
|
|
1772
|
-
filename: str,
|
|
1773
|
-
drop_empty: Optional[str] = None,
|
|
1774
|
-
min_len: Optional[str] = None,
|
|
1775
|
-
max_len: Optional[str] = None,
|
|
1776
|
-
keep: Optional[str] = None,
|
|
1777
|
-
drop: Optional[str] = None,
|
|
1778
|
-
strip: bool = False,
|
|
1779
|
-
output: Optional[str] = None,
|
|
1780
|
-
) -> None:
|
|
1781
|
-
"""
|
|
1782
|
-
数据清洗(默认流式处理)。
|
|
1783
|
-
|
|
1784
|
-
Args:
|
|
1785
|
-
filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
|
|
1786
|
-
drop_empty: 删除空值记录,支持嵌套路径语法
|
|
1787
|
-
- 不带值:删除任意字段为空的记录
|
|
1788
|
-
- 指定字段:删除指定字段为空的记录(逗号分隔)
|
|
1789
|
-
min_len: 最小长度过滤,格式 "字段:长度",字段支持嵌套路径
|
|
1790
|
-
max_len: 最大长度过滤,格式 "字段:长度",字段支持嵌套路径
|
|
1791
|
-
keep: 只保留指定字段(逗号分隔,仅支持顶层字段)
|
|
1792
|
-
drop: 删除指定字段(逗号分隔,仅支持顶层字段)
|
|
1793
|
-
strip: 去除所有字符串字段的首尾空白
|
|
1794
|
-
output: 输出文件路径,不指定则覆盖原文件
|
|
1795
|
-
|
|
1796
|
-
Examples:
|
|
1797
|
-
dt clean data.jsonl --drop-empty # 删除任意空值记录
|
|
1798
|
-
dt clean data.jsonl --drop-empty=text,answer # 删除指定字段为空的记录
|
|
1799
|
-
dt clean data.jsonl --drop-empty=meta.source # 删除嵌套字段为空的记录
|
|
1800
|
-
dt clean data.jsonl --min-len=text:10 # text 字段最少 10 字符
|
|
1801
|
-
dt clean data.jsonl --min-len=messages.#:2 # 至少 2 条消息
|
|
1802
|
-
dt clean data.jsonl --max-len=messages[-1].content:500 # 最后一条消息最多 500 字符
|
|
1803
|
-
dt clean data.jsonl --keep=question,answer # 只保留这些字段
|
|
1804
|
-
dt clean data.jsonl --drop=metadata,timestamp # 删除这些字段
|
|
1805
|
-
dt clean data.jsonl --strip # 去除字符串首尾空白
|
|
1806
|
-
"""
|
|
1807
|
-
filepath = Path(filename)
|
|
1808
|
-
|
|
1809
|
-
if not filepath.exists():
|
|
1810
|
-
print(f"错误: 文件不存在 - {filename}")
|
|
1811
|
-
return
|
|
1812
|
-
|
|
1813
|
-
if not _check_file_format(filepath):
|
|
1814
|
-
return
|
|
1815
|
-
|
|
1816
|
-
# 解析参数
|
|
1817
|
-
min_len_field, min_len_value = _parse_len_param(min_len) if min_len else (None, None)
|
|
1818
|
-
max_len_field, max_len_value = _parse_len_param(max_len) if max_len else (None, None)
|
|
1819
|
-
keep_fields = _parse_field_list(keep) if keep else None
|
|
1820
|
-
drop_fields_set = set(_parse_field_list(drop)) if drop else None
|
|
1821
|
-
keep_set = set(keep_fields) if keep_fields else None
|
|
1822
|
-
|
|
1823
|
-
# 构建清洗配置
|
|
1824
|
-
empty_fields = None
|
|
1825
|
-
if drop_empty is not None:
|
|
1826
|
-
if drop_empty == "" or drop_empty is True:
|
|
1827
|
-
print("🔄 删除任意字段为空的记录...")
|
|
1828
|
-
empty_fields = []
|
|
1829
|
-
else:
|
|
1830
|
-
empty_fields = _parse_field_list(drop_empty)
|
|
1831
|
-
print(f"🔄 删除字段为空的记录: {', '.join(empty_fields)}")
|
|
1832
|
-
|
|
1833
|
-
if strip:
|
|
1834
|
-
print("🔄 去除字符串首尾空白...")
|
|
1835
|
-
if min_len_field:
|
|
1836
|
-
print(f"🔄 过滤 {min_len_field} 长度 < {min_len_value} 的记录...")
|
|
1837
|
-
if max_len_field:
|
|
1838
|
-
print(f"🔄 过滤 {max_len_field} 长度 > {max_len_value} 的记录...")
|
|
1839
|
-
if keep_fields:
|
|
1840
|
-
print(f"🔄 只保留字段: {', '.join(keep_fields)}")
|
|
1841
|
-
if drop_fields_set:
|
|
1842
|
-
print(f"🔄 删除字段: {', '.join(drop_fields_set)}")
|
|
1843
|
-
|
|
1844
|
-
output_path = output or str(filepath)
|
|
1845
|
-
|
|
1846
|
-
# 检查输入输出是否相同(流式处理需要临时文件)
|
|
1847
|
-
input_resolved = filepath.resolve()
|
|
1848
|
-
output_resolved = Path(output_path).resolve()
|
|
1849
|
-
use_temp_file = input_resolved == output_resolved
|
|
1850
|
-
|
|
1851
|
-
# 对于 JSONL 文件使用流式处理
|
|
1852
|
-
if _is_streaming_supported(filepath):
|
|
1853
|
-
print(f"📊 流式加载: {filepath}")
|
|
1854
|
-
|
|
1855
|
-
# 如果输入输出相同,使用临时文件
|
|
1856
|
-
if use_temp_file:
|
|
1857
|
-
print("⚠ 检测到输出文件与输入文件相同,将使用临时文件")
|
|
1858
|
-
temp_fd, temp_path = tempfile.mkstemp(
|
|
1859
|
-
suffix=output_resolved.suffix,
|
|
1860
|
-
prefix=".tmp_",
|
|
1861
|
-
dir=output_resolved.parent,
|
|
1862
|
-
)
|
|
1863
|
-
os.close(temp_fd)
|
|
1864
|
-
actual_output = temp_path
|
|
1865
|
-
else:
|
|
1866
|
-
actual_output = output_path
|
|
1867
|
-
|
|
1868
|
-
try:
|
|
1869
|
-
count = _clean_streaming(
|
|
1870
|
-
str(filepath),
|
|
1871
|
-
actual_output,
|
|
1872
|
-
strip=strip,
|
|
1873
|
-
empty_fields=empty_fields,
|
|
1874
|
-
min_len_field=min_len_field,
|
|
1875
|
-
min_len_value=min_len_value,
|
|
1876
|
-
max_len_field=max_len_field,
|
|
1877
|
-
max_len_value=max_len_value,
|
|
1878
|
-
keep_set=keep_set,
|
|
1879
|
-
drop_fields_set=drop_fields_set,
|
|
1880
|
-
)
|
|
1881
|
-
|
|
1882
|
-
# 如果使用了临时文件,移动到目标位置
|
|
1883
|
-
if use_temp_file:
|
|
1884
|
-
shutil.move(temp_path, output_path)
|
|
1885
|
-
|
|
1886
|
-
print(f"💾 保存结果: {output_path}")
|
|
1887
|
-
print(f"\n✅ 完成! 清洗后 {count} 条数据")
|
|
1888
|
-
except Exception as e:
|
|
1889
|
-
# 清理临时文件
|
|
1890
|
-
if use_temp_file and os.path.exists(temp_path):
|
|
1891
|
-
os.unlink(temp_path)
|
|
1892
|
-
print(f"错误: 清洗失败 - {e}")
|
|
1893
|
-
import traceback
|
|
1894
|
-
|
|
1895
|
-
traceback.print_exc()
|
|
1896
|
-
return
|
|
1897
|
-
|
|
1898
|
-
# 非 JSONL 文件使用传统方式
|
|
1899
|
-
print(f"📊 加载数据: {filepath}")
|
|
1900
|
-
try:
|
|
1901
|
-
dt = DataTransformer.load(str(filepath))
|
|
1902
|
-
except Exception as e:
|
|
1903
|
-
print(f"错误: 无法读取文件 - {e}")
|
|
1904
|
-
return
|
|
1905
|
-
|
|
1906
|
-
original_count = len(dt)
|
|
1907
|
-
print(f" 共 {original_count} 条数据")
|
|
1908
|
-
|
|
1909
|
-
# 单次遍历执行所有清洗操作
|
|
1910
|
-
data, step_stats = _clean_data_single_pass(
|
|
1911
|
-
dt.data,
|
|
1912
|
-
strip=strip,
|
|
1913
|
-
empty_fields=empty_fields,
|
|
1914
|
-
min_len_field=min_len_field,
|
|
1915
|
-
min_len_value=min_len_value,
|
|
1916
|
-
max_len_field=max_len_field,
|
|
1917
|
-
max_len_value=max_len_value,
|
|
1918
|
-
keep_fields=keep_fields,
|
|
1919
|
-
drop_fields=drop_fields_set,
|
|
1920
|
-
)
|
|
1921
|
-
|
|
1922
|
-
# 保存结果
|
|
1923
|
-
final_count = len(data)
|
|
1924
|
-
print(f"💾 保存结果: {output_path}")
|
|
1925
|
-
|
|
1926
|
-
try:
|
|
1927
|
-
save_data(data, output_path)
|
|
1928
|
-
except Exception as e:
|
|
1929
|
-
print(f"错误: 无法保存文件 - {e}")
|
|
1930
|
-
return
|
|
1931
|
-
|
|
1932
|
-
# 打印统计
|
|
1933
|
-
removed_count = original_count - final_count
|
|
1934
|
-
print(f"\n✅ 完成!")
|
|
1935
|
-
print(f" 原始: {original_count} 条 -> 清洗后: {final_count} 条 (删除 {removed_count} 条)")
|
|
1936
|
-
if step_stats:
|
|
1937
|
-
print(f" 步骤: {' | '.join(step_stats)}")
|
|
1938
|
-
|
|
1939
|
-
|
|
1940
|
-
def _parse_len_param(param: str) -> tuple:
|
|
1941
|
-
"""解析长度参数,格式 'field:length'"""
|
|
1942
|
-
if ":" not in param:
|
|
1943
|
-
raise ValueError(f"长度参数格式错误: {param},应为 '字段:长度'")
|
|
1944
|
-
parts = param.split(":", 1)
|
|
1945
|
-
field = parts[0].strip()
|
|
1946
|
-
try:
|
|
1947
|
-
length = int(parts[1].strip())
|
|
1948
|
-
except ValueError:
|
|
1949
|
-
raise ValueError(f"长度必须是整数: {parts[1]}")
|
|
1950
|
-
return field, length
|
|
1951
|
-
|
|
1952
|
-
|
|
1953
|
-
def _parse_field_list(value: Any) -> List[str]:
|
|
1954
|
-
"""解析字段列表参数(处理 fire 将逗号分隔的值解析为元组的情况)"""
|
|
1955
|
-
if isinstance(value, (list, tuple)):
|
|
1956
|
-
return [str(f).strip() for f in value]
|
|
1957
|
-
elif isinstance(value, str):
|
|
1958
|
-
return [f.strip() for f in value.split(",")]
|
|
1959
|
-
else:
|
|
1960
|
-
return [str(value)]
|
|
1961
|
-
|
|
1962
|
-
|
|
1963
|
-
def _is_empty_value(v: Any) -> bool:
|
|
1964
|
-
"""判断值是否为空"""
|
|
1965
|
-
if v is None:
|
|
1966
|
-
return True
|
|
1967
|
-
if isinstance(v, str) and v.strip() == "":
|
|
1968
|
-
return True
|
|
1969
|
-
if isinstance(v, (list, dict)) and len(v) == 0:
|
|
1970
|
-
return True
|
|
1971
|
-
return False
|
|
1972
|
-
|
|
1973
|
-
|
|
1974
|
-
def _get_value_len(value: Any) -> int:
|
|
1975
|
-
"""
|
|
1976
|
-
获取值的长度。
|
|
1977
|
-
|
|
1978
|
-
- str/list/dict: 返回 len()
|
|
1979
|
-
- int/float: 直接返回该数值(用于 messages.# 这种返回数量的场景)
|
|
1980
|
-
- None: 返回 0
|
|
1981
|
-
- 其他: 转为字符串后返回长度
|
|
1982
|
-
"""
|
|
1983
|
-
if value is None:
|
|
1984
|
-
return 0
|
|
1985
|
-
if isinstance(value, (int, float)):
|
|
1986
|
-
return int(value)
|
|
1987
|
-
if isinstance(value, (str, list, dict)):
|
|
1988
|
-
return len(value)
|
|
1989
|
-
return len(str(value))
|
|
1990
|
-
|
|
1991
|
-
|
|
1992
|
-
def _clean_data_single_pass(
|
|
1993
|
-
data: List[Dict],
|
|
1994
|
-
strip: bool = False,
|
|
1995
|
-
empty_fields: Optional[List[str]] = None,
|
|
1996
|
-
min_len_field: Optional[str] = None,
|
|
1997
|
-
min_len_value: Optional[int] = None,
|
|
1998
|
-
max_len_field: Optional[str] = None,
|
|
1999
|
-
max_len_value: Optional[int] = None,
|
|
2000
|
-
keep_fields: Optional[List[str]] = None,
|
|
2001
|
-
drop_fields: Optional[set] = None,
|
|
2002
|
-
) -> tuple:
|
|
2003
|
-
"""
|
|
2004
|
-
单次遍历执行所有清洗操作。
|
|
2005
|
-
|
|
2006
|
-
Args:
|
|
2007
|
-
data: 原始数据列表
|
|
2008
|
-
strip: 是否去除字符串首尾空白
|
|
2009
|
-
empty_fields: 检查空值的字段列表(支持嵌套路径),空列表表示检查所有字段,None 表示不检查
|
|
2010
|
-
min_len_field: 最小长度检查的字段(支持嵌套路径)
|
|
2011
|
-
min_len_value: 最小长度值
|
|
2012
|
-
max_len_field: 最大长度检查的字段(支持嵌套路径)
|
|
2013
|
-
max_len_value: 最大长度值
|
|
2014
|
-
keep_fields: 只保留的字段列表(仅支持顶层字段)
|
|
2015
|
-
drop_fields: 要删除的字段集合(仅支持顶层字段)
|
|
2016
|
-
|
|
2017
|
-
Returns:
|
|
2018
|
-
(清洗后的数据, 统计信息列表)
|
|
2019
|
-
"""
|
|
2020
|
-
result = []
|
|
2021
|
-
stats = {
|
|
2022
|
-
"drop_empty": 0,
|
|
2023
|
-
"min_len": 0,
|
|
2024
|
-
"max_len": 0,
|
|
2025
|
-
}
|
|
2026
|
-
|
|
2027
|
-
# 预先计算 keep_fields 集合(如果有的话)
|
|
2028
|
-
keep_set = set(keep_fields) if keep_fields else None
|
|
2029
|
-
|
|
2030
|
-
for item in data:
|
|
2031
|
-
# 1. strip 处理(在过滤前执行,这样空值检测更准确)
|
|
2032
|
-
if strip:
|
|
2033
|
-
item = {k: v.strip() if isinstance(v, str) else v for k, v in item.items()}
|
|
2034
|
-
|
|
2035
|
-
# 2. 空值过滤
|
|
2036
|
-
if empty_fields is not None:
|
|
2037
|
-
if len(empty_fields) == 0:
|
|
2038
|
-
# 检查所有字段
|
|
2039
|
-
if any(_is_empty_value(v) for v in item.values()):
|
|
2040
|
-
stats["drop_empty"] += 1
|
|
2041
|
-
continue
|
|
2042
|
-
else:
|
|
2043
|
-
# 检查指定字段(支持嵌套路径)
|
|
2044
|
-
if any(_is_empty_value(get_field_with_spec(item, f)) for f in empty_fields):
|
|
2045
|
-
stats["drop_empty"] += 1
|
|
2046
|
-
continue
|
|
2047
|
-
|
|
2048
|
-
# 3. 最小长度过滤(支持嵌套路径)
|
|
2049
|
-
if min_len_field is not None:
|
|
2050
|
-
if _get_value_len(get_field_with_spec(item, min_len_field, default="")) < min_len_value:
|
|
2051
|
-
stats["min_len"] += 1
|
|
2052
|
-
continue
|
|
2053
|
-
|
|
2054
|
-
# 4. 最大长度过滤(支持嵌套路径)
|
|
2055
|
-
if max_len_field is not None:
|
|
2056
|
-
if _get_value_len(get_field_with_spec(item, max_len_field, default="")) > max_len_value:
|
|
2057
|
-
stats["max_len"] += 1
|
|
2058
|
-
continue
|
|
2059
|
-
|
|
2060
|
-
# 5. 字段管理(keep/drop)
|
|
2061
|
-
if keep_set is not None:
|
|
2062
|
-
item = {k: v for k, v in item.items() if k in keep_set}
|
|
2063
|
-
elif drop_fields is not None:
|
|
2064
|
-
item = {k: v for k, v in item.items() if k not in drop_fields}
|
|
2065
|
-
|
|
2066
|
-
result.append(item)
|
|
2067
|
-
|
|
2068
|
-
# 构建统计信息字符串列表
|
|
2069
|
-
step_stats = []
|
|
2070
|
-
if strip:
|
|
2071
|
-
step_stats.append("strip")
|
|
2072
|
-
if stats["drop_empty"] > 0:
|
|
2073
|
-
step_stats.append(f"drop-empty: -{stats['drop_empty']}")
|
|
2074
|
-
if stats["min_len"] > 0:
|
|
2075
|
-
step_stats.append(f"min-len: -{stats['min_len']}")
|
|
2076
|
-
if stats["max_len"] > 0:
|
|
2077
|
-
step_stats.append(f"max-len: -{stats['max_len']}")
|
|
2078
|
-
if keep_fields:
|
|
2079
|
-
step_stats.append(f"keep: {len(keep_fields)} 字段")
|
|
2080
|
-
if drop_fields:
|
|
2081
|
-
step_stats.append(f"drop: {len(drop_fields)} 字段")
|
|
2082
|
-
|
|
2083
|
-
return result, step_stats
|
|
2084
|
-
|
|
2085
|
-
|
|
2086
|
-
def _clean_streaming(
|
|
2087
|
-
input_path: str,
|
|
2088
|
-
output_path: str,
|
|
2089
|
-
strip: bool = False,
|
|
2090
|
-
empty_fields: Optional[List[str]] = None,
|
|
2091
|
-
min_len_field: Optional[str] = None,
|
|
2092
|
-
min_len_value: Optional[int] = None,
|
|
2093
|
-
max_len_field: Optional[str] = None,
|
|
2094
|
-
max_len_value: Optional[int] = None,
|
|
2095
|
-
keep_set: Optional[set] = None,
|
|
2096
|
-
drop_fields_set: Optional[set] = None,
|
|
2097
|
-
) -> int:
|
|
2098
|
-
"""
|
|
2099
|
-
流式清洗数据。
|
|
2100
|
-
|
|
2101
|
-
Returns:
|
|
2102
|
-
处理后的数据条数
|
|
2103
|
-
"""
|
|
2104
|
-
|
|
2105
|
-
def clean_filter(item: Dict) -> bool:
|
|
2106
|
-
"""过滤函数:返回 True 保留,False 过滤(支持嵌套路径)"""
|
|
2107
|
-
# 空值过滤
|
|
2108
|
-
if empty_fields is not None:
|
|
2109
|
-
if len(empty_fields) == 0:
|
|
2110
|
-
if any(_is_empty_value(v) for v in item.values()):
|
|
2111
|
-
return False
|
|
2112
|
-
else:
|
|
2113
|
-
# 支持嵌套路径
|
|
2114
|
-
if any(_is_empty_value(get_field_with_spec(item, f)) for f in empty_fields):
|
|
2115
|
-
return False
|
|
2116
|
-
|
|
2117
|
-
# 最小长度过滤(支持嵌套路径)
|
|
2118
|
-
if min_len_field is not None:
|
|
2119
|
-
if _get_value_len(get_field_with_spec(item, min_len_field, default="")) < min_len_value:
|
|
2120
|
-
return False
|
|
2121
|
-
|
|
2122
|
-
# 最大长度过滤(支持嵌套路径)
|
|
2123
|
-
if max_len_field is not None:
|
|
2124
|
-
if _get_value_len(get_field_with_spec(item, max_len_field, default="")) > max_len_value:
|
|
2125
|
-
return False
|
|
2126
|
-
|
|
2127
|
-
return True
|
|
2128
|
-
|
|
2129
|
-
def clean_transform(item: Dict) -> Dict:
|
|
2130
|
-
"""转换函数:strip + 字段管理"""
|
|
2131
|
-
# strip 处理
|
|
2132
|
-
if strip:
|
|
2133
|
-
item = {k: v.strip() if isinstance(v, str) else v for k, v in item.items()}
|
|
2134
|
-
|
|
2135
|
-
# 字段管理
|
|
2136
|
-
if keep_set is not None:
|
|
2137
|
-
item = {k: v for k, v in item.items() if k in keep_set}
|
|
2138
|
-
elif drop_fields_set is not None:
|
|
2139
|
-
item = {k: v for k, v in item.items() if k not in drop_fields_set}
|
|
2140
|
-
|
|
2141
|
-
return item
|
|
2142
|
-
|
|
2143
|
-
# 构建流式处理链
|
|
2144
|
-
st = load_stream(input_path)
|
|
2145
|
-
|
|
2146
|
-
# 如果需要 strip,先执行 strip 转换(在过滤之前,这样空值检测更准确)
|
|
2147
|
-
if strip:
|
|
2148
|
-
st = st.transform(
|
|
2149
|
-
lambda x: {k: v.strip() if isinstance(v, str) else v for k, v in x.items()}
|
|
2150
|
-
)
|
|
2151
|
-
|
|
2152
|
-
# 执行过滤
|
|
2153
|
-
if empty_fields is not None or min_len_field is not None or max_len_field is not None:
|
|
2154
|
-
st = st.filter(clean_filter)
|
|
2155
|
-
|
|
2156
|
-
# 执行字段管理(如果没有 strip,也需要在这里处理)
|
|
2157
|
-
if keep_set is not None or drop_fields_set is not None:
|
|
2158
|
-
|
|
2159
|
-
def field_transform(item):
|
|
2160
|
-
if keep_set is not None:
|
|
2161
|
-
return {k: v for k, v in item.items() if k in keep_set}
|
|
2162
|
-
elif drop_fields_set is not None:
|
|
2163
|
-
return {k: v for k, v in item.items() if k not in drop_fields_set}
|
|
2164
|
-
return item
|
|
2165
|
-
|
|
2166
|
-
st = st.transform(field_transform)
|
|
2167
|
-
|
|
2168
|
-
return st.save(output_path)
|
|
2169
|
-
|
|
2170
|
-
|
|
2171
|
-
# ============ Run Command ============
|
|
2172
|
-
|
|
2173
|
-
|
|
2174
|
-
def run(
|
|
2175
|
-
config: str,
|
|
2176
|
-
input: Optional[str] = None,
|
|
2177
|
-
output: Optional[str] = None,
|
|
2178
|
-
) -> None:
|
|
2179
|
-
"""
|
|
2180
|
-
执行 Pipeline 配置文件。
|
|
2181
|
-
|
|
2182
|
-
Args:
|
|
2183
|
-
config: Pipeline YAML 配置文件路径
|
|
2184
|
-
input: 输入文件路径(覆盖配置中的 input)
|
|
2185
|
-
output: 输出文件路径(覆盖配置中的 output)
|
|
2186
|
-
|
|
2187
|
-
Examples:
|
|
2188
|
-
dt run pipeline.yaml
|
|
2189
|
-
dt run pipeline.yaml --input=new_data.jsonl
|
|
2190
|
-
dt run pipeline.yaml --input=data.jsonl --output=result.jsonl
|
|
2191
|
-
"""
|
|
2192
|
-
config_path = Path(config)
|
|
2193
|
-
|
|
2194
|
-
if not config_path.exists():
|
|
2195
|
-
print(f"错误: 配置文件不存在 - {config}")
|
|
2196
|
-
return
|
|
2197
|
-
|
|
2198
|
-
if config_path.suffix.lower() not in (".yaml", ".yml"):
|
|
2199
|
-
print(f"错误: 配置文件必须是 YAML 格式 (.yaml 或 .yml)")
|
|
2200
|
-
return
|
|
2201
|
-
|
|
2202
|
-
# 验证配置
|
|
2203
|
-
errors = validate_pipeline(config)
|
|
2204
|
-
if errors:
|
|
2205
|
-
print("❌ 配置文件验证失败:")
|
|
2206
|
-
for err in errors:
|
|
2207
|
-
print(f" - {err}")
|
|
2208
|
-
return
|
|
2209
|
-
|
|
2210
|
-
# 执行 pipeline
|
|
2211
|
-
try:
|
|
2212
|
-
run_pipeline(config, input_file=input, output_file=output, verbose=True)
|
|
2213
|
-
except Exception as e:
|
|
2214
|
-
print(f"错误: {e}")
|
|
2215
|
-
import traceback
|
|
2216
|
-
|
|
2217
|
-
traceback.print_exc()
|
|
2218
|
-
|
|
2219
|
-
|
|
2220
|
-
# ============ Token Stats Command ============
|
|
2221
|
-
|
|
2222
|
-
|
|
2223
|
-
def token_stats(
|
|
2224
|
-
filename: str,
|
|
2225
|
-
field: str = "messages",
|
|
2226
|
-
model: str = "cl100k_base",
|
|
2227
|
-
detailed: bool = False,
|
|
2228
|
-
) -> None:
|
|
2229
|
-
"""
|
|
2230
|
-
统计数据集的 Token 信息。
|
|
2231
|
-
|
|
2232
|
-
Args:
|
|
2233
|
-
filename: 输入文件路径
|
|
2234
|
-
field: 要统计的字段(默认 messages),支持嵌套路径语法
|
|
2235
|
-
model: 分词器: cl100k_base (默认), qwen2.5, llama3, gpt-4 等
|
|
2236
|
-
detailed: 是否显示详细统计
|
|
2237
|
-
|
|
2238
|
-
Examples:
|
|
2239
|
-
dt token-stats data.jsonl
|
|
2240
|
-
dt token-stats data.jsonl --field=text --model=qwen2.5
|
|
2241
|
-
dt token-stats data.jsonl --field=conversation.messages
|
|
2242
|
-
dt token-stats data.jsonl --field=messages[-1].content # 统计最后一条消息
|
|
2243
|
-
dt token-stats data.jsonl --detailed
|
|
2244
|
-
"""
|
|
2245
|
-
filepath = Path(filename)
|
|
2246
|
-
|
|
2247
|
-
if not filepath.exists():
|
|
2248
|
-
print(f"错误: 文件不存在 - {filename}")
|
|
2249
|
-
return
|
|
2250
|
-
|
|
2251
|
-
if not _check_file_format(filepath):
|
|
2252
|
-
return
|
|
2253
|
-
|
|
2254
|
-
# 加载数据
|
|
2255
|
-
print(f"📊 加载数据: {filepath}")
|
|
2256
|
-
try:
|
|
2257
|
-
data = load_data(str(filepath))
|
|
2258
|
-
except Exception as e:
|
|
2259
|
-
print(f"错误: 无法读取文件 - {e}")
|
|
2260
|
-
return
|
|
2261
|
-
|
|
2262
|
-
if not data:
|
|
2263
|
-
print("文件为空")
|
|
2264
|
-
return
|
|
2265
|
-
|
|
2266
|
-
total = len(data)
|
|
2267
|
-
print(f" 共 {total} 条数据")
|
|
2268
|
-
print(f"🔢 统计 Token (模型: {model}, 字段: {field})...")
|
|
2269
|
-
|
|
2270
|
-
# 检查字段类型并选择合适的统计方法(支持嵌套路径)
|
|
2271
|
-
sample = data[0]
|
|
2272
|
-
field_value = get_field_with_spec(sample, field)
|
|
2273
|
-
|
|
2274
|
-
try:
|
|
2275
|
-
if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
|
|
2276
|
-
# messages 格式
|
|
2277
|
-
from ..tokenizers import messages_token_stats
|
|
2278
|
-
|
|
2279
|
-
stats = messages_token_stats(data, messages_field=field, model=model)
|
|
2280
|
-
_print_messages_token_stats(stats, detailed)
|
|
2281
|
-
else:
|
|
2282
|
-
# 普通文本字段
|
|
2283
|
-
from ..tokenizers import token_stats as compute_token_stats
|
|
2284
|
-
|
|
2285
|
-
stats = compute_token_stats(data, fields=field, model=model)
|
|
2286
|
-
_print_text_token_stats(stats, detailed)
|
|
2287
|
-
except ImportError as e:
|
|
2288
|
-
print(f"错误: {e}")
|
|
2289
|
-
return
|
|
2290
|
-
except Exception as e:
|
|
2291
|
-
print(f"错误: 统计失败 - {e}")
|
|
2292
|
-
import traceback
|
|
2293
|
-
|
|
2294
|
-
traceback.print_exc()
|
|
2295
|
-
|
|
2296
|
-
|
|
2297
|
-
def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
2298
|
-
"""打印 messages 格式的 token 统计"""
|
|
2299
|
-
try:
|
|
2300
|
-
from rich.console import Console
|
|
2301
|
-
from rich.panel import Panel
|
|
2302
|
-
from rich.table import Table
|
|
2303
|
-
|
|
2304
|
-
console = Console()
|
|
2305
|
-
|
|
2306
|
-
# 概览
|
|
2307
|
-
overview = (
|
|
2308
|
-
f"[bold]总样本数:[/bold] {stats['count']:,}\n"
|
|
2309
|
-
f"[bold]总 Token:[/bold] {stats['total_tokens']:,}\n"
|
|
2310
|
-
f"[bold]平均 Token:[/bold] {stats['avg_tokens']:,}\n"
|
|
2311
|
-
f"[bold]中位数:[/bold] {stats['median_tokens']:,}\n"
|
|
2312
|
-
f"[bold]范围:[/bold] {stats['min_tokens']:,} - {stats['max_tokens']:,}"
|
|
2313
|
-
)
|
|
2314
|
-
console.print(Panel(overview, title="📊 Token 统计概览", expand=False))
|
|
2315
|
-
|
|
2316
|
-
if detailed:
|
|
2317
|
-
# 详细统计
|
|
2318
|
-
table = Table(title="📋 分角色统计")
|
|
2319
|
-
table.add_column("角色", style="cyan")
|
|
2320
|
-
table.add_column("Token 数", justify="right")
|
|
2321
|
-
table.add_column("占比", justify="right")
|
|
2322
|
-
|
|
2323
|
-
total = stats["total_tokens"]
|
|
2324
|
-
for role, key in [
|
|
2325
|
-
("User", "user_tokens"),
|
|
2326
|
-
("Assistant", "assistant_tokens"),
|
|
2327
|
-
("System", "system_tokens"),
|
|
2328
|
-
]:
|
|
2329
|
-
tokens = stats.get(key, 0)
|
|
2330
|
-
pct = tokens / total * 100 if total > 0 else 0
|
|
2331
|
-
table.add_row(role, f"{tokens:,}", f"{pct:.1f}%")
|
|
2332
|
-
|
|
2333
|
-
console.print(table)
|
|
2334
|
-
console.print(f"\n平均对话轮数: {stats.get('avg_turns', 0)}")
|
|
2335
|
-
|
|
2336
|
-
except ImportError:
|
|
2337
|
-
# 没有 rich,使用普通打印
|
|
2338
|
-
print(f"\n{'=' * 40}")
|
|
2339
|
-
print("📊 Token 统计概览")
|
|
2340
|
-
print(f"{'=' * 40}")
|
|
2341
|
-
print(f"总样本数: {stats['count']:,}")
|
|
2342
|
-
print(f"总 Token: {stats['total_tokens']:,}")
|
|
2343
|
-
print(f"平均 Token: {stats['avg_tokens']:,}")
|
|
2344
|
-
print(f"中位数: {stats['median_tokens']:,}")
|
|
2345
|
-
print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
|
|
2346
|
-
|
|
2347
|
-
if detailed:
|
|
2348
|
-
print(f"\n{'=' * 40}")
|
|
2349
|
-
print("📋 分角色统计")
|
|
2350
|
-
print(f"{'=' * 40}")
|
|
2351
|
-
total = stats["total_tokens"]
|
|
2352
|
-
for role, key in [
|
|
2353
|
-
("User", "user_tokens"),
|
|
2354
|
-
("Assistant", "assistant_tokens"),
|
|
2355
|
-
("System", "system_tokens"),
|
|
2356
|
-
]:
|
|
2357
|
-
tokens = stats.get(key, 0)
|
|
2358
|
-
pct = tokens / total * 100 if total > 0 else 0
|
|
2359
|
-
print(f"{role}: {tokens:,} ({pct:.1f}%)")
|
|
2360
|
-
print(f"\n平均对话轮数: {stats.get('avg_turns', 0)}")
|
|
2361
|
-
|
|
2362
|
-
|
|
2363
|
-
def _print_text_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
2364
|
-
"""打印普通文本的 token 统计"""
|
|
2365
|
-
try:
|
|
2366
|
-
from rich.console import Console
|
|
2367
|
-
from rich.panel import Panel
|
|
2368
|
-
|
|
2369
|
-
console = Console()
|
|
2370
|
-
|
|
2371
|
-
overview = (
|
|
2372
|
-
f"[bold]总样本数:[/bold] {stats['count']:,}\n"
|
|
2373
|
-
f"[bold]总 Token:[/bold] {stats['total_tokens']:,}\n"
|
|
2374
|
-
f"[bold]平均 Token:[/bold] {stats['avg_tokens']:.1f}\n"
|
|
2375
|
-
f"[bold]中位数:[/bold] {stats['median_tokens']:,}\n"
|
|
2376
|
-
f"[bold]范围:[/bold] {stats['min_tokens']:,} - {stats['max_tokens']:,}"
|
|
2377
|
-
)
|
|
2378
|
-
console.print(Panel(overview, title="📊 Token 统计", expand=False))
|
|
2379
|
-
|
|
2380
|
-
except ImportError:
|
|
2381
|
-
print(f"\n{'=' * 40}")
|
|
2382
|
-
print("📊 Token 统计")
|
|
2383
|
-
print(f"{'=' * 40}")
|
|
2384
|
-
print(f"总样本数: {stats['count']:,}")
|
|
2385
|
-
print(f"总 Token: {stats['total_tokens']:,}")
|
|
2386
|
-
print(f"平均 Token: {stats['avg_tokens']:.1f}")
|
|
2387
|
-
print(f"中位数: {stats['median_tokens']:,}")
|
|
2388
|
-
print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
|
|
2389
|
-
|
|
2390
|
-
|
|
2391
|
-
# ============ Diff Command ============
|
|
2392
|
-
|
|
2393
|
-
|
|
2394
|
-
def diff(
|
|
2395
|
-
file1: str,
|
|
2396
|
-
file2: str,
|
|
2397
|
-
key: Optional[str] = None,
|
|
2398
|
-
output: Optional[str] = None,
|
|
2399
|
-
) -> None:
|
|
2400
|
-
"""
|
|
2401
|
-
对比两个数据集的差异。
|
|
2402
|
-
|
|
2403
|
-
Args:
|
|
2404
|
-
file1: 第一个文件路径
|
|
2405
|
-
file2: 第二个文件路径
|
|
2406
|
-
key: 用于匹配的键字段,支持嵌套路径语法(可选)
|
|
2407
|
-
output: 差异报告输出路径(可选)
|
|
2408
|
-
|
|
2409
|
-
Examples:
|
|
2410
|
-
dt diff v1/train.jsonl v2/train.jsonl
|
|
2411
|
-
dt diff a.jsonl b.jsonl --key=id
|
|
2412
|
-
dt diff a.jsonl b.jsonl --key=meta.uuid # 按嵌套字段匹配
|
|
2413
|
-
dt diff a.jsonl b.jsonl --output=diff_report.json
|
|
2414
|
-
"""
|
|
2415
|
-
path1 = Path(file1)
|
|
2416
|
-
path2 = Path(file2)
|
|
2417
|
-
|
|
2418
|
-
# 验证文件
|
|
2419
|
-
for p, name in [(path1, "file1"), (path2, "file2")]:
|
|
2420
|
-
if not p.exists():
|
|
2421
|
-
print(f"错误: 文件不存在 - {p}")
|
|
2422
|
-
return
|
|
2423
|
-
if not _check_file_format(p):
|
|
2424
|
-
return
|
|
2425
|
-
|
|
2426
|
-
# 加载数据
|
|
2427
|
-
print(f"📊 加载数据...")
|
|
2428
|
-
try:
|
|
2429
|
-
data1 = load_data(str(path1))
|
|
2430
|
-
data2 = load_data(str(path2))
|
|
2431
|
-
except Exception as e:
|
|
2432
|
-
print(f"错误: 无法读取文件 - {e}")
|
|
2433
|
-
return
|
|
2434
|
-
|
|
2435
|
-
print(f" 文件1: {path1.name} ({len(data1)} 条)")
|
|
2436
|
-
print(f" 文件2: {path2.name} ({len(data2)} 条)")
|
|
2437
|
-
|
|
2438
|
-
# 计算差异
|
|
2439
|
-
print("🔍 计算差异...")
|
|
2440
|
-
diff_result = _compute_diff(data1, data2, key)
|
|
2441
|
-
|
|
2442
|
-
# 打印差异报告
|
|
2443
|
-
_print_diff_report(diff_result, path1.name, path2.name)
|
|
2444
|
-
|
|
2445
|
-
# 保存报告
|
|
2446
|
-
if output:
|
|
2447
|
-
print(f"\n💾 保存报告: {output}")
|
|
2448
|
-
save_data([diff_result], output)
|
|
2449
|
-
|
|
2450
|
-
|
|
2451
|
-
def _compute_diff(
|
|
2452
|
-
data1: List[Dict],
|
|
2453
|
-
data2: List[Dict],
|
|
2454
|
-
key: Optional[str] = None,
|
|
2455
|
-
) -> Dict[str, Any]:
|
|
2456
|
-
"""计算两个数据集的差异"""
|
|
2457
|
-
result = {
|
|
2458
|
-
"summary": {
|
|
2459
|
-
"file1_count": len(data1),
|
|
2460
|
-
"file2_count": len(data2),
|
|
2461
|
-
"added": 0,
|
|
2462
|
-
"removed": 0,
|
|
2463
|
-
"modified": 0,
|
|
2464
|
-
"unchanged": 0,
|
|
2465
|
-
},
|
|
2466
|
-
"field_changes": {},
|
|
2467
|
-
"details": {
|
|
2468
|
-
"added": [],
|
|
2469
|
-
"removed": [],
|
|
2470
|
-
"modified": [],
|
|
2471
|
-
},
|
|
2472
|
-
}
|
|
2473
|
-
|
|
2474
|
-
if key:
|
|
2475
|
-
# 基于 key 的精确匹配(支持嵌套路径)
|
|
2476
|
-
dict1 = {get_field_with_spec(item, key): item for item in data1 if get_field_with_spec(item, key) is not None}
|
|
2477
|
-
dict2 = {get_field_with_spec(item, key): item for item in data2 if get_field_with_spec(item, key) is not None}
|
|
2478
|
-
|
|
2479
|
-
keys1 = set(dict1.keys())
|
|
2480
|
-
keys2 = set(dict2.keys())
|
|
2481
|
-
|
|
2482
|
-
# 新增
|
|
2483
|
-
added_keys = keys2 - keys1
|
|
2484
|
-
result["summary"]["added"] = len(added_keys)
|
|
2485
|
-
result["details"]["added"] = [dict2[k] for k in list(added_keys)[:10]] # 最多显示 10 条
|
|
2486
|
-
|
|
2487
|
-
# 删除
|
|
2488
|
-
removed_keys = keys1 - keys2
|
|
2489
|
-
result["summary"]["removed"] = len(removed_keys)
|
|
2490
|
-
result["details"]["removed"] = [dict1[k] for k in list(removed_keys)[:10]]
|
|
2491
|
-
|
|
2492
|
-
# 修改/未变
|
|
2493
|
-
common_keys = keys1 & keys2
|
|
2494
|
-
for k in common_keys:
|
|
2495
|
-
if dict1[k] == dict2[k]:
|
|
2496
|
-
result["summary"]["unchanged"] += 1
|
|
2497
|
-
else:
|
|
2498
|
-
result["summary"]["modified"] += 1
|
|
2499
|
-
if len(result["details"]["modified"]) < 10:
|
|
2500
|
-
result["details"]["modified"].append(
|
|
2501
|
-
{
|
|
2502
|
-
"key": k,
|
|
2503
|
-
"before": dict1[k],
|
|
2504
|
-
"after": dict2[k],
|
|
2505
|
-
}
|
|
2506
|
-
)
|
|
2507
|
-
else:
|
|
2508
|
-
# 基于哈希的比较
|
|
2509
|
-
def _hash_item(item):
|
|
2510
|
-
return orjson.dumps(item, option=orjson.OPT_SORT_KEYS)
|
|
2511
|
-
|
|
2512
|
-
set1 = {_hash_item(item) for item in data1}
|
|
2513
|
-
set2 = {_hash_item(item) for item in data2}
|
|
2514
|
-
|
|
2515
|
-
added = set2 - set1
|
|
2516
|
-
removed = set1 - set2
|
|
2517
|
-
unchanged = set1 & set2
|
|
2518
|
-
|
|
2519
|
-
result["summary"]["added"] = len(added)
|
|
2520
|
-
result["summary"]["removed"] = len(removed)
|
|
2521
|
-
result["summary"]["unchanged"] = len(unchanged)
|
|
2522
|
-
|
|
2523
|
-
# 详情
|
|
2524
|
-
result["details"]["added"] = [orjson.loads(h) for h in list(added)[:10]]
|
|
2525
|
-
result["details"]["removed"] = [orjson.loads(h) for h in list(removed)[:10]]
|
|
2526
|
-
|
|
2527
|
-
# 字段变化分析
|
|
2528
|
-
fields1 = set()
|
|
2529
|
-
fields2 = set()
|
|
2530
|
-
for item in data1[:1000]: # 采样分析
|
|
2531
|
-
fields1.update(item.keys())
|
|
2532
|
-
for item in data2[:1000]:
|
|
2533
|
-
fields2.update(item.keys())
|
|
2534
|
-
|
|
2535
|
-
result["field_changes"] = {
|
|
2536
|
-
"added_fields": list(fields2 - fields1),
|
|
2537
|
-
"removed_fields": list(fields1 - fields2),
|
|
2538
|
-
"common_fields": list(fields1 & fields2),
|
|
2539
|
-
}
|
|
2540
|
-
|
|
2541
|
-
return result
|
|
2542
|
-
|
|
2543
|
-
|
|
2544
|
-
def _print_diff_report(diff_result: Dict[str, Any], name1: str, name2: str) -> None:
|
|
2545
|
-
"""打印差异报告"""
|
|
2546
|
-
summary = diff_result["summary"]
|
|
2547
|
-
field_changes = diff_result["field_changes"]
|
|
2548
|
-
|
|
2549
|
-
try:
|
|
2550
|
-
from rich.console import Console
|
|
2551
|
-
from rich.panel import Panel
|
|
2552
|
-
from rich.table import Table
|
|
2553
|
-
|
|
2554
|
-
console = Console()
|
|
2555
|
-
|
|
2556
|
-
# 概览
|
|
2557
|
-
overview = (
|
|
2558
|
-
f"[bold]{name1}:[/bold] {summary['file1_count']:,} 条\n"
|
|
2559
|
-
f"[bold]{name2}:[/bold] {summary['file2_count']:,} 条\n"
|
|
2560
|
-
f"\n"
|
|
2561
|
-
f"[green]+ 新增:[/green] {summary['added']:,} 条\n"
|
|
2562
|
-
f"[red]- 删除:[/red] {summary['removed']:,} 条\n"
|
|
2563
|
-
f"[yellow]~ 修改:[/yellow] {summary['modified']:,} 条\n"
|
|
2564
|
-
f"[dim]= 未变:[/dim] {summary['unchanged']:,} 条"
|
|
2565
|
-
)
|
|
2566
|
-
console.print(Panel(overview, title="📊 差异概览", expand=False))
|
|
2567
|
-
|
|
2568
|
-
# 字段变化
|
|
2569
|
-
if field_changes["added_fields"] or field_changes["removed_fields"]:
|
|
2570
|
-
console.print("\n[bold]📋 字段变化:[/bold]")
|
|
2571
|
-
if field_changes["added_fields"]:
|
|
2572
|
-
console.print(
|
|
2573
|
-
f" [green]+ 新增字段:[/green] {', '.join(field_changes['added_fields'])}"
|
|
2574
|
-
)
|
|
2575
|
-
if field_changes["removed_fields"]:
|
|
2576
|
-
console.print(
|
|
2577
|
-
f" [red]- 删除字段:[/red] {', '.join(field_changes['removed_fields'])}"
|
|
2578
|
-
)
|
|
2579
|
-
|
|
2580
|
-
except ImportError:
|
|
2581
|
-
print(f"\n{'=' * 50}")
|
|
2582
|
-
print("📊 差异概览")
|
|
2583
|
-
print(f"{'=' * 50}")
|
|
2584
|
-
print(f"{name1}: {summary['file1_count']:,} 条")
|
|
2585
|
-
print(f"{name2}: {summary['file2_count']:,} 条")
|
|
2586
|
-
print()
|
|
2587
|
-
print(f"+ 新增: {summary['added']:,} 条")
|
|
2588
|
-
print(f"- 删除: {summary['removed']:,} 条")
|
|
2589
|
-
print(f"~ 修改: {summary['modified']:,} 条")
|
|
2590
|
-
print(f"= 未变: {summary['unchanged']:,} 条")
|
|
2591
|
-
|
|
2592
|
-
if field_changes["added_fields"] or field_changes["removed_fields"]:
|
|
2593
|
-
print(f"\n📋 字段变化:")
|
|
2594
|
-
if field_changes["added_fields"]:
|
|
2595
|
-
print(f" + 新增字段: {', '.join(field_changes['added_fields'])}")
|
|
2596
|
-
if field_changes["removed_fields"]:
|
|
2597
|
-
print(f" - 删除字段: {', '.join(field_changes['removed_fields'])}")
|
|
2598
|
-
|
|
2599
|
-
|
|
2600
|
-
# ============ History Command ============
|
|
2601
|
-
|
|
2602
|
-
|
|
2603
|
-
def history(
|
|
2604
|
-
filename: str,
|
|
2605
|
-
json: bool = False,
|
|
2606
|
-
) -> None:
|
|
2607
|
-
"""
|
|
2608
|
-
显示数据文件的血缘历史。
|
|
2609
|
-
|
|
2610
|
-
Args:
|
|
2611
|
-
filename: 数据文件路径
|
|
2612
|
-
json: 以 JSON 格式输出
|
|
2613
|
-
|
|
2614
|
-
Examples:
|
|
2615
|
-
dt history data.jsonl
|
|
2616
|
-
dt history data.jsonl --json
|
|
2617
|
-
"""
|
|
2618
|
-
filepath = Path(filename)
|
|
2619
|
-
|
|
2620
|
-
if not filepath.exists():
|
|
2621
|
-
print(f"错误: 文件不存在 - {filename}")
|
|
2622
|
-
return
|
|
2623
|
-
|
|
2624
|
-
if not has_lineage(str(filepath)):
|
|
2625
|
-
print(f"文件 {filename} 没有血缘记录")
|
|
2626
|
-
print("\n提示: 使用 track_lineage=True 加载数据,并在保存时使用 lineage=True 来记录血缘")
|
|
2627
|
-
print("示例:")
|
|
2628
|
-
print(" dt = DataTransformer.load('data.jsonl', track_lineage=True)")
|
|
2629
|
-
print(" dt.filter(...).transform(...).save('output.jsonl', lineage=True)")
|
|
2630
|
-
return
|
|
2631
|
-
|
|
2632
|
-
if json:
|
|
2633
|
-
# JSON 格式输出
|
|
2634
|
-
chain = get_lineage_chain(str(filepath))
|
|
2635
|
-
output = [record.to_dict() for record in chain]
|
|
2636
|
-
print(orjson.dumps(output, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
|
2637
|
-
else:
|
|
2638
|
-
# 格式化报告
|
|
2639
|
-
report = format_lineage_report(str(filepath))
|
|
2640
|
-
print(report)
|
|
15
|
+
# 采样命令
|
|
16
|
+
from .sample import head, sample, tail
|
|
17
|
+
|
|
18
|
+
# 转换命令
|
|
19
|
+
from .transform import transform
|
|
20
|
+
|
|
21
|
+
# 统计命令
|
|
22
|
+
from .stats import stats, token_stats
|
|
23
|
+
|
|
24
|
+
# 清洗命令
|
|
25
|
+
from .clean import clean, dedupe
|
|
26
|
+
|
|
27
|
+
# IO 操作命令
|
|
28
|
+
from .io_ops import concat, diff
|
|
29
|
+
|
|
30
|
+
# Pipeline 命令
|
|
31
|
+
from .pipeline import run
|
|
32
|
+
|
|
33
|
+
# 血缘追踪命令
|
|
34
|
+
from .lineage import history
|
|
35
|
+
|
|
36
|
+
__all__ = [
|
|
37
|
+
# 采样
|
|
38
|
+
"sample",
|
|
39
|
+
"head",
|
|
40
|
+
"tail",
|
|
41
|
+
# 转换
|
|
42
|
+
"transform",
|
|
43
|
+
# 统计
|
|
44
|
+
"stats",
|
|
45
|
+
"token_stats",
|
|
46
|
+
# 清洗
|
|
47
|
+
"clean",
|
|
48
|
+
"dedupe",
|
|
49
|
+
# IO 操作
|
|
50
|
+
"concat",
|
|
51
|
+
"diff",
|
|
52
|
+
# Pipeline
|
|
53
|
+
"run",
|
|
54
|
+
# 血缘
|
|
55
|
+
"history",
|
|
56
|
+
]
|