webmainbench 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- webmainbench/__init__.py +33 -0
- webmainbench/config.py +11 -0
- webmainbench/data/__init__.py +16 -0
- webmainbench/data/dataset.py +161 -0
- webmainbench/data/loader.py +231 -0
- webmainbench/data/saver.py +507 -0
- webmainbench/evaluator/__init__.py +13 -0
- webmainbench/evaluator/evaluator.py +598 -0
- webmainbench/evaluator/main_html_evaluator.py +316 -0
- webmainbench/extractors/__init__.py +29 -0
- webmainbench/extractors/base.py +216 -0
- webmainbench/extractors/dripper_extractor.py +95 -0
- webmainbench/extractors/factory.py +141 -0
- webmainbench/extractors/jina_extractor.py +106 -0
- webmainbench/extractors/llm_webkit_extractor.py +808 -0
- webmainbench/extractors/magic_html_extractor.py +84 -0
- webmainbench/extractors/resiliparse_extractor.py +128 -0
- webmainbench/extractors/test_model_extractor.py +27 -0
- webmainbench/extractors/trafilatura_extractor.py +126 -0
- webmainbench/extractors/trafilatura_txt_extractor.py +132 -0
- webmainbench/metrics/__init__.py +38 -0
- webmainbench/metrics/base.py +294 -0
- webmainbench/metrics/base_content_splitter.py +101 -0
- webmainbench/metrics/calculator.py +301 -0
- webmainbench/metrics/code_extractor.py +91 -0
- webmainbench/metrics/formula_extractor.py +115 -0
- webmainbench/metrics/formula_metrics.py +74 -0
- webmainbench/metrics/mainhtml_calculator.py +51 -0
- webmainbench/metrics/table_extractor.py +90 -0
- webmainbench/metrics/table_metrics.py +114 -0
- webmainbench/metrics/teds_metrics.py +295 -0
- webmainbench/metrics/text_metrics.py +417 -0
- webmainbench/utils/__init__.py +14 -0
- webmainbench/utils/helpers.py +111 -0
- webmainbench/utils/main_html.py +104 -0
- webmainbench-0.1.0.dist-info/METADATA +496 -0
- webmainbench-0.1.0.dist-info/RECORD +41 -0
- webmainbench-0.1.0.dist-info/WHEEL +5 -0
- webmainbench-0.1.0.dist-info/entry_points.txt +2 -0
- webmainbench-0.1.0.dist-info/licenses/LICENSE +201 -0
- webmainbench-0.1.0.dist-info/top_level.txt +1 -0
webmainbench/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""
|
|
2
|
+
WebMainBench: A comprehensive benchmark for web main content extraction.
|
|
3
|
+
|
|
4
|
+
This package provides a standardized evaluation framework for comparing
|
|
5
|
+
different web content extraction tools and methods.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
__version__ = "0.1.0"
|
|
9
|
+
__author__ = "WebMainBench Team"
|
|
10
|
+
|
|
11
|
+
from .data import DataLoader, DataSaver, BenchmarkDataset, DataSample
|
|
12
|
+
from .extractors import BaseExtractor, ExtractorFactory, ExtractionResult
|
|
13
|
+
from .metrics import BaseMetric, MetricCalculator, MetricResult
|
|
14
|
+
from .evaluator import Evaluator, EvaluationResult, MainHTMLEvaluator
|
|
15
|
+
from .utils import setup_logging, format_results
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"DataLoader",
|
|
19
|
+
"DataSaver",
|
|
20
|
+
"BenchmarkDataset",
|
|
21
|
+
"DataSample",
|
|
22
|
+
"BaseExtractor",
|
|
23
|
+
"ExtractorFactory",
|
|
24
|
+
"ExtractionResult",
|
|
25
|
+
"BaseMetric",
|
|
26
|
+
"MetricCalculator",
|
|
27
|
+
"MetricResult",
|
|
28
|
+
"Evaluator",
|
|
29
|
+
"EvaluationResult",
|
|
30
|
+
"setup_logging",
|
|
31
|
+
"format_results",
|
|
32
|
+
"MainHTMLEvaluator"
|
|
33
|
+
]
|
webmainbench/config.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data module for WebMainBench.
|
|
3
|
+
|
|
4
|
+
This module handles loading, saving and managing benchmark datasets.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .dataset import BenchmarkDataset, DataSample
|
|
8
|
+
from .loader import DataLoader
|
|
9
|
+
from .saver import DataSaver
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"BenchmarkDataset",
|
|
13
|
+
"DataSample",
|
|
14
|
+
"DataLoader",
|
|
15
|
+
"DataSaver",
|
|
16
|
+
]
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Dataset classes for WebMainBench.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Dict, List, Optional, Any, Union
|
|
8
|
+
import json
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class DataSample:
|
|
14
|
+
"""Single data sample in the benchmark dataset."""
|
|
15
|
+
|
|
16
|
+
# Required fields
|
|
17
|
+
id: str
|
|
18
|
+
html: str # HTML with cc-select=true annotations
|
|
19
|
+
groundtruth_content: str # Groundtruth markdown content
|
|
20
|
+
# Optional fields with defaults
|
|
21
|
+
groundtruth_content_list: Optional[List[Dict[str, Any]]] = None # Groundtruth content_list from llm-webkit
|
|
22
|
+
content_list: List[Dict[str, Any]] = None # Content_list from llm-webkit
|
|
23
|
+
content: str = None # Content from llm-webkit
|
|
24
|
+
# Optional metadata
|
|
25
|
+
url: Optional[str] = None
|
|
26
|
+
domain: Optional[str] = None
|
|
27
|
+
language: Optional[str] = None
|
|
28
|
+
content_type: Optional[str] = None # article, forum, blog, etc.
|
|
29
|
+
difficulty: Optional[str] = None # easy, medium, hard
|
|
30
|
+
tags: Optional[List[str]] = None
|
|
31
|
+
llm_webkit_md: Optional[str] = None
|
|
32
|
+
llm_webkit_html: Optional[str] = None # 预处理HTML字段
|
|
33
|
+
main_html: Optional[str] = None # 主要HTML内容字段
|
|
34
|
+
|
|
35
|
+
# Extracted results (populated during evaluation)
|
|
36
|
+
extracted_results: Optional[Dict[str, Any]] = None
|
|
37
|
+
|
|
38
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
39
|
+
"""Convert to dictionary format."""
|
|
40
|
+
return {
|
|
41
|
+
"id": self.id,
|
|
42
|
+
"html": self.html,
|
|
43
|
+
"groundtruth_content": self.groundtruth_content,
|
|
44
|
+
"groundtruth_content_list": self.groundtruth_content_list,
|
|
45
|
+
"content_list": self.content_list,
|
|
46
|
+
"content": self.content,
|
|
47
|
+
"llm_webkit_md": self.llm_webkit_md,
|
|
48
|
+
"llm_webkit_html": self.llm_webkit_html,
|
|
49
|
+
"main_html": self.main_html,
|
|
50
|
+
"url": self.url,
|
|
51
|
+
"domain": self.domain,
|
|
52
|
+
"language": self.language,
|
|
53
|
+
"content_type": self.content_type,
|
|
54
|
+
"difficulty": self.difficulty,
|
|
55
|
+
"tags": self.tags,
|
|
56
|
+
"extracted_results": self.extracted_results,
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def from_dict(cls, data: Dict[str, Any]) -> "DataSample":
|
|
61
|
+
"""Create from dictionary, ignoring unknown fields and supporting field mapping."""
|
|
62
|
+
# 获取类的所有字段名
|
|
63
|
+
import dataclasses
|
|
64
|
+
field_names = {f.name for f in dataclasses.fields(cls)}
|
|
65
|
+
|
|
66
|
+
# 定义字段名映射(外部字段名 -> 内部字段名)
|
|
67
|
+
field_mapping = {
|
|
68
|
+
"track_id": "id", # track_id 映射到 id
|
|
69
|
+
"content": "groundtruth_content", # content 映射到 groundtruth_content
|
|
70
|
+
"convert_main_content": "groundtruth_content", # convert_main_content 映射到 groundtruth_content
|
|
71
|
+
"content_list": "groundtruth_content_list", # content_list 映射到 groundtruth_content_list
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
# 只提取定义的字段,忽略其他字段
|
|
75
|
+
filtered_data = {}
|
|
76
|
+
for key, value in data.items():
|
|
77
|
+
# 首先检查是否需要字段映射
|
|
78
|
+
mapped_key = field_mapping.get(key, key)
|
|
79
|
+
|
|
80
|
+
# 如果映射后的字段名在类字段中,则添加
|
|
81
|
+
if mapped_key in field_names:
|
|
82
|
+
filtered_data[mapped_key] = value
|
|
83
|
+
# 忽略未定义的字段,如 layout_id、max_layer_n 等
|
|
84
|
+
|
|
85
|
+
return cls(**filtered_data)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class BenchmarkDataset:
|
|
89
|
+
"""Main dataset class for WebMainBench."""
|
|
90
|
+
|
|
91
|
+
def __init__(self, name: str, description: str = ""):
|
|
92
|
+
self.name = name
|
|
93
|
+
self.description = description
|
|
94
|
+
self.samples: List[DataSample] = []
|
|
95
|
+
self._metadata: Dict[str, Any] = {}
|
|
96
|
+
|
|
97
|
+
def add_sample(self, sample: DataSample) -> None:
|
|
98
|
+
"""Add a data sample to the dataset."""
|
|
99
|
+
self.samples.append(sample)
|
|
100
|
+
|
|
101
|
+
def get_sample(self, sample_id: str) -> Optional[DataSample]:
|
|
102
|
+
"""Get a sample by ID."""
|
|
103
|
+
for sample in self.samples:
|
|
104
|
+
if sample.id == sample_id:
|
|
105
|
+
return sample
|
|
106
|
+
return None
|
|
107
|
+
|
|
108
|
+
def filter_by_criteria(self, **kwargs) -> List[DataSample]:
|
|
109
|
+
"""Filter samples by criteria (e.g., language='en', difficulty='hard')."""
|
|
110
|
+
filtered = self.samples
|
|
111
|
+
for key, value in kwargs.items():
|
|
112
|
+
filtered = [s for s in filtered if getattr(s, key, None) == value]
|
|
113
|
+
return filtered
|
|
114
|
+
|
|
115
|
+
def get_statistics(self) -> Dict[str, Any]:
|
|
116
|
+
"""Get dataset statistics."""
|
|
117
|
+
stats = {
|
|
118
|
+
"total_samples": len(self.samples),
|
|
119
|
+
"languages": {},
|
|
120
|
+
"content_types": {},
|
|
121
|
+
"difficulties": {},
|
|
122
|
+
"domains": {},
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
for sample in self.samples:
|
|
126
|
+
# Count languages
|
|
127
|
+
lang = sample.language or "unknown"
|
|
128
|
+
stats["languages"][lang] = stats["languages"].get(lang, 0) + 1
|
|
129
|
+
|
|
130
|
+
# Count content types
|
|
131
|
+
ctype = sample.content_type or "unknown"
|
|
132
|
+
stats["content_types"][ctype] = stats["content_types"].get(ctype, 0) + 1
|
|
133
|
+
|
|
134
|
+
# Count difficulties
|
|
135
|
+
diff = sample.difficulty or "unknown"
|
|
136
|
+
stats["difficulties"][diff] = stats["difficulties"].get(diff, 0) + 1
|
|
137
|
+
|
|
138
|
+
# Count domains
|
|
139
|
+
domain = sample.domain or "unknown"
|
|
140
|
+
stats["domains"][domain] = stats["domains"].get(domain, 0) + 1
|
|
141
|
+
|
|
142
|
+
return stats
|
|
143
|
+
|
|
144
|
+
def set_metadata(self, key: str, value: Any) -> None:
|
|
145
|
+
"""Set dataset metadata."""
|
|
146
|
+
self._metadata[key] = value
|
|
147
|
+
|
|
148
|
+
def get_metadata(self, key: str = None) -> Union[Any, Dict[str, Any]]:
|
|
149
|
+
"""Get dataset metadata."""
|
|
150
|
+
if key:
|
|
151
|
+
return self._metadata.get(key)
|
|
152
|
+
return self._metadata.copy()
|
|
153
|
+
|
|
154
|
+
def __len__(self) -> int:
|
|
155
|
+
return len(self.samples)
|
|
156
|
+
|
|
157
|
+
def __iter__(self):
|
|
158
|
+
return iter(self.samples)
|
|
159
|
+
|
|
160
|
+
def __getitem__(self, index: int) -> DataSample:
|
|
161
|
+
return self.samples[index]
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data loader for WebMainBench.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import jsonlines
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import List, Dict, Any, Optional, Union, Iterator
|
|
9
|
+
from .dataset import BenchmarkDataset, DataSample
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DataLoader:
|
|
13
|
+
"""Data loader for various input formats."""
|
|
14
|
+
|
|
15
|
+
@staticmethod
|
|
16
|
+
def load_jsonl(file_path: Union[str, Path], **kwargs) -> BenchmarkDataset:
|
|
17
|
+
"""
|
|
18
|
+
Load dataset from JSONL file.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
file_path: Path to the JSONL file
|
|
22
|
+
**kwargs: Additional parameters for dataset creation
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
BenchmarkDataset instance
|
|
26
|
+
"""
|
|
27
|
+
file_path = Path(file_path)
|
|
28
|
+
dataset_name = kwargs.get('name', file_path.stem)
|
|
29
|
+
dataset = BenchmarkDataset(name=dataset_name)
|
|
30
|
+
|
|
31
|
+
with jsonlines.open(file_path, 'r') as reader:
|
|
32
|
+
for idx, line in enumerate(reader):
|
|
33
|
+
try:
|
|
34
|
+
# 使用DataSample.from_dict()来正确处理字段映射和过滤
|
|
35
|
+
sample = DataSample.from_dict(line)
|
|
36
|
+
dataset.add_sample(sample)
|
|
37
|
+
|
|
38
|
+
except Exception as e:
|
|
39
|
+
print(f"Warning: Failed to load sample at line {idx}: {e}")
|
|
40
|
+
continue
|
|
41
|
+
|
|
42
|
+
return dataset
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def load_json(file_path: Union[str, Path], **kwargs) -> BenchmarkDataset:
|
|
46
|
+
"""
|
|
47
|
+
Load dataset from JSON file.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
file_path: Path to the JSON file
|
|
51
|
+
**kwargs: Additional parameters for dataset creation
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
BenchmarkDataset instance
|
|
55
|
+
"""
|
|
56
|
+
file_path = Path(file_path)
|
|
57
|
+
dataset_name = kwargs.get('name', file_path.stem)
|
|
58
|
+
dataset = BenchmarkDataset(name=dataset_name)
|
|
59
|
+
|
|
60
|
+
with open(file_path, 'r', encoding='utf-8') as f:
|
|
61
|
+
data = json.load(f)
|
|
62
|
+
|
|
63
|
+
# Handle different JSON structures
|
|
64
|
+
if isinstance(data, list):
|
|
65
|
+
# Array of samples
|
|
66
|
+
samples_data = data
|
|
67
|
+
elif isinstance(data, dict):
|
|
68
|
+
if 'samples' in data:
|
|
69
|
+
# Structured format with metadata
|
|
70
|
+
samples_data = data['samples']
|
|
71
|
+
# Load metadata if available
|
|
72
|
+
if 'metadata' in data:
|
|
73
|
+
for key, value in data['metadata'].items():
|
|
74
|
+
dataset.set_metadata(key, value)
|
|
75
|
+
else:
|
|
76
|
+
# Single sample in dict format
|
|
77
|
+
samples_data = [data]
|
|
78
|
+
else:
|
|
79
|
+
raise ValueError(f"Unsupported JSON structure in {file_path}")
|
|
80
|
+
|
|
81
|
+
for idx, sample_data in enumerate(samples_data):
|
|
82
|
+
try:
|
|
83
|
+
sample = DataSample.from_dict(sample_data)
|
|
84
|
+
if not sample.id:
|
|
85
|
+
sample.id = f"sample_{idx}"
|
|
86
|
+
dataset.add_sample(sample)
|
|
87
|
+
except Exception as e:
|
|
88
|
+
print(f"Warning: Failed to load sample {idx}: {e}")
|
|
89
|
+
continue
|
|
90
|
+
|
|
91
|
+
return dataset
|
|
92
|
+
|
|
93
|
+
@staticmethod
|
|
94
|
+
def load_from_directory(dir_path: Union[str, Path],
|
|
95
|
+
pattern: str = "*.jsonl",
|
|
96
|
+
**kwargs) -> Dict[str, BenchmarkDataset]:
|
|
97
|
+
"""
|
|
98
|
+
Load multiple datasets from a directory.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
dir_path: Directory containing dataset files
|
|
102
|
+
pattern: File pattern to match (default: "*.jsonl")
|
|
103
|
+
**kwargs: Additional parameters for dataset creation
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
Dictionary mapping filenames to BenchmarkDataset instances
|
|
107
|
+
"""
|
|
108
|
+
dir_path = Path(dir_path)
|
|
109
|
+
datasets = {}
|
|
110
|
+
|
|
111
|
+
for file_path in dir_path.glob(pattern):
|
|
112
|
+
try:
|
|
113
|
+
if file_path.suffix == '.jsonl':
|
|
114
|
+
dataset = DataLoader.load_jsonl(file_path, **kwargs)
|
|
115
|
+
elif file_path.suffix == '.json':
|
|
116
|
+
dataset = DataLoader.load_json(file_path, **kwargs)
|
|
117
|
+
else:
|
|
118
|
+
print(f"Warning: Unsupported file format: {file_path}")
|
|
119
|
+
continue
|
|
120
|
+
|
|
121
|
+
datasets[file_path.stem] = dataset
|
|
122
|
+
|
|
123
|
+
except Exception as e:
|
|
124
|
+
print(f"Error loading {file_path}: {e}")
|
|
125
|
+
continue
|
|
126
|
+
|
|
127
|
+
return datasets
|
|
128
|
+
|
|
129
|
+
@staticmethod
|
|
130
|
+
def merge_datasets(datasets: List[BenchmarkDataset],
|
|
131
|
+
name: str = "merged_dataset") -> BenchmarkDataset:
|
|
132
|
+
"""
|
|
133
|
+
Merge multiple datasets into one.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
datasets: List of BenchmarkDataset instances to merge
|
|
137
|
+
name: Name for the merged dataset
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Merged BenchmarkDataset instance
|
|
141
|
+
"""
|
|
142
|
+
merged = BenchmarkDataset(name=name)
|
|
143
|
+
|
|
144
|
+
for dataset in datasets:
|
|
145
|
+
for sample in dataset.samples:
|
|
146
|
+
# Ensure unique IDs
|
|
147
|
+
original_id = sample.id
|
|
148
|
+
counter = 1
|
|
149
|
+
while merged.get_sample(sample.id) is not None:
|
|
150
|
+
sample.id = f"{original_id}_{counter}"
|
|
151
|
+
counter += 1
|
|
152
|
+
|
|
153
|
+
merged.add_sample(sample)
|
|
154
|
+
|
|
155
|
+
return merged
|
|
156
|
+
|
|
157
|
+
@staticmethod
|
|
158
|
+
def stream_jsonl(file_path: Union[str, Path],
|
|
159
|
+
categories: Optional[List[str]] = None,
|
|
160
|
+
max_samples: Optional[int] = None) -> Iterator[DataSample]:
|
|
161
|
+
"""
|
|
162
|
+
流式读取JSONL文件,逐个返回DataSample,减少内存使用。
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
file_path: JSONL文件路径
|
|
166
|
+
categories: 类别过滤列表
|
|
167
|
+
max_samples: 最大样本数限制
|
|
168
|
+
|
|
169
|
+
Yields:
|
|
170
|
+
DataSample: 逐个生成的数据样本
|
|
171
|
+
"""
|
|
172
|
+
file_path = Path(file_path)
|
|
173
|
+
|
|
174
|
+
sample_count = 0
|
|
175
|
+
with jsonlines.open(file_path, 'r') as reader:
|
|
176
|
+
for line_idx, line in enumerate(reader):
|
|
177
|
+
try:
|
|
178
|
+
# 创建样本
|
|
179
|
+
sample = DataSample.from_dict(line)
|
|
180
|
+
|
|
181
|
+
# 类别过滤
|
|
182
|
+
if categories and sample.content_type not in categories:
|
|
183
|
+
continue
|
|
184
|
+
|
|
185
|
+
# 返回样本
|
|
186
|
+
yield sample
|
|
187
|
+
sample_count += 1
|
|
188
|
+
|
|
189
|
+
# 检查样本数限制
|
|
190
|
+
if max_samples and sample_count >= max_samples:
|
|
191
|
+
break
|
|
192
|
+
|
|
193
|
+
except Exception as e:
|
|
194
|
+
print(f"Warning: Failed to load sample at line {line_idx}: {e}")
|
|
195
|
+
continue
|
|
196
|
+
|
|
197
|
+
@staticmethod
|
|
198
|
+
def stream_jsonl_batched(file_path: Union[str, Path],
|
|
199
|
+
batch_size: int = 50,
|
|
200
|
+
categories: Optional[List[str]] = None,
|
|
201
|
+
max_samples: Optional[int] = None) -> Iterator[List[DataSample]]:
|
|
202
|
+
"""
|
|
203
|
+
流式读取JSONL文件,按批次返回DataSample列表。
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
file_path: JSONL文件路径
|
|
207
|
+
batch_size: 批次大小
|
|
208
|
+
categories: 类别过滤列表
|
|
209
|
+
max_samples: 最大样本数限制
|
|
210
|
+
|
|
211
|
+
Yields:
|
|
212
|
+
List[DataSample]: 批次数据样本列表
|
|
213
|
+
"""
|
|
214
|
+
batch = []
|
|
215
|
+
sample_count = 0
|
|
216
|
+
|
|
217
|
+
for sample in DataLoader.stream_jsonl(file_path, categories, max_samples):
|
|
218
|
+
batch.append(sample)
|
|
219
|
+
sample_count += 1
|
|
220
|
+
|
|
221
|
+
# 达到批次大小或样本数限制时返回批次
|
|
222
|
+
if len(batch) >= batch_size or (max_samples and sample_count >= max_samples):
|
|
223
|
+
yield batch
|
|
224
|
+
batch = []
|
|
225
|
+
|
|
226
|
+
if max_samples and sample_count >= max_samples:
|
|
227
|
+
break
|
|
228
|
+
|
|
229
|
+
# 返回最后一批(如果有)
|
|
230
|
+
if batch:
|
|
231
|
+
yield batch
|