dtflow 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtflow/__init__.py +36 -2
- dtflow/__main__.py +292 -239
- dtflow/cli/__init__.py +8 -2
- dtflow/cli/commands.py +1030 -92
- dtflow/converters.py +456 -0
- dtflow/core.py +96 -31
- dtflow/lineage.py +407 -0
- dtflow/mcp/cli.py +14 -14
- dtflow/pipeline.py +450 -0
- dtflow/storage/io.py +376 -370
- dtflow/streaming.py +661 -0
- dtflow/tokenizers.py +387 -31
- dtflow/utils/display.py +5 -4
- {dtflow-0.2.0.dist-info → dtflow-0.3.1.dist-info}/METADATA +234 -15
- dtflow-0.3.1.dist-info/RECORD +24 -0
- dtflow-0.2.0.dist-info/RECORD +0 -21
- {dtflow-0.2.0.dist-info → dtflow-0.3.1.dist-info}/WHEEL +0 -0
- {dtflow-0.2.0.dist-info → dtflow-0.3.1.dist-info}/entry_points.txt +0 -0
dtflow/tokenizers.py
CHANGED
|
@@ -2,49 +2,169 @@
|
|
|
2
2
|
Token 统计模块
|
|
3
3
|
|
|
4
4
|
提供 token 计数和基于 token 长度的过滤功能。
|
|
5
|
+
支持 OpenAI (tiktoken) 和开源模型 (transformers) 两种后端。
|
|
5
6
|
"""
|
|
6
7
|
from typing import Callable, Union, List, Dict, Any, Optional
|
|
7
8
|
|
|
8
9
|
# 延迟导入,避免未安装时报错
|
|
9
10
|
_tokenizer_cache = {}
|
|
10
11
|
|
|
12
|
+
# 默认编码器(使用 tiktoken 的 cl100k_base,速度快且依赖轻)
|
|
13
|
+
DEFAULT_MODEL = "cl100k_base"
|
|
14
|
+
|
|
15
|
+
# 模型别名映射:简短名称 -> HuggingFace 模型路径
|
|
16
|
+
MODEL_ALIASES = {
|
|
17
|
+
# Qwen 系列
|
|
18
|
+
"qwen2.5": "Qwen/Qwen2.5-7B",
|
|
19
|
+
"qwen2.5-0.5b": "Qwen/Qwen2.5-0.5B",
|
|
20
|
+
"qwen2.5-1.5b": "Qwen/Qwen2.5-1.5B",
|
|
21
|
+
"qwen2.5-3b": "Qwen/Qwen2.5-3B",
|
|
22
|
+
"qwen2.5-7b": "Qwen/Qwen2.5-7B",
|
|
23
|
+
"qwen2.5-14b": "Qwen/Qwen2.5-14B",
|
|
24
|
+
"qwen2.5-32b": "Qwen/Qwen2.5-32B",
|
|
25
|
+
"qwen2.5-72b": "Qwen/Qwen2.5-72B",
|
|
26
|
+
"qwen3": "Qwen/Qwen3-8B",
|
|
27
|
+
"qwen3-0.6b": "Qwen/Qwen3-0.6B",
|
|
28
|
+
"qwen3-1.7b": "Qwen/Qwen3-1.7B",
|
|
29
|
+
"qwen3-4b": "Qwen/Qwen3-4B",
|
|
30
|
+
"qwen3-8b": "Qwen/Qwen3-8B",
|
|
31
|
+
"qwen3-14b": "Qwen/Qwen3-14B",
|
|
32
|
+
"qwen3-32b": "Qwen/Qwen3-32B",
|
|
33
|
+
"qwen3-30b-a3b": "Qwen/Qwen3-30B-A3B",
|
|
34
|
+
"qwen3-235b-a22b": "Qwen/Qwen3-235B-A22B",
|
|
35
|
+
"qwen2": "Qwen/Qwen2-7B",
|
|
36
|
+
"qwen2-0.5b": "Qwen/Qwen2-0.5B",
|
|
37
|
+
"qwen2-1.5b": "Qwen/Qwen2-1.5B",
|
|
38
|
+
"qwen2-7b": "Qwen/Qwen2-7B",
|
|
39
|
+
"qwen2-72b": "Qwen/Qwen2-72B",
|
|
40
|
+
# Llama 系列
|
|
41
|
+
"llama3": "meta-llama/Llama-3.1-8B",
|
|
42
|
+
"llama3.1": "meta-llama/Llama-3.1-8B",
|
|
43
|
+
"llama3.1-8b": "meta-llama/Llama-3.1-8B",
|
|
44
|
+
"llama3.1-70b": "meta-llama/Llama-3.1-70B",
|
|
45
|
+
"llama3.2": "meta-llama/Llama-3.2-3B",
|
|
46
|
+
"llama3.2-1b": "meta-llama/Llama-3.2-1B",
|
|
47
|
+
"llama3.2-3b": "meta-llama/Llama-3.2-3B",
|
|
48
|
+
"llama3.3": "meta-llama/Llama-3.3-70B-Instruct",
|
|
49
|
+
"llama3.3-70b": "meta-llama/Llama-3.3-70B-Instruct",
|
|
50
|
+
# Mistral 系列
|
|
51
|
+
"mistral": "mistralai/Mistral-7B-v0.3",
|
|
52
|
+
"mistral-7b": "mistralai/Mistral-7B-v0.3",
|
|
53
|
+
"mixtral": "mistralai/Mixtral-8x7B-v0.1",
|
|
54
|
+
"mixtral-8x7b": "mistralai/Mixtral-8x7B-v0.1",
|
|
55
|
+
# DeepSeek 系列
|
|
56
|
+
"deepseek": "deepseek-ai/DeepSeek-V3",
|
|
57
|
+
"deepseek-v3": "deepseek-ai/DeepSeek-V3",
|
|
58
|
+
"deepseek-v2": "deepseek-ai/DeepSeek-V2",
|
|
59
|
+
"deepseek-coder": "deepseek-ai/deepseek-coder-6.7b-base",
|
|
60
|
+
# Yi 系列
|
|
61
|
+
"yi": "01-ai/Yi-1.5-9B",
|
|
62
|
+
"yi-1.5": "01-ai/Yi-1.5-9B",
|
|
63
|
+
"yi-1.5-6b": "01-ai/Yi-1.5-6B",
|
|
64
|
+
"yi-1.5-9b": "01-ai/Yi-1.5-9B",
|
|
65
|
+
"yi-1.5-34b": "01-ai/Yi-1.5-34B",
|
|
66
|
+
# InternLM 系列
|
|
67
|
+
"internlm": "internlm/internlm2_5-7b",
|
|
68
|
+
"internlm2.5": "internlm/internlm2_5-7b",
|
|
69
|
+
"internlm2.5-7b": "internlm/internlm2_5-7b",
|
|
70
|
+
"internlm2.5-20b": "internlm/internlm2_5-20b",
|
|
71
|
+
# GLM 系列
|
|
72
|
+
"glm4": "THUDM/glm-4-9b",
|
|
73
|
+
"glm4-9b": "THUDM/glm-4-9b",
|
|
74
|
+
# Baichuan 系列
|
|
75
|
+
"baichuan2": "baichuan-inc/Baichuan2-13B-Base",
|
|
76
|
+
"baichuan2-7b": "baichuan-inc/Baichuan2-7B-Base",
|
|
77
|
+
"baichuan2-13b": "baichuan-inc/Baichuan2-13B-Base",
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
# OpenAI 模型(使用 tiktoken)
|
|
81
|
+
OPENAI_MODELS = {"gpt-4", "gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo", "gpt-4-turbo", "o1", "o1-mini", "o1-preview", "o3", "o3-mini"}
|
|
82
|
+
|
|
83
|
+
# tiktoken 编码器名称
|
|
84
|
+
TIKTOKEN_ENCODINGS = {"cl100k_base", "p50k_base", "p50k_edit", "r50k_base", "o200k_base"}
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def resolve_model(model: str) -> str:
|
|
88
|
+
"""
|
|
89
|
+
解析模型名称,将别名转换为完整的 HuggingFace 路径。
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
model: 模型名称或别名
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
完整的模型路径
|
|
96
|
+
"""
|
|
97
|
+
return MODEL_ALIASES.get(model.lower(), model)
|
|
11
98
|
|
|
12
|
-
|
|
99
|
+
|
|
100
|
+
def _get_tiktoken_encoder(model: str):
|
|
13
101
|
"""获取 tiktoken 编码器(带缓存)"""
|
|
14
102
|
if model not in _tokenizer_cache:
|
|
15
103
|
try:
|
|
16
104
|
import tiktoken
|
|
17
|
-
|
|
105
|
+
# 直接使用编码器名称 (cl100k_base 等) 或通过模型名获取
|
|
106
|
+
if model in TIKTOKEN_ENCODINGS:
|
|
107
|
+
_tokenizer_cache[model] = tiktoken.get_encoding(model)
|
|
108
|
+
else:
|
|
109
|
+
_tokenizer_cache[model] = tiktoken.encoding_for_model(model)
|
|
18
110
|
except ImportError:
|
|
19
111
|
raise ImportError("需要安装 tiktoken: pip install tiktoken")
|
|
20
112
|
return _tokenizer_cache[model]
|
|
21
113
|
|
|
22
114
|
|
|
23
|
-
def
|
|
24
|
-
"""
|
|
25
|
-
|
|
115
|
+
def _get_hf_tokenizer(model: str):
|
|
116
|
+
"""
|
|
117
|
+
获取 HuggingFace tokenizer(带缓存,支持别名解析)。
|
|
118
|
+
|
|
119
|
+
优先使用 tokenizers 库(Rust 实现,轻量快速),失败时 fallback 到 transformers。
|
|
120
|
+
"""
|
|
121
|
+
resolved = resolve_model(model)
|
|
122
|
+
if resolved not in _tokenizer_cache:
|
|
123
|
+
# 优先使用 tokenizers 库(更轻量)
|
|
26
124
|
try:
|
|
27
|
-
from
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
125
|
+
from tokenizers import Tokenizer
|
|
126
|
+
from huggingface_hub import hf_hub_download
|
|
127
|
+
|
|
128
|
+
tokenizer_path = hf_hub_download(repo_id=resolved, filename="tokenizer.json")
|
|
129
|
+
_tokenizer_cache[resolved] = ("tokenizers", Tokenizer.from_file(tokenizer_path))
|
|
130
|
+
except Exception:
|
|
131
|
+
# Fallback 到 transformers(某些模型可能没有 tokenizer.json)
|
|
132
|
+
try:
|
|
133
|
+
from transformers import AutoTokenizer
|
|
134
|
+
tokenizer = AutoTokenizer.from_pretrained(resolved, trust_remote_code=True)
|
|
135
|
+
_tokenizer_cache[resolved] = ("transformers", tokenizer)
|
|
136
|
+
except ImportError:
|
|
137
|
+
raise ImportError(
|
|
138
|
+
"需要安装 tokenizers 或 transformers:\n"
|
|
139
|
+
" pip install tokenizers huggingface_hub (推荐,更轻量)\n"
|
|
140
|
+
" pip install transformers"
|
|
141
|
+
)
|
|
142
|
+
return _tokenizer_cache[resolved]
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _encode_tokens(tokenizer_info, text: str) -> int:
|
|
146
|
+
"""编码文本,返回 token 数量"""
|
|
147
|
+
backend, tokenizer = tokenizer_info
|
|
148
|
+
if backend == "tokenizers":
|
|
149
|
+
return len(tokenizer.encode(text).ids)
|
|
150
|
+
else:
|
|
151
|
+
return len(tokenizer.encode(text))
|
|
32
152
|
|
|
33
153
|
|
|
34
154
|
def count_tokens(
|
|
35
155
|
text: str,
|
|
36
|
-
model: str =
|
|
37
|
-
backend: str =
|
|
156
|
+
model: str = DEFAULT_MODEL,
|
|
157
|
+
backend: Optional[str] = None,
|
|
38
158
|
) -> int:
|
|
39
159
|
"""
|
|
40
160
|
计算文本的 token 数量。
|
|
41
161
|
|
|
42
162
|
Args:
|
|
43
163
|
text: 输入文本
|
|
44
|
-
model:
|
|
45
|
-
backend:
|
|
164
|
+
model: 模型名称或别名,如 "qwen2.5", "gpt-4", "llama3" 等
|
|
165
|
+
backend: 后端选择,None 则自动检测
|
|
46
166
|
- "tiktoken": OpenAI tiktoken(快速,支持 GPT 系列)
|
|
47
|
-
- "transformers": HuggingFace transformers
|
|
167
|
+
- "transformers": HuggingFace transformers(支持开源模型)
|
|
48
168
|
|
|
49
169
|
Returns:
|
|
50
170
|
token 数量
|
|
@@ -52,20 +172,22 @@ def count_tokens(
|
|
|
52
172
|
if not text:
|
|
53
173
|
return 0
|
|
54
174
|
|
|
55
|
-
|
|
175
|
+
_backend = backend or _auto_backend(model)
|
|
176
|
+
|
|
177
|
+
if _backend == "tiktoken":
|
|
56
178
|
encoder = _get_tiktoken_encoder(model)
|
|
57
179
|
return len(encoder.encode(text))
|
|
58
|
-
elif
|
|
59
|
-
|
|
60
|
-
return
|
|
180
|
+
elif _backend == "transformers":
|
|
181
|
+
tokenizer_info = _get_hf_tokenizer(model)
|
|
182
|
+
return _encode_tokens(tokenizer_info, text)
|
|
61
183
|
else:
|
|
62
|
-
raise ValueError(f"不支持的 backend: {
|
|
184
|
+
raise ValueError(f"不支持的 backend: {_backend}")
|
|
63
185
|
|
|
64
186
|
|
|
65
187
|
def token_counter(
|
|
66
188
|
fields: Union[str, List[str]],
|
|
67
|
-
model: str =
|
|
68
|
-
backend: str =
|
|
189
|
+
model: str = DEFAULT_MODEL,
|
|
190
|
+
backend: Optional[str] = None,
|
|
69
191
|
output_field: str = "token_count",
|
|
70
192
|
) -> Callable:
|
|
71
193
|
"""
|
|
@@ -73,8 +195,8 @@ def token_counter(
|
|
|
73
195
|
|
|
74
196
|
Args:
|
|
75
197
|
fields: 要统计的字段(单个或多个)
|
|
76
|
-
model:
|
|
77
|
-
backend:
|
|
198
|
+
model: 模型名称或别名,如 "qwen2.5", "gpt-4", "llama3" 等
|
|
199
|
+
backend: 后端选择,None 则自动检测
|
|
78
200
|
output_field: 输出字段名
|
|
79
201
|
|
|
80
202
|
Returns:
|
|
@@ -82,7 +204,7 @@ def token_counter(
|
|
|
82
204
|
|
|
83
205
|
Examples:
|
|
84
206
|
>>> dt.transform(token_counter("text"))
|
|
85
|
-
>>> dt.transform(token_counter(["question", "answer"]))
|
|
207
|
+
>>> dt.transform(token_counter(["question", "answer"], model="qwen3"))
|
|
86
208
|
"""
|
|
87
209
|
if isinstance(fields, str):
|
|
88
210
|
fields = [fields]
|
|
@@ -104,8 +226,8 @@ def token_filter(
|
|
|
104
226
|
fields: Union[str, List[str]],
|
|
105
227
|
min_tokens: Optional[int] = None,
|
|
106
228
|
max_tokens: Optional[int] = None,
|
|
107
|
-
model: str =
|
|
108
|
-
backend: str =
|
|
229
|
+
model: str = DEFAULT_MODEL,
|
|
230
|
+
backend: Optional[str] = None,
|
|
109
231
|
) -> Callable:
|
|
110
232
|
"""
|
|
111
233
|
创建基于 token 长度的过滤函数。
|
|
@@ -146,8 +268,8 @@ def token_filter(
|
|
|
146
268
|
def token_stats(
|
|
147
269
|
data: List[Dict[str, Any]],
|
|
148
270
|
fields: Union[str, List[str]],
|
|
149
|
-
model: str =
|
|
150
|
-
backend: str =
|
|
271
|
+
model: str = DEFAULT_MODEL,
|
|
272
|
+
backend: Optional[str] = None,
|
|
151
273
|
) -> Dict[str, Any]:
|
|
152
274
|
"""
|
|
153
275
|
统计数据集的 token 信息。
|
|
@@ -155,8 +277,8 @@ def token_stats(
|
|
|
155
277
|
Args:
|
|
156
278
|
data: 数据列表
|
|
157
279
|
fields: 要统计的字段
|
|
158
|
-
model:
|
|
159
|
-
backend:
|
|
280
|
+
model: 模型名称或别名,如 "qwen2.5", "gpt-4" 等
|
|
281
|
+
backend: 后端选择,None 则自动检测
|
|
160
282
|
|
|
161
283
|
Returns:
|
|
162
284
|
统计信息字典
|
|
@@ -184,3 +306,237 @@ def token_stats(
|
|
|
184
306
|
"max_tokens": max(counts),
|
|
185
307
|
"median_tokens": sorted(counts)[len(counts) // 2],
|
|
186
308
|
}
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def _auto_backend(model: str) -> str:
|
|
312
|
+
"""
|
|
313
|
+
自动检测 tokenizer backend。
|
|
314
|
+
|
|
315
|
+
规则:
|
|
316
|
+
1. tiktoken 编码器名称 (cl100k_base 等) -> tiktoken
|
|
317
|
+
2. OpenAI 模型 (gpt-*, o1*, o3*) -> tiktoken
|
|
318
|
+
3. 其他模型(包括别名和 HuggingFace 路径)-> transformers
|
|
319
|
+
"""
|
|
320
|
+
model_lower = model.lower()
|
|
321
|
+
|
|
322
|
+
# tiktoken 编码器名称
|
|
323
|
+
if model_lower in TIKTOKEN_ENCODINGS:
|
|
324
|
+
return "tiktoken"
|
|
325
|
+
|
|
326
|
+
# OpenAI 模型使用 tiktoken
|
|
327
|
+
if model_lower in OPENAI_MODELS or model_lower.startswith(("gpt-", "o1", "o3")):
|
|
328
|
+
return "tiktoken"
|
|
329
|
+
|
|
330
|
+
# 其他模型使用 transformers
|
|
331
|
+
return "transformers"
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def _count_messages_tokens(
|
|
335
|
+
messages: List[Dict[str, Any]],
|
|
336
|
+
model: str,
|
|
337
|
+
backend: str,
|
|
338
|
+
) -> Dict[str, int]:
|
|
339
|
+
"""统计 messages 中各角色的 token 数"""
|
|
340
|
+
role_tokens = {"user": 0, "assistant": 0, "system": 0, "other": 0}
|
|
341
|
+
turn_tokens = []
|
|
342
|
+
|
|
343
|
+
for msg in messages:
|
|
344
|
+
role = msg.get("role", "other")
|
|
345
|
+
content = msg.get("content", "")
|
|
346
|
+
if not content:
|
|
347
|
+
continue
|
|
348
|
+
|
|
349
|
+
tokens = count_tokens(str(content), model=model, backend=backend)
|
|
350
|
+
|
|
351
|
+
if role in role_tokens:
|
|
352
|
+
role_tokens[role] += tokens
|
|
353
|
+
else:
|
|
354
|
+
role_tokens["other"] += tokens
|
|
355
|
+
|
|
356
|
+
turn_tokens.append(tokens)
|
|
357
|
+
|
|
358
|
+
total = sum(role_tokens.values())
|
|
359
|
+
return {
|
|
360
|
+
"total": total,
|
|
361
|
+
"user": role_tokens["user"],
|
|
362
|
+
"assistant": role_tokens["assistant"],
|
|
363
|
+
"system": role_tokens["system"],
|
|
364
|
+
"turns": len(turn_tokens),
|
|
365
|
+
"avg_turn": total // len(turn_tokens) if turn_tokens else 0,
|
|
366
|
+
"max_turn": max(turn_tokens) if turn_tokens else 0,
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def messages_token_counter(
|
|
371
|
+
messages_field: str = "messages",
|
|
372
|
+
model: str = DEFAULT_MODEL,
|
|
373
|
+
backend: Optional[str] = None,
|
|
374
|
+
output_field: str = "token_stats",
|
|
375
|
+
detailed: bool = False,
|
|
376
|
+
) -> Callable:
|
|
377
|
+
"""
|
|
378
|
+
创建 messages token 计数转换函数。
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
messages_field: messages 字段名
|
|
382
|
+
model: 模型名称或别名
|
|
383
|
+
- 别名: "qwen2.5", "qwen3", "llama3", "deepseek" 等
|
|
384
|
+
- OpenAI 模型: "gpt-4", "gpt-4o" 等(使用 tiktoken)
|
|
385
|
+
- HuggingFace 模型: "Qwen/Qwen2.5-7B" 等
|
|
386
|
+
- 本地路径: "/path/to/model"
|
|
387
|
+
backend: 强制指定后端,None 则自动检测
|
|
388
|
+
output_field: 输出字段名
|
|
389
|
+
detailed: True 则输出详细统计,False 只输出 total
|
|
390
|
+
|
|
391
|
+
Returns:
|
|
392
|
+
转换函数,用于 dt.transform()
|
|
393
|
+
|
|
394
|
+
Examples:
|
|
395
|
+
>>> # 使用默认模型 (qwen2.5)
|
|
396
|
+
>>> dt.transform(messages_token_counter())
|
|
397
|
+
|
|
398
|
+
>>> # 使用 Qwen3
|
|
399
|
+
>>> dt.transform(messages_token_counter(model="qwen3"))
|
|
400
|
+
|
|
401
|
+
>>> # 使用 OpenAI 模型
|
|
402
|
+
>>> dt.transform(messages_token_counter(model="gpt-4"))
|
|
403
|
+
|
|
404
|
+
>>> # 详细统计
|
|
405
|
+
>>> dt.transform(messages_token_counter(detailed=True))
|
|
406
|
+
# 输出: {"token_stats": {"total": 500, "user": 200, "assistant": 300, ...}}
|
|
407
|
+
"""
|
|
408
|
+
_backend = backend or _auto_backend(model)
|
|
409
|
+
|
|
410
|
+
def transform(item) -> dict:
|
|
411
|
+
result = item.to_dict() if hasattr(item, "to_dict") else dict(item)
|
|
412
|
+
messages = item.get(messages_field, []) if hasattr(item, "get") else item.get(messages_field, [])
|
|
413
|
+
|
|
414
|
+
if not messages:
|
|
415
|
+
result[output_field] = 0 if not detailed else {"total": 0}
|
|
416
|
+
return result
|
|
417
|
+
|
|
418
|
+
stats = _count_messages_tokens(messages, model=model, backend=_backend)
|
|
419
|
+
|
|
420
|
+
if detailed:
|
|
421
|
+
result[output_field] = stats
|
|
422
|
+
else:
|
|
423
|
+
result[output_field] = stats["total"]
|
|
424
|
+
|
|
425
|
+
return result
|
|
426
|
+
|
|
427
|
+
return transform
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def messages_token_filter(
|
|
431
|
+
messages_field: str = "messages",
|
|
432
|
+
min_tokens: Optional[int] = None,
|
|
433
|
+
max_tokens: Optional[int] = None,
|
|
434
|
+
min_turns: Optional[int] = None,
|
|
435
|
+
max_turns: Optional[int] = None,
|
|
436
|
+
model: str = DEFAULT_MODEL,
|
|
437
|
+
backend: Optional[str] = None,
|
|
438
|
+
) -> Callable:
|
|
439
|
+
"""
|
|
440
|
+
创建基于 messages token 的过滤函数。
|
|
441
|
+
|
|
442
|
+
Args:
|
|
443
|
+
messages_field: messages 字段名
|
|
444
|
+
min_tokens: 最小总 token 数
|
|
445
|
+
max_tokens: 最大总 token 数
|
|
446
|
+
min_turns: 最小对话轮数
|
|
447
|
+
max_turns: 最大对话轮数
|
|
448
|
+
model: 模型名称或别名
|
|
449
|
+
backend: 后端,None 则自动检测
|
|
450
|
+
|
|
451
|
+
Returns:
|
|
452
|
+
过滤函数,用于 dt.filter()
|
|
453
|
+
|
|
454
|
+
Examples:
|
|
455
|
+
>>> dt.filter(messages_token_filter(min_tokens=100, max_tokens=2048))
|
|
456
|
+
>>> dt.filter(messages_token_filter(min_turns=2, max_turns=10, model="qwen3"))
|
|
457
|
+
"""
|
|
458
|
+
_backend = backend or _auto_backend(model)
|
|
459
|
+
|
|
460
|
+
def filter_func(item) -> bool:
|
|
461
|
+
messages = item.get(messages_field, []) if hasattr(item, "get") else item.get(messages_field, [])
|
|
462
|
+
|
|
463
|
+
if not messages:
|
|
464
|
+
return False
|
|
465
|
+
|
|
466
|
+
stats = _count_messages_tokens(messages, model=model, backend=_backend)
|
|
467
|
+
|
|
468
|
+
if min_tokens is not None and stats["total"] < min_tokens:
|
|
469
|
+
return False
|
|
470
|
+
if max_tokens is not None and stats["total"] > max_tokens:
|
|
471
|
+
return False
|
|
472
|
+
if min_turns is not None and stats["turns"] < min_turns:
|
|
473
|
+
return False
|
|
474
|
+
if max_turns is not None and stats["turns"] > max_turns:
|
|
475
|
+
return False
|
|
476
|
+
|
|
477
|
+
return True
|
|
478
|
+
|
|
479
|
+
return filter_func
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
def messages_token_stats(
|
|
483
|
+
data: List[Dict[str, Any]],
|
|
484
|
+
messages_field: str = "messages",
|
|
485
|
+
model: str = DEFAULT_MODEL,
|
|
486
|
+
backend: Optional[str] = None,
|
|
487
|
+
) -> Dict[str, Any]:
|
|
488
|
+
"""
|
|
489
|
+
统计数据集中 messages 的 token 信息。
|
|
490
|
+
|
|
491
|
+
Args:
|
|
492
|
+
data: 数据列表
|
|
493
|
+
messages_field: messages 字段名
|
|
494
|
+
model: 模型名称或别名
|
|
495
|
+
backend: 后端,None 则自动检测
|
|
496
|
+
|
|
497
|
+
Returns:
|
|
498
|
+
统计信息字典
|
|
499
|
+
|
|
500
|
+
Examples:
|
|
501
|
+
>>> stats = messages_token_stats(dt.data) # 使用默认 qwen2.5
|
|
502
|
+
>>> stats = messages_token_stats(dt.data, model="qwen3")
|
|
503
|
+
>>> print(stats)
|
|
504
|
+
{
|
|
505
|
+
"count": 1000,
|
|
506
|
+
"total_tokens": 500000,
|
|
507
|
+
"user_tokens": 200000,
|
|
508
|
+
"assistant_tokens": 290000,
|
|
509
|
+
"system_tokens": 10000,
|
|
510
|
+
"avg_tokens": 500,
|
|
511
|
+
"max_tokens": 2048,
|
|
512
|
+
"min_tokens": 50,
|
|
513
|
+
"avg_turns": 4,
|
|
514
|
+
}
|
|
515
|
+
"""
|
|
516
|
+
_backend = backend or _auto_backend(model)
|
|
517
|
+
|
|
518
|
+
if not data:
|
|
519
|
+
return {"count": 0, "total_tokens": 0}
|
|
520
|
+
|
|
521
|
+
all_stats = []
|
|
522
|
+
for item in data:
|
|
523
|
+
messages = item.get(messages_field, [])
|
|
524
|
+
if messages:
|
|
525
|
+
all_stats.append(_count_messages_tokens(messages, model=model, backend=_backend))
|
|
526
|
+
|
|
527
|
+
if not all_stats:
|
|
528
|
+
return {"count": 0, "total_tokens": 0}
|
|
529
|
+
|
|
530
|
+
totals = [s["total"] for s in all_stats]
|
|
531
|
+
return {
|
|
532
|
+
"count": len(all_stats),
|
|
533
|
+
"total_tokens": sum(totals),
|
|
534
|
+
"user_tokens": sum(s["user"] for s in all_stats),
|
|
535
|
+
"assistant_tokens": sum(s["assistant"] for s in all_stats),
|
|
536
|
+
"system_tokens": sum(s["system"] for s in all_stats),
|
|
537
|
+
"avg_tokens": sum(totals) // len(totals),
|
|
538
|
+
"max_tokens": max(totals),
|
|
539
|
+
"min_tokens": min(totals),
|
|
540
|
+
"median_tokens": sorted(totals)[len(totals) // 2],
|
|
541
|
+
"avg_turns": sum(s["turns"] for s in all_stats) // len(all_stats),
|
|
542
|
+
}
|
dtflow/utils/display.py
CHANGED
|
@@ -2,7 +2,8 @@
|
|
|
2
2
|
Data display utilities.
|
|
3
3
|
"""
|
|
4
4
|
from typing import List, Dict, Any, Optional
|
|
5
|
-
|
|
5
|
+
|
|
6
|
+
import orjson
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
def display_data(data: List[Dict[str, Any]],
|
|
@@ -52,7 +53,7 @@ def _display_with_rich(data: List[Dict[str, Any]],
|
|
|
52
53
|
display_item = {k: v for k, v in item.items() if k in fields}
|
|
53
54
|
|
|
54
55
|
# Create a panel for each item
|
|
55
|
-
json_str =
|
|
56
|
+
json_str = orjson.dumps(display_item, option=orjson.OPT_INDENT_2).decode("utf-8")
|
|
56
57
|
|
|
57
58
|
panel = Panel(
|
|
58
59
|
JSON(json_str, indent=2),
|
|
@@ -84,7 +85,7 @@ def _display_plain(data: List[Dict[str, Any]],
|
|
|
84
85
|
display_item = {k: v for k, v in item.items() if k in fields}
|
|
85
86
|
|
|
86
87
|
# Pretty print JSON
|
|
87
|
-
print(
|
|
88
|
+
print(orjson.dumps(display_item, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
|
88
89
|
|
|
89
90
|
print(f"\n{separator}\n")
|
|
90
91
|
|
|
@@ -100,7 +101,7 @@ def format_item(item: Dict[str, Any], max_width: int = 80) -> str:
|
|
|
100
101
|
Returns:
|
|
101
102
|
Formatted string
|
|
102
103
|
"""
|
|
103
|
-
return
|
|
104
|
+
return orjson.dumps(item, option=orjson.OPT_INDENT_2).decode("utf-8")
|
|
104
105
|
|
|
105
106
|
|
|
106
107
|
def preview_fields(data: List[Dict[str, Any]], n: int = 5) -> Dict[str, List[Any]]:
|