dtflow 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dtflow/__init__.py CHANGED
@@ -60,7 +60,7 @@ from .tokenizers import (
60
60
  token_stats,
61
61
  )
62
62
 
63
- __version__ = "0.5.4"
63
+ __version__ = "0.5.6"
64
64
 
65
65
  __all__ = [
66
66
  # core
dtflow/__main__.py CHANGED
@@ -60,17 +60,18 @@ def sample(
60
60
  filename: str = typer.Argument(..., help="输入文件路径"),
61
61
  num_arg: Optional[int] = typer.Argument(None, help="采样数量", metavar="NUM"),
62
62
  num: int = typer.Option(10, "--num", "-n", help="采样数量", show_default=True),
63
- type: str = typer.Option("head", "--type", "-t", help="采样方式: random/head/tail"),
63
+ type: str = typer.Option("random", "--type", "-t", help="采样方式: random/head/tail"),
64
64
  output: Optional[str] = typer.Option(None, "--output", "-o", help="输出文件路径"),
65
65
  seed: Optional[int] = typer.Option(None, "--seed", help="随机种子"),
66
66
  by: Optional[str] = typer.Option(None, "--by", help="分层采样字段"),
67
67
  uniform: bool = typer.Option(False, "--uniform", help="均匀采样模式"),
68
68
  fields: Optional[str] = typer.Option(None, "--fields", "-f", help="只显示指定字段(逗号分隔)"),
69
69
  raw: bool = typer.Option(False, "--raw", "-r", help="输出原始 JSON(不截断)"),
70
+ where: Optional[List[str]] = typer.Option(None, "--where", "-w", help="筛选条件 (可多次使用)"),
70
71
  ):
71
72
  """从数据文件中采样指定数量的数据"""
72
73
  actual_num = num_arg if num_arg is not None else num
73
- _sample(filename, actual_num, type, output, seed, by, uniform, fields, raw)
74
+ _sample(filename, actual_num, type, output, seed, by, uniform, fields, raw, where)
74
75
 
75
76
 
76
77
  @app.command()
@@ -223,9 +224,7 @@ def validate(
223
224
  None, "--preset", "-p", help="预设 Schema: openai_chat, alpaca, dpo, sharegpt"
224
225
  ),
225
226
  output: Optional[str] = typer.Option(None, "--output", "-o", help="输出有效数据的文件路径"),
226
- filter: bool = typer.Option(
227
- False, "--filter", "-f", help="过滤无效数据并保存"
228
- ),
227
+ filter: bool = typer.Option(False, "--filter", "-f", help="过滤无效数据并保存"),
229
228
  max_errors: int = typer.Option(20, "--max-errors", help="最多显示的错误数量"),
230
229
  verbose: bool = typer.Option(False, "--verbose", "-v", help="显示详细信息"),
231
230
  ):
dtflow/cli/sample.py CHANGED
@@ -2,8 +2,9 @@
2
2
  CLI 采样相关命令
3
3
  """
4
4
 
5
+ import re
5
6
  from pathlib import Path
6
- from typing import Any, Dict, List, Literal, Optional
7
+ from typing import Any, Callable, Dict, List, Literal, Optional
7
8
 
8
9
  import orjson
9
10
 
@@ -16,17 +17,134 @@ from .common import (
16
17
  _print_samples,
17
18
  )
18
19
 
20
+ # where 条件解析正则:field op value
21
+ _WHERE_PATTERN = re.compile(r"^(.+?)(!=|~=|>=|<=|>|<|=)(.*)$")
22
+
23
+
24
+ def _parse_where(condition: str) -> Callable[[dict], bool]:
25
+ """
26
+ 解析 where 条件字符串,返回筛选函数。
27
+
28
+ 支持的操作符:
29
+ = 等于
30
+ != 不等于
31
+ ~= 包含(字符串)
32
+ > 大于
33
+ >= 大于等于
34
+ < 小于
35
+ <= 小于等于
36
+
37
+ Examples:
38
+ _parse_where("category=tech")
39
+ _parse_where("meta.source!=wiki")
40
+ _parse_where("content~=机器学习")
41
+ _parse_where("messages.#>=2")
42
+ """
43
+ match = _WHERE_PATTERN.match(condition)
44
+ if not match:
45
+ raise ValueError(f"无效的 where 条件: {condition}")
46
+
47
+ field, op, value = match.groups()
48
+
49
+ # 尝试转换 value 为数值
50
+ def parse_value(v: str) -> Any:
51
+ if v.lower() == "true":
52
+ return True
53
+ if v.lower() == "false":
54
+ return False
55
+ try:
56
+ return int(v)
57
+ except ValueError:
58
+ try:
59
+ return float(v)
60
+ except ValueError:
61
+ return v
62
+
63
+ parsed_value = parse_value(value)
64
+
65
+ def filter_fn(item: dict) -> bool:
66
+ field_value = get_field_with_spec(item, field)
67
+
68
+ if op == "=":
69
+ # 字符串比较或数值比较
70
+ if field_value is None:
71
+ return value == "" or value.lower() == "none"
72
+ return str(field_value) == value or field_value == parsed_value
73
+ elif op == "!=":
74
+ if field_value is None:
75
+ return value != "" and value.lower() != "none"
76
+ return str(field_value) != value and field_value != parsed_value
77
+ elif op == "~=":
78
+ # 包含
79
+ if field_value is None:
80
+ return False
81
+ return value in str(field_value)
82
+ elif op in (">", ">=", "<", "<="):
83
+ # 数值比较
84
+ if field_value is None:
85
+ return False
86
+ try:
87
+ num_field = float(field_value)
88
+ num_value = float(value)
89
+ if op == ">":
90
+ return num_field > num_value
91
+ elif op == ">=":
92
+ return num_field >= num_value
93
+ elif op == "<":
94
+ return num_field < num_value
95
+ else: # <=
96
+ return num_field <= num_value
97
+ except (ValueError, TypeError):
98
+ return False
99
+ return False
100
+
101
+ return filter_fn
102
+
103
+
104
+ def _apply_where_filters(data: List[Dict], where_conditions: List[str]) -> List[Dict]:
105
+ """应用多个 where 条件(AND 关系)"""
106
+ if not where_conditions:
107
+ return data
108
+
109
+ filters = [_parse_where(cond) for cond in where_conditions]
110
+ return [item for item in data if all(f(item) for f in filters)]
111
+
112
+
113
+ def _sample_from_list(
114
+ data: List[Dict],
115
+ num: int,
116
+ sample_type: str,
117
+ seed: Optional[int] = None,
118
+ ) -> List[Dict]:
119
+ """从列表中采样"""
120
+ import random
121
+
122
+ if seed is not None:
123
+ random.seed(seed)
124
+
125
+ total = len(data)
126
+ if num <= 0 or num > total:
127
+ num = total
128
+
129
+ if sample_type == "random":
130
+ return random.sample(data, num)
131
+ elif sample_type == "head":
132
+ return data[:num]
133
+ else: # tail
134
+ return data[-num:]
135
+
19
136
 
20
137
  def sample(
21
138
  filename: str,
22
139
  num: int = 10,
23
- type: Literal["random", "head", "tail"] = "head",
140
+ type: Literal["random", "head", "tail"] = "random",
24
141
  output: Optional[str] = None,
25
142
  seed: Optional[int] = None,
26
143
  by: Optional[str] = None,
27
144
  uniform: bool = False,
28
145
  fields: Optional[str] = None,
29
146
  raw: bool = False,
147
+ where: Optional[List[str]] = None,
30
148
  ) -> None:
31
149
  """
32
150
  从数据文件中采样指定数量的数据。
@@ -37,13 +155,14 @@ def sample(
37
155
  - num > 0: 采样指定数量
38
156
  - num = 0: 采样所有数据
39
157
  - num < 0: Python 切片风格(如 -1 表示最后 1 条,-10 表示最后 10 条)
40
- type: 采样方式,可选 random/head/tail,默认 head
158
+ type: 采样方式,可选 random/head/tail,默认 random
41
159
  output: 输出文件路径,不指定则打印到控制台
42
160
  seed: 随机种子(仅在 type=random 时有效)
43
161
  by: 分层采样字段名,按该字段的值分组采样
44
162
  uniform: 均匀采样模式(需配合 --by 使用),各组采样相同数量
45
163
  fields: 只显示指定字段(逗号分隔),仅在预览模式下有效
46
164
  raw: 输出原始 JSON 格式(不截断,完整显示所有内容)
165
+ where: 筛选条件列表,支持 =, !=, ~=, >, >=, <, <= 操作符
47
166
 
48
167
  Examples:
49
168
  dt sample data.jsonl 5
@@ -54,6 +173,9 @@ def sample(
54
173
  dt sample data.jsonl 1000 --by=category # 按比例分层采样
55
174
  dt sample data.jsonl 1000 --by=category --uniform # 均匀分层采样
56
175
  dt sample data.jsonl --fields=question,answer # 只显示指定字段
176
+ dt sample data.jsonl --where="category=tech" # 筛选 category 为 tech 的数据
177
+ dt sample data.jsonl --where="meta.source~=wiki" # 筛选 meta.source 包含 wiki
178
+ dt sample data.jsonl --where="messages.#>=2" # 筛选消息数量 >= 2
57
179
  """
58
180
  filepath = Path(filename)
59
181
 
@@ -69,23 +191,46 @@ def sample(
69
191
  print("错误: --uniform 必须配合 --by 使用")
70
192
  return
71
193
 
194
+ # 处理 where 筛选
195
+ where_conditions = where or []
196
+ filtered_data = None
197
+ original_count = None
198
+
199
+ if where_conditions:
200
+ # 有 where 条件时,先加载全部数据再筛选
201
+ try:
202
+ all_data = load_data(str(filepath))
203
+ original_count = len(all_data)
204
+ filtered_data = _apply_where_filters(all_data, where_conditions)
205
+ print(f"🔍 筛选: {original_count} → {len(filtered_data)} 条")
206
+ if not filtered_data:
207
+ print("⚠️ 筛选后无数据")
208
+ return
209
+ except ValueError as e:
210
+ print(f"错误: {e}")
211
+ return
212
+
72
213
  # 分层采样模式
73
214
  if by:
74
215
  try:
75
- sampled = _stratified_sample(filepath, num, by, uniform, seed, type)
216
+ sampled = _stratified_sample(filepath, num, by, uniform, seed, type, data=filtered_data)
76
217
  except Exception as e:
77
218
  print(f"错误: {e}")
78
219
  return
79
220
  else:
80
221
  # 普通采样
81
222
  try:
82
- sampled = sample_file(
83
- str(filepath),
84
- num=num,
85
- sample_type=type,
86
- seed=seed,
87
- output=None, # 先不保存,统一在最后处理
88
- )
223
+ if filtered_data is not None:
224
+ # 已筛选的数据,直接采样
225
+ sampled = _sample_from_list(filtered_data, num, type, seed)
226
+ else:
227
+ sampled = sample_file(
228
+ str(filepath),
229
+ num=num,
230
+ sample_type=type,
231
+ seed=seed,
232
+ output=None, # 先不保存,统一在最后处理
233
+ )
89
234
  except Exception as e:
90
235
  print(f"错误: {e}")
91
236
  return
@@ -117,6 +262,7 @@ def _stratified_sample(
117
262
  uniform: bool,
118
263
  seed: Optional[int],
119
264
  sample_type: str,
265
+ data: Optional[List[Dict]] = None,
120
266
  ) -> List[Dict]:
121
267
  """
122
268
  分层采样实现。
@@ -133,6 +279,7 @@ def _stratified_sample(
133
279
  uniform: 是否均匀采样(各组相同数量)
134
280
  seed: 随机种子
135
281
  sample_type: 采样方式(用于组内采样)
282
+ data: 预筛选的数据(可选,如果提供则不从文件加载)
136
283
 
137
284
  Returns:
138
285
  采样后的数据列表
@@ -143,8 +290,9 @@ def _stratified_sample(
143
290
  if seed is not None:
144
291
  random.seed(seed)
145
292
 
146
- # 加载数据
147
- data = load_data(str(filepath))
293
+ # 加载数据(如果没有预筛选数据)
294
+ if data is None:
295
+ data = load_data(str(filepath))
148
296
  total = len(data)
149
297
 
150
298
  if num <= 0 or num > total:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dtflow
3
- Version: 0.5.4
3
+ Version: 0.5.6
4
4
  Summary: A flexible data transformation tool for ML training formats (SFT, RLHF, Pretrain)
5
5
  Project-URL: Homepage, https://github.com/yourusername/DataTransformer
6
6
  Project-URL: Documentation, https://github.com/yourusername/DataTransformer#readme
@@ -423,6 +423,8 @@ dt sample data.csv --num=100 --sample_type=head
423
423
  dt sample data.jsonl 1000 --by=category # 分层采样
424
424
  dt sample data.jsonl 1000 --by=meta.source # 按嵌套字段分层采样
425
425
  dt sample data.jsonl 1000 --by=messages.# # 按消息数量分层采样
426
+ dt sample data.jsonl --where="category=tech" # 筛选后采样
427
+ dt sample data.jsonl --where="messages.#>=2" # 多条件筛选
426
428
 
427
429
  # 数据转换 - 预设模式
428
430
  dt transform data.jsonl --preset=openai_chat
@@ -496,7 +498,7 @@ CLI 命令中的字段参数支持嵌套路径语法,可访问深层嵌套的
496
498
 
497
499
  | 命令 | 参数 | 示例 |
498
500
  |------|------|------|
499
- | `sample` | `--by=` | `--by=meta.source`、`--by=messages.#` |
501
+ | `sample` | `--by=`, `--where=` | `--by=meta.source`、`--where=messages.#>=2` |
500
502
  | `dedupe` | `--key=` | `--key=meta.id`、`--key=messages[0].content` |
501
503
  | `clean` | `--drop-empty=` | `--drop-empty=meta.source` |
502
504
  | `clean` | `--min-len=` | `--min-len=messages.#:2` |
@@ -1,5 +1,5 @@
1
- dtflow/__init__.py,sha256=yUwvKuVAmhDnp-1tYhZGlZcTdiEnZ3Jh-IJymgMIUhA,3031
2
- dtflow/__main__.py,sha256=ySpqvEn7k-vsrYFPx-8O6p-yx_24KccgnOSPd2XybhM,12572
1
+ dtflow/__init__.py,sha256=_KUxZUD08hQhhLugGbjo_jlP5JuMCFAcCs0o0SCCoVM,3031
2
+ dtflow/__main__.py,sha256=OJ60M0PbA0PcsQfA7FP9k9CflJgzexKhIl-yc-CPXkw,12675
3
3
  dtflow/converters.py,sha256=X3qeFD7FCOMnfiP3MicL5MXimOm4XUYBs5pczIkudU0,22331
4
4
  dtflow/core.py,sha256=qMo6B3LK--TWRK7ZBKObGcs3pKFnd0NPoaM0T8JC7Jw,38135
5
5
  dtflow/framework.py,sha256=jyICi_RWHjX7WfsXdSbWmP1SL7y1OWSPyd5G5Y-lvg4,17578
@@ -16,7 +16,7 @@ dtflow/cli/common.py,sha256=gCwnF5Sw2ploqfZJO_z3Ms9mR1HNT7Lj6ydHn0uVaIw,13817
16
16
  dtflow/cli/io_ops.py,sha256=BMDisP6dxzzmSjYwmeFwaHmpHHPqirmXAWeNTD-9MQM,13254
17
17
  dtflow/cli/lineage.py,sha256=_lNh35nF9AA0Zy6FyZ4g8IzrXH2ZQnp3inF-o2Hs1pw,1383
18
18
  dtflow/cli/pipeline.py,sha256=QNEo-BJlaC1CVnVeRZr7TwfuZYloJ4TebIzJ5ALzry0,1426
19
- dtflow/cli/sample.py,sha256=LRCkpFi9t0CI2QjRKADmvwWMdGfLriqdNkoFG6_wQkY,10497
19
+ dtflow/cli/sample.py,sha256=pubpx4AIzsarBEalD150MC2apYQSt4bal70IZkTfFO0,15475
20
20
  dtflow/cli/stats.py,sha256=u4ehCfgw1X8WuOyAjrApMRgcIO3BVmINbsTjxEscQro,24086
21
21
  dtflow/cli/transform.py,sha256=w6xqMOxPxQvL2u_BPCfpDHuPSC9gmcqMPVN8s-B6bbY,15052
22
22
  dtflow/cli/validate.py,sha256=65aGVlMS_Rq0Ch0YQ-TclVJ03RQP4CnG137wthzb8Ao,4384
@@ -31,7 +31,7 @@ dtflow/utils/__init__.py,sha256=Pn-ltwV04fBQmeZG7FxInDQmzH29LYOi90LgeLMEuQk,506
31
31
  dtflow/utils/display.py,sha256=OeOdTh6mbDwSkDWlmkjfpTjy2QG8ZUaYU0NpHUWkpEQ,5881
32
32
  dtflow/utils/field_path.py,sha256=K8nU196RxTSJ1OoieTWGcYOWl9KjGq2iSxCAkfjECuM,7621
33
33
  dtflow/utils/helpers.py,sha256=JXN176_B2pm53GLVyZ1wj3wrmBJG52Tkw6AMQSdj7M8,791
34
- dtflow-0.5.4.dist-info/METADATA,sha256=mQIIV3B-6VBOuNSRiPQjqOwdLTs6Nir6to1_FIER3d0,22544
35
- dtflow-0.5.4.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
36
- dtflow-0.5.4.dist-info/entry_points.txt,sha256=dadIDOK7Iu9pMxnMPBfpb4aAPe4hQbBOshpQYjVYpGc,44
37
- dtflow-0.5.4.dist-info/RECORD,,
34
+ dtflow-0.5.6.dist-info/METADATA,sha256=TPSDq-fQDini8uKERCdm_4cZYw-b9t6V8UQ1MlTJ7iA,22698
35
+ dtflow-0.5.6.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
36
+ dtflow-0.5.6.dist-info/entry_points.txt,sha256=dadIDOK7Iu9pMxnMPBfpb4aAPe4hQbBOshpQYjVYpGc,44
37
+ dtflow-0.5.6.dist-info/RECORD,,
File without changes