dtflow 0.5.2__tar.gz → 0.5.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. dtflow-0.5.3/CHANGELOG.md +19 -0
  2. {dtflow-0.5.2 → dtflow-0.5.3}/PKG-INFO +1 -1
  3. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/__init__.py +7 -7
  4. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/converters.py +17 -13
  5. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/presets.py +14 -15
  6. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/utils/__init__.py +3 -0
  7. dtflow-0.5.3/dtflow/utils/helpers.py +30 -0
  8. {dtflow-0.5.2 → dtflow-0.5.3}/pyproject.toml +6 -4
  9. dtflow-0.5.3/tests/test_cli_clean.py +314 -0
  10. dtflow-0.5.3/tests/test_cli_sample.py +242 -0
  11. dtflow-0.5.3/tests/test_cli_stats.py +213 -0
  12. dtflow-0.5.3/tests/test_cli_transform.py +304 -0
  13. {dtflow-0.5.2 → dtflow-0.5.3}/.gitignore +0 -0
  14. {dtflow-0.5.2 → dtflow-0.5.3}/README.md +0 -0
  15. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/__main__.py +0 -0
  16. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/cli/__init__.py +0 -0
  17. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/cli/clean.py +0 -0
  18. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/cli/commands.py +0 -0
  19. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/cli/common.py +0 -0
  20. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/cli/io_ops.py +0 -0
  21. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/cli/lineage.py +0 -0
  22. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/cli/pipeline.py +0 -0
  23. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/cli/sample.py +0 -0
  24. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/cli/stats.py +0 -0
  25. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/cli/transform.py +0 -0
  26. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/cli/validate.py +0 -0
  27. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/core.py +0 -0
  28. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/framework.py +0 -0
  29. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/lineage.py +0 -0
  30. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/mcp/__init__.py +0 -0
  31. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/mcp/__main__.py +0 -0
  32. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/mcp/cli.py +0 -0
  33. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/mcp/docs.py +0 -0
  34. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/mcp/server.py +0 -0
  35. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/pipeline.py +0 -0
  36. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/schema.py +0 -0
  37. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/storage/__init__.py +0 -0
  38. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/storage/io.py +0 -0
  39. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/streaming.py +0 -0
  40. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/tokenizers.py +0 -0
  41. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/utils/display.py +0 -0
  42. {dtflow-0.5.2 → dtflow-0.5.3}/dtflow/utils/field_path.py +0 -0
  43. {dtflow-0.5.2 → dtflow-0.5.3}/tests/README.md +0 -0
  44. {dtflow-0.5.2 → dtflow-0.5.3}/tests/benchmark_io.py +0 -0
  45. {dtflow-0.5.2 → dtflow-0.5.3}/tests/benchmark_sharegpt.py +0 -0
  46. {dtflow-0.5.2 → dtflow-0.5.3}/tests/test_cli_benchmark.py +0 -0
  47. {dtflow-0.5.2 → dtflow-0.5.3}/tests/test_converters.py +0 -0
  48. {dtflow-0.5.2 → dtflow-0.5.3}/tests/test_field_path.py +0 -0
  49. {dtflow-0.5.2 → dtflow-0.5.3}/tests/test_framework.py +0 -0
  50. {dtflow-0.5.2 → dtflow-0.5.3}/tests/test_io.py +0 -0
  51. {dtflow-0.5.2 → dtflow-0.5.3}/tests/test_lineage.py +0 -0
  52. {dtflow-0.5.2 → dtflow-0.5.3}/tests/test_pipeline.py +0 -0
  53. {dtflow-0.5.2 → dtflow-0.5.3}/tests/test_schema.py +0 -0
  54. {dtflow-0.5.2 → dtflow-0.5.3}/tests/test_streaming.py +0 -0
  55. {dtflow-0.5.2 → dtflow-0.5.3}/tests/test_tokenizers.py +0 -0
  56. {dtflow-0.5.2 → dtflow-0.5.3}/tests/test_transformer.py +0 -0
@@ -0,0 +1,19 @@
1
+ # Changelog
2
+
3
+ ## [0.5.2] - 2026-01-18
4
+
5
+ ### Miscellaneous
6
+
7
+ - Bump version to 0.5.2
8
+ - 添加 pre-commit 配置和发版脚本
9
+
10
+ ## [0.5.1] - 2026-01-18
11
+
12
+ ### Features
13
+
14
+ - 优化 sample 命令文本预览显示
15
+
16
+ ### Testing
17
+
18
+ - 添加测试运行说明
19
+ - 补充 tail/token-stats/validate 性能测试
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dtflow
3
- Version: 0.5.2
3
+ Version: 0.5.3
4
4
  Summary: A flexible data transformation tool for ML training formats (SFT, RLHF, Pretrain)
5
5
  Project-URL: Homepage, https://github.com/yourusername/DataTransformer
6
6
  Project-URL: Documentation, https://github.com/yourusername/DataTransformer#readme
@@ -26,6 +26,12 @@ from .converters import ( # LLaMA-Factory 扩展; ms-swift
26
26
  to_swift_vlm,
27
27
  )
28
28
  from .core import DataTransformer, DictWrapper, TransformError, TransformErrors
29
+ from .framework import (
30
+ CompatibilityResult,
31
+ check_compatibility,
32
+ detect_format,
33
+ export_for,
34
+ )
29
35
  from .presets import get_preset, list_presets
30
36
  from .schema import (
31
37
  Field,
@@ -38,12 +44,6 @@ from .schema import (
38
44
  sharegpt_schema,
39
45
  validate_data,
40
46
  )
41
- from .framework import (
42
- CompatibilityResult,
43
- check_compatibility,
44
- detect_format,
45
- export_for,
46
- )
47
47
  from .storage import load_data, sample_file, save_data
48
48
  from .streaming import StreamingTransformer, load_sharded, load_stream, process_shards
49
49
  from .tokenizers import (
@@ -60,7 +60,7 @@ from .tokenizers import (
60
60
  token_stats,
61
61
  )
62
62
 
63
- __version__ = "0.5.2"
63
+ __version__ = "0.5.3"
64
64
 
65
65
  __all__ = [
66
66
  # core
@@ -4,7 +4,7 @@
4
4
  提供与 HuggingFace datasets 等常用格式的互转功能。
5
5
  """
6
6
 
7
- from typing import Any, Callable, Dict, List, Optional, Union
7
+ from typing import Any, Callable, Dict, List, Optional
8
8
 
9
9
 
10
10
  def to_hf_dataset(data: List[Dict[str, Any]]):
@@ -143,14 +143,16 @@ def to_openai_batch(
143
143
  >>> batch_input = dt.to(to_openai_batch(model="gpt-4o"))
144
144
  """
145
145
 
146
- def transform(item, idx=[0]) -> dict:
146
+ counter = {"idx": 0}
147
+
148
+ def transform(item) -> dict:
147
149
  messages = item.get(messages_field, []) if hasattr(item, "get") else item[messages_field]
148
150
 
149
151
  if custom_id_field:
150
152
  custom_id = item.get(custom_id_field) if hasattr(item, "get") else item[custom_id_field]
151
153
  else:
152
- custom_id = f"request-{idx[0]}"
153
- idx[0] += 1
154
+ custom_id = f"request-{counter['idx']}"
155
+ counter["idx"] += 1
154
156
 
155
157
  return {
156
158
  "custom_id": str(custom_id),
@@ -196,7 +198,7 @@ def to_llama_factory(
196
198
  """
197
199
 
198
200
  def transform(item) -> dict:
199
- get = lambda f: (item.get(f, "") if hasattr(item, "get") else item.get(f, ""))
201
+ get = lambda f: item.get(f, "") if hasattr(item, "get") else getattr(item, f, "")
200
202
 
201
203
  result = {
202
204
  "instruction": get(instruction_field),
@@ -248,7 +250,7 @@ def to_axolotl(
248
250
  conversations = (
249
251
  item.get(conversations_field, [])
250
252
  if hasattr(item, "get")
251
- else item.get(conversations_field, [])
253
+ else getattr(item, conversations_field, [])
252
254
  )
253
255
 
254
256
  # 如果已经是正确格式,直接返回
@@ -257,7 +259,9 @@ def to_axolotl(
257
259
  return {"conversations": conversations}
258
260
 
259
261
  # 尝试从 messages 格式转换
260
- messages = item.get("messages", []) if hasattr(item, "get") else item.get("messages", [])
262
+ messages = (
263
+ item.get("messages", []) if hasattr(item, "get") else getattr(item, "messages", [])
264
+ )
261
265
  if messages:
262
266
  role_map = {"user": "human", "assistant": "gpt", "system": "system"}
263
267
  conversations = [
@@ -312,7 +316,7 @@ def to_llama_factory_sharegpt(
312
316
  }
313
317
 
314
318
  def transform(item) -> dict:
315
- get = lambda f: (item.get(f, "") if hasattr(item, "get") else item.get(f, ""))
319
+ get = lambda f: item.get(f, "") if hasattr(item, "get") else getattr(item, f, "")
316
320
  messages = get(messages_field) or []
317
321
 
318
322
  conversations = []
@@ -385,7 +389,7 @@ def to_llama_factory_vlm(
385
389
  """
386
390
 
387
391
  def transform(item) -> dict:
388
- get = lambda f: item.get(f) if hasattr(item, "get") else item.get(f)
392
+ get = lambda f: item.get(f) if hasattr(item, "get") else getattr(item, f, None)
389
393
  messages = get(messages_field) or []
390
394
 
391
395
  instruction = ""
@@ -467,7 +471,7 @@ def to_llama_factory_vlm_sharegpt(
467
471
  role_map = {"user": "human", "assistant": "gpt", "system": "system"}
468
472
 
469
473
  def transform(item) -> dict:
470
- get = lambda f: item.get(f) if hasattr(item, "get") else item.get(f)
474
+ get = lambda f: item.get(f) if hasattr(item, "get") else getattr(item, f, None)
471
475
  messages = get(messages_field) or []
472
476
 
473
477
  conversations = []
@@ -541,7 +545,7 @@ def to_swift_messages(
541
545
  """
542
546
 
543
547
  def transform(item) -> dict:
544
- get = lambda f: item.get(f) if hasattr(item, "get") else item.get(f)
548
+ get = lambda f: item.get(f) if hasattr(item, "get") else getattr(item, f, None)
545
549
  messages = get(messages_field) or []
546
550
 
547
551
  # 复制 messages,避免修改原数据
@@ -600,7 +604,7 @@ def to_swift_query_response(
600
604
  """
601
605
 
602
606
  def transform(item) -> dict:
603
- get = lambda f: item.get(f) if hasattr(item, "get") else item.get(f)
607
+ get = lambda f: item.get(f) if hasattr(item, "get") else getattr(item, f, None)
604
608
 
605
609
  query = get(query_field)
606
610
  response = get(response_field)
@@ -693,7 +697,7 @@ def to_swift_vlm(
693
697
  """
694
698
 
695
699
  def transform(item) -> dict:
696
- get = lambda f: item.get(f) if hasattr(item, "get") else item.get(f)
700
+ get = lambda f: item.get(f) if hasattr(item, "get") else getattr(item, f, None)
697
701
  messages = get(messages_field) or []
698
702
 
699
703
  result_messages = []
@@ -6,6 +6,8 @@
6
6
 
7
7
  from typing import Any, Callable
8
8
 
9
+ from dtflow.utils.helpers import get_field_value
10
+
9
11
 
10
12
  def openai_chat(
11
13
  user_field: str = "q", assistant_field: str = "a", system_prompt: str = None
@@ -33,8 +35,8 @@ def openai_chat(
33
35
  if system_prompt:
34
36
  messages.append({"role": "system", "content": system_prompt})
35
37
 
36
- user_content = getattr(item, user_field, None) or item.get(user_field, "")
37
- assistant_content = getattr(item, assistant_field, None) or item.get(assistant_field, "")
38
+ user_content = get_field_value(item, user_field)
39
+ assistant_content = get_field_value(item, assistant_field)
38
40
 
39
41
  messages.append({"role": "user", "content": user_content})
40
42
  messages.append({"role": "assistant", "content": assistant_content})
@@ -60,10 +62,9 @@ def alpaca(
60
62
 
61
63
  def transform(item: Any) -> dict:
62
64
  return {
63
- "instruction": getattr(item, instruction_field, None)
64
- or item.get(instruction_field, ""),
65
- "input": getattr(item, input_field, None) or item.get(input_field, ""),
66
- "output": getattr(item, output_field, None) or item.get(output_field, ""),
65
+ "instruction": get_field_value(item, instruction_field),
66
+ "input": get_field_value(item, input_field),
67
+ "output": get_field_value(item, output_field),
67
68
  }
68
69
 
69
70
  return transform
@@ -84,9 +85,7 @@ def sharegpt(conversations_field: str = "conversations", role_mapping: dict = No
84
85
  role_mapping = role_mapping or {"user": "human", "assistant": "gpt"}
85
86
 
86
87
  def transform(item: Any) -> dict:
87
- conversations = getattr(item, conversations_field, None) or item.get(
88
- conversations_field, []
89
- )
88
+ conversations = get_field_value(item, conversations_field, [])
90
89
 
91
90
  # 如果已经是对话格式,直接返回
92
91
  if conversations:
@@ -102,7 +101,7 @@ def sharegpt(conversations_field: str = "conversations", role_mapping: dict = No
102
101
  ("answer", "gpt"),
103
102
  ("output", "gpt"),
104
103
  ]:
105
- value = getattr(item, field, None) or item.get(field, None)
104
+ value = get_field_value(item, field, None)
106
105
  if value:
107
106
  result.append({"from": role, "value": value})
108
107
 
@@ -127,9 +126,9 @@ def dpo_pair(
127
126
 
128
127
  def transform(item: Any) -> dict:
129
128
  return {
130
- "prompt": getattr(item, prompt_field, None) or item.get(prompt_field, ""),
131
- "chosen": getattr(item, chosen_field, None) or item.get(chosen_field, ""),
132
- "rejected": getattr(item, rejected_field, None) or item.get(rejected_field, ""),
129
+ "prompt": get_field_value(item, prompt_field),
130
+ "chosen": get_field_value(item, chosen_field),
131
+ "rejected": get_field_value(item, rejected_field),
133
132
  }
134
133
 
135
134
  return transform
@@ -148,8 +147,8 @@ def simple_qa(question_field: str = "q", answer_field: str = "a") -> Callable:
148
147
 
149
148
  def transform(item: Any) -> dict:
150
149
  return {
151
- "question": getattr(item, question_field, None) or item.get(question_field, ""),
152
- "answer": getattr(item, answer_field, None) or item.get(answer_field, ""),
150
+ "question": get_field_value(item, question_field),
151
+ "answer": get_field_value(item, answer_field),
153
152
  }
154
153
 
155
154
  return transform
@@ -9,6 +9,7 @@ from .field_path import (
9
9
  get_field_with_spec,
10
10
  parse_field_spec,
11
11
  )
12
+ from .helpers import get_field_value
12
13
 
13
14
  __all__ = [
14
15
  "display_data",
@@ -20,4 +21,6 @@ __all__ = [
20
21
  "extract",
21
22
  "extract_with_spec",
22
23
  "ExpandMode",
24
+ # helpers
25
+ "get_field_value",
23
26
  ]
@@ -0,0 +1,30 @@
1
+ """公共辅助函数"""
2
+
3
+ from typing import Any
4
+
5
+
6
+ def get_field_value(item: Any, field: str, default: Any = "") -> Any:
7
+ """
8
+ 获取字段值,支持 DictWrapper 和普通 dict。
9
+
10
+ 优先尝试 dict.get(),如果没有 get 方法则使用 getattr()。
11
+
12
+ Args:
13
+ item: 数据对象(dict 或 DictWrapper)
14
+ field: 字段名
15
+ default: 默认值
16
+
17
+ Returns:
18
+ 字段值或默认值
19
+
20
+ Examples:
21
+ >>> get_field_value({"name": "test"}, "name")
22
+ 'test'
23
+ >>> get_field_value({"name": ""}, "name", "default")
24
+ 'default'
25
+ """
26
+ if hasattr(item, "get"):
27
+ value = item.get(field, default)
28
+ else:
29
+ value = getattr(item, field, default)
30
+ return value if value else default
@@ -300,8 +300,10 @@ ignore_missing_imports = true
300
300
 
301
301
  # Ruff configuration (optional alternative to flake8)
302
302
  [tool.ruff]
303
- target-version = "py37"
303
+ target-version = "py38"
304
304
  line-length = 100
305
+
306
+ [tool.ruff.lint]
305
307
  select = [
306
308
  "E", # pycodestyle errors
307
309
  "W", # pycodestyle warnings
@@ -311,13 +313,13 @@ select = [
311
313
  "B", # flake8-bugbear
312
314
  ]
313
315
  ignore = [
314
- "E501", # line too long, handled by black
316
+ "E501", # line too long, handled by ruff-format
315
317
  "B008", # do not perform function calls in argument defaults
316
318
  "C901", # too complex
317
319
  ]
318
320
 
319
- [tool.ruff.per-file-ignores]
321
+ [tool.ruff.lint.per-file-ignores]
320
322
  "__init__.py" = ["F401"]
321
323
 
322
- [tool.ruff.isort]
324
+ [tool.ruff.lint.isort]
323
325
  known-first-party = ["dtflow"]
@@ -0,0 +1,314 @@
1
+ """
2
+ Tests for CLI clean and dedupe commands.
3
+ """
4
+
5
+ import pytest
6
+
7
+ from dtflow.cli.clean import _clean_data_single_pass, _parse_len_param, clean, dedupe
8
+ from dtflow.storage.io import load_data, save_data
9
+
10
+ # ============== Fixtures ==============
11
+
12
+
13
+ @pytest.fixture
14
+ def sample_data_file(tmp_path):
15
+ """Create a sample dataset file."""
16
+ data = [
17
+ {"text": "Hello world", "score": 0.8, "category": "greeting"},
18
+ {"text": "How are you?", "score": 0.9, "category": "question"},
19
+ {"text": " Needs trimming ", "score": 0.7, "category": "test"},
20
+ {"text": "", "score": 0.6, "category": "empty_text"},
21
+ {"text": "Short", "score": None, "category": None},
22
+ {"text": "Hello world", "score": 0.85, "category": "duplicate"}, # duplicate text
23
+ ]
24
+ filepath = tmp_path / "test_data.jsonl"
25
+ save_data(data, str(filepath))
26
+ return filepath, data
27
+
28
+
29
+ @pytest.fixture
30
+ def sample_nested_file(tmp_path):
31
+ """Create a sample dataset with nested fields."""
32
+ data = [
33
+ {
34
+ "id": 1,
35
+ "meta": {"source": "web", "score": 0.9},
36
+ "messages": [
37
+ {"role": "user", "content": "Hello"},
38
+ {"role": "assistant", "content": "Hi!"},
39
+ ],
40
+ },
41
+ {
42
+ "id": 2,
43
+ "meta": {"source": "api", "score": 0.8},
44
+ "messages": [{"role": "user", "content": "Hi"}],
45
+ },
46
+ {
47
+ "id": 3,
48
+ "meta": {"source": None, "score": 0.5},
49
+ "messages": [
50
+ {"role": "user", "content": "A"},
51
+ {"role": "assistant", "content": "B"},
52
+ {"role": "user", "content": "C"},
53
+ ],
54
+ },
55
+ ]
56
+ filepath = tmp_path / "test_nested.jsonl"
57
+ save_data(data, str(filepath))
58
+ return filepath, data
59
+
60
+
61
+ # ============== Clean Command Tests ==============
62
+
63
+
64
+ class TestCleanBasic:
65
+ """Test basic clean functionality."""
66
+
67
+ def test_clean_drop_empty(self, sample_data_file, tmp_path):
68
+ """Test dropping empty records."""
69
+ filepath, _ = sample_data_file
70
+ output = tmp_path / "output.jsonl"
71
+
72
+ clean(str(filepath), drop_empty="text", output=str(output))
73
+
74
+ result = load_data(str(output))
75
+ # Should remove the record with empty text
76
+ for item in result:
77
+ assert item["text"] != ""
78
+
79
+ def test_clean_drop_empty_all_fields(self, sample_data_file, tmp_path):
80
+ """Test dropping records with any empty field."""
81
+ filepath, _ = sample_data_file
82
+ output = tmp_path / "output.jsonl"
83
+
84
+ # drop_empty="" means check all fields
85
+ clean(str(filepath), drop_empty="", output=str(output))
86
+
87
+ result = load_data(str(output))
88
+ # Should remove records with any None or empty value
89
+ for item in result:
90
+ assert all(v is not None and v != "" for v in item.values())
91
+
92
+ def test_clean_strip(self, sample_data_file, tmp_path):
93
+ """Test stripping whitespace."""
94
+ filepath, _ = sample_data_file
95
+ output = tmp_path / "output.jsonl"
96
+
97
+ clean(str(filepath), strip=True, output=str(output))
98
+
99
+ result = load_data(str(output))
100
+ # Find the item that had extra whitespace
101
+ trimmed = [item for item in result if "Needs trimming" in item.get("text", "")]
102
+ if trimmed:
103
+ assert trimmed[0]["text"] == "Needs trimming"
104
+
105
+ def test_clean_min_len(self, sample_data_file, tmp_path):
106
+ """Test minimum length filtering."""
107
+ filepath, _ = sample_data_file
108
+ output = tmp_path / "output.jsonl"
109
+
110
+ clean(str(filepath), min_len="text:10", output=str(output))
111
+
112
+ result = load_data(str(output))
113
+ for item in result:
114
+ assert len(item["text"]) >= 10
115
+
116
+ def test_clean_max_len(self, sample_data_file, tmp_path):
117
+ """Test maximum length filtering."""
118
+ filepath, _ = sample_data_file
119
+ output = tmp_path / "output.jsonl"
120
+
121
+ clean(str(filepath), max_len="text:10", output=str(output))
122
+
123
+ result = load_data(str(output))
124
+ for item in result:
125
+ assert len(item["text"]) <= 10
126
+
127
+ def test_clean_keep_fields(self, sample_data_file, tmp_path):
128
+ """Test keeping only specified fields."""
129
+ filepath, _ = sample_data_file
130
+ output = tmp_path / "output.jsonl"
131
+
132
+ clean(str(filepath), keep="text,category", output=str(output))
133
+
134
+ result = load_data(str(output))
135
+ for item in result:
136
+ assert set(item.keys()) == {"text", "category"}
137
+
138
+ def test_clean_drop_fields(self, sample_data_file, tmp_path):
139
+ """Test dropping specified fields."""
140
+ filepath, _ = sample_data_file
141
+ output = tmp_path / "output.jsonl"
142
+
143
+ clean(str(filepath), drop="score", output=str(output))
144
+
145
+ result = load_data(str(output))
146
+ for item in result:
147
+ assert "score" not in item
148
+
149
+
150
+ # ============== Clean with Nested Fields Tests ==============
151
+
152
+
153
+ class TestCleanNested:
154
+ """Test clean with nested field paths."""
155
+
156
+ def test_clean_drop_empty_nested(self, sample_nested_file, tmp_path):
157
+ """Test dropping records with empty nested field."""
158
+ filepath, _ = sample_nested_file
159
+ output = tmp_path / "output.jsonl"
160
+
161
+ clean(str(filepath), drop_empty="meta.source", output=str(output))
162
+
163
+ result = load_data(str(output))
164
+ assert len(result) == 2 # Should remove the one with None source
165
+ for item in result:
166
+ assert item["meta"]["source"] is not None
167
+
168
+ def test_clean_min_len_messages(self, sample_nested_file, tmp_path):
169
+ """Test filtering by message count using .# syntax."""
170
+ filepath, _ = sample_nested_file
171
+ output = tmp_path / "output.jsonl"
172
+
173
+ clean(str(filepath), min_len="messages.#:2", output=str(output))
174
+
175
+ result = load_data(str(output))
176
+ for item in result:
177
+ assert len(item["messages"]) >= 2
178
+
179
+
180
+ # ============== Dedupe Command Tests ==============
181
+
182
+
183
+ class TestDedupeBasic:
184
+ """Test basic dedupe functionality."""
185
+
186
+ def test_dedupe_by_field(self, sample_data_file, tmp_path):
187
+ """Test deduplication by specific field."""
188
+ filepath, _ = sample_data_file
189
+ output = tmp_path / "output.jsonl"
190
+
191
+ dedupe(str(filepath), key="text", output=str(output))
192
+
193
+ result = load_data(str(output))
194
+ # Should have removed the duplicate "Hello world"
195
+ texts = [item["text"] for item in result]
196
+ assert len(texts) == len(set(texts))
197
+
198
+ def test_dedupe_full(self, sample_data_file, tmp_path):
199
+ """Test full record deduplication."""
200
+ filepath, _ = sample_data_file
201
+ output = tmp_path / "output.jsonl"
202
+
203
+ dedupe(str(filepath), output=str(output))
204
+
205
+ result = load_data(str(output))
206
+ # All records are unique, so should have same count
207
+ # (unless there are exact duplicates)
208
+ assert len(result) >= 1
209
+
210
+ def test_dedupe_overwrite(self, sample_data_file):
211
+ """Test deduplication with overwrite (no output specified)."""
212
+ filepath, original_data = sample_data_file
213
+
214
+ dedupe(str(filepath), key="text")
215
+
216
+ result = load_data(str(filepath))
217
+ texts = [item["text"] for item in result]
218
+ assert len(texts) == len(set(texts))
219
+
220
+
221
+ # ============== Parameter Parsing Tests ==============
222
+
223
+
224
+ class TestParamParsing:
225
+ """Test parameter parsing functions."""
226
+
227
+ def test_parse_len_param_valid(self):
228
+ """Test valid length parameter parsing."""
229
+ field, value = _parse_len_param("text:100")
230
+ assert field == "text"
231
+ assert value == 100
232
+
233
+ def test_parse_len_param_nested(self):
234
+ """Test nested field length parameter."""
235
+ field, value = _parse_len_param("messages.#:5")
236
+ assert field == "messages.#"
237
+ assert value == 5
238
+
239
+ def test_parse_len_param_invalid_no_colon(self):
240
+ """Test invalid parameter without colon."""
241
+ with pytest.raises(ValueError):
242
+ _parse_len_param("text100")
243
+
244
+ def test_parse_len_param_invalid_non_numeric(self):
245
+ """Test invalid parameter with non-numeric value."""
246
+ with pytest.raises(ValueError):
247
+ _parse_len_param("text:abc")
248
+
249
+
250
+ # ============== Clean Single Pass Tests ==============
251
+
252
+
253
+ class TestCleanSinglePass:
254
+ """Test _clean_data_single_pass function."""
255
+
256
+ def test_single_pass_strip(self):
257
+ """Test strip in single pass."""
258
+ data = [{"text": " hello ", "value": " world "}]
259
+ result, _ = _clean_data_single_pass(data, strip=True)
260
+
261
+ assert result[0]["text"] == "hello"
262
+ assert result[0]["value"] == "world"
263
+
264
+ def test_single_pass_drop_empty(self):
265
+ """Test drop empty in single pass."""
266
+ data = [
267
+ {"text": "hello", "value": "world"},
268
+ {"text": "", "value": "test"},
269
+ {"text": "hi", "value": None},
270
+ ]
271
+ result, stats = _clean_data_single_pass(data, empty_fields=["text", "value"])
272
+
273
+ assert len(result) == 1
274
+ assert result[0]["text"] == "hello"
275
+
276
+ def test_single_pass_combined(self):
277
+ """Test combined operations in single pass."""
278
+ data = [
279
+ {"text": " long text here ", "score": 0.9},
280
+ {"text": " hi ", "score": 0.8},
281
+ ]
282
+ result, stats = _clean_data_single_pass(
283
+ data, strip=True, min_len_field="text", min_len_value=5
284
+ )
285
+
286
+ assert len(result) == 1
287
+ assert result[0]["text"] == "long text here"
288
+
289
+
290
+ # ============== Error Handling Tests ==============
291
+
292
+
293
+ class TestCleanErrors:
294
+ """Test error handling in clean commands."""
295
+
296
+ def test_clean_file_not_exists(self, tmp_path, capsys):
297
+ """Test error when file doesn't exist."""
298
+ clean(str(tmp_path / "nonexistent.jsonl"))
299
+ captured = capsys.readouterr()
300
+ assert "文件不存在" in captured.out
301
+
302
+ def test_dedupe_similar_without_key(self, sample_data_file, capsys):
303
+ """Test error when using similar without key."""
304
+ filepath, _ = sample_data_file
305
+ dedupe(str(filepath), similar=0.8)
306
+ captured = capsys.readouterr()
307
+ assert "需要指定 --key" in captured.out
308
+
309
+ def test_dedupe_invalid_similar_range(self, sample_data_file, capsys):
310
+ """Test error when similar value is out of range."""
311
+ filepath, _ = sample_data_file
312
+ dedupe(str(filepath), key="text", similar=1.5)
313
+ captured = capsys.readouterr()
314
+ assert "0-1 之间" in captured.out