dtflow 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtflow/__init__.py +1 -1
- dtflow/__main__.py +4 -5
- dtflow/cli/sample.py +161 -13
- {dtflow-0.5.4.dist-info → dtflow-0.5.6.dist-info}/METADATA +4 -2
- {dtflow-0.5.4.dist-info → dtflow-0.5.6.dist-info}/RECORD +7 -7
- {dtflow-0.5.4.dist-info → dtflow-0.5.6.dist-info}/WHEEL +0 -0
- {dtflow-0.5.4.dist-info → dtflow-0.5.6.dist-info}/entry_points.txt +0 -0
dtflow/__init__.py
CHANGED
dtflow/__main__.py
CHANGED
|
@@ -60,17 +60,18 @@ def sample(
|
|
|
60
60
|
filename: str = typer.Argument(..., help="输入文件路径"),
|
|
61
61
|
num_arg: Optional[int] = typer.Argument(None, help="采样数量", metavar="NUM"),
|
|
62
62
|
num: int = typer.Option(10, "--num", "-n", help="采样数量", show_default=True),
|
|
63
|
-
type: str = typer.Option("
|
|
63
|
+
type: str = typer.Option("random", "--type", "-t", help="采样方式: random/head/tail"),
|
|
64
64
|
output: Optional[str] = typer.Option(None, "--output", "-o", help="输出文件路径"),
|
|
65
65
|
seed: Optional[int] = typer.Option(None, "--seed", help="随机种子"),
|
|
66
66
|
by: Optional[str] = typer.Option(None, "--by", help="分层采样字段"),
|
|
67
67
|
uniform: bool = typer.Option(False, "--uniform", help="均匀采样模式"),
|
|
68
68
|
fields: Optional[str] = typer.Option(None, "--fields", "-f", help="只显示指定字段(逗号分隔)"),
|
|
69
69
|
raw: bool = typer.Option(False, "--raw", "-r", help="输出原始 JSON(不截断)"),
|
|
70
|
+
where: Optional[List[str]] = typer.Option(None, "--where", "-w", help="筛选条件 (可多次使用)"),
|
|
70
71
|
):
|
|
71
72
|
"""从数据文件中采样指定数量的数据"""
|
|
72
73
|
actual_num = num_arg if num_arg is not None else num
|
|
73
|
-
_sample(filename, actual_num, type, output, seed, by, uniform, fields, raw)
|
|
74
|
+
_sample(filename, actual_num, type, output, seed, by, uniform, fields, raw, where)
|
|
74
75
|
|
|
75
76
|
|
|
76
77
|
@app.command()
|
|
@@ -223,9 +224,7 @@ def validate(
|
|
|
223
224
|
None, "--preset", "-p", help="预设 Schema: openai_chat, alpaca, dpo, sharegpt"
|
|
224
225
|
),
|
|
225
226
|
output: Optional[str] = typer.Option(None, "--output", "-o", help="输出有效数据的文件路径"),
|
|
226
|
-
filter: bool = typer.Option(
|
|
227
|
-
False, "--filter", "-f", help="过滤无效数据并保存"
|
|
228
|
-
),
|
|
227
|
+
filter: bool = typer.Option(False, "--filter", "-f", help="过滤无效数据并保存"),
|
|
229
228
|
max_errors: int = typer.Option(20, "--max-errors", help="最多显示的错误数量"),
|
|
230
229
|
verbose: bool = typer.Option(False, "--verbose", "-v", help="显示详细信息"),
|
|
231
230
|
):
|
dtflow/cli/sample.py
CHANGED
|
@@ -2,8 +2,9 @@
|
|
|
2
2
|
CLI 采样相关命令
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
import re
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
from typing import Any, Dict, List, Literal, Optional
|
|
7
|
+
from typing import Any, Callable, Dict, List, Literal, Optional
|
|
7
8
|
|
|
8
9
|
import orjson
|
|
9
10
|
|
|
@@ -16,17 +17,134 @@ from .common import (
|
|
|
16
17
|
_print_samples,
|
|
17
18
|
)
|
|
18
19
|
|
|
20
|
+
# where 条件解析正则:field op value
|
|
21
|
+
_WHERE_PATTERN = re.compile(r"^(.+?)(!=|~=|>=|<=|>|<|=)(.*)$")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _parse_where(condition: str) -> Callable[[dict], bool]:
|
|
25
|
+
"""
|
|
26
|
+
解析 where 条件字符串,返回筛选函数。
|
|
27
|
+
|
|
28
|
+
支持的操作符:
|
|
29
|
+
= 等于
|
|
30
|
+
!= 不等于
|
|
31
|
+
~= 包含(字符串)
|
|
32
|
+
> 大于
|
|
33
|
+
>= 大于等于
|
|
34
|
+
< 小于
|
|
35
|
+
<= 小于等于
|
|
36
|
+
|
|
37
|
+
Examples:
|
|
38
|
+
_parse_where("category=tech")
|
|
39
|
+
_parse_where("meta.source!=wiki")
|
|
40
|
+
_parse_where("content~=机器学习")
|
|
41
|
+
_parse_where("messages.#>=2")
|
|
42
|
+
"""
|
|
43
|
+
match = _WHERE_PATTERN.match(condition)
|
|
44
|
+
if not match:
|
|
45
|
+
raise ValueError(f"无效的 where 条件: {condition}")
|
|
46
|
+
|
|
47
|
+
field, op, value = match.groups()
|
|
48
|
+
|
|
49
|
+
# 尝试转换 value 为数值
|
|
50
|
+
def parse_value(v: str) -> Any:
|
|
51
|
+
if v.lower() == "true":
|
|
52
|
+
return True
|
|
53
|
+
if v.lower() == "false":
|
|
54
|
+
return False
|
|
55
|
+
try:
|
|
56
|
+
return int(v)
|
|
57
|
+
except ValueError:
|
|
58
|
+
try:
|
|
59
|
+
return float(v)
|
|
60
|
+
except ValueError:
|
|
61
|
+
return v
|
|
62
|
+
|
|
63
|
+
parsed_value = parse_value(value)
|
|
64
|
+
|
|
65
|
+
def filter_fn(item: dict) -> bool:
|
|
66
|
+
field_value = get_field_with_spec(item, field)
|
|
67
|
+
|
|
68
|
+
if op == "=":
|
|
69
|
+
# 字符串比较或数值比较
|
|
70
|
+
if field_value is None:
|
|
71
|
+
return value == "" or value.lower() == "none"
|
|
72
|
+
return str(field_value) == value or field_value == parsed_value
|
|
73
|
+
elif op == "!=":
|
|
74
|
+
if field_value is None:
|
|
75
|
+
return value != "" and value.lower() != "none"
|
|
76
|
+
return str(field_value) != value and field_value != parsed_value
|
|
77
|
+
elif op == "~=":
|
|
78
|
+
# 包含
|
|
79
|
+
if field_value is None:
|
|
80
|
+
return False
|
|
81
|
+
return value in str(field_value)
|
|
82
|
+
elif op in (">", ">=", "<", "<="):
|
|
83
|
+
# 数值比较
|
|
84
|
+
if field_value is None:
|
|
85
|
+
return False
|
|
86
|
+
try:
|
|
87
|
+
num_field = float(field_value)
|
|
88
|
+
num_value = float(value)
|
|
89
|
+
if op == ">":
|
|
90
|
+
return num_field > num_value
|
|
91
|
+
elif op == ">=":
|
|
92
|
+
return num_field >= num_value
|
|
93
|
+
elif op == "<":
|
|
94
|
+
return num_field < num_value
|
|
95
|
+
else: # <=
|
|
96
|
+
return num_field <= num_value
|
|
97
|
+
except (ValueError, TypeError):
|
|
98
|
+
return False
|
|
99
|
+
return False
|
|
100
|
+
|
|
101
|
+
return filter_fn
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _apply_where_filters(data: List[Dict], where_conditions: List[str]) -> List[Dict]:
|
|
105
|
+
"""应用多个 where 条件(AND 关系)"""
|
|
106
|
+
if not where_conditions:
|
|
107
|
+
return data
|
|
108
|
+
|
|
109
|
+
filters = [_parse_where(cond) for cond in where_conditions]
|
|
110
|
+
return [item for item in data if all(f(item) for f in filters)]
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _sample_from_list(
|
|
114
|
+
data: List[Dict],
|
|
115
|
+
num: int,
|
|
116
|
+
sample_type: str,
|
|
117
|
+
seed: Optional[int] = None,
|
|
118
|
+
) -> List[Dict]:
|
|
119
|
+
"""从列表中采样"""
|
|
120
|
+
import random
|
|
121
|
+
|
|
122
|
+
if seed is not None:
|
|
123
|
+
random.seed(seed)
|
|
124
|
+
|
|
125
|
+
total = len(data)
|
|
126
|
+
if num <= 0 or num > total:
|
|
127
|
+
num = total
|
|
128
|
+
|
|
129
|
+
if sample_type == "random":
|
|
130
|
+
return random.sample(data, num)
|
|
131
|
+
elif sample_type == "head":
|
|
132
|
+
return data[:num]
|
|
133
|
+
else: # tail
|
|
134
|
+
return data[-num:]
|
|
135
|
+
|
|
19
136
|
|
|
20
137
|
def sample(
|
|
21
138
|
filename: str,
|
|
22
139
|
num: int = 10,
|
|
23
|
-
type: Literal["random", "head", "tail"] = "
|
|
140
|
+
type: Literal["random", "head", "tail"] = "random",
|
|
24
141
|
output: Optional[str] = None,
|
|
25
142
|
seed: Optional[int] = None,
|
|
26
143
|
by: Optional[str] = None,
|
|
27
144
|
uniform: bool = False,
|
|
28
145
|
fields: Optional[str] = None,
|
|
29
146
|
raw: bool = False,
|
|
147
|
+
where: Optional[List[str]] = None,
|
|
30
148
|
) -> None:
|
|
31
149
|
"""
|
|
32
150
|
从数据文件中采样指定数量的数据。
|
|
@@ -37,13 +155,14 @@ def sample(
|
|
|
37
155
|
- num > 0: 采样指定数量
|
|
38
156
|
- num = 0: 采样所有数据
|
|
39
157
|
- num < 0: Python 切片风格(如 -1 表示最后 1 条,-10 表示最后 10 条)
|
|
40
|
-
type: 采样方式,可选 random/head/tail,默认
|
|
158
|
+
type: 采样方式,可选 random/head/tail,默认 random
|
|
41
159
|
output: 输出文件路径,不指定则打印到控制台
|
|
42
160
|
seed: 随机种子(仅在 type=random 时有效)
|
|
43
161
|
by: 分层采样字段名,按该字段的值分组采样
|
|
44
162
|
uniform: 均匀采样模式(需配合 --by 使用),各组采样相同数量
|
|
45
163
|
fields: 只显示指定字段(逗号分隔),仅在预览模式下有效
|
|
46
164
|
raw: 输出原始 JSON 格式(不截断,完整显示所有内容)
|
|
165
|
+
where: 筛选条件列表,支持 =, !=, ~=, >, >=, <, <= 操作符
|
|
47
166
|
|
|
48
167
|
Examples:
|
|
49
168
|
dt sample data.jsonl 5
|
|
@@ -54,6 +173,9 @@ def sample(
|
|
|
54
173
|
dt sample data.jsonl 1000 --by=category # 按比例分层采样
|
|
55
174
|
dt sample data.jsonl 1000 --by=category --uniform # 均匀分层采样
|
|
56
175
|
dt sample data.jsonl --fields=question,answer # 只显示指定字段
|
|
176
|
+
dt sample data.jsonl --where="category=tech" # 筛选 category 为 tech 的数据
|
|
177
|
+
dt sample data.jsonl --where="meta.source~=wiki" # 筛选 meta.source 包含 wiki
|
|
178
|
+
dt sample data.jsonl --where="messages.#>=2" # 筛选消息数量 >= 2
|
|
57
179
|
"""
|
|
58
180
|
filepath = Path(filename)
|
|
59
181
|
|
|
@@ -69,23 +191,46 @@ def sample(
|
|
|
69
191
|
print("错误: --uniform 必须配合 --by 使用")
|
|
70
192
|
return
|
|
71
193
|
|
|
194
|
+
# 处理 where 筛选
|
|
195
|
+
where_conditions = where or []
|
|
196
|
+
filtered_data = None
|
|
197
|
+
original_count = None
|
|
198
|
+
|
|
199
|
+
if where_conditions:
|
|
200
|
+
# 有 where 条件时,先加载全部数据再筛选
|
|
201
|
+
try:
|
|
202
|
+
all_data = load_data(str(filepath))
|
|
203
|
+
original_count = len(all_data)
|
|
204
|
+
filtered_data = _apply_where_filters(all_data, where_conditions)
|
|
205
|
+
print(f"🔍 筛选: {original_count} → {len(filtered_data)} 条")
|
|
206
|
+
if not filtered_data:
|
|
207
|
+
print("⚠️ 筛选后无数据")
|
|
208
|
+
return
|
|
209
|
+
except ValueError as e:
|
|
210
|
+
print(f"错误: {e}")
|
|
211
|
+
return
|
|
212
|
+
|
|
72
213
|
# 分层采样模式
|
|
73
214
|
if by:
|
|
74
215
|
try:
|
|
75
|
-
sampled = _stratified_sample(filepath, num, by, uniform, seed, type)
|
|
216
|
+
sampled = _stratified_sample(filepath, num, by, uniform, seed, type, data=filtered_data)
|
|
76
217
|
except Exception as e:
|
|
77
218
|
print(f"错误: {e}")
|
|
78
219
|
return
|
|
79
220
|
else:
|
|
80
221
|
# 普通采样
|
|
81
222
|
try:
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
223
|
+
if filtered_data is not None:
|
|
224
|
+
# 已筛选的数据,直接采样
|
|
225
|
+
sampled = _sample_from_list(filtered_data, num, type, seed)
|
|
226
|
+
else:
|
|
227
|
+
sampled = sample_file(
|
|
228
|
+
str(filepath),
|
|
229
|
+
num=num,
|
|
230
|
+
sample_type=type,
|
|
231
|
+
seed=seed,
|
|
232
|
+
output=None, # 先不保存,统一在最后处理
|
|
233
|
+
)
|
|
89
234
|
except Exception as e:
|
|
90
235
|
print(f"错误: {e}")
|
|
91
236
|
return
|
|
@@ -117,6 +262,7 @@ def _stratified_sample(
|
|
|
117
262
|
uniform: bool,
|
|
118
263
|
seed: Optional[int],
|
|
119
264
|
sample_type: str,
|
|
265
|
+
data: Optional[List[Dict]] = None,
|
|
120
266
|
) -> List[Dict]:
|
|
121
267
|
"""
|
|
122
268
|
分层采样实现。
|
|
@@ -133,6 +279,7 @@ def _stratified_sample(
|
|
|
133
279
|
uniform: 是否均匀采样(各组相同数量)
|
|
134
280
|
seed: 随机种子
|
|
135
281
|
sample_type: 采样方式(用于组内采样)
|
|
282
|
+
data: 预筛选的数据(可选,如果提供则不从文件加载)
|
|
136
283
|
|
|
137
284
|
Returns:
|
|
138
285
|
采样后的数据列表
|
|
@@ -143,8 +290,9 @@ def _stratified_sample(
|
|
|
143
290
|
if seed is not None:
|
|
144
291
|
random.seed(seed)
|
|
145
292
|
|
|
146
|
-
#
|
|
147
|
-
data
|
|
293
|
+
# 加载数据(如果没有预筛选数据)
|
|
294
|
+
if data is None:
|
|
295
|
+
data = load_data(str(filepath))
|
|
148
296
|
total = len(data)
|
|
149
297
|
|
|
150
298
|
if num <= 0 or num > total:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dtflow
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.6
|
|
4
4
|
Summary: A flexible data transformation tool for ML training formats (SFT, RLHF, Pretrain)
|
|
5
5
|
Project-URL: Homepage, https://github.com/yourusername/DataTransformer
|
|
6
6
|
Project-URL: Documentation, https://github.com/yourusername/DataTransformer#readme
|
|
@@ -423,6 +423,8 @@ dt sample data.csv --num=100 --sample_type=head
|
|
|
423
423
|
dt sample data.jsonl 1000 --by=category # 分层采样
|
|
424
424
|
dt sample data.jsonl 1000 --by=meta.source # 按嵌套字段分层采样
|
|
425
425
|
dt sample data.jsonl 1000 --by=messages.# # 按消息数量分层采样
|
|
426
|
+
dt sample data.jsonl --where="category=tech" # 筛选后采样
|
|
427
|
+
dt sample data.jsonl --where="messages.#>=2" # 多条件筛选
|
|
426
428
|
|
|
427
429
|
# 数据转换 - 预设模式
|
|
428
430
|
dt transform data.jsonl --preset=openai_chat
|
|
@@ -496,7 +498,7 @@ CLI 命令中的字段参数支持嵌套路径语法,可访问深层嵌套的
|
|
|
496
498
|
|
|
497
499
|
| 命令 | 参数 | 示例 |
|
|
498
500
|
|------|------|------|
|
|
499
|
-
| `sample` | `--by=` | `--by=meta.source`、`--
|
|
501
|
+
| `sample` | `--by=`, `--where=` | `--by=meta.source`、`--where=messages.#>=2` |
|
|
500
502
|
| `dedupe` | `--key=` | `--key=meta.id`、`--key=messages[0].content` |
|
|
501
503
|
| `clean` | `--drop-empty=` | `--drop-empty=meta.source` |
|
|
502
504
|
| `clean` | `--min-len=` | `--min-len=messages.#:2` |
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
dtflow/__init__.py,sha256=
|
|
2
|
-
dtflow/__main__.py,sha256=
|
|
1
|
+
dtflow/__init__.py,sha256=_KUxZUD08hQhhLugGbjo_jlP5JuMCFAcCs0o0SCCoVM,3031
|
|
2
|
+
dtflow/__main__.py,sha256=OJ60M0PbA0PcsQfA7FP9k9CflJgzexKhIl-yc-CPXkw,12675
|
|
3
3
|
dtflow/converters.py,sha256=X3qeFD7FCOMnfiP3MicL5MXimOm4XUYBs5pczIkudU0,22331
|
|
4
4
|
dtflow/core.py,sha256=qMo6B3LK--TWRK7ZBKObGcs3pKFnd0NPoaM0T8JC7Jw,38135
|
|
5
5
|
dtflow/framework.py,sha256=jyICi_RWHjX7WfsXdSbWmP1SL7y1OWSPyd5G5Y-lvg4,17578
|
|
@@ -16,7 +16,7 @@ dtflow/cli/common.py,sha256=gCwnF5Sw2ploqfZJO_z3Ms9mR1HNT7Lj6ydHn0uVaIw,13817
|
|
|
16
16
|
dtflow/cli/io_ops.py,sha256=BMDisP6dxzzmSjYwmeFwaHmpHHPqirmXAWeNTD-9MQM,13254
|
|
17
17
|
dtflow/cli/lineage.py,sha256=_lNh35nF9AA0Zy6FyZ4g8IzrXH2ZQnp3inF-o2Hs1pw,1383
|
|
18
18
|
dtflow/cli/pipeline.py,sha256=QNEo-BJlaC1CVnVeRZr7TwfuZYloJ4TebIzJ5ALzry0,1426
|
|
19
|
-
dtflow/cli/sample.py,sha256=
|
|
19
|
+
dtflow/cli/sample.py,sha256=pubpx4AIzsarBEalD150MC2apYQSt4bal70IZkTfFO0,15475
|
|
20
20
|
dtflow/cli/stats.py,sha256=u4ehCfgw1X8WuOyAjrApMRgcIO3BVmINbsTjxEscQro,24086
|
|
21
21
|
dtflow/cli/transform.py,sha256=w6xqMOxPxQvL2u_BPCfpDHuPSC9gmcqMPVN8s-B6bbY,15052
|
|
22
22
|
dtflow/cli/validate.py,sha256=65aGVlMS_Rq0Ch0YQ-TclVJ03RQP4CnG137wthzb8Ao,4384
|
|
@@ -31,7 +31,7 @@ dtflow/utils/__init__.py,sha256=Pn-ltwV04fBQmeZG7FxInDQmzH29LYOi90LgeLMEuQk,506
|
|
|
31
31
|
dtflow/utils/display.py,sha256=OeOdTh6mbDwSkDWlmkjfpTjy2QG8ZUaYU0NpHUWkpEQ,5881
|
|
32
32
|
dtflow/utils/field_path.py,sha256=K8nU196RxTSJ1OoieTWGcYOWl9KjGq2iSxCAkfjECuM,7621
|
|
33
33
|
dtflow/utils/helpers.py,sha256=JXN176_B2pm53GLVyZ1wj3wrmBJG52Tkw6AMQSdj7M8,791
|
|
34
|
-
dtflow-0.5.
|
|
35
|
-
dtflow-0.5.
|
|
36
|
-
dtflow-0.5.
|
|
37
|
-
dtflow-0.5.
|
|
34
|
+
dtflow-0.5.6.dist-info/METADATA,sha256=TPSDq-fQDini8uKERCdm_4cZYw-b9t6V8UQ1MlTJ7iA,22698
|
|
35
|
+
dtflow-0.5.6.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
36
|
+
dtflow-0.5.6.dist-info/entry_points.txt,sha256=dadIDOK7Iu9pMxnMPBfpb4aAPe4hQbBOshpQYjVYpGc,44
|
|
37
|
+
dtflow-0.5.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|