dtflow 0.5.8__py3-none-any.whl → 0.5.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtflow/SKILL.md +22 -8
- dtflow/__init__.py +1 -1
- dtflow/__main__.py +108 -14
- dtflow/cli/clean.py +90 -1
- dtflow/cli/commands.py +17 -1
- dtflow/cli/eval.py +288 -0
- dtflow/cli/export.py +81 -0
- dtflow/cli/sample.py +90 -3
- dtflow/cli/split.py +138 -0
- dtflow/cli/stats.py +10 -23
- dtflow/cli/validate.py +19 -52
- dtflow/eval.py +276 -0
- dtflow/schema.py +13 -99
- dtflow/tokenizers.py +21 -104
- dtflow/utils/text_parser.py +124 -0
- {dtflow-0.5.8.dist-info → dtflow-0.5.9.dist-info}/METADATA +29 -3
- {dtflow-0.5.8.dist-info → dtflow-0.5.9.dist-info}/RECORD +19 -15
- dtflow/parallel.py +0 -115
- {dtflow-0.5.8.dist-info → dtflow-0.5.9.dist-info}/WHEEL +0 -0
- {dtflow-0.5.8.dist-info → dtflow-0.5.9.dist-info}/entry_points.txt +0 -0
dtflow/cli/eval.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CLI eval 命令实现
|
|
3
|
+
|
|
4
|
+
对模型输出进行解析 + 指标计算,支持两阶段解析和管道式提取。
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import re
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
12
|
+
from rich.console import Console
|
|
13
|
+
|
|
14
|
+
from ..storage.io import load_data
|
|
15
|
+
from ..utils.field_path import get_field
|
|
16
|
+
from ..utils.text_parser import extract_code_snippets, parse_generic_tags, strip_think_tags
|
|
17
|
+
|
|
18
|
+
console = Console()
|
|
19
|
+
|
|
20
|
+
# 自动检测 label 的候选字段名
|
|
21
|
+
LABEL_CANDIDATES = ["label", "labels", "content_label", "target", "ground_truth", "answer"]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def eval(
|
|
25
|
+
result_file: str,
|
|
26
|
+
source: Optional[str] = None,
|
|
27
|
+
response_col: str = "content",
|
|
28
|
+
label_col: Optional[str] = None,
|
|
29
|
+
extract: str = "direct",
|
|
30
|
+
sep: Optional[str] = None,
|
|
31
|
+
mapping: Optional[str] = None,
|
|
32
|
+
output_dir: str = "record",
|
|
33
|
+
):
|
|
34
|
+
"""对模型输出 .jsonl 文件进行解析和指标计算
|
|
35
|
+
|
|
36
|
+
两阶段解析流程:
|
|
37
|
+
阶段1(自动):去除 <think>...</think>,提取 ```...``` 代码块
|
|
38
|
+
阶段2(--extract 指定):管道式提取
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
result_file: 模型输出的 .jsonl 文件路径
|
|
42
|
+
source: 原始输入文件,按行号对齐合并(当 result_file 不含 label 时使用)
|
|
43
|
+
response_col: 模型响应所在字段名(支持嵌套路径,如 api_output.content)
|
|
44
|
+
label_col: 标签字段名(不指定时自动检测,支持嵌套路径)
|
|
45
|
+
extract: 管道式提取规则,算子间用 " | " 分隔
|
|
46
|
+
sep: 配合 index 算子使用的分隔符
|
|
47
|
+
mapping: 值映射,格式 "k1:v1,k2:v2"
|
|
48
|
+
output_dir: 指标报告输出目录
|
|
49
|
+
"""
|
|
50
|
+
import pandas as pd
|
|
51
|
+
|
|
52
|
+
from ..eval import export_eval_report
|
|
53
|
+
|
|
54
|
+
# --- 加载数据 ---
|
|
55
|
+
data = load_data(result_file)
|
|
56
|
+
df = pd.DataFrame(data)
|
|
57
|
+
console.print(f"[cyan]加载 {result_file},共 {len(df)} 条[/cyan]")
|
|
58
|
+
|
|
59
|
+
# 合并 source 文件
|
|
60
|
+
if source:
|
|
61
|
+
source_data = load_data(source)
|
|
62
|
+
source_df = pd.DataFrame(source_data)
|
|
63
|
+
if len(source_df) != len(df):
|
|
64
|
+
console.print(f"[red]行数不一致: result={len(df)}, source={len(source_df)}[/red]")
|
|
65
|
+
return
|
|
66
|
+
for col in source_df.columns:
|
|
67
|
+
if col not in df.columns:
|
|
68
|
+
df[col] = source_df[col].values
|
|
69
|
+
console.print(f"[dim]已合并 source 文件: {source}[/dim]")
|
|
70
|
+
|
|
71
|
+
# --- 解析 response_col(支持嵌套)---
|
|
72
|
+
response_col_resolved = _resolve_nested_col(df, response_col)
|
|
73
|
+
if response_col_resolved is None:
|
|
74
|
+
console.print(f"[red]响应列 '{response_col}' 不存在。可用列: {list(df.columns)}[/red]")
|
|
75
|
+
return
|
|
76
|
+
|
|
77
|
+
# --- 自动检测 label_col ---
|
|
78
|
+
if label_col is None:
|
|
79
|
+
label_col = _auto_detect_label_col(df)
|
|
80
|
+
if label_col is None:
|
|
81
|
+
console.print(
|
|
82
|
+
f"[red]未找到标签列,请通过 --label-col 指定。可用列: {list(df.columns)}[/red]"
|
|
83
|
+
)
|
|
84
|
+
return
|
|
85
|
+
|
|
86
|
+
# 解析 label_col(支持嵌套)
|
|
87
|
+
label_col_resolved = _resolve_nested_col(df, label_col)
|
|
88
|
+
if label_col_resolved is None:
|
|
89
|
+
console.print(f"[red]标签列 '{label_col}' 不存在。可用列: {list(df.columns)}[/red]")
|
|
90
|
+
return
|
|
91
|
+
|
|
92
|
+
console.print(
|
|
93
|
+
f"[dim]response_col={response_col_resolved}, "
|
|
94
|
+
f"label_col={label_col_resolved}, extract={extract}[/dim]"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# --- 阶段1+2:解析 ---
|
|
98
|
+
ops = _parse_pipeline(extract)
|
|
99
|
+
pred_col = "__pred__"
|
|
100
|
+
df[pred_col] = df[response_col_resolved].apply(
|
|
101
|
+
lambda x: _run_pipeline(_stage1_clean(x), ops, sep)
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# --- mapping 阶段 ---
|
|
105
|
+
if mapping:
|
|
106
|
+
m = _parse_mapping(mapping)
|
|
107
|
+
priority = {}
|
|
108
|
+
for i, v in enumerate(m.values()):
|
|
109
|
+
priority[v] = i
|
|
110
|
+
|
|
111
|
+
def map_value(x):
|
|
112
|
+
if isinstance(x, list):
|
|
113
|
+
mapped = [m.get(v, v) for v in x]
|
|
114
|
+
return max(mapped, key=lambda v: priority.get(v, -1))
|
|
115
|
+
return m.get(x, x) if isinstance(x, str) else m.get(str(x), x)
|
|
116
|
+
|
|
117
|
+
df[pred_col] = df[pred_col].apply(map_value)
|
|
118
|
+
df[label_col_resolved] = df[label_col_resolved].apply(map_value)
|
|
119
|
+
|
|
120
|
+
# 统一转字符串
|
|
121
|
+
df[pred_col] = df[pred_col].apply(
|
|
122
|
+
lambda x: str(x).strip() if not isinstance(x, str) else x.strip()
|
|
123
|
+
)
|
|
124
|
+
df[label_col_resolved] = df[label_col_resolved].apply(
|
|
125
|
+
lambda x: str(x).strip() if not isinstance(x, str) else x.strip()
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# --- 调用 export_eval_report ---
|
|
129
|
+
console.print("\n[bold green]评估结果[/bold green]")
|
|
130
|
+
input_name = Path(result_file).stem
|
|
131
|
+
export_eval_report(
|
|
132
|
+
df,
|
|
133
|
+
pred_col=pred_col,
|
|
134
|
+
label_col=label_col_resolved,
|
|
135
|
+
record_folder=output_dir,
|
|
136
|
+
input_name=input_name,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
# ============ 内部工具函数 ============
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _resolve_nested_col(df, col_name: str) -> Optional[str]:
|
|
144
|
+
"""解析嵌套字段路径,将其展开为 DataFrame 的新列
|
|
145
|
+
|
|
146
|
+
使用 dtflow 的 get_field() 支持完整的嵌套路径语法。
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
解析后的列名,或 None(如果字段不存在)
|
|
150
|
+
"""
|
|
151
|
+
# 简单情况:直接列名存在
|
|
152
|
+
if col_name in df.columns:
|
|
153
|
+
return col_name
|
|
154
|
+
|
|
155
|
+
# 尝试嵌套路径
|
|
156
|
+
if "." not in col_name and "[" not in col_name:
|
|
157
|
+
return None
|
|
158
|
+
|
|
159
|
+
# 用 get_field 从第一个非空行试探
|
|
160
|
+
sample_row = None
|
|
161
|
+
for _, row in df.iterrows():
|
|
162
|
+
row_dict = row.to_dict()
|
|
163
|
+
val = get_field(row_dict, col_name)
|
|
164
|
+
if val is not None:
|
|
165
|
+
sample_row = row_dict
|
|
166
|
+
break
|
|
167
|
+
|
|
168
|
+
if sample_row is None:
|
|
169
|
+
return None
|
|
170
|
+
|
|
171
|
+
# 展开嵌套字段到新列
|
|
172
|
+
resolved_name = col_name.replace(".", "__").replace("[", "_").replace("]", "")
|
|
173
|
+
df[resolved_name] = df.apply(lambda row: get_field(row.to_dict(), col_name), axis=1)
|
|
174
|
+
return resolved_name
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def _auto_detect_label_col(df) -> Optional[str]:
|
|
178
|
+
"""自动检测 label 列"""
|
|
179
|
+
# 优先在顶层列中查找
|
|
180
|
+
for c in LABEL_CANDIDATES:
|
|
181
|
+
if c in df.columns:
|
|
182
|
+
return c
|
|
183
|
+
|
|
184
|
+
# 搜索 dict 类型列的嵌套 key
|
|
185
|
+
for col in df.columns:
|
|
186
|
+
non_null = df[col].dropna()
|
|
187
|
+
if non_null.empty:
|
|
188
|
+
continue
|
|
189
|
+
sample = non_null.iloc[0]
|
|
190
|
+
if isinstance(sample, dict):
|
|
191
|
+
for c in LABEL_CANDIDATES:
|
|
192
|
+
if c in sample:
|
|
193
|
+
return f"{col}.{c}"
|
|
194
|
+
|
|
195
|
+
return None
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _stage1_clean(text) -> str:
|
|
199
|
+
"""阶段1:自动清洗(去思考链 + 提取代码块)"""
|
|
200
|
+
if not isinstance(text, str):
|
|
201
|
+
return str(text) if text is not None else ""
|
|
202
|
+
text = strip_think_tags(text)
|
|
203
|
+
snippets = extract_code_snippets(text)
|
|
204
|
+
if snippets:
|
|
205
|
+
return snippets[-1]["code"]
|
|
206
|
+
return text.strip()
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _parse_pipeline(extract_str: str) -> list:
|
|
210
|
+
"""解析管道表达式,按 ' | ' 分割"""
|
|
211
|
+
return [op.strip() for op in extract_str.split(" | ") if op.strip()]
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _apply_op(text: str, op: str, sep: Optional[str] = None) -> str:
|
|
215
|
+
"""对单个字符串执行单个算子"""
|
|
216
|
+
if op == "direct":
|
|
217
|
+
return text
|
|
218
|
+
elif op.startswith("tag:"):
|
|
219
|
+
tag_name = op[4:]
|
|
220
|
+
tags = parse_generic_tags(text)
|
|
221
|
+
return tags.get(tag_name, text)
|
|
222
|
+
elif op.startswith("json_key:"):
|
|
223
|
+
key = op[9:]
|
|
224
|
+
try:
|
|
225
|
+
obj = json.loads(text)
|
|
226
|
+
except Exception:
|
|
227
|
+
return text
|
|
228
|
+
if isinstance(obj, dict):
|
|
229
|
+
return str(obj.get(key, text))
|
|
230
|
+
return text
|
|
231
|
+
elif op.startswith("index:"):
|
|
232
|
+
idx = int(op[6:])
|
|
233
|
+
delimiter = sep if sep else ","
|
|
234
|
+
parts = text.split(delimiter)
|
|
235
|
+
if 0 <= idx < len(parts):
|
|
236
|
+
return parts[idx].strip()
|
|
237
|
+
return text
|
|
238
|
+
elif op.startswith("line:"):
|
|
239
|
+
n = int(op[5:])
|
|
240
|
+
text_lines = [line.strip() for line in text.splitlines() if line.strip()]
|
|
241
|
+
if text_lines and -len(text_lines) <= n < len(text_lines):
|
|
242
|
+
return text_lines[n]
|
|
243
|
+
return text
|
|
244
|
+
elif op.startswith("regex:"):
|
|
245
|
+
pattern = op[6:]
|
|
246
|
+
m = re.search(pattern, text)
|
|
247
|
+
if m:
|
|
248
|
+
return m.group(1) if m.lastindex else m.group(0)
|
|
249
|
+
return text
|
|
250
|
+
else:
|
|
251
|
+
console.print(f"[yellow]未知算子: {op},跳过[/yellow]")
|
|
252
|
+
return text
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def _run_pipeline(text: str, ops: list, sep: Optional[str] = None):
|
|
256
|
+
"""执行管道,处理 lines 展开"""
|
|
257
|
+
if "lines" in ops:
|
|
258
|
+
pos = ops.index("lines")
|
|
259
|
+
# lines 之前的算子先执行
|
|
260
|
+
for op in ops[:pos]:
|
|
261
|
+
text = _apply_op(text, op, sep)
|
|
262
|
+
# 展开为多行
|
|
263
|
+
items = [line.strip() for line in text.splitlines() if line.strip()]
|
|
264
|
+
# 每行独立走后续管道
|
|
265
|
+
rest_ops = ops[pos + 1 :]
|
|
266
|
+
results = []
|
|
267
|
+
for item in items:
|
|
268
|
+
for op in rest_ops:
|
|
269
|
+
item = _apply_op(item, op, sep)
|
|
270
|
+
results.append(item)
|
|
271
|
+
if not results:
|
|
272
|
+
return text
|
|
273
|
+
return results if len(results) > 1 else results[0]
|
|
274
|
+
else:
|
|
275
|
+
for op in ops:
|
|
276
|
+
text = _apply_op(text, op, sep)
|
|
277
|
+
return text
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def _parse_mapping(mapping_str: str) -> dict:
|
|
281
|
+
"""解析 'k1:v1,k2:v2' 格式的映射"""
|
|
282
|
+
m = {}
|
|
283
|
+
for pair in mapping_str.split(","):
|
|
284
|
+
pair = pair.strip()
|
|
285
|
+
if ":" in pair:
|
|
286
|
+
k, v = pair.split(":", 1)
|
|
287
|
+
m[k.strip()] = v.strip()
|
|
288
|
+
return m
|
dtflow/cli/export.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CLI 训练框架导出命令
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from ..core import DataTransformer
|
|
9
|
+
from ..framework import check_compatibility, detect_format, export_for
|
|
10
|
+
from .common import _check_file_format
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def export(
|
|
14
|
+
filename: str,
|
|
15
|
+
framework: str,
|
|
16
|
+
output: Optional[str] = None,
|
|
17
|
+
name: Optional[str] = None,
|
|
18
|
+
check: bool = False,
|
|
19
|
+
) -> None:
|
|
20
|
+
"""
|
|
21
|
+
导出数据到训练框架 (LLaMA-Factory, ms-swift, Axolotl)。
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
filename: 输入文件路径
|
|
25
|
+
framework: 目标框架 (llama-factory, swift, axolotl)
|
|
26
|
+
output: 输出目录(默认 {stem}_{framework}/)
|
|
27
|
+
name: 数据集名称(默认 custom_dataset)
|
|
28
|
+
check: 仅检查兼容性,不导出
|
|
29
|
+
"""
|
|
30
|
+
filepath = Path(filename)
|
|
31
|
+
|
|
32
|
+
if not filepath.exists():
|
|
33
|
+
print(f"错误: 文件不存在 - {filename}")
|
|
34
|
+
return
|
|
35
|
+
|
|
36
|
+
if not _check_file_format(filepath):
|
|
37
|
+
return
|
|
38
|
+
|
|
39
|
+
# 加载数据
|
|
40
|
+
print(f"📊 加载数据: {filepath}")
|
|
41
|
+
try:
|
|
42
|
+
dt = DataTransformer.load(str(filepath))
|
|
43
|
+
except Exception as e:
|
|
44
|
+
print(f"错误: 无法读取文件 - {e}")
|
|
45
|
+
return
|
|
46
|
+
|
|
47
|
+
data = dt.data
|
|
48
|
+
total = len(data)
|
|
49
|
+
print(f" 共 {total} 条数据")
|
|
50
|
+
|
|
51
|
+
# 检测格式
|
|
52
|
+
fmt = detect_format(data)
|
|
53
|
+
print(f"📋 检测到格式: {fmt}")
|
|
54
|
+
|
|
55
|
+
# 兼容性检查
|
|
56
|
+
result = check_compatibility(data, framework)
|
|
57
|
+
print(f"\n{result}")
|
|
58
|
+
|
|
59
|
+
if check:
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
if not result.valid:
|
|
63
|
+
print("\n❌ 兼容性检查未通过,跳过导出")
|
|
64
|
+
return
|
|
65
|
+
|
|
66
|
+
# 确定输出目录
|
|
67
|
+
if output is None:
|
|
68
|
+
fw_short = framework.lower().replace("-", "_")
|
|
69
|
+
output = str(filepath.parent / f"{filepath.stem}_{fw_short}")
|
|
70
|
+
|
|
71
|
+
dataset_name = name or "custom_dataset"
|
|
72
|
+
|
|
73
|
+
# 执行导出
|
|
74
|
+
print(f"\n📦 导出到 {framework}...")
|
|
75
|
+
try:
|
|
76
|
+
export_for(data, framework, output, dataset_name=dataset_name)
|
|
77
|
+
except Exception as e:
|
|
78
|
+
print(f"错误: 导出失败 - {e}")
|
|
79
|
+
return
|
|
80
|
+
|
|
81
|
+
print(f"\n✅ 导出完成! 文件保存在: {output}")
|
dtflow/cli/sample.py
CHANGED
|
@@ -143,7 +143,7 @@ def sample(
|
|
|
143
143
|
by: Optional[str] = None,
|
|
144
144
|
uniform: bool = False,
|
|
145
145
|
fields: Optional[str] = None,
|
|
146
|
-
raw: bool =
|
|
146
|
+
raw: bool = True,
|
|
147
147
|
where: Optional[List[str]] = None,
|
|
148
148
|
) -> None:
|
|
149
149
|
"""
|
|
@@ -389,7 +389,7 @@ def head(
|
|
|
389
389
|
num: int = 10,
|
|
390
390
|
output: Optional[str] = None,
|
|
391
391
|
fields: Optional[str] = None,
|
|
392
|
-
raw: bool =
|
|
392
|
+
raw: bool = True,
|
|
393
393
|
) -> None:
|
|
394
394
|
"""
|
|
395
395
|
显示文件的前 N 条数据(dt sample --type=head 的快捷方式)。
|
|
@@ -415,12 +415,99 @@ def head(
|
|
|
415
415
|
sample(filename, num=num, type="head", output=output, fields=fields, raw=raw)
|
|
416
416
|
|
|
417
417
|
|
|
418
|
+
def slice_data(
|
|
419
|
+
filename: str,
|
|
420
|
+
range_str: str,
|
|
421
|
+
output: Optional[str] = None,
|
|
422
|
+
fields: Optional[str] = None,
|
|
423
|
+
raw: bool = True,
|
|
424
|
+
) -> None:
|
|
425
|
+
"""
|
|
426
|
+
按行号范围查看数据(Python 切片语法)。
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
filename: 输入文件路径
|
|
430
|
+
range_str: 行号范围,格式为 start:end(0-based,左闭右开)
|
|
431
|
+
- 10:20 第 10-19 行(共 10 条)
|
|
432
|
+
- :100 前 100 行
|
|
433
|
+
- 100: 第 100 行到末尾
|
|
434
|
+
- -10: 最后 10 行
|
|
435
|
+
output: 输出文件路径
|
|
436
|
+
fields: 只显示指定字段(逗号分隔)
|
|
437
|
+
raw: 输出原始 JSON 格式
|
|
438
|
+
|
|
439
|
+
Examples:
|
|
440
|
+
dt slice data.jsonl 10:20
|
|
441
|
+
dt slice data.jsonl :100
|
|
442
|
+
dt slice data.jsonl 100:
|
|
443
|
+
dt slice data.jsonl -10:
|
|
444
|
+
dt slice data.jsonl 10:20 --output=sliced.jsonl
|
|
445
|
+
dt slice data.jsonl 10:20 --fields=question,answer
|
|
446
|
+
"""
|
|
447
|
+
filepath = Path(filename)
|
|
448
|
+
|
|
449
|
+
if not filepath.exists():
|
|
450
|
+
print(f"错误: 文件不存在 - {filename}")
|
|
451
|
+
return
|
|
452
|
+
|
|
453
|
+
if not _check_file_format(filepath):
|
|
454
|
+
return
|
|
455
|
+
|
|
456
|
+
# 解析 range
|
|
457
|
+
if ":" not in range_str:
|
|
458
|
+
print(f"错误: 无效的范围格式 '{range_str}',应为 start:end(如 10:20)")
|
|
459
|
+
return
|
|
460
|
+
|
|
461
|
+
parts = range_str.split(":", 1)
|
|
462
|
+
start_str, end_str = parts[0].strip(), parts[1].strip()
|
|
463
|
+
|
|
464
|
+
try:
|
|
465
|
+
start = int(start_str) if start_str else None
|
|
466
|
+
end = int(end_str) if end_str else None
|
|
467
|
+
except ValueError:
|
|
468
|
+
print(f"错误: 无效的范围格式 '{range_str}',start 和 end 必须为整数")
|
|
469
|
+
return
|
|
470
|
+
|
|
471
|
+
# 加载数据并切片
|
|
472
|
+
try:
|
|
473
|
+
data = load_data(str(filepath))
|
|
474
|
+
except Exception as e:
|
|
475
|
+
print(f"错误: {e}")
|
|
476
|
+
return
|
|
477
|
+
|
|
478
|
+
sliced = data[start:end]
|
|
479
|
+
|
|
480
|
+
if not sliced:
|
|
481
|
+
total = len(data)
|
|
482
|
+
print(f"⚠️ 范围 [{range_str}] 无数据(文件共 {total} 行)")
|
|
483
|
+
return
|
|
484
|
+
|
|
485
|
+
# 显示范围信息
|
|
486
|
+
total = len(data)
|
|
487
|
+
actual_start = start if start is not None else 0
|
|
488
|
+
if actual_start < 0:
|
|
489
|
+
actual_start = max(0, total + actual_start)
|
|
490
|
+
actual_end = min(end, total) if end is not None else total
|
|
491
|
+
print(f"📍 行 {actual_start}-{actual_end - 1}(共 {len(sliced)} 条,文件共 {total} 行)")
|
|
492
|
+
|
|
493
|
+
# 输出结果
|
|
494
|
+
if output:
|
|
495
|
+
save_data(sliced, output)
|
|
496
|
+
print(f"已保存 {len(sliced)} 条数据到 {output}")
|
|
497
|
+
elif raw:
|
|
498
|
+
for item in sliced:
|
|
499
|
+
print(orjson.dumps(item, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
|
500
|
+
else:
|
|
501
|
+
field_list = _parse_field_list(fields) if fields else None
|
|
502
|
+
_print_samples(sliced, filepath.name, total, field_list, filepath.stat().st_size)
|
|
503
|
+
|
|
504
|
+
|
|
418
505
|
def tail(
|
|
419
506
|
filename: str,
|
|
420
507
|
num: int = 10,
|
|
421
508
|
output: Optional[str] = None,
|
|
422
509
|
fields: Optional[str] = None,
|
|
423
|
-
raw: bool =
|
|
510
|
+
raw: bool = True,
|
|
424
511
|
) -> None:
|
|
425
512
|
"""
|
|
426
513
|
显示文件的后 N 条数据(dt sample --type=tail 的快捷方式)。
|
dtflow/cli/split.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CLI 数据集切分命令
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import List, Optional
|
|
7
|
+
|
|
8
|
+
from ..core import DataTransformer
|
|
9
|
+
from ..storage.io import save_data
|
|
10
|
+
from .common import _check_file_format
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _parse_ratio(ratio_str: str) -> List[float]:
|
|
14
|
+
"""
|
|
15
|
+
解析比例参数。
|
|
16
|
+
|
|
17
|
+
- "0.8" -> [0.8, 0.2](二分)
|
|
18
|
+
- "0.8,0.1,0.1" -> [0.8, 0.1, 0.1](三分)
|
|
19
|
+
"""
|
|
20
|
+
parts = [float(x.strip()) for x in ratio_str.split(",")]
|
|
21
|
+
|
|
22
|
+
if len(parts) == 1:
|
|
23
|
+
if not (0 < parts[0] < 1):
|
|
24
|
+
raise ValueError(f"比例必须在 0-1 之间: {parts[0]}")
|
|
25
|
+
parts.append(round(1 - parts[0], 10))
|
|
26
|
+
|
|
27
|
+
total = sum(parts)
|
|
28
|
+
if abs(total - 1.0) > 1e-6:
|
|
29
|
+
raise ValueError(f"比例之和必须为 1.0,当前为 {total}")
|
|
30
|
+
|
|
31
|
+
if any(p <= 0 for p in parts):
|
|
32
|
+
raise ValueError("每个比例都必须大于 0")
|
|
33
|
+
|
|
34
|
+
return parts
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# 切分名称:二分用 train/test,三分及以上用 train/val/test/part4/part5...
|
|
38
|
+
_SPLIT_NAMES_2 = ["train", "test"]
|
|
39
|
+
_SPLIT_NAMES_3 = ["train", "val", "test"]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _get_split_names(count: int) -> List[str]:
|
|
43
|
+
"""根据切分数量获取名称"""
|
|
44
|
+
if count == 2:
|
|
45
|
+
return _SPLIT_NAMES_2
|
|
46
|
+
elif count == 3:
|
|
47
|
+
return _SPLIT_NAMES_3
|
|
48
|
+
else:
|
|
49
|
+
names = ["train", "val", "test"]
|
|
50
|
+
for i in range(3, count):
|
|
51
|
+
names.append(f"part{i + 1}")
|
|
52
|
+
return names
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def split(
|
|
56
|
+
filename: str,
|
|
57
|
+
ratio: str = "0.8",
|
|
58
|
+
seed: Optional[int] = None,
|
|
59
|
+
output: Optional[str] = None,
|
|
60
|
+
) -> None:
|
|
61
|
+
"""
|
|
62
|
+
分割数据集为 train/test (或 train/val/test)。
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
filename: 输入文件路径
|
|
66
|
+
ratio: 分割比例,如 "0.8" 或 "0.7,0.15,0.15"
|
|
67
|
+
seed: 随机种子
|
|
68
|
+
output: 输出目录(默认同目录)
|
|
69
|
+
"""
|
|
70
|
+
filepath = Path(filename)
|
|
71
|
+
|
|
72
|
+
if not filepath.exists():
|
|
73
|
+
print(f"错误: 文件不存在 - {filename}")
|
|
74
|
+
return
|
|
75
|
+
|
|
76
|
+
if not _check_file_format(filepath):
|
|
77
|
+
return
|
|
78
|
+
|
|
79
|
+
# 解析比例
|
|
80
|
+
try:
|
|
81
|
+
ratios = _parse_ratio(ratio)
|
|
82
|
+
except ValueError as e:
|
|
83
|
+
print(f"错误: {e}")
|
|
84
|
+
return
|
|
85
|
+
|
|
86
|
+
split_names = _get_split_names(len(ratios))
|
|
87
|
+
|
|
88
|
+
# 加载数据
|
|
89
|
+
print(f"📊 加载数据: {filepath}")
|
|
90
|
+
try:
|
|
91
|
+
dt = DataTransformer.load(str(filepath))
|
|
92
|
+
except Exception as e:
|
|
93
|
+
print(f"错误: 无法读取文件 - {e}")
|
|
94
|
+
return
|
|
95
|
+
|
|
96
|
+
total = len(dt)
|
|
97
|
+
print(f" 共 {total} 条数据")
|
|
98
|
+
|
|
99
|
+
# 打乱
|
|
100
|
+
shuffled = dt.shuffle(seed)
|
|
101
|
+
if seed is not None:
|
|
102
|
+
print(f"🎲 随机种子: {seed}")
|
|
103
|
+
|
|
104
|
+
# 计算切分点
|
|
105
|
+
data = shuffled.data
|
|
106
|
+
split_indices = []
|
|
107
|
+
acc = 0
|
|
108
|
+
for r in ratios[:-1]:
|
|
109
|
+
acc += int(total * r)
|
|
110
|
+
split_indices.append(acc)
|
|
111
|
+
|
|
112
|
+
# 切分数据
|
|
113
|
+
parts = []
|
|
114
|
+
prev = 0
|
|
115
|
+
for idx in split_indices:
|
|
116
|
+
parts.append(data[prev:idx])
|
|
117
|
+
prev = idx
|
|
118
|
+
parts.append(data[prev:])
|
|
119
|
+
|
|
120
|
+
# 确定输出目录
|
|
121
|
+
if output:
|
|
122
|
+
output_dir = Path(output)
|
|
123
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
124
|
+
else:
|
|
125
|
+
output_dir = filepath.parent
|
|
126
|
+
|
|
127
|
+
# 保存各部分
|
|
128
|
+
stem = filepath.stem
|
|
129
|
+
ext = filepath.suffix
|
|
130
|
+
|
|
131
|
+
print(f"\n🔀 切分比例: {' / '.join(f'{r:.0%}' for r in ratios)}")
|
|
132
|
+
for i, (name, part) in enumerate(zip(split_names, parts)):
|
|
133
|
+
output_path = output_dir / f"{stem}_{name}{ext}"
|
|
134
|
+
save_data(part, str(output_path))
|
|
135
|
+
pct = ratios[i] * 100
|
|
136
|
+
print(f" {name}: {len(part)} 条 ({pct:.1f}%) -> {output_path}")
|
|
137
|
+
|
|
138
|
+
print(f"\n✅ 完成! 共切分为 {len(ratios)} 个部分")
|