dtflow 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtflow/__init__.py +17 -1
- dtflow/__main__.py +292 -239
- dtflow/cli/__init__.py +8 -2
- dtflow/cli/commands.py +1030 -92
- dtflow/core.py +96 -31
- dtflow/lineage.py +407 -0
- dtflow/mcp/cli.py +14 -14
- dtflow/pipeline.py +450 -0
- dtflow/storage/io.py +376 -370
- dtflow/streaming.py +661 -0
- dtflow/tokenizers.py +188 -51
- dtflow/utils/display.py +5 -4
- {dtflow-0.3.0.dist-info → dtflow-0.3.1.dist-info}/METADATA +153 -7
- dtflow-0.3.1.dist-info/RECORD +24 -0
- dtflow-0.3.0.dist-info/RECORD +0 -21
- {dtflow-0.3.0.dist-info → dtflow-0.3.1.dist-info}/WHEEL +0 -0
- {dtflow-0.3.0.dist-info → dtflow-0.3.1.dist-info}/entry_points.txt +0 -0
dtflow/tokenizers.py
CHANGED
|
@@ -2,49 +2,169 @@
|
|
|
2
2
|
Token 统计模块
|
|
3
3
|
|
|
4
4
|
提供 token 计数和基于 token 长度的过滤功能。
|
|
5
|
+
支持 OpenAI (tiktoken) 和开源模型 (transformers) 两种后端。
|
|
5
6
|
"""
|
|
6
7
|
from typing import Callable, Union, List, Dict, Any, Optional
|
|
7
8
|
|
|
8
9
|
# 延迟导入,避免未安装时报错
|
|
9
10
|
_tokenizer_cache = {}
|
|
10
11
|
|
|
12
|
+
# 默认编码器(使用 tiktoken 的 cl100k_base,速度快且依赖轻)
|
|
13
|
+
DEFAULT_MODEL = "cl100k_base"
|
|
14
|
+
|
|
15
|
+
# 模型别名映射:简短名称 -> HuggingFace 模型路径
|
|
16
|
+
MODEL_ALIASES = {
|
|
17
|
+
# Qwen 系列
|
|
18
|
+
"qwen2.5": "Qwen/Qwen2.5-7B",
|
|
19
|
+
"qwen2.5-0.5b": "Qwen/Qwen2.5-0.5B",
|
|
20
|
+
"qwen2.5-1.5b": "Qwen/Qwen2.5-1.5B",
|
|
21
|
+
"qwen2.5-3b": "Qwen/Qwen2.5-3B",
|
|
22
|
+
"qwen2.5-7b": "Qwen/Qwen2.5-7B",
|
|
23
|
+
"qwen2.5-14b": "Qwen/Qwen2.5-14B",
|
|
24
|
+
"qwen2.5-32b": "Qwen/Qwen2.5-32B",
|
|
25
|
+
"qwen2.5-72b": "Qwen/Qwen2.5-72B",
|
|
26
|
+
"qwen3": "Qwen/Qwen3-8B",
|
|
27
|
+
"qwen3-0.6b": "Qwen/Qwen3-0.6B",
|
|
28
|
+
"qwen3-1.7b": "Qwen/Qwen3-1.7B",
|
|
29
|
+
"qwen3-4b": "Qwen/Qwen3-4B",
|
|
30
|
+
"qwen3-8b": "Qwen/Qwen3-8B",
|
|
31
|
+
"qwen3-14b": "Qwen/Qwen3-14B",
|
|
32
|
+
"qwen3-32b": "Qwen/Qwen3-32B",
|
|
33
|
+
"qwen3-30b-a3b": "Qwen/Qwen3-30B-A3B",
|
|
34
|
+
"qwen3-235b-a22b": "Qwen/Qwen3-235B-A22B",
|
|
35
|
+
"qwen2": "Qwen/Qwen2-7B",
|
|
36
|
+
"qwen2-0.5b": "Qwen/Qwen2-0.5B",
|
|
37
|
+
"qwen2-1.5b": "Qwen/Qwen2-1.5B",
|
|
38
|
+
"qwen2-7b": "Qwen/Qwen2-7B",
|
|
39
|
+
"qwen2-72b": "Qwen/Qwen2-72B",
|
|
40
|
+
# Llama 系列
|
|
41
|
+
"llama3": "meta-llama/Llama-3.1-8B",
|
|
42
|
+
"llama3.1": "meta-llama/Llama-3.1-8B",
|
|
43
|
+
"llama3.1-8b": "meta-llama/Llama-3.1-8B",
|
|
44
|
+
"llama3.1-70b": "meta-llama/Llama-3.1-70B",
|
|
45
|
+
"llama3.2": "meta-llama/Llama-3.2-3B",
|
|
46
|
+
"llama3.2-1b": "meta-llama/Llama-3.2-1B",
|
|
47
|
+
"llama3.2-3b": "meta-llama/Llama-3.2-3B",
|
|
48
|
+
"llama3.3": "meta-llama/Llama-3.3-70B-Instruct",
|
|
49
|
+
"llama3.3-70b": "meta-llama/Llama-3.3-70B-Instruct",
|
|
50
|
+
# Mistral 系列
|
|
51
|
+
"mistral": "mistralai/Mistral-7B-v0.3",
|
|
52
|
+
"mistral-7b": "mistralai/Mistral-7B-v0.3",
|
|
53
|
+
"mixtral": "mistralai/Mixtral-8x7B-v0.1",
|
|
54
|
+
"mixtral-8x7b": "mistralai/Mixtral-8x7B-v0.1",
|
|
55
|
+
# DeepSeek 系列
|
|
56
|
+
"deepseek": "deepseek-ai/DeepSeek-V3",
|
|
57
|
+
"deepseek-v3": "deepseek-ai/DeepSeek-V3",
|
|
58
|
+
"deepseek-v2": "deepseek-ai/DeepSeek-V2",
|
|
59
|
+
"deepseek-coder": "deepseek-ai/deepseek-coder-6.7b-base",
|
|
60
|
+
# Yi 系列
|
|
61
|
+
"yi": "01-ai/Yi-1.5-9B",
|
|
62
|
+
"yi-1.5": "01-ai/Yi-1.5-9B",
|
|
63
|
+
"yi-1.5-6b": "01-ai/Yi-1.5-6B",
|
|
64
|
+
"yi-1.5-9b": "01-ai/Yi-1.5-9B",
|
|
65
|
+
"yi-1.5-34b": "01-ai/Yi-1.5-34B",
|
|
66
|
+
# InternLM 系列
|
|
67
|
+
"internlm": "internlm/internlm2_5-7b",
|
|
68
|
+
"internlm2.5": "internlm/internlm2_5-7b",
|
|
69
|
+
"internlm2.5-7b": "internlm/internlm2_5-7b",
|
|
70
|
+
"internlm2.5-20b": "internlm/internlm2_5-20b",
|
|
71
|
+
# GLM 系列
|
|
72
|
+
"glm4": "THUDM/glm-4-9b",
|
|
73
|
+
"glm4-9b": "THUDM/glm-4-9b",
|
|
74
|
+
# Baichuan 系列
|
|
75
|
+
"baichuan2": "baichuan-inc/Baichuan2-13B-Base",
|
|
76
|
+
"baichuan2-7b": "baichuan-inc/Baichuan2-7B-Base",
|
|
77
|
+
"baichuan2-13b": "baichuan-inc/Baichuan2-13B-Base",
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
# OpenAI 模型(使用 tiktoken)
|
|
81
|
+
OPENAI_MODELS = {"gpt-4", "gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo", "gpt-4-turbo", "o1", "o1-mini", "o1-preview", "o3", "o3-mini"}
|
|
82
|
+
|
|
83
|
+
# tiktoken 编码器名称
|
|
84
|
+
TIKTOKEN_ENCODINGS = {"cl100k_base", "p50k_base", "p50k_edit", "r50k_base", "o200k_base"}
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def resolve_model(model: str) -> str:
|
|
88
|
+
"""
|
|
89
|
+
解析模型名称,将别名转换为完整的 HuggingFace 路径。
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
model: 模型名称或别名
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
完整的模型路径
|
|
96
|
+
"""
|
|
97
|
+
return MODEL_ALIASES.get(model.lower(), model)
|
|
98
|
+
|
|
11
99
|
|
|
12
|
-
def _get_tiktoken_encoder(model: str
|
|
100
|
+
def _get_tiktoken_encoder(model: str):
|
|
13
101
|
"""获取 tiktoken 编码器(带缓存)"""
|
|
14
102
|
if model not in _tokenizer_cache:
|
|
15
103
|
try:
|
|
16
104
|
import tiktoken
|
|
17
|
-
|
|
105
|
+
# 直接使用编码器名称 (cl100k_base 等) 或通过模型名获取
|
|
106
|
+
if model in TIKTOKEN_ENCODINGS:
|
|
107
|
+
_tokenizer_cache[model] = tiktoken.get_encoding(model)
|
|
108
|
+
else:
|
|
109
|
+
_tokenizer_cache[model] = tiktoken.encoding_for_model(model)
|
|
18
110
|
except ImportError:
|
|
19
111
|
raise ImportError("需要安装 tiktoken: pip install tiktoken")
|
|
20
112
|
return _tokenizer_cache[model]
|
|
21
113
|
|
|
22
114
|
|
|
23
|
-
def
|
|
24
|
-
"""
|
|
25
|
-
|
|
115
|
+
def _get_hf_tokenizer(model: str):
|
|
116
|
+
"""
|
|
117
|
+
获取 HuggingFace tokenizer(带缓存,支持别名解析)。
|
|
118
|
+
|
|
119
|
+
优先使用 tokenizers 库(Rust 实现,轻量快速),失败时 fallback 到 transformers。
|
|
120
|
+
"""
|
|
121
|
+
resolved = resolve_model(model)
|
|
122
|
+
if resolved not in _tokenizer_cache:
|
|
123
|
+
# 优先使用 tokenizers 库(更轻量)
|
|
26
124
|
try:
|
|
27
|
-
from
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
125
|
+
from tokenizers import Tokenizer
|
|
126
|
+
from huggingface_hub import hf_hub_download
|
|
127
|
+
|
|
128
|
+
tokenizer_path = hf_hub_download(repo_id=resolved, filename="tokenizer.json")
|
|
129
|
+
_tokenizer_cache[resolved] = ("tokenizers", Tokenizer.from_file(tokenizer_path))
|
|
130
|
+
except Exception:
|
|
131
|
+
# Fallback 到 transformers(某些模型可能没有 tokenizer.json)
|
|
132
|
+
try:
|
|
133
|
+
from transformers import AutoTokenizer
|
|
134
|
+
tokenizer = AutoTokenizer.from_pretrained(resolved, trust_remote_code=True)
|
|
135
|
+
_tokenizer_cache[resolved] = ("transformers", tokenizer)
|
|
136
|
+
except ImportError:
|
|
137
|
+
raise ImportError(
|
|
138
|
+
"需要安装 tokenizers 或 transformers:\n"
|
|
139
|
+
" pip install tokenizers huggingface_hub (推荐,更轻量)\n"
|
|
140
|
+
" pip install transformers"
|
|
141
|
+
)
|
|
142
|
+
return _tokenizer_cache[resolved]
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _encode_tokens(tokenizer_info, text: str) -> int:
|
|
146
|
+
"""编码文本,返回 token 数量"""
|
|
147
|
+
backend, tokenizer = tokenizer_info
|
|
148
|
+
if backend == "tokenizers":
|
|
149
|
+
return len(tokenizer.encode(text).ids)
|
|
150
|
+
else:
|
|
151
|
+
return len(tokenizer.encode(text))
|
|
32
152
|
|
|
33
153
|
|
|
34
154
|
def count_tokens(
|
|
35
155
|
text: str,
|
|
36
|
-
model: str =
|
|
37
|
-
backend: str =
|
|
156
|
+
model: str = DEFAULT_MODEL,
|
|
157
|
+
backend: Optional[str] = None,
|
|
38
158
|
) -> int:
|
|
39
159
|
"""
|
|
40
160
|
计算文本的 token 数量。
|
|
41
161
|
|
|
42
162
|
Args:
|
|
43
163
|
text: 输入文本
|
|
44
|
-
model:
|
|
45
|
-
backend:
|
|
164
|
+
model: 模型名称或别名,如 "qwen2.5", "gpt-4", "llama3" 等
|
|
165
|
+
backend: 后端选择,None 则自动检测
|
|
46
166
|
- "tiktoken": OpenAI tiktoken(快速,支持 GPT 系列)
|
|
47
|
-
- "transformers": HuggingFace transformers
|
|
167
|
+
- "transformers": HuggingFace transformers(支持开源模型)
|
|
48
168
|
|
|
49
169
|
Returns:
|
|
50
170
|
token 数量
|
|
@@ -52,20 +172,22 @@ def count_tokens(
|
|
|
52
172
|
if not text:
|
|
53
173
|
return 0
|
|
54
174
|
|
|
55
|
-
|
|
175
|
+
_backend = backend or _auto_backend(model)
|
|
176
|
+
|
|
177
|
+
if _backend == "tiktoken":
|
|
56
178
|
encoder = _get_tiktoken_encoder(model)
|
|
57
179
|
return len(encoder.encode(text))
|
|
58
|
-
elif
|
|
59
|
-
|
|
60
|
-
return
|
|
180
|
+
elif _backend == "transformers":
|
|
181
|
+
tokenizer_info = _get_hf_tokenizer(model)
|
|
182
|
+
return _encode_tokens(tokenizer_info, text)
|
|
61
183
|
else:
|
|
62
|
-
raise ValueError(f"不支持的 backend: {
|
|
184
|
+
raise ValueError(f"不支持的 backend: {_backend}")
|
|
63
185
|
|
|
64
186
|
|
|
65
187
|
def token_counter(
|
|
66
188
|
fields: Union[str, List[str]],
|
|
67
|
-
model: str =
|
|
68
|
-
backend: str =
|
|
189
|
+
model: str = DEFAULT_MODEL,
|
|
190
|
+
backend: Optional[str] = None,
|
|
69
191
|
output_field: str = "token_count",
|
|
70
192
|
) -> Callable:
|
|
71
193
|
"""
|
|
@@ -73,8 +195,8 @@ def token_counter(
|
|
|
73
195
|
|
|
74
196
|
Args:
|
|
75
197
|
fields: 要统计的字段(单个或多个)
|
|
76
|
-
model:
|
|
77
|
-
backend:
|
|
198
|
+
model: 模型名称或别名,如 "qwen2.5", "gpt-4", "llama3" 等
|
|
199
|
+
backend: 后端选择,None 则自动检测
|
|
78
200
|
output_field: 输出字段名
|
|
79
201
|
|
|
80
202
|
Returns:
|
|
@@ -82,7 +204,7 @@ def token_counter(
|
|
|
82
204
|
|
|
83
205
|
Examples:
|
|
84
206
|
>>> dt.transform(token_counter("text"))
|
|
85
|
-
>>> dt.transform(token_counter(["question", "answer"]))
|
|
207
|
+
>>> dt.transform(token_counter(["question", "answer"], model="qwen3"))
|
|
86
208
|
"""
|
|
87
209
|
if isinstance(fields, str):
|
|
88
210
|
fields = [fields]
|
|
@@ -104,8 +226,8 @@ def token_filter(
|
|
|
104
226
|
fields: Union[str, List[str]],
|
|
105
227
|
min_tokens: Optional[int] = None,
|
|
106
228
|
max_tokens: Optional[int] = None,
|
|
107
|
-
model: str =
|
|
108
|
-
backend: str =
|
|
229
|
+
model: str = DEFAULT_MODEL,
|
|
230
|
+
backend: Optional[str] = None,
|
|
109
231
|
) -> Callable:
|
|
110
232
|
"""
|
|
111
233
|
创建基于 token 长度的过滤函数。
|
|
@@ -146,8 +268,8 @@ def token_filter(
|
|
|
146
268
|
def token_stats(
|
|
147
269
|
data: List[Dict[str, Any]],
|
|
148
270
|
fields: Union[str, List[str]],
|
|
149
|
-
model: str =
|
|
150
|
-
backend: str =
|
|
271
|
+
model: str = DEFAULT_MODEL,
|
|
272
|
+
backend: Optional[str] = None,
|
|
151
273
|
) -> Dict[str, Any]:
|
|
152
274
|
"""
|
|
153
275
|
统计数据集的 token 信息。
|
|
@@ -155,8 +277,8 @@ def token_stats(
|
|
|
155
277
|
Args:
|
|
156
278
|
data: 数据列表
|
|
157
279
|
fields: 要统计的字段
|
|
158
|
-
model:
|
|
159
|
-
backend:
|
|
280
|
+
model: 模型名称或别名,如 "qwen2.5", "gpt-4" 等
|
|
281
|
+
backend: 后端选择,None 则自动检测
|
|
160
282
|
|
|
161
283
|
Returns:
|
|
162
284
|
统计信息字典
|
|
@@ -187,15 +309,25 @@ def token_stats(
|
|
|
187
309
|
|
|
188
310
|
|
|
189
311
|
def _auto_backend(model: str) -> str:
|
|
190
|
-
"""
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
312
|
+
"""
|
|
313
|
+
自动检测 tokenizer backend。
|
|
314
|
+
|
|
315
|
+
规则:
|
|
316
|
+
1. tiktoken 编码器名称 (cl100k_base 等) -> tiktoken
|
|
317
|
+
2. OpenAI 模型 (gpt-*, o1*, o3*) -> tiktoken
|
|
318
|
+
3. 其他模型(包括别名和 HuggingFace 路径)-> transformers
|
|
319
|
+
"""
|
|
320
|
+
model_lower = model.lower()
|
|
321
|
+
|
|
322
|
+
# tiktoken 编码器名称
|
|
323
|
+
if model_lower in TIKTOKEN_ENCODINGS:
|
|
324
|
+
return "tiktoken"
|
|
325
|
+
|
|
194
326
|
# OpenAI 模型使用 tiktoken
|
|
195
|
-
|
|
196
|
-
if model in openai_models or model.startswith(("gpt-", "o1")):
|
|
327
|
+
if model_lower in OPENAI_MODELS or model_lower.startswith(("gpt-", "o1", "o3")):
|
|
197
328
|
return "tiktoken"
|
|
198
|
-
|
|
329
|
+
|
|
330
|
+
# 其他模型使用 transformers
|
|
199
331
|
return "transformers"
|
|
200
332
|
|
|
201
333
|
|
|
@@ -237,7 +369,7 @@ def _count_messages_tokens(
|
|
|
237
369
|
|
|
238
370
|
def messages_token_counter(
|
|
239
371
|
messages_field: str = "messages",
|
|
240
|
-
model: str =
|
|
372
|
+
model: str = DEFAULT_MODEL,
|
|
241
373
|
backend: Optional[str] = None,
|
|
242
374
|
output_field: str = "token_stats",
|
|
243
375
|
detailed: bool = False,
|
|
@@ -247,9 +379,10 @@ def messages_token_counter(
|
|
|
247
379
|
|
|
248
380
|
Args:
|
|
249
381
|
messages_field: messages 字段名
|
|
250
|
-
model:
|
|
382
|
+
model: 模型名称或别名
|
|
383
|
+
- 别名: "qwen2.5", "qwen3", "llama3", "deepseek" 等
|
|
251
384
|
- OpenAI 模型: "gpt-4", "gpt-4o" 等(使用 tiktoken)
|
|
252
|
-
- HuggingFace 模型: "Qwen/Qwen2-7B" 等
|
|
385
|
+
- HuggingFace 模型: "Qwen/Qwen2.5-7B" 等
|
|
253
386
|
- 本地路径: "/path/to/model"
|
|
254
387
|
backend: 强制指定后端,None 则自动检测
|
|
255
388
|
output_field: 输出字段名
|
|
@@ -259,11 +392,14 @@ def messages_token_counter(
|
|
|
259
392
|
转换函数,用于 dt.transform()
|
|
260
393
|
|
|
261
394
|
Examples:
|
|
262
|
-
>>> #
|
|
263
|
-
>>> dt.transform(messages_token_counter(
|
|
395
|
+
>>> # 使用默认模型 (qwen2.5)
|
|
396
|
+
>>> dt.transform(messages_token_counter())
|
|
264
397
|
|
|
265
|
-
>>> #
|
|
266
|
-
>>> dt.transform(messages_token_counter(model="
|
|
398
|
+
>>> # 使用 Qwen3
|
|
399
|
+
>>> dt.transform(messages_token_counter(model="qwen3"))
|
|
400
|
+
|
|
401
|
+
>>> # 使用 OpenAI 模型
|
|
402
|
+
>>> dt.transform(messages_token_counter(model="gpt-4"))
|
|
267
403
|
|
|
268
404
|
>>> # 详细统计
|
|
269
405
|
>>> dt.transform(messages_token_counter(detailed=True))
|
|
@@ -297,7 +433,7 @@ def messages_token_filter(
|
|
|
297
433
|
max_tokens: Optional[int] = None,
|
|
298
434
|
min_turns: Optional[int] = None,
|
|
299
435
|
max_turns: Optional[int] = None,
|
|
300
|
-
model: str =
|
|
436
|
+
model: str = DEFAULT_MODEL,
|
|
301
437
|
backend: Optional[str] = None,
|
|
302
438
|
) -> Callable:
|
|
303
439
|
"""
|
|
@@ -309,7 +445,7 @@ def messages_token_filter(
|
|
|
309
445
|
max_tokens: 最大总 token 数
|
|
310
446
|
min_turns: 最小对话轮数
|
|
311
447
|
max_turns: 最大对话轮数
|
|
312
|
-
model:
|
|
448
|
+
model: 模型名称或别名
|
|
313
449
|
backend: 后端,None 则自动检测
|
|
314
450
|
|
|
315
451
|
Returns:
|
|
@@ -317,7 +453,7 @@ def messages_token_filter(
|
|
|
317
453
|
|
|
318
454
|
Examples:
|
|
319
455
|
>>> dt.filter(messages_token_filter(min_tokens=100, max_tokens=2048))
|
|
320
|
-
>>> dt.filter(messages_token_filter(min_turns=2, max_turns=10))
|
|
456
|
+
>>> dt.filter(messages_token_filter(min_turns=2, max_turns=10, model="qwen3"))
|
|
321
457
|
"""
|
|
322
458
|
_backend = backend or _auto_backend(model)
|
|
323
459
|
|
|
@@ -346,7 +482,7 @@ def messages_token_filter(
|
|
|
346
482
|
def messages_token_stats(
|
|
347
483
|
data: List[Dict[str, Any]],
|
|
348
484
|
messages_field: str = "messages",
|
|
349
|
-
model: str =
|
|
485
|
+
model: str = DEFAULT_MODEL,
|
|
350
486
|
backend: Optional[str] = None,
|
|
351
487
|
) -> Dict[str, Any]:
|
|
352
488
|
"""
|
|
@@ -355,14 +491,15 @@ def messages_token_stats(
|
|
|
355
491
|
Args:
|
|
356
492
|
data: 数据列表
|
|
357
493
|
messages_field: messages 字段名
|
|
358
|
-
model:
|
|
494
|
+
model: 模型名称或别名
|
|
359
495
|
backend: 后端,None 则自动检测
|
|
360
496
|
|
|
361
497
|
Returns:
|
|
362
498
|
统计信息字典
|
|
363
499
|
|
|
364
500
|
Examples:
|
|
365
|
-
>>> stats = messages_token_stats(dt.data
|
|
501
|
+
>>> stats = messages_token_stats(dt.data) # 使用默认 qwen2.5
|
|
502
|
+
>>> stats = messages_token_stats(dt.data, model="qwen3")
|
|
366
503
|
>>> print(stats)
|
|
367
504
|
{
|
|
368
505
|
"count": 1000,
|
dtflow/utils/display.py
CHANGED
|
@@ -2,7 +2,8 @@
|
|
|
2
2
|
Data display utilities.
|
|
3
3
|
"""
|
|
4
4
|
from typing import List, Dict, Any, Optional
|
|
5
|
-
|
|
5
|
+
|
|
6
|
+
import orjson
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
def display_data(data: List[Dict[str, Any]],
|
|
@@ -52,7 +53,7 @@ def _display_with_rich(data: List[Dict[str, Any]],
|
|
|
52
53
|
display_item = {k: v for k, v in item.items() if k in fields}
|
|
53
54
|
|
|
54
55
|
# Create a panel for each item
|
|
55
|
-
json_str =
|
|
56
|
+
json_str = orjson.dumps(display_item, option=orjson.OPT_INDENT_2).decode("utf-8")
|
|
56
57
|
|
|
57
58
|
panel = Panel(
|
|
58
59
|
JSON(json_str, indent=2),
|
|
@@ -84,7 +85,7 @@ def _display_plain(data: List[Dict[str, Any]],
|
|
|
84
85
|
display_item = {k: v for k, v in item.items() if k in fields}
|
|
85
86
|
|
|
86
87
|
# Pretty print JSON
|
|
87
|
-
print(
|
|
88
|
+
print(orjson.dumps(display_item, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
|
88
89
|
|
|
89
90
|
print(f"\n{separator}\n")
|
|
90
91
|
|
|
@@ -100,7 +101,7 @@ def format_item(item: Dict[str, Any], max_width: int = 80) -> str:
|
|
|
100
101
|
Returns:
|
|
101
102
|
Formatted string
|
|
102
103
|
"""
|
|
103
|
-
return
|
|
104
|
+
return orjson.dumps(item, option=orjson.OPT_INDENT_2).decode("utf-8")
|
|
104
105
|
|
|
105
106
|
|
|
106
107
|
def preview_fields(data: List[Dict[str, Any]], n: int = 5) -> Dict[str, List[Any]]:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dtflow
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.1
|
|
4
4
|
Summary: A flexible data transformation tool for ML training formats (SFT, RLHF, Pretrain)
|
|
5
5
|
Project-URL: Homepage, https://github.com/yourusername/DataTransformer
|
|
6
6
|
Project-URL: Documentation, https://github.com/yourusername/DataTransformer#readme
|
|
@@ -27,11 +27,12 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
27
27
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
28
28
|
Classifier: Topic :: Text Processing
|
|
29
29
|
Requires-Python: >=3.8
|
|
30
|
-
Requires-Dist: fire>=0.4.0
|
|
31
30
|
Requires-Dist: numpy>=1.20.0
|
|
32
31
|
Requires-Dist: orjson>=3.9.0
|
|
32
|
+
Requires-Dist: polars>=0.20.0
|
|
33
33
|
Requires-Dist: pyyaml>=5.4.0
|
|
34
|
-
Requires-Dist:
|
|
34
|
+
Requires-Dist: rich>=10.0.0
|
|
35
|
+
Requires-Dist: typer>=0.9.0
|
|
35
36
|
Provides-Extra: converters
|
|
36
37
|
Requires-Dist: datasets>=2.0.0; extra == 'converters'
|
|
37
38
|
Provides-Extra: dev
|
|
@@ -42,7 +43,6 @@ Requires-Dist: mypy>=0.910; extra == 'dev'
|
|
|
42
43
|
Requires-Dist: pytest-cov>=2.12.0; extra == 'dev'
|
|
43
44
|
Requires-Dist: pytest>=6.0.0; extra == 'dev'
|
|
44
45
|
Provides-Extra: display
|
|
45
|
-
Requires-Dist: rich>=10.0.0; extra == 'display'
|
|
46
46
|
Provides-Extra: docs
|
|
47
47
|
Requires-Dist: myst-parser>=0.15.0; extra == 'docs'
|
|
48
48
|
Requires-Dist: sphinx-rtd-theme>=0.5.0; extra == 'docs'
|
|
@@ -50,21 +50,28 @@ Requires-Dist: sphinx>=4.0.0; extra == 'docs'
|
|
|
50
50
|
Provides-Extra: full
|
|
51
51
|
Requires-Dist: datasets>=2.0.0; extra == 'full'
|
|
52
52
|
Requires-Dist: datasketch>=1.5.0; extra == 'full'
|
|
53
|
-
Requires-Dist:
|
|
53
|
+
Requires-Dist: huggingface-hub>=0.20.0; extra == 'full'
|
|
54
54
|
Requires-Dist: pyarrow; extra == 'full'
|
|
55
55
|
Requires-Dist: rich>=10.0.0; extra == 'full'
|
|
56
56
|
Requires-Dist: scikit-learn>=0.24.0; extra == 'full'
|
|
57
57
|
Requires-Dist: tiktoken>=0.5.0; extra == 'full'
|
|
58
|
+
Requires-Dist: tokenizers>=0.15.0; extra == 'full'
|
|
59
|
+
Requires-Dist: toolong>=1.5.0; extra == 'full'
|
|
60
|
+
Provides-Extra: logs
|
|
61
|
+
Requires-Dist: toolong>=1.5.0; extra == 'logs'
|
|
58
62
|
Provides-Extra: mcp
|
|
59
63
|
Requires-Dist: mcp>=1.0.0; extra == 'mcp'
|
|
60
64
|
Provides-Extra: similarity
|
|
61
65
|
Requires-Dist: datasketch>=1.5.0; extra == 'similarity'
|
|
62
66
|
Requires-Dist: scikit-learn>=0.24.0; extra == 'similarity'
|
|
63
67
|
Provides-Extra: storage
|
|
64
|
-
Requires-Dist: pandas>=1.3.0; extra == 'storage'
|
|
65
68
|
Requires-Dist: pyarrow; extra == 'storage'
|
|
66
69
|
Provides-Extra: tokenizers
|
|
67
70
|
Requires-Dist: tiktoken>=0.5.0; extra == 'tokenizers'
|
|
71
|
+
Provides-Extra: tokenizers-hf
|
|
72
|
+
Requires-Dist: huggingface-hub>=0.20.0; extra == 'tokenizers-hf'
|
|
73
|
+
Requires-Dist: tiktoken>=0.5.0; extra == 'tokenizers-hf'
|
|
74
|
+
Requires-Dist: tokenizers>=0.15.0; extra == 'tokenizers-hf'
|
|
68
75
|
Description-Content-Type: text/markdown
|
|
69
76
|
|
|
70
77
|
# dtflow
|
|
@@ -101,7 +108,7 @@ dt = DataTransformer.load("data.jsonl")
|
|
|
101
108
|
### 数据加载与保存
|
|
102
109
|
|
|
103
110
|
```python
|
|
104
|
-
# 支持 JSONL、JSON、CSV、Parquet
|
|
111
|
+
# 支持 JSONL、JSON、CSV、Parquet、Arrow(使用 Polars 引擎,比 Pandas 快 3x)
|
|
105
112
|
dt = DataTransformer.load("data.jsonl")
|
|
106
113
|
dt.save("output.jsonl")
|
|
107
114
|
|
|
@@ -293,6 +300,7 @@ dt.shuffle(seed=42)
|
|
|
293
300
|
# 数据采样
|
|
294
301
|
dt sample data.jsonl --num=10
|
|
295
302
|
dt sample data.csv --num=100 --sample_type=head
|
|
303
|
+
dt sample data.jsonl 1000 --by=category # 分层采样
|
|
296
304
|
|
|
297
305
|
# 数据转换 - 预设模式
|
|
298
306
|
dt transform data.jsonl --preset=openai_chat
|
|
@@ -303,6 +311,18 @@ dt transform data.jsonl # 首次运行生成配置文件
|
|
|
303
311
|
# 编辑 .dt/data.py 后再次运行
|
|
304
312
|
dt transform data.jsonl --num=100 # 执行转换
|
|
305
313
|
|
|
314
|
+
# Pipeline 执行(可复现的数据处理流程)
|
|
315
|
+
dt run pipeline.yaml
|
|
316
|
+
dt run pipeline.yaml --input=new_data.jsonl --output=result.jsonl
|
|
317
|
+
|
|
318
|
+
# Token 统计
|
|
319
|
+
dt token-stats data.jsonl --field=messages --model=gpt-4
|
|
320
|
+
dt token-stats data.jsonl --field=text --detailed
|
|
321
|
+
|
|
322
|
+
# 数据对比
|
|
323
|
+
dt diff v1/train.jsonl v2/train.jsonl
|
|
324
|
+
dt diff a.jsonl b.jsonl --key=id
|
|
325
|
+
|
|
306
326
|
# 数据清洗
|
|
307
327
|
dt clean data.jsonl --drop-empty # 删除任意空值记录
|
|
308
328
|
dt clean data.jsonl --drop-empty=text,answer # 删除指定字段为空的记录
|
|
@@ -325,6 +345,132 @@ dt concat a.jsonl b.jsonl -o merged.jsonl
|
|
|
325
345
|
dt stats data.jsonl
|
|
326
346
|
```
|
|
327
347
|
|
|
348
|
+
### Pipeline 配置
|
|
349
|
+
|
|
350
|
+
使用 YAML 配置文件定义可复现的数据处理流程:
|
|
351
|
+
|
|
352
|
+
```yaml
|
|
353
|
+
# pipeline.yaml
|
|
354
|
+
version: "1.0"
|
|
355
|
+
seed: 42
|
|
356
|
+
input: raw_data.jsonl
|
|
357
|
+
output: processed.jsonl
|
|
358
|
+
|
|
359
|
+
steps:
|
|
360
|
+
- type: filter
|
|
361
|
+
condition: "score > 0.5"
|
|
362
|
+
|
|
363
|
+
- type: filter
|
|
364
|
+
condition: "len(text) > 10"
|
|
365
|
+
|
|
366
|
+
- type: transform
|
|
367
|
+
preset: openai_chat
|
|
368
|
+
params:
|
|
369
|
+
user_field: q
|
|
370
|
+
assistant_field: a
|
|
371
|
+
|
|
372
|
+
- type: dedupe
|
|
373
|
+
key: text
|
|
374
|
+
```
|
|
375
|
+
|
|
376
|
+
支持的步骤类型:
|
|
377
|
+
|
|
378
|
+
| 步骤 | 参数 | 说明 |
|
|
379
|
+
|------|------|------|
|
|
380
|
+
| `filter` | `condition` | 条件过滤:`score > 0.5`, `len(text) > 10`, `field is not empty` |
|
|
381
|
+
| `transform` | `preset`, `params` | 格式转换,使用预设模板 |
|
|
382
|
+
| `dedupe` | `key`, `similar` | 去重,支持精确和相似度去重 |
|
|
383
|
+
| `sample` | `num`, `seed` | 随机采样 |
|
|
384
|
+
| `head` | `num` | 取前 N 条 |
|
|
385
|
+
| `tail` | `num` | 取后 N 条 |
|
|
386
|
+
| `shuffle` | `seed` | 打乱顺序 |
|
|
387
|
+
| `split` | `ratio`, `seed` | 数据集分割 |
|
|
388
|
+
|
|
389
|
+
执行 Pipeline:
|
|
390
|
+
|
|
391
|
+
```bash
|
|
392
|
+
dt run pipeline.yaml
|
|
393
|
+
dt run pipeline.yaml --input=new_data.jsonl # 覆盖输入文件
|
|
394
|
+
```
|
|
395
|
+
|
|
396
|
+
### 数据血缘追踪
|
|
397
|
+
|
|
398
|
+
记录数据处理的完整历史,支持可复现和问题追溯:
|
|
399
|
+
|
|
400
|
+
```python
|
|
401
|
+
# 启用血缘追踪
|
|
402
|
+
dt = DataTransformer.load("raw.jsonl", track_lineage=True)
|
|
403
|
+
|
|
404
|
+
# 正常进行数据处理
|
|
405
|
+
result = (dt
|
|
406
|
+
.filter(lambda x: x.score > 0.5)
|
|
407
|
+
.transform(lambda x: {"q": x.q, "a": x.a})
|
|
408
|
+
.dedupe("q")
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
# 保存时记录血缘
|
|
412
|
+
result.save("processed.jsonl", lineage=True)
|
|
413
|
+
# 自动生成 processed.jsonl.lineage.json
|
|
414
|
+
```
|
|
415
|
+
|
|
416
|
+
查看血缘历史:
|
|
417
|
+
|
|
418
|
+
```bash
|
|
419
|
+
dt history processed.jsonl
|
|
420
|
+
# 输出:
|
|
421
|
+
# 📊 数据血缘报告: processed.jsonl
|
|
422
|
+
# └─ 版本 1
|
|
423
|
+
# 来源: raw.jsonl
|
|
424
|
+
# 操作链:
|
|
425
|
+
# ├─ filter: 1000 → 800
|
|
426
|
+
# ├─ transform: 800 → 800
|
|
427
|
+
# └─ dedupe: 800 → 750
|
|
428
|
+
# 输出数量: 750
|
|
429
|
+
|
|
430
|
+
dt history processed.jsonl --json # JSON 格式输出
|
|
431
|
+
```
|
|
432
|
+
|
|
433
|
+
### 大文件流式处理
|
|
434
|
+
|
|
435
|
+
专为超大文件设计的流式处理接口,内存占用 O(1),支持 JSONL、CSV、Parquet、Arrow 格式:
|
|
436
|
+
|
|
437
|
+
```python
|
|
438
|
+
from dtflow import load_stream, load_sharded
|
|
439
|
+
|
|
440
|
+
# 流式加载和处理(100GB 文件也只用常量内存)
|
|
441
|
+
(load_stream("huge_100gb.jsonl")
|
|
442
|
+
.filter(lambda x: x["score"] > 0.5)
|
|
443
|
+
.transform(lambda x: {"text": x["content"]})
|
|
444
|
+
.save("output.jsonl"))
|
|
445
|
+
|
|
446
|
+
# 跨格式转换(CSV → Parquet)
|
|
447
|
+
(load_stream("data.csv")
|
|
448
|
+
.filter(lambda x: x["score"] > 0.5)
|
|
449
|
+
.save("output.parquet"))
|
|
450
|
+
|
|
451
|
+
# 分片文件加载(支持多格式)
|
|
452
|
+
(load_sharded("data/train_*.parquet")
|
|
453
|
+
.filter(lambda x: len(x["text"]) > 10)
|
|
454
|
+
.save("merged.jsonl"))
|
|
455
|
+
|
|
456
|
+
# 分片保存
|
|
457
|
+
(load_stream("huge.jsonl")
|
|
458
|
+
.transform(lambda x: {"q": x["question"], "a": x["answer"]})
|
|
459
|
+
.save_sharded("output/", shard_size=100000))
|
|
460
|
+
# 生成: output/part-00000.jsonl, output/part-00001.jsonl, ...
|
|
461
|
+
|
|
462
|
+
# 批次处理(适合需要批量调用 API 的场景)
|
|
463
|
+
for batch in load_stream("data.jsonl").batch(1000):
|
|
464
|
+
results = call_api(batch) # 批量处理
|
|
465
|
+
```
|
|
466
|
+
|
|
467
|
+
特点:
|
|
468
|
+
- **惰性执行**:filter/transform 不会立即执行,只在 save/collect 时才触发
|
|
469
|
+
- **O(1) 内存**:无论文件多大,内存占用恒定(读取侧)
|
|
470
|
+
- **多格式支持**:JSONL、CSV、Parquet、Arrow 均支持流式处理
|
|
471
|
+
- **跨格式转换**:可直接从 CSV 读取并保存为 Parquet 等
|
|
472
|
+
- **分片支持**:支持 glob 模式加载多个分片,自动合并处理
|
|
473
|
+
|
|
328
474
|
## 错误处理
|
|
329
475
|
|
|
330
476
|
```python
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
dtflow/__init__.py,sha256=H3bkYYp8XL2NTWZVLgYcTUAeMDA3BvRLxM0E8-qo2-U,2296
|
|
2
|
+
dtflow/__main__.py,sha256=DqE7wYJ-DXTtA0dHi7x7xGg23gP9vtOlcbuYjadFKFQ,10353
|
|
3
|
+
dtflow/converters.py,sha256=cONDqSMlIQKHBS6fZ2pj1BPStIWXp62LqAbgddhDH14,21860
|
|
4
|
+
dtflow/core.py,sha256=1bBPS1QJ6IB_qlXQeOmu4BSu2sFPdJxKekPL6U2c79c,27903
|
|
5
|
+
dtflow/lineage.py,sha256=sCG-p70RTUb6PawMTiSoFSS-dM0DI2Jlf69_5yIt0mg,12146
|
|
6
|
+
dtflow/pipeline.py,sha256=JbMMZn211v-ARFamFupS5GlRdmt8iTvgzTuymneWMAI,13701
|
|
7
|
+
dtflow/presets.py,sha256=ZoTBWLAgCEGRZUPMRwBZ_4in6Gfb4u8_Avt0UCjF3C4,4970
|
|
8
|
+
dtflow/streaming.py,sha256=KMWr9_YzmRXzRv-dI7BzJkWGS9u1Gf19yNHo11kQBu8,21618
|
|
9
|
+
dtflow/tokenizers.py,sha256=ru1SCeiszmVyRBpQ3qoiBa1H2WSaT_1wc-Akj8_JaHw,17770
|
|
10
|
+
dtflow/cli/__init__.py,sha256=2W6vhT9W7-woP8WWTtTLuksB0GfUaR-0X5KdCIA7aYc,319
|
|
11
|
+
dtflow/cli/commands.py,sha256=DuQhoeu1x9OPZZKsl27yyX5WB1-a6vPcnCx6y7Bzdgc,81213
|
|
12
|
+
dtflow/mcp/__init__.py,sha256=8yperojW-yc_40DnC-fRurYR2C_rutDbGdoMJaPeGww,895
|
|
13
|
+
dtflow/mcp/__main__.py,sha256=BRki4AdGIIY5O1Ly1Sa4PKdFPlCTeLXTgecl6KG41tU,444
|
|
14
|
+
dtflow/mcp/cli.py,sha256=GUM-QffXp_kuHVLPCN_LuhYnhLP1ZjSn99o3E_4mfnI,12952
|
|
15
|
+
dtflow/mcp/docs.py,sha256=YQ9P7Kb45h1MSuNucw3hR4FDlbTAJgqIascytCdLrxA,8844
|
|
16
|
+
dtflow/mcp/server.py,sha256=Nf0UlqDGhV55ndGuEglfr7VRjDWAC_9rRsNhdr0-ssM,4275
|
|
17
|
+
dtflow/storage/__init__.py,sha256=8TZVXPpmz882OjLgoXEDeQQFLyPqZApX_GM0nFnxhlc,360
|
|
18
|
+
dtflow/storage/io.py,sha256=7EnqeYt9iNZaIz1dPZweQaQAzXYnudAMYdkGXEHiBFI,21791
|
|
19
|
+
dtflow/utils/__init__.py,sha256=NY_s-2r0R-199mpRLMk_jo5O-aDsKhRxEwRuZHRcTqs,113
|
|
20
|
+
dtflow/utils/display.py,sha256=ZxojDIGYvQco9DnSsq7dE4NYLitX6FqspDLp9tuSUYM,6043
|
|
21
|
+
dtflow-0.3.1.dist-info/METADATA,sha256=WwnC1hM0RVL7l_ctU7Y1G8UCwmuxthmRmohSTDslRBA,16326
|
|
22
|
+
dtflow-0.3.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
23
|
+
dtflow-0.3.1.dist-info/entry_points.txt,sha256=dadIDOK7Iu9pMxnMPBfpb4aAPe4hQbBOshpQYjVYpGc,44
|
|
24
|
+
dtflow-0.3.1.dist-info/RECORD,,
|