dtflow 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dtflow/cli/commands.py CHANGED
@@ -1,7 +1,10 @@
1
1
  """
2
2
  CLI 命令实现
3
3
  """
4
- import json
4
+ import orjson
5
+ import os
6
+ import shutil
7
+ import tempfile
5
8
  from datetime import datetime
6
9
  from pathlib import Path
7
10
  from typing import Any, Dict, List, Literal, Optional
@@ -9,11 +12,22 @@ from typing import Any, Dict, List, Literal, Optional
9
12
  from ..core import DataTransformer, DictWrapper
10
13
  from ..presets import get_preset, list_presets
11
14
  from ..storage.io import load_data, save_data, sample_file
15
+ from ..pipeline import run_pipeline, validate_pipeline
16
+ from ..lineage import load_lineage, format_lineage_report, has_lineage, get_lineage_chain
17
+ from ..streaming import load_stream
12
18
 
13
19
 
14
20
  # 支持的文件格式
15
21
  SUPPORTED_FORMATS = {".csv", ".jsonl", ".json", ".xlsx", ".xls", ".parquet", ".arrow", ".feather"}
16
22
 
23
+ # 支持流式处理的格式(与 streaming.py 保持一致)
24
+ STREAMING_FORMATS = {".jsonl", ".csv", ".parquet", ".arrow", ".feather"}
25
+
26
+
27
+ def _is_streaming_supported(filepath: Path) -> bool:
28
+ """检查文件是否支持流式处理"""
29
+ return filepath.suffix.lower() in STREAMING_FORMATS
30
+
17
31
 
18
32
  def _check_file_format(filepath: Path) -> bool:
19
33
  """检查文件格式是否支持,不支持则打印错误信息并返回 False"""
@@ -28,11 +42,12 @@ def _check_file_format(filepath: Path) -> bool:
28
42
  def sample(
29
43
  filename: str,
30
44
  num: int = 10,
31
- sample_type: Literal["random", "head", "tail"] = "head",
45
+ type: Literal["random", "head", "tail"] = "head",
32
46
  output: Optional[str] = None,
33
47
  seed: Optional[int] = None,
34
48
  by: Optional[str] = None,
35
49
  uniform: bool = False,
50
+ fields: Optional[str] = None,
36
51
  ) -> None:
37
52
  """
38
53
  从数据文件中采样指定数量的数据。
@@ -43,20 +58,22 @@ def sample(
43
58
  - num > 0: 采样指定数量
44
59
  - num = 0: 采样所有数据
45
60
  - num < 0: Python 切片风格(如 -1 表示最后 1 条,-10 表示最后 10 条)
46
- sample_type: 采样方式,可选 random/head/tail,默认 head
61
+ type: 采样方式,可选 random/head/tail,默认 head
47
62
  output: 输出文件路径,不指定则打印到控制台
48
- seed: 随机种子(仅在 sample_type=random 时有效)
63
+ seed: 随机种子(仅在 type=random 时有效)
49
64
  by: 分层采样字段名,按该字段的值分组采样
50
65
  uniform: 均匀采样模式(需配合 --by 使用),各组采样相同数量
66
+ fields: 只显示指定字段(逗号分隔),仅在预览模式下有效
51
67
 
52
68
  Examples:
53
69
  dt sample data.jsonl 5
54
- dt sample data.csv 100 --sample_type=head
70
+ dt sample data.csv 100 --type=head
55
71
  dt sample data.xlsx 50 --output=sampled.jsonl
56
72
  dt sample data.jsonl 0 # 采样所有数据
57
73
  dt sample data.jsonl -10 # 最后 10 条数据
58
74
  dt sample data.jsonl 1000 --by=category # 按比例分层采样
59
75
  dt sample data.jsonl 1000 --by=category --uniform # 均匀分层采样
76
+ dt sample data.jsonl --fields=question,answer # 只显示指定字段
60
77
  """
61
78
  filepath = Path(filename)
62
79
 
@@ -76,7 +93,7 @@ def sample(
76
93
  if by:
77
94
  try:
78
95
  sampled = _stratified_sample(
79
- filepath, num, by, uniform, seed, sample_type
96
+ filepath, num, by, uniform, seed, type
80
97
  )
81
98
  except Exception as e:
82
99
  print(f"错误: {e}")
@@ -87,7 +104,7 @@ def sample(
87
104
  sampled = sample_file(
88
105
  str(filepath),
89
106
  num=num,
90
- sample_type=sample_type,
107
+ sample_type=type,
91
108
  seed=seed,
92
109
  output=None, # 先不保存,统一在最后处理
93
110
  )
@@ -100,7 +117,11 @@ def sample(
100
117
  save_data(sampled, output)
101
118
  print(f"已保存 {len(sampled)} 条数据到 {output}")
102
119
  else:
103
- _print_samples(sampled)
120
+ # 获取文件总行数用于显示
121
+ total_count = _get_file_row_count(filepath)
122
+ # 解析 fields 参数
123
+ field_list = _parse_field_list(fields) if fields else None
124
+ _print_samples(sampled, filepath.name, total_count, field_list)
104
125
 
105
126
 
106
127
  def _stratified_sample(
@@ -225,9 +246,10 @@ def head(
225
246
  filename: str,
226
247
  num: int = 10,
227
248
  output: Optional[str] = None,
249
+ fields: Optional[str] = None,
228
250
  ) -> None:
229
251
  """
230
- 显示文件的前 N 条数据(dt sample --sample_type=head 的快捷方式)。
252
+ 显示文件的前 N 条数据(dt sample --type=head 的快捷方式)。
231
253
 
232
254
  Args:
233
255
  filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
@@ -236,23 +258,26 @@ def head(
236
258
  - num = 0: 显示所有数据
237
259
  - num < 0: Python 切片风格(如 -10 表示最后 10 条)
238
260
  output: 输出文件路径,不指定则打印到控制台
261
+ fields: 只显示指定字段(逗号分隔),仅在预览模式下有效
239
262
 
240
263
  Examples:
241
264
  dt head data.jsonl # 显示前 10 条
242
265
  dt head data.jsonl 20 # 显示前 20 条
243
266
  dt head data.csv 0 # 显示所有数据
244
267
  dt head data.xlsx --output=head.jsonl
268
+ dt head data.jsonl --fields=question,answer
245
269
  """
246
- sample(filename, num=num, sample_type="head", output=output)
270
+ sample(filename, num=num, type="head", output=output, fields=fields)
247
271
 
248
272
 
249
273
  def tail(
250
274
  filename: str,
251
275
  num: int = 10,
252
276
  output: Optional[str] = None,
277
+ fields: Optional[str] = None,
253
278
  ) -> None:
254
279
  """
255
- 显示文件的后 N 条数据(dt sample --sample_type=tail 的快捷方式)。
280
+ 显示文件的后 N 条数据(dt sample --type=tail 的快捷方式)。
256
281
 
257
282
  Args:
258
283
  filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
@@ -261,58 +286,244 @@ def tail(
261
286
  - num = 0: 显示所有数据
262
287
  - num < 0: Python 切片风格(如 -10 表示最后 10 条)
263
288
  output: 输出文件路径,不指定则打印到控制台
289
+ fields: 只显示指定字段(逗号分隔),仅在预览模式下有效
264
290
 
265
291
  Examples:
266
292
  dt tail data.jsonl # 显示后 10 条
267
293
  dt tail data.jsonl 20 # 显示后 20 条
268
294
  dt tail data.csv 0 # 显示所有数据
269
295
  dt tail data.xlsx --output=tail.jsonl
296
+ dt tail data.jsonl --fields=question,answer
270
297
  """
271
- sample(filename, num=num, sample_type="tail", output=output)
298
+ sample(filename, num=num, type="tail", output=output, fields=fields)
272
299
 
273
300
 
274
- def _print_samples(samples: list) -> None:
275
- """打印采样结果。"""
301
+ def _get_file_row_count(filepath: Path) -> Optional[int]:
302
+ """
303
+ 快速获取文件行数(不加载全部数据)。
304
+
305
+ 对于 JSONL 文件,直接计算行数;其他格式返回 None。
306
+ """
307
+ ext = filepath.suffix.lower()
308
+ if ext == ".jsonl":
309
+ try:
310
+ with open(filepath, "rb") as f:
311
+ return sum(1 for _ in f)
312
+ except Exception:
313
+ return None
314
+ # 其他格式暂不支持快速计数
315
+ return None
316
+
317
+
318
+ def _format_value(value: Any, max_len: int = 80) -> str:
319
+ """格式化单个值,长文本截断。"""
320
+ if value is None:
321
+ return "[dim]null[/dim]"
322
+ if isinstance(value, bool):
323
+ return "[cyan]true[/cyan]" if value else "[cyan]false[/cyan]"
324
+ if isinstance(value, (int, float)):
325
+ return f"[cyan]{value}[/cyan]"
326
+ if isinstance(value, str):
327
+ # 处理多行文本
328
+ if "\n" in value:
329
+ lines = value.split("\n")
330
+ if len(lines) > 3:
331
+ preview = lines[0][:max_len] + f"... [dim]({len(lines)} 行)[/dim]"
332
+ else:
333
+ preview = value.replace("\n", "\\n")
334
+ if len(preview) > max_len:
335
+ preview = preview[:max_len] + "..."
336
+ return f'"{preview}"'
337
+ if len(value) > max_len:
338
+ return f'"{value[:max_len]}..." [dim]({len(value)} 字符)[/dim]'
339
+ return f'"{value}"'
340
+ return str(value)
341
+
342
+
343
+ def _format_nested(
344
+ value: Any,
345
+ indent: str = "",
346
+ is_last: bool = True,
347
+ max_len: int = 80,
348
+ ) -> List[str]:
349
+ """
350
+ 递归格式化嵌套结构,返回行列表。
351
+
352
+ 使用树形符号展示结构:
353
+ ├─ 中间项
354
+ └─ 最后一项
355
+ """
356
+ lines = []
357
+ branch = "└─ " if is_last else "├─ "
358
+ cont = " " if is_last else "│ "
359
+
360
+ if isinstance(value, dict):
361
+ items = list(value.items())
362
+ for i, (k, v) in enumerate(items):
363
+ is_last_item = (i == len(items) - 1)
364
+ b = "└─ " if is_last_item else "├─ "
365
+ c = " " if is_last_item else "│ "
366
+
367
+ if isinstance(v, (dict, list)) and v:
368
+ # 嵌套结构
369
+ if isinstance(v, list):
370
+ # 检测是否为 messages 格式
371
+ is_messages = (
372
+ v and isinstance(v[0], dict)
373
+ and "role" in v[0] and "content" in v[0]
374
+ )
375
+ if is_messages:
376
+ lines.append(f"{indent}{b}[green]{k}[/green]: ({len(v)} items) [dim]→ \\[role]: content[/dim]")
377
+ else:
378
+ lines.append(f"{indent}{b}[green]{k}[/green]: ({len(v)} items)")
379
+ else:
380
+ lines.append(f"{indent}{b}[green]{k}[/green]:")
381
+ lines.extend(_format_nested(v, indent + c, True, max_len))
382
+ else:
383
+ # 简单值
384
+ lines.append(f"{indent}{b}[green]{k}[/green]: {_format_value(v, max_len)}")
385
+
386
+ elif isinstance(value, list):
387
+ for i, item in enumerate(value):
388
+ is_last_item = (i == len(value) - 1)
389
+ b = "└─ " if is_last_item else "├─ "
390
+ c = " " if is_last_item else "│ "
391
+
392
+ if isinstance(item, dict):
393
+ # 列表中的字典项 - 检测是否为 messages 格式
394
+ if "role" in item and "content" in item:
395
+ role = item.get("role", "")
396
+ content = item.get("content", "")
397
+ # 截断长内容
398
+ if len(content) > max_len:
399
+ content = content[:max_len].replace("\n", "\\n") + "..."
400
+ else:
401
+ content = content.replace("\n", "\\n")
402
+ # 使用 \[ 转义避免被 rich 解析为样式
403
+ lines.append(f"{indent}{b}[yellow]\\[{role}]:[/yellow] {content}")
404
+ else:
405
+ # 普通字典
406
+ lines.append(f"{indent}{b}[dim]{{...}}[/dim]")
407
+ lines.extend(_format_nested(item, indent + c, True, max_len))
408
+ elif isinstance(item, list):
409
+ lines.append(f"{indent}{b}[dim][{len(item)} items][/dim]")
410
+ lines.extend(_format_nested(item, indent + c, True, max_len))
411
+ else:
412
+ lines.append(f"{indent}{b}{_format_value(item, max_len)}")
413
+
414
+ return lines
415
+
416
+
417
+ def _is_simple_data(samples: List[Dict]) -> bool:
418
+ """判断数据是否适合表格展示(无嵌套结构)。"""
419
+ if not samples or not isinstance(samples[0], dict):
420
+ return False
421
+ keys = list(samples[0].keys())
422
+ if len(keys) > 6:
423
+ return False
424
+ for s in samples[:3]:
425
+ for k in keys:
426
+ v = s.get(k)
427
+ if isinstance(v, (dict, list)):
428
+ return False
429
+ if isinstance(v, str) and len(v) > 80:
430
+ return False
431
+ return True
432
+
433
+
434
+ def _print_samples(
435
+ samples: list,
436
+ filename: Optional[str] = None,
437
+ total_count: Optional[int] = None,
438
+ fields: Optional[List[str]] = None,
439
+ ) -> None:
440
+ """
441
+ 打印采样结果。
442
+
443
+ Args:
444
+ samples: 采样数据列表
445
+ filename: 文件名(用于显示概览)
446
+ total_count: 文件总行数(用于显示概览)
447
+ fields: 只显示指定字段
448
+ """
276
449
  if not samples:
277
450
  print("没有数据")
278
451
  return
279
452
 
453
+ # 过滤字段
454
+ if fields and isinstance(samples[0], dict):
455
+ field_set = set(fields)
456
+ samples = [{k: v for k, v in item.items() if k in field_set} for item in samples]
457
+
280
458
  try:
281
459
  from rich.console import Console
282
- from rich.json import JSON
283
460
  from rich.table import Table
461
+ from rich.panel import Panel
284
462
 
285
463
  console = Console()
286
464
 
287
- # 尝试以表格形式展示
288
- if isinstance(samples[0], dict):
465
+ # 显示数据概览头部
466
+ if filename:
467
+ all_fields = set()
468
+ for item in samples:
469
+ if isinstance(item, dict):
470
+ all_fields.update(item.keys())
471
+ field_names = ", ".join(sorted(all_fields))
472
+
473
+ if total_count is not None:
474
+ info = f"总行数: {total_count:,} | 采样: {len(samples)} 条 | 字段: {len(all_fields)} 个"
475
+ else:
476
+ info = f"采样: {len(samples)} 条 | 字段: {len(all_fields)} 个"
477
+
478
+ console.print(Panel(
479
+ f"[dim]{info}[/dim]\n[dim]字段: {field_names}[/dim]",
480
+ title=f"[bold]📊 {filename}[/bold]",
481
+ expand=False,
482
+ border_style="dim",
483
+ ))
484
+ console.print()
485
+
486
+ # 简单数据用表格展示
487
+ if _is_simple_data(samples):
289
488
  keys = list(samples[0].keys())
290
- # 适合表格展示:字段不太多且值不太长
291
- if len(keys) <= 5 and all(
292
- len(str(s.get(k, ""))) < 100 for s in samples[:3] for k in keys
293
- ):
294
- table = Table(title=f"采样结果 ({len(samples)} )")
295
- for key in keys:
296
- table.add_column(key, overflow="fold")
297
- for item in samples:
298
- table.add_row(*[str(item.get(k, "")) for k in keys])
299
- console.print(table)
300
- return
301
-
302
- # 以 JSON 形式展示
489
+ table = Table(show_header=True, header_style="bold cyan")
490
+ for key in keys:
491
+ table.add_column(key, overflow="fold")
492
+ for item in samples:
493
+ table.add_row(*[str(item.get(k, "")) for k in keys])
494
+ console.print(table)
495
+ return
496
+
497
+ # 嵌套数据用树形结构展示
303
498
  for i, item in enumerate(samples, 1):
304
- console.print(f"\n[bold cyan]--- 第 {i} 条 ---[/bold cyan]")
305
- console.print(JSON.from_data(item))
499
+ console.print(f"[bold cyan]--- 第 {i} 条 ---[/bold cyan]")
500
+ if isinstance(item, dict):
501
+ for line in _format_nested(item):
502
+ console.print(line)
503
+ else:
504
+ console.print(_format_value(item))
505
+ console.print()
306
506
 
307
507
  except ImportError:
308
508
  # 没有 rich,使用普通打印
309
- import json
509
+ if filename:
510
+ all_fields = set()
511
+ for item in samples:
512
+ if isinstance(item, dict):
513
+ all_fields.update(item.keys())
514
+
515
+ print(f"\n📊 {filename}")
516
+ if total_count is not None:
517
+ print(f" 总行数: {total_count:,} | 采样: {len(samples)} 条 | 字段: {len(all_fields)} 个")
518
+ else:
519
+ print(f" 采样: {len(samples)} 条 | 字段: {len(all_fields)} 个")
520
+ print(f" 字段: {', '.join(sorted(all_fields))}")
521
+ print()
310
522
 
311
523
  for i, item in enumerate(samples, 1):
312
- print(f"\n--- 第 {i} 条 ---")
313
- print(json.dumps(item, ensure_ascii=False, indent=2))
314
-
315
- print(f"\n共 {len(samples)} 条数据")
524
+ print(f"--- 第 {i} 条 ---")
525
+ print(orjson.dumps(item, option=orjson.OPT_INDENT_2).decode("utf-8"))
526
+ print()
316
527
 
317
528
 
318
529
  # ============ Transform Command ============
@@ -522,17 +733,16 @@ def _format_example_value(value: Any, max_len: int = 50) -> str:
522
733
  # 截断长字符串
523
734
  if len(value) > max_len:
524
735
  value = value[:max_len] + "..."
525
- # 转义并加引号
526
- escaped = value.replace("\\", "\\\\").replace('"', '\\"').replace("\n", "\\n")
527
- return f'"{escaped}"'
736
+ # 使用 repr() 自动处理所有转义字符
737
+ return repr(value)
528
738
  if isinstance(value, bool):
529
739
  return str(value)
530
740
  if isinstance(value, (int, float)):
531
741
  return str(value)
532
742
  if isinstance(value, (list, dict)):
533
- s = json.dumps(value, ensure_ascii=False)
743
+ s = orjson.dumps(value).decode("utf-8")
534
744
  if len(s) > max_len:
535
- return f"{s[:max_len]}..."
745
+ return repr(s[:max_len] + "...")
536
746
  return s
537
747
  return '""'
538
748
 
@@ -579,7 +789,7 @@ def _execute_transform(
579
789
  output_override: Optional[str],
580
790
  num: Optional[int],
581
791
  ) -> None:
582
- """执行数据转换"""
792
+ """执行数据转换(默认流式处理)"""
583
793
  print(f"📂 加载配置: {config_path}")
584
794
 
585
795
  # 动态加载配置文件
@@ -599,7 +809,28 @@ def _execute_transform(
599
809
  # 获取输出路径
600
810
  output_path = output_override or config_ns.get("output", "output.jsonl")
601
811
 
602
- # 加载数据并使用 DataTransformer 执行转换
812
+ # 对于 JSONL 文件使用流式处理
813
+ if _is_streaming_supported(input_path):
814
+ print(f"📊 流式加载: {input_path}")
815
+ print("🔄 执行转换...")
816
+ try:
817
+ # 包装转换函数以支持属性访问(配置文件中定义的 Item 类)
818
+ def wrapped_transform(item):
819
+ return transform_func(DictWrapper(item))
820
+
821
+ st = load_stream(str(input_path))
822
+ if num:
823
+ st = st.head(num)
824
+ count = st.transform(wrapped_transform).save(output_path)
825
+ print(f"💾 保存结果: {output_path}")
826
+ print(f"\n✅ 完成! 已转换 {count} 条数据到 {output_path}")
827
+ except Exception as e:
828
+ print(f"错误: 转换失败 - {e}")
829
+ import traceback
830
+ traceback.print_exc()
831
+ return
832
+
833
+ # 非 JSONL 文件使用传统方式
603
834
  print(f"📊 加载数据: {input_path}")
604
835
  try:
605
836
  dt = DataTransformer.load(str(input_path))
@@ -641,7 +872,7 @@ def _execute_preset_transform(
641
872
  output_override: Optional[str],
642
873
  num: Optional[int],
643
874
  ) -> None:
644
- """使用预设模板执行转换"""
875
+ """使用预设模板执行转换(默认流式处理)"""
645
876
  print(f"📂 使用预设: {preset_name}")
646
877
 
647
878
  # 获取预设函数
@@ -652,7 +883,57 @@ def _execute_preset_transform(
652
883
  print(f"可用预设: {', '.join(list_presets())}")
653
884
  return
654
885
 
655
- # 加载数据
886
+ output_path = output_override or f"{input_path.stem}_{preset_name}.jsonl"
887
+
888
+ # 检查输入输出是否相同
889
+ input_resolved = input_path.resolve()
890
+ output_resolved = Path(output_path).resolve()
891
+ use_temp_file = input_resolved == output_resolved
892
+
893
+ # 对于 JSONL 文件使用流式处理
894
+ if _is_streaming_supported(input_path):
895
+ print(f"📊 流式加载: {input_path}")
896
+ print("🔄 执行转换...")
897
+
898
+ # 如果输入输出相同,使用临时文件
899
+ if use_temp_file:
900
+ print("⚠ 检测到输出文件与输入文件相同,将使用临时文件")
901
+ temp_fd, temp_path = tempfile.mkstemp(
902
+ suffix=output_resolved.suffix,
903
+ prefix=".tmp_",
904
+ dir=output_resolved.parent,
905
+ )
906
+ os.close(temp_fd)
907
+ actual_output = temp_path
908
+ else:
909
+ actual_output = output_path
910
+
911
+ try:
912
+ # 包装转换函数以支持属性访问
913
+ def wrapped_transform(item):
914
+ return transform_func(DictWrapper(item))
915
+
916
+ st = load_stream(str(input_path))
917
+ if num:
918
+ st = st.head(num)
919
+ count = st.transform(wrapped_transform).save(actual_output)
920
+
921
+ # 如果使用了临时文件,移动到目标位置
922
+ if use_temp_file:
923
+ shutil.move(temp_path, output_path)
924
+
925
+ print(f"💾 保存结果: {output_path}")
926
+ print(f"\n✅ 完成! 已转换 {count} 条数据到 {output_path}")
927
+ except Exception as e:
928
+ # 清理临时文件
929
+ if use_temp_file and os.path.exists(temp_path):
930
+ os.unlink(temp_path)
931
+ print(f"错误: 转换失败 - {e}")
932
+ import traceback
933
+ traceback.print_exc()
934
+ return
935
+
936
+ # 非 JSONL 文件使用传统方式
656
937
  print(f"📊 加载数据: {input_path}")
657
938
  try:
658
939
  dt = DataTransformer.load(str(input_path))
@@ -678,7 +959,6 @@ def _execute_preset_transform(
678
959
  return
679
960
 
680
961
  # 保存结果
681
- output_path = output_override or f"{input_path.stem}_{preset_name}.jsonl"
682
962
  print(f"💾 保存结果: {output_path}")
683
963
  try:
684
964
  save_data(results, output_path)
@@ -809,7 +1089,7 @@ def concat(
809
1089
  strict: bool = False,
810
1090
  ) -> None:
811
1091
  """
812
- 拼接多个数据文件。
1092
+ 拼接多个数据文件(流式处理,内存占用 O(1))。
813
1093
 
814
1094
  Args:
815
1095
  *files: 输入文件路径列表,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
@@ -832,7 +1112,7 @@ def concat(
832
1112
  # 验证所有文件
833
1113
  file_paths = []
834
1114
  for f in files:
835
- filepath = Path(f)
1115
+ filepath = Path(f).resolve() # 使用绝对路径进行比较
836
1116
  if not filepath.exists():
837
1117
  print(f"错误: 文件不存在 - {f}")
838
1118
  return
@@ -840,31 +1120,37 @@ def concat(
840
1120
  return
841
1121
  file_paths.append(filepath)
842
1122
 
843
- # 分析各文件的字段
1123
+ # 检查输出文件是否与输入文件冲突
1124
+ output_path = Path(output).resolve()
1125
+ use_temp_file = output_path in file_paths
1126
+ if use_temp_file:
1127
+ print("⚠ 检测到输出文件与输入文件相同,将使用临时文件")
1128
+
1129
+ # 流式分析字段(只读取每个文件的第一行)
844
1130
  print("📊 文件字段分析:")
845
- file_infos = [] # [(filepath, data, fields, count)]
1131
+ file_fields = [] # [(filepath, fields)]
846
1132
 
847
1133
  for filepath in file_paths:
848
1134
  try:
849
- data = load_data(str(filepath))
1135
+ # 只读取第一行来获取字段
1136
+ first_row = load_stream(str(filepath)).head(1).collect()
1137
+ if not first_row:
1138
+ print(f"警告: 文件为空 - {filepath}")
1139
+ fields = set()
1140
+ else:
1141
+ fields = set(first_row[0].keys())
850
1142
  except Exception as e:
851
1143
  print(f"错误: 无法读取文件 {filepath} - {e}")
852
1144
  return
853
1145
 
854
- if not data:
855
- print(f"警告: 文件为空 - {filepath}")
856
- fields = set()
857
- else:
858
- fields = set(data[0].keys())
859
-
860
- file_infos.append((filepath, data, fields, len(data)))
1146
+ file_fields.append((filepath, fields))
861
1147
  fields_str = ", ".join(sorted(fields)) if fields else "(空)"
862
- print(f" {filepath.name}: {fields_str} ({len(data)} 条)")
1148
+ print(f" {filepath.name}: {fields_str}")
863
1149
 
864
1150
  # 分析字段差异
865
1151
  all_fields = set()
866
1152
  common_fields = None
867
- for _, _, fields, _ in file_infos:
1153
+ for _, fields in file_fields:
868
1154
  all_fields.update(fields)
869
1155
  if common_fields is None:
870
1156
  common_fields = fields.copy()
@@ -883,25 +1169,72 @@ def concat(
883
1169
  else:
884
1170
  print(f"\n⚠ 字段差异: {', '.join(sorted(diff_fields))} 仅在部分文件中存在")
885
1171
 
886
- # 执行拼接
887
- print("\n🔄 执行拼接...")
888
- all_data = []
889
- for _, data, _, _ in file_infos:
890
- all_data.extend(data)
1172
+ # 流式拼接
1173
+ print("\n🔄 流式拼接...")
1174
+
1175
+ # 如果输出文件与输入文件冲突,使用临时文件(在输出文件同一目录下)
1176
+ if use_temp_file:
1177
+ output_dir = output_path.parent
1178
+ temp_fd, temp_path = tempfile.mkstemp(
1179
+ suffix=output_path.suffix,
1180
+ prefix=".tmp_",
1181
+ dir=output_dir,
1182
+ )
1183
+ os.close(temp_fd)
1184
+ actual_output = temp_path
1185
+ print(f"💾 写入临时文件: {temp_path}")
1186
+ else:
1187
+ actual_output = output
1188
+ print(f"💾 保存结果: {output}")
891
1189
 
892
- # 保存结果
893
- print(f"💾 保存结果: {output}")
894
1190
  try:
895
- save_data(all_data, output)
1191
+ total_count = _concat_streaming(file_paths, actual_output)
1192
+
1193
+ # 如果使用了临时文件,重命名为目标文件
1194
+ if use_temp_file:
1195
+ shutil.move(temp_path, output)
1196
+ print(f"💾 移动到目标文件: {output}")
896
1197
  except Exception as e:
897
- print(f"错误: 无法保存文件 - {e}")
1198
+ # 清理临时文件
1199
+ if use_temp_file and os.path.exists(temp_path):
1200
+ os.unlink(temp_path)
1201
+ print(f"错误: 拼接失败 - {e}")
898
1202
  return
899
1203
 
900
- total_count = len(all_data)
901
1204
  file_count = len(files)
902
1205
  print(f"\n✅ 完成! 已合并 {file_count} 个文件,共 {total_count} 条数据到 {output}")
903
1206
 
904
1207
 
1208
+ def _concat_streaming(file_paths: List[Path], output: str) -> int:
1209
+ """流式拼接多个文件"""
1210
+ from ..streaming import StreamingTransformer, _stream_jsonl, _stream_csv, _stream_parquet, _stream_arrow
1211
+
1212
+ def generator():
1213
+ for filepath in file_paths:
1214
+ ext = filepath.suffix.lower()
1215
+ if ext == ".jsonl":
1216
+ yield from _stream_jsonl(str(filepath))
1217
+ elif ext == ".csv":
1218
+ yield from _stream_csv(str(filepath))
1219
+ elif ext == ".parquet":
1220
+ yield from _stream_parquet(str(filepath))
1221
+ elif ext in (".arrow", ".feather"):
1222
+ yield from _stream_arrow(str(filepath))
1223
+ elif ext in (".json",):
1224
+ # JSON 需要全量加载
1225
+ data = load_data(str(filepath))
1226
+ yield from data
1227
+ elif ext in (".xlsx", ".xls"):
1228
+ # Excel 需要全量加载
1229
+ data = load_data(str(filepath))
1230
+ yield from data
1231
+ else:
1232
+ yield from _stream_jsonl(str(filepath))
1233
+
1234
+ st = StreamingTransformer(generator())
1235
+ return st.save(output, show_progress=True)
1236
+
1237
+
905
1238
  # ============ Stats Command ============
906
1239
 
907
1240
 
@@ -992,8 +1325,8 @@ def _compute_field_stats(data: List[Dict], top: int) -> List[Dict[str, Any]]:
992
1325
 
993
1326
  # 类型特定统计
994
1327
  if non_null:
995
- # 唯一值计数
996
- stat["unique"] = len(set(str(v) for v in non_null))
1328
+ # 唯一值计数(对复杂类型使用 hash 节省内存)
1329
+ stat["unique"] = _count_unique(non_null, field_type)
997
1330
 
998
1331
  # 字符串类型:计算长度统计
999
1332
  if field_type == "str":
@@ -1025,6 +1358,28 @@ def _compute_field_stats(data: List[Dict], top: int) -> List[Dict[str, Any]]:
1025
1358
  return stats_list
1026
1359
 
1027
1360
 
1361
+ def _count_unique(values: List[Any], field_type: str) -> int:
1362
+ """
1363
+ 计算唯一值数量。
1364
+
1365
+ 对于简单类型直接比较,对于 list/dict 使用 hash 节省内存。
1366
+ """
1367
+ if field_type in ("list", "dict"):
1368
+ # 复杂类型:使用 orjson 序列化后计算 hash
1369
+ import hashlib
1370
+
1371
+ import orjson
1372
+
1373
+ seen = set()
1374
+ for v in values:
1375
+ h = hashlib.md5(orjson.dumps(v, option=orjson.OPT_SORT_KEYS)).digest()
1376
+ seen.add(h)
1377
+ return len(seen)
1378
+ else:
1379
+ # 简单类型:直接比较
1380
+ return len(set(values))
1381
+
1382
+
1028
1383
  def _infer_type(values: List[Any]) -> str:
1029
1384
  """推断字段类型"""
1030
1385
  if not values:
@@ -1200,7 +1555,7 @@ def clean(
1200
1555
  output: Optional[str] = None,
1201
1556
  ) -> None:
1202
1557
  """
1203
- 数据清洗。
1558
+ 数据清洗(默认流式处理)。
1204
1559
 
1205
1560
  Args:
1206
1561
  filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
@@ -1233,29 +1588,19 @@ def clean(
1233
1588
  if not _check_file_format(filepath):
1234
1589
  return
1235
1590
 
1236
- # 加载数据
1237
- print(f"📊 加载数据: {filepath}")
1238
- try:
1239
- dt = DataTransformer.load(str(filepath))
1240
- except Exception as e:
1241
- print(f"错误: 无法读取文件 - {e}")
1242
- return
1243
-
1244
- original_count = len(dt)
1245
- print(f" 共 {original_count} 条数据")
1246
-
1247
- # 解析参数(fire 可能会将逗号分隔的值解析为元组)
1591
+ # 解析参数
1248
1592
  min_len_field, min_len_value = _parse_len_param(min_len) if min_len else (None, None)
1249
1593
  max_len_field, max_len_value = _parse_len_param(max_len) if max_len else (None, None)
1250
1594
  keep_fields = _parse_field_list(keep) if keep else None
1251
- drop_fields = _parse_field_list(drop) if drop else None
1595
+ drop_fields_set = set(_parse_field_list(drop)) if drop else None
1596
+ keep_set = set(keep_fields) if keep_fields else None
1252
1597
 
1253
1598
  # 构建清洗配置
1254
1599
  empty_fields = None
1255
1600
  if drop_empty is not None:
1256
1601
  if drop_empty == "" or drop_empty is True:
1257
1602
  print("🔄 删除任意字段为空的记录...")
1258
- empty_fields = [] # 空列表表示检查所有字段
1603
+ empty_fields = []
1259
1604
  else:
1260
1605
  empty_fields = _parse_field_list(drop_empty)
1261
1606
  print(f"🔄 删除字段为空的记录: {', '.join(empty_fields)}")
@@ -1268,8 +1613,72 @@ def clean(
1268
1613
  print(f"🔄 过滤 {max_len_field} 长度 > {max_len_value} 的记录...")
1269
1614
  if keep_fields:
1270
1615
  print(f"🔄 只保留字段: {', '.join(keep_fields)}")
1271
- if drop_fields:
1272
- print(f"🔄 删除字段: {', '.join(drop_fields)}")
1616
+ if drop_fields_set:
1617
+ print(f"🔄 删除字段: {', '.join(drop_fields_set)}")
1618
+
1619
+ output_path = output or str(filepath)
1620
+
1621
+ # 检查输入输出是否相同(流式处理需要临时文件)
1622
+ input_resolved = filepath.resolve()
1623
+ output_resolved = Path(output_path).resolve()
1624
+ use_temp_file = input_resolved == output_resolved
1625
+
1626
+ # 对于 JSONL 文件使用流式处理
1627
+ if _is_streaming_supported(filepath):
1628
+ print(f"📊 流式加载: {filepath}")
1629
+
1630
+ # 如果输入输出相同,使用临时文件
1631
+ if use_temp_file:
1632
+ print("⚠ 检测到输出文件与输入文件相同,将使用临时文件")
1633
+ temp_fd, temp_path = tempfile.mkstemp(
1634
+ suffix=output_resolved.suffix,
1635
+ prefix=".tmp_",
1636
+ dir=output_resolved.parent,
1637
+ )
1638
+ os.close(temp_fd)
1639
+ actual_output = temp_path
1640
+ else:
1641
+ actual_output = output_path
1642
+
1643
+ try:
1644
+ count = _clean_streaming(
1645
+ str(filepath),
1646
+ actual_output,
1647
+ strip=strip,
1648
+ empty_fields=empty_fields,
1649
+ min_len_field=min_len_field,
1650
+ min_len_value=min_len_value,
1651
+ max_len_field=max_len_field,
1652
+ max_len_value=max_len_value,
1653
+ keep_set=keep_set,
1654
+ drop_fields_set=drop_fields_set,
1655
+ )
1656
+
1657
+ # 如果使用了临时文件,移动到目标位置
1658
+ if use_temp_file:
1659
+ shutil.move(temp_path, output_path)
1660
+
1661
+ print(f"💾 保存结果: {output_path}")
1662
+ print(f"\n✅ 完成! 清洗后 {count} 条数据")
1663
+ except Exception as e:
1664
+ # 清理临时文件
1665
+ if use_temp_file and os.path.exists(temp_path):
1666
+ os.unlink(temp_path)
1667
+ print(f"错误: 清洗失败 - {e}")
1668
+ import traceback
1669
+ traceback.print_exc()
1670
+ return
1671
+
1672
+ # 非 JSONL 文件使用传统方式
1673
+ print(f"📊 加载数据: {filepath}")
1674
+ try:
1675
+ dt = DataTransformer.load(str(filepath))
1676
+ except Exception as e:
1677
+ print(f"错误: 无法读取文件 - {e}")
1678
+ return
1679
+
1680
+ original_count = len(dt)
1681
+ print(f" 共 {original_count} 条数据")
1273
1682
 
1274
1683
  # 单次遍历执行所有清洗操作
1275
1684
  data, step_stats = _clean_data_single_pass(
@@ -1281,12 +1690,11 @@ def clean(
1281
1690
  max_len_field=max_len_field,
1282
1691
  max_len_value=max_len_value,
1283
1692
  keep_fields=keep_fields,
1284
- drop_fields=set(drop_fields) if drop_fields else None,
1693
+ drop_fields=drop_fields_set,
1285
1694
  )
1286
1695
 
1287
1696
  # 保存结果
1288
1697
  final_count = len(data)
1289
- output_path = output or str(filepath)
1290
1698
  print(f"💾 保存结果: {output_path}")
1291
1699
 
1292
1700
  try:
@@ -1438,3 +1846,533 @@ def _clean_data_single_pass(
1438
1846
  step_stats.append(f"drop: {len(drop_fields)} 字段")
1439
1847
 
1440
1848
  return result, step_stats
1849
+
1850
+
1851
+ def _clean_streaming(
1852
+ input_path: str,
1853
+ output_path: str,
1854
+ strip: bool = False,
1855
+ empty_fields: Optional[List[str]] = None,
1856
+ min_len_field: Optional[str] = None,
1857
+ min_len_value: Optional[int] = None,
1858
+ max_len_field: Optional[str] = None,
1859
+ max_len_value: Optional[int] = None,
1860
+ keep_set: Optional[set] = None,
1861
+ drop_fields_set: Optional[set] = None,
1862
+ ) -> int:
1863
+ """
1864
+ 流式清洗数据。
1865
+
1866
+ Returns:
1867
+ 处理后的数据条数
1868
+ """
1869
+ def clean_filter(item: Dict) -> bool:
1870
+ """过滤函数:返回 True 保留,False 过滤"""
1871
+ # 空值过滤
1872
+ if empty_fields is not None:
1873
+ if len(empty_fields) == 0:
1874
+ if any(_is_empty_value(v) for v in item.values()):
1875
+ return False
1876
+ else:
1877
+ if any(_is_empty_value(item.get(f)) for f in empty_fields):
1878
+ return False
1879
+
1880
+ # 最小长度过滤
1881
+ if min_len_field is not None:
1882
+ if _get_value_len(item.get(min_len_field, "")) < min_len_value:
1883
+ return False
1884
+
1885
+ # 最大长度过滤
1886
+ if max_len_field is not None:
1887
+ if _get_value_len(item.get(max_len_field, "")) > max_len_value:
1888
+ return False
1889
+
1890
+ return True
1891
+
1892
+ def clean_transform(item: Dict) -> Dict:
1893
+ """转换函数:strip + 字段管理"""
1894
+ # strip 处理
1895
+ if strip:
1896
+ item = {k: v.strip() if isinstance(v, str) else v for k, v in item.items()}
1897
+
1898
+ # 字段管理
1899
+ if keep_set is not None:
1900
+ item = {k: v for k, v in item.items() if k in keep_set}
1901
+ elif drop_fields_set is not None:
1902
+ item = {k: v for k, v in item.items() if k not in drop_fields_set}
1903
+
1904
+ return item
1905
+
1906
+ # 构建流式处理链
1907
+ st = load_stream(input_path)
1908
+
1909
+ # 如果需要 strip,先执行 strip 转换(在过滤之前,这样空值检测更准确)
1910
+ if strip:
1911
+ st = st.transform(lambda x: {k: v.strip() if isinstance(v, str) else v for k, v in x.items()})
1912
+
1913
+ # 执行过滤
1914
+ if empty_fields is not None or min_len_field is not None or max_len_field is not None:
1915
+ st = st.filter(clean_filter)
1916
+
1917
+ # 执行字段管理(如果没有 strip,也需要在这里处理)
1918
+ if keep_set is not None or drop_fields_set is not None:
1919
+ def field_transform(item):
1920
+ if keep_set is not None:
1921
+ return {k: v for k, v in item.items() if k in keep_set}
1922
+ elif drop_fields_set is not None:
1923
+ return {k: v for k, v in item.items() if k not in drop_fields_set}
1924
+ return item
1925
+ st = st.transform(field_transform)
1926
+
1927
+ return st.save(output_path)
1928
+
1929
+
1930
+ # ============ Run Command ============
1931
+
1932
+
1933
+ def run(
1934
+ config: str,
1935
+ input: Optional[str] = None,
1936
+ output: Optional[str] = None,
1937
+ ) -> None:
1938
+ """
1939
+ 执行 Pipeline 配置文件。
1940
+
1941
+ Args:
1942
+ config: Pipeline YAML 配置文件路径
1943
+ input: 输入文件路径(覆盖配置中的 input)
1944
+ output: 输出文件路径(覆盖配置中的 output)
1945
+
1946
+ Examples:
1947
+ dt run pipeline.yaml
1948
+ dt run pipeline.yaml --input=new_data.jsonl
1949
+ dt run pipeline.yaml --input=data.jsonl --output=result.jsonl
1950
+ """
1951
+ config_path = Path(config)
1952
+
1953
+ if not config_path.exists():
1954
+ print(f"错误: 配置文件不存在 - {config}")
1955
+ return
1956
+
1957
+ if config_path.suffix.lower() not in (".yaml", ".yml"):
1958
+ print(f"错误: 配置文件必须是 YAML 格式 (.yaml 或 .yml)")
1959
+ return
1960
+
1961
+ # 验证配置
1962
+ errors = validate_pipeline(config)
1963
+ if errors:
1964
+ print("❌ 配置文件验证失败:")
1965
+ for err in errors:
1966
+ print(f" - {err}")
1967
+ return
1968
+
1969
+ # 执行 pipeline
1970
+ try:
1971
+ run_pipeline(config, input_file=input, output_file=output, verbose=True)
1972
+ except Exception as e:
1973
+ print(f"错误: {e}")
1974
+ import traceback
1975
+ traceback.print_exc()
1976
+
1977
+
1978
+ # ============ Token Stats Command ============
1979
+
1980
+
1981
+ def token_stats(
1982
+ filename: str,
1983
+ field: str = "messages",
1984
+ model: str = "cl100k_base",
1985
+ detailed: bool = False,
1986
+ ) -> None:
1987
+ """
1988
+ 统计数据集的 Token 信息。
1989
+
1990
+ Args:
1991
+ filename: 输入文件路径
1992
+ field: 要统计的字段(默认 messages)
1993
+ model: 分词器: cl100k_base (默认), qwen2.5, llama3, gpt-4 等
1994
+ detailed: 是否显示详细统计
1995
+
1996
+ Examples:
1997
+ dt token-stats data.jsonl
1998
+ dt token-stats data.jsonl --field=text --model=qwen2.5
1999
+ dt token-stats data.jsonl --detailed
2000
+ """
2001
+ filepath = Path(filename)
2002
+
2003
+ if not filepath.exists():
2004
+ print(f"错误: 文件不存在 - {filename}")
2005
+ return
2006
+
2007
+ if not _check_file_format(filepath):
2008
+ return
2009
+
2010
+ # 加载数据
2011
+ print(f"📊 加载数据: {filepath}")
2012
+ try:
2013
+ data = load_data(str(filepath))
2014
+ except Exception as e:
2015
+ print(f"错误: 无法读取文件 - {e}")
2016
+ return
2017
+
2018
+ if not data:
2019
+ print("文件为空")
2020
+ return
2021
+
2022
+ total = len(data)
2023
+ print(f" 共 {total} 条数据")
2024
+ print(f"🔢 统计 Token (模型: {model}, 字段: {field})...")
2025
+
2026
+ # 检查字段类型并选择合适的统计方法
2027
+ sample = data[0]
2028
+ field_value = sample.get(field)
2029
+
2030
+ try:
2031
+ if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
2032
+ # messages 格式
2033
+ from ..tokenizers import messages_token_stats
2034
+ stats = messages_token_stats(data, messages_field=field, model=model)
2035
+ _print_messages_token_stats(stats, detailed)
2036
+ else:
2037
+ # 普通文本字段
2038
+ from ..tokenizers import token_stats as compute_token_stats
2039
+ stats = compute_token_stats(data, fields=field, model=model)
2040
+ _print_text_token_stats(stats, detailed)
2041
+ except ImportError as e:
2042
+ print(f"错误: {e}")
2043
+ return
2044
+ except Exception as e:
2045
+ print(f"错误: 统计失败 - {e}")
2046
+ import traceback
2047
+ traceback.print_exc()
2048
+
2049
+
2050
+ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
2051
+ """打印 messages 格式的 token 统计"""
2052
+ try:
2053
+ from rich.console import Console
2054
+ from rich.table import Table
2055
+ from rich.panel import Panel
2056
+
2057
+ console = Console()
2058
+
2059
+ # 概览
2060
+ overview = (
2061
+ f"[bold]总样本数:[/bold] {stats['count']:,}\n"
2062
+ f"[bold]总 Token:[/bold] {stats['total_tokens']:,}\n"
2063
+ f"[bold]平均 Token:[/bold] {stats['avg_tokens']:,}\n"
2064
+ f"[bold]中位数:[/bold] {stats['median_tokens']:,}\n"
2065
+ f"[bold]范围:[/bold] {stats['min_tokens']:,} - {stats['max_tokens']:,}"
2066
+ )
2067
+ console.print(Panel(overview, title="📊 Token 统计概览", expand=False))
2068
+
2069
+ if detailed:
2070
+ # 详细统计
2071
+ table = Table(title="📋 分角色统计")
2072
+ table.add_column("角色", style="cyan")
2073
+ table.add_column("Token 数", justify="right")
2074
+ table.add_column("占比", justify="right")
2075
+
2076
+ total = stats['total_tokens']
2077
+ for role, key in [("User", "user_tokens"), ("Assistant", "assistant_tokens"), ("System", "system_tokens")]:
2078
+ tokens = stats.get(key, 0)
2079
+ pct = tokens / total * 100 if total > 0 else 0
2080
+ table.add_row(role, f"{tokens:,}", f"{pct:.1f}%")
2081
+
2082
+ console.print(table)
2083
+ console.print(f"\n平均对话轮数: {stats.get('avg_turns', 0)}")
2084
+
2085
+ except ImportError:
2086
+ # 没有 rich,使用普通打印
2087
+ print(f"\n{'=' * 40}")
2088
+ print("📊 Token 统计概览")
2089
+ print(f"{'=' * 40}")
2090
+ print(f"总样本数: {stats['count']:,}")
2091
+ print(f"总 Token: {stats['total_tokens']:,}")
2092
+ print(f"平均 Token: {stats['avg_tokens']:,}")
2093
+ print(f"中位数: {stats['median_tokens']:,}")
2094
+ print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
2095
+
2096
+ if detailed:
2097
+ print(f"\n{'=' * 40}")
2098
+ print("📋 分角色统计")
2099
+ print(f"{'=' * 40}")
2100
+ total = stats['total_tokens']
2101
+ for role, key in [("User", "user_tokens"), ("Assistant", "assistant_tokens"), ("System", "system_tokens")]:
2102
+ tokens = stats.get(key, 0)
2103
+ pct = tokens / total * 100 if total > 0 else 0
2104
+ print(f"{role}: {tokens:,} ({pct:.1f}%)")
2105
+ print(f"\n平均对话轮数: {stats.get('avg_turns', 0)}")
2106
+
2107
+
2108
+ def _print_text_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
2109
+ """打印普通文本的 token 统计"""
2110
+ try:
2111
+ from rich.console import Console
2112
+ from rich.panel import Panel
2113
+
2114
+ console = Console()
2115
+
2116
+ overview = (
2117
+ f"[bold]总样本数:[/bold] {stats['count']:,}\n"
2118
+ f"[bold]总 Token:[/bold] {stats['total_tokens']:,}\n"
2119
+ f"[bold]平均 Token:[/bold] {stats['avg_tokens']:.1f}\n"
2120
+ f"[bold]中位数:[/bold] {stats['median_tokens']:,}\n"
2121
+ f"[bold]范围:[/bold] {stats['min_tokens']:,} - {stats['max_tokens']:,}"
2122
+ )
2123
+ console.print(Panel(overview, title="📊 Token 统计", expand=False))
2124
+
2125
+ except ImportError:
2126
+ print(f"\n{'=' * 40}")
2127
+ print("📊 Token 统计")
2128
+ print(f"{'=' * 40}")
2129
+ print(f"总样本数: {stats['count']:,}")
2130
+ print(f"总 Token: {stats['total_tokens']:,}")
2131
+ print(f"平均 Token: {stats['avg_tokens']:.1f}")
2132
+ print(f"中位数: {stats['median_tokens']:,}")
2133
+ print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
2134
+
2135
+
2136
+ # ============ Diff Command ============
2137
+
2138
+
2139
+ def diff(
2140
+ file1: str,
2141
+ file2: str,
2142
+ key: Optional[str] = None,
2143
+ output: Optional[str] = None,
2144
+ ) -> None:
2145
+ """
2146
+ 对比两个数据集的差异。
2147
+
2148
+ Args:
2149
+ file1: 第一个文件路径
2150
+ file2: 第二个文件路径
2151
+ key: 用于匹配的键字段(可选)
2152
+ output: 差异报告输出路径(可选)
2153
+
2154
+ Examples:
2155
+ dt diff v1/train.jsonl v2/train.jsonl
2156
+ dt diff a.jsonl b.jsonl --key=id
2157
+ dt diff a.jsonl b.jsonl --output=diff_report.json
2158
+ """
2159
+ path1 = Path(file1)
2160
+ path2 = Path(file2)
2161
+
2162
+ # 验证文件
2163
+ for p, name in [(path1, "file1"), (path2, "file2")]:
2164
+ if not p.exists():
2165
+ print(f"错误: 文件不存在 - {p}")
2166
+ return
2167
+ if not _check_file_format(p):
2168
+ return
2169
+
2170
+ # 加载数据
2171
+ print(f"📊 加载数据...")
2172
+ try:
2173
+ data1 = load_data(str(path1))
2174
+ data2 = load_data(str(path2))
2175
+ except Exception as e:
2176
+ print(f"错误: 无法读取文件 - {e}")
2177
+ return
2178
+
2179
+ print(f" 文件1: {path1.name} ({len(data1)} 条)")
2180
+ print(f" 文件2: {path2.name} ({len(data2)} 条)")
2181
+
2182
+ # 计算差异
2183
+ print("🔍 计算差异...")
2184
+ diff_result = _compute_diff(data1, data2, key)
2185
+
2186
+ # 打印差异报告
2187
+ _print_diff_report(diff_result, path1.name, path2.name)
2188
+
2189
+ # 保存报告
2190
+ if output:
2191
+ print(f"\n💾 保存报告: {output}")
2192
+ save_data([diff_result], output)
2193
+
2194
+
2195
+ def _compute_diff(
2196
+ data1: List[Dict],
2197
+ data2: List[Dict],
2198
+ key: Optional[str] = None,
2199
+ ) -> Dict[str, Any]:
2200
+ """计算两个数据集的差异"""
2201
+ result = {
2202
+ "summary": {
2203
+ "file1_count": len(data1),
2204
+ "file2_count": len(data2),
2205
+ "added": 0,
2206
+ "removed": 0,
2207
+ "modified": 0,
2208
+ "unchanged": 0,
2209
+ },
2210
+ "field_changes": {},
2211
+ "details": {
2212
+ "added": [],
2213
+ "removed": [],
2214
+ "modified": [],
2215
+ },
2216
+ }
2217
+
2218
+ if key:
2219
+ # 基于 key 的精确匹配
2220
+ dict1 = {item.get(key): item for item in data1 if item.get(key) is not None}
2221
+ dict2 = {item.get(key): item for item in data2 if item.get(key) is not None}
2222
+
2223
+ keys1 = set(dict1.keys())
2224
+ keys2 = set(dict2.keys())
2225
+
2226
+ # 新增
2227
+ added_keys = keys2 - keys1
2228
+ result["summary"]["added"] = len(added_keys)
2229
+ result["details"]["added"] = [dict2[k] for k in list(added_keys)[:10]] # 最多显示 10 条
2230
+
2231
+ # 删除
2232
+ removed_keys = keys1 - keys2
2233
+ result["summary"]["removed"] = len(removed_keys)
2234
+ result["details"]["removed"] = [dict1[k] for k in list(removed_keys)[:10]]
2235
+
2236
+ # 修改/未变
2237
+ common_keys = keys1 & keys2
2238
+ for k in common_keys:
2239
+ if dict1[k] == dict2[k]:
2240
+ result["summary"]["unchanged"] += 1
2241
+ else:
2242
+ result["summary"]["modified"] += 1
2243
+ if len(result["details"]["modified"]) < 10:
2244
+ result["details"]["modified"].append({
2245
+ "key": k,
2246
+ "before": dict1[k],
2247
+ "after": dict2[k],
2248
+ })
2249
+ else:
2250
+ # 基于哈希的比较
2251
+ def _hash_item(item):
2252
+ return orjson.dumps(item, option=orjson.OPT_SORT_KEYS)
2253
+
2254
+ set1 = {_hash_item(item) for item in data1}
2255
+ set2 = {_hash_item(item) for item in data2}
2256
+
2257
+ added = set2 - set1
2258
+ removed = set1 - set2
2259
+ unchanged = set1 & set2
2260
+
2261
+ result["summary"]["added"] = len(added)
2262
+ result["summary"]["removed"] = len(removed)
2263
+ result["summary"]["unchanged"] = len(unchanged)
2264
+
2265
+ # 详情
2266
+ result["details"]["added"] = [orjson.loads(h) for h in list(added)[:10]]
2267
+ result["details"]["removed"] = [orjson.loads(h) for h in list(removed)[:10]]
2268
+
2269
+ # 字段变化分析
2270
+ fields1 = set()
2271
+ fields2 = set()
2272
+ for item in data1[:1000]: # 采样分析
2273
+ fields1.update(item.keys())
2274
+ for item in data2[:1000]:
2275
+ fields2.update(item.keys())
2276
+
2277
+ result["field_changes"] = {
2278
+ "added_fields": list(fields2 - fields1),
2279
+ "removed_fields": list(fields1 - fields2),
2280
+ "common_fields": list(fields1 & fields2),
2281
+ }
2282
+
2283
+ return result
2284
+
2285
+
2286
+ def _print_diff_report(diff_result: Dict[str, Any], name1: str, name2: str) -> None:
2287
+ """打印差异报告"""
2288
+ summary = diff_result["summary"]
2289
+ field_changes = diff_result["field_changes"]
2290
+
2291
+ try:
2292
+ from rich.console import Console
2293
+ from rich.table import Table
2294
+ from rich.panel import Panel
2295
+
2296
+ console = Console()
2297
+
2298
+ # 概览
2299
+ overview = (
2300
+ f"[bold]{name1}:[/bold] {summary['file1_count']:,} 条\n"
2301
+ f"[bold]{name2}:[/bold] {summary['file2_count']:,} 条\n"
2302
+ f"\n"
2303
+ f"[green]+ 新增:[/green] {summary['added']:,} 条\n"
2304
+ f"[red]- 删除:[/red] {summary['removed']:,} 条\n"
2305
+ f"[yellow]~ 修改:[/yellow] {summary['modified']:,} 条\n"
2306
+ f"[dim]= 未变:[/dim] {summary['unchanged']:,} 条"
2307
+ )
2308
+ console.print(Panel(overview, title="📊 差异概览", expand=False))
2309
+
2310
+ # 字段变化
2311
+ if field_changes["added_fields"] or field_changes["removed_fields"]:
2312
+ console.print("\n[bold]📋 字段变化:[/bold]")
2313
+ if field_changes["added_fields"]:
2314
+ console.print(f" [green]+ 新增字段:[/green] {', '.join(field_changes['added_fields'])}")
2315
+ if field_changes["removed_fields"]:
2316
+ console.print(f" [red]- 删除字段:[/red] {', '.join(field_changes['removed_fields'])}")
2317
+
2318
+ except ImportError:
2319
+ print(f"\n{'=' * 50}")
2320
+ print("📊 差异概览")
2321
+ print(f"{'=' * 50}")
2322
+ print(f"{name1}: {summary['file1_count']:,} 条")
2323
+ print(f"{name2}: {summary['file2_count']:,} 条")
2324
+ print()
2325
+ print(f"+ 新增: {summary['added']:,} 条")
2326
+ print(f"- 删除: {summary['removed']:,} 条")
2327
+ print(f"~ 修改: {summary['modified']:,} 条")
2328
+ print(f"= 未变: {summary['unchanged']:,} 条")
2329
+
2330
+ if field_changes["added_fields"] or field_changes["removed_fields"]:
2331
+ print(f"\n📋 字段变化:")
2332
+ if field_changes["added_fields"]:
2333
+ print(f" + 新增字段: {', '.join(field_changes['added_fields'])}")
2334
+ if field_changes["removed_fields"]:
2335
+ print(f" - 删除字段: {', '.join(field_changes['removed_fields'])}")
2336
+
2337
+
2338
+ # ============ History Command ============
2339
+
2340
+
2341
+ def history(
2342
+ filename: str,
2343
+ json: bool = False,
2344
+ ) -> None:
2345
+ """
2346
+ 显示数据文件的血缘历史。
2347
+
2348
+ Args:
2349
+ filename: 数据文件路径
2350
+ json: 以 JSON 格式输出
2351
+
2352
+ Examples:
2353
+ dt history data.jsonl
2354
+ dt history data.jsonl --json
2355
+ """
2356
+ filepath = Path(filename)
2357
+
2358
+ if not filepath.exists():
2359
+ print(f"错误: 文件不存在 - {filename}")
2360
+ return
2361
+
2362
+ if not has_lineage(str(filepath)):
2363
+ print(f"文件 {filename} 没有血缘记录")
2364
+ print("\n提示: 使用 track_lineage=True 加载数据,并在保存时使用 lineage=True 来记录血缘")
2365
+ print("示例:")
2366
+ print(" dt = DataTransformer.load('data.jsonl', track_lineage=True)")
2367
+ print(" dt.filter(...).transform(...).save('output.jsonl', lineage=True)")
2368
+ return
2369
+
2370
+ if json:
2371
+ # JSON 格式输出
2372
+ chain = get_lineage_chain(str(filepath))
2373
+ output = [record.to_dict() for record in chain]
2374
+ print(orjson.dumps(output, option=orjson.OPT_INDENT_2).decode("utf-8"))
2375
+ else:
2376
+ # 格式化报告
2377
+ report = format_lineage_report(str(filepath))
2378
+ print(report)