PVNet 5.0.14__tar.gz → 5.0.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {pvnet-5.0.14 → pvnet-5.0.16}/PKG-INFO +1 -2
  2. {pvnet-5.0.14 → pvnet-5.0.16}/PVNet.egg-info/PKG-INFO +1 -2
  3. {pvnet-5.0.14 → pvnet-5.0.16}/PVNet.egg-info/requires.txt +0 -1
  4. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/models/base_model.py +1 -67
  5. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/models/late_fusion/encoders/basic_blocks.py +1 -0
  6. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/models/late_fusion/late_fusion.py +0 -8
  7. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/training/lightning_module.py +11 -16
  8. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/training/train.py +0 -1
  9. pvnet-5.0.16/pvnet/utils.py +157 -0
  10. {pvnet-5.0.14 → pvnet-5.0.16}/pyproject.toml +0 -1
  11. {pvnet-5.0.14 → pvnet-5.0.16}/tests/test_end2end.py +5 -4
  12. pvnet-5.0.14/pvnet/utils.py +0 -88
  13. {pvnet-5.0.14 → pvnet-5.0.16}/LICENSE +0 -0
  14. {pvnet-5.0.14 → pvnet-5.0.16}/PVNet.egg-info/SOURCES.txt +0 -0
  15. {pvnet-5.0.14 → pvnet-5.0.16}/PVNet.egg-info/dependency_links.txt +0 -0
  16. {pvnet-5.0.14 → pvnet-5.0.16}/PVNet.egg-info/top_level.txt +0 -0
  17. {pvnet-5.0.14 → pvnet-5.0.16}/README.md +0 -0
  18. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/__init__.py +0 -0
  19. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/data/__init__.py +0 -0
  20. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/data/base_datamodule.py +0 -0
  21. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/data/site_datamodule.py +0 -0
  22. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/data/uk_regional_datamodule.py +0 -0
  23. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/load_model.py +0 -0
  24. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/models/__init__.py +0 -0
  25. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/models/ensemble.py +0 -0
  26. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/models/late_fusion/__init__.py +0 -0
  27. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/models/late_fusion/basic_blocks.py +0 -0
  28. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/models/late_fusion/encoders/__init__.py +0 -0
  29. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/models/late_fusion/encoders/encoders3d.py +0 -0
  30. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/models/late_fusion/linear_networks/__init__.py +0 -0
  31. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/models/late_fusion/linear_networks/basic_blocks.py +0 -0
  32. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/models/late_fusion/linear_networks/networks.py +0 -0
  33. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/models/late_fusion/site_encoders/__init__.py +0 -0
  34. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/models/late_fusion/site_encoders/basic_blocks.py +0 -0
  35. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/models/late_fusion/site_encoders/encoders.py +0 -0
  36. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/optimizers.py +0 -0
  37. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/training/__init__.py +0 -0
  38. {pvnet-5.0.14 → pvnet-5.0.16}/pvnet/training/plots.py +0 -0
  39. {pvnet-5.0.14 → pvnet-5.0.16}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.0.14
3
+ Version: 5.0.16
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: >=3.11
@@ -14,7 +14,6 @@ Requires-Dist: xarray
14
14
  Requires-Dist: h5netcdf
15
15
  Requires-Dist: torch>=2.0.0
16
16
  Requires-Dist: lightning
17
- Requires-Dist: torchvision
18
17
  Requires-Dist: typer
19
18
  Requires-Dist: sqlalchemy
20
19
  Requires-Dist: fsspec[s3]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.0.14
3
+ Version: 5.0.16
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: >=3.11
@@ -14,7 +14,6 @@ Requires-Dist: xarray
14
14
  Requires-Dist: h5netcdf
15
15
  Requires-Dist: torch>=2.0.0
16
16
  Requires-Dist: lightning
17
- Requires-Dist: torchvision
18
17
  Requires-Dist: typer
19
18
  Requires-Dist: sqlalchemy
20
19
  Requires-Dist: fsspec[s3]
@@ -6,7 +6,6 @@ xarray
6
6
  h5netcdf
7
7
  torch>=2.0.0
8
8
  lightning
9
- torchvision
10
9
  typer
11
10
  sqlalchemy
12
11
  fsspec[s3]
@@ -1,5 +1,4 @@
1
1
  """Base model for all PVNet submodels"""
2
- import copy
3
2
  import logging
4
3
  import os
5
4
  import shutil
@@ -12,9 +11,7 @@ import torch
12
11
  import yaml
13
12
  from huggingface_hub import ModelCard, ModelCardData, snapshot_download
14
13
  from huggingface_hub.hf_api import HfApi
15
- from ocf_data_sampler.numpy_sample.common_types import TensorBatch
16
14
  from safetensors.torch import load_file, save_file
17
- from torchvision.transforms.functional import center_crop
18
15
 
19
16
  from pvnet.utils import (
20
17
  DATA_CONFIG_NAME,
@@ -437,69 +434,6 @@ class BaseModel(torch.nn.Module, HuggingfaceMixin):
437
434
  else:
438
435
  self.num_output_features = self.forecast_len
439
436
 
440
- def _adapt_batch(self, batch: TensorBatch) -> TensorBatch:
441
- """Slice batches into appropriate shapes for model.
442
-
443
- Returns a new batch dictionary with adapted data, leaving the original batch unchanged.
444
- We make some specific assumptions about the original batch and the derived sliced batch:
445
- - We are only limiting the future projections. I.e. we are never shrinking the batch from
446
- the left hand side of the time axis, only slicing it from the right
447
- - We are only shrinking the spatial crop of the satellite and NWP data
448
-
449
- """
450
- # Create a copy of the batch to avoid modifying the original
451
- new_batch = {key: copy.deepcopy(value) for key, value in batch.items()}
452
-
453
- if "gsp" in new_batch.keys():
454
- # Slice off the end of the GSP data
455
- gsp_len = self.forecast_len + self.history_len + 1
456
- new_batch["gsp"] = new_batch["gsp"][:, :gsp_len]
457
- new_batch["gsp_time_utc"] = new_batch["gsp_time_utc"][:, :gsp_len]
458
-
459
- if "site" in new_batch.keys():
460
- # Slice off the end of the site data
461
- site_len = self.forecast_len + self.history_len + 1
462
- new_batch["site"] = new_batch["site"][:, :site_len]
463
-
464
- # Slice all site related datetime coordinates and features
465
- site_time_keys = [
466
- "site_time_utc",
467
- "site_date_sin",
468
- "site_date_cos",
469
- "site_time_sin",
470
- "site_time_cos",
471
- ]
472
-
473
- for key in site_time_keys:
474
- if key in new_batch.keys():
475
- new_batch[key] = new_batch[key][:, :site_len]
476
-
477
- if self.include_sat:
478
- # Slice off the end of the satellite data and spatially crop
479
- # Shape: batch_size, seq_length, channel, height, width
480
- new_batch["satellite_actual"] = center_crop(
481
- new_batch["satellite_actual"][:, : self.sat_sequence_len],
482
- output_size=self.sat_encoder.image_size_pixels,
483
- )
484
-
485
- if self.include_nwp:
486
- # Slice off the end of the NWP data and spatially crop
487
- for nwp_source in self.nwp_encoders_dict:
488
- # shape: batch_size, seq_len, n_chans, height, width
489
- new_batch["nwp"][nwp_source]["nwp"] = center_crop(
490
- new_batch["nwp"][nwp_source]["nwp"],
491
- output_size=self.nwp_encoders_dict[nwp_source].image_size_pixels,
492
- )[:, : self.nwp_encoders_dict[nwp_source].sequence_length]
493
-
494
- if self.include_sun:
495
- sun_len = self.forecast_len + self.history_len + 1
496
- # Slice off end of solar coords
497
- for s in ["solar_azimuth", "solar_elevation"]:
498
- if s in new_batch.keys():
499
- new_batch[s] = new_batch[s][:, :sun_len]
500
-
501
- return new_batch
502
-
503
437
  def _quantiles_to_prediction(self, y_quantiles: torch.Tensor) -> torch.Tensor:
504
438
  """
505
439
  Convert network prediction into a point prediction.
@@ -517,4 +451,4 @@ class BaseModel(torch.nn.Module, HuggingfaceMixin):
517
451
  """
518
452
  # y_quantiles Shape: batch_size, seq_length, num_quantiles
519
453
  idx = self.output_quantiles.index(0.5)
520
- return y_quantiles[..., idx]
454
+ return y_quantiles[..., idx]
@@ -31,6 +31,7 @@ class AbstractNWPSatelliteEncoder(nn.Module, metaclass=ABCMeta):
31
31
  self.out_features = out_features
32
32
  self.image_size_pixels = image_size_pixels
33
33
  self.sequence_length = sequence_length
34
+ self.in_channels = in_channels
34
35
 
35
36
  @abstractmethod
36
37
  def forward(self):
@@ -61,7 +61,6 @@ class LateFusionModel(BaseModel):
61
61
  nwp_interval_minutes: DictConfig | None = None,
62
62
  pv_interval_minutes: int = 5,
63
63
  sat_interval_minutes: int = 5,
64
- adapt_batches: bool = False,
65
64
  ):
66
65
  """Neural network which combines information from different sources.
67
66
 
@@ -110,9 +109,6 @@ class LateFusionModel(BaseModel):
110
109
  data for each source
111
110
  pv_interval_minutes: The interval between each sample of the PV data
112
111
  sat_interval_minutes: The interval between each sample of the satellite data
113
- adapt_batches: If set to true, we attempt to slice the batches to the expected shape for
114
- the model to use. This allows us to overprepare batches and slice from them for the
115
- data we need for a model run.
116
112
  """
117
113
  super().__init__(
118
114
  history_minutes=history_minutes,
@@ -134,7 +130,6 @@ class LateFusionModel(BaseModel):
134
130
  self.add_image_embedding_channel = add_image_embedding_channel
135
131
  self.interval_minutes = interval_minutes
136
132
  self.min_sat_delay_minutes = min_sat_delay_minutes
137
- self.adapt_batches = adapt_batches
138
133
 
139
134
  if self.location_id_mapping is None:
140
135
  logger.warning(
@@ -272,9 +267,6 @@ class LateFusionModel(BaseModel):
272
267
  def forward(self, x: TensorBatch) -> torch.Tensor:
273
268
  """Run model forward"""
274
269
 
275
- if self.adapt_batches:
276
- x = self._adapt_batch(x)
277
-
278
270
  if self.use_id_embedding:
279
271
  # eg: x['gsp_id'] = [1] with location_id_mapping = {1:0}, would give [0]
280
272
  id = torch.tensor(
@@ -15,6 +15,7 @@ from pvnet.data.base_datamodule import collate_fn
15
15
  from pvnet.models.base_model import BaseModel
16
16
  from pvnet.optimizers import AbstractOptimizer
17
17
  from pvnet.training.plots import plot_sample_forecasts, wandb_line_plot
18
+ from pvnet.utils import validate_batch_against_config
18
19
 
19
20
 
20
21
  class PVNetLightningModule(pl.LightningModule):
@@ -42,9 +43,6 @@ class PVNetLightningModule(pl.LightningModule):
42
43
  # This setting is only used when lr is tuned with callback
43
44
  self.lr = None
44
45
 
45
- # Set up store for all all validation results so we can log these
46
- self.save_all_validation_results = save_all_validation_results
47
-
48
46
  def transfer_batch_to_device(
49
47
  self,
50
48
  batch: TensorBatch,
@@ -105,10 +103,6 @@ class PVNetLightningModule(pl.LightningModule):
105
103
  """Run training step"""
106
104
  y_hat = self.model(batch)
107
105
 
108
- # Batch may be adapted in the model forward method, would need adapting here too
109
- if self.model.adapt_batches:
110
- batch = self.model._adapt_batch(batch)
111
-
112
106
  y = batch[self.model._target_key][:, -self.model.forecast_len :]
113
107
 
114
108
  losses = self._calculate_common_losses(y, y_hat)
@@ -193,7 +187,7 @@ class PVNetLightningModule(pl.LightningModule):
193
187
  self._val_horizon_maes: list[np.array] = []
194
188
  if self.current_epoch==0:
195
189
  self._val_persistence_horizon_maes: list[np.array] = []
196
-
190
+
197
191
  # Plot some sample forecasts
198
192
  val_dataset = self.trainer.val_dataloaders.dataset
199
193
 
@@ -209,13 +203,17 @@ class PVNetLightningModule(pl.LightningModule):
209
203
 
210
204
  batch = collate_fn([val_dataset[i] for i in idxs])
211
205
  batch = self.transfer_batch_to_device(batch, self.device, dataloader_idx=0)
206
+
207
+ # Batch validation check only during sanity check phase - use first batch
208
+ if self.trainer.sanity_checking and plot_num == 0:
209
+ validate_batch_against_config(
210
+ batch=batch,
211
+ model=self.model
212
+ )
213
+
212
214
  with torch.no_grad():
213
215
  y_hat = self.model(batch)
214
216
 
215
- # Batch may be adapted in the model forward method, would need adapting here too
216
- if self.model.adapt_batches:
217
- batch = self.model._adapt_batch(batch)
218
-
219
217
  fig = plot_sample_forecasts(
220
218
  batch,
221
219
  y_hat,
@@ -235,9 +233,6 @@ class PVNetLightningModule(pl.LightningModule):
235
233
  """Run validation step"""
236
234
 
237
235
  y_hat = self.model(batch)
238
- # Batch may be adapted in the model forward method, would need adapting here too
239
- if self.model.adapt_batches:
240
- batch = self.model._adapt_batch(batch)
241
236
 
242
237
  # Internally store the val predictions
243
238
  self._store_val_predictions(batch, y_hat)
@@ -317,7 +312,7 @@ class PVNetLightningModule(pl.LightningModule):
317
312
  self.log_dict(extreme_error_metrics, on_step=False, on_epoch=True)
318
313
 
319
314
  # Optionally save all validation results - these are overridden each epoch
320
- if self.save_all_validation_results:
315
+ if self.hparams.save_all_validation_results:
321
316
  # Add attributes
322
317
  ds_val_results.attrs["epoch"] = self.current_epoch
323
318
 
@@ -26,7 +26,6 @@ from pvnet.utils import (
26
26
  log = logging.getLogger(__name__)
27
27
 
28
28
 
29
-
30
29
  def resolve_monitor_loss(output_quantiles: list | None) -> str:
31
30
  """Return the desired metric to monitor based on whether quantile regression is being used.
32
31
 
@@ -0,0 +1,157 @@
1
+ """Utils"""
2
+ import logging
3
+ from typing import TYPE_CHECKING
4
+
5
+ import rich.syntax
6
+ import rich.tree
7
+ from lightning.pytorch.utilities import rank_zero_only
8
+ from omegaconf import DictConfig, OmegaConf
9
+
10
+ if TYPE_CHECKING:
11
+ from pvnet.models.base_model import BaseModel
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ PYTORCH_WEIGHTS_NAME = "model_weights.safetensors"
17
+ MODEL_CONFIG_NAME = "model_config.yaml"
18
+ DATA_CONFIG_NAME = "data_config.yaml"
19
+ DATAMODULE_CONFIG_NAME = "datamodule_config.yaml"
20
+ FULL_CONFIG_NAME = "full_experiment_config.yaml"
21
+ MODEL_CARD_NAME = "README.md"
22
+
23
+
24
+
25
+ def run_config_utilities(config: DictConfig) -> None:
26
+ """A couple of optional utilities.
27
+
28
+ Controlled by main config file:
29
+ - forcing debug friendly configuration
30
+
31
+ Modifies DictConfig in place.
32
+
33
+ Args:
34
+ config (DictConfig): Configuration composed by Hydra.
35
+ """
36
+
37
+ # Enable adding new keys to config
38
+ OmegaConf.set_struct(config, False)
39
+
40
+ # Force debugger friendly configuration if <config.trainer.fast_dev_run=True>
41
+ if config.trainer.get("fast_dev_run"):
42
+ logger.info("Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>")
43
+ # Debuggers don't like GPUs or multiprocessing
44
+ if config.trainer.get("gpus"):
45
+ config.trainer.gpus = 0
46
+ if config.datamodule.get("pin_memory"):
47
+ config.datamodule.pin_memory = False
48
+ if config.datamodule.get("num_workers"):
49
+ config.datamodule.num_workers = 0
50
+ if config.datamodule.get("prefetch_factor"):
51
+ config.datamodule.prefetch_factor = None
52
+
53
+ # Disable adding new keys to config
54
+ OmegaConf.set_struct(config, True)
55
+
56
+
57
+ @rank_zero_only
58
+ def print_config(
59
+ config: DictConfig,
60
+ fields: tuple[str] = (
61
+ "trainer",
62
+ "model",
63
+ "datamodule",
64
+ "callbacks",
65
+ "logger",
66
+ "seed",
67
+ ),
68
+ resolve: bool = True,
69
+ ) -> None:
70
+ """Prints content of DictConfig using Rich library and its tree structure.
71
+
72
+ Args:
73
+ config (DictConfig): Configuration composed by Hydra.
74
+ fields (Sequence[str], optional): Determines which main fields from config will
75
+ be printed and in what order.
76
+ resolve (bool, optional): Whether to resolve reference fields of DictConfig.
77
+ """
78
+
79
+ style = "dim"
80
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
81
+
82
+ for field in fields:
83
+ branch = tree.add(field, style=style, guide_style=style)
84
+
85
+ config_section = config.get(field)
86
+
87
+ branch_content = str(config_section)
88
+ if isinstance(config_section, DictConfig):
89
+ branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
90
+
91
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
92
+
93
+ rich.print(tree)
94
+
95
+
96
+ def validate_batch_against_config(
97
+ batch: dict,
98
+ model: "BaseModel",
99
+ ) -> None:
100
+ """Validates tensor shapes in batch against model configuration."""
101
+ logger.info("Performing batch shape validation against model config.")
102
+
103
+ # NWP validation
104
+ if hasattr(model, 'nwp_encoders_dict'):
105
+ if "nwp" not in batch:
106
+ raise ValueError(
107
+ "Model configured with 'nwp_encoders_dict' but 'nwp' data missing from batch."
108
+ )
109
+
110
+ for source, nwp_data in batch["nwp"].items():
111
+ if source in model.nwp_encoders_dict:
112
+
113
+ enc = model.nwp_encoders_dict[source]
114
+ expected_channels = enc.in_channels
115
+ if model.add_image_embedding_channel:
116
+ expected_channels -= 1
117
+
118
+ expected = (nwp_data["nwp"].shape[0], enc.sequence_length,
119
+ expected_channels, enc.image_size_pixels, enc.image_size_pixels)
120
+ if tuple(nwp_data["nwp"].shape) != expected:
121
+ actual_shape = tuple(nwp_data['nwp'].shape)
122
+ raise ValueError(
123
+ f"NWP.{source} shape mismatch: expected {expected}, got {actual_shape}"
124
+ )
125
+
126
+ # Satellite validation
127
+ if hasattr(model, 'sat_encoder'):
128
+ if "satellite_actual" not in batch:
129
+ raise ValueError(
130
+ "Model configured with 'sat_encoder' but 'satellite_actual' missing from batch."
131
+ )
132
+
133
+ enc = model.sat_encoder
134
+ expected_channels = enc.in_channels
135
+ if model.add_image_embedding_channel:
136
+ expected_channels -= 1
137
+
138
+ expected = (batch["satellite_actual"].shape[0], enc.sequence_length, expected_channels,
139
+ enc.image_size_pixels, enc.image_size_pixels)
140
+ if tuple(batch["satellite_actual"].shape) != expected:
141
+ actual_shape = tuple(batch['satellite_actual'].shape)
142
+ raise ValueError(f"Satellite shape mismatch: expected {expected}, got {actual_shape}")
143
+
144
+ # GSP/Site validation
145
+ key = model._target_key
146
+ if key in batch:
147
+ total_minutes = model.history_minutes + model.forecast_minutes
148
+ interval = model.interval_minutes
149
+ expected_len = total_minutes // interval + 1
150
+ expected = (batch[key].shape[0], expected_len)
151
+ if tuple(batch[key].shape) != expected:
152
+ actual_shape = tuple(batch[key].shape)
153
+ raise ValueError(
154
+ f"{key.upper()} shape mismatch: expected {expected}, got {actual_shape}"
155
+ )
156
+
157
+ logger.info("Batch shape validation successful!")
@@ -20,7 +20,6 @@ dependencies = [
20
20
  "h5netcdf",
21
21
  "torch>=2.0.0",
22
22
  "lightning",
23
- "torchvision",
24
23
  "typer",
25
24
  "sqlalchemy",
26
25
  "fsspec[s3]",
@@ -1,7 +1,8 @@
1
1
  import lightning
2
- from pvnet.data import UKRegionalStreamedDataModule
3
- from pvnet.training.lightning_module import PVNetLightningModule
2
+
3
+ from pvnet.data import UKRegionalStreamedDataModule
4
4
  from pvnet.optimizers import EmbAdamWReduceLROnPlateau
5
+ from pvnet.training.lightning_module import PVNetLightningModule
5
6
 
6
7
 
7
8
  def test_model_trainer_fit(session_tmp_path, uk_data_config_path, late_fusion_model):
@@ -15,7 +16,7 @@ def test_model_trainer_fit(session_tmp_path, uk_data_config_path, late_fusion_mo
15
16
  dataset_pickle_dir=f"{session_tmp_path}/dataset_pickles"
16
17
  )
17
18
 
18
- ligtning_model = PVNetLightningModule(
19
+ lightning_model = PVNetLightningModule(
19
20
  model=late_fusion_model,
20
21
  optimizer=EmbAdamWReduceLROnPlateau(),
21
22
  )
@@ -29,4 +30,4 @@ def test_model_trainer_fit(session_tmp_path, uk_data_config_path, late_fusion_mo
29
30
  logger=False,
30
31
  enable_checkpointing=False,
31
32
  )
32
- trainer.fit(model=ligtning_model, datamodule=datamodule)
33
+ trainer.fit(model=lightning_model, datamodule=datamodule)
@@ -1,88 +0,0 @@
1
- """Utils"""
2
- import logging
3
-
4
- import rich.syntax
5
- import rich.tree
6
- from lightning.pytorch.utilities import rank_zero_only
7
- from omegaconf import DictConfig, OmegaConf
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- PYTORCH_WEIGHTS_NAME = "model_weights.safetensors"
13
- MODEL_CONFIG_NAME = "model_config.yaml"
14
- DATA_CONFIG_NAME = "data_config.yaml"
15
- DATAMODULE_CONFIG_NAME = "datamodule_config.yaml"
16
- FULL_CONFIG_NAME = "full_experiment_config.yaml"
17
- MODEL_CARD_NAME = "README.md"
18
-
19
-
20
-
21
- def run_config_utilities(config: DictConfig) -> None:
22
- """A couple of optional utilities.
23
-
24
- Controlled by main config file:
25
- - forcing debug friendly configuration
26
-
27
- Modifies DictConfig in place.
28
-
29
- Args:
30
- config (DictConfig): Configuration composed by Hydra.
31
- """
32
-
33
- # Enable adding new keys to config
34
- OmegaConf.set_struct(config, False)
35
-
36
- # Force debugger friendly configuration if <config.trainer.fast_dev_run=True>
37
- if config.trainer.get("fast_dev_run"):
38
- logger.info("Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>")
39
- # Debuggers don't like GPUs or multiprocessing
40
- if config.trainer.get("gpus"):
41
- config.trainer.gpus = 0
42
- if config.datamodule.get("pin_memory"):
43
- config.datamodule.pin_memory = False
44
- if config.datamodule.get("num_workers"):
45
- config.datamodule.num_workers = 0
46
- if config.datamodule.get("prefetch_factor"):
47
- config.datamodule.prefetch_factor = None
48
-
49
- # Disable adding new keys to config
50
- OmegaConf.set_struct(config, True)
51
-
52
-
53
- @rank_zero_only
54
- def print_config(
55
- config: DictConfig,
56
- fields: tuple[str] = (
57
- "trainer",
58
- "model",
59
- "datamodule",
60
- "callbacks",
61
- "logger",
62
- "seed",
63
- ),
64
- resolve: bool = True,
65
- ) -> None:
66
- """Prints content of DictConfig using Rich library and its tree structure.
67
-
68
- Args:
69
- config (DictConfig): Configuration composed by Hydra.
70
- fields (Sequence[str], optional): Determines which main fields from config will
71
- be printed and in what order.
72
- resolve (bool, optional): Whether to resolve reference fields of DictConfig.
73
- """
74
-
75
- style = "dim"
76
- tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
77
-
78
- for field in fields:
79
- branch = tree.add(field, style=style, guide_style=style)
80
-
81
- config_section = config.get(field)
82
- branch_content = str(config_section)
83
- if isinstance(config_section, DictConfig):
84
- branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
85
-
86
- branch.add(rich.syntax.Syntax(branch_content, "yaml"))
87
-
88
- rich.print(tree)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes