fusion-bench 0.2.28__py3-none-any.whl → 0.2.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/constants/__init__.py +5 -1
- fusion_bench/constants/runtime.py +111 -7
- fusion_bench/dataset/gsm8k.py +6 -2
- fusion_bench/dataset/image_corruption/make_corruption.py +168 -0
- fusion_bench/method/__init__.py +1 -1
- fusion_bench/method/classification/image_classification_finetune.py +1 -2
- fusion_bench/method/gossip/clip_task_wise_gossip.py +1 -29
- fusion_bench/metrics/nyuv2/__init__.py +31 -0
- fusion_bench/metrics/nyuv2/depth.py +30 -0
- fusion_bench/metrics/nyuv2/loss.py +40 -0
- fusion_bench/metrics/nyuv2/noise.py +24 -0
- fusion_bench/metrics/nyuv2/normal.py +34 -1
- fusion_bench/metrics/nyuv2/segmentation.py +35 -1
- fusion_bench/mixins/clip_classification.py +30 -2
- fusion_bench/mixins/lightning_fabric.py +46 -5
- fusion_bench/mixins/rich_live.py +76 -0
- fusion_bench/modelpool/base_pool.py +86 -5
- fusion_bench/scripts/webui.py +250 -17
- fusion_bench/utils/__init__.py +14 -0
- fusion_bench/utils/data.py +100 -9
- fusion_bench/utils/fabric.py +185 -4
- fusion_bench/utils/json.py +6 -0
- fusion_bench/utils/validation.py +197 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.29.dist-info}/METADATA +66 -7
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.29.dist-info}/RECORD +35 -35
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +6 -19
- fusion_bench_config/llama_full_finetune.yaml +4 -16
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +4 -13
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench/utils/auto.py +0 -31
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.29.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.29.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.29.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.29.dist-info}/top_level.txt +0 -0
fusion_bench/scripts/webui.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
"""
|
|
2
|
-
|
|
2
|
+
Web UI for FusionBench Command Generator with per-session state management.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
import argparse
|
|
@@ -13,39 +13,94 @@ import hydra
|
|
|
13
13
|
import yaml
|
|
14
14
|
from colorama import Fore, Style # For cross-platform color support
|
|
15
15
|
from hydra import compose, initialize_config_dir
|
|
16
|
+
from hydra.core.hydra_config import HydraConfig
|
|
16
17
|
from omegaconf import DictConfig, ListConfig, OmegaConf
|
|
17
18
|
|
|
18
19
|
from fusion_bench.scripts.cli import _get_default_config_path
|
|
19
20
|
|
|
20
21
|
|
|
22
|
+
def escape_overrides(value: str) -> str:
|
|
23
|
+
"""
|
|
24
|
+
Escapes special characters in Hydra command-line override values.
|
|
25
|
+
|
|
26
|
+
Adds quotes around values containing spaces and escapes equals signs
|
|
27
|
+
to prevent them from being interpreted as key-value separators.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
value (str): The override value to escape.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
str: The escaped value ready for use in command-line overrides.
|
|
34
|
+
"""
|
|
35
|
+
if " " in value and not (value.startswith('"') or value.startswith("'")):
|
|
36
|
+
return f"'{value}'"
|
|
37
|
+
if "=" in value:
|
|
38
|
+
return value.replace("=", "\\=")
|
|
39
|
+
return value
|
|
40
|
+
|
|
41
|
+
|
|
21
42
|
class ConfigGroupNode:
|
|
43
|
+
"""
|
|
44
|
+
Represents a node in the configuration directory tree.
|
|
45
|
+
|
|
46
|
+
This class recursively builds a tree structure representing the Hydra
|
|
47
|
+
configuration directory hierarchy, including subdirectories (child groups)
|
|
48
|
+
and YAML configuration files.
|
|
49
|
+
|
|
50
|
+
Attributes:
|
|
51
|
+
name (str): Name of the configuration group (directory name).
|
|
52
|
+
path (Path): Full path to the directory.
|
|
53
|
+
parent (Optional[ConfigGroupNode]): Parent node in the tree.
|
|
54
|
+
children (List[ConfigGroupNode]): Child directory nodes.
|
|
55
|
+
configs (List[str]): List of YAML config file names (without extension).
|
|
56
|
+
"""
|
|
57
|
+
|
|
22
58
|
name: str
|
|
23
59
|
path: Path
|
|
24
|
-
parent: Optional["ConfigGroupNode"]
|
|
60
|
+
parent: Optional["ConfigGroupNode"]
|
|
25
61
|
children: List["ConfigGroupNode"]
|
|
26
62
|
configs: List[str]
|
|
27
63
|
|
|
28
|
-
def __init__(self, path: str | Path):
|
|
64
|
+
def __init__(self, path: str | Path, parent: Optional["ConfigGroupNode"] = None):
|
|
65
|
+
"""
|
|
66
|
+
Initialize a ConfigGroupNode.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
path: Path to the configuration directory.
|
|
70
|
+
parent: Parent node in the tree (None for root).
|
|
71
|
+
"""
|
|
29
72
|
self.path = Path(path)
|
|
30
73
|
assert self.path.is_dir()
|
|
31
74
|
self.name = self.path.stem
|
|
75
|
+
self.parent = parent
|
|
32
76
|
self.children = []
|
|
33
77
|
self.configs = []
|
|
34
78
|
for child in self.path.iterdir():
|
|
35
79
|
if child.is_dir():
|
|
36
|
-
child_node = ConfigGroupNode(child)
|
|
37
|
-
child_node.parent = self
|
|
80
|
+
child_node = ConfigGroupNode(child, parent=self)
|
|
38
81
|
self.children.append(child_node)
|
|
39
82
|
elif child.is_file() and child.suffix == ".yaml":
|
|
40
83
|
self.configs.append(child.stem)
|
|
41
84
|
|
|
42
85
|
def __repr__(self):
|
|
43
86
|
"""
|
|
44
|
-
Return string of the tree structure
|
|
87
|
+
Return a colored string representation of the tree structure.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
str: Tree structure with colored group names.
|
|
45
91
|
"""
|
|
46
92
|
return f"{Fore.BLUE}{self.name}{Style.RESET_ALL}\n" + self._repr_indented()
|
|
47
93
|
|
|
48
94
|
def _repr_indented(self, prefix=""):
|
|
95
|
+
"""
|
|
96
|
+
Generate indented tree representation recursively.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
prefix: String prefix for indentation.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
str: Indented tree structure.
|
|
103
|
+
"""
|
|
49
104
|
result = ""
|
|
50
105
|
|
|
51
106
|
items = self.configs + self.children
|
|
@@ -63,9 +118,30 @@ class ConfigGroupNode:
|
|
|
63
118
|
return result
|
|
64
119
|
|
|
65
120
|
def has_child_group(self, name: str) -> bool:
|
|
121
|
+
"""
|
|
122
|
+
Check if this node has a child group with the given name.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
name: Name of the child group to check.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
bool: True if child group exists, False otherwise.
|
|
129
|
+
"""
|
|
66
130
|
return any(child.name == name for child in self.children)
|
|
67
131
|
|
|
68
132
|
def __getitem__(self, key: str) -> Union["ConfigGroupNode", str]:
|
|
133
|
+
"""
|
|
134
|
+
Get a child group or config by name.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
key: Name of the child group or config file.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Union[ConfigGroupNode, str]: Child node or config file name.
|
|
141
|
+
|
|
142
|
+
Raises:
|
|
143
|
+
KeyError: If no child group or config with that name exists.
|
|
144
|
+
"""
|
|
69
145
|
for child in self.children:
|
|
70
146
|
if child.name == key:
|
|
71
147
|
return child
|
|
@@ -76,12 +152,28 @@ class ConfigGroupNode:
|
|
|
76
152
|
|
|
77
153
|
@functools.cached_property
|
|
78
154
|
def prefix(self) -> str:
|
|
155
|
+
"""
|
|
156
|
+
Get the dot-separated prefix path from root to this node.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
str: Prefix path (e.g., "method.modelpool.").
|
|
160
|
+
"""
|
|
79
161
|
if self.parent is None:
|
|
80
162
|
return ""
|
|
81
163
|
return self.parent.prefix + self.name + "."
|
|
82
164
|
|
|
83
165
|
|
|
84
166
|
def priority_iterable(iter, priority_keys):
|
|
167
|
+
"""
|
|
168
|
+
Iterate over items with priority keys first, then remaining items.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
iter: Iterable to process.
|
|
172
|
+
priority_keys: Keys to yield first.
|
|
173
|
+
|
|
174
|
+
Yields:
|
|
175
|
+
Items from iter, with priority_keys first.
|
|
176
|
+
"""
|
|
85
177
|
items = list(iter)
|
|
86
178
|
for key in priority_keys:
|
|
87
179
|
if key in items:
|
|
@@ -93,7 +185,16 @@ def priority_iterable(iter, priority_keys):
|
|
|
93
185
|
|
|
94
186
|
class AppState:
|
|
95
187
|
"""
|
|
96
|
-
|
|
188
|
+
Per-session state of the app.
|
|
189
|
+
|
|
190
|
+
Manages the current configuration state including the selected config name,
|
|
191
|
+
overrides, and the composed Hydra configuration.
|
|
192
|
+
|
|
193
|
+
Attributes:
|
|
194
|
+
config_name (str): Name of the root configuration file.
|
|
195
|
+
hydra_options (List[str]): Hydra-specific command line options.
|
|
196
|
+
overrides (List[str]): List of configuration overrides.
|
|
197
|
+
config (DictConfig): The composed OmegaConf configuration.
|
|
97
198
|
"""
|
|
98
199
|
|
|
99
200
|
config_name: str
|
|
@@ -108,6 +209,15 @@ class AppState:
|
|
|
108
209
|
hydra_options: List[str] = [],
|
|
109
210
|
overrides: List[str] = [],
|
|
110
211
|
) -> None:
|
|
212
|
+
"""
|
|
213
|
+
Initialize the application state.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
config_path: Path to the config directory.
|
|
217
|
+
config_name: Name of the root config file.
|
|
218
|
+
hydra_options: Hydra command line options.
|
|
219
|
+
overrides: Initial configuration overrides.
|
|
220
|
+
"""
|
|
111
221
|
super().__init__()
|
|
112
222
|
self.config_path = config_path
|
|
113
223
|
self.config_name = config_name
|
|
@@ -117,6 +227,12 @@ class AppState:
|
|
|
117
227
|
|
|
118
228
|
@property
|
|
119
229
|
def config_str(self):
|
|
230
|
+
"""
|
|
231
|
+
Get the YAML string representation of the current configuration.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
str: YAML formatted configuration.
|
|
235
|
+
"""
|
|
120
236
|
return OmegaConf.to_yaml(self.config)
|
|
121
237
|
|
|
122
238
|
def update_config(
|
|
@@ -124,6 +240,16 @@ class AppState:
|
|
|
124
240
|
config_name: str = None,
|
|
125
241
|
overrides: List[str] = None,
|
|
126
242
|
) -> "AppState":
|
|
243
|
+
"""
|
|
244
|
+
Update the configuration with new name and/or overrides.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
config_name: New root config name (optional).
|
|
248
|
+
overrides: New list of overrides (optional).
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
AppState: Self for method chaining.
|
|
252
|
+
"""
|
|
127
253
|
if config_name is not None:
|
|
128
254
|
self.config_name = config_name
|
|
129
255
|
if overrides is not None:
|
|
@@ -133,11 +259,21 @@ class AppState:
|
|
|
133
259
|
self.config = ""
|
|
134
260
|
else:
|
|
135
261
|
self.config = compose(
|
|
136
|
-
config_name=self.config_name,
|
|
262
|
+
config_name=self.config_name,
|
|
263
|
+
overrides=self.overrides,
|
|
264
|
+
return_hydra_config=True,
|
|
137
265
|
)
|
|
266
|
+
HydraConfig().set_config(self.config)
|
|
267
|
+
del self.config.hydra
|
|
138
268
|
return self
|
|
139
269
|
|
|
140
270
|
def generate_command(self):
|
|
271
|
+
"""
|
|
272
|
+
Generate the fusion_bench CLI command from current state.
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
str: Complete command ready to execute in shell.
|
|
276
|
+
"""
|
|
141
277
|
# Generate the command according to `config_name` and `overrides` (a list of strings)
|
|
142
278
|
command = "fusion_bench \\\n"
|
|
143
279
|
if self.config_path is not None:
|
|
@@ -152,23 +288,66 @@ class AppState:
|
|
|
152
288
|
|
|
153
289
|
@property
|
|
154
290
|
def config_str_and_command(self):
|
|
291
|
+
"""
|
|
292
|
+
Get both config string and command as a tuple.
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
Tuple[str, str]: (YAML config, shell command).
|
|
296
|
+
"""
|
|
155
297
|
return self.config_str, self.generate_command()
|
|
156
298
|
|
|
157
299
|
def get_override(self, key: str):
|
|
300
|
+
"""
|
|
301
|
+
Get the override value for a specific key.
|
|
302
|
+
|
|
303
|
+
Args:
|
|
304
|
+
key: Configuration key to look up.
|
|
305
|
+
|
|
306
|
+
Returns:
|
|
307
|
+
Optional[str]: Override value or None if not found.
|
|
308
|
+
"""
|
|
158
309
|
for ov in self.overrides:
|
|
159
310
|
if ov.startswith(f"{key}="):
|
|
160
311
|
return "".join(ov.split("=")[1:])
|
|
161
312
|
return None
|
|
162
313
|
|
|
163
314
|
def update_override(self, key: str, value):
|
|
315
|
+
"""
|
|
316
|
+
Update or add an override for a specific key.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
key: Configuration key to override.
|
|
320
|
+
value: New value for the key.
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
AppState: Updated state after recomposing config.
|
|
324
|
+
"""
|
|
164
325
|
self.overrides = [ov for ov in self.overrides if not ov.startswith(f"{key}=")]
|
|
165
326
|
if value:
|
|
166
|
-
self.overrides.append(f"{key}={value}")
|
|
327
|
+
self.overrides.append(f"{key}={escape_overrides(value)}")
|
|
167
328
|
return self.update_config()
|
|
168
329
|
|
|
169
330
|
|
|
170
331
|
class App:
|
|
332
|
+
"""
|
|
333
|
+
Main application class for the FusionBench WebUI.
|
|
334
|
+
|
|
335
|
+
Manages the Gradio interface, configuration tree, and application state.
|
|
336
|
+
|
|
337
|
+
Attributes:
|
|
338
|
+
args: Command line arguments.
|
|
339
|
+
group_tree (ConfigGroupNode): Root of the config directory tree.
|
|
340
|
+
init_config_name (str): Initial configuration name.
|
|
341
|
+
app_state (AppState): Current application state.
|
|
342
|
+
"""
|
|
343
|
+
|
|
171
344
|
def __init__(self, args):
|
|
345
|
+
"""
|
|
346
|
+
Initialize the application.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
args: Parsed command line arguments.
|
|
350
|
+
"""
|
|
172
351
|
super().__init__()
|
|
173
352
|
self.args = args
|
|
174
353
|
group_tree = ConfigGroupNode(self.config_path)
|
|
@@ -177,8 +356,8 @@ class App:
|
|
|
177
356
|
|
|
178
357
|
self.group_tree = group_tree
|
|
179
358
|
|
|
180
|
-
if "
|
|
181
|
-
self.init_config_name = "
|
|
359
|
+
if "fabric_model_fusion" in group_tree.configs:
|
|
360
|
+
self.init_config_name = "fabric_model_fusion"
|
|
182
361
|
else:
|
|
183
362
|
self.init_config_name = group_tree.configs[0]
|
|
184
363
|
|
|
@@ -197,31 +376,64 @@ class App:
|
|
|
197
376
|
|
|
198
377
|
@functools.cached_property
|
|
199
378
|
def config_path(self):
|
|
379
|
+
"""
|
|
380
|
+
Get the configuration directory path.
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
Path: Path to the config directory.
|
|
384
|
+
"""
|
|
200
385
|
if self.args.config_path:
|
|
201
386
|
return Path(self.args.config_path)
|
|
202
387
|
else:
|
|
203
388
|
return _get_default_config_path()
|
|
204
389
|
|
|
205
390
|
def __getattr__(self, name):
|
|
391
|
+
"""
|
|
392
|
+
Delegate attribute access to app_state if not found in App.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
name: Attribute name.
|
|
396
|
+
|
|
397
|
+
Returns:
|
|
398
|
+
Attribute value from app_state.
|
|
399
|
+
|
|
400
|
+
Raises:
|
|
401
|
+
AttributeError: If attribute not found in app_state either.
|
|
402
|
+
"""
|
|
206
403
|
if hasattr(self.app_state, name):
|
|
207
404
|
return getattr(self.app_state, name)
|
|
208
405
|
raise AttributeError(f"App object has no attribute {name}")
|
|
209
406
|
|
|
210
407
|
def generate_ui(self):
|
|
408
|
+
"""
|
|
409
|
+
Generate the Gradio user interface.
|
|
410
|
+
|
|
411
|
+
Creates interactive UI components for configuration selection,
|
|
412
|
+
parameter editing, and command generation.
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
gr.Blocks: Gradio application instance.
|
|
416
|
+
"""
|
|
211
417
|
with gr.Blocks() as app:
|
|
212
418
|
gr.Markdown("# FusionBench Command Generator")
|
|
213
419
|
|
|
214
420
|
# 1. Choose a root config file
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
421
|
+
with gr.Row(equal_height=True):
|
|
422
|
+
root_configs = gr.Dropdown(
|
|
423
|
+
choices=self.group_tree.configs,
|
|
424
|
+
value=self.config_name,
|
|
425
|
+
label="Root Config",
|
|
426
|
+
scale=4,
|
|
427
|
+
)
|
|
428
|
+
reset_button = gr.Button("Reset", scale=1)
|
|
220
429
|
|
|
221
430
|
with gr.Row():
|
|
222
431
|
with gr.Column(scale=2):
|
|
223
432
|
command_output = gr.Code(
|
|
224
|
-
|
|
433
|
+
value=self.app_state.generate_command(),
|
|
434
|
+
language="shell",
|
|
435
|
+
label="Generated Command",
|
|
436
|
+
interactive=False,
|
|
225
437
|
)
|
|
226
438
|
|
|
227
439
|
@gr.render(inputs=[root_configs, command_output])
|
|
@@ -352,10 +564,26 @@ class App:
|
|
|
352
564
|
outputs=[config_output, command_output],
|
|
353
565
|
)
|
|
354
566
|
|
|
567
|
+
def reset_app(config_name):
|
|
568
|
+
# Reset overrides and update config
|
|
569
|
+
self.app_state.overrides = []
|
|
570
|
+
return self.app_state.update_config(config_name).config_str_and_command
|
|
571
|
+
|
|
572
|
+
reset_button.click(
|
|
573
|
+
reset_app,
|
|
574
|
+
inputs=[root_configs],
|
|
575
|
+
outputs=[config_output, command_output],
|
|
576
|
+
)
|
|
355
577
|
return app
|
|
356
578
|
|
|
357
579
|
|
|
358
580
|
def parse_args():
|
|
581
|
+
"""
|
|
582
|
+
Parse command line arguments for the WebUI.
|
|
583
|
+
|
|
584
|
+
Returns:
|
|
585
|
+
argparse.Namespace: Parsed arguments.
|
|
586
|
+
"""
|
|
359
587
|
parser = argparse.ArgumentParser(description="FusionBench Command Generator")
|
|
360
588
|
parser.add_argument(
|
|
361
589
|
"--config-path",
|
|
@@ -390,6 +618,11 @@ def parse_args():
|
|
|
390
618
|
|
|
391
619
|
|
|
392
620
|
def main() -> None:
|
|
621
|
+
"""
|
|
622
|
+
Main entry point for the FusionBench WebUI application.
|
|
623
|
+
|
|
624
|
+
Parses arguments, initializes the app, and launches the Gradio interface.
|
|
625
|
+
"""
|
|
393
626
|
args = parse_args()
|
|
394
627
|
|
|
395
628
|
app = App(args).generate_ui()
|
fusion_bench/utils/__init__.py
CHANGED
|
@@ -93,6 +93,13 @@ _import_structure = {
|
|
|
93
93
|
"StateDictType",
|
|
94
94
|
"TorchModelType",
|
|
95
95
|
],
|
|
96
|
+
"validation": [
|
|
97
|
+
"validate_path_exists",
|
|
98
|
+
"validate_file_exists",
|
|
99
|
+
"validate_directory_exists",
|
|
100
|
+
"validate_model_name",
|
|
101
|
+
"ValidationError",
|
|
102
|
+
],
|
|
96
103
|
}
|
|
97
104
|
|
|
98
105
|
if TYPE_CHECKING:
|
|
@@ -159,6 +166,13 @@ if TYPE_CHECKING:
|
|
|
159
166
|
)
|
|
160
167
|
from .timer import timeit_context
|
|
161
168
|
from .type import BoolStateDictType, StateDictType, TorchModelType
|
|
169
|
+
from .validation import (
|
|
170
|
+
ValidationError,
|
|
171
|
+
validate_directory_exists,
|
|
172
|
+
validate_file_exists,
|
|
173
|
+
validate_model_name,
|
|
174
|
+
validate_path_exists,
|
|
175
|
+
)
|
|
162
176
|
|
|
163
177
|
else:
|
|
164
178
|
sys.modules[__name__] = LazyImporter(
|
fusion_bench/utils/data.py
CHANGED
|
@@ -7,6 +7,8 @@ import torch
|
|
|
7
7
|
import torch.utils.data
|
|
8
8
|
from torch.utils.data import DataLoader, Dataset
|
|
9
9
|
|
|
10
|
+
from fusion_bench.utils.validation import ValidationError, validate_file_exists
|
|
11
|
+
|
|
10
12
|
|
|
11
13
|
class InfiniteDataLoader:
|
|
12
14
|
"""
|
|
@@ -18,23 +20,105 @@ class InfiniteDataLoader:
|
|
|
18
20
|
|
|
19
21
|
Attributes:
|
|
20
22
|
data_loader (DataLoader): The DataLoader to wrap.
|
|
21
|
-
|
|
23
|
+
_data_iter (iterator): An iterator over the DataLoader.
|
|
24
|
+
_iteration_count (int): Number of complete iterations through the dataset.
|
|
25
|
+
|
|
26
|
+
Example:
|
|
27
|
+
>>> train_loader = DataLoader(dataset, batch_size=32)
|
|
28
|
+
>>> infinite_loader = InfiniteDataLoader(train_loader)
|
|
29
|
+
>>> for i, batch in enumerate(infinite_loader):
|
|
30
|
+
... if i >= 1000: # Train for 1000 steps
|
|
31
|
+
... break
|
|
32
|
+
... train_step(batch)
|
|
22
33
|
"""
|
|
23
34
|
|
|
24
|
-
def __init__(self, data_loader: DataLoader):
|
|
35
|
+
def __init__(self, data_loader: DataLoader, max_retries: int = 1):
|
|
36
|
+
"""
|
|
37
|
+
Initialize the InfiniteDataLoader.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
data_loader: The DataLoader to wrap.
|
|
41
|
+
max_retries: Maximum number of retry attempts when resetting the data loader (default: 1).
|
|
42
|
+
|
|
43
|
+
Raises:
|
|
44
|
+
ValidationError: If data_loader is None or not a DataLoader instance.
|
|
45
|
+
"""
|
|
46
|
+
if data_loader is None:
|
|
47
|
+
raise ValidationError(
|
|
48
|
+
"data_loader cannot be None", field="data_loader", value=data_loader
|
|
49
|
+
)
|
|
50
|
+
|
|
25
51
|
self.data_loader = data_loader
|
|
26
|
-
self.
|
|
52
|
+
self.max_retries = max_retries
|
|
53
|
+
self._data_iter = iter(data_loader)
|
|
54
|
+
self._iteration_count = 0
|
|
27
55
|
|
|
28
56
|
def __iter__(self):
|
|
57
|
+
"""Reset the iterator to the beginning."""
|
|
58
|
+
self._data_iter = iter(self.data_loader)
|
|
59
|
+
self._iteration_count = 0
|
|
29
60
|
return self
|
|
30
61
|
|
|
31
62
|
def __next__(self):
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
data
|
|
37
|
-
|
|
63
|
+
"""
|
|
64
|
+
Get the next batch, resetting to the beginning when the dataset is exhausted.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
The next batch from the data loader.
|
|
68
|
+
|
|
69
|
+
Raises:
|
|
70
|
+
RuntimeError: If the data loader consistently fails to produce data.
|
|
71
|
+
"""
|
|
72
|
+
last_exception = None
|
|
73
|
+
for attempt in range(self.max_retries):
|
|
74
|
+
try:
|
|
75
|
+
data = next(self._data_iter)
|
|
76
|
+
return data
|
|
77
|
+
except StopIteration:
|
|
78
|
+
# Dataset exhausted or dataloader is empty, reset to beginning
|
|
79
|
+
self._iteration_count += 1
|
|
80
|
+
try:
|
|
81
|
+
self._data_iter = iter(self.data_loader)
|
|
82
|
+
data = next(self._data_iter)
|
|
83
|
+
return data
|
|
84
|
+
except Exception as e:
|
|
85
|
+
last_exception = e
|
|
86
|
+
continue
|
|
87
|
+
except Exception as e:
|
|
88
|
+
# Handle other potential errors from the data loader
|
|
89
|
+
raise RuntimeError(
|
|
90
|
+
f"Error retrieving data from data loader: [{type(e).__name__}]{e}"
|
|
91
|
+
) from e
|
|
92
|
+
|
|
93
|
+
# If we get here, all attempts failed
|
|
94
|
+
raise RuntimeError(
|
|
95
|
+
f"Failed to retrieve data from data loader after {self.max_retries} attempts. "
|
|
96
|
+
f"Last error: [{type(last_exception).__name__}]{last_exception}. "
|
|
97
|
+
+ (
|
|
98
|
+
f"The data loader appears to be empty."
|
|
99
|
+
if isinstance(last_exception, StopIteration)
|
|
100
|
+
else ""
|
|
101
|
+
)
|
|
102
|
+
) from last_exception
|
|
103
|
+
|
|
104
|
+
def reset(self):
|
|
105
|
+
"""Manually reset the iterator to the beginning of the dataset."""
|
|
106
|
+
self._data_iter = iter(self.data_loader)
|
|
107
|
+
self._iteration_count = 0
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def iteration_count(self) -> int:
|
|
111
|
+
"""Get the number of complete iterations through the dataset."""
|
|
112
|
+
return self._iteration_count
|
|
113
|
+
|
|
114
|
+
def __len__(self) -> int:
|
|
115
|
+
"""
|
|
116
|
+
Return the length of the underlying data loader.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
The number of batches in one complete iteration.
|
|
120
|
+
"""
|
|
121
|
+
return len(self.data_loader)
|
|
38
122
|
|
|
39
123
|
|
|
40
124
|
def load_tensor_from_file(
|
|
@@ -50,7 +134,14 @@ def load_tensor_from_file(
|
|
|
50
134
|
|
|
51
135
|
Returns:
|
|
52
136
|
torch.Tensor: The tensor loaded from the file.
|
|
137
|
+
|
|
138
|
+
Raises:
|
|
139
|
+
ValidationError: If the file doesn't exist
|
|
140
|
+
ValueError: If the file format is unsupported
|
|
53
141
|
"""
|
|
142
|
+
# Validate file exists
|
|
143
|
+
validate_file_exists(file_path)
|
|
144
|
+
|
|
54
145
|
if file_path.endswith(".np"):
|
|
55
146
|
tensor = torch.from_numpy(np.load(file_path)).detach_()
|
|
56
147
|
if file_path.endswith((".pt", ".pth")):
|