fusion-bench 0.2.28__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. fusion_bench/constants/__init__.py +5 -1
  2. fusion_bench/constants/runtime.py +111 -7
  3. fusion_bench/dataset/gsm8k.py +6 -2
  4. fusion_bench/dataset/image_corruption/make_corruption.py +168 -0
  5. fusion_bench/method/__init__.py +10 -2
  6. fusion_bench/method/base_algorithm.py +29 -19
  7. fusion_bench/method/classification/image_classification_finetune.py +1 -2
  8. fusion_bench/method/gossip/clip_task_wise_gossip.py +1 -29
  9. fusion_bench/metrics/model_kinship/__init__.py +2 -0
  10. fusion_bench/metrics/model_kinship/calculate.py +77 -0
  11. fusion_bench/metrics/model_kinship/calculate_split.py +171 -0
  12. fusion_bench/metrics/model_kinship/utility.py +184 -0
  13. fusion_bench/metrics/nyuv2/__init__.py +31 -0
  14. fusion_bench/metrics/nyuv2/depth.py +30 -0
  15. fusion_bench/metrics/nyuv2/loss.py +40 -0
  16. fusion_bench/metrics/nyuv2/noise.py +24 -0
  17. fusion_bench/metrics/nyuv2/normal.py +34 -1
  18. fusion_bench/metrics/nyuv2/segmentation.py +35 -1
  19. fusion_bench/mixins/clip_classification.py +30 -2
  20. fusion_bench/mixins/lightning_fabric.py +46 -5
  21. fusion_bench/mixins/rich_live.py +76 -0
  22. fusion_bench/modelpool/base_pool.py +86 -5
  23. fusion_bench/models/masks/mask_model.py +8 -2
  24. fusion_bench/models/open_clip/modeling.py +7 -0
  25. fusion_bench/models/wrappers/layer_wise_fusion.py +41 -3
  26. fusion_bench/models/wrappers/task_wise_fusion.py +14 -3
  27. fusion_bench/scripts/cli.py +14 -0
  28. fusion_bench/scripts/webui.py +250 -17
  29. fusion_bench/utils/__init__.py +14 -0
  30. fusion_bench/utils/data.py +100 -9
  31. fusion_bench/utils/devices.py +3 -1
  32. fusion_bench/utils/fabric.py +185 -4
  33. fusion_bench/utils/instantiate_utils.py +29 -18
  34. fusion_bench/utils/json.py +6 -0
  35. fusion_bench/utils/misc.py +16 -0
  36. fusion_bench/utils/rich_utils.py +123 -6
  37. fusion_bench/utils/validation.py +197 -0
  38. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/METADATA +72 -13
  39. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/RECORD +49 -45
  40. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +6 -19
  41. fusion_bench_config/llama_full_finetune.yaml +4 -16
  42. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  43. fusion_bench_config/nyuv2_config.yaml +4 -13
  44. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  45. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  46. fusion_bench/utils/auto.py +0 -31
  47. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/WHEEL +0 -0
  48. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/entry_points.txt +0 -0
  49. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/licenses/LICENSE +0 -0
  50. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  """
2
- TODO: Per-session state management (use AppState)
2
+ Web UI for FusionBench Command Generator with per-session state management.
3
3
  """
4
4
 
5
5
  import argparse
@@ -13,39 +13,94 @@ import hydra
13
13
  import yaml
14
14
  from colorama import Fore, Style # For cross-platform color support
15
15
  from hydra import compose, initialize_config_dir
16
+ from hydra.core.hydra_config import HydraConfig
16
17
  from omegaconf import DictConfig, ListConfig, OmegaConf
17
18
 
18
19
  from fusion_bench.scripts.cli import _get_default_config_path
19
20
 
20
21
 
22
+ def escape_overrides(value: str) -> str:
23
+ """
24
+ Escapes special characters in Hydra command-line override values.
25
+
26
+ Adds quotes around values containing spaces and escapes equals signs
27
+ to prevent them from being interpreted as key-value separators.
28
+
29
+ Args:
30
+ value (str): The override value to escape.
31
+
32
+ Returns:
33
+ str: The escaped value ready for use in command-line overrides.
34
+ """
35
+ if " " in value and not (value.startswith('"') or value.startswith("'")):
36
+ return f"'{value}'"
37
+ if "=" in value:
38
+ return value.replace("=", "\\=")
39
+ return value
40
+
41
+
21
42
  class ConfigGroupNode:
43
+ """
44
+ Represents a node in the configuration directory tree.
45
+
46
+ This class recursively builds a tree structure representing the Hydra
47
+ configuration directory hierarchy, including subdirectories (child groups)
48
+ and YAML configuration files.
49
+
50
+ Attributes:
51
+ name (str): Name of the configuration group (directory name).
52
+ path (Path): Full path to the directory.
53
+ parent (Optional[ConfigGroupNode]): Parent node in the tree.
54
+ children (List[ConfigGroupNode]): Child directory nodes.
55
+ configs (List[str]): List of YAML config file names (without extension).
56
+ """
57
+
22
58
  name: str
23
59
  path: Path
24
- parent: Optional["ConfigGroupNode"] = None
60
+ parent: Optional["ConfigGroupNode"]
25
61
  children: List["ConfigGroupNode"]
26
62
  configs: List[str]
27
63
 
28
- def __init__(self, path: str | Path):
64
+ def __init__(self, path: str | Path, parent: Optional["ConfigGroupNode"] = None):
65
+ """
66
+ Initialize a ConfigGroupNode.
67
+
68
+ Args:
69
+ path: Path to the configuration directory.
70
+ parent: Parent node in the tree (None for root).
71
+ """
29
72
  self.path = Path(path)
30
73
  assert self.path.is_dir()
31
74
  self.name = self.path.stem
75
+ self.parent = parent
32
76
  self.children = []
33
77
  self.configs = []
34
78
  for child in self.path.iterdir():
35
79
  if child.is_dir():
36
- child_node = ConfigGroupNode(child)
37
- child_node.parent = self
80
+ child_node = ConfigGroupNode(child, parent=self)
38
81
  self.children.append(child_node)
39
82
  elif child.is_file() and child.suffix == ".yaml":
40
83
  self.configs.append(child.stem)
41
84
 
42
85
  def __repr__(self):
43
86
  """
44
- Return string of the tree structure
87
+ Return a colored string representation of the tree structure.
88
+
89
+ Returns:
90
+ str: Tree structure with colored group names.
45
91
  """
46
92
  return f"{Fore.BLUE}{self.name}{Style.RESET_ALL}\n" + self._repr_indented()
47
93
 
48
94
  def _repr_indented(self, prefix=""):
95
+ """
96
+ Generate indented tree representation recursively.
97
+
98
+ Args:
99
+ prefix: String prefix for indentation.
100
+
101
+ Returns:
102
+ str: Indented tree structure.
103
+ """
49
104
  result = ""
50
105
 
51
106
  items = self.configs + self.children
@@ -63,9 +118,30 @@ class ConfigGroupNode:
63
118
  return result
64
119
 
65
120
  def has_child_group(self, name: str) -> bool:
121
+ """
122
+ Check if this node has a child group with the given name.
123
+
124
+ Args:
125
+ name: Name of the child group to check.
126
+
127
+ Returns:
128
+ bool: True if child group exists, False otherwise.
129
+ """
66
130
  return any(child.name == name for child in self.children)
67
131
 
68
132
  def __getitem__(self, key: str) -> Union["ConfigGroupNode", str]:
133
+ """
134
+ Get a child group or config by name.
135
+
136
+ Args:
137
+ key: Name of the child group or config file.
138
+
139
+ Returns:
140
+ Union[ConfigGroupNode, str]: Child node or config file name.
141
+
142
+ Raises:
143
+ KeyError: If no child group or config with that name exists.
144
+ """
69
145
  for child in self.children:
70
146
  if child.name == key:
71
147
  return child
@@ -76,12 +152,28 @@ class ConfigGroupNode:
76
152
 
77
153
  @functools.cached_property
78
154
  def prefix(self) -> str:
155
+ """
156
+ Get the dot-separated prefix path from root to this node.
157
+
158
+ Returns:
159
+ str: Prefix path (e.g., "method.modelpool.").
160
+ """
79
161
  if self.parent is None:
80
162
  return ""
81
163
  return self.parent.prefix + self.name + "."
82
164
 
83
165
 
84
166
  def priority_iterable(iter, priority_keys):
167
+ """
168
+ Iterate over items with priority keys first, then remaining items.
169
+
170
+ Args:
171
+ iter: Iterable to process.
172
+ priority_keys: Keys to yield first.
173
+
174
+ Yields:
175
+ Items from iter, with priority_keys first.
176
+ """
85
177
  items = list(iter)
86
178
  for key in priority_keys:
87
179
  if key in items:
@@ -93,7 +185,16 @@ def priority_iterable(iter, priority_keys):
93
185
 
94
186
  class AppState:
95
187
  """
96
- Pre-session state of the app
188
+ Per-session state of the app.
189
+
190
+ Manages the current configuration state including the selected config name,
191
+ overrides, and the composed Hydra configuration.
192
+
193
+ Attributes:
194
+ config_name (str): Name of the root configuration file.
195
+ hydra_options (List[str]): Hydra-specific command line options.
196
+ overrides (List[str]): List of configuration overrides.
197
+ config (DictConfig): The composed OmegaConf configuration.
97
198
  """
98
199
 
99
200
  config_name: str
@@ -108,6 +209,15 @@ class AppState:
108
209
  hydra_options: List[str] = [],
109
210
  overrides: List[str] = [],
110
211
  ) -> None:
212
+ """
213
+ Initialize the application state.
214
+
215
+ Args:
216
+ config_path: Path to the config directory.
217
+ config_name: Name of the root config file.
218
+ hydra_options: Hydra command line options.
219
+ overrides: Initial configuration overrides.
220
+ """
111
221
  super().__init__()
112
222
  self.config_path = config_path
113
223
  self.config_name = config_name
@@ -117,6 +227,12 @@ class AppState:
117
227
 
118
228
  @property
119
229
  def config_str(self):
230
+ """
231
+ Get the YAML string representation of the current configuration.
232
+
233
+ Returns:
234
+ str: YAML formatted configuration.
235
+ """
120
236
  return OmegaConf.to_yaml(self.config)
121
237
 
122
238
  def update_config(
@@ -124,6 +240,16 @@ class AppState:
124
240
  config_name: str = None,
125
241
  overrides: List[str] = None,
126
242
  ) -> "AppState":
243
+ """
244
+ Update the configuration with new name and/or overrides.
245
+
246
+ Args:
247
+ config_name: New root config name (optional).
248
+ overrides: New list of overrides (optional).
249
+
250
+ Returns:
251
+ AppState: Self for method chaining.
252
+ """
127
253
  if config_name is not None:
128
254
  self.config_name = config_name
129
255
  if overrides is not None:
@@ -133,11 +259,21 @@ class AppState:
133
259
  self.config = ""
134
260
  else:
135
261
  self.config = compose(
136
- config_name=self.config_name, overrides=self.overrides
262
+ config_name=self.config_name,
263
+ overrides=self.overrides,
264
+ return_hydra_config=True,
137
265
  )
266
+ HydraConfig().set_config(self.config)
267
+ del self.config.hydra
138
268
  return self
139
269
 
140
270
  def generate_command(self):
271
+ """
272
+ Generate the fusion_bench CLI command from current state.
273
+
274
+ Returns:
275
+ str: Complete command ready to execute in shell.
276
+ """
141
277
  # Generate the command according to `config_name` and `overrides` (a list of strings)
142
278
  command = "fusion_bench \\\n"
143
279
  if self.config_path is not None:
@@ -152,23 +288,66 @@ class AppState:
152
288
 
153
289
  @property
154
290
  def config_str_and_command(self):
291
+ """
292
+ Get both config string and command as a tuple.
293
+
294
+ Returns:
295
+ Tuple[str, str]: (YAML config, shell command).
296
+ """
155
297
  return self.config_str, self.generate_command()
156
298
 
157
299
  def get_override(self, key: str):
300
+ """
301
+ Get the override value for a specific key.
302
+
303
+ Args:
304
+ key: Configuration key to look up.
305
+
306
+ Returns:
307
+ Optional[str]: Override value or None if not found.
308
+ """
158
309
  for ov in self.overrides:
159
310
  if ov.startswith(f"{key}="):
160
311
  return "".join(ov.split("=")[1:])
161
312
  return None
162
313
 
163
314
  def update_override(self, key: str, value):
315
+ """
316
+ Update or add an override for a specific key.
317
+
318
+ Args:
319
+ key: Configuration key to override.
320
+ value: New value for the key.
321
+
322
+ Returns:
323
+ AppState: Updated state after recomposing config.
324
+ """
164
325
  self.overrides = [ov for ov in self.overrides if not ov.startswith(f"{key}=")]
165
326
  if value:
166
- self.overrides.append(f"{key}={value}")
327
+ self.overrides.append(f"{key}={escape_overrides(value)}")
167
328
  return self.update_config()
168
329
 
169
330
 
170
331
  class App:
332
+ """
333
+ Main application class for the FusionBench WebUI.
334
+
335
+ Manages the Gradio interface, configuration tree, and application state.
336
+
337
+ Attributes:
338
+ args: Command line arguments.
339
+ group_tree (ConfigGroupNode): Root of the config directory tree.
340
+ init_config_name (str): Initial configuration name.
341
+ app_state (AppState): Current application state.
342
+ """
343
+
171
344
  def __init__(self, args):
345
+ """
346
+ Initialize the application.
347
+
348
+ Args:
349
+ args: Parsed command line arguments.
350
+ """
172
351
  super().__init__()
173
352
  self.args = args
174
353
  group_tree = ConfigGroupNode(self.config_path)
@@ -177,8 +356,8 @@ class App:
177
356
 
178
357
  self.group_tree = group_tree
179
358
 
180
- if "example_config" in group_tree.configs:
181
- self.init_config_name = "example_config"
359
+ if "fabric_model_fusion" in group_tree.configs:
360
+ self.init_config_name = "fabric_model_fusion"
182
361
  else:
183
362
  self.init_config_name = group_tree.configs[0]
184
363
 
@@ -197,31 +376,64 @@ class App:
197
376
 
198
377
  @functools.cached_property
199
378
  def config_path(self):
379
+ """
380
+ Get the configuration directory path.
381
+
382
+ Returns:
383
+ Path: Path to the config directory.
384
+ """
200
385
  if self.args.config_path:
201
386
  return Path(self.args.config_path)
202
387
  else:
203
388
  return _get_default_config_path()
204
389
 
205
390
  def __getattr__(self, name):
391
+ """
392
+ Delegate attribute access to app_state if not found in App.
393
+
394
+ Args:
395
+ name: Attribute name.
396
+
397
+ Returns:
398
+ Attribute value from app_state.
399
+
400
+ Raises:
401
+ AttributeError: If attribute not found in app_state either.
402
+ """
206
403
  if hasattr(self.app_state, name):
207
404
  return getattr(self.app_state, name)
208
405
  raise AttributeError(f"App object has no attribute {name}")
209
406
 
210
407
  def generate_ui(self):
408
+ """
409
+ Generate the Gradio user interface.
410
+
411
+ Creates interactive UI components for configuration selection,
412
+ parameter editing, and command generation.
413
+
414
+ Returns:
415
+ gr.Blocks: Gradio application instance.
416
+ """
211
417
  with gr.Blocks() as app:
212
418
  gr.Markdown("# FusionBench Command Generator")
213
419
 
214
420
  # 1. Choose a root config file
215
- root_configs = gr.Dropdown(
216
- choices=self.group_tree.configs,
217
- value=self.config_name,
218
- label="Root Config",
219
- )
421
+ with gr.Row(equal_height=True):
422
+ root_configs = gr.Dropdown(
423
+ choices=self.group_tree.configs,
424
+ value=self.config_name,
425
+ label="Root Config",
426
+ scale=4,
427
+ )
428
+ reset_button = gr.Button("Reset", scale=1)
220
429
 
221
430
  with gr.Row():
222
431
  with gr.Column(scale=2):
223
432
  command_output = gr.Code(
224
- language="shell", label="Generated Command"
433
+ value=self.app_state.generate_command(),
434
+ language="shell",
435
+ label="Generated Command",
436
+ interactive=False,
225
437
  )
226
438
 
227
439
  @gr.render(inputs=[root_configs, command_output])
@@ -352,10 +564,26 @@ class App:
352
564
  outputs=[config_output, command_output],
353
565
  )
354
566
 
567
+ def reset_app(config_name):
568
+ # Reset overrides and update config
569
+ self.app_state.overrides = []
570
+ return self.app_state.update_config(config_name).config_str_and_command
571
+
572
+ reset_button.click(
573
+ reset_app,
574
+ inputs=[root_configs],
575
+ outputs=[config_output, command_output],
576
+ )
355
577
  return app
356
578
 
357
579
 
358
580
  def parse_args():
581
+ """
582
+ Parse command line arguments for the WebUI.
583
+
584
+ Returns:
585
+ argparse.Namespace: Parsed arguments.
586
+ """
359
587
  parser = argparse.ArgumentParser(description="FusionBench Command Generator")
360
588
  parser.add_argument(
361
589
  "--config-path",
@@ -390,6 +618,11 @@ def parse_args():
390
618
 
391
619
 
392
620
  def main() -> None:
621
+ """
622
+ Main entry point for the FusionBench WebUI application.
623
+
624
+ Parses arguments, initializes the app, and launches the Gradio interface.
625
+ """
393
626
  args = parse_args()
394
627
 
395
628
  app = App(args).generate_ui()
@@ -93,6 +93,13 @@ _import_structure = {
93
93
  "StateDictType",
94
94
  "TorchModelType",
95
95
  ],
96
+ "validation": [
97
+ "validate_path_exists",
98
+ "validate_file_exists",
99
+ "validate_directory_exists",
100
+ "validate_model_name",
101
+ "ValidationError",
102
+ ],
96
103
  }
97
104
 
98
105
  if TYPE_CHECKING:
@@ -159,6 +166,13 @@ if TYPE_CHECKING:
159
166
  )
160
167
  from .timer import timeit_context
161
168
  from .type import BoolStateDictType, StateDictType, TorchModelType
169
+ from .validation import (
170
+ ValidationError,
171
+ validate_directory_exists,
172
+ validate_file_exists,
173
+ validate_model_name,
174
+ validate_path_exists,
175
+ )
162
176
 
163
177
  else:
164
178
  sys.modules[__name__] = LazyImporter(
@@ -7,6 +7,8 @@ import torch
7
7
  import torch.utils.data
8
8
  from torch.utils.data import DataLoader, Dataset
9
9
 
10
+ from fusion_bench.utils.validation import ValidationError, validate_file_exists
11
+
10
12
 
11
13
  class InfiniteDataLoader:
12
14
  """
@@ -18,23 +20,105 @@ class InfiniteDataLoader:
18
20
 
19
21
  Attributes:
20
22
  data_loader (DataLoader): The DataLoader to wrap.
21
- data_iter (iterator): An iterator over the DataLoader.
23
+ _data_iter (iterator): An iterator over the DataLoader.
24
+ _iteration_count (int): Number of complete iterations through the dataset.
25
+
26
+ Example:
27
+ >>> train_loader = DataLoader(dataset, batch_size=32)
28
+ >>> infinite_loader = InfiniteDataLoader(train_loader)
29
+ >>> for i, batch in enumerate(infinite_loader):
30
+ ... if i >= 1000: # Train for 1000 steps
31
+ ... break
32
+ ... train_step(batch)
22
33
  """
23
34
 
24
- def __init__(self, data_loader: DataLoader):
35
+ def __init__(self, data_loader: DataLoader, max_retries: int = 1):
36
+ """
37
+ Initialize the InfiniteDataLoader.
38
+
39
+ Args:
40
+ data_loader: The DataLoader to wrap.
41
+ max_retries: Maximum number of retry attempts when resetting the data loader (default: 1).
42
+
43
+ Raises:
44
+ ValidationError: If data_loader is None or not a DataLoader instance.
45
+ """
46
+ if data_loader is None:
47
+ raise ValidationError(
48
+ "data_loader cannot be None", field="data_loader", value=data_loader
49
+ )
50
+
25
51
  self.data_loader = data_loader
26
- self.data_iter = iter(data_loader)
52
+ self.max_retries = max_retries
53
+ self._data_iter = iter(data_loader)
54
+ self._iteration_count = 0
27
55
 
28
56
  def __iter__(self):
57
+ """Reset the iterator to the beginning."""
58
+ self._data_iter = iter(self.data_loader)
59
+ self._iteration_count = 0
29
60
  return self
30
61
 
31
62
  def __next__(self):
32
- try:
33
- data = next(self.data_iter)
34
- except StopIteration:
35
- self.data_iter = iter(self.data_loader) # Reset the data loader
36
- data = next(self.data_iter)
37
- return data
63
+ """
64
+ Get the next batch, resetting to the beginning when the dataset is exhausted.
65
+
66
+ Returns:
67
+ The next batch from the data loader.
68
+
69
+ Raises:
70
+ RuntimeError: If the data loader consistently fails to produce data.
71
+ """
72
+ last_exception = None
73
+ for attempt in range(self.max_retries):
74
+ try:
75
+ data = next(self._data_iter)
76
+ return data
77
+ except StopIteration:
78
+ # Dataset exhausted or dataloader is empty, reset to beginning
79
+ self._iteration_count += 1
80
+ try:
81
+ self._data_iter = iter(self.data_loader)
82
+ data = next(self._data_iter)
83
+ return data
84
+ except Exception as e:
85
+ last_exception = e
86
+ continue
87
+ except Exception as e:
88
+ # Handle other potential errors from the data loader
89
+ raise RuntimeError(
90
+ f"Error retrieving data from data loader: [{type(e).__name__}]{e}"
91
+ ) from e
92
+
93
+ # If we get here, all attempts failed
94
+ raise RuntimeError(
95
+ f"Failed to retrieve data from data loader after {self.max_retries} attempts. "
96
+ f"Last error: [{type(last_exception).__name__}]{last_exception}. "
97
+ + (
98
+ f"The data loader appears to be empty."
99
+ if isinstance(last_exception, StopIteration)
100
+ else ""
101
+ )
102
+ ) from last_exception
103
+
104
+ def reset(self):
105
+ """Manually reset the iterator to the beginning of the dataset."""
106
+ self._data_iter = iter(self.data_loader)
107
+ self._iteration_count = 0
108
+
109
+ @property
110
+ def iteration_count(self) -> int:
111
+ """Get the number of complete iterations through the dataset."""
112
+ return self._iteration_count
113
+
114
+ def __len__(self) -> int:
115
+ """
116
+ Return the length of the underlying data loader.
117
+
118
+ Returns:
119
+ The number of batches in one complete iteration.
120
+ """
121
+ return len(self.data_loader)
38
122
 
39
123
 
40
124
  def load_tensor_from_file(
@@ -50,7 +134,14 @@ def load_tensor_from_file(
50
134
 
51
135
  Returns:
52
136
  torch.Tensor: The tensor loaded from the file.
137
+
138
+ Raises:
139
+ ValidationError: If the file doesn't exist
140
+ ValueError: If the file format is unsupported
53
141
  """
142
+ # Validate file exists
143
+ validate_file_exists(file_path)
144
+
54
145
  if file_path.endswith(".np"):
55
146
  tensor = torch.from_numpy(np.load(file_path)).detach_()
56
147
  if file_path.endswith((".pt", ".pth")):
@@ -32,11 +32,13 @@ def clear_cuda_cache():
32
32
  Clears the CUDA memory cache to free up GPU memory.
33
33
  Works only if CUDA is available.
34
34
  """
35
+
35
36
  gc.collect()
36
37
  if torch.cuda.is_available():
37
38
  torch.cuda.empty_cache()
39
+ gc.collect()
38
40
  else:
39
- log.warning("CUDA is not available. No cache to clear.")
41
+ log.debug("CUDA is not available. No cache to clear.")
40
42
 
41
43
 
42
44
  def to_device(