dtflow 0.5.9__py3-none-any.whl → 0.5.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dtflow/SKILL.md CHANGED
@@ -154,6 +154,7 @@ dt token-stats data.jsonl # 默认统计 messages 字段
154
154
  dt token-stats data.jsonl -f text # 指定统计字段
155
155
  dt token-stats data.jsonl -m qwen2.5 # 指定分词器 (cl100k_base/qwen2.5/llama3)
156
156
  dt token-stats data.jsonl --detailed # 显示详细统计
157
+ dt token-stats data.jsonl -w 4 # 多进程加速(数据量>=1000时自动启用)
157
158
 
158
159
  # 采样(支持字段路径语法)
159
160
  dt sample data.jsonl 100 # 随机采样 100 条
@@ -204,6 +205,7 @@ dt validate data.jsonl --preset=openai_chat # 预设: openai_chat/alpaca/d
204
205
  dt validate data.jsonl -p alpaca -f -o valid.jsonl # 过滤无效数据并保存
205
206
  dt validate data.jsonl -p openai_chat -v # 显示详细信息
206
207
  dt validate data.jsonl -p openai_chat --max-errors=50 # 最多显示 50 条错误
208
+ dt validate data.jsonl -p openai_chat -w 4 # 多进程加速
207
209
 
208
210
  # 转换
209
211
  dt transform data.jsonl --preset=openai_chat
dtflow/__init__.py CHANGED
@@ -60,7 +60,7 @@ from .tokenizers import (
60
60
  token_stats,
61
61
  )
62
62
 
63
- __version__ = "0.5.9"
63
+ __version__ = "0.5.10"
64
64
 
65
65
  __all__ = [
66
66
  # core
dtflow/__main__.py CHANGED
@@ -256,9 +256,12 @@ def token_stats(
256
256
  "cl100k_base", "--model", "-m", help="分词器: cl100k_base (默认), qwen2.5, llama3, gpt-4 等"
257
257
  ),
258
258
  detailed: bool = typer.Option(False, "--detailed", "-d", help="显示详细统计"),
259
+ workers: Optional[int] = typer.Option(
260
+ None, "--workers", "-w", help="并行进程数 (默认自动, 1 禁用并行)"
261
+ ),
259
262
  ):
260
263
  """统计数据集的 Token 信息"""
261
- _token_stats(filename, field, model, detailed)
264
+ _token_stats(filename, field, model, detailed, workers)
262
265
 
263
266
 
264
267
  @app.command()
@@ -359,9 +362,12 @@ def validate(
359
362
  filter: bool = typer.Option(False, "--filter", "-f", help="过滤无效数据并保存"),
360
363
  max_errors: int = typer.Option(20, "--max-errors", help="最多显示的错误数量"),
361
364
  verbose: bool = typer.Option(False, "--verbose", "-v", help="显示详细信息"),
365
+ workers: Optional[int] = typer.Option(
366
+ None, "--workers", "-w", help="并行进程数 (默认自动, 1 禁用并行)"
367
+ ),
362
368
  ):
363
369
  """使用预设 Schema 验证数据格式"""
364
- _validate(filename, preset, output, filter, max_errors, verbose)
370
+ _validate(filename, preset, output, filter, max_errors, verbose, workers)
365
371
 
366
372
 
367
373
  # ============ 工具命令 ============
dtflow/cli/stats.py CHANGED
@@ -209,7 +209,7 @@ def _quick_stats(filepath: Path) -> None:
209
209
  print(f"字段: {len(fields)} 个")
210
210
 
211
211
  if fields:
212
- print(f"\n📋 字段结构:")
212
+ print("\n📋 字段结构:")
213
213
  for i, f in enumerate(fields, 1):
214
214
  print(f" {i}. {f['field']} ({f['type']})")
215
215
 
@@ -597,14 +597,14 @@ def _print_stats(filename: str, total: int, field_stats: List[Dict[str, Any]]) -
597
597
  except ImportError:
598
598
  # 没有 rich,使用普通打印
599
599
  print(f"\n{'=' * 50}")
600
- print(f"📊 数据概览")
600
+ print("📊 数据概览")
601
601
  print(f"{'=' * 50}")
602
602
  print(f"文件: {filename}")
603
603
  print(f"总数: {total:,} 条")
604
604
  print(f"字段: {len(field_stats)} 个")
605
605
 
606
606
  print(f"\n{'=' * 50}")
607
- print(f"📋 字段统计")
607
+ print("📋 字段统计")
608
608
  print(f"{'=' * 50}")
609
609
  print(f"{'字段':<20} {'类型':<8} {'非空率':<8} {'唯一值':<8}")
610
610
  print("-" * 50)
@@ -620,6 +620,7 @@ def token_stats(
620
620
  field: str = "messages",
621
621
  model: str = "cl100k_base",
622
622
  detailed: bool = False,
623
+ workers: Optional[int] = None,
623
624
  ) -> None:
624
625
  """
625
626
  统计数据集的 Token 信息。
@@ -629,6 +630,7 @@ def token_stats(
629
630
  field: 要统计的字段(默认 messages),支持嵌套路径语法
630
631
  model: 分词器: cl100k_base (默认), qwen2.5, llama3, gpt-4 等
631
632
  detailed: 是否显示详细统计
633
+ workers: 并行进程数,None 自动检测,1 禁用并行
632
634
 
633
635
  Examples:
634
636
  dt token-stats data.jsonl
@@ -636,6 +638,7 @@ def token_stats(
636
638
  dt token-stats data.jsonl --field=conversation.messages
637
639
  dt token-stats data.jsonl --field=messages[-1].content # 统计最后一条消息
638
640
  dt token-stats data.jsonl --detailed
641
+ dt token-stats data.jsonl --workers=4 # 使用 4 进程
639
642
  """
640
643
  filepath = Path(filename)
641
644
 
@@ -667,7 +670,7 @@ def token_stats(
667
670
 
668
671
  # 尝试使用 rich 进度条
669
672
  try:
670
- from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
673
+ from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
671
674
 
672
675
  with Progress(
673
676
  SpinnerColumn(),
@@ -685,14 +688,22 @@ def token_stats(
685
688
  from ..tokenizers import messages_token_stats
686
689
 
687
690
  stats_result = messages_token_stats(
688
- data, messages_field=field, model=model, progress_callback=update_progress
691
+ data,
692
+ messages_field=field,
693
+ model=model,
694
+ progress_callback=update_progress,
695
+ workers=workers,
689
696
  )
690
697
  _print_messages_token_stats(stats_result, detailed)
691
698
  else:
692
699
  from ..tokenizers import token_stats as compute_token_stats
693
700
 
694
701
  stats_result = compute_token_stats(
695
- data, fields=field, model=model, progress_callback=update_progress
702
+ data,
703
+ fields=field,
704
+ model=model,
705
+ progress_callback=update_progress,
706
+ workers=workers,
696
707
  )
697
708
  _print_text_token_stats(stats_result, detailed)
698
709
 
@@ -703,12 +714,14 @@ def token_stats(
703
714
  if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
704
715
  from ..tokenizers import messages_token_stats
705
716
 
706
- stats_result = messages_token_stats(data, messages_field=field, model=model)
717
+ stats_result = messages_token_stats(
718
+ data, messages_field=field, model=model, workers=workers
719
+ )
707
720
  _print_messages_token_stats(stats_result, detailed)
708
721
  else:
709
722
  from ..tokenizers import token_stats as compute_token_stats
710
723
 
711
- stats_result = compute_token_stats(data, fields=field, model=model)
724
+ stats_result = compute_token_stats(data, fields=field, model=model, workers=workers)
712
725
  _print_text_token_stats(stats_result, detailed)
713
726
  except ImportError as e:
714
727
  print(f"错误: {e}")
@@ -788,7 +801,7 @@ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
788
801
  print(f"平均 Token: {stats['avg_tokens']:,} (std: {std:.1f})")
789
802
  print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
790
803
 
791
- print(f"\n📈 百分位分布:")
804
+ print("\n📈 百分位分布:")
792
805
  print(f" P25: {stats.get('p25', '-'):,} P50: {stats.get('median_tokens', '-'):,}")
793
806
  print(f" P75: {stats.get('p75', '-'):,} P90: {stats.get('p90', '-'):,}")
794
807
  print(f" P95: {stats.get('p95', '-'):,} P99: {stats.get('p99', '-'):,}")
@@ -855,7 +868,7 @@ def _print_text_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
855
868
  print(f"平均 Token: {stats['avg_tokens']:.1f} (std: {std:.1f})")
856
869
  print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
857
870
 
858
- print(f"\n📈 百分位分布:")
871
+ print("\n📈 百分位分布:")
859
872
  print(f" P25: {stats.get('p25', '-'):,} P50: {stats.get('median_tokens', '-'):,}")
860
873
  print(f" P75: {stats.get('p75', '-'):,} P90: {stats.get('p90', '-'):,}")
861
874
  print(f" P95: {stats.get('p95', '-'):,} P99: {stats.get('p99', '-'):,}")
dtflow/cli/validate.py CHANGED
@@ -6,8 +6,6 @@ from pathlib import Path
6
6
  from typing import Optional
7
7
 
8
8
  from ..schema import (
9
- Schema,
10
- Field,
11
9
  alpaca_schema,
12
10
  dpo_schema,
13
11
  openai_chat_schema,
@@ -16,7 +14,6 @@ from ..schema import (
16
14
  from ..storage.io import load_data, save_data
17
15
  from .common import _check_file_format
18
16
 
19
-
20
17
  # 预设 Schema 映射
21
18
  PRESET_SCHEMAS = {
22
19
  "openai_chat": openai_chat_schema,
@@ -36,6 +33,7 @@ def validate(
36
33
  filter_invalid: bool = False,
37
34
  max_errors: int = 20,
38
35
  verbose: bool = False,
36
+ workers: Optional[int] = None,
39
37
  ) -> None:
40
38
  """
41
39
  使用 Schema 验证数据文件。
@@ -47,11 +45,13 @@ def validate(
47
45
  filter_invalid: 过滤无效数据并保存
48
46
  max_errors: 最多显示的错误数量
49
47
  verbose: 显示详细信息
48
+ workers: 并行进程数,None 自动检测,1 禁用并行
50
49
 
51
50
  Examples:
52
51
  dt validate data.jsonl --preset=openai_chat
53
52
  dt validate data.jsonl --preset=alpaca -o valid.jsonl
54
53
  dt validate data.jsonl --preset=chat --filter
54
+ dt validate data.jsonl --preset=chat --workers=4
55
55
  """
56
56
  filepath = Path(filename)
57
57
 
@@ -99,19 +99,54 @@ def validate(
99
99
  print(f"总记录数: {total}")
100
100
  print()
101
101
 
102
- # 验证
103
- valid_data = []
104
- invalid_count = 0
105
- error_samples = []
106
-
107
- for i, item in enumerate(data):
108
- result = schema.validate(item)
109
- if result.valid:
110
- valid_data.append(item)
111
- else:
112
- invalid_count += 1
113
- if len(error_samples) < max_errors:
114
- error_samples.append((i, result))
102
+ # 验证(使用并行或串行)
103
+ use_parallel = workers != 1 and total >= 1000
104
+
105
+ if use_parallel:
106
+ # 使用进度条(如果有 rich)
107
+ try:
108
+ from rich.progress import (
109
+ BarColumn,
110
+ Progress,
111
+ SpinnerColumn,
112
+ TaskProgressColumn,
113
+ TextColumn,
114
+ )
115
+
116
+ with Progress(
117
+ SpinnerColumn(),
118
+ TextColumn("[bold blue]验证数据"),
119
+ BarColumn(),
120
+ TaskProgressColumn(),
121
+ ) as progress:
122
+ task = progress.add_task("", total=total)
123
+
124
+ def update_progress(current: int, total_count: int):
125
+ progress.update(task, completed=current)
126
+
127
+ valid_data, invalid_results = schema.validate_parallel(
128
+ data, workers=workers, progress_callback=update_progress
129
+ )
130
+ except ImportError:
131
+ print("🔍 验证数据...")
132
+ valid_data, invalid_results = schema.validate_parallel(data, workers=workers)
133
+
134
+ invalid_count = len(invalid_results)
135
+ error_samples = invalid_results[:max_errors]
136
+ else:
137
+ # 串行验证
138
+ valid_data = []
139
+ invalid_count = 0
140
+ error_samples = []
141
+
142
+ for i, item in enumerate(data):
143
+ result = schema.validate(item)
144
+ if result.valid:
145
+ valid_data.append(item)
146
+ else:
147
+ invalid_count += 1
148
+ if len(error_samples) < max_errors:
149
+ error_samples.append((i, result))
115
150
 
116
151
  valid_count = len(valid_data)
117
152
  valid_ratio = valid_count / total * 100 if total > 0 else 0
@@ -138,9 +173,7 @@ def validate(
138
173
 
139
174
  # 保存有效数据
140
175
  if output or filter_invalid:
141
- output_path = output or str(filepath).replace(
142
- filepath.suffix, f"_valid{filepath.suffix}"
143
- )
176
+ output_path = output or str(filepath).replace(filepath.suffix, f"_valid{filepath.suffix}")
144
177
  save_data(valid_data, output_path)
145
178
  print(f"✅ 有效数据已保存: {output_path} ({valid_count} 条)")
146
179
 
dtflow/parallel.py ADDED
@@ -0,0 +1,115 @@
1
+ """
2
+ 并行处理模块
3
+
4
+ 提供多进程并行处理工具,用于加速大数据集的 token 统计和 schema 验证。
5
+ """
6
+
7
+ from multiprocessing import Pool, cpu_count
8
+ from typing import Callable, List, Optional, TypeVar
9
+
10
+ T = TypeVar("T")
11
+ R = TypeVar("R")
12
+
13
+
14
+ def parallel_map(
15
+ func: Callable[[T], R],
16
+ data: List[T],
17
+ workers: Optional[int] = None,
18
+ threshold: int = 1000,
19
+ chunksize: Optional[int] = None,
20
+ ) -> List[R]:
21
+ """
22
+ 并行 map 操作。
23
+
24
+ Args:
25
+ func: 处理函数(必须可 pickle,不能是 lambda 或闭包)
26
+ data: 数据列表
27
+ workers: 进程数,None 则使用 CPU 核数
28
+ threshold: 数据量阈值,低于此值使用串行
29
+ chunksize: 每个进程的任务块大小,None 则自动计算
30
+
31
+ Returns:
32
+ 处理结果列表(保持顺序)
33
+ """
34
+ n = len(data)
35
+
36
+ # 数据量小或指定单进程,使用串行
37
+ if n < threshold or workers == 1:
38
+ return [func(item) for item in data]
39
+
40
+ workers = workers or cpu_count()
41
+ workers = min(workers, n) # 进程数不超过数据量
42
+
43
+ # 自动计算 chunksize
44
+ if chunksize is None:
45
+ chunksize = max(1, n // (workers * 4))
46
+
47
+ with Pool(processes=workers) as pool:
48
+ return pool.map(func, data, chunksize=chunksize)
49
+
50
+
51
+ def parallel_imap(
52
+ func: Callable[[T], R],
53
+ data: List[T],
54
+ workers: Optional[int] = None,
55
+ threshold: int = 1000,
56
+ chunksize: Optional[int] = None,
57
+ ):
58
+ """
59
+ 并行 imap 操作(惰性迭代器版本,支持进度回调)。
60
+
61
+ Args:
62
+ func: 处理函数(必须可 pickle)
63
+ data: 数据列表
64
+ workers: 进程数,None 则使用 CPU 核数
65
+ threshold: 数据量阈值,低于此值使用串行
66
+ chunksize: 每个进程的任务块大小
67
+
68
+ Yields:
69
+ 处理结果(按顺序)
70
+ """
71
+ n = len(data)
72
+
73
+ # 数据量小或指定单进程,使用串行
74
+ if n < threshold or workers == 1:
75
+ for item in data:
76
+ yield func(item)
77
+ return
78
+
79
+ workers = workers or cpu_count()
80
+ workers = min(workers, n)
81
+
82
+ if chunksize is None:
83
+ chunksize = max(1, n // (workers * 4))
84
+
85
+ with Pool(processes=workers) as pool:
86
+ for result in pool.imap(func, data, chunksize=chunksize):
87
+ yield result
88
+
89
+
90
+ def get_optimal_workers(data_size: int, default: Optional[int] = None) -> int:
91
+ """
92
+ 根据数据量计算最优进程数。
93
+
94
+ Args:
95
+ data_size: 数据量
96
+ default: 用户指定的进程数,None 则自动计算
97
+
98
+ Returns:
99
+ 最优进程数
100
+ """
101
+ if default is not None:
102
+ return default
103
+
104
+ cpu_cores = cpu_count()
105
+
106
+ # 数据量小于阈值,单进程
107
+ if data_size < 1000:
108
+ return 1
109
+
110
+ # 数据量适中,使用一半 CPU
111
+ if data_size < 10000:
112
+ return max(1, cpu_cores // 2)
113
+
114
+ # 大数据量,使用全部 CPU
115
+ return cpu_cores
dtflow/schema.py CHANGED
@@ -26,10 +26,35 @@ Schema 验证模块
26
26
  results = dt.validate_schema(schema)
27
27
  """
28
28
 
29
- from dataclasses import dataclass, field as dataclass_field
30
- from typing import Any, Callable, Dict, List, Literal, Optional, Set, Union
29
+ from dataclasses import dataclass
30
+ from dataclasses import field as dataclass_field
31
+ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
32
+
33
+ from .utils.field_path import _parse_path, get_field
34
+
35
+
36
+ def _validate_item_wrapper(args: tuple) -> Tuple[int, bool, list]:
37
+ """
38
+ 验证单条数据(用于多进程)。
39
+
40
+ Args:
41
+ args: (index, item, schema_fields) 元组
42
+
43
+ Returns:
44
+ (index, is_valid, errors_as_dicts) - 返回字典列表而非对象(pickle 兼容)
45
+ """
46
+ idx, item, fields = args
47
+ # 在子进程中重建 Schema
48
+ schema = Schema(fields)
49
+ result = schema.validate(item)
50
+
51
+ if result.valid:
52
+ return (idx, True, [])
53
+ else:
54
+ # 将错误转换为字典(pickle 兼容)
55
+ errors = [{"path": e.path, "message": e.message, "value": e.value} for e in result.errors]
56
+ return (idx, False, errors)
31
57
 
32
- from .utils.field_path import get_field, _parse_path, _get_value_by_segments
33
58
 
34
59
  # 支持的类型
35
60
  FieldType = Literal["str", "int", "float", "bool", "list", "dict", "any"]
@@ -162,9 +187,7 @@ class Field:
162
187
 
163
188
  # 选项检查
164
189
  if self.choices is not None and value not in self.choices:
165
- errors.append(
166
- ValidationError(path, f"值必须是 {self.choices} 之一", value)
167
- )
190
+ errors.append(ValidationError(path, f"值必须是 {self.choices} 之一", value))
168
191
 
169
192
  # 正则表达式检查
170
193
  if self.pattern is not None and isinstance(value, str):
@@ -324,9 +347,7 @@ class Schema:
324
347
 
325
348
  return errors
326
349
 
327
- def validate_batch(
328
- self, data: List[dict], max_errors: int = 100
329
- ) -> List[tuple]:
350
+ def validate_batch(self, data: List[dict], max_errors: int = 100) -> List[tuple]:
330
351
  """
331
352
  批量验证数据
332
353
 
@@ -350,9 +371,76 @@ class Schema:
350
371
 
351
372
  return failed
352
373
 
374
+ def validate_parallel(
375
+ self,
376
+ data: List[dict],
377
+ workers: Optional[int] = None,
378
+ progress_callback: Optional[Callable[[int, int], None]] = None,
379
+ ) -> tuple:
380
+ """
381
+ 并行验证数据列表。
382
+
383
+ Args:
384
+ data: 数据列表
385
+ workers: 进程数,None 自动检测,1 禁用并行
386
+ progress_callback: 进度回调函数
387
+
388
+ Returns:
389
+ (valid_data, invalid_indices_results) 元组
390
+ - valid_data: 有效数据列表
391
+ - invalid_indices_results: [(index, ValidationResult), ...] 无效数据
392
+ """
393
+ if not data:
394
+ return [], []
395
+
396
+ total = len(data)
397
+ use_parallel = workers != 1 and total >= 1000
398
+
399
+ valid_data = []
400
+ invalid_results = []
401
+
402
+ if use_parallel:
403
+ from .parallel import get_optimal_workers, parallel_imap
404
+
405
+ actual_workers = get_optimal_workers(total, workers)
406
+ # 准备参数:(index, item, schema_fields)
407
+ args_list = [(i, item, self._fields) for i, item in enumerate(data)]
408
+
409
+ for i, (idx, is_valid, result_data) in enumerate(
410
+ parallel_imap(
411
+ _validate_item_wrapper,
412
+ args_list,
413
+ workers=actual_workers,
414
+ threshold=1000,
415
+ )
416
+ ):
417
+ if is_valid:
418
+ valid_data.append(data[idx])
419
+ else:
420
+ # 重建 ValidationResult(因为不能直接 pickle)
421
+ errors = [
422
+ ValidationError(path=e["path"], message=e["message"], value=e.get("value"))
423
+ for e in result_data
424
+ ]
425
+ invalid_results.append((idx, ValidationResult(valid=False, errors=errors)))
426
+ if progress_callback:
427
+ progress_callback(i + 1, total)
428
+ else:
429
+ # 串行处理
430
+ for i, item in enumerate(data):
431
+ result = self.validate(item)
432
+ if result.valid:
433
+ valid_data.append(item)
434
+ else:
435
+ invalid_results.append((i, result))
436
+ if progress_callback:
437
+ progress_callback(i + 1, total)
438
+
439
+ return valid_data, invalid_results
440
+
353
441
  def __repr__(self) -> str:
354
442
  field_strs = [f" {path}: {field_def}" for path, field_def in self._fields.items()]
355
- return f"Schema({{\n" + ",\n".join(field_strs) + "\n}})"
443
+ return "Schema({\n" + ",\n".join(field_strs) + "\n}})"
356
444
 
357
445
 
358
446
  # ============================================================================
@@ -461,9 +549,7 @@ def sharegpt_schema(
461
549
  """
462
550
  return Schema(
463
551
  {
464
- "conversations": Field(
465
- type="list", required=True, min_length=min_conversations
466
- ),
552
+ "conversations": Field(type="list", required=True, min_length=min_conversations),
467
553
  "conversations[*].from": Field(
468
554
  type="str", required=True, choices=[human_role, gpt_role]
469
555
  ),
dtflow/tokenizers.py CHANGED
@@ -122,8 +122,8 @@ def _get_tiktoken_encoder(model: str):
122
122
  _tokenizer_cache[model] = tiktoken.get_encoding(model)
123
123
  else:
124
124
  _tokenizer_cache[model] = tiktoken.encoding_for_model(model)
125
- except ImportError:
126
- raise ImportError("需要安装 tiktoken: pip install tiktoken")
125
+ except ImportError as e:
126
+ raise ImportError("需要安装 tiktoken: pip install tiktoken") from e
127
127
  return _tokenizer_cache[model]
128
128
 
129
129
 
@@ -149,12 +149,12 @@ def _get_hf_tokenizer(model: str):
149
149
 
150
150
  tokenizer = AutoTokenizer.from_pretrained(resolved, trust_remote_code=True)
151
151
  _tokenizer_cache[resolved] = ("transformers", tokenizer)
152
- except ImportError:
152
+ except ImportError as e:
153
153
  raise ImportError(
154
154
  "需要安装 tokenizers 或 transformers:\n"
155
155
  " pip install tokenizers huggingface_hub (推荐,更轻量)\n"
156
156
  " pip install transformers"
157
- )
157
+ ) from e
158
158
  return _tokenizer_cache[resolved]
159
159
 
160
160
 
@@ -309,12 +309,29 @@ def _std(counts: List[int], avg: float) -> float:
309
309
  return variance**0.5
310
310
 
311
311
 
312
+ def _count_item_tokens(args: tuple) -> int:
313
+ """
314
+ 计算单条数据的 token 数(用于多进程)。
315
+
316
+ Args:
317
+ args: (item, fields, model, backend) 元组
318
+ """
319
+ item, fields, model, backend = args
320
+ total = 0
321
+ for field in fields:
322
+ value = get_field_with_spec(item, field, default="")
323
+ if value:
324
+ total += count_tokens(str(value), model=model, backend=backend)
325
+ return total
326
+
327
+
312
328
  def token_stats(
313
329
  data: List[Dict[str, Any]],
314
330
  fields: Union[str, List[str]],
315
331
  model: str = DEFAULT_MODEL,
316
332
  backend: Optional[str] = None,
317
333
  progress_callback: Optional[Callable[[int, int], None]] = None,
334
+ workers: Optional[int] = None,
318
335
  ) -> Dict[str, Any]:
319
336
  """
320
337
  统计数据集的 token 信息。
@@ -325,6 +342,7 @@ def token_stats(
325
342
  model: 模型名称或别名,如 "qwen2.5", "gpt-4" 等
326
343
  backend: 后端选择,None 则自动检测
327
344
  progress_callback: 进度回调函数,接收 (current, total) 两个参数
345
+ workers: 进程数,None 自动检测,1 表示禁用并行
328
346
 
329
347
  Returns:
330
348
  统计信息字典,包含:
@@ -342,17 +360,42 @@ def token_stats(
342
360
  if not data:
343
361
  return {"total_tokens": 0, "count": 0}
344
362
 
345
- counts = []
346
363
  total_items = len(data)
347
- for i, item in enumerate(data):
348
- total = 0
349
- for field in fields:
350
- value = get_field_with_spec(item, field, default="")
351
- if value:
352
- total += count_tokens(str(value), model=model, backend=backend)
353
- counts.append(total)
354
- if progress_callback:
355
- progress_callback(i + 1, total_items)
364
+ _backend = backend or _auto_backend(model)
365
+
366
+ # 判断是否使用多进程
367
+ use_parallel = workers != 1 and total_items >= 1000
368
+
369
+ if use_parallel:
370
+ from .parallel import get_optimal_workers, parallel_imap
371
+
372
+ actual_workers = get_optimal_workers(total_items, workers)
373
+ # 准备参数
374
+ args_list = [(item, fields, model, _backend) for item in data]
375
+ counts = []
376
+ for i, result in enumerate(
377
+ parallel_imap(
378
+ _count_item_tokens,
379
+ args_list,
380
+ workers=actual_workers,
381
+ threshold=1000,
382
+ )
383
+ ):
384
+ counts.append(result)
385
+ if progress_callback:
386
+ progress_callback(i + 1, total_items)
387
+ else:
388
+ # 串行处理
389
+ counts = []
390
+ for i, item in enumerate(data):
391
+ total = 0
392
+ for field in fields:
393
+ value = get_field_with_spec(item, field, default="")
394
+ if value:
395
+ total += count_tokens(str(value), model=model, backend=_backend)
396
+ counts.append(total)
397
+ if progress_callback:
398
+ progress_callback(i + 1, total_items)
356
399
 
357
400
  sorted_counts = sorted(counts)
358
401
  avg = sum(counts) / len(counts)
@@ -548,12 +591,27 @@ def messages_token_filter(
548
591
  return filter_func
549
592
 
550
593
 
594
+ def _count_messages_tokens_wrapper(args: tuple) -> Optional[Dict[str, int]]:
595
+ """
596
+ 计算单条 messages 的 token 数(用于多进程)。
597
+
598
+ Args:
599
+ args: (item, messages_field, model, backend) 元组
600
+ """
601
+ item, messages_field, model, backend = args
602
+ messages = get_field_with_spec(item, messages_field, default=[])
603
+ if messages:
604
+ return _count_messages_tokens(messages, model=model, backend=backend)
605
+ return None
606
+
607
+
551
608
  def messages_token_stats(
552
609
  data: List[Dict[str, Any]],
553
610
  messages_field: str = "messages",
554
611
  model: str = DEFAULT_MODEL,
555
612
  backend: Optional[str] = None,
556
613
  progress_callback: Optional[Callable[[int, int], None]] = None,
614
+ workers: Optional[int] = None,
557
615
  ) -> Dict[str, Any]:
558
616
  """
559
617
  统计数据集中 messages 的 token 信息。
@@ -564,6 +622,7 @@ def messages_token_stats(
564
622
  model: 模型名称或别名
565
623
  backend: 后端,None 则自动检测
566
624
  progress_callback: 进度回调函数,接收 (current, total) 两个参数
625
+ workers: 进程数,None 自动检测,1 表示禁用并行
567
626
 
568
627
  Returns:
569
628
  统计信息字典,包含:
@@ -581,14 +640,38 @@ def messages_token_stats(
581
640
  if not data:
582
641
  return {"count": 0, "total_tokens": 0}
583
642
 
584
- all_stats = []
585
643
  total_items = len(data)
586
- for i, item in enumerate(data):
587
- messages = get_field_with_spec(item, messages_field, default=[])
588
- if messages:
589
- all_stats.append(_count_messages_tokens(messages, model=model, backend=_backend))
590
- if progress_callback:
591
- progress_callback(i + 1, total_items)
644
+
645
+ # 判断是否使用多进程
646
+ use_parallel = workers != 1 and total_items >= 1000
647
+
648
+ all_stats = []
649
+ if use_parallel:
650
+ from .parallel import get_optimal_workers, parallel_imap
651
+
652
+ actual_workers = get_optimal_workers(total_items, workers)
653
+ args_list = [(item, messages_field, model, _backend) for item in data]
654
+
655
+ for i, result in enumerate(
656
+ parallel_imap(
657
+ _count_messages_tokens_wrapper,
658
+ args_list,
659
+ workers=actual_workers,
660
+ threshold=1000,
661
+ )
662
+ ):
663
+ if result is not None:
664
+ all_stats.append(result)
665
+ if progress_callback:
666
+ progress_callback(i + 1, total_items)
667
+ else:
668
+ # 串行处理
669
+ for i, item in enumerate(data):
670
+ messages = get_field_with_spec(item, messages_field, default=[])
671
+ if messages:
672
+ all_stats.append(_count_messages_tokens(messages, model=model, backend=_backend))
673
+ if progress_callback:
674
+ progress_callback(i + 1, total_items)
592
675
 
593
676
  if not all_stats:
594
677
  return {"count": 0, "total_tokens": 0}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dtflow
3
- Version: 0.5.9
3
+ Version: 0.5.10
4
4
  Summary: A flexible data transformation tool for ML training formats (SFT, RLHF, Pretrain)
5
5
  Project-URL: Homepage, https://github.com/yourusername/DataTransformer
6
6
  Project-URL: Documentation, https://github.com/yourusername/DataTransformer#readme
@@ -464,6 +464,7 @@ dt run pipeline.yaml --input=new_data.jsonl --output=result.jsonl
464
464
  dt token-stats data.jsonl --field=messages --model=gpt-4
465
465
  dt token-stats data.jsonl --field=messages[-1].content # 统计最后一条消息
466
466
  dt token-stats data.jsonl --field=text --detailed
467
+ dt token-stats data.jsonl --workers=4 # 多进程加速(数据量大时自动启用)
467
468
 
468
469
  # 数据对比
469
470
  dt diff v1/train.jsonl v2/train.jsonl
@@ -521,6 +522,7 @@ dt validate data.jsonl --preset=openai_chat # 使用预设 schema 验
521
522
  dt validate data.jsonl --preset=alpaca --verbose # 详细输出
522
523
  dt validate data.jsonl --preset=sharegpt --filter-invalid -o valid.jsonl # 过滤出有效数据
523
524
  dt validate data.jsonl --preset=dpo --max-errors=100 # 限制错误输出数量
525
+ dt validate data.jsonl --preset=openai_chat --workers=4 # 多进程加速
524
526
  ```
525
527
 
526
528
  ### 字段路径语法
@@ -1,16 +1,17 @@
1
- dtflow/SKILL.md,sha256=hPxJhroGmNbBv8MLZUkOA2yW1TDdUKEUYYlz9tW2mao,10393
2
- dtflow/__init__.py,sha256=9ZqhqD8qQM9w2dfHKyUWIaqSX-X4elWtbaQN4CNBhgg,3031
3
- dtflow/__main__.py,sha256=gg3v7u-Ot7AicgKrP1fuyKtMJXVduNuLmhy7L1LUPDg,17710
1
+ dtflow/SKILL.md,sha256=Oq8Kb5JghZMJ1WoP8OWhX3qAWaUY9Sip_iWAv8S2eMg,10567
2
+ dtflow/__init__.py,sha256=2A-P6k9VBIWZXRgXwYPFOwHMCmgkfKZVYuGuBziqqhc,3032
3
+ dtflow/__main__.py,sha256=_wrpYfOog6G83I17yuBe-hryBsaCrIwbXSEnzT-r28g,18008
4
4
  dtflow/converters.py,sha256=X3qeFD7FCOMnfiP3MicL5MXimOm4XUYBs5pczIkudU0,22331
5
5
  dtflow/core.py,sha256=qMo6B3LK--TWRK7ZBKObGcs3pKFnd0NPoaM0T8JC7Jw,38135
6
6
  dtflow/eval.py,sha256=_c-XP2zsOBznYltSyKEScOqvmPVX2orqepg5cNhXXB0,9836
7
7
  dtflow/framework.py,sha256=jyICi_RWHjX7WfsXdSbWmP1SL7y1OWSPyd5G5Y-lvg4,17578
8
8
  dtflow/lineage.py,sha256=jie3OL1qK90-_cOOqqLbhSJ1oGUktDM1x5HRpQ5Qiyc,12800
9
+ dtflow/parallel.py,sha256=EnIdGEGMrZUNT2-CBIV93UFfpqr_jU_heqqvdGXcP-Y,3046
9
10
  dtflow/pipeline.py,sha256=zZaC4fg5vsp_30Fhbg75vu0yggsdvf28bWBiVDWzZ6Y,13901
10
11
  dtflow/presets.py,sha256=qa8WQJhbNMuGxqqgA9BFadEBwDB9s0zWNxxhzF3q1K8,4701
11
- dtflow/schema.py,sha256=IFcij22_UFKcgKT1YWwRg2QJO0vcAvCb1arZmsGByts,16824
12
+ dtflow/schema.py,sha256=zCZNEAqTMT1BS_p2t0CYczR5S9rqyDREa7ZsYI5pFGA,19885
12
13
  dtflow/streaming.py,sha256=dxpNd1-Wz_PTLTdvM5qn06_2TJr5NRlIIuw0LOSS2Iw,24755
13
- dtflow/tokenizers.py,sha256=7ZAelSmcDxLWH5kICgH9Q1ULH3_BfDZb9suHMjJJRZU,20589
14
+ dtflow/tokenizers.py,sha256=GFQsuLSLn2GHn2kaXhJkP8G85lgsdLzYtJNbppQhYPE,23408
14
15
  dtflow/cli/__init__.py,sha256=QhZ-thgx9IBTFII7T_hdoWFUl0CCsdGQHN5ZEZw2XB0,423
15
16
  dtflow/cli/clean.py,sha256=BEQQlH2q6luCbx51M3oxxOwcnwlOA8vo9WX3Fp7I6AY,29498
16
17
  dtflow/cli/commands.py,sha256=LvyDQ_nWUM7UlPDEFQadRdw5O2ZKDLgF41_xAJRhYxI,1583
@@ -23,9 +24,9 @@ dtflow/cli/pipeline.py,sha256=QNEo-BJlaC1CVnVeRZr7TwfuZYloJ4TebIzJ5ALzry0,1426
23
24
  dtflow/cli/sample.py,sha256=etbro5I0pyNgn0Qfhp1M6Bh-95JN-AntDa5AwVe_oKY,18269
24
25
  dtflow/cli/skill.py,sha256=opiTEBejA7JHKrEMftMOPDQlOgZ4n59rwaHXGU1Nukk,2022
25
26
  dtflow/cli/split.py,sha256=96bhWnxHnjIqifoliLgciApkLbwQU8bWHovK8bcMk9g,3667
26
- dtflow/cli/stats.py,sha256=Jx3d4X0ftgpzU5q5RAWZEVJWwXviQTF4EAwBmz1IliA,31366
27
+ dtflow/cli/stats.py,sha256=HkTZD80h4tzYXTtMnfpjLUMP6kl_es6ifcmExxzGdMU,31813
27
28
  dtflow/cli/transform.py,sha256=w6xqMOxPxQvL2u_BPCfpDHuPSC9gmcqMPVN8s-B6bbY,15052
28
- dtflow/cli/validate.py,sha256=65aGVlMS_Rq0Ch0YQ-TclVJ03RQP4CnG137wthzb8Ao,4384
29
+ dtflow/cli/validate.py,sha256=Frs-jKcDHmYozpmIYZueDSX5o2i1Xn-WW81FGUyUrng,5796
29
30
  dtflow/storage/__init__.py,sha256=C0jpWNQU808Ezz7lWneddABal3wILy8ijFUNiSKbHV4,362
30
31
  dtflow/storage/io.py,sha256=ZH2aSE-S89gpy3z4oTqhcqWf4u10OdkDoyul7o_YBDI,23374
31
32
  dtflow/utils/__init__.py,sha256=Pn-ltwV04fBQmeZG7FxInDQmzH29LYOi90LgeLMEuQk,506
@@ -33,7 +34,7 @@ dtflow/utils/display.py,sha256=OeOdTh6mbDwSkDWlmkjfpTjy2QG8ZUaYU0NpHUWkpEQ,5881
33
34
  dtflow/utils/field_path.py,sha256=K8nU196RxTSJ1OoieTWGcYOWl9KjGq2iSxCAkfjECuM,7621
34
35
  dtflow/utils/helpers.py,sha256=JXN176_B2pm53GLVyZ1wj3wrmBJG52Tkw6AMQSdj7M8,791
35
36
  dtflow/utils/text_parser.py,sha256=0t2TMOSha4dTiDu9H4ygdb67cI20zhtBH1XavDspL_g,3727
36
- dtflow-0.5.9.dist-info/METADATA,sha256=Pu92Dz2vj7U_dki4A0e5xgka36BTT9K2PnN1LIeEhN0,25839
37
- dtflow-0.5.9.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
38
- dtflow-0.5.9.dist-info/entry_points.txt,sha256=dadIDOK7Iu9pMxnMPBfpb4aAPe4hQbBOshpQYjVYpGc,44
39
- dtflow-0.5.9.dist-info/RECORD,,
37
+ dtflow-0.5.10.dist-info/METADATA,sha256=OGefMoe17by5IbxdxZgqoJ1Y6OWPt_iGEFM4KgltRZw,26023
38
+ dtflow-0.5.10.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
39
+ dtflow-0.5.10.dist-info/entry_points.txt,sha256=dadIDOK7Iu9pMxnMPBfpb4aAPe4hQbBOshpQYjVYpGc,44
40
+ dtflow-0.5.10.dist-info/RECORD,,