dtflow 0.5.8__py3-none-any.whl → 0.5.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dtflow/cli/eval.py ADDED
@@ -0,0 +1,288 @@
1
+ """
2
+ CLI eval 命令实现
3
+
4
+ 对模型输出进行解析 + 指标计算,支持两阶段解析和管道式提取。
5
+ """
6
+
7
+ import json
8
+ import re
9
+ from pathlib import Path
10
+ from typing import Optional
11
+
12
+ from rich.console import Console
13
+
14
+ from ..storage.io import load_data
15
+ from ..utils.field_path import get_field
16
+ from ..utils.text_parser import extract_code_snippets, parse_generic_tags, strip_think_tags
17
+
18
+ console = Console()
19
+
20
+ # 自动检测 label 的候选字段名
21
+ LABEL_CANDIDATES = ["label", "labels", "content_label", "target", "ground_truth", "answer"]
22
+
23
+
24
+ def eval(
25
+ result_file: str,
26
+ source: Optional[str] = None,
27
+ response_col: str = "content",
28
+ label_col: Optional[str] = None,
29
+ extract: str = "direct",
30
+ sep: Optional[str] = None,
31
+ mapping: Optional[str] = None,
32
+ output_dir: str = "record",
33
+ ):
34
+ """对模型输出 .jsonl 文件进行解析和指标计算
35
+
36
+ 两阶段解析流程:
37
+ 阶段1(自动):去除 <think>...</think>,提取 ```...``` 代码块
38
+ 阶段2(--extract 指定):管道式提取
39
+
40
+ Args:
41
+ result_file: 模型输出的 .jsonl 文件路径
42
+ source: 原始输入文件,按行号对齐合并(当 result_file 不含 label 时使用)
43
+ response_col: 模型响应所在字段名(支持嵌套路径,如 api_output.content)
44
+ label_col: 标签字段名(不指定时自动检测,支持嵌套路径)
45
+ extract: 管道式提取规则,算子间用 " | " 分隔
46
+ sep: 配合 index 算子使用的分隔符
47
+ mapping: 值映射,格式 "k1:v1,k2:v2"
48
+ output_dir: 指标报告输出目录
49
+ """
50
+ import pandas as pd
51
+
52
+ from ..eval import export_eval_report
53
+
54
+ # --- 加载数据 ---
55
+ data = load_data(result_file)
56
+ df = pd.DataFrame(data)
57
+ console.print(f"[cyan]加载 {result_file},共 {len(df)} 条[/cyan]")
58
+
59
+ # 合并 source 文件
60
+ if source:
61
+ source_data = load_data(source)
62
+ source_df = pd.DataFrame(source_data)
63
+ if len(source_df) != len(df):
64
+ console.print(f"[red]行数不一致: result={len(df)}, source={len(source_df)}[/red]")
65
+ return
66
+ for col in source_df.columns:
67
+ if col not in df.columns:
68
+ df[col] = source_df[col].values
69
+ console.print(f"[dim]已合并 source 文件: {source}[/dim]")
70
+
71
+ # --- 解析 response_col(支持嵌套)---
72
+ response_col_resolved = _resolve_nested_col(df, response_col)
73
+ if response_col_resolved is None:
74
+ console.print(f"[red]响应列 '{response_col}' 不存在。可用列: {list(df.columns)}[/red]")
75
+ return
76
+
77
+ # --- 自动检测 label_col ---
78
+ if label_col is None:
79
+ label_col = _auto_detect_label_col(df)
80
+ if label_col is None:
81
+ console.print(
82
+ f"[red]未找到标签列,请通过 --label-col 指定。可用列: {list(df.columns)}[/red]"
83
+ )
84
+ return
85
+
86
+ # 解析 label_col(支持嵌套)
87
+ label_col_resolved = _resolve_nested_col(df, label_col)
88
+ if label_col_resolved is None:
89
+ console.print(f"[red]标签列 '{label_col}' 不存在。可用列: {list(df.columns)}[/red]")
90
+ return
91
+
92
+ console.print(
93
+ f"[dim]response_col={response_col_resolved}, "
94
+ f"label_col={label_col_resolved}, extract={extract}[/dim]"
95
+ )
96
+
97
+ # --- 阶段1+2:解析 ---
98
+ ops = _parse_pipeline(extract)
99
+ pred_col = "__pred__"
100
+ df[pred_col] = df[response_col_resolved].apply(
101
+ lambda x: _run_pipeline(_stage1_clean(x), ops, sep)
102
+ )
103
+
104
+ # --- mapping 阶段 ---
105
+ if mapping:
106
+ m = _parse_mapping(mapping)
107
+ priority = {}
108
+ for i, v in enumerate(m.values()):
109
+ priority[v] = i
110
+
111
+ def map_value(x):
112
+ if isinstance(x, list):
113
+ mapped = [m.get(v, v) for v in x]
114
+ return max(mapped, key=lambda v: priority.get(v, -1))
115
+ return m.get(x, x) if isinstance(x, str) else m.get(str(x), x)
116
+
117
+ df[pred_col] = df[pred_col].apply(map_value)
118
+ df[label_col_resolved] = df[label_col_resolved].apply(map_value)
119
+
120
+ # 统一转字符串
121
+ df[pred_col] = df[pred_col].apply(
122
+ lambda x: str(x).strip() if not isinstance(x, str) else x.strip()
123
+ )
124
+ df[label_col_resolved] = df[label_col_resolved].apply(
125
+ lambda x: str(x).strip() if not isinstance(x, str) else x.strip()
126
+ )
127
+
128
+ # --- 调用 export_eval_report ---
129
+ console.print("\n[bold green]评估结果[/bold green]")
130
+ input_name = Path(result_file).stem
131
+ export_eval_report(
132
+ df,
133
+ pred_col=pred_col,
134
+ label_col=label_col_resolved,
135
+ record_folder=output_dir,
136
+ input_name=input_name,
137
+ )
138
+
139
+
140
+ # ============ 内部工具函数 ============
141
+
142
+
143
+ def _resolve_nested_col(df, col_name: str) -> Optional[str]:
144
+ """解析嵌套字段路径,将其展开为 DataFrame 的新列
145
+
146
+ 使用 dtflow 的 get_field() 支持完整的嵌套路径语法。
147
+
148
+ Returns:
149
+ 解析后的列名,或 None(如果字段不存在)
150
+ """
151
+ # 简单情况:直接列名存在
152
+ if col_name in df.columns:
153
+ return col_name
154
+
155
+ # 尝试嵌套路径
156
+ if "." not in col_name and "[" not in col_name:
157
+ return None
158
+
159
+ # 用 get_field 从第一个非空行试探
160
+ sample_row = None
161
+ for _, row in df.iterrows():
162
+ row_dict = row.to_dict()
163
+ val = get_field(row_dict, col_name)
164
+ if val is not None:
165
+ sample_row = row_dict
166
+ break
167
+
168
+ if sample_row is None:
169
+ return None
170
+
171
+ # 展开嵌套字段到新列
172
+ resolved_name = col_name.replace(".", "__").replace("[", "_").replace("]", "")
173
+ df[resolved_name] = df.apply(lambda row: get_field(row.to_dict(), col_name), axis=1)
174
+ return resolved_name
175
+
176
+
177
+ def _auto_detect_label_col(df) -> Optional[str]:
178
+ """自动检测 label 列"""
179
+ # 优先在顶层列中查找
180
+ for c in LABEL_CANDIDATES:
181
+ if c in df.columns:
182
+ return c
183
+
184
+ # 搜索 dict 类型列的嵌套 key
185
+ for col in df.columns:
186
+ non_null = df[col].dropna()
187
+ if non_null.empty:
188
+ continue
189
+ sample = non_null.iloc[0]
190
+ if isinstance(sample, dict):
191
+ for c in LABEL_CANDIDATES:
192
+ if c in sample:
193
+ return f"{col}.{c}"
194
+
195
+ return None
196
+
197
+
198
+ def _stage1_clean(text) -> str:
199
+ """阶段1:自动清洗(去思考链 + 提取代码块)"""
200
+ if not isinstance(text, str):
201
+ return str(text) if text is not None else ""
202
+ text = strip_think_tags(text)
203
+ snippets = extract_code_snippets(text)
204
+ if snippets:
205
+ return snippets[-1]["code"]
206
+ return text.strip()
207
+
208
+
209
+ def _parse_pipeline(extract_str: str) -> list:
210
+ """解析管道表达式,按 ' | ' 分割"""
211
+ return [op.strip() for op in extract_str.split(" | ") if op.strip()]
212
+
213
+
214
+ def _apply_op(text: str, op: str, sep: Optional[str] = None) -> str:
215
+ """对单个字符串执行单个算子"""
216
+ if op == "direct":
217
+ return text
218
+ elif op.startswith("tag:"):
219
+ tag_name = op[4:]
220
+ tags = parse_generic_tags(text)
221
+ return tags.get(tag_name, text)
222
+ elif op.startswith("json_key:"):
223
+ key = op[9:]
224
+ try:
225
+ obj = json.loads(text)
226
+ except Exception:
227
+ return text
228
+ if isinstance(obj, dict):
229
+ return str(obj.get(key, text))
230
+ return text
231
+ elif op.startswith("index:"):
232
+ idx = int(op[6:])
233
+ delimiter = sep if sep else ","
234
+ parts = text.split(delimiter)
235
+ if 0 <= idx < len(parts):
236
+ return parts[idx].strip()
237
+ return text
238
+ elif op.startswith("line:"):
239
+ n = int(op[5:])
240
+ text_lines = [line.strip() for line in text.splitlines() if line.strip()]
241
+ if text_lines and -len(text_lines) <= n < len(text_lines):
242
+ return text_lines[n]
243
+ return text
244
+ elif op.startswith("regex:"):
245
+ pattern = op[6:]
246
+ m = re.search(pattern, text)
247
+ if m:
248
+ return m.group(1) if m.lastindex else m.group(0)
249
+ return text
250
+ else:
251
+ console.print(f"[yellow]未知算子: {op},跳过[/yellow]")
252
+ return text
253
+
254
+
255
+ def _run_pipeline(text: str, ops: list, sep: Optional[str] = None):
256
+ """执行管道,处理 lines 展开"""
257
+ if "lines" in ops:
258
+ pos = ops.index("lines")
259
+ # lines 之前的算子先执行
260
+ for op in ops[:pos]:
261
+ text = _apply_op(text, op, sep)
262
+ # 展开为多行
263
+ items = [line.strip() for line in text.splitlines() if line.strip()]
264
+ # 每行独立走后续管道
265
+ rest_ops = ops[pos + 1 :]
266
+ results = []
267
+ for item in items:
268
+ for op in rest_ops:
269
+ item = _apply_op(item, op, sep)
270
+ results.append(item)
271
+ if not results:
272
+ return text
273
+ return results if len(results) > 1 else results[0]
274
+ else:
275
+ for op in ops:
276
+ text = _apply_op(text, op, sep)
277
+ return text
278
+
279
+
280
+ def _parse_mapping(mapping_str: str) -> dict:
281
+ """解析 'k1:v1,k2:v2' 格式的映射"""
282
+ m = {}
283
+ for pair in mapping_str.split(","):
284
+ pair = pair.strip()
285
+ if ":" in pair:
286
+ k, v = pair.split(":", 1)
287
+ m[k.strip()] = v.strip()
288
+ return m
dtflow/cli/export.py ADDED
@@ -0,0 +1,81 @@
1
+ """
2
+ CLI 训练框架导出命令
3
+ """
4
+
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ from ..core import DataTransformer
9
+ from ..framework import check_compatibility, detect_format, export_for
10
+ from .common import _check_file_format
11
+
12
+
13
+ def export(
14
+ filename: str,
15
+ framework: str,
16
+ output: Optional[str] = None,
17
+ name: Optional[str] = None,
18
+ check: bool = False,
19
+ ) -> None:
20
+ """
21
+ 导出数据到训练框架 (LLaMA-Factory, ms-swift, Axolotl)。
22
+
23
+ Args:
24
+ filename: 输入文件路径
25
+ framework: 目标框架 (llama-factory, swift, axolotl)
26
+ output: 输出目录(默认 {stem}_{framework}/)
27
+ name: 数据集名称(默认 custom_dataset)
28
+ check: 仅检查兼容性,不导出
29
+ """
30
+ filepath = Path(filename)
31
+
32
+ if not filepath.exists():
33
+ print(f"错误: 文件不存在 - {filename}")
34
+ return
35
+
36
+ if not _check_file_format(filepath):
37
+ return
38
+
39
+ # 加载数据
40
+ print(f"📊 加载数据: {filepath}")
41
+ try:
42
+ dt = DataTransformer.load(str(filepath))
43
+ except Exception as e:
44
+ print(f"错误: 无法读取文件 - {e}")
45
+ return
46
+
47
+ data = dt.data
48
+ total = len(data)
49
+ print(f" 共 {total} 条数据")
50
+
51
+ # 检测格式
52
+ fmt = detect_format(data)
53
+ print(f"📋 检测到格式: {fmt}")
54
+
55
+ # 兼容性检查
56
+ result = check_compatibility(data, framework)
57
+ print(f"\n{result}")
58
+
59
+ if check:
60
+ return
61
+
62
+ if not result.valid:
63
+ print("\n❌ 兼容性检查未通过,跳过导出")
64
+ return
65
+
66
+ # 确定输出目录
67
+ if output is None:
68
+ fw_short = framework.lower().replace("-", "_")
69
+ output = str(filepath.parent / f"{filepath.stem}_{fw_short}")
70
+
71
+ dataset_name = name or "custom_dataset"
72
+
73
+ # 执行导出
74
+ print(f"\n📦 导出到 {framework}...")
75
+ try:
76
+ export_for(data, framework, output, dataset_name=dataset_name)
77
+ except Exception as e:
78
+ print(f"错误: 导出失败 - {e}")
79
+ return
80
+
81
+ print(f"\n✅ 导出完成! 文件保存在: {output}")
dtflow/cli/sample.py CHANGED
@@ -143,7 +143,7 @@ def sample(
143
143
  by: Optional[str] = None,
144
144
  uniform: bool = False,
145
145
  fields: Optional[str] = None,
146
- raw: bool = False,
146
+ raw: bool = True,
147
147
  where: Optional[List[str]] = None,
148
148
  ) -> None:
149
149
  """
@@ -389,7 +389,7 @@ def head(
389
389
  num: int = 10,
390
390
  output: Optional[str] = None,
391
391
  fields: Optional[str] = None,
392
- raw: bool = False,
392
+ raw: bool = True,
393
393
  ) -> None:
394
394
  """
395
395
  显示文件的前 N 条数据(dt sample --type=head 的快捷方式)。
@@ -415,12 +415,99 @@ def head(
415
415
  sample(filename, num=num, type="head", output=output, fields=fields, raw=raw)
416
416
 
417
417
 
418
+ def slice_data(
419
+ filename: str,
420
+ range_str: str,
421
+ output: Optional[str] = None,
422
+ fields: Optional[str] = None,
423
+ raw: bool = True,
424
+ ) -> None:
425
+ """
426
+ 按行号范围查看数据(Python 切片语法)。
427
+
428
+ Args:
429
+ filename: 输入文件路径
430
+ range_str: 行号范围,格式为 start:end(0-based,左闭右开)
431
+ - 10:20 第 10-19 行(共 10 条)
432
+ - :100 前 100 行
433
+ - 100: 第 100 行到末尾
434
+ - -10: 最后 10 行
435
+ output: 输出文件路径
436
+ fields: 只显示指定字段(逗号分隔)
437
+ raw: 输出原始 JSON 格式
438
+
439
+ Examples:
440
+ dt slice data.jsonl 10:20
441
+ dt slice data.jsonl :100
442
+ dt slice data.jsonl 100:
443
+ dt slice data.jsonl -10:
444
+ dt slice data.jsonl 10:20 --output=sliced.jsonl
445
+ dt slice data.jsonl 10:20 --fields=question,answer
446
+ """
447
+ filepath = Path(filename)
448
+
449
+ if not filepath.exists():
450
+ print(f"错误: 文件不存在 - {filename}")
451
+ return
452
+
453
+ if not _check_file_format(filepath):
454
+ return
455
+
456
+ # 解析 range
457
+ if ":" not in range_str:
458
+ print(f"错误: 无效的范围格式 '{range_str}',应为 start:end(如 10:20)")
459
+ return
460
+
461
+ parts = range_str.split(":", 1)
462
+ start_str, end_str = parts[0].strip(), parts[1].strip()
463
+
464
+ try:
465
+ start = int(start_str) if start_str else None
466
+ end = int(end_str) if end_str else None
467
+ except ValueError:
468
+ print(f"错误: 无效的范围格式 '{range_str}',start 和 end 必须为整数")
469
+ return
470
+
471
+ # 加载数据并切片
472
+ try:
473
+ data = load_data(str(filepath))
474
+ except Exception as e:
475
+ print(f"错误: {e}")
476
+ return
477
+
478
+ sliced = data[start:end]
479
+
480
+ if not sliced:
481
+ total = len(data)
482
+ print(f"⚠️ 范围 [{range_str}] 无数据(文件共 {total} 行)")
483
+ return
484
+
485
+ # 显示范围信息
486
+ total = len(data)
487
+ actual_start = start if start is not None else 0
488
+ if actual_start < 0:
489
+ actual_start = max(0, total + actual_start)
490
+ actual_end = min(end, total) if end is not None else total
491
+ print(f"📍 行 {actual_start}-{actual_end - 1}(共 {len(sliced)} 条,文件共 {total} 行)")
492
+
493
+ # 输出结果
494
+ if output:
495
+ save_data(sliced, output)
496
+ print(f"已保存 {len(sliced)} 条数据到 {output}")
497
+ elif raw:
498
+ for item in sliced:
499
+ print(orjson.dumps(item, option=orjson.OPT_INDENT_2).decode("utf-8"))
500
+ else:
501
+ field_list = _parse_field_list(fields) if fields else None
502
+ _print_samples(sliced, filepath.name, total, field_list, filepath.stat().st_size)
503
+
504
+
418
505
  def tail(
419
506
  filename: str,
420
507
  num: int = 10,
421
508
  output: Optional[str] = None,
422
509
  fields: Optional[str] = None,
423
- raw: bool = False,
510
+ raw: bool = True,
424
511
  ) -> None:
425
512
  """
426
513
  显示文件的后 N 条数据(dt sample --type=tail 的快捷方式)。
dtflow/cli/split.py ADDED
@@ -0,0 +1,138 @@
1
+ """
2
+ CLI 数据集切分命令
3
+ """
4
+
5
+ from pathlib import Path
6
+ from typing import List, Optional
7
+
8
+ from ..core import DataTransformer
9
+ from ..storage.io import save_data
10
+ from .common import _check_file_format
11
+
12
+
13
+ def _parse_ratio(ratio_str: str) -> List[float]:
14
+ """
15
+ 解析比例参数。
16
+
17
+ - "0.8" -> [0.8, 0.2](二分)
18
+ - "0.8,0.1,0.1" -> [0.8, 0.1, 0.1](三分)
19
+ """
20
+ parts = [float(x.strip()) for x in ratio_str.split(",")]
21
+
22
+ if len(parts) == 1:
23
+ if not (0 < parts[0] < 1):
24
+ raise ValueError(f"比例必须在 0-1 之间: {parts[0]}")
25
+ parts.append(round(1 - parts[0], 10))
26
+
27
+ total = sum(parts)
28
+ if abs(total - 1.0) > 1e-6:
29
+ raise ValueError(f"比例之和必须为 1.0,当前为 {total}")
30
+
31
+ if any(p <= 0 for p in parts):
32
+ raise ValueError("每个比例都必须大于 0")
33
+
34
+ return parts
35
+
36
+
37
+ # 切分名称:二分用 train/test,三分及以上用 train/val/test/part4/part5...
38
+ _SPLIT_NAMES_2 = ["train", "test"]
39
+ _SPLIT_NAMES_3 = ["train", "val", "test"]
40
+
41
+
42
+ def _get_split_names(count: int) -> List[str]:
43
+ """根据切分数量获取名称"""
44
+ if count == 2:
45
+ return _SPLIT_NAMES_2
46
+ elif count == 3:
47
+ return _SPLIT_NAMES_3
48
+ else:
49
+ names = ["train", "val", "test"]
50
+ for i in range(3, count):
51
+ names.append(f"part{i + 1}")
52
+ return names
53
+
54
+
55
+ def split(
56
+ filename: str,
57
+ ratio: str = "0.8",
58
+ seed: Optional[int] = None,
59
+ output: Optional[str] = None,
60
+ ) -> None:
61
+ """
62
+ 分割数据集为 train/test (或 train/val/test)。
63
+
64
+ Args:
65
+ filename: 输入文件路径
66
+ ratio: 分割比例,如 "0.8" 或 "0.7,0.15,0.15"
67
+ seed: 随机种子
68
+ output: 输出目录(默认同目录)
69
+ """
70
+ filepath = Path(filename)
71
+
72
+ if not filepath.exists():
73
+ print(f"错误: 文件不存在 - {filename}")
74
+ return
75
+
76
+ if not _check_file_format(filepath):
77
+ return
78
+
79
+ # 解析比例
80
+ try:
81
+ ratios = _parse_ratio(ratio)
82
+ except ValueError as e:
83
+ print(f"错误: {e}")
84
+ return
85
+
86
+ split_names = _get_split_names(len(ratios))
87
+
88
+ # 加载数据
89
+ print(f"📊 加载数据: {filepath}")
90
+ try:
91
+ dt = DataTransformer.load(str(filepath))
92
+ except Exception as e:
93
+ print(f"错误: 无法读取文件 - {e}")
94
+ return
95
+
96
+ total = len(dt)
97
+ print(f" 共 {total} 条数据")
98
+
99
+ # 打乱
100
+ shuffled = dt.shuffle(seed)
101
+ if seed is not None:
102
+ print(f"🎲 随机种子: {seed}")
103
+
104
+ # 计算切分点
105
+ data = shuffled.data
106
+ split_indices = []
107
+ acc = 0
108
+ for r in ratios[:-1]:
109
+ acc += int(total * r)
110
+ split_indices.append(acc)
111
+
112
+ # 切分数据
113
+ parts = []
114
+ prev = 0
115
+ for idx in split_indices:
116
+ parts.append(data[prev:idx])
117
+ prev = idx
118
+ parts.append(data[prev:])
119
+
120
+ # 确定输出目录
121
+ if output:
122
+ output_dir = Path(output)
123
+ output_dir.mkdir(parents=True, exist_ok=True)
124
+ else:
125
+ output_dir = filepath.parent
126
+
127
+ # 保存各部分
128
+ stem = filepath.stem
129
+ ext = filepath.suffix
130
+
131
+ print(f"\n🔀 切分比例: {' / '.join(f'{r:.0%}' for r in ratios)}")
132
+ for i, (name, part) in enumerate(zip(split_names, parts)):
133
+ output_path = output_dir / f"{stem}_{name}{ext}"
134
+ save_data(part, str(output_path))
135
+ pct = ratios[i] * 100
136
+ print(f" {name}: {len(part)} 条 ({pct:.1f}%) -> {output_path}")
137
+
138
+ print(f"\n✅ 完成! 共切分为 {len(ratios)} 个部分")