sdofmv2 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. sdofmv2-0.1.0/LICENSE +21 -0
  2. sdofmv2-0.1.0/PKG-INFO +149 -0
  3. sdofmv2-0.1.0/README.md +119 -0
  4. sdofmv2-0.1.0/pyproject.toml +48 -0
  5. sdofmv2-0.1.0/setup.cfg +4 -0
  6. sdofmv2-0.1.0/src/sdofmv2/__init__.py +0 -0
  7. sdofmv2-0.1.0/src/sdofmv2/core/__init__.py +18 -0
  8. sdofmv2-0.1.0/src/sdofmv2/core/attention_map.py +314 -0
  9. sdofmv2-0.1.0/src/sdofmv2/core/basemodule.py +140 -0
  10. sdofmv2-0.1.0/src/sdofmv2/core/datamodule.py +1169 -0
  11. sdofmv2-0.1.0/src/sdofmv2/core/losses.py +87 -0
  12. sdofmv2-0.1.0/src/sdofmv2/core/mae3d.py +491 -0
  13. sdofmv2-0.1.0/src/sdofmv2/core/mae3d_old.py +413 -0
  14. sdofmv2-0.1.0/src/sdofmv2/core/mae_module.py +262 -0
  15. sdofmv2-0.1.0/src/sdofmv2/core/mae_module_old.py +222 -0
  16. sdofmv2-0.1.0/src/sdofmv2/core/pca_analysis.py +49 -0
  17. sdofmv2-0.1.0/src/sdofmv2/core/reconstruction.py +164 -0
  18. sdofmv2-0.1.0/src/sdofmv2/tasks/__init__.py +0 -0
  19. sdofmv2-0.1.0/src/sdofmv2/tasks/f107/__init__.py +2 -0
  20. sdofmv2-0.1.0/src/sdofmv2/tasks/f107/f107_datamodule.py +232 -0
  21. sdofmv2-0.1.0/src/sdofmv2/tasks/f107/f107_module.py +138 -0
  22. sdofmv2-0.1.0/src/sdofmv2/tasks/missing_data/__init__.py +3 -0
  23. sdofmv2-0.1.0/src/sdofmv2/tasks/missing_data/missing_data_module.py +99 -0
  24. sdofmv2-0.1.0/src/sdofmv2/tasks/missing_data/necks.py +154 -0
  25. sdofmv2-0.1.0/src/sdofmv2/tasks/missing_data/wrap_encoder.py +43 -0
  26. sdofmv2-0.1.0/src/sdofmv2/tasks/solar_wind/__init__.py +4 -0
  27. sdofmv2-0.1.0/src/sdofmv2/tasks/solar_wind/datamodule.py +579 -0
  28. sdofmv2-0.1.0/src/sdofmv2/tasks/solar_wind/focal_loss.py +44 -0
  29. sdofmv2-0.1.0/src/sdofmv2/tasks/solar_wind/head_networks.py +296 -0
  30. sdofmv2-0.1.0/src/sdofmv2/tasks/solar_wind/model.py +582 -0
  31. sdofmv2-0.1.0/src/sdofmv2/tasks/solar_wind/visualization.py +623 -0
  32. sdofmv2-0.1.0/src/sdofmv2/utils/__init__.py +17 -0
  33. sdofmv2-0.1.0/src/sdofmv2/utils/constants.py +55 -0
  34. sdofmv2-0.1.0/src/sdofmv2/utils/utils.py +251 -0
  35. sdofmv2-0.1.0/src/sdofmv2.egg-info/PKG-INFO +149 -0
  36. sdofmv2-0.1.0/src/sdofmv2.egg-info/SOURCES.txt +37 -0
  37. sdofmv2-0.1.0/src/sdofmv2.egg-info/dependency_links.txt +1 -0
  38. sdofmv2-0.1.0/src/sdofmv2.egg-info/requires.txt +11 -0
  39. sdofmv2-0.1.0/src/sdofmv2.egg-info/top_level.txt +1 -0
sdofmv2-0.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Joseph Gallego
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
sdofmv2-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,149 @@
1
+ Metadata-Version: 2.4
2
+ Name: sdofmv2
3
+ Version: 0.1.0
4
+ Summary: Solar phenomena prediction models
5
+ Author-email: Joseph Gallego <joaggi@gmail.com>, Daniela Martin <dmartinvega@gmail.com>, Jinsu Hong <jinsuhong.knight@gmail.com>
6
+ License: MIT
7
+ Project-URL: Repository, https://github.com/Joaggi/sdofmv2
8
+ Project-URL: Issues, https://github.com/Joaggi/sdofmv2/issues
9
+ Keywords: foundation model,solar physics,deep learning,space weather
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Operating System :: OS Independent
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: Topic :: Scientific/Engineering :: Astronomy
15
+ Requires-Python: >=3.11
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+ Requires-Dist: torch>=2.5.1
19
+ Requires-Dist: lightning>=2.6.0
20
+ Requires-Dist: numpy>=2.3.5
21
+ Requires-Dist: pandas>=2.3.3
22
+ Requires-Dist: transformers>=4.57.3
23
+ Requires-Dist: sunpy>=7.0.4
24
+ Requires-Dist: astropy>=6.0
25
+ Requires-Dist: timm>=1.0.22
26
+ Requires-Dist: einops>=0.8.1
27
+ Requires-Dist: hydra-core>=1.3.2
28
+ Requires-Dist: wandb>=0.23.1
29
+ Dynamic: license-file
30
+
31
+ # SDO FM v2: [Full Title of the Project/Model]
32
+
33
+ [![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/)
34
+ [![PyTorch](https://img.shields.io/badge/PyTorch-%23EE4C2C.svg?style=flat&logo=PyTorch&logoColor=white)](https://pytorch.org/)
35
+ [![PyTorch Lightning](https://img.shields.io/badge/PyTorch_Lightning-%23792EE5.svg?style=flat&logo=pytorchlightning&logoColor=white)](https://lightning.ai/docs/pytorch/stable/)
36
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
37
+
38
+ ## Introduction
39
+ **SDOFMv2** is an advanced multi-instrument foundation model designed to analyze Solar Dynamics Observatory (SDO) data and drive large-scale, data-driven heliophysics research. Building upon the original SDOFM framework, this version addresses previous limitations like restricted temporal coverage and reconstruction artifacts to significantly improve spatial coherence and global consistency.
40
+
41
+ ![Model architecture](sdofmv2.svg)
42
+ *A Masked Autoencoder (MAE) based on a Vision Transformer (ViT) architecture is utilized for pretraining. During this phase, a% of the image patches are masked, while the remaining (100 - a)% are processed by the encoder. The decoder block then reconstructs all patches, optimized via a customized loss function.*
43
+
44
+ ---
45
+
46
+ ## Getting Started
47
+
48
+ ### Prerequisites
49
+ * Linux or macOS
50
+ * Python 3.11+
51
+ * NVIDIA GPU + CUDA toolkit (Recommended for training)
52
+
53
+ ### Environment Setup
54
+ We recommend using `mamba` to manage dependencies.
55
+
56
+ > **Important Hardware Note:** > The `sdofmv2_environment.yml` file is configured for **CUDA 12.8** by default. If your hardware or drivers require a different CUDA version (e.g., CUDA 11.8), please open `sdofmv2_environment.yml` and modify the `pip` section at the bottom to match your system (e.g., change `cu128` to `cu118`) before running the setup commands.
57
+
58
+ **Using Mamba:**
59
+ ```bash
60
+ # Clone the repository
61
+ git clone [https://github.com/Joaggi/sdofmv2.git](https://github.com/Joaggi/sdofmv2.git)
62
+ cd sdofmv2
63
+
64
+ # Create and activate the environment (This automatically installs PyTorch and the local package)
65
+ mamba env create -f sdofmv2_environment.yml
66
+ mamba activate sdofmv2
67
+ ```
68
+ ---
69
+
70
+ ## Repository Structure
71
+
72
+ ```text
73
+ .
74
+ ├── configs/ # YAML configurations for experiments
75
+ │ ├── downstream/ # Configs for downstream tasks (F10.7, solar wind)
76
+ │ └── pretrain/ # Configs for MAE pretraining (AIA, HMI)
77
+ ├── notebooks/ # Jupyter notebooks for analysis and visualization
78
+ │ ├── analysis/ # Attention maps, PCA, and masking analysis
79
+ │ └── downstream_apps/ # How to use downstream scripts (Notebooks) for F10.7 and missing data applications
80
+ ├── scripts/ # Executable scripts for training and testing
81
+ │ ├── pretrain.py # Main pretraining script
82
+ │ ├── finetuning_*.py # Scripts for downstream finetuning
83
+ │ └── test.py # Script for evaluating checkpoints
84
+ ├── src/ # Core source code package
85
+ │ └── sdofmv2/
86
+ │ ├── core/ # Base model architectures and modules
87
+ │ ├── tasks/ # PyTorch Lightning modules (model & data module) for downstream tasks
88
+ │ └── utils/ # Helper functions, physical constants and metrics
89
+ ├── pyproject.toml # Project metadata and build dependencies
90
+ └── sdofmv2_environment.yml # Mamba environment definition file
91
+ ```
92
+
93
+ ---
94
+
95
+ ## How to Use
96
+
97
+ *(Note: It is recommended to run all scripts from the root directory of the repository so that file paths to `configs/` and `src/` resolve correctly.)*
98
+
99
+ ### 1. Data Preparation
100
+ Before training or running inference, you need to prepare the dataset.
101
+ [Explain where to download the data, or provide a command if you have a script for it.]
102
+ ```bash
103
+ python scripts/download_data_cache.py --target_dir ./assets/
104
+ ```
105
+
106
+ ### 2. Training the Model
107
+ To train the model from scratch, execute the pretraining script and pass the relevant configuration file.
108
+ ```bash
109
+ python scripts/pretrain.py --config-name pretrain_mae_AIA.yaml
110
+ ```
111
+
112
+ ### 3. Inference and Evaluation
113
+ To evaluate a pre-trained checkpoint on the test set:
114
+ ```bash
115
+ python scripts/test.py --config-name pretrain_mae_AIA.yaml
116
+ ```
117
+
118
+ ### 4. Downstream Finetuning
119
+ To finetune the model on a specific downstream task (e.g., solar wind forecasting):
120
+ ```bash
121
+ python scripts/finetuning_solarwind.py --config-name finetune_solarwind_config.yaml
122
+ ```
123
+ ---
124
+
125
+ ## Results & Visualizations
126
+ [Include a brief summary of the model's performance. You can add a table of metrics or a sample plot showing predictions vs. ground truth.]
127
+
128
+ ![Sample Visualization](notebooks/analysis/SDOFMv2_AIA_results_exp.png)
129
+ *The first row displays the original ground-truth images. The second and third rows show the model's reconstructed images using masking ratios of 0% and 50%, respectively.*
130
+
131
+ ---
132
+
133
+ ## Citation
134
+ If you find this repository or model useful in your academic research, please consider citing our work:
135
+
136
+ ```bibtex
137
+ @misc{sdofmv2,
138
+ author = {Hong, Jinsu and Martin, Daniela and Gallego, Joseph},
139
+ title = {SDOFMv2: A Multi-Instrument Foundation Model for the Solar Dynamics Observatory with Transferable Downstream Applications},
140
+ year = {2026},
141
+ publisher = {GitHub},
142
+ journal = {GitHub repository},
143
+ howpublished = {\url{[https://github.com/Joaggi/sdofmv2](https://github.com/Joaggi/sdofmv2)}},
144
+ note = {Jinsu Hong, Daniela Martin, and Joseph Gallego contributed equally to this work}
145
+ }
146
+ ```
147
+
148
+ ## Contributing
149
+ Contributions, bug reports, and feature requests are welcome! Please feel free to check the [issues page](https://github.com/Joaggi/sdofmv2/issues) or submit a pull request.
@@ -0,0 +1,119 @@
1
+ # SDO FM v2: [Full Title of the Project/Model]
2
+
3
+ [![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/)
4
+ [![PyTorch](https://img.shields.io/badge/PyTorch-%23EE4C2C.svg?style=flat&logo=PyTorch&logoColor=white)](https://pytorch.org/)
5
+ [![PyTorch Lightning](https://img.shields.io/badge/PyTorch_Lightning-%23792EE5.svg?style=flat&logo=pytorchlightning&logoColor=white)](https://lightning.ai/docs/pytorch/stable/)
6
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
7
+
8
+ ## Introduction
9
+ **SDOFMv2** is an advanced multi-instrument foundation model designed to analyze Solar Dynamics Observatory (SDO) data and drive large-scale, data-driven heliophysics research. Building upon the original SDOFM framework, this version addresses previous limitations like restricted temporal coverage and reconstruction artifacts to significantly improve spatial coherence and global consistency.
10
+
11
+ ![Model architecture](sdofmv2.svg)
12
+ *A Masked Autoencoder (MAE) based on a Vision Transformer (ViT) architecture is utilized for pretraining. During this phase, a% of the image patches are masked, while the remaining (100 - a)% are processed by the encoder. The decoder block then reconstructs all patches, optimized via a customized loss function.*
13
+
14
+ ---
15
+
16
+ ## Getting Started
17
+
18
+ ### Prerequisites
19
+ * Linux or macOS
20
+ * Python 3.11+
21
+ * NVIDIA GPU + CUDA toolkit (Recommended for training)
22
+
23
+ ### Environment Setup
24
+ We recommend using `mamba` to manage dependencies.
25
+
26
+ > **Important Hardware Note:** > The `sdofmv2_environment.yml` file is configured for **CUDA 12.8** by default. If your hardware or drivers require a different CUDA version (e.g., CUDA 11.8), please open `sdofmv2_environment.yml` and modify the `pip` section at the bottom to match your system (e.g., change `cu128` to `cu118`) before running the setup commands.
27
+
28
+ **Using Mamba:**
29
+ ```bash
30
+ # Clone the repository
31
+ git clone [https://github.com/Joaggi/sdofmv2.git](https://github.com/Joaggi/sdofmv2.git)
32
+ cd sdofmv2
33
+
34
+ # Create and activate the environment (This automatically installs PyTorch and the local package)
35
+ mamba env create -f sdofmv2_environment.yml
36
+ mamba activate sdofmv2
37
+ ```
38
+ ---
39
+
40
+ ## Repository Structure
41
+
42
+ ```text
43
+ .
44
+ ├── configs/ # YAML configurations for experiments
45
+ │ ├── downstream/ # Configs for downstream tasks (F10.7, solar wind)
46
+ │ └── pretrain/ # Configs for MAE pretraining (AIA, HMI)
47
+ ├── notebooks/ # Jupyter notebooks for analysis and visualization
48
+ │ ├── analysis/ # Attention maps, PCA, and masking analysis
49
+ │ └── downstream_apps/ # How to use downstream scripts (Notebooks) for F10.7 and missing data applications
50
+ ├── scripts/ # Executable scripts for training and testing
51
+ │ ├── pretrain.py # Main pretraining script
52
+ │ ├── finetuning_*.py # Scripts for downstream finetuning
53
+ │ └── test.py # Script for evaluating checkpoints
54
+ ├── src/ # Core source code package
55
+ │ └── sdofmv2/
56
+ │ ├── core/ # Base model architectures and modules
57
+ │ ├── tasks/ # PyTorch Lightning modules (model & data module) for downstream tasks
58
+ │ └── utils/ # Helper functions, physical constants and metrics
59
+ ├── pyproject.toml # Project metadata and build dependencies
60
+ └── sdofmv2_environment.yml # Mamba environment definition file
61
+ ```
62
+
63
+ ---
64
+
65
+ ## How to Use
66
+
67
+ *(Note: It is recommended to run all scripts from the root directory of the repository so that file paths to `configs/` and `src/` resolve correctly.)*
68
+
69
+ ### 1. Data Preparation
70
+ Before training or running inference, you need to prepare the dataset.
71
+ [Explain where to download the data, or provide a command if you have a script for it.]
72
+ ```bash
73
+ python scripts/download_data_cache.py --target_dir ./assets/
74
+ ```
75
+
76
+ ### 2. Training the Model
77
+ To train the model from scratch, execute the pretraining script and pass the relevant configuration file.
78
+ ```bash
79
+ python scripts/pretrain.py --config-name pretrain_mae_AIA.yaml
80
+ ```
81
+
82
+ ### 3. Inference and Evaluation
83
+ To evaluate a pre-trained checkpoint on the test set:
84
+ ```bash
85
+ python scripts/test.py --config-name pretrain_mae_AIA.yaml
86
+ ```
87
+
88
+ ### 4. Downstream Finetuning
89
+ To finetune the model on a specific downstream task (e.g., solar wind forecasting):
90
+ ```bash
91
+ python scripts/finetuning_solarwind.py --config-name finetune_solarwind_config.yaml
92
+ ```
93
+ ---
94
+
95
+ ## Results & Visualizations
96
+ [Include a brief summary of the model's performance. You can add a table of metrics or a sample plot showing predictions vs. ground truth.]
97
+
98
+ ![Sample Visualization](notebooks/analysis/SDOFMv2_AIA_results_exp.png)
99
+ *The first row displays the original ground-truth images. The second and third rows show the model's reconstructed images using masking ratios of 0% and 50%, respectively.*
100
+
101
+ ---
102
+
103
+ ## Citation
104
+ If you find this repository or model useful in your academic research, please consider citing our work:
105
+
106
+ ```bibtex
107
+ @misc{sdofmv2,
108
+ author = {Hong, Jinsu and Martin, Daniela and Gallego, Joseph},
109
+ title = {SDOFMv2: A Multi-Instrument Foundation Model for the Solar Dynamics Observatory with Transferable Downstream Applications},
110
+ year = {2026},
111
+ publisher = {GitHub},
112
+ journal = {GitHub repository},
113
+ howpublished = {\url{[https://github.com/Joaggi/sdofmv2](https://github.com/Joaggi/sdofmv2)}},
114
+ note = {Jinsu Hong, Daniela Martin, and Joseph Gallego contributed equally to this work}
115
+ }
116
+ ```
117
+
118
+ ## Contributing
119
+ Contributions, bug reports, and feature requests are welcome! Please feel free to check the [issues page](https://github.com/Joaggi/sdofmv2/issues) or submit a pull request.
@@ -0,0 +1,48 @@
1
+ [project]
2
+ name = "sdofmv2"
3
+ version = "0.1.0"
4
+ description = "Solar phenomena prediction models"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ license = {text = "MIT"}
8
+ authors = [
9
+ {name = "Joseph Gallego", email = "joaggi@gmail.com"},
10
+ {name = "Daniela Martin", email = "dmartinvega@gmail.com"},
11
+ {name = "Jinsu Hong", email = "jinsuhong.knight@gmail.com"},
12
+ ]
13
+ keywords = ["foundation model", "solar physics", "deep learning", "space weather"]
14
+ classifiers = [
15
+ "Programming Language :: Python :: 3",
16
+ "License :: OSI Approved :: MIT License",
17
+ "Operating System :: OS Independent",
18
+ "Intended Audience :: Science/Research",
19
+ "Topic :: Scientific/Engineering :: Astronomy",
20
+ ]
21
+
22
+ dependencies = [
23
+ "torch>=2.5.1",
24
+ "lightning>=2.6.0",
25
+ "numpy>=2.3.5",
26
+ "pandas>=2.3.3",
27
+ "transformers>=4.57.3",
28
+ "sunpy>=7.0.4",
29
+ "astropy>=6.0",
30
+ "timm>=1.0.22",
31
+ "einops>=0.8.1",
32
+ "hydra-core>=1.3.2",
33
+ "wandb>=0.23.1",
34
+ ]
35
+
36
+ [project.urls]
37
+ Repository = "https://github.com/Joaggi/sdofmv2"
38
+ Issues = "https://github.com/Joaggi/sdofmv2/issues"
39
+
40
+ [build-system]
41
+ requires = ["setuptools>=61.0"]
42
+ build-backend = "setuptools.build_meta"
43
+
44
+ [tool.setuptools]
45
+ package-dir = {"" = "src"}
46
+
47
+ [tool.setuptools.packages.find]
48
+ where = ["src"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
File without changes
@@ -0,0 +1,18 @@
1
+ from .basemodule import BaseModule
2
+ from .datamodule import (
3
+ SDOMLDataModule,
4
+ SDOMLDataset,
5
+ inverse_log_norm,
6
+ inverse_zscore_norm,
7
+ )
8
+ from .losses import (
9
+ mae_loss,
10
+ vector_aware_loss,
11
+ pixel_weight_loss,
12
+ )
13
+ from .mae3d import MaskedAutoencoderViT3D
14
+ from .mae3d_old import MaskedAutoencoderViT3D_old
15
+ from .mae_module import MAE
16
+ from .mae_module_old import MAE_old
17
+ from .pca_analysis import mapping_dense_to_rgb
18
+ from .attention_map import plot_heads, patch_attn_layers, visualize_head
@@ -0,0 +1,314 @@
1
+ import os
2
+ import types
3
+ from typing import List, Optional
4
+
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from matplotlib.colors import TwoSlopeNorm
8
+ import sunpy.visualization.colormaps as sunpycm
9
+ from sunpy.visualization.colormaps import color_tables
10
+
11
+ import torch
12
+ from loguru import logger as lgr_logger
13
+ from omegaconf import OmegaConf
14
+ from timm.layers import maybe_add_mask
15
+
16
+ from sdofmv2.core import MAE, SDOMLDataModule
17
+
18
+
19
+ # Attention patching function
20
+ def patch_attn_layers(model: MAE) -> List[torch.Tensor]:
21
+ """
22
+ Monkey-patch the attention layers of a MAE model to store attention maps.
23
+
24
+ Returns:
25
+ attn_maps: List[Tensor] with shape [B, num_heads, N, N] per block
26
+ """
27
+ attn_maps: List[torch.Tensor] = []
28
+
29
+ def patched_forward(self, x, attn_mask=None):
30
+ B, N, C = x.shape
31
+ qkv = (
32
+ self.qkv(x)
33
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
34
+ .permute(2, 0, 3, 1, 4)
35
+ )
36
+ q, k, v = qkv.unbind(0)
37
+ q, k = self.q_norm(q), self.k_norm(k)
38
+
39
+ # Force unfused attention
40
+ if getattr(self, "fused_attn", False):
41
+ self.fused_attn = False
42
+
43
+ # Compute attention
44
+ q = q * self.scale
45
+ attn = q @ k.transpose(-2, -1)
46
+ attn = maybe_add_mask(attn, attn_mask)
47
+ attn = attn.softmax(dim=-1)
48
+ attn_maps.append(attn.detach().cpu())
49
+ attn = self.attn_drop(attn)
50
+ x_out = attn @ v
51
+
52
+ x_out = x_out.transpose(1, 2).reshape(B, N, C)
53
+ x_out = self.norm(x_out)
54
+ x_out = self.proj(x_out)
55
+ x_out = self.proj_drop(x_out)
56
+ return x_out
57
+
58
+ # Patch all encoder blocks
59
+ for blk in model.autoencoder.blocks:
60
+ blk.attn.forward = types.MethodType(patched_forward, blk.attn)
61
+
62
+ return attn_maps
63
+
64
+
65
+ def patch_id_to_xy(patch_id, patch_size=16, grid=32):
66
+ row = patch_id // grid
67
+ col = patch_id % grid
68
+ x = col * patch_size
69
+ y = row * patch_size
70
+ return x, y
71
+
72
+
73
+ def attn_to_image(attn_vector, visible_patch_ids, img_size=512, patch_size=16):
74
+ heatmap = np.zeros((img_size, img_size), dtype=np.float32)
75
+
76
+ # Normalize for visualization
77
+ attn_norm = attn_vector / attn_vector.max()
78
+
79
+ for w, patch_id in zip(attn_norm, visible_patch_ids):
80
+ x, y = patch_id_to_xy(patch_id, patch_size)
81
+ heatmap[y : y + patch_size, x : x + patch_size] = w
82
+
83
+ return heatmap
84
+
85
+
86
+ # Load MAE weights
87
+ def load_mae_weights(ckpt_path: str, masking_ratio: float = 0.5) -> MAE:
88
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
89
+ hparams = ckpt["hyper_parameters"]
90
+
91
+ # if masking_ratio is not None:
92
+ # hparams["masking_ratio"] = masking_ratio
93
+ # print(f"Overriding masking_ratio to {masking_ratio}")
94
+
95
+ # Clean hyperparameters for model construction
96
+ for key in ["create_embedding_file", "lr", "num_classes"]:
97
+ if key in hparams:
98
+ hparams.pop(key)
99
+
100
+ if "wavelengths" in hparams:
101
+ hparams["chan_types"] = hparams.pop("wavelengths")
102
+
103
+ model = MAE(**hparams)
104
+ model.load_state_dict(ckpt["state_dict"], strict=False)
105
+ return model
106
+
107
+
108
+ def visualize_head(attn_head, ids_keep, img_size=512, patch=16):
109
+ """
110
+ attn_head: [num_kept] attention values for ONLY the kept patches
111
+ ids_keep: [num_kept] indices of which patches were kept
112
+ """
113
+ grid_size = img_size // patch # 32
114
+ heatmap = np.zeros((grid_size, grid_size))
115
+
116
+ # Normalize attention values
117
+ attn_norm = (attn_head - attn_head.min()) / (
118
+ attn_head.max() - attn_head.min() + 1e-8
119
+ )
120
+
121
+ # Place attention values at the correct patch positions
122
+ for score, patch_idx in zip(attn_norm, ids_keep):
123
+ row = patch_idx // grid_size
124
+ col = patch_idx % grid_size
125
+ heatmap[row, col] = score
126
+
127
+ # Resize to full image resolution
128
+ heatmap_full = np.repeat(np.repeat(heatmap, patch, axis=0), patch, axis=1)
129
+
130
+ return heatmap_full
131
+
132
+
133
+ # def visualize_head(attn_head, ids_keep, img_size=512, patch=16):
134
+ # heatmap = np.zeros((img_size, img_size))
135
+
136
+ # # attn_head: [549] attention from one query token → each key token
137
+ # attn_norm = attn_head / (attn_head.max() + 1e-6)
138
+
139
+ # for score, patch_id in zip(attn_norm, ids_keep):
140
+ # y, x = divmod(patch_id.item(), 32)
141
+ # y *= patch
142
+ # x *= patch
143
+ # heatmap[y:y+patch, x:x+patch] = score
144
+
145
+ # return heatmap
146
+
147
+
148
+ def plot_heads(
149
+ attn_maps, ids_restore, image, channels=["Bx", "By", "Bz"]
150
+ ): # , patch_id=16
151
+ attn = attn_maps[0][0][:, 1:, 1:] # [num_head, num_patch, num_patch]
152
+ num_heads = attn.shape[0]
153
+ full_order = torch.argsort(ids_restore[0]) # Invert argsort
154
+ num_kept = attn.shape[1]
155
+ ids_keep = full_order[:num_kept]
156
+
157
+ if len(channels) == 3:
158
+ cmap = [
159
+ sunpycm.cmlist.get("hmimag"),
160
+ sunpycm.cmlist.get("hmimag"),
161
+ sunpycm.cmlist.get("hmimag"),
162
+ ]
163
+ norm = TwoSlopeNorm(vmin=-4000, vcenter=0, vmax=4000)
164
+
165
+ elif len(channels) == 9:
166
+ cmap = [
167
+ sunpycm.cmlist.get("sdoaia131"),
168
+ sunpycm.cmlist.get("sdoaia1600"),
169
+ sunpycm.cmlist.get("sdoaia1700"),
170
+ sunpycm.cmlist.get("sdoaia171"),
171
+ sunpycm.cmlist.get("sdoaia193"),
172
+ sunpycm.cmlist.get("sdoaia211"),
173
+ sunpycm.cmlist.get("sdoaia304"),
174
+ sunpycm.cmlist.get("sdoaia335"),
175
+ sunpycm.cmlist.get("sdoaia94"),
176
+ ]
177
+ norm = None
178
+
179
+ else:
180
+ raise ValueError(f"Channel info is wrong")
181
+
182
+ attn_received = attn.mean(axis=1)
183
+ # attn_received = attn[:, 0, 1:] # cls token
184
+ num_images = image.shape[0]
185
+ num_channels = image.shape[1]
186
+
187
+ fig, axs = plt.subplots(
188
+ num_images, num_heads + num_channels, figsize=(25, 4), squeeze=False
189
+ )
190
+
191
+ for i in range(image.shape[0]):
192
+ for i_ch, ch in enumerate(channels):
193
+ axs[i, i_ch].imshow(image[i, i_ch, :, :], cmap=cmap[i_ch], norm=norm)
194
+ axs[i, i_ch].set_title(f"{ch}")
195
+ axs[i, i_ch].axis("off")
196
+
197
+ for h in range(num_heads):
198
+ # head_attn = attn_shuffled[h, patch_id, :] # weights based on patches
199
+ head_attn = attn_received[h, :]
200
+ heatmap = visualize_head(head_attn, ids_keep, 512, 16)
201
+ axs[i, h + num_channels].imshow(heatmap, cmap="jet")
202
+ axs[i, h + num_channels].set_title(f"Head {h}")
203
+ axs[i, h + num_channels].axis("off")
204
+
205
+ plt.tight_layout()
206
+ # plt.savefig("attention_map_no_limb.png", dpi=200)
207
+ return fig, axs
208
+
209
+
210
+ def plot_heads_no_limb(attn_maps, ids_restore, image, patch_id=16):
211
+ attn = attn_maps[0][0][:, 1:, 1:] # [num_head, num_patch, num_patch]
212
+ num_heads = attn.shape[0]
213
+ ids_keep = ids_restore[0, : attn.shape[1]]
214
+ attn_shuffled = attn[:, ids_restore[0]][:, :, ids_restore[0]]
215
+ attn_received = attn_shuffled.mean(axis=1)
216
+ num_images = image.shape[0]
217
+ num_channels = image.shape[1]
218
+
219
+ fig, axs = plt.subplots(
220
+ num_images, num_heads + num_channels, figsize=(25, 4), squeeze=False
221
+ )
222
+
223
+ for i in range(image.shape[0]):
224
+ for i_ch, ch in enumerate(["Bx", "By", "Bz"]):
225
+ axs[i, i_ch].imshow(image[i, i_ch, :, :].to("cpu").numpy(), cmap="gray")
226
+ axs[i, i_ch].set_title(f"Ch: {ch}")
227
+ axs[i, i_ch].axis("off")
228
+
229
+ for h in range(num_heads):
230
+ # head_attn = attn_shuffled[h, patch_id, :] # weights based on patches
231
+ head_attn = attn_received[h, :] # weights based on mean of patches
232
+ heatmap = visualize_head(head_attn, ids_keep, 512, 16)
233
+ axs[i, h + num_channels].imshow(heatmap, cmap="jet")
234
+ axs[i, h + num_channels].set_title(f"Head {h}")
235
+ axs[i, h + num_channels].axis("off")
236
+
237
+ plt.tight_layout()
238
+ # plt.savefig("attention_map_no_limb.png", dpi=200)
239
+ return fig
240
+
241
+
242
+ # Main execution
243
+ if __name__ == "__main__":
244
+ cfg = OmegaConf.load(
245
+ "/home/jhong36/Project/2025-HL-Solar-Wind/solar_phenomena_prediction/configs/pretrain_mae.yaml"
246
+ )
247
+
248
+ # Setup dataset
249
+ data_module = SDOMLDataModule(
250
+ hmi_path=(
251
+ os.path.join(
252
+ cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.hmi
253
+ )
254
+ if cfg.data.sdoml.sub_directory.hmi
255
+ else None
256
+ ),
257
+ aia_path=(
258
+ os.path.join(
259
+ cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.aia
260
+ )
261
+ if cfg.data.sdoml.sub_directory.aia
262
+ else None
263
+ ),
264
+ eve_path=None,
265
+ components=cfg.data.sdoml.components,
266
+ wavelengths=cfg.data.sdoml.wavelengths,
267
+ ions=cfg.data.sdoml.ions,
268
+ frequency=cfg.data.sdoml.frequency,
269
+ batch_size=cfg.model.opt.batch_size,
270
+ num_workers=cfg.data.num_workers,
271
+ val_months=cfg.data.month_splits.val,
272
+ test_months=cfg.data.month_splits.test,
273
+ holdout_months=cfg.data.month_splits.holdout,
274
+ cache_dir=os.path.join(
275
+ cfg.data.sdoml.save_directory, cfg.data.sdoml.sub_directory.cache
276
+ ),
277
+ min_date=cfg.data.min_date,
278
+ max_date=cfg.data.max_date,
279
+ num_frames=cfg.model.mae.num_frames,
280
+ drop_frame_dim=cfg.data.drop_frame_dim,
281
+ apply_mask=cfg.data.sdoml.apply_mask,
282
+ precision=cfg.experiment.precision,
283
+ normalization=cfg.data.sdoml.normalization,
284
+ )
285
+ data_module.setup()
286
+
287
+ # Load model
288
+ base_path = "../../../../assets/check_point/backbone/"
289
+ model_hmi = load_mae_weights(
290
+ os.path.join(base_path, "id_xn2c11go_mae_epoch=25-val_loss=0.00.ckpt"),
291
+ # masking_ratio=0 # full image by default
292
+ )
293
+ # model_hmi.autoencoder.ids_limb_mask = None
294
+ model_hmi.eval()
295
+
296
+ # Patch attention layers
297
+ attn_maps = patch_attn_layers(model_hmi)
298
+
299
+ # Forward pass
300
+ id_input = 0
301
+ x = data_module.test_ds[id_input][0].unsqueeze(0)
302
+ lgr_logger.info(f"Input shape: {x.shape}")
303
+
304
+ with torch.no_grad():
305
+ latent, mask, ids_restore = model_hmi.autoencoder.forward_encoder(
306
+ x, mask_ratio=0
307
+ )
308
+
309
+ # Example: check first attention map
310
+ if len(attn_maps) > 0:
311
+ lgr_logger.info(f"First attention map shape: {attn_maps[0].shape}")
312
+
313
+ fig = plot_heads(attn_maps, ids_restore, x[:, 0:1, 0, :, :])
314
+ # fig = plot_heads_no_limb(attn_maps, ids_restore, x[:, 0:1, 0, :, :])