dtflow 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dtflow/tokenizers.py CHANGED
@@ -2,49 +2,169 @@
2
2
  Token 统计模块
3
3
 
4
4
  提供 token 计数和基于 token 长度的过滤功能。
5
+ 支持 OpenAI (tiktoken) 和开源模型 (transformers) 两种后端。
5
6
  """
6
7
  from typing import Callable, Union, List, Dict, Any, Optional
7
8
 
8
9
  # 延迟导入,避免未安装时报错
9
10
  _tokenizer_cache = {}
10
11
 
12
+ # 默认编码器(使用 tiktoken 的 cl100k_base,速度快且依赖轻)
13
+ DEFAULT_MODEL = "cl100k_base"
14
+
15
+ # 模型别名映射:简短名称 -> HuggingFace 模型路径
16
+ MODEL_ALIASES = {
17
+ # Qwen 系列
18
+ "qwen2.5": "Qwen/Qwen2.5-7B",
19
+ "qwen2.5-0.5b": "Qwen/Qwen2.5-0.5B",
20
+ "qwen2.5-1.5b": "Qwen/Qwen2.5-1.5B",
21
+ "qwen2.5-3b": "Qwen/Qwen2.5-3B",
22
+ "qwen2.5-7b": "Qwen/Qwen2.5-7B",
23
+ "qwen2.5-14b": "Qwen/Qwen2.5-14B",
24
+ "qwen2.5-32b": "Qwen/Qwen2.5-32B",
25
+ "qwen2.5-72b": "Qwen/Qwen2.5-72B",
26
+ "qwen3": "Qwen/Qwen3-8B",
27
+ "qwen3-0.6b": "Qwen/Qwen3-0.6B",
28
+ "qwen3-1.7b": "Qwen/Qwen3-1.7B",
29
+ "qwen3-4b": "Qwen/Qwen3-4B",
30
+ "qwen3-8b": "Qwen/Qwen3-8B",
31
+ "qwen3-14b": "Qwen/Qwen3-14B",
32
+ "qwen3-32b": "Qwen/Qwen3-32B",
33
+ "qwen3-30b-a3b": "Qwen/Qwen3-30B-A3B",
34
+ "qwen3-235b-a22b": "Qwen/Qwen3-235B-A22B",
35
+ "qwen2": "Qwen/Qwen2-7B",
36
+ "qwen2-0.5b": "Qwen/Qwen2-0.5B",
37
+ "qwen2-1.5b": "Qwen/Qwen2-1.5B",
38
+ "qwen2-7b": "Qwen/Qwen2-7B",
39
+ "qwen2-72b": "Qwen/Qwen2-72B",
40
+ # Llama 系列
41
+ "llama3": "meta-llama/Llama-3.1-8B",
42
+ "llama3.1": "meta-llama/Llama-3.1-8B",
43
+ "llama3.1-8b": "meta-llama/Llama-3.1-8B",
44
+ "llama3.1-70b": "meta-llama/Llama-3.1-70B",
45
+ "llama3.2": "meta-llama/Llama-3.2-3B",
46
+ "llama3.2-1b": "meta-llama/Llama-3.2-1B",
47
+ "llama3.2-3b": "meta-llama/Llama-3.2-3B",
48
+ "llama3.3": "meta-llama/Llama-3.3-70B-Instruct",
49
+ "llama3.3-70b": "meta-llama/Llama-3.3-70B-Instruct",
50
+ # Mistral 系列
51
+ "mistral": "mistralai/Mistral-7B-v0.3",
52
+ "mistral-7b": "mistralai/Mistral-7B-v0.3",
53
+ "mixtral": "mistralai/Mixtral-8x7B-v0.1",
54
+ "mixtral-8x7b": "mistralai/Mixtral-8x7B-v0.1",
55
+ # DeepSeek 系列
56
+ "deepseek": "deepseek-ai/DeepSeek-V3",
57
+ "deepseek-v3": "deepseek-ai/DeepSeek-V3",
58
+ "deepseek-v2": "deepseek-ai/DeepSeek-V2",
59
+ "deepseek-coder": "deepseek-ai/deepseek-coder-6.7b-base",
60
+ # Yi 系列
61
+ "yi": "01-ai/Yi-1.5-9B",
62
+ "yi-1.5": "01-ai/Yi-1.5-9B",
63
+ "yi-1.5-6b": "01-ai/Yi-1.5-6B",
64
+ "yi-1.5-9b": "01-ai/Yi-1.5-9B",
65
+ "yi-1.5-34b": "01-ai/Yi-1.5-34B",
66
+ # InternLM 系列
67
+ "internlm": "internlm/internlm2_5-7b",
68
+ "internlm2.5": "internlm/internlm2_5-7b",
69
+ "internlm2.5-7b": "internlm/internlm2_5-7b",
70
+ "internlm2.5-20b": "internlm/internlm2_5-20b",
71
+ # GLM 系列
72
+ "glm4": "THUDM/glm-4-9b",
73
+ "glm4-9b": "THUDM/glm-4-9b",
74
+ # Baichuan 系列
75
+ "baichuan2": "baichuan-inc/Baichuan2-13B-Base",
76
+ "baichuan2-7b": "baichuan-inc/Baichuan2-7B-Base",
77
+ "baichuan2-13b": "baichuan-inc/Baichuan2-13B-Base",
78
+ }
79
+
80
+ # OpenAI 模型(使用 tiktoken)
81
+ OPENAI_MODELS = {"gpt-4", "gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo", "gpt-4-turbo", "o1", "o1-mini", "o1-preview", "o3", "o3-mini"}
82
+
83
+ # tiktoken 编码器名称
84
+ TIKTOKEN_ENCODINGS = {"cl100k_base", "p50k_base", "p50k_edit", "r50k_base", "o200k_base"}
85
+
86
+
87
+ def resolve_model(model: str) -> str:
88
+ """
89
+ 解析模型名称,将别名转换为完整的 HuggingFace 路径。
90
+
91
+ Args:
92
+ model: 模型名称或别名
93
+
94
+ Returns:
95
+ 完整的模型路径
96
+ """
97
+ return MODEL_ALIASES.get(model.lower(), model)
11
98
 
12
- def _get_tiktoken_encoder(model: str = "gpt-4"):
99
+
100
+ def _get_tiktoken_encoder(model: str):
13
101
  """获取 tiktoken 编码器(带缓存)"""
14
102
  if model not in _tokenizer_cache:
15
103
  try:
16
104
  import tiktoken
17
- _tokenizer_cache[model] = tiktoken.encoding_for_model(model)
105
+ # 直接使用编码器名称 (cl100k_base 等) 或通过模型名获取
106
+ if model in TIKTOKEN_ENCODINGS:
107
+ _tokenizer_cache[model] = tiktoken.get_encoding(model)
108
+ else:
109
+ _tokenizer_cache[model] = tiktoken.encoding_for_model(model)
18
110
  except ImportError:
19
111
  raise ImportError("需要安装 tiktoken: pip install tiktoken")
20
112
  return _tokenizer_cache[model]
21
113
 
22
114
 
23
- def _get_transformers_tokenizer(model: str):
24
- """获取 transformers tokenizer(带缓存)"""
25
- if model not in _tokenizer_cache:
115
+ def _get_hf_tokenizer(model: str):
116
+ """
117
+ 获取 HuggingFace tokenizer(带缓存,支持别名解析)。
118
+
119
+ 优先使用 tokenizers 库(Rust 实现,轻量快速),失败时 fallback 到 transformers。
120
+ """
121
+ resolved = resolve_model(model)
122
+ if resolved not in _tokenizer_cache:
123
+ # 优先使用 tokenizers 库(更轻量)
26
124
  try:
27
- from transformers import AutoTokenizer
28
- _tokenizer_cache[model] = AutoTokenizer.from_pretrained(model)
29
- except ImportError:
30
- raise ImportError("需要安装 transformers: pip install transformers")
31
- return _tokenizer_cache[model]
125
+ from tokenizers import Tokenizer
126
+ from huggingface_hub import hf_hub_download
127
+
128
+ tokenizer_path = hf_hub_download(repo_id=resolved, filename="tokenizer.json")
129
+ _tokenizer_cache[resolved] = ("tokenizers", Tokenizer.from_file(tokenizer_path))
130
+ except Exception:
131
+ # Fallback 到 transformers(某些模型可能没有 tokenizer.json)
132
+ try:
133
+ from transformers import AutoTokenizer
134
+ tokenizer = AutoTokenizer.from_pretrained(resolved, trust_remote_code=True)
135
+ _tokenizer_cache[resolved] = ("transformers", tokenizer)
136
+ except ImportError:
137
+ raise ImportError(
138
+ "需要安装 tokenizers 或 transformers:\n"
139
+ " pip install tokenizers huggingface_hub (推荐,更轻量)\n"
140
+ " pip install transformers"
141
+ )
142
+ return _tokenizer_cache[resolved]
143
+
144
+
145
+ def _encode_tokens(tokenizer_info, text: str) -> int:
146
+ """编码文本,返回 token 数量"""
147
+ backend, tokenizer = tokenizer_info
148
+ if backend == "tokenizers":
149
+ return len(tokenizer.encode(text).ids)
150
+ else:
151
+ return len(tokenizer.encode(text))
32
152
 
33
153
 
34
154
  def count_tokens(
35
155
  text: str,
36
- model: str = "gpt-4",
37
- backend: str = "tiktoken",
156
+ model: str = DEFAULT_MODEL,
157
+ backend: Optional[str] = None,
38
158
  ) -> int:
39
159
  """
40
160
  计算文本的 token 数量。
41
161
 
42
162
  Args:
43
163
  text: 输入文本
44
- model: 模型名称
45
- backend: 后端选择
164
+ model: 模型名称或别名,如 "qwen2.5", "gpt-4", "llama3" 等
165
+ backend: 后端选择,None 则自动检测
46
166
  - "tiktoken": OpenAI tiktoken(快速,支持 GPT 系列)
47
- - "transformers": HuggingFace transformers(支持更多模型)
167
+ - "transformers": HuggingFace transformers(支持开源模型)
48
168
 
49
169
  Returns:
50
170
  token 数量
@@ -52,20 +172,22 @@ def count_tokens(
52
172
  if not text:
53
173
  return 0
54
174
 
55
- if backend == "tiktoken":
175
+ _backend = backend or _auto_backend(model)
176
+
177
+ if _backend == "tiktoken":
56
178
  encoder = _get_tiktoken_encoder(model)
57
179
  return len(encoder.encode(text))
58
- elif backend == "transformers":
59
- tokenizer = _get_transformers_tokenizer(model)
60
- return len(tokenizer.encode(text))
180
+ elif _backend == "transformers":
181
+ tokenizer_info = _get_hf_tokenizer(model)
182
+ return _encode_tokens(tokenizer_info, text)
61
183
  else:
62
- raise ValueError(f"不支持的 backend: {backend}")
184
+ raise ValueError(f"不支持的 backend: {_backend}")
63
185
 
64
186
 
65
187
  def token_counter(
66
188
  fields: Union[str, List[str]],
67
- model: str = "gpt-4",
68
- backend: str = "tiktoken",
189
+ model: str = DEFAULT_MODEL,
190
+ backend: Optional[str] = None,
69
191
  output_field: str = "token_count",
70
192
  ) -> Callable:
71
193
  """
@@ -73,8 +195,8 @@ def token_counter(
73
195
 
74
196
  Args:
75
197
  fields: 要统计的字段(单个或多个)
76
- model: 模型名称
77
- backend: tiktoken 或 transformers
198
+ model: 模型名称或别名,如 "qwen2.5", "gpt-4", "llama3" 等
199
+ backend: 后端选择,None 则自动检测
78
200
  output_field: 输出字段名
79
201
 
80
202
  Returns:
@@ -82,7 +204,7 @@ def token_counter(
82
204
 
83
205
  Examples:
84
206
  >>> dt.transform(token_counter("text"))
85
- >>> dt.transform(token_counter(["question", "answer"]))
207
+ >>> dt.transform(token_counter(["question", "answer"], model="qwen3"))
86
208
  """
87
209
  if isinstance(fields, str):
88
210
  fields = [fields]
@@ -104,8 +226,8 @@ def token_filter(
104
226
  fields: Union[str, List[str]],
105
227
  min_tokens: Optional[int] = None,
106
228
  max_tokens: Optional[int] = None,
107
- model: str = "gpt-4",
108
- backend: str = "tiktoken",
229
+ model: str = DEFAULT_MODEL,
230
+ backend: Optional[str] = None,
109
231
  ) -> Callable:
110
232
  """
111
233
  创建基于 token 长度的过滤函数。
@@ -146,8 +268,8 @@ def token_filter(
146
268
  def token_stats(
147
269
  data: List[Dict[str, Any]],
148
270
  fields: Union[str, List[str]],
149
- model: str = "gpt-4",
150
- backend: str = "tiktoken",
271
+ model: str = DEFAULT_MODEL,
272
+ backend: Optional[str] = None,
151
273
  ) -> Dict[str, Any]:
152
274
  """
153
275
  统计数据集的 token 信息。
@@ -155,8 +277,8 @@ def token_stats(
155
277
  Args:
156
278
  data: 数据列表
157
279
  fields: 要统计的字段
158
- model: 模型名称
159
- backend: tiktoken 或 transformers
280
+ model: 模型名称或别名,如 "qwen2.5", "gpt-4" 等
281
+ backend: 后端选择,None 则自动检测
160
282
 
161
283
  Returns:
162
284
  统计信息字典
@@ -184,3 +306,237 @@ def token_stats(
184
306
  "max_tokens": max(counts),
185
307
  "median_tokens": sorted(counts)[len(counts) // 2],
186
308
  }
309
+
310
+
311
+ def _auto_backend(model: str) -> str:
312
+ """
313
+ 自动检测 tokenizer backend。
314
+
315
+ 规则:
316
+ 1. tiktoken 编码器名称 (cl100k_base 等) -> tiktoken
317
+ 2. OpenAI 模型 (gpt-*, o1*, o3*) -> tiktoken
318
+ 3. 其他模型(包括别名和 HuggingFace 路径)-> transformers
319
+ """
320
+ model_lower = model.lower()
321
+
322
+ # tiktoken 编码器名称
323
+ if model_lower in TIKTOKEN_ENCODINGS:
324
+ return "tiktoken"
325
+
326
+ # OpenAI 模型使用 tiktoken
327
+ if model_lower in OPENAI_MODELS or model_lower.startswith(("gpt-", "o1", "o3")):
328
+ return "tiktoken"
329
+
330
+ # 其他模型使用 transformers
331
+ return "transformers"
332
+
333
+
334
+ def _count_messages_tokens(
335
+ messages: List[Dict[str, Any]],
336
+ model: str,
337
+ backend: str,
338
+ ) -> Dict[str, int]:
339
+ """统计 messages 中各角色的 token 数"""
340
+ role_tokens = {"user": 0, "assistant": 0, "system": 0, "other": 0}
341
+ turn_tokens = []
342
+
343
+ for msg in messages:
344
+ role = msg.get("role", "other")
345
+ content = msg.get("content", "")
346
+ if not content:
347
+ continue
348
+
349
+ tokens = count_tokens(str(content), model=model, backend=backend)
350
+
351
+ if role in role_tokens:
352
+ role_tokens[role] += tokens
353
+ else:
354
+ role_tokens["other"] += tokens
355
+
356
+ turn_tokens.append(tokens)
357
+
358
+ total = sum(role_tokens.values())
359
+ return {
360
+ "total": total,
361
+ "user": role_tokens["user"],
362
+ "assistant": role_tokens["assistant"],
363
+ "system": role_tokens["system"],
364
+ "turns": len(turn_tokens),
365
+ "avg_turn": total // len(turn_tokens) if turn_tokens else 0,
366
+ "max_turn": max(turn_tokens) if turn_tokens else 0,
367
+ }
368
+
369
+
370
+ def messages_token_counter(
371
+ messages_field: str = "messages",
372
+ model: str = DEFAULT_MODEL,
373
+ backend: Optional[str] = None,
374
+ output_field: str = "token_stats",
375
+ detailed: bool = False,
376
+ ) -> Callable:
377
+ """
378
+ 创建 messages token 计数转换函数。
379
+
380
+ Args:
381
+ messages_field: messages 字段名
382
+ model: 模型名称或别名
383
+ - 别名: "qwen2.5", "qwen3", "llama3", "deepseek" 等
384
+ - OpenAI 模型: "gpt-4", "gpt-4o" 等(使用 tiktoken)
385
+ - HuggingFace 模型: "Qwen/Qwen2.5-7B" 等
386
+ - 本地路径: "/path/to/model"
387
+ backend: 强制指定后端,None 则自动检测
388
+ output_field: 输出字段名
389
+ detailed: True 则输出详细统计,False 只输出 total
390
+
391
+ Returns:
392
+ 转换函数,用于 dt.transform()
393
+
394
+ Examples:
395
+ >>> # 使用默认模型 (qwen2.5)
396
+ >>> dt.transform(messages_token_counter())
397
+
398
+ >>> # 使用 Qwen3
399
+ >>> dt.transform(messages_token_counter(model="qwen3"))
400
+
401
+ >>> # 使用 OpenAI 模型
402
+ >>> dt.transform(messages_token_counter(model="gpt-4"))
403
+
404
+ >>> # 详细统计
405
+ >>> dt.transform(messages_token_counter(detailed=True))
406
+ # 输出: {"token_stats": {"total": 500, "user": 200, "assistant": 300, ...}}
407
+ """
408
+ _backend = backend or _auto_backend(model)
409
+
410
+ def transform(item) -> dict:
411
+ result = item.to_dict() if hasattr(item, "to_dict") else dict(item)
412
+ messages = item.get(messages_field, []) if hasattr(item, "get") else item.get(messages_field, [])
413
+
414
+ if not messages:
415
+ result[output_field] = 0 if not detailed else {"total": 0}
416
+ return result
417
+
418
+ stats = _count_messages_tokens(messages, model=model, backend=_backend)
419
+
420
+ if detailed:
421
+ result[output_field] = stats
422
+ else:
423
+ result[output_field] = stats["total"]
424
+
425
+ return result
426
+
427
+ return transform
428
+
429
+
430
+ def messages_token_filter(
431
+ messages_field: str = "messages",
432
+ min_tokens: Optional[int] = None,
433
+ max_tokens: Optional[int] = None,
434
+ min_turns: Optional[int] = None,
435
+ max_turns: Optional[int] = None,
436
+ model: str = DEFAULT_MODEL,
437
+ backend: Optional[str] = None,
438
+ ) -> Callable:
439
+ """
440
+ 创建基于 messages token 的过滤函数。
441
+
442
+ Args:
443
+ messages_field: messages 字段名
444
+ min_tokens: 最小总 token 数
445
+ max_tokens: 最大总 token 数
446
+ min_turns: 最小对话轮数
447
+ max_turns: 最大对话轮数
448
+ model: 模型名称或别名
449
+ backend: 后端,None 则自动检测
450
+
451
+ Returns:
452
+ 过滤函数,用于 dt.filter()
453
+
454
+ Examples:
455
+ >>> dt.filter(messages_token_filter(min_tokens=100, max_tokens=2048))
456
+ >>> dt.filter(messages_token_filter(min_turns=2, max_turns=10, model="qwen3"))
457
+ """
458
+ _backend = backend or _auto_backend(model)
459
+
460
+ def filter_func(item) -> bool:
461
+ messages = item.get(messages_field, []) if hasattr(item, "get") else item.get(messages_field, [])
462
+
463
+ if not messages:
464
+ return False
465
+
466
+ stats = _count_messages_tokens(messages, model=model, backend=_backend)
467
+
468
+ if min_tokens is not None and stats["total"] < min_tokens:
469
+ return False
470
+ if max_tokens is not None and stats["total"] > max_tokens:
471
+ return False
472
+ if min_turns is not None and stats["turns"] < min_turns:
473
+ return False
474
+ if max_turns is not None and stats["turns"] > max_turns:
475
+ return False
476
+
477
+ return True
478
+
479
+ return filter_func
480
+
481
+
482
+ def messages_token_stats(
483
+ data: List[Dict[str, Any]],
484
+ messages_field: str = "messages",
485
+ model: str = DEFAULT_MODEL,
486
+ backend: Optional[str] = None,
487
+ ) -> Dict[str, Any]:
488
+ """
489
+ 统计数据集中 messages 的 token 信息。
490
+
491
+ Args:
492
+ data: 数据列表
493
+ messages_field: messages 字段名
494
+ model: 模型名称或别名
495
+ backend: 后端,None 则自动检测
496
+
497
+ Returns:
498
+ 统计信息字典
499
+
500
+ Examples:
501
+ >>> stats = messages_token_stats(dt.data) # 使用默认 qwen2.5
502
+ >>> stats = messages_token_stats(dt.data, model="qwen3")
503
+ >>> print(stats)
504
+ {
505
+ "count": 1000,
506
+ "total_tokens": 500000,
507
+ "user_tokens": 200000,
508
+ "assistant_tokens": 290000,
509
+ "system_tokens": 10000,
510
+ "avg_tokens": 500,
511
+ "max_tokens": 2048,
512
+ "min_tokens": 50,
513
+ "avg_turns": 4,
514
+ }
515
+ """
516
+ _backend = backend or _auto_backend(model)
517
+
518
+ if not data:
519
+ return {"count": 0, "total_tokens": 0}
520
+
521
+ all_stats = []
522
+ for item in data:
523
+ messages = item.get(messages_field, [])
524
+ if messages:
525
+ all_stats.append(_count_messages_tokens(messages, model=model, backend=_backend))
526
+
527
+ if not all_stats:
528
+ return {"count": 0, "total_tokens": 0}
529
+
530
+ totals = [s["total"] for s in all_stats]
531
+ return {
532
+ "count": len(all_stats),
533
+ "total_tokens": sum(totals),
534
+ "user_tokens": sum(s["user"] for s in all_stats),
535
+ "assistant_tokens": sum(s["assistant"] for s in all_stats),
536
+ "system_tokens": sum(s["system"] for s in all_stats),
537
+ "avg_tokens": sum(totals) // len(totals),
538
+ "max_tokens": max(totals),
539
+ "min_tokens": min(totals),
540
+ "median_tokens": sorted(totals)[len(totals) // 2],
541
+ "avg_turns": sum(s["turns"] for s in all_stats) // len(all_stats),
542
+ }
dtflow/utils/display.py CHANGED
@@ -2,7 +2,8 @@
2
2
  Data display utilities.
3
3
  """
4
4
  from typing import List, Dict, Any, Optional
5
- import json
5
+
6
+ import orjson
6
7
 
7
8
 
8
9
  def display_data(data: List[Dict[str, Any]],
@@ -52,7 +53,7 @@ def _display_with_rich(data: List[Dict[str, Any]],
52
53
  display_item = {k: v for k, v in item.items() if k in fields}
53
54
 
54
55
  # Create a panel for each item
55
- json_str = json.dumps(display_item, indent=2, ensure_ascii=False)
56
+ json_str = orjson.dumps(display_item, option=orjson.OPT_INDENT_2).decode("utf-8")
56
57
 
57
58
  panel = Panel(
58
59
  JSON(json_str, indent=2),
@@ -84,7 +85,7 @@ def _display_plain(data: List[Dict[str, Any]],
84
85
  display_item = {k: v for k, v in item.items() if k in fields}
85
86
 
86
87
  # Pretty print JSON
87
- print(json.dumps(display_item, indent=2, ensure_ascii=False))
88
+ print(orjson.dumps(display_item, option=orjson.OPT_INDENT_2).decode("utf-8"))
88
89
 
89
90
  print(f"\n{separator}\n")
90
91
 
@@ -100,7 +101,7 @@ def format_item(item: Dict[str, Any], max_width: int = 80) -> str:
100
101
  Returns:
101
102
  Formatted string
102
103
  """
103
- return json.dumps(item, indent=2, ensure_ascii=False)
104
+ return orjson.dumps(item, option=orjson.OPT_INDENT_2).decode("utf-8")
104
105
 
105
106
 
106
107
  def preview_fields(data: List[Dict[str, Any]], n: int = 5) -> Dict[str, List[Any]]: