dtflow 0.5.2__tar.gz → 0.5.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtflow-0.5.3/CHANGELOG.md +19 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/PKG-INFO +1 -1
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/__init__.py +7 -7
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/converters.py +17 -13
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/presets.py +14 -15
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/utils/__init__.py +3 -0
- dtflow-0.5.3/dtflow/utils/helpers.py +30 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/pyproject.toml +6 -4
- dtflow-0.5.3/tests/test_cli_clean.py +314 -0
- dtflow-0.5.3/tests/test_cli_sample.py +242 -0
- dtflow-0.5.3/tests/test_cli_stats.py +213 -0
- dtflow-0.5.3/tests/test_cli_transform.py +304 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/.gitignore +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/README.md +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/__main__.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/cli/__init__.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/cli/clean.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/cli/commands.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/cli/common.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/cli/io_ops.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/cli/lineage.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/cli/pipeline.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/cli/sample.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/cli/stats.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/cli/transform.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/cli/validate.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/core.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/framework.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/lineage.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/mcp/__init__.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/mcp/__main__.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/mcp/cli.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/mcp/docs.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/mcp/server.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/pipeline.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/schema.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/storage/__init__.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/storage/io.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/streaming.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/tokenizers.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/utils/display.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/utils/field_path.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/tests/README.md +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/tests/benchmark_io.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/tests/benchmark_sharegpt.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/tests/test_cli_benchmark.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/tests/test_converters.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/tests/test_field_path.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/tests/test_framework.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/tests/test_io.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/tests/test_lineage.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/tests/test_pipeline.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/tests/test_schema.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/tests/test_streaming.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/tests/test_tokenizers.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.3}/tests/test_transformer.py +0 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Changelog
|
|
2
|
+
|
|
3
|
+
## [0.5.2] - 2026-01-18
|
|
4
|
+
|
|
5
|
+
### Miscellaneous
|
|
6
|
+
|
|
7
|
+
- Bump version to 0.5.2
|
|
8
|
+
- 添加 pre-commit 配置和发版脚本
|
|
9
|
+
|
|
10
|
+
## [0.5.1] - 2026-01-18
|
|
11
|
+
|
|
12
|
+
### Features
|
|
13
|
+
|
|
14
|
+
- 优化 sample 命令文本预览显示
|
|
15
|
+
|
|
16
|
+
### Testing
|
|
17
|
+
|
|
18
|
+
- 添加测试运行说明
|
|
19
|
+
- 补充 tail/token-stats/validate 性能测试
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dtflow
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.3
|
|
4
4
|
Summary: A flexible data transformation tool for ML training formats (SFT, RLHF, Pretrain)
|
|
5
5
|
Project-URL: Homepage, https://github.com/yourusername/DataTransformer
|
|
6
6
|
Project-URL: Documentation, https://github.com/yourusername/DataTransformer#readme
|
|
@@ -26,6 +26,12 @@ from .converters import ( # LLaMA-Factory 扩展; ms-swift
|
|
|
26
26
|
to_swift_vlm,
|
|
27
27
|
)
|
|
28
28
|
from .core import DataTransformer, DictWrapper, TransformError, TransformErrors
|
|
29
|
+
from .framework import (
|
|
30
|
+
CompatibilityResult,
|
|
31
|
+
check_compatibility,
|
|
32
|
+
detect_format,
|
|
33
|
+
export_for,
|
|
34
|
+
)
|
|
29
35
|
from .presets import get_preset, list_presets
|
|
30
36
|
from .schema import (
|
|
31
37
|
Field,
|
|
@@ -38,12 +44,6 @@ from .schema import (
|
|
|
38
44
|
sharegpt_schema,
|
|
39
45
|
validate_data,
|
|
40
46
|
)
|
|
41
|
-
from .framework import (
|
|
42
|
-
CompatibilityResult,
|
|
43
|
-
check_compatibility,
|
|
44
|
-
detect_format,
|
|
45
|
-
export_for,
|
|
46
|
-
)
|
|
47
47
|
from .storage import load_data, sample_file, save_data
|
|
48
48
|
from .streaming import StreamingTransformer, load_sharded, load_stream, process_shards
|
|
49
49
|
from .tokenizers import (
|
|
@@ -60,7 +60,7 @@ from .tokenizers import (
|
|
|
60
60
|
token_stats,
|
|
61
61
|
)
|
|
62
62
|
|
|
63
|
-
__version__ = "0.5.
|
|
63
|
+
__version__ = "0.5.3"
|
|
64
64
|
|
|
65
65
|
__all__ = [
|
|
66
66
|
# core
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
提供与 HuggingFace datasets 等常用格式的互转功能。
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from typing import Any, Callable, Dict, List, Optional
|
|
7
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def to_hf_dataset(data: List[Dict[str, Any]]):
|
|
@@ -143,14 +143,16 @@ def to_openai_batch(
|
|
|
143
143
|
>>> batch_input = dt.to(to_openai_batch(model="gpt-4o"))
|
|
144
144
|
"""
|
|
145
145
|
|
|
146
|
-
|
|
146
|
+
counter = {"idx": 0}
|
|
147
|
+
|
|
148
|
+
def transform(item) -> dict:
|
|
147
149
|
messages = item.get(messages_field, []) if hasattr(item, "get") else item[messages_field]
|
|
148
150
|
|
|
149
151
|
if custom_id_field:
|
|
150
152
|
custom_id = item.get(custom_id_field) if hasattr(item, "get") else item[custom_id_field]
|
|
151
153
|
else:
|
|
152
|
-
custom_id = f"request-{idx
|
|
153
|
-
idx
|
|
154
|
+
custom_id = f"request-{counter['idx']}"
|
|
155
|
+
counter["idx"] += 1
|
|
154
156
|
|
|
155
157
|
return {
|
|
156
158
|
"custom_id": str(custom_id),
|
|
@@ -196,7 +198,7 @@ def to_llama_factory(
|
|
|
196
198
|
"""
|
|
197
199
|
|
|
198
200
|
def transform(item) -> dict:
|
|
199
|
-
get = lambda f:
|
|
201
|
+
get = lambda f: item.get(f, "") if hasattr(item, "get") else getattr(item, f, "")
|
|
200
202
|
|
|
201
203
|
result = {
|
|
202
204
|
"instruction": get(instruction_field),
|
|
@@ -248,7 +250,7 @@ def to_axolotl(
|
|
|
248
250
|
conversations = (
|
|
249
251
|
item.get(conversations_field, [])
|
|
250
252
|
if hasattr(item, "get")
|
|
251
|
-
else item
|
|
253
|
+
else getattr(item, conversations_field, [])
|
|
252
254
|
)
|
|
253
255
|
|
|
254
256
|
# 如果已经是正确格式,直接返回
|
|
@@ -257,7 +259,9 @@ def to_axolotl(
|
|
|
257
259
|
return {"conversations": conversations}
|
|
258
260
|
|
|
259
261
|
# 尝试从 messages 格式转换
|
|
260
|
-
messages =
|
|
262
|
+
messages = (
|
|
263
|
+
item.get("messages", []) if hasattr(item, "get") else getattr(item, "messages", [])
|
|
264
|
+
)
|
|
261
265
|
if messages:
|
|
262
266
|
role_map = {"user": "human", "assistant": "gpt", "system": "system"}
|
|
263
267
|
conversations = [
|
|
@@ -312,7 +316,7 @@ def to_llama_factory_sharegpt(
|
|
|
312
316
|
}
|
|
313
317
|
|
|
314
318
|
def transform(item) -> dict:
|
|
315
|
-
get = lambda f:
|
|
319
|
+
get = lambda f: item.get(f, "") if hasattr(item, "get") else getattr(item, f, "")
|
|
316
320
|
messages = get(messages_field) or []
|
|
317
321
|
|
|
318
322
|
conversations = []
|
|
@@ -385,7 +389,7 @@ def to_llama_factory_vlm(
|
|
|
385
389
|
"""
|
|
386
390
|
|
|
387
391
|
def transform(item) -> dict:
|
|
388
|
-
get = lambda f: item.get(f) if hasattr(item, "get") else item
|
|
392
|
+
get = lambda f: item.get(f) if hasattr(item, "get") else getattr(item, f, None)
|
|
389
393
|
messages = get(messages_field) or []
|
|
390
394
|
|
|
391
395
|
instruction = ""
|
|
@@ -467,7 +471,7 @@ def to_llama_factory_vlm_sharegpt(
|
|
|
467
471
|
role_map = {"user": "human", "assistant": "gpt", "system": "system"}
|
|
468
472
|
|
|
469
473
|
def transform(item) -> dict:
|
|
470
|
-
get = lambda f: item.get(f) if hasattr(item, "get") else item
|
|
474
|
+
get = lambda f: item.get(f) if hasattr(item, "get") else getattr(item, f, None)
|
|
471
475
|
messages = get(messages_field) or []
|
|
472
476
|
|
|
473
477
|
conversations = []
|
|
@@ -541,7 +545,7 @@ def to_swift_messages(
|
|
|
541
545
|
"""
|
|
542
546
|
|
|
543
547
|
def transform(item) -> dict:
|
|
544
|
-
get = lambda f: item.get(f) if hasattr(item, "get") else item
|
|
548
|
+
get = lambda f: item.get(f) if hasattr(item, "get") else getattr(item, f, None)
|
|
545
549
|
messages = get(messages_field) or []
|
|
546
550
|
|
|
547
551
|
# 复制 messages,避免修改原数据
|
|
@@ -600,7 +604,7 @@ def to_swift_query_response(
|
|
|
600
604
|
"""
|
|
601
605
|
|
|
602
606
|
def transform(item) -> dict:
|
|
603
|
-
get = lambda f: item.get(f) if hasattr(item, "get") else item
|
|
607
|
+
get = lambda f: item.get(f) if hasattr(item, "get") else getattr(item, f, None)
|
|
604
608
|
|
|
605
609
|
query = get(query_field)
|
|
606
610
|
response = get(response_field)
|
|
@@ -693,7 +697,7 @@ def to_swift_vlm(
|
|
|
693
697
|
"""
|
|
694
698
|
|
|
695
699
|
def transform(item) -> dict:
|
|
696
|
-
get = lambda f: item.get(f) if hasattr(item, "get") else item
|
|
700
|
+
get = lambda f: item.get(f) if hasattr(item, "get") else getattr(item, f, None)
|
|
697
701
|
messages = get(messages_field) or []
|
|
698
702
|
|
|
699
703
|
result_messages = []
|
|
@@ -6,6 +6,8 @@
|
|
|
6
6
|
|
|
7
7
|
from typing import Any, Callable
|
|
8
8
|
|
|
9
|
+
from dtflow.utils.helpers import get_field_value
|
|
10
|
+
|
|
9
11
|
|
|
10
12
|
def openai_chat(
|
|
11
13
|
user_field: str = "q", assistant_field: str = "a", system_prompt: str = None
|
|
@@ -33,8 +35,8 @@ def openai_chat(
|
|
|
33
35
|
if system_prompt:
|
|
34
36
|
messages.append({"role": "system", "content": system_prompt})
|
|
35
37
|
|
|
36
|
-
user_content =
|
|
37
|
-
assistant_content =
|
|
38
|
+
user_content = get_field_value(item, user_field)
|
|
39
|
+
assistant_content = get_field_value(item, assistant_field)
|
|
38
40
|
|
|
39
41
|
messages.append({"role": "user", "content": user_content})
|
|
40
42
|
messages.append({"role": "assistant", "content": assistant_content})
|
|
@@ -60,10 +62,9 @@ def alpaca(
|
|
|
60
62
|
|
|
61
63
|
def transform(item: Any) -> dict:
|
|
62
64
|
return {
|
|
63
|
-
"instruction":
|
|
64
|
-
|
|
65
|
-
"
|
|
66
|
-
"output": getattr(item, output_field, None) or item.get(output_field, ""),
|
|
65
|
+
"instruction": get_field_value(item, instruction_field),
|
|
66
|
+
"input": get_field_value(item, input_field),
|
|
67
|
+
"output": get_field_value(item, output_field),
|
|
67
68
|
}
|
|
68
69
|
|
|
69
70
|
return transform
|
|
@@ -84,9 +85,7 @@ def sharegpt(conversations_field: str = "conversations", role_mapping: dict = No
|
|
|
84
85
|
role_mapping = role_mapping or {"user": "human", "assistant": "gpt"}
|
|
85
86
|
|
|
86
87
|
def transform(item: Any) -> dict:
|
|
87
|
-
conversations =
|
|
88
|
-
conversations_field, []
|
|
89
|
-
)
|
|
88
|
+
conversations = get_field_value(item, conversations_field, [])
|
|
90
89
|
|
|
91
90
|
# 如果已经是对话格式,直接返回
|
|
92
91
|
if conversations:
|
|
@@ -102,7 +101,7 @@ def sharegpt(conversations_field: str = "conversations", role_mapping: dict = No
|
|
|
102
101
|
("answer", "gpt"),
|
|
103
102
|
("output", "gpt"),
|
|
104
103
|
]:
|
|
105
|
-
value =
|
|
104
|
+
value = get_field_value(item, field, None)
|
|
106
105
|
if value:
|
|
107
106
|
result.append({"from": role, "value": value})
|
|
108
107
|
|
|
@@ -127,9 +126,9 @@ def dpo_pair(
|
|
|
127
126
|
|
|
128
127
|
def transform(item: Any) -> dict:
|
|
129
128
|
return {
|
|
130
|
-
"prompt":
|
|
131
|
-
"chosen":
|
|
132
|
-
"rejected":
|
|
129
|
+
"prompt": get_field_value(item, prompt_field),
|
|
130
|
+
"chosen": get_field_value(item, chosen_field),
|
|
131
|
+
"rejected": get_field_value(item, rejected_field),
|
|
133
132
|
}
|
|
134
133
|
|
|
135
134
|
return transform
|
|
@@ -148,8 +147,8 @@ def simple_qa(question_field: str = "q", answer_field: str = "a") -> Callable:
|
|
|
148
147
|
|
|
149
148
|
def transform(item: Any) -> dict:
|
|
150
149
|
return {
|
|
151
|
-
"question":
|
|
152
|
-
"answer":
|
|
150
|
+
"question": get_field_value(item, question_field),
|
|
151
|
+
"answer": get_field_value(item, answer_field),
|
|
153
152
|
}
|
|
154
153
|
|
|
155
154
|
return transform
|
|
@@ -9,6 +9,7 @@ from .field_path import (
|
|
|
9
9
|
get_field_with_spec,
|
|
10
10
|
parse_field_spec,
|
|
11
11
|
)
|
|
12
|
+
from .helpers import get_field_value
|
|
12
13
|
|
|
13
14
|
__all__ = [
|
|
14
15
|
"display_data",
|
|
@@ -20,4 +21,6 @@ __all__ = [
|
|
|
20
21
|
"extract",
|
|
21
22
|
"extract_with_spec",
|
|
22
23
|
"ExpandMode",
|
|
24
|
+
# helpers
|
|
25
|
+
"get_field_value",
|
|
23
26
|
]
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""公共辅助函数"""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_field_value(item: Any, field: str, default: Any = "") -> Any:
|
|
7
|
+
"""
|
|
8
|
+
获取字段值,支持 DictWrapper 和普通 dict。
|
|
9
|
+
|
|
10
|
+
优先尝试 dict.get(),如果没有 get 方法则使用 getattr()。
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
item: 数据对象(dict 或 DictWrapper)
|
|
14
|
+
field: 字段名
|
|
15
|
+
default: 默认值
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
字段值或默认值
|
|
19
|
+
|
|
20
|
+
Examples:
|
|
21
|
+
>>> get_field_value({"name": "test"}, "name")
|
|
22
|
+
'test'
|
|
23
|
+
>>> get_field_value({"name": ""}, "name", "default")
|
|
24
|
+
'default'
|
|
25
|
+
"""
|
|
26
|
+
if hasattr(item, "get"):
|
|
27
|
+
value = item.get(field, default)
|
|
28
|
+
else:
|
|
29
|
+
value = getattr(item, field, default)
|
|
30
|
+
return value if value else default
|
|
@@ -300,8 +300,10 @@ ignore_missing_imports = true
|
|
|
300
300
|
|
|
301
301
|
# Ruff configuration (optional alternative to flake8)
|
|
302
302
|
[tool.ruff]
|
|
303
|
-
target-version = "
|
|
303
|
+
target-version = "py38"
|
|
304
304
|
line-length = 100
|
|
305
|
+
|
|
306
|
+
[tool.ruff.lint]
|
|
305
307
|
select = [
|
|
306
308
|
"E", # pycodestyle errors
|
|
307
309
|
"W", # pycodestyle warnings
|
|
@@ -311,13 +313,13 @@ select = [
|
|
|
311
313
|
"B", # flake8-bugbear
|
|
312
314
|
]
|
|
313
315
|
ignore = [
|
|
314
|
-
"E501", # line too long, handled by
|
|
316
|
+
"E501", # line too long, handled by ruff-format
|
|
315
317
|
"B008", # do not perform function calls in argument defaults
|
|
316
318
|
"C901", # too complex
|
|
317
319
|
]
|
|
318
320
|
|
|
319
|
-
[tool.ruff.per-file-ignores]
|
|
321
|
+
[tool.ruff.lint.per-file-ignores]
|
|
320
322
|
"__init__.py" = ["F401"]
|
|
321
323
|
|
|
322
|
-
[tool.ruff.isort]
|
|
324
|
+
[tool.ruff.lint.isort]
|
|
323
325
|
known-first-party = ["dtflow"]
|
|
@@ -0,0 +1,314 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tests for CLI clean and dedupe commands.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from dtflow.cli.clean import _clean_data_single_pass, _parse_len_param, clean, dedupe
|
|
8
|
+
from dtflow.storage.io import load_data, save_data
|
|
9
|
+
|
|
10
|
+
# ============== Fixtures ==============
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@pytest.fixture
|
|
14
|
+
def sample_data_file(tmp_path):
|
|
15
|
+
"""Create a sample dataset file."""
|
|
16
|
+
data = [
|
|
17
|
+
{"text": "Hello world", "score": 0.8, "category": "greeting"},
|
|
18
|
+
{"text": "How are you?", "score": 0.9, "category": "question"},
|
|
19
|
+
{"text": " Needs trimming ", "score": 0.7, "category": "test"},
|
|
20
|
+
{"text": "", "score": 0.6, "category": "empty_text"},
|
|
21
|
+
{"text": "Short", "score": None, "category": None},
|
|
22
|
+
{"text": "Hello world", "score": 0.85, "category": "duplicate"}, # duplicate text
|
|
23
|
+
]
|
|
24
|
+
filepath = tmp_path / "test_data.jsonl"
|
|
25
|
+
save_data(data, str(filepath))
|
|
26
|
+
return filepath, data
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@pytest.fixture
|
|
30
|
+
def sample_nested_file(tmp_path):
|
|
31
|
+
"""Create a sample dataset with nested fields."""
|
|
32
|
+
data = [
|
|
33
|
+
{
|
|
34
|
+
"id": 1,
|
|
35
|
+
"meta": {"source": "web", "score": 0.9},
|
|
36
|
+
"messages": [
|
|
37
|
+
{"role": "user", "content": "Hello"},
|
|
38
|
+
{"role": "assistant", "content": "Hi!"},
|
|
39
|
+
],
|
|
40
|
+
},
|
|
41
|
+
{
|
|
42
|
+
"id": 2,
|
|
43
|
+
"meta": {"source": "api", "score": 0.8},
|
|
44
|
+
"messages": [{"role": "user", "content": "Hi"}],
|
|
45
|
+
},
|
|
46
|
+
{
|
|
47
|
+
"id": 3,
|
|
48
|
+
"meta": {"source": None, "score": 0.5},
|
|
49
|
+
"messages": [
|
|
50
|
+
{"role": "user", "content": "A"},
|
|
51
|
+
{"role": "assistant", "content": "B"},
|
|
52
|
+
{"role": "user", "content": "C"},
|
|
53
|
+
],
|
|
54
|
+
},
|
|
55
|
+
]
|
|
56
|
+
filepath = tmp_path / "test_nested.jsonl"
|
|
57
|
+
save_data(data, str(filepath))
|
|
58
|
+
return filepath, data
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# ============== Clean Command Tests ==============
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class TestCleanBasic:
|
|
65
|
+
"""Test basic clean functionality."""
|
|
66
|
+
|
|
67
|
+
def test_clean_drop_empty(self, sample_data_file, tmp_path):
|
|
68
|
+
"""Test dropping empty records."""
|
|
69
|
+
filepath, _ = sample_data_file
|
|
70
|
+
output = tmp_path / "output.jsonl"
|
|
71
|
+
|
|
72
|
+
clean(str(filepath), drop_empty="text", output=str(output))
|
|
73
|
+
|
|
74
|
+
result = load_data(str(output))
|
|
75
|
+
# Should remove the record with empty text
|
|
76
|
+
for item in result:
|
|
77
|
+
assert item["text"] != ""
|
|
78
|
+
|
|
79
|
+
def test_clean_drop_empty_all_fields(self, sample_data_file, tmp_path):
|
|
80
|
+
"""Test dropping records with any empty field."""
|
|
81
|
+
filepath, _ = sample_data_file
|
|
82
|
+
output = tmp_path / "output.jsonl"
|
|
83
|
+
|
|
84
|
+
# drop_empty="" means check all fields
|
|
85
|
+
clean(str(filepath), drop_empty="", output=str(output))
|
|
86
|
+
|
|
87
|
+
result = load_data(str(output))
|
|
88
|
+
# Should remove records with any None or empty value
|
|
89
|
+
for item in result:
|
|
90
|
+
assert all(v is not None and v != "" for v in item.values())
|
|
91
|
+
|
|
92
|
+
def test_clean_strip(self, sample_data_file, tmp_path):
|
|
93
|
+
"""Test stripping whitespace."""
|
|
94
|
+
filepath, _ = sample_data_file
|
|
95
|
+
output = tmp_path / "output.jsonl"
|
|
96
|
+
|
|
97
|
+
clean(str(filepath), strip=True, output=str(output))
|
|
98
|
+
|
|
99
|
+
result = load_data(str(output))
|
|
100
|
+
# Find the item that had extra whitespace
|
|
101
|
+
trimmed = [item for item in result if "Needs trimming" in item.get("text", "")]
|
|
102
|
+
if trimmed:
|
|
103
|
+
assert trimmed[0]["text"] == "Needs trimming"
|
|
104
|
+
|
|
105
|
+
def test_clean_min_len(self, sample_data_file, tmp_path):
|
|
106
|
+
"""Test minimum length filtering."""
|
|
107
|
+
filepath, _ = sample_data_file
|
|
108
|
+
output = tmp_path / "output.jsonl"
|
|
109
|
+
|
|
110
|
+
clean(str(filepath), min_len="text:10", output=str(output))
|
|
111
|
+
|
|
112
|
+
result = load_data(str(output))
|
|
113
|
+
for item in result:
|
|
114
|
+
assert len(item["text"]) >= 10
|
|
115
|
+
|
|
116
|
+
def test_clean_max_len(self, sample_data_file, tmp_path):
|
|
117
|
+
"""Test maximum length filtering."""
|
|
118
|
+
filepath, _ = sample_data_file
|
|
119
|
+
output = tmp_path / "output.jsonl"
|
|
120
|
+
|
|
121
|
+
clean(str(filepath), max_len="text:10", output=str(output))
|
|
122
|
+
|
|
123
|
+
result = load_data(str(output))
|
|
124
|
+
for item in result:
|
|
125
|
+
assert len(item["text"]) <= 10
|
|
126
|
+
|
|
127
|
+
def test_clean_keep_fields(self, sample_data_file, tmp_path):
|
|
128
|
+
"""Test keeping only specified fields."""
|
|
129
|
+
filepath, _ = sample_data_file
|
|
130
|
+
output = tmp_path / "output.jsonl"
|
|
131
|
+
|
|
132
|
+
clean(str(filepath), keep="text,category", output=str(output))
|
|
133
|
+
|
|
134
|
+
result = load_data(str(output))
|
|
135
|
+
for item in result:
|
|
136
|
+
assert set(item.keys()) == {"text", "category"}
|
|
137
|
+
|
|
138
|
+
def test_clean_drop_fields(self, sample_data_file, tmp_path):
|
|
139
|
+
"""Test dropping specified fields."""
|
|
140
|
+
filepath, _ = sample_data_file
|
|
141
|
+
output = tmp_path / "output.jsonl"
|
|
142
|
+
|
|
143
|
+
clean(str(filepath), drop="score", output=str(output))
|
|
144
|
+
|
|
145
|
+
result = load_data(str(output))
|
|
146
|
+
for item in result:
|
|
147
|
+
assert "score" not in item
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# ============== Clean with Nested Fields Tests ==============
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class TestCleanNested:
|
|
154
|
+
"""Test clean with nested field paths."""
|
|
155
|
+
|
|
156
|
+
def test_clean_drop_empty_nested(self, sample_nested_file, tmp_path):
|
|
157
|
+
"""Test dropping records with empty nested field."""
|
|
158
|
+
filepath, _ = sample_nested_file
|
|
159
|
+
output = tmp_path / "output.jsonl"
|
|
160
|
+
|
|
161
|
+
clean(str(filepath), drop_empty="meta.source", output=str(output))
|
|
162
|
+
|
|
163
|
+
result = load_data(str(output))
|
|
164
|
+
assert len(result) == 2 # Should remove the one with None source
|
|
165
|
+
for item in result:
|
|
166
|
+
assert item["meta"]["source"] is not None
|
|
167
|
+
|
|
168
|
+
def test_clean_min_len_messages(self, sample_nested_file, tmp_path):
|
|
169
|
+
"""Test filtering by message count using .# syntax."""
|
|
170
|
+
filepath, _ = sample_nested_file
|
|
171
|
+
output = tmp_path / "output.jsonl"
|
|
172
|
+
|
|
173
|
+
clean(str(filepath), min_len="messages.#:2", output=str(output))
|
|
174
|
+
|
|
175
|
+
result = load_data(str(output))
|
|
176
|
+
for item in result:
|
|
177
|
+
assert len(item["messages"]) >= 2
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
# ============== Dedupe Command Tests ==============
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class TestDedupeBasic:
|
|
184
|
+
"""Test basic dedupe functionality."""
|
|
185
|
+
|
|
186
|
+
def test_dedupe_by_field(self, sample_data_file, tmp_path):
|
|
187
|
+
"""Test deduplication by specific field."""
|
|
188
|
+
filepath, _ = sample_data_file
|
|
189
|
+
output = tmp_path / "output.jsonl"
|
|
190
|
+
|
|
191
|
+
dedupe(str(filepath), key="text", output=str(output))
|
|
192
|
+
|
|
193
|
+
result = load_data(str(output))
|
|
194
|
+
# Should have removed the duplicate "Hello world"
|
|
195
|
+
texts = [item["text"] for item in result]
|
|
196
|
+
assert len(texts) == len(set(texts))
|
|
197
|
+
|
|
198
|
+
def test_dedupe_full(self, sample_data_file, tmp_path):
|
|
199
|
+
"""Test full record deduplication."""
|
|
200
|
+
filepath, _ = sample_data_file
|
|
201
|
+
output = tmp_path / "output.jsonl"
|
|
202
|
+
|
|
203
|
+
dedupe(str(filepath), output=str(output))
|
|
204
|
+
|
|
205
|
+
result = load_data(str(output))
|
|
206
|
+
# All records are unique, so should have same count
|
|
207
|
+
# (unless there are exact duplicates)
|
|
208
|
+
assert len(result) >= 1
|
|
209
|
+
|
|
210
|
+
def test_dedupe_overwrite(self, sample_data_file):
|
|
211
|
+
"""Test deduplication with overwrite (no output specified)."""
|
|
212
|
+
filepath, original_data = sample_data_file
|
|
213
|
+
|
|
214
|
+
dedupe(str(filepath), key="text")
|
|
215
|
+
|
|
216
|
+
result = load_data(str(filepath))
|
|
217
|
+
texts = [item["text"] for item in result]
|
|
218
|
+
assert len(texts) == len(set(texts))
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
# ============== Parameter Parsing Tests ==============
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class TestParamParsing:
|
|
225
|
+
"""Test parameter parsing functions."""
|
|
226
|
+
|
|
227
|
+
def test_parse_len_param_valid(self):
|
|
228
|
+
"""Test valid length parameter parsing."""
|
|
229
|
+
field, value = _parse_len_param("text:100")
|
|
230
|
+
assert field == "text"
|
|
231
|
+
assert value == 100
|
|
232
|
+
|
|
233
|
+
def test_parse_len_param_nested(self):
|
|
234
|
+
"""Test nested field length parameter."""
|
|
235
|
+
field, value = _parse_len_param("messages.#:5")
|
|
236
|
+
assert field == "messages.#"
|
|
237
|
+
assert value == 5
|
|
238
|
+
|
|
239
|
+
def test_parse_len_param_invalid_no_colon(self):
|
|
240
|
+
"""Test invalid parameter without colon."""
|
|
241
|
+
with pytest.raises(ValueError):
|
|
242
|
+
_parse_len_param("text100")
|
|
243
|
+
|
|
244
|
+
def test_parse_len_param_invalid_non_numeric(self):
|
|
245
|
+
"""Test invalid parameter with non-numeric value."""
|
|
246
|
+
with pytest.raises(ValueError):
|
|
247
|
+
_parse_len_param("text:abc")
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
# ============== Clean Single Pass Tests ==============
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class TestCleanSinglePass:
|
|
254
|
+
"""Test _clean_data_single_pass function."""
|
|
255
|
+
|
|
256
|
+
def test_single_pass_strip(self):
|
|
257
|
+
"""Test strip in single pass."""
|
|
258
|
+
data = [{"text": " hello ", "value": " world "}]
|
|
259
|
+
result, _ = _clean_data_single_pass(data, strip=True)
|
|
260
|
+
|
|
261
|
+
assert result[0]["text"] == "hello"
|
|
262
|
+
assert result[0]["value"] == "world"
|
|
263
|
+
|
|
264
|
+
def test_single_pass_drop_empty(self):
|
|
265
|
+
"""Test drop empty in single pass."""
|
|
266
|
+
data = [
|
|
267
|
+
{"text": "hello", "value": "world"},
|
|
268
|
+
{"text": "", "value": "test"},
|
|
269
|
+
{"text": "hi", "value": None},
|
|
270
|
+
]
|
|
271
|
+
result, stats = _clean_data_single_pass(data, empty_fields=["text", "value"])
|
|
272
|
+
|
|
273
|
+
assert len(result) == 1
|
|
274
|
+
assert result[0]["text"] == "hello"
|
|
275
|
+
|
|
276
|
+
def test_single_pass_combined(self):
|
|
277
|
+
"""Test combined operations in single pass."""
|
|
278
|
+
data = [
|
|
279
|
+
{"text": " long text here ", "score": 0.9},
|
|
280
|
+
{"text": " hi ", "score": 0.8},
|
|
281
|
+
]
|
|
282
|
+
result, stats = _clean_data_single_pass(
|
|
283
|
+
data, strip=True, min_len_field="text", min_len_value=5
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
assert len(result) == 1
|
|
287
|
+
assert result[0]["text"] == "long text here"
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
# ============== Error Handling Tests ==============
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class TestCleanErrors:
|
|
294
|
+
"""Test error handling in clean commands."""
|
|
295
|
+
|
|
296
|
+
def test_clean_file_not_exists(self, tmp_path, capsys):
|
|
297
|
+
"""Test error when file doesn't exist."""
|
|
298
|
+
clean(str(tmp_path / "nonexistent.jsonl"))
|
|
299
|
+
captured = capsys.readouterr()
|
|
300
|
+
assert "文件不存在" in captured.out
|
|
301
|
+
|
|
302
|
+
def test_dedupe_similar_without_key(self, sample_data_file, capsys):
|
|
303
|
+
"""Test error when using similar without key."""
|
|
304
|
+
filepath, _ = sample_data_file
|
|
305
|
+
dedupe(str(filepath), similar=0.8)
|
|
306
|
+
captured = capsys.readouterr()
|
|
307
|
+
assert "需要指定 --key" in captured.out
|
|
308
|
+
|
|
309
|
+
def test_dedupe_invalid_similar_range(self, sample_data_file, capsys):
|
|
310
|
+
"""Test error when similar value is out of range."""
|
|
311
|
+
filepath, _ = sample_data_file
|
|
312
|
+
dedupe(str(filepath), key="text", similar=1.5)
|
|
313
|
+
captured = capsys.readouterr()
|
|
314
|
+
assert "0-1 之间" in captured.out
|