dtflow 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dtflow/cli/commands.py CHANGED
@@ -1,19 +1,34 @@
1
1
  """
2
2
  CLI 命令实现
3
3
  """
4
- import json
4
+
5
+ import os
6
+ import shutil
7
+ import tempfile
5
8
  from datetime import datetime
6
9
  from pathlib import Path
7
10
  from typing import Any, Dict, List, Literal, Optional
8
11
 
12
+ import orjson
13
+
9
14
  from ..core import DataTransformer, DictWrapper
15
+ from ..lineage import format_lineage_report, get_lineage_chain, has_lineage, load_lineage
16
+ from ..pipeline import run_pipeline, validate_pipeline
10
17
  from ..presets import get_preset, list_presets
11
- from ..storage.io import load_data, save_data, sample_file
12
-
18
+ from ..storage.io import load_data, sample_file, save_data
19
+ from ..streaming import load_stream
13
20
 
14
21
  # 支持的文件格式
15
22
  SUPPORTED_FORMATS = {".csv", ".jsonl", ".json", ".xlsx", ".xls", ".parquet", ".arrow", ".feather"}
16
23
 
24
+ # 支持流式处理的格式(与 streaming.py 保持一致)
25
+ STREAMING_FORMATS = {".jsonl", ".csv", ".parquet", ".arrow", ".feather"}
26
+
27
+
28
+ def _is_streaming_supported(filepath: Path) -> bool:
29
+ """检查文件是否支持流式处理"""
30
+ return filepath.suffix.lower() in STREAMING_FORMATS
31
+
17
32
 
18
33
  def _check_file_format(filepath: Path) -> bool:
19
34
  """检查文件格式是否支持,不支持则打印错误信息并返回 False"""
@@ -28,11 +43,12 @@ def _check_file_format(filepath: Path) -> bool:
28
43
  def sample(
29
44
  filename: str,
30
45
  num: int = 10,
31
- sample_type: Literal["random", "head", "tail"] = "head",
46
+ type: Literal["random", "head", "tail"] = "head",
32
47
  output: Optional[str] = None,
33
48
  seed: Optional[int] = None,
34
49
  by: Optional[str] = None,
35
50
  uniform: bool = False,
51
+ fields: Optional[str] = None,
36
52
  ) -> None:
37
53
  """
38
54
  从数据文件中采样指定数量的数据。
@@ -43,20 +59,22 @@ def sample(
43
59
  - num > 0: 采样指定数量
44
60
  - num = 0: 采样所有数据
45
61
  - num < 0: Python 切片风格(如 -1 表示最后 1 条,-10 表示最后 10 条)
46
- sample_type: 采样方式,可选 random/head/tail,默认 head
62
+ type: 采样方式,可选 random/head/tail,默认 head
47
63
  output: 输出文件路径,不指定则打印到控制台
48
- seed: 随机种子(仅在 sample_type=random 时有效)
64
+ seed: 随机种子(仅在 type=random 时有效)
49
65
  by: 分层采样字段名,按该字段的值分组采样
50
66
  uniform: 均匀采样模式(需配合 --by 使用),各组采样相同数量
67
+ fields: 只显示指定字段(逗号分隔),仅在预览模式下有效
51
68
 
52
69
  Examples:
53
70
  dt sample data.jsonl 5
54
- dt sample data.csv 100 --sample_type=head
71
+ dt sample data.csv 100 --type=head
55
72
  dt sample data.xlsx 50 --output=sampled.jsonl
56
73
  dt sample data.jsonl 0 # 采样所有数据
57
74
  dt sample data.jsonl -10 # 最后 10 条数据
58
75
  dt sample data.jsonl 1000 --by=category # 按比例分层采样
59
76
  dt sample data.jsonl 1000 --by=category --uniform # 均匀分层采样
77
+ dt sample data.jsonl --fields=question,answer # 只显示指定字段
60
78
  """
61
79
  filepath = Path(filename)
62
80
 
@@ -75,9 +93,7 @@ def sample(
75
93
  # 分层采样模式
76
94
  if by:
77
95
  try:
78
- sampled = _stratified_sample(
79
- filepath, num, by, uniform, seed, sample_type
80
- )
96
+ sampled = _stratified_sample(filepath, num, by, uniform, seed, type)
81
97
  except Exception as e:
82
98
  print(f"错误: {e}")
83
99
  return
@@ -87,7 +103,7 @@ def sample(
87
103
  sampled = sample_file(
88
104
  str(filepath),
89
105
  num=num,
90
- sample_type=sample_type,
106
+ sample_type=type,
91
107
  seed=seed,
92
108
  output=None, # 先不保存,统一在最后处理
93
109
  )
@@ -100,7 +116,11 @@ def sample(
100
116
  save_data(sampled, output)
101
117
  print(f"已保存 {len(sampled)} 条数据到 {output}")
102
118
  else:
103
- _print_samples(sampled)
119
+ # 获取文件总行数用于显示
120
+ total_count = _get_file_row_count(filepath)
121
+ # 解析 fields 参数
122
+ field_list = _parse_field_list(fields) if fields else None
123
+ _print_samples(sampled, filepath.name, total_count, field_list)
104
124
 
105
125
 
106
126
  def _stratified_sample(
@@ -225,9 +245,10 @@ def head(
225
245
  filename: str,
226
246
  num: int = 10,
227
247
  output: Optional[str] = None,
248
+ fields: Optional[str] = None,
228
249
  ) -> None:
229
250
  """
230
- 显示文件的前 N 条数据(dt sample --sample_type=head 的快捷方式)。
251
+ 显示文件的前 N 条数据(dt sample --type=head 的快捷方式)。
231
252
 
232
253
  Args:
233
254
  filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
@@ -236,23 +257,26 @@ def head(
236
257
  - num = 0: 显示所有数据
237
258
  - num < 0: Python 切片风格(如 -10 表示最后 10 条)
238
259
  output: 输出文件路径,不指定则打印到控制台
260
+ fields: 只显示指定字段(逗号分隔),仅在预览模式下有效
239
261
 
240
262
  Examples:
241
263
  dt head data.jsonl # 显示前 10 条
242
264
  dt head data.jsonl 20 # 显示前 20 条
243
265
  dt head data.csv 0 # 显示所有数据
244
266
  dt head data.xlsx --output=head.jsonl
267
+ dt head data.jsonl --fields=question,answer
245
268
  """
246
- sample(filename, num=num, sample_type="head", output=output)
269
+ sample(filename, num=num, type="head", output=output, fields=fields)
247
270
 
248
271
 
249
272
  def tail(
250
273
  filename: str,
251
274
  num: int = 10,
252
275
  output: Optional[str] = None,
276
+ fields: Optional[str] = None,
253
277
  ) -> None:
254
278
  """
255
- 显示文件的后 N 条数据(dt sample --sample_type=tail 的快捷方式)。
279
+ 显示文件的后 N 条数据(dt sample --type=tail 的快捷方式)。
256
280
 
257
281
  Args:
258
282
  filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
@@ -261,58 +285,249 @@ def tail(
261
285
  - num = 0: 显示所有数据
262
286
  - num < 0: Python 切片风格(如 -10 表示最后 10 条)
263
287
  output: 输出文件路径,不指定则打印到控制台
288
+ fields: 只显示指定字段(逗号分隔),仅在预览模式下有效
264
289
 
265
290
  Examples:
266
291
  dt tail data.jsonl # 显示后 10 条
267
292
  dt tail data.jsonl 20 # 显示后 20 条
268
293
  dt tail data.csv 0 # 显示所有数据
269
294
  dt tail data.xlsx --output=tail.jsonl
295
+ dt tail data.jsonl --fields=question,answer
296
+ """
297
+ sample(filename, num=num, type="tail", output=output, fields=fields)
298
+
299
+
300
+ def _get_file_row_count(filepath: Path) -> Optional[int]:
301
+ """
302
+ 快速获取文件行数(不加载全部数据)。
303
+
304
+ 对于 JSONL 文件,直接计算行数;其他格式返回 None。
305
+ """
306
+ ext = filepath.suffix.lower()
307
+ if ext == ".jsonl":
308
+ try:
309
+ with open(filepath, "rb") as f:
310
+ return sum(1 for _ in f)
311
+ except Exception:
312
+ return None
313
+ # 其他格式暂不支持快速计数
314
+ return None
315
+
316
+
317
+ def _format_value(value: Any, max_len: int = 80) -> str:
318
+ """格式化单个值,长文本截断。"""
319
+ if value is None:
320
+ return "[dim]null[/dim]"
321
+ if isinstance(value, bool):
322
+ return "[cyan]true[/cyan]" if value else "[cyan]false[/cyan]"
323
+ if isinstance(value, (int, float)):
324
+ return f"[cyan]{value}[/cyan]"
325
+ if isinstance(value, str):
326
+ # 处理多行文本
327
+ if "\n" in value:
328
+ lines = value.split("\n")
329
+ if len(lines) > 3:
330
+ preview = lines[0][:max_len] + f"... [dim]({len(lines)} 行)[/dim]"
331
+ else:
332
+ preview = value.replace("\n", "\\n")
333
+ if len(preview) > max_len:
334
+ preview = preview[:max_len] + "..."
335
+ return f'"{preview}"'
336
+ if len(value) > max_len:
337
+ return f'"{value[:max_len]}..." [dim]({len(value)} 字符)[/dim]'
338
+ return f'"{value}"'
339
+ return str(value)
340
+
341
+
342
+ def _format_nested(
343
+ value: Any,
344
+ indent: str = "",
345
+ is_last: bool = True,
346
+ max_len: int = 80,
347
+ ) -> List[str]:
270
348
  """
271
- sample(filename, num=num, sample_type="tail", output=output)
349
+ 递归格式化嵌套结构,返回行列表。
350
+
351
+ 使用树形符号展示结构:
352
+ ├─ 中间项
353
+ └─ 最后一项
354
+ """
355
+ lines = []
356
+ branch = "└─ " if is_last else "├─ "
357
+ cont = " " if is_last else "│ "
358
+
359
+ if isinstance(value, dict):
360
+ items = list(value.items())
361
+ for i, (k, v) in enumerate(items):
362
+ is_last_item = i == len(items) - 1
363
+ b = "└─ " if is_last_item else "├─ "
364
+ c = " " if is_last_item else "│ "
365
+
366
+ if isinstance(v, (dict, list)) and v:
367
+ # 嵌套结构
368
+ if isinstance(v, list):
369
+ # 检测是否为 messages 格式
370
+ is_messages = (
371
+ v and isinstance(v[0], dict) and "role" in v[0] and "content" in v[0]
372
+ )
373
+ if is_messages:
374
+ lines.append(
375
+ f"{indent}{b}[green]{k}[/green]: ({len(v)} items) [dim]→ \\[role]: content[/dim]"
376
+ )
377
+ else:
378
+ lines.append(f"{indent}{b}[green]{k}[/green]: ({len(v)} items)")
379
+ else:
380
+ lines.append(f"{indent}{b}[green]{k}[/green]:")
381
+ lines.extend(_format_nested(v, indent + c, True, max_len))
382
+ else:
383
+ # 简单值
384
+ lines.append(f"{indent}{b}[green]{k}[/green]: {_format_value(v, max_len)}")
385
+
386
+ elif isinstance(value, list):
387
+ for i, item in enumerate(value):
388
+ is_last_item = i == len(value) - 1
389
+ b = "└─ " if is_last_item else "├─ "
390
+ c = " " if is_last_item else "│ "
391
+
392
+ if isinstance(item, dict):
393
+ # 列表中的字典项 - 检测是否为 messages 格式
394
+ if "role" in item and "content" in item:
395
+ role = item.get("role", "")
396
+ content = item.get("content", "")
397
+ # 截断长内容
398
+ if len(content) > max_len:
399
+ content = content[:max_len].replace("\n", "\\n") + "..."
400
+ else:
401
+ content = content.replace("\n", "\\n")
402
+ # 使用 \[ 转义避免被 rich 解析为样式
403
+ lines.append(f"{indent}{b}[yellow]\\[{role}]:[/yellow] {content}")
404
+ else:
405
+ # 普通字典
406
+ lines.append(f"{indent}{b}[dim]{{...}}[/dim]")
407
+ lines.extend(_format_nested(item, indent + c, True, max_len))
408
+ elif isinstance(item, list):
409
+ lines.append(f"{indent}{b}[dim][{len(item)} items][/dim]")
410
+ lines.extend(_format_nested(item, indent + c, True, max_len))
411
+ else:
412
+ lines.append(f"{indent}{b}{_format_value(item, max_len)}")
413
+
414
+ return lines
415
+
272
416
 
417
+ def _is_simple_data(samples: List[Dict]) -> bool:
418
+ """判断数据是否适合表格展示(无嵌套结构)。"""
419
+ if not samples or not isinstance(samples[0], dict):
420
+ return False
421
+ keys = list(samples[0].keys())
422
+ if len(keys) > 6:
423
+ return False
424
+ for s in samples[:3]:
425
+ for k in keys:
426
+ v = s.get(k)
427
+ if isinstance(v, (dict, list)):
428
+ return False
429
+ if isinstance(v, str) and len(v) > 80:
430
+ return False
431
+ return True
273
432
 
274
- def _print_samples(samples: list) -> None:
275
- """打印采样结果。"""
433
+
434
+ def _print_samples(
435
+ samples: list,
436
+ filename: Optional[str] = None,
437
+ total_count: Optional[int] = None,
438
+ fields: Optional[List[str]] = None,
439
+ ) -> None:
440
+ """
441
+ 打印采样结果。
442
+
443
+ Args:
444
+ samples: 采样数据列表
445
+ filename: 文件名(用于显示概览)
446
+ total_count: 文件总行数(用于显示概览)
447
+ fields: 只显示指定字段
448
+ """
276
449
  if not samples:
277
450
  print("没有数据")
278
451
  return
279
452
 
453
+ # 过滤字段
454
+ if fields and isinstance(samples[0], dict):
455
+ field_set = set(fields)
456
+ samples = [{k: v for k, v in item.items() if k in field_set} for item in samples]
457
+
280
458
  try:
281
459
  from rich.console import Console
282
- from rich.json import JSON
460
+ from rich.panel import Panel
283
461
  from rich.table import Table
284
462
 
285
463
  console = Console()
286
464
 
287
- # 尝试以表格形式展示
288
- if isinstance(samples[0], dict):
465
+ # 显示数据概览头部
466
+ if filename:
467
+ all_fields = set()
468
+ for item in samples:
469
+ if isinstance(item, dict):
470
+ all_fields.update(item.keys())
471
+ field_names = ", ".join(sorted(all_fields))
472
+
473
+ if total_count is not None:
474
+ info = f"总行数: {total_count:,} | 采样: {len(samples)} 条 | 字段: {len(all_fields)} 个"
475
+ else:
476
+ info = f"采样: {len(samples)} 条 | 字段: {len(all_fields)} 个"
477
+
478
+ console.print(
479
+ Panel(
480
+ f"[dim]{info}[/dim]\n[dim]字段: {field_names}[/dim]",
481
+ title=f"[bold]📊 {filename}[/bold]",
482
+ expand=False,
483
+ border_style="dim",
484
+ )
485
+ )
486
+ console.print()
487
+
488
+ # 简单数据用表格展示
489
+ if _is_simple_data(samples):
289
490
  keys = list(samples[0].keys())
290
- # 适合表格展示:字段不太多且值不太长
291
- if len(keys) <= 5 and all(
292
- len(str(s.get(k, ""))) < 100 for s in samples[:3] for k in keys
293
- ):
294
- table = Table(title=f"采样结果 ({len(samples)} )")
295
- for key in keys:
296
- table.add_column(key, overflow="fold")
297
- for item in samples:
298
- table.add_row(*[str(item.get(k, "")) for k in keys])
299
- console.print(table)
300
- return
301
-
302
- # 以 JSON 形式展示
491
+ table = Table(show_header=True, header_style="bold cyan")
492
+ for key in keys:
493
+ table.add_column(key, overflow="fold")
494
+ for item in samples:
495
+ table.add_row(*[str(item.get(k, "")) for k in keys])
496
+ console.print(table)
497
+ return
498
+
499
+ # 嵌套数据用树形结构展示
303
500
  for i, item in enumerate(samples, 1):
304
- console.print(f"\n[bold cyan]--- 第 {i} 条 ---[/bold cyan]")
305
- console.print(JSON.from_data(item))
501
+ console.print(f"[bold cyan]--- 第 {i} 条 ---[/bold cyan]")
502
+ if isinstance(item, dict):
503
+ for line in _format_nested(item):
504
+ console.print(line)
505
+ else:
506
+ console.print(_format_value(item))
507
+ console.print()
306
508
 
307
509
  except ImportError:
308
510
  # 没有 rich,使用普通打印
309
- import json
511
+ if filename:
512
+ all_fields = set()
513
+ for item in samples:
514
+ if isinstance(item, dict):
515
+ all_fields.update(item.keys())
516
+
517
+ print(f"\n📊 {filename}")
518
+ if total_count is not None:
519
+ print(
520
+ f" 总行数: {total_count:,} | 采样: {len(samples)} 条 | 字段: {len(all_fields)} 个"
521
+ )
522
+ else:
523
+ print(f" 采样: {len(samples)} 条 | 字段: {len(all_fields)} 个")
524
+ print(f" 字段: {', '.join(sorted(all_fields))}")
525
+ print()
310
526
 
311
527
  for i, item in enumerate(samples, 1):
312
- print(f"\n--- 第 {i} 条 ---")
313
- print(json.dumps(item, ensure_ascii=False, indent=2))
314
-
315
- print(f"\n共 {len(samples)} 条数据")
528
+ print(f"--- 第 {i} 条 ---")
529
+ print(orjson.dumps(item, option=orjson.OPT_INDENT_2).decode("utf-8"))
530
+ print()
316
531
 
317
532
 
318
533
  # ============ Transform Command ============
@@ -522,17 +737,16 @@ def _format_example_value(value: Any, max_len: int = 50) -> str:
522
737
  # 截断长字符串
523
738
  if len(value) > max_len:
524
739
  value = value[:max_len] + "..."
525
- # 转义并加引号
526
- escaped = value.replace("\\", "\\\\").replace('"', '\\"').replace("\n", "\\n")
527
- return f'"{escaped}"'
740
+ # 使用 repr() 自动处理所有转义字符
741
+ return repr(value)
528
742
  if isinstance(value, bool):
529
743
  return str(value)
530
744
  if isinstance(value, (int, float)):
531
745
  return str(value)
532
746
  if isinstance(value, (list, dict)):
533
- s = json.dumps(value, ensure_ascii=False)
747
+ s = orjson.dumps(value).decode("utf-8")
534
748
  if len(s) > max_len:
535
- return f"{s[:max_len]}..."
749
+ return repr(s[:max_len] + "...")
536
750
  return s
537
751
  return '""'
538
752
 
@@ -570,7 +784,7 @@ def _generate_default_transform(field_names: List[str]) -> str:
570
784
  for name in field_names[:5]: # 最多显示 5 个字段
571
785
  safe_name, _ = _sanitize_field_name(name)
572
786
  lines.append(f' "{name}": item.{safe_name},')
573
- return "\n".join(lines) if lines else ' # 在这里定义输出字段'
787
+ return "\n".join(lines) if lines else " # 在这里定义输出字段"
574
788
 
575
789
 
576
790
  def _execute_transform(
@@ -579,7 +793,7 @@ def _execute_transform(
579
793
  output_override: Optional[str],
580
794
  num: Optional[int],
581
795
  ) -> None:
582
- """执行数据转换"""
796
+ """执行数据转换(默认流式处理)"""
583
797
  print(f"📂 加载配置: {config_path}")
584
798
 
585
799
  # 动态加载配置文件
@@ -599,7 +813,29 @@ def _execute_transform(
599
813
  # 获取输出路径
600
814
  output_path = output_override or config_ns.get("output", "output.jsonl")
601
815
 
602
- # 加载数据并使用 DataTransformer 执行转换
816
+ # 对于 JSONL 文件使用流式处理
817
+ if _is_streaming_supported(input_path):
818
+ print(f"📊 流式加载: {input_path}")
819
+ print("🔄 执行转换...")
820
+ try:
821
+ # 包装转换函数以支持属性访问(配置文件中定义的 Item 类)
822
+ def wrapped_transform(item):
823
+ return transform_func(DictWrapper(item))
824
+
825
+ st = load_stream(str(input_path))
826
+ if num:
827
+ st = st.head(num)
828
+ count = st.transform(wrapped_transform).save(output_path)
829
+ print(f"💾 保存结果: {output_path}")
830
+ print(f"\n✅ 完成! 已转换 {count} 条数据到 {output_path}")
831
+ except Exception as e:
832
+ print(f"错误: 转换失败 - {e}")
833
+ import traceback
834
+
835
+ traceback.print_exc()
836
+ return
837
+
838
+ # 非 JSONL 文件使用传统方式
603
839
  print(f"📊 加载数据: {input_path}")
604
840
  try:
605
841
  dt = DataTransformer.load(str(input_path))
@@ -621,6 +857,7 @@ def _execute_transform(
621
857
  except Exception as e:
622
858
  print(f"错误: 转换失败 - {e}")
623
859
  import traceback
860
+
624
861
  traceback.print_exc()
625
862
  return
626
863
 
@@ -641,7 +878,7 @@ def _execute_preset_transform(
641
878
  output_override: Optional[str],
642
879
  num: Optional[int],
643
880
  ) -> None:
644
- """使用预设模板执行转换"""
881
+ """使用预设模板执行转换(默认流式处理)"""
645
882
  print(f"📂 使用预设: {preset_name}")
646
883
 
647
884
  # 获取预设函数
@@ -652,7 +889,58 @@ def _execute_preset_transform(
652
889
  print(f"可用预设: {', '.join(list_presets())}")
653
890
  return
654
891
 
655
- # 加载数据
892
+ output_path = output_override or f"{input_path.stem}_{preset_name}.jsonl"
893
+
894
+ # 检查输入输出是否相同
895
+ input_resolved = input_path.resolve()
896
+ output_resolved = Path(output_path).resolve()
897
+ use_temp_file = input_resolved == output_resolved
898
+
899
+ # 对于 JSONL 文件使用流式处理
900
+ if _is_streaming_supported(input_path):
901
+ print(f"📊 流式加载: {input_path}")
902
+ print("🔄 执行转换...")
903
+
904
+ # 如果输入输出相同,使用临时文件
905
+ if use_temp_file:
906
+ print("⚠ 检测到输出文件与输入文件相同,将使用临时文件")
907
+ temp_fd, temp_path = tempfile.mkstemp(
908
+ suffix=output_resolved.suffix,
909
+ prefix=".tmp_",
910
+ dir=output_resolved.parent,
911
+ )
912
+ os.close(temp_fd)
913
+ actual_output = temp_path
914
+ else:
915
+ actual_output = output_path
916
+
917
+ try:
918
+ # 包装转换函数以支持属性访问
919
+ def wrapped_transform(item):
920
+ return transform_func(DictWrapper(item))
921
+
922
+ st = load_stream(str(input_path))
923
+ if num:
924
+ st = st.head(num)
925
+ count = st.transform(wrapped_transform).save(actual_output)
926
+
927
+ # 如果使用了临时文件,移动到目标位置
928
+ if use_temp_file:
929
+ shutil.move(temp_path, output_path)
930
+
931
+ print(f"💾 保存结果: {output_path}")
932
+ print(f"\n✅ 完成! 已转换 {count} 条数据到 {output_path}")
933
+ except Exception as e:
934
+ # 清理临时文件
935
+ if use_temp_file and os.path.exists(temp_path):
936
+ os.unlink(temp_path)
937
+ print(f"错误: 转换失败 - {e}")
938
+ import traceback
939
+
940
+ traceback.print_exc()
941
+ return
942
+
943
+ # 非 JSONL 文件使用传统方式
656
944
  print(f"📊 加载数据: {input_path}")
657
945
  try:
658
946
  dt = DataTransformer.load(str(input_path))
@@ -674,11 +962,11 @@ def _execute_preset_transform(
674
962
  except Exception as e:
675
963
  print(f"错误: 转换失败 - {e}")
676
964
  import traceback
965
+
677
966
  traceback.print_exc()
678
967
  return
679
968
 
680
969
  # 保存结果
681
- output_path = output_override or f"{input_path.stem}_{preset_name}.jsonl"
682
970
  print(f"💾 保存结果: {output_path}")
683
971
  try:
684
972
  save_data(results, output_path)
@@ -809,7 +1097,7 @@ def concat(
809
1097
  strict: bool = False,
810
1098
  ) -> None:
811
1099
  """
812
- 拼接多个数据文件。
1100
+ 拼接多个数据文件(流式处理,内存占用 O(1))。
813
1101
 
814
1102
  Args:
815
1103
  *files: 输入文件路径列表,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
@@ -832,7 +1120,7 @@ def concat(
832
1120
  # 验证所有文件
833
1121
  file_paths = []
834
1122
  for f in files:
835
- filepath = Path(f)
1123
+ filepath = Path(f).resolve() # 使用绝对路径进行比较
836
1124
  if not filepath.exists():
837
1125
  print(f"错误: 文件不存在 - {f}")
838
1126
  return
@@ -840,31 +1128,42 @@ def concat(
840
1128
  return
841
1129
  file_paths.append(filepath)
842
1130
 
843
- # 分析各文件的字段
1131
+ # 检查输出文件是否与输入文件冲突
1132
+ output_path = Path(output).resolve()
1133
+ use_temp_file = output_path in file_paths
1134
+ if use_temp_file:
1135
+ print("⚠ 检测到输出文件与输入文件相同,将使用临时文件")
1136
+
1137
+ # 流式分析字段(只读取每个文件的第一行)
844
1138
  print("📊 文件字段分析:")
845
- file_infos = [] # [(filepath, data, fields, count)]
1139
+ file_fields = [] # [(filepath, fields)]
846
1140
 
847
1141
  for filepath in file_paths:
848
1142
  try:
849
- data = load_data(str(filepath))
1143
+ # 只读取第一行来获取字段(根据格式选择加载方式)
1144
+ if _is_streaming_supported(filepath):
1145
+ first_row = load_stream(str(filepath)).head(1).collect()
1146
+ else:
1147
+ # 非流式格式(如 .json, .xlsx)使用全量加载
1148
+ data = load_data(str(filepath))
1149
+ first_row = data[:1] if data else []
1150
+ if not first_row:
1151
+ print(f"警告: 文件为空 - {filepath}")
1152
+ fields = set()
1153
+ else:
1154
+ fields = set(first_row[0].keys())
850
1155
  except Exception as e:
851
1156
  print(f"错误: 无法读取文件 {filepath} - {e}")
852
1157
  return
853
1158
 
854
- if not data:
855
- print(f"警告: 文件为空 - {filepath}")
856
- fields = set()
857
- else:
858
- fields = set(data[0].keys())
859
-
860
- file_infos.append((filepath, data, fields, len(data)))
1159
+ file_fields.append((filepath, fields))
861
1160
  fields_str = ", ".join(sorted(fields)) if fields else "(空)"
862
- print(f" {filepath.name}: {fields_str} ({len(data)} 条)")
1161
+ print(f" {filepath.name}: {fields_str}")
863
1162
 
864
1163
  # 分析字段差异
865
1164
  all_fields = set()
866
1165
  common_fields = None
867
- for _, _, fields, _ in file_infos:
1166
+ for _, fields in file_fields:
868
1167
  all_fields.update(fields)
869
1168
  if common_fields is None:
870
1169
  common_fields = fields.copy()
@@ -883,25 +1182,78 @@ def concat(
883
1182
  else:
884
1183
  print(f"\n⚠ 字段差异: {', '.join(sorted(diff_fields))} 仅在部分文件中存在")
885
1184
 
886
- # 执行拼接
887
- print("\n🔄 执行拼接...")
888
- all_data = []
889
- for _, data, _, _ in file_infos:
890
- all_data.extend(data)
1185
+ # 流式拼接
1186
+ print("\n🔄 流式拼接...")
1187
+
1188
+ # 如果输出文件与输入文件冲突,使用临时文件(在输出文件同一目录下)
1189
+ if use_temp_file:
1190
+ output_dir = output_path.parent
1191
+ temp_fd, temp_path = tempfile.mkstemp(
1192
+ suffix=output_path.suffix,
1193
+ prefix=".tmp_",
1194
+ dir=output_dir,
1195
+ )
1196
+ os.close(temp_fd)
1197
+ actual_output = temp_path
1198
+ print(f"💾 写入临时文件: {temp_path}")
1199
+ else:
1200
+ actual_output = output
1201
+ print(f"💾 保存结果: {output}")
891
1202
 
892
- # 保存结果
893
- print(f"💾 保存结果: {output}")
894
1203
  try:
895
- save_data(all_data, output)
1204
+ total_count = _concat_streaming(file_paths, actual_output)
1205
+
1206
+ # 如果使用了临时文件,重命名为目标文件
1207
+ if use_temp_file:
1208
+ shutil.move(temp_path, output)
1209
+ print(f"💾 移动到目标文件: {output}")
896
1210
  except Exception as e:
897
- print(f"错误: 无法保存文件 - {e}")
1211
+ # 清理临时文件
1212
+ if use_temp_file and os.path.exists(temp_path):
1213
+ os.unlink(temp_path)
1214
+ print(f"错误: 拼接失败 - {e}")
898
1215
  return
899
1216
 
900
- total_count = len(all_data)
901
1217
  file_count = len(files)
902
1218
  print(f"\n✅ 完成! 已合并 {file_count} 个文件,共 {total_count} 条数据到 {output}")
903
1219
 
904
1220
 
1221
+ def _concat_streaming(file_paths: List[Path], output: str) -> int:
1222
+ """流式拼接多个文件"""
1223
+ from ..streaming import (
1224
+ StreamingTransformer,
1225
+ _stream_arrow,
1226
+ _stream_csv,
1227
+ _stream_jsonl,
1228
+ _stream_parquet,
1229
+ )
1230
+
1231
+ def generator():
1232
+ for filepath in file_paths:
1233
+ ext = filepath.suffix.lower()
1234
+ if ext == ".jsonl":
1235
+ yield from _stream_jsonl(str(filepath))
1236
+ elif ext == ".csv":
1237
+ yield from _stream_csv(str(filepath))
1238
+ elif ext == ".parquet":
1239
+ yield from _stream_parquet(str(filepath))
1240
+ elif ext in (".arrow", ".feather"):
1241
+ yield from _stream_arrow(str(filepath))
1242
+ elif ext in (".json",):
1243
+ # JSON 需要全量加载
1244
+ data = load_data(str(filepath))
1245
+ yield from data
1246
+ elif ext in (".xlsx", ".xls"):
1247
+ # Excel 需要全量加载
1248
+ data = load_data(str(filepath))
1249
+ yield from data
1250
+ else:
1251
+ yield from _stream_jsonl(str(filepath))
1252
+
1253
+ st = StreamingTransformer(generator())
1254
+ return st.save(output, show_progress=True)
1255
+
1256
+
905
1257
  # ============ Stats Command ============
906
1258
 
907
1259
 
@@ -992,8 +1344,8 @@ def _compute_field_stats(data: List[Dict], top: int) -> List[Dict[str, Any]]:
992
1344
 
993
1345
  # 类型特定统计
994
1346
  if non_null:
995
- # 唯一值计数
996
- stat["unique"] = len(set(str(v) for v in non_null))
1347
+ # 唯一值计数(对复杂类型使用 hash 节省内存)
1348
+ stat["unique"] = _count_unique(non_null, field_type)
997
1349
 
998
1350
  # 字符串类型:计算长度统计
999
1351
  if field_type == "str":
@@ -1025,6 +1377,28 @@ def _compute_field_stats(data: List[Dict], top: int) -> List[Dict[str, Any]]:
1025
1377
  return stats_list
1026
1378
 
1027
1379
 
1380
+ def _count_unique(values: List[Any], field_type: str) -> int:
1381
+ """
1382
+ 计算唯一值数量。
1383
+
1384
+ 对于简单类型直接比较,对于 list/dict 使用 hash 节省内存。
1385
+ """
1386
+ if field_type in ("list", "dict"):
1387
+ # 复杂类型:使用 orjson 序列化后计算 hash
1388
+ import hashlib
1389
+
1390
+ import orjson
1391
+
1392
+ seen = set()
1393
+ for v in values:
1394
+ h = hashlib.md5(orjson.dumps(v, option=orjson.OPT_SORT_KEYS)).digest()
1395
+ seen.add(h)
1396
+ return len(seen)
1397
+ else:
1398
+ # 简单类型:直接比较
1399
+ return len(set(values))
1400
+
1401
+
1028
1402
  def _infer_type(values: List[Any]) -> str:
1029
1403
  """推断字段类型"""
1030
1404
  if not values:
@@ -1058,12 +1432,16 @@ def _truncate(v: Any, max_width: int) -> str:
1058
1432
  result = []
1059
1433
  for char in s:
1060
1434
  # CJK 字符范围
1061
- if '\u4e00' <= char <= '\u9fff' or '\u3000' <= char <= '\u303f' or '\uff00' <= char <= '\uffef':
1435
+ if (
1436
+ "\u4e00" <= char <= "\u9fff"
1437
+ or "\u3000" <= char <= "\u303f"
1438
+ or "\uff00" <= char <= "\uffef"
1439
+ ):
1062
1440
  char_width = 2
1063
1441
  else:
1064
1442
  char_width = 1
1065
1443
  if width + char_width > max_width - 3: # 预留 ... 的宽度
1066
- return ''.join(result) + "..."
1444
+ return "".join(result) + "..."
1067
1445
  result.append(char)
1068
1446
  width += char_width
1069
1447
  return s
@@ -1074,7 +1452,11 @@ def _display_width(s: str) -> int:
1074
1452
  width = 0
1075
1453
  for char in s:
1076
1454
  # CJK 字符范围
1077
- if '\u4e00' <= char <= '\u9fff' or '\u3000' <= char <= '\u303f' or '\uff00' <= char <= '\uffef':
1455
+ if (
1456
+ "\u4e00" <= char <= "\u9fff"
1457
+ or "\u3000" <= char <= "\u303f"
1458
+ or "\uff00" <= char <= "\uffef"
1459
+ ):
1078
1460
  width += 2
1079
1461
  else:
1080
1462
  width += 1
@@ -1086,26 +1468,28 @@ def _pad_to_width(s: str, target_width: int) -> str:
1086
1468
  current_width = _display_width(s)
1087
1469
  if current_width >= target_width:
1088
1470
  return s
1089
- return s + ' ' * (target_width - current_width)
1471
+ return s + " " * (target_width - current_width)
1090
1472
 
1091
1473
 
1092
1474
  def _print_stats(filename: str, total: int, field_stats: List[Dict[str, Any]]) -> None:
1093
1475
  """打印统计信息"""
1094
1476
  try:
1095
1477
  from rich.console import Console
1096
- from rich.table import Table
1097
1478
  from rich.panel import Panel
1479
+ from rich.table import Table
1098
1480
 
1099
1481
  console = Console()
1100
1482
 
1101
1483
  # 概览
1102
- console.print(Panel(
1103
- f"[bold]文件:[/bold] {filename}\n"
1104
- f"[bold]总数:[/bold] {total:,} 条\n"
1105
- f"[bold]字段:[/bold] {len(field_stats)} ",
1106
- title="📊 数据概览",
1107
- expand=False,
1108
- ))
1484
+ console.print(
1485
+ Panel(
1486
+ f"[bold]文件:[/bold] {filename}\n"
1487
+ f"[bold]总数:[/bold] {total:,} 条\n"
1488
+ f"[bold]字段:[/bold] {len(field_stats)} 个",
1489
+ title="📊 数据概览",
1490
+ expand=False,
1491
+ )
1492
+ )
1109
1493
 
1110
1494
  # 字段统计表
1111
1495
  table = Table(title="📋 字段统计", show_header=True, header_style="bold cyan")
@@ -1122,12 +1506,18 @@ def _print_stats(filename: str, total: int, field_stats: List[Dict[str, Any]]) -
1122
1506
  # 构建统计信息字符串
1123
1507
  extra = []
1124
1508
  if "len_avg" in stat:
1125
- extra.append(f"长度: {stat['len_min']}-{stat['len_max']} (avg {stat['len_avg']:.0f})")
1509
+ extra.append(
1510
+ f"长度: {stat['len_min']}-{stat['len_max']} (avg {stat['len_avg']:.0f})"
1511
+ )
1126
1512
  if "avg" in stat:
1127
1513
  if stat["type"] == "int":
1128
- extra.append(f"范围: {int(stat['min'])}-{int(stat['max'])} (avg {stat['avg']:.1f})")
1514
+ extra.append(
1515
+ f"范围: {int(stat['min'])}-{int(stat['max'])} (avg {stat['avg']:.1f})"
1516
+ )
1129
1517
  else:
1130
- extra.append(f"范围: {stat['min']:.2f}-{stat['max']:.2f} (avg {stat['avg']:.2f})")
1518
+ extra.append(
1519
+ f"范围: {stat['min']:.2f}-{stat['max']:.2f} (avg {stat['avg']:.2f})"
1520
+ )
1131
1521
 
1132
1522
  table.add_row(
1133
1523
  stat["field"],
@@ -1154,7 +1544,9 @@ def _print_stats(filename: str, total: int, field_stats: List[Dict[str, Any]]) -
1154
1544
  if unique_ratio > 0.9 and stat.get("unique", 0) > 100:
1155
1545
  continue
1156
1546
 
1157
- console.print(f"\n[bold cyan]{stat['field']}[/bold cyan] 值分布 (Top {len(top_values)}):")
1547
+ console.print(
1548
+ f"\n[bold cyan]{stat['field']}[/bold cyan] 值分布 (Top {len(top_values)}):"
1549
+ )
1158
1550
  max_count = max(c for _, c in top_values) if top_values else 1
1159
1551
  for value, count in top_values:
1160
1552
  pct = count / total * 100
@@ -1200,7 +1592,7 @@ def clean(
1200
1592
  output: Optional[str] = None,
1201
1593
  ) -> None:
1202
1594
  """
1203
- 数据清洗。
1595
+ 数据清洗(默认流式处理)。
1204
1596
 
1205
1597
  Args:
1206
1598
  filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
@@ -1233,29 +1625,19 @@ def clean(
1233
1625
  if not _check_file_format(filepath):
1234
1626
  return
1235
1627
 
1236
- # 加载数据
1237
- print(f"📊 加载数据: {filepath}")
1238
- try:
1239
- dt = DataTransformer.load(str(filepath))
1240
- except Exception as e:
1241
- print(f"错误: 无法读取文件 - {e}")
1242
- return
1243
-
1244
- original_count = len(dt)
1245
- print(f" 共 {original_count} 条数据")
1246
-
1247
- # 解析参数(fire 可能会将逗号分隔的值解析为元组)
1628
+ # 解析参数
1248
1629
  min_len_field, min_len_value = _parse_len_param(min_len) if min_len else (None, None)
1249
1630
  max_len_field, max_len_value = _parse_len_param(max_len) if max_len else (None, None)
1250
1631
  keep_fields = _parse_field_list(keep) if keep else None
1251
- drop_fields = _parse_field_list(drop) if drop else None
1632
+ drop_fields_set = set(_parse_field_list(drop)) if drop else None
1633
+ keep_set = set(keep_fields) if keep_fields else None
1252
1634
 
1253
1635
  # 构建清洗配置
1254
1636
  empty_fields = None
1255
1637
  if drop_empty is not None:
1256
1638
  if drop_empty == "" or drop_empty is True:
1257
1639
  print("🔄 删除任意字段为空的记录...")
1258
- empty_fields = [] # 空列表表示检查所有字段
1640
+ empty_fields = []
1259
1641
  else:
1260
1642
  empty_fields = _parse_field_list(drop_empty)
1261
1643
  print(f"🔄 删除字段为空的记录: {', '.join(empty_fields)}")
@@ -1268,8 +1650,73 @@ def clean(
1268
1650
  print(f"🔄 过滤 {max_len_field} 长度 > {max_len_value} 的记录...")
1269
1651
  if keep_fields:
1270
1652
  print(f"🔄 只保留字段: {', '.join(keep_fields)}")
1271
- if drop_fields:
1272
- print(f"🔄 删除字段: {', '.join(drop_fields)}")
1653
+ if drop_fields_set:
1654
+ print(f"🔄 删除字段: {', '.join(drop_fields_set)}")
1655
+
1656
+ output_path = output or str(filepath)
1657
+
1658
+ # 检查输入输出是否相同(流式处理需要临时文件)
1659
+ input_resolved = filepath.resolve()
1660
+ output_resolved = Path(output_path).resolve()
1661
+ use_temp_file = input_resolved == output_resolved
1662
+
1663
+ # 对于 JSONL 文件使用流式处理
1664
+ if _is_streaming_supported(filepath):
1665
+ print(f"📊 流式加载: {filepath}")
1666
+
1667
+ # 如果输入输出相同,使用临时文件
1668
+ if use_temp_file:
1669
+ print("⚠ 检测到输出文件与输入文件相同,将使用临时文件")
1670
+ temp_fd, temp_path = tempfile.mkstemp(
1671
+ suffix=output_resolved.suffix,
1672
+ prefix=".tmp_",
1673
+ dir=output_resolved.parent,
1674
+ )
1675
+ os.close(temp_fd)
1676
+ actual_output = temp_path
1677
+ else:
1678
+ actual_output = output_path
1679
+
1680
+ try:
1681
+ count = _clean_streaming(
1682
+ str(filepath),
1683
+ actual_output,
1684
+ strip=strip,
1685
+ empty_fields=empty_fields,
1686
+ min_len_field=min_len_field,
1687
+ min_len_value=min_len_value,
1688
+ max_len_field=max_len_field,
1689
+ max_len_value=max_len_value,
1690
+ keep_set=keep_set,
1691
+ drop_fields_set=drop_fields_set,
1692
+ )
1693
+
1694
+ # 如果使用了临时文件,移动到目标位置
1695
+ if use_temp_file:
1696
+ shutil.move(temp_path, output_path)
1697
+
1698
+ print(f"💾 保存结果: {output_path}")
1699
+ print(f"\n✅ 完成! 清洗后 {count} 条数据")
1700
+ except Exception as e:
1701
+ # 清理临时文件
1702
+ if use_temp_file and os.path.exists(temp_path):
1703
+ os.unlink(temp_path)
1704
+ print(f"错误: 清洗失败 - {e}")
1705
+ import traceback
1706
+
1707
+ traceback.print_exc()
1708
+ return
1709
+
1710
+ # 非 JSONL 文件使用传统方式
1711
+ print(f"📊 加载数据: {filepath}")
1712
+ try:
1713
+ dt = DataTransformer.load(str(filepath))
1714
+ except Exception as e:
1715
+ print(f"错误: 无法读取文件 - {e}")
1716
+ return
1717
+
1718
+ original_count = len(dt)
1719
+ print(f" 共 {original_count} 条数据")
1273
1720
 
1274
1721
  # 单次遍历执行所有清洗操作
1275
1722
  data, step_stats = _clean_data_single_pass(
@@ -1281,12 +1728,11 @@ def clean(
1281
1728
  max_len_field=max_len_field,
1282
1729
  max_len_value=max_len_value,
1283
1730
  keep_fields=keep_fields,
1284
- drop_fields=set(drop_fields) if drop_fields else None,
1731
+ drop_fields=drop_fields_set,
1285
1732
  )
1286
1733
 
1287
1734
  # 保存结果
1288
1735
  final_count = len(data)
1289
- output_path = output or str(filepath)
1290
1736
  print(f"💾 保存结果: {output_path}")
1291
1737
 
1292
1738
  try:
@@ -1438,3 +1884,556 @@ def _clean_data_single_pass(
1438
1884
  step_stats.append(f"drop: {len(drop_fields)} 字段")
1439
1885
 
1440
1886
  return result, step_stats
1887
+
1888
+
1889
+ def _clean_streaming(
1890
+ input_path: str,
1891
+ output_path: str,
1892
+ strip: bool = False,
1893
+ empty_fields: Optional[List[str]] = None,
1894
+ min_len_field: Optional[str] = None,
1895
+ min_len_value: Optional[int] = None,
1896
+ max_len_field: Optional[str] = None,
1897
+ max_len_value: Optional[int] = None,
1898
+ keep_set: Optional[set] = None,
1899
+ drop_fields_set: Optional[set] = None,
1900
+ ) -> int:
1901
+ """
1902
+ 流式清洗数据。
1903
+
1904
+ Returns:
1905
+ 处理后的数据条数
1906
+ """
1907
+
1908
+ def clean_filter(item: Dict) -> bool:
1909
+ """过滤函数:返回 True 保留,False 过滤"""
1910
+ # 空值过滤
1911
+ if empty_fields is not None:
1912
+ if len(empty_fields) == 0:
1913
+ if any(_is_empty_value(v) for v in item.values()):
1914
+ return False
1915
+ else:
1916
+ if any(_is_empty_value(item.get(f)) for f in empty_fields):
1917
+ return False
1918
+
1919
+ # 最小长度过滤
1920
+ if min_len_field is not None:
1921
+ if _get_value_len(item.get(min_len_field, "")) < min_len_value:
1922
+ return False
1923
+
1924
+ # 最大长度过滤
1925
+ if max_len_field is not None:
1926
+ if _get_value_len(item.get(max_len_field, "")) > max_len_value:
1927
+ return False
1928
+
1929
+ return True
1930
+
1931
+ def clean_transform(item: Dict) -> Dict:
1932
+ """转换函数:strip + 字段管理"""
1933
+ # strip 处理
1934
+ if strip:
1935
+ item = {k: v.strip() if isinstance(v, str) else v for k, v in item.items()}
1936
+
1937
+ # 字段管理
1938
+ if keep_set is not None:
1939
+ item = {k: v for k, v in item.items() if k in keep_set}
1940
+ elif drop_fields_set is not None:
1941
+ item = {k: v for k, v in item.items() if k not in drop_fields_set}
1942
+
1943
+ return item
1944
+
1945
+ # 构建流式处理链
1946
+ st = load_stream(input_path)
1947
+
1948
+ # 如果需要 strip,先执行 strip 转换(在过滤之前,这样空值检测更准确)
1949
+ if strip:
1950
+ st = st.transform(
1951
+ lambda x: {k: v.strip() if isinstance(v, str) else v for k, v in x.items()}
1952
+ )
1953
+
1954
+ # 执行过滤
1955
+ if empty_fields is not None or min_len_field is not None or max_len_field is not None:
1956
+ st = st.filter(clean_filter)
1957
+
1958
+ # 执行字段管理(如果没有 strip,也需要在这里处理)
1959
+ if keep_set is not None or drop_fields_set is not None:
1960
+
1961
+ def field_transform(item):
1962
+ if keep_set is not None:
1963
+ return {k: v for k, v in item.items() if k in keep_set}
1964
+ elif drop_fields_set is not None:
1965
+ return {k: v for k, v in item.items() if k not in drop_fields_set}
1966
+ return item
1967
+
1968
+ st = st.transform(field_transform)
1969
+
1970
+ return st.save(output_path)
1971
+
1972
+
1973
+ # ============ Run Command ============
1974
+
1975
+
1976
+ def run(
1977
+ config: str,
1978
+ input: Optional[str] = None,
1979
+ output: Optional[str] = None,
1980
+ ) -> None:
1981
+ """
1982
+ 执行 Pipeline 配置文件。
1983
+
1984
+ Args:
1985
+ config: Pipeline YAML 配置文件路径
1986
+ input: 输入文件路径(覆盖配置中的 input)
1987
+ output: 输出文件路径(覆盖配置中的 output)
1988
+
1989
+ Examples:
1990
+ dt run pipeline.yaml
1991
+ dt run pipeline.yaml --input=new_data.jsonl
1992
+ dt run pipeline.yaml --input=data.jsonl --output=result.jsonl
1993
+ """
1994
+ config_path = Path(config)
1995
+
1996
+ if not config_path.exists():
1997
+ print(f"错误: 配置文件不存在 - {config}")
1998
+ return
1999
+
2000
+ if config_path.suffix.lower() not in (".yaml", ".yml"):
2001
+ print(f"错误: 配置文件必须是 YAML 格式 (.yaml 或 .yml)")
2002
+ return
2003
+
2004
+ # 验证配置
2005
+ errors = validate_pipeline(config)
2006
+ if errors:
2007
+ print("❌ 配置文件验证失败:")
2008
+ for err in errors:
2009
+ print(f" - {err}")
2010
+ return
2011
+
2012
+ # 执行 pipeline
2013
+ try:
2014
+ run_pipeline(config, input_file=input, output_file=output, verbose=True)
2015
+ except Exception as e:
2016
+ print(f"错误: {e}")
2017
+ import traceback
2018
+
2019
+ traceback.print_exc()
2020
+
2021
+
2022
+ # ============ Token Stats Command ============
2023
+
2024
+
2025
+ def token_stats(
2026
+ filename: str,
2027
+ field: str = "messages",
2028
+ model: str = "cl100k_base",
2029
+ detailed: bool = False,
2030
+ ) -> None:
2031
+ """
2032
+ 统计数据集的 Token 信息。
2033
+
2034
+ Args:
2035
+ filename: 输入文件路径
2036
+ field: 要统计的字段(默认 messages)
2037
+ model: 分词器: cl100k_base (默认), qwen2.5, llama3, gpt-4 等
2038
+ detailed: 是否显示详细统计
2039
+
2040
+ Examples:
2041
+ dt token-stats data.jsonl
2042
+ dt token-stats data.jsonl --field=text --model=qwen2.5
2043
+ dt token-stats data.jsonl --detailed
2044
+ """
2045
+ filepath = Path(filename)
2046
+
2047
+ if not filepath.exists():
2048
+ print(f"错误: 文件不存在 - {filename}")
2049
+ return
2050
+
2051
+ if not _check_file_format(filepath):
2052
+ return
2053
+
2054
+ # 加载数据
2055
+ print(f"📊 加载数据: {filepath}")
2056
+ try:
2057
+ data = load_data(str(filepath))
2058
+ except Exception as e:
2059
+ print(f"错误: 无法读取文件 - {e}")
2060
+ return
2061
+
2062
+ if not data:
2063
+ print("文件为空")
2064
+ return
2065
+
2066
+ total = len(data)
2067
+ print(f" 共 {total} 条数据")
2068
+ print(f"🔢 统计 Token (模型: {model}, 字段: {field})...")
2069
+
2070
+ # 检查字段类型并选择合适的统计方法
2071
+ sample = data[0]
2072
+ field_value = sample.get(field)
2073
+
2074
+ try:
2075
+ if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
2076
+ # messages 格式
2077
+ from ..tokenizers import messages_token_stats
2078
+
2079
+ stats = messages_token_stats(data, messages_field=field, model=model)
2080
+ _print_messages_token_stats(stats, detailed)
2081
+ else:
2082
+ # 普通文本字段
2083
+ from ..tokenizers import token_stats as compute_token_stats
2084
+
2085
+ stats = compute_token_stats(data, fields=field, model=model)
2086
+ _print_text_token_stats(stats, detailed)
2087
+ except ImportError as e:
2088
+ print(f"错误: {e}")
2089
+ return
2090
+ except Exception as e:
2091
+ print(f"错误: 统计失败 - {e}")
2092
+ import traceback
2093
+
2094
+ traceback.print_exc()
2095
+
2096
+
2097
+ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
2098
+ """打印 messages 格式的 token 统计"""
2099
+ try:
2100
+ from rich.console import Console
2101
+ from rich.panel import Panel
2102
+ from rich.table import Table
2103
+
2104
+ console = Console()
2105
+
2106
+ # 概览
2107
+ overview = (
2108
+ f"[bold]总样本数:[/bold] {stats['count']:,}\n"
2109
+ f"[bold]总 Token:[/bold] {stats['total_tokens']:,}\n"
2110
+ f"[bold]平均 Token:[/bold] {stats['avg_tokens']:,}\n"
2111
+ f"[bold]中位数:[/bold] {stats['median_tokens']:,}\n"
2112
+ f"[bold]范围:[/bold] {stats['min_tokens']:,} - {stats['max_tokens']:,}"
2113
+ )
2114
+ console.print(Panel(overview, title="📊 Token 统计概览", expand=False))
2115
+
2116
+ if detailed:
2117
+ # 详细统计
2118
+ table = Table(title="📋 分角色统计")
2119
+ table.add_column("角色", style="cyan")
2120
+ table.add_column("Token 数", justify="right")
2121
+ table.add_column("占比", justify="right")
2122
+
2123
+ total = stats["total_tokens"]
2124
+ for role, key in [
2125
+ ("User", "user_tokens"),
2126
+ ("Assistant", "assistant_tokens"),
2127
+ ("System", "system_tokens"),
2128
+ ]:
2129
+ tokens = stats.get(key, 0)
2130
+ pct = tokens / total * 100 if total > 0 else 0
2131
+ table.add_row(role, f"{tokens:,}", f"{pct:.1f}%")
2132
+
2133
+ console.print(table)
2134
+ console.print(f"\n平均对话轮数: {stats.get('avg_turns', 0)}")
2135
+
2136
+ except ImportError:
2137
+ # 没有 rich,使用普通打印
2138
+ print(f"\n{'=' * 40}")
2139
+ print("📊 Token 统计概览")
2140
+ print(f"{'=' * 40}")
2141
+ print(f"总样本数: {stats['count']:,}")
2142
+ print(f"总 Token: {stats['total_tokens']:,}")
2143
+ print(f"平均 Token: {stats['avg_tokens']:,}")
2144
+ print(f"中位数: {stats['median_tokens']:,}")
2145
+ print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
2146
+
2147
+ if detailed:
2148
+ print(f"\n{'=' * 40}")
2149
+ print("📋 分角色统计")
2150
+ print(f"{'=' * 40}")
2151
+ total = stats["total_tokens"]
2152
+ for role, key in [
2153
+ ("User", "user_tokens"),
2154
+ ("Assistant", "assistant_tokens"),
2155
+ ("System", "system_tokens"),
2156
+ ]:
2157
+ tokens = stats.get(key, 0)
2158
+ pct = tokens / total * 100 if total > 0 else 0
2159
+ print(f"{role}: {tokens:,} ({pct:.1f}%)")
2160
+ print(f"\n平均对话轮数: {stats.get('avg_turns', 0)}")
2161
+
2162
+
2163
+ def _print_text_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
2164
+ """打印普通文本的 token 统计"""
2165
+ try:
2166
+ from rich.console import Console
2167
+ from rich.panel import Panel
2168
+
2169
+ console = Console()
2170
+
2171
+ overview = (
2172
+ f"[bold]总样本数:[/bold] {stats['count']:,}\n"
2173
+ f"[bold]总 Token:[/bold] {stats['total_tokens']:,}\n"
2174
+ f"[bold]平均 Token:[/bold] {stats['avg_tokens']:.1f}\n"
2175
+ f"[bold]中位数:[/bold] {stats['median_tokens']:,}\n"
2176
+ f"[bold]范围:[/bold] {stats['min_tokens']:,} - {stats['max_tokens']:,}"
2177
+ )
2178
+ console.print(Panel(overview, title="📊 Token 统计", expand=False))
2179
+
2180
+ except ImportError:
2181
+ print(f"\n{'=' * 40}")
2182
+ print("📊 Token 统计")
2183
+ print(f"{'=' * 40}")
2184
+ print(f"总样本数: {stats['count']:,}")
2185
+ print(f"总 Token: {stats['total_tokens']:,}")
2186
+ print(f"平均 Token: {stats['avg_tokens']:.1f}")
2187
+ print(f"中位数: {stats['median_tokens']:,}")
2188
+ print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
2189
+
2190
+
2191
+ # ============ Diff Command ============
2192
+
2193
+
2194
+ def diff(
2195
+ file1: str,
2196
+ file2: str,
2197
+ key: Optional[str] = None,
2198
+ output: Optional[str] = None,
2199
+ ) -> None:
2200
+ """
2201
+ 对比两个数据集的差异。
2202
+
2203
+ Args:
2204
+ file1: 第一个文件路径
2205
+ file2: 第二个文件路径
2206
+ key: 用于匹配的键字段(可选)
2207
+ output: 差异报告输出路径(可选)
2208
+
2209
+ Examples:
2210
+ dt diff v1/train.jsonl v2/train.jsonl
2211
+ dt diff a.jsonl b.jsonl --key=id
2212
+ dt diff a.jsonl b.jsonl --output=diff_report.json
2213
+ """
2214
+ path1 = Path(file1)
2215
+ path2 = Path(file2)
2216
+
2217
+ # 验证文件
2218
+ for p, name in [(path1, "file1"), (path2, "file2")]:
2219
+ if not p.exists():
2220
+ print(f"错误: 文件不存在 - {p}")
2221
+ return
2222
+ if not _check_file_format(p):
2223
+ return
2224
+
2225
+ # 加载数据
2226
+ print(f"📊 加载数据...")
2227
+ try:
2228
+ data1 = load_data(str(path1))
2229
+ data2 = load_data(str(path2))
2230
+ except Exception as e:
2231
+ print(f"错误: 无法读取文件 - {e}")
2232
+ return
2233
+
2234
+ print(f" 文件1: {path1.name} ({len(data1)} 条)")
2235
+ print(f" 文件2: {path2.name} ({len(data2)} 条)")
2236
+
2237
+ # 计算差异
2238
+ print("🔍 计算差异...")
2239
+ diff_result = _compute_diff(data1, data2, key)
2240
+
2241
+ # 打印差异报告
2242
+ _print_diff_report(diff_result, path1.name, path2.name)
2243
+
2244
+ # 保存报告
2245
+ if output:
2246
+ print(f"\n💾 保存报告: {output}")
2247
+ save_data([diff_result], output)
2248
+
2249
+
2250
+ def _compute_diff(
2251
+ data1: List[Dict],
2252
+ data2: List[Dict],
2253
+ key: Optional[str] = None,
2254
+ ) -> Dict[str, Any]:
2255
+ """计算两个数据集的差异"""
2256
+ result = {
2257
+ "summary": {
2258
+ "file1_count": len(data1),
2259
+ "file2_count": len(data2),
2260
+ "added": 0,
2261
+ "removed": 0,
2262
+ "modified": 0,
2263
+ "unchanged": 0,
2264
+ },
2265
+ "field_changes": {},
2266
+ "details": {
2267
+ "added": [],
2268
+ "removed": [],
2269
+ "modified": [],
2270
+ },
2271
+ }
2272
+
2273
+ if key:
2274
+ # 基于 key 的精确匹配
2275
+ dict1 = {item.get(key): item for item in data1 if item.get(key) is not None}
2276
+ dict2 = {item.get(key): item for item in data2 if item.get(key) is not None}
2277
+
2278
+ keys1 = set(dict1.keys())
2279
+ keys2 = set(dict2.keys())
2280
+
2281
+ # 新增
2282
+ added_keys = keys2 - keys1
2283
+ result["summary"]["added"] = len(added_keys)
2284
+ result["details"]["added"] = [dict2[k] for k in list(added_keys)[:10]] # 最多显示 10 条
2285
+
2286
+ # 删除
2287
+ removed_keys = keys1 - keys2
2288
+ result["summary"]["removed"] = len(removed_keys)
2289
+ result["details"]["removed"] = [dict1[k] for k in list(removed_keys)[:10]]
2290
+
2291
+ # 修改/未变
2292
+ common_keys = keys1 & keys2
2293
+ for k in common_keys:
2294
+ if dict1[k] == dict2[k]:
2295
+ result["summary"]["unchanged"] += 1
2296
+ else:
2297
+ result["summary"]["modified"] += 1
2298
+ if len(result["details"]["modified"]) < 10:
2299
+ result["details"]["modified"].append(
2300
+ {
2301
+ "key": k,
2302
+ "before": dict1[k],
2303
+ "after": dict2[k],
2304
+ }
2305
+ )
2306
+ else:
2307
+ # 基于哈希的比较
2308
+ def _hash_item(item):
2309
+ return orjson.dumps(item, option=orjson.OPT_SORT_KEYS)
2310
+
2311
+ set1 = {_hash_item(item) for item in data1}
2312
+ set2 = {_hash_item(item) for item in data2}
2313
+
2314
+ added = set2 - set1
2315
+ removed = set1 - set2
2316
+ unchanged = set1 & set2
2317
+
2318
+ result["summary"]["added"] = len(added)
2319
+ result["summary"]["removed"] = len(removed)
2320
+ result["summary"]["unchanged"] = len(unchanged)
2321
+
2322
+ # 详情
2323
+ result["details"]["added"] = [orjson.loads(h) for h in list(added)[:10]]
2324
+ result["details"]["removed"] = [orjson.loads(h) for h in list(removed)[:10]]
2325
+
2326
+ # 字段变化分析
2327
+ fields1 = set()
2328
+ fields2 = set()
2329
+ for item in data1[:1000]: # 采样分析
2330
+ fields1.update(item.keys())
2331
+ for item in data2[:1000]:
2332
+ fields2.update(item.keys())
2333
+
2334
+ result["field_changes"] = {
2335
+ "added_fields": list(fields2 - fields1),
2336
+ "removed_fields": list(fields1 - fields2),
2337
+ "common_fields": list(fields1 & fields2),
2338
+ }
2339
+
2340
+ return result
2341
+
2342
+
2343
+ def _print_diff_report(diff_result: Dict[str, Any], name1: str, name2: str) -> None:
2344
+ """打印差异报告"""
2345
+ summary = diff_result["summary"]
2346
+ field_changes = diff_result["field_changes"]
2347
+
2348
+ try:
2349
+ from rich.console import Console
2350
+ from rich.panel import Panel
2351
+ from rich.table import Table
2352
+
2353
+ console = Console()
2354
+
2355
+ # 概览
2356
+ overview = (
2357
+ f"[bold]{name1}:[/bold] {summary['file1_count']:,} 条\n"
2358
+ f"[bold]{name2}:[/bold] {summary['file2_count']:,} 条\n"
2359
+ f"\n"
2360
+ f"[green]+ 新增:[/green] {summary['added']:,} 条\n"
2361
+ f"[red]- 删除:[/red] {summary['removed']:,} 条\n"
2362
+ f"[yellow]~ 修改:[/yellow] {summary['modified']:,} 条\n"
2363
+ f"[dim]= 未变:[/dim] {summary['unchanged']:,} 条"
2364
+ )
2365
+ console.print(Panel(overview, title="📊 差异概览", expand=False))
2366
+
2367
+ # 字段变化
2368
+ if field_changes["added_fields"] or field_changes["removed_fields"]:
2369
+ console.print("\n[bold]📋 字段变化:[/bold]")
2370
+ if field_changes["added_fields"]:
2371
+ console.print(
2372
+ f" [green]+ 新增字段:[/green] {', '.join(field_changes['added_fields'])}"
2373
+ )
2374
+ if field_changes["removed_fields"]:
2375
+ console.print(
2376
+ f" [red]- 删除字段:[/red] {', '.join(field_changes['removed_fields'])}"
2377
+ )
2378
+
2379
+ except ImportError:
2380
+ print(f"\n{'=' * 50}")
2381
+ print("📊 差异概览")
2382
+ print(f"{'=' * 50}")
2383
+ print(f"{name1}: {summary['file1_count']:,} 条")
2384
+ print(f"{name2}: {summary['file2_count']:,} 条")
2385
+ print()
2386
+ print(f"+ 新增: {summary['added']:,} 条")
2387
+ print(f"- 删除: {summary['removed']:,} 条")
2388
+ print(f"~ 修改: {summary['modified']:,} 条")
2389
+ print(f"= 未变: {summary['unchanged']:,} 条")
2390
+
2391
+ if field_changes["added_fields"] or field_changes["removed_fields"]:
2392
+ print(f"\n📋 字段变化:")
2393
+ if field_changes["added_fields"]:
2394
+ print(f" + 新增字段: {', '.join(field_changes['added_fields'])}")
2395
+ if field_changes["removed_fields"]:
2396
+ print(f" - 删除字段: {', '.join(field_changes['removed_fields'])}")
2397
+
2398
+
2399
+ # ============ History Command ============
2400
+
2401
+
2402
+ def history(
2403
+ filename: str,
2404
+ json: bool = False,
2405
+ ) -> None:
2406
+ """
2407
+ 显示数据文件的血缘历史。
2408
+
2409
+ Args:
2410
+ filename: 数据文件路径
2411
+ json: 以 JSON 格式输出
2412
+
2413
+ Examples:
2414
+ dt history data.jsonl
2415
+ dt history data.jsonl --json
2416
+ """
2417
+ filepath = Path(filename)
2418
+
2419
+ if not filepath.exists():
2420
+ print(f"错误: 文件不存在 - {filename}")
2421
+ return
2422
+
2423
+ if not has_lineage(str(filepath)):
2424
+ print(f"文件 {filename} 没有血缘记录")
2425
+ print("\n提示: 使用 track_lineage=True 加载数据,并在保存时使用 lineage=True 来记录血缘")
2426
+ print("示例:")
2427
+ print(" dt = DataTransformer.load('data.jsonl', track_lineage=True)")
2428
+ print(" dt.filter(...).transform(...).save('output.jsonl', lineage=True)")
2429
+ return
2430
+
2431
+ if json:
2432
+ # JSON 格式输出
2433
+ chain = get_lineage_chain(str(filepath))
2434
+ output = [record.to_dict() for record in chain]
2435
+ print(orjson.dumps(output, option=orjson.OPT_INDENT_2).decode("utf-8"))
2436
+ else:
2437
+ # 格式化报告
2438
+ report = format_lineage_report(str(filepath))
2439
+ print(report)