dtflow 0.5.5__tar.gz → 0.5.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {dtflow-0.5.5 → dtflow-0.5.6}/PKG-INFO +4 -2
  2. {dtflow-0.5.5 → dtflow-0.5.6}/README.md +3 -1
  3. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/__init__.py +1 -1
  4. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/__main__.py +2 -1
  5. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/cli/sample.py +159 -11
  6. {dtflow-0.5.5 → dtflow-0.5.6}/tests/test_cli_sample.py +99 -0
  7. {dtflow-0.5.5 → dtflow-0.5.6}/.gitignore +0 -0
  8. {dtflow-0.5.5 → dtflow-0.5.6}/CHANGELOG.md +0 -0
  9. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/cli/__init__.py +0 -0
  10. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/cli/clean.py +0 -0
  11. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/cli/commands.py +0 -0
  12. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/cli/common.py +0 -0
  13. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/cli/io_ops.py +0 -0
  14. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/cli/lineage.py +0 -0
  15. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/cli/pipeline.py +0 -0
  16. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/cli/stats.py +0 -0
  17. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/cli/transform.py +0 -0
  18. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/cli/validate.py +0 -0
  19. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/converters.py +0 -0
  20. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/core.py +0 -0
  21. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/framework.py +0 -0
  22. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/lineage.py +0 -0
  23. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/mcp/__init__.py +0 -0
  24. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/mcp/__main__.py +0 -0
  25. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/mcp/cli.py +0 -0
  26. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/mcp/docs.py +0 -0
  27. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/mcp/server.py +0 -0
  28. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/pipeline.py +0 -0
  29. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/presets.py +0 -0
  30. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/schema.py +0 -0
  31. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/storage/__init__.py +0 -0
  32. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/storage/io.py +0 -0
  33. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/streaming.py +0 -0
  34. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/tokenizers.py +0 -0
  35. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/utils/__init__.py +0 -0
  36. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/utils/display.py +0 -0
  37. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/utils/field_path.py +0 -0
  38. {dtflow-0.5.5 → dtflow-0.5.6}/dtflow/utils/helpers.py +0 -0
  39. {dtflow-0.5.5 → dtflow-0.5.6}/pyproject.toml +0 -0
  40. {dtflow-0.5.5 → dtflow-0.5.6}/tests/README.md +0 -0
  41. {dtflow-0.5.5 → dtflow-0.5.6}/tests/benchmark_io.py +0 -0
  42. {dtflow-0.5.5 → dtflow-0.5.6}/tests/benchmark_sharegpt.py +0 -0
  43. {dtflow-0.5.5 → dtflow-0.5.6}/tests/test_cli_benchmark.py +0 -0
  44. {dtflow-0.5.5 → dtflow-0.5.6}/tests/test_cli_clean.py +0 -0
  45. {dtflow-0.5.5 → dtflow-0.5.6}/tests/test_cli_stats.py +0 -0
  46. {dtflow-0.5.5 → dtflow-0.5.6}/tests/test_cli_transform.py +0 -0
  47. {dtflow-0.5.5 → dtflow-0.5.6}/tests/test_converters.py +0 -0
  48. {dtflow-0.5.5 → dtflow-0.5.6}/tests/test_field_path.py +0 -0
  49. {dtflow-0.5.5 → dtflow-0.5.6}/tests/test_framework.py +0 -0
  50. {dtflow-0.5.5 → dtflow-0.5.6}/tests/test_io.py +0 -0
  51. {dtflow-0.5.5 → dtflow-0.5.6}/tests/test_lineage.py +0 -0
  52. {dtflow-0.5.5 → dtflow-0.5.6}/tests/test_pipeline.py +0 -0
  53. {dtflow-0.5.5 → dtflow-0.5.6}/tests/test_schema.py +0 -0
  54. {dtflow-0.5.5 → dtflow-0.5.6}/tests/test_streaming.py +0 -0
  55. {dtflow-0.5.5 → dtflow-0.5.6}/tests/test_tokenizers.py +0 -0
  56. {dtflow-0.5.5 → dtflow-0.5.6}/tests/test_transformer.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dtflow
3
- Version: 0.5.5
3
+ Version: 0.5.6
4
4
  Summary: A flexible data transformation tool for ML training formats (SFT, RLHF, Pretrain)
5
5
  Project-URL: Homepage, https://github.com/yourusername/DataTransformer
6
6
  Project-URL: Documentation, https://github.com/yourusername/DataTransformer#readme
@@ -423,6 +423,8 @@ dt sample data.csv --num=100 --sample_type=head
423
423
  dt sample data.jsonl 1000 --by=category # 分层采样
424
424
  dt sample data.jsonl 1000 --by=meta.source # 按嵌套字段分层采样
425
425
  dt sample data.jsonl 1000 --by=messages.# # 按消息数量分层采样
426
+ dt sample data.jsonl --where="category=tech" # 筛选后采样
427
+ dt sample data.jsonl --where="messages.#>=2" # 多条件筛选
426
428
 
427
429
  # 数据转换 - 预设模式
428
430
  dt transform data.jsonl --preset=openai_chat
@@ -496,7 +498,7 @@ CLI 命令中的字段参数支持嵌套路径语法,可访问深层嵌套的
496
498
 
497
499
  | 命令 | 参数 | 示例 |
498
500
  |------|------|------|
499
- | `sample` | `--by=` | `--by=meta.source`、`--by=messages.#` |
501
+ | `sample` | `--by=`, `--where=` | `--by=meta.source`、`--where=messages.#>=2` |
500
502
  | `dedupe` | `--key=` | `--key=meta.id`、`--key=messages[0].content` |
501
503
  | `clean` | `--drop-empty=` | `--drop-empty=meta.source` |
502
504
  | `clean` | `--min-len=` | `--min-len=messages.#:2` |
@@ -337,6 +337,8 @@ dt sample data.csv --num=100 --sample_type=head
337
337
  dt sample data.jsonl 1000 --by=category # 分层采样
338
338
  dt sample data.jsonl 1000 --by=meta.source # 按嵌套字段分层采样
339
339
  dt sample data.jsonl 1000 --by=messages.# # 按消息数量分层采样
340
+ dt sample data.jsonl --where="category=tech" # 筛选后采样
341
+ dt sample data.jsonl --where="messages.#>=2" # 多条件筛选
340
342
 
341
343
  # 数据转换 - 预设模式
342
344
  dt transform data.jsonl --preset=openai_chat
@@ -410,7 +412,7 @@ CLI 命令中的字段参数支持嵌套路径语法,可访问深层嵌套的
410
412
 
411
413
  | 命令 | 参数 | 示例 |
412
414
  |------|------|------|
413
- | `sample` | `--by=` | `--by=meta.source`、`--by=messages.#` |
415
+ | `sample` | `--by=`, `--where=` | `--by=meta.source`、`--where=messages.#>=2` |
414
416
  | `dedupe` | `--key=` | `--key=meta.id`、`--key=messages[0].content` |
415
417
  | `clean` | `--drop-empty=` | `--drop-empty=meta.source` |
416
418
  | `clean` | `--min-len=` | `--min-len=messages.#:2` |
@@ -60,7 +60,7 @@ from .tokenizers import (
60
60
  token_stats,
61
61
  )
62
62
 
63
- __version__ = "0.5.5"
63
+ __version__ = "0.5.6"
64
64
 
65
65
  __all__ = [
66
66
  # core
@@ -67,10 +67,11 @@ def sample(
67
67
  uniform: bool = typer.Option(False, "--uniform", help="均匀采样模式"),
68
68
  fields: Optional[str] = typer.Option(None, "--fields", "-f", help="只显示指定字段(逗号分隔)"),
69
69
  raw: bool = typer.Option(False, "--raw", "-r", help="输出原始 JSON(不截断)"),
70
+ where: Optional[List[str]] = typer.Option(None, "--where", "-w", help="筛选条件 (可多次使用)"),
70
71
  ):
71
72
  """从数据文件中采样指定数量的数据"""
72
73
  actual_num = num_arg if num_arg is not None else num
73
- _sample(filename, actual_num, type, output, seed, by, uniform, fields, raw)
74
+ _sample(filename, actual_num, type, output, seed, by, uniform, fields, raw, where)
74
75
 
75
76
 
76
77
  @app.command()
@@ -2,8 +2,9 @@
2
2
  CLI 采样相关命令
3
3
  """
4
4
 
5
+ import re
5
6
  from pathlib import Path
6
- from typing import Any, Dict, List, Literal, Optional
7
+ from typing import Any, Callable, Dict, List, Literal, Optional
7
8
 
8
9
  import orjson
9
10
 
@@ -16,6 +17,122 @@ from .common import (
16
17
  _print_samples,
17
18
  )
18
19
 
20
+ # where 条件解析正则:field op value
21
+ _WHERE_PATTERN = re.compile(r"^(.+?)(!=|~=|>=|<=|>|<|=)(.*)$")
22
+
23
+
24
+ def _parse_where(condition: str) -> Callable[[dict], bool]:
25
+ """
26
+ 解析 where 条件字符串,返回筛选函数。
27
+
28
+ 支持的操作符:
29
+ = 等于
30
+ != 不等于
31
+ ~= 包含(字符串)
32
+ > 大于
33
+ >= 大于等于
34
+ < 小于
35
+ <= 小于等于
36
+
37
+ Examples:
38
+ _parse_where("category=tech")
39
+ _parse_where("meta.source!=wiki")
40
+ _parse_where("content~=机器学习")
41
+ _parse_where("messages.#>=2")
42
+ """
43
+ match = _WHERE_PATTERN.match(condition)
44
+ if not match:
45
+ raise ValueError(f"无效的 where 条件: {condition}")
46
+
47
+ field, op, value = match.groups()
48
+
49
+ # 尝试转换 value 为数值
50
+ def parse_value(v: str) -> Any:
51
+ if v.lower() == "true":
52
+ return True
53
+ if v.lower() == "false":
54
+ return False
55
+ try:
56
+ return int(v)
57
+ except ValueError:
58
+ try:
59
+ return float(v)
60
+ except ValueError:
61
+ return v
62
+
63
+ parsed_value = parse_value(value)
64
+
65
+ def filter_fn(item: dict) -> bool:
66
+ field_value = get_field_with_spec(item, field)
67
+
68
+ if op == "=":
69
+ # 字符串比较或数值比较
70
+ if field_value is None:
71
+ return value == "" or value.lower() == "none"
72
+ return str(field_value) == value or field_value == parsed_value
73
+ elif op == "!=":
74
+ if field_value is None:
75
+ return value != "" and value.lower() != "none"
76
+ return str(field_value) != value and field_value != parsed_value
77
+ elif op == "~=":
78
+ # 包含
79
+ if field_value is None:
80
+ return False
81
+ return value in str(field_value)
82
+ elif op in (">", ">=", "<", "<="):
83
+ # 数值比较
84
+ if field_value is None:
85
+ return False
86
+ try:
87
+ num_field = float(field_value)
88
+ num_value = float(value)
89
+ if op == ">":
90
+ return num_field > num_value
91
+ elif op == ">=":
92
+ return num_field >= num_value
93
+ elif op == "<":
94
+ return num_field < num_value
95
+ else: # <=
96
+ return num_field <= num_value
97
+ except (ValueError, TypeError):
98
+ return False
99
+ return False
100
+
101
+ return filter_fn
102
+
103
+
104
+ def _apply_where_filters(data: List[Dict], where_conditions: List[str]) -> List[Dict]:
105
+ """应用多个 where 条件(AND 关系)"""
106
+ if not where_conditions:
107
+ return data
108
+
109
+ filters = [_parse_where(cond) for cond in where_conditions]
110
+ return [item for item in data if all(f(item) for f in filters)]
111
+
112
+
113
+ def _sample_from_list(
114
+ data: List[Dict],
115
+ num: int,
116
+ sample_type: str,
117
+ seed: Optional[int] = None,
118
+ ) -> List[Dict]:
119
+ """从列表中采样"""
120
+ import random
121
+
122
+ if seed is not None:
123
+ random.seed(seed)
124
+
125
+ total = len(data)
126
+ if num <= 0 or num > total:
127
+ num = total
128
+
129
+ if sample_type == "random":
130
+ return random.sample(data, num)
131
+ elif sample_type == "head":
132
+ return data[:num]
133
+ else: # tail
134
+ return data[-num:]
135
+
19
136
 
20
137
  def sample(
21
138
  filename: str,
@@ -27,6 +144,7 @@ def sample(
27
144
  uniform: bool = False,
28
145
  fields: Optional[str] = None,
29
146
  raw: bool = False,
147
+ where: Optional[List[str]] = None,
30
148
  ) -> None:
31
149
  """
32
150
  从数据文件中采样指定数量的数据。
@@ -44,6 +162,7 @@ def sample(
44
162
  uniform: 均匀采样模式(需配合 --by 使用),各组采样相同数量
45
163
  fields: 只显示指定字段(逗号分隔),仅在预览模式下有效
46
164
  raw: 输出原始 JSON 格式(不截断,完整显示所有内容)
165
+ where: 筛选条件列表,支持 =, !=, ~=, >, >=, <, <= 操作符
47
166
 
48
167
  Examples:
49
168
  dt sample data.jsonl 5
@@ -54,6 +173,9 @@ def sample(
54
173
  dt sample data.jsonl 1000 --by=category # 按比例分层采样
55
174
  dt sample data.jsonl 1000 --by=category --uniform # 均匀分层采样
56
175
  dt sample data.jsonl --fields=question,answer # 只显示指定字段
176
+ dt sample data.jsonl --where="category=tech" # 筛选 category 为 tech 的数据
177
+ dt sample data.jsonl --where="meta.source~=wiki" # 筛选 meta.source 包含 wiki
178
+ dt sample data.jsonl --where="messages.#>=2" # 筛选消息数量 >= 2
57
179
  """
58
180
  filepath = Path(filename)
59
181
 
@@ -69,23 +191,46 @@ def sample(
69
191
  print("错误: --uniform 必须配合 --by 使用")
70
192
  return
71
193
 
194
+ # 处理 where 筛选
195
+ where_conditions = where or []
196
+ filtered_data = None
197
+ original_count = None
198
+
199
+ if where_conditions:
200
+ # 有 where 条件时,先加载全部数据再筛选
201
+ try:
202
+ all_data = load_data(str(filepath))
203
+ original_count = len(all_data)
204
+ filtered_data = _apply_where_filters(all_data, where_conditions)
205
+ print(f"🔍 筛选: {original_count} → {len(filtered_data)} 条")
206
+ if not filtered_data:
207
+ print("⚠️ 筛选后无数据")
208
+ return
209
+ except ValueError as e:
210
+ print(f"错误: {e}")
211
+ return
212
+
72
213
  # 分层采样模式
73
214
  if by:
74
215
  try:
75
- sampled = _stratified_sample(filepath, num, by, uniform, seed, type)
216
+ sampled = _stratified_sample(filepath, num, by, uniform, seed, type, data=filtered_data)
76
217
  except Exception as e:
77
218
  print(f"错误: {e}")
78
219
  return
79
220
  else:
80
221
  # 普通采样
81
222
  try:
82
- sampled = sample_file(
83
- str(filepath),
84
- num=num,
85
- sample_type=type,
86
- seed=seed,
87
- output=None, # 先不保存,统一在最后处理
88
- )
223
+ if filtered_data is not None:
224
+ # 已筛选的数据,直接采样
225
+ sampled = _sample_from_list(filtered_data, num, type, seed)
226
+ else:
227
+ sampled = sample_file(
228
+ str(filepath),
229
+ num=num,
230
+ sample_type=type,
231
+ seed=seed,
232
+ output=None, # 先不保存,统一在最后处理
233
+ )
89
234
  except Exception as e:
90
235
  print(f"错误: {e}")
91
236
  return
@@ -117,6 +262,7 @@ def _stratified_sample(
117
262
  uniform: bool,
118
263
  seed: Optional[int],
119
264
  sample_type: str,
265
+ data: Optional[List[Dict]] = None,
120
266
  ) -> List[Dict]:
121
267
  """
122
268
  分层采样实现。
@@ -133,6 +279,7 @@ def _stratified_sample(
133
279
  uniform: 是否均匀采样(各组相同数量)
134
280
  seed: 随机种子
135
281
  sample_type: 采样方式(用于组内采样)
282
+ data: 预筛选的数据(可选,如果提供则不从文件加载)
136
283
 
137
284
  Returns:
138
285
  采样后的数据列表
@@ -143,8 +290,9 @@ def _stratified_sample(
143
290
  if seed is not None:
144
291
  random.seed(seed)
145
292
 
146
- # 加载数据
147
- data = load_data(str(filepath))
293
+ # 加载数据(如果没有预筛选数据)
294
+ if data is None:
295
+ data = load_data(str(filepath))
148
296
  total = len(data)
149
297
 
150
298
  if num <= 0 or num > total:
@@ -240,3 +240,102 @@ class TestRawOutput:
240
240
  # Raw mode outputs JSON with indentation
241
241
  assert "question" in captured.out
242
242
  assert "Question 0" in captured.out
243
+
244
+
245
+ # ============== Where Filter Tests ==============
246
+
247
+
248
+ class TestWhereFilter:
249
+ """Test --where filter functionality."""
250
+
251
+ def test_where_equal(self, sample_qa_file, tmp_path, capsys):
252
+ """Test where filter with = operator."""
253
+ filepath, _ = sample_qa_file
254
+ output = tmp_path / "filtered.jsonl"
255
+
256
+ sample(str(filepath), num=100, output=str(output), where=["category=cat0"])
257
+
258
+ result = load_data(str(output))
259
+ assert len(result) > 0
260
+ assert all(item["category"] == "cat0" for item in result)
261
+
262
+ def test_where_not_equal(self, sample_qa_file, tmp_path, capsys):
263
+ """Test where filter with != operator."""
264
+ filepath, _ = sample_qa_file
265
+ output = tmp_path / "filtered.jsonl"
266
+
267
+ sample(str(filepath), num=100, output=str(output), where=["category!=cat0"])
268
+
269
+ result = load_data(str(output))
270
+ assert len(result) > 0
271
+ assert all(item["category"] != "cat0" for item in result)
272
+
273
+ def test_where_contains(self, sample_qa_file, tmp_path, capsys):
274
+ """Test where filter with ~= (contains) operator."""
275
+ filepath, _ = sample_qa_file
276
+ output = tmp_path / "filtered.jsonl"
277
+
278
+ sample(str(filepath), num=100, output=str(output), where=["question~=Question 1"])
279
+
280
+ result = load_data(str(output))
281
+ assert len(result) > 0
282
+ assert all("Question 1" in item["question"] for item in result)
283
+
284
+ def test_where_nested_field(self, sample_nested_file, tmp_path, capsys):
285
+ """Test where filter on nested fields."""
286
+ filepath, _ = sample_nested_file
287
+ output = tmp_path / "filtered.jsonl"
288
+
289
+ sample(str(filepath), num=100, output=str(output), where=["meta.source=source0"])
290
+
291
+ result = load_data(str(output))
292
+ assert len(result) > 0
293
+ assert all(item["meta"]["source"] == "source0" for item in result)
294
+
295
+ def test_where_numeric_comparison(self, sample_nested_file, tmp_path, capsys):
296
+ """Test where filter with numeric comparison."""
297
+ filepath, _ = sample_nested_file
298
+ output = tmp_path / "filtered.jsonl"
299
+
300
+ sample(str(filepath), num=100, output=str(output), where=["id>=10"])
301
+
302
+ result = load_data(str(output))
303
+ assert len(result) > 0
304
+ assert all(item["id"] >= 10 for item in result)
305
+
306
+ def test_where_multiple_conditions(self, sample_qa_file, tmp_path, capsys):
307
+ """Test multiple where conditions (AND logic)."""
308
+ filepath, _ = sample_qa_file
309
+ output = tmp_path / "filtered.jsonl"
310
+
311
+ sample(
312
+ str(filepath),
313
+ num=100,
314
+ output=str(output),
315
+ where=["category=cat0", "question~=Question 0"],
316
+ )
317
+
318
+ result = load_data(str(output))
319
+ # category=cat0 包括 id 0, 3, 6, 9, 12, 15, 18
320
+ # question~=Question 0 包括 Question 0
321
+ assert len(result) == 1
322
+ assert result[0]["category"] == "cat0"
323
+ assert "Question 0" in result[0]["question"]
324
+
325
+ def test_where_no_match(self, sample_qa_file, capsys):
326
+ """Test where filter with no matching results."""
327
+ filepath, _ = sample_qa_file
328
+
329
+ sample(str(filepath), num=10, where=["category=nonexistent"])
330
+
331
+ captured = capsys.readouterr()
332
+ assert "筛选后无数据" in captured.out
333
+
334
+ def test_where_invalid_condition(self, sample_qa_file, capsys):
335
+ """Test where filter with invalid condition format."""
336
+ filepath, _ = sample_qa_file
337
+
338
+ sample(str(filepath), num=10, where=["invalid_condition"])
339
+
340
+ captured = capsys.readouterr()
341
+ assert "无效的 where 条件" in captured.out
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes