sdofmv2 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdofmv2-0.1.0/LICENSE +21 -0
- sdofmv2-0.1.0/PKG-INFO +149 -0
- sdofmv2-0.1.0/README.md +119 -0
- sdofmv2-0.1.0/pyproject.toml +48 -0
- sdofmv2-0.1.0/setup.cfg +4 -0
- sdofmv2-0.1.0/src/sdofmv2/__init__.py +0 -0
- sdofmv2-0.1.0/src/sdofmv2/core/__init__.py +18 -0
- sdofmv2-0.1.0/src/sdofmv2/core/attention_map.py +314 -0
- sdofmv2-0.1.0/src/sdofmv2/core/basemodule.py +140 -0
- sdofmv2-0.1.0/src/sdofmv2/core/datamodule.py +1169 -0
- sdofmv2-0.1.0/src/sdofmv2/core/losses.py +87 -0
- sdofmv2-0.1.0/src/sdofmv2/core/mae3d.py +491 -0
- sdofmv2-0.1.0/src/sdofmv2/core/mae3d_old.py +413 -0
- sdofmv2-0.1.0/src/sdofmv2/core/mae_module.py +262 -0
- sdofmv2-0.1.0/src/sdofmv2/core/mae_module_old.py +222 -0
- sdofmv2-0.1.0/src/sdofmv2/core/pca_analysis.py +49 -0
- sdofmv2-0.1.0/src/sdofmv2/core/reconstruction.py +164 -0
- sdofmv2-0.1.0/src/sdofmv2/tasks/__init__.py +0 -0
- sdofmv2-0.1.0/src/sdofmv2/tasks/f107/__init__.py +2 -0
- sdofmv2-0.1.0/src/sdofmv2/tasks/f107/f107_datamodule.py +232 -0
- sdofmv2-0.1.0/src/sdofmv2/tasks/f107/f107_module.py +138 -0
- sdofmv2-0.1.0/src/sdofmv2/tasks/missing_data/__init__.py +3 -0
- sdofmv2-0.1.0/src/sdofmv2/tasks/missing_data/missing_data_module.py +99 -0
- sdofmv2-0.1.0/src/sdofmv2/tasks/missing_data/necks.py +154 -0
- sdofmv2-0.1.0/src/sdofmv2/tasks/missing_data/wrap_encoder.py +43 -0
- sdofmv2-0.1.0/src/sdofmv2/tasks/solar_wind/__init__.py +4 -0
- sdofmv2-0.1.0/src/sdofmv2/tasks/solar_wind/datamodule.py +579 -0
- sdofmv2-0.1.0/src/sdofmv2/tasks/solar_wind/focal_loss.py +44 -0
- sdofmv2-0.1.0/src/sdofmv2/tasks/solar_wind/head_networks.py +296 -0
- sdofmv2-0.1.0/src/sdofmv2/tasks/solar_wind/model.py +582 -0
- sdofmv2-0.1.0/src/sdofmv2/tasks/solar_wind/visualization.py +623 -0
- sdofmv2-0.1.0/src/sdofmv2/utils/__init__.py +17 -0
- sdofmv2-0.1.0/src/sdofmv2/utils/constants.py +55 -0
- sdofmv2-0.1.0/src/sdofmv2/utils/utils.py +251 -0
- sdofmv2-0.1.0/src/sdofmv2.egg-info/PKG-INFO +149 -0
- sdofmv2-0.1.0/src/sdofmv2.egg-info/SOURCES.txt +37 -0
- sdofmv2-0.1.0/src/sdofmv2.egg-info/dependency_links.txt +1 -0
- sdofmv2-0.1.0/src/sdofmv2.egg-info/requires.txt +11 -0
- sdofmv2-0.1.0/src/sdofmv2.egg-info/top_level.txt +1 -0
sdofmv2-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Joseph Gallego
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
sdofmv2-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: sdofmv2
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Solar phenomena prediction models
|
|
5
|
+
Author-email: Joseph Gallego <joaggi@gmail.com>, Daniela Martin <dmartinvega@gmail.com>, Jinsu Hong <jinsuhong.knight@gmail.com>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Repository, https://github.com/Joaggi/sdofmv2
|
|
8
|
+
Project-URL: Issues, https://github.com/Joaggi/sdofmv2/issues
|
|
9
|
+
Keywords: foundation model,solar physics,deep learning,space weather
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Operating System :: OS Independent
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: Topic :: Scientific/Engineering :: Astronomy
|
|
15
|
+
Requires-Python: >=3.11
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
License-File: LICENSE
|
|
18
|
+
Requires-Dist: torch>=2.5.1
|
|
19
|
+
Requires-Dist: lightning>=2.6.0
|
|
20
|
+
Requires-Dist: numpy>=2.3.5
|
|
21
|
+
Requires-Dist: pandas>=2.3.3
|
|
22
|
+
Requires-Dist: transformers>=4.57.3
|
|
23
|
+
Requires-Dist: sunpy>=7.0.4
|
|
24
|
+
Requires-Dist: astropy>=6.0
|
|
25
|
+
Requires-Dist: timm>=1.0.22
|
|
26
|
+
Requires-Dist: einops>=0.8.1
|
|
27
|
+
Requires-Dist: hydra-core>=1.3.2
|
|
28
|
+
Requires-Dist: wandb>=0.23.1
|
|
29
|
+
Dynamic: license-file
|
|
30
|
+
|
|
31
|
+
# SDO FM v2: [Full Title of the Project/Model]
|
|
32
|
+
|
|
33
|
+
[](https://www.python.org/downloads/)
|
|
34
|
+
[](https://pytorch.org/)
|
|
35
|
+
[](https://lightning.ai/docs/pytorch/stable/)
|
|
36
|
+
[](https://opensource.org/licenses/MIT)
|
|
37
|
+
|
|
38
|
+
## Introduction
|
|
39
|
+
**SDOFMv2** is an advanced multi-instrument foundation model designed to analyze Solar Dynamics Observatory (SDO) data and drive large-scale, data-driven heliophysics research. Building upon the original SDOFM framework, this version addresses previous limitations like restricted temporal coverage and reconstruction artifacts to significantly improve spatial coherence and global consistency.
|
|
40
|
+
|
|
41
|
+

|
|
42
|
+
*A Masked Autoencoder (MAE) based on a Vision Transformer (ViT) architecture is utilized for pretraining. During this phase, a% of the image patches are masked, while the remaining (100 - a)% are processed by the encoder. The decoder block then reconstructs all patches, optimized via a customized loss function.*
|
|
43
|
+
|
|
44
|
+
---
|
|
45
|
+
|
|
46
|
+
## Getting Started
|
|
47
|
+
|
|
48
|
+
### Prerequisites
|
|
49
|
+
* Linux or macOS
|
|
50
|
+
* Python 3.11+
|
|
51
|
+
* NVIDIA GPU + CUDA toolkit (Recommended for training)
|
|
52
|
+
|
|
53
|
+
### Environment Setup
|
|
54
|
+
We recommend using `mamba` to manage dependencies.
|
|
55
|
+
|
|
56
|
+
> **Important Hardware Note:** > The `sdofmv2_environment.yml` file is configured for **CUDA 12.8** by default. If your hardware or drivers require a different CUDA version (e.g., CUDA 11.8), please open `sdofmv2_environment.yml` and modify the `pip` section at the bottom to match your system (e.g., change `cu128` to `cu118`) before running the setup commands.
|
|
57
|
+
|
|
58
|
+
**Using Mamba:**
|
|
59
|
+
```bash
|
|
60
|
+
# Clone the repository
|
|
61
|
+
git clone [https://github.com/Joaggi/sdofmv2.git](https://github.com/Joaggi/sdofmv2.git)
|
|
62
|
+
cd sdofmv2
|
|
63
|
+
|
|
64
|
+
# Create and activate the environment (This automatically installs PyTorch and the local package)
|
|
65
|
+
mamba env create -f sdofmv2_environment.yml
|
|
66
|
+
mamba activate sdofmv2
|
|
67
|
+
```
|
|
68
|
+
---
|
|
69
|
+
|
|
70
|
+
## Repository Structure
|
|
71
|
+
|
|
72
|
+
```text
|
|
73
|
+
.
|
|
74
|
+
├── configs/ # YAML configurations for experiments
|
|
75
|
+
│ ├── downstream/ # Configs for downstream tasks (F10.7, solar wind)
|
|
76
|
+
│ └── pretrain/ # Configs for MAE pretraining (AIA, HMI)
|
|
77
|
+
├── notebooks/ # Jupyter notebooks for analysis and visualization
|
|
78
|
+
│ ├── analysis/ # Attention maps, PCA, and masking analysis
|
|
79
|
+
│ └── downstream_apps/ # How to use downstream scripts (Notebooks) for F10.7 and missing data applications
|
|
80
|
+
├── scripts/ # Executable scripts for training and testing
|
|
81
|
+
│ ├── pretrain.py # Main pretraining script
|
|
82
|
+
│ ├── finetuning_*.py # Scripts for downstream finetuning
|
|
83
|
+
│ └── test.py # Script for evaluating checkpoints
|
|
84
|
+
├── src/ # Core source code package
|
|
85
|
+
│ └── sdofmv2/
|
|
86
|
+
│ ├── core/ # Base model architectures and modules
|
|
87
|
+
│ ├── tasks/ # PyTorch Lightning modules (model & data module) for downstream tasks
|
|
88
|
+
│ └── utils/ # Helper functions, physical constants and metrics
|
|
89
|
+
├── pyproject.toml # Project metadata and build dependencies
|
|
90
|
+
└── sdofmv2_environment.yml # Mamba environment definition file
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
---
|
|
94
|
+
|
|
95
|
+
## How to Use
|
|
96
|
+
|
|
97
|
+
*(Note: It is recommended to run all scripts from the root directory of the repository so that file paths to `configs/` and `src/` resolve correctly.)*
|
|
98
|
+
|
|
99
|
+
### 1. Data Preparation
|
|
100
|
+
Before training or running inference, you need to prepare the dataset.
|
|
101
|
+
[Explain where to download the data, or provide a command if you have a script for it.]
|
|
102
|
+
```bash
|
|
103
|
+
python scripts/download_data_cache.py --target_dir ./assets/
|
|
104
|
+
```
|
|
105
|
+
|
|
106
|
+
### 2. Training the Model
|
|
107
|
+
To train the model from scratch, execute the pretraining script and pass the relevant configuration file.
|
|
108
|
+
```bash
|
|
109
|
+
python scripts/pretrain.py --config-name pretrain_mae_AIA.yaml
|
|
110
|
+
```
|
|
111
|
+
|
|
112
|
+
### 3. Inference and Evaluation
|
|
113
|
+
To evaluate a pre-trained checkpoint on the test set:
|
|
114
|
+
```bash
|
|
115
|
+
python scripts/test.py --config-name pretrain_mae_AIA.yaml
|
|
116
|
+
```
|
|
117
|
+
|
|
118
|
+
### 4. Downstream Finetuning
|
|
119
|
+
To finetune the model on a specific downstream task (e.g., solar wind forecasting):
|
|
120
|
+
```bash
|
|
121
|
+
python scripts/finetuning_solarwind.py --config-name finetune_solarwind_config.yaml
|
|
122
|
+
```
|
|
123
|
+
---
|
|
124
|
+
|
|
125
|
+
## Results & Visualizations
|
|
126
|
+
[Include a brief summary of the model's performance. You can add a table of metrics or a sample plot showing predictions vs. ground truth.]
|
|
127
|
+
|
|
128
|
+

|
|
129
|
+
*The first row displays the original ground-truth images. The second and third rows show the model's reconstructed images using masking ratios of 0% and 50%, respectively.*
|
|
130
|
+
|
|
131
|
+
---
|
|
132
|
+
|
|
133
|
+
## Citation
|
|
134
|
+
If you find this repository or model useful in your academic research, please consider citing our work:
|
|
135
|
+
|
|
136
|
+
```bibtex
|
|
137
|
+
@misc{sdofmv2,
|
|
138
|
+
author = {Hong, Jinsu and Martin, Daniela and Gallego, Joseph},
|
|
139
|
+
title = {SDOFMv2: A Multi-Instrument Foundation Model for the Solar Dynamics Observatory with Transferable Downstream Applications},
|
|
140
|
+
year = {2026},
|
|
141
|
+
publisher = {GitHub},
|
|
142
|
+
journal = {GitHub repository},
|
|
143
|
+
howpublished = {\url{[https://github.com/Joaggi/sdofmv2](https://github.com/Joaggi/sdofmv2)}},
|
|
144
|
+
note = {Jinsu Hong, Daniela Martin, and Joseph Gallego contributed equally to this work}
|
|
145
|
+
}
|
|
146
|
+
```
|
|
147
|
+
|
|
148
|
+
## Contributing
|
|
149
|
+
Contributions, bug reports, and feature requests are welcome! Please feel free to check the [issues page](https://github.com/Joaggi/sdofmv2/issues) or submit a pull request.
|
sdofmv2-0.1.0/README.md
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
# SDO FM v2: [Full Title of the Project/Model]
|
|
2
|
+
|
|
3
|
+
[](https://www.python.org/downloads/)
|
|
4
|
+
[](https://pytorch.org/)
|
|
5
|
+
[](https://lightning.ai/docs/pytorch/stable/)
|
|
6
|
+
[](https://opensource.org/licenses/MIT)
|
|
7
|
+
|
|
8
|
+
## Introduction
|
|
9
|
+
**SDOFMv2** is an advanced multi-instrument foundation model designed to analyze Solar Dynamics Observatory (SDO) data and drive large-scale, data-driven heliophysics research. Building upon the original SDOFM framework, this version addresses previous limitations like restricted temporal coverage and reconstruction artifacts to significantly improve spatial coherence and global consistency.
|
|
10
|
+
|
|
11
|
+

|
|
12
|
+
*A Masked Autoencoder (MAE) based on a Vision Transformer (ViT) architecture is utilized for pretraining. During this phase, a% of the image patches are masked, while the remaining (100 - a)% are processed by the encoder. The decoder block then reconstructs all patches, optimized via a customized loss function.*
|
|
13
|
+
|
|
14
|
+
---
|
|
15
|
+
|
|
16
|
+
## Getting Started
|
|
17
|
+
|
|
18
|
+
### Prerequisites
|
|
19
|
+
* Linux or macOS
|
|
20
|
+
* Python 3.11+
|
|
21
|
+
* NVIDIA GPU + CUDA toolkit (Recommended for training)
|
|
22
|
+
|
|
23
|
+
### Environment Setup
|
|
24
|
+
We recommend using `mamba` to manage dependencies.
|
|
25
|
+
|
|
26
|
+
> **Important Hardware Note:** > The `sdofmv2_environment.yml` file is configured for **CUDA 12.8** by default. If your hardware or drivers require a different CUDA version (e.g., CUDA 11.8), please open `sdofmv2_environment.yml` and modify the `pip` section at the bottom to match your system (e.g., change `cu128` to `cu118`) before running the setup commands.
|
|
27
|
+
|
|
28
|
+
**Using Mamba:**
|
|
29
|
+
```bash
|
|
30
|
+
# Clone the repository
|
|
31
|
+
git clone [https://github.com/Joaggi/sdofmv2.git](https://github.com/Joaggi/sdofmv2.git)
|
|
32
|
+
cd sdofmv2
|
|
33
|
+
|
|
34
|
+
# Create and activate the environment (This automatically installs PyTorch and the local package)
|
|
35
|
+
mamba env create -f sdofmv2_environment.yml
|
|
36
|
+
mamba activate sdofmv2
|
|
37
|
+
```
|
|
38
|
+
---
|
|
39
|
+
|
|
40
|
+
## Repository Structure
|
|
41
|
+
|
|
42
|
+
```text
|
|
43
|
+
.
|
|
44
|
+
├── configs/ # YAML configurations for experiments
|
|
45
|
+
│ ├── downstream/ # Configs for downstream tasks (F10.7, solar wind)
|
|
46
|
+
│ └── pretrain/ # Configs for MAE pretraining (AIA, HMI)
|
|
47
|
+
├── notebooks/ # Jupyter notebooks for analysis and visualization
|
|
48
|
+
│ ├── analysis/ # Attention maps, PCA, and masking analysis
|
|
49
|
+
│ └── downstream_apps/ # How to use downstream scripts (Notebooks) for F10.7 and missing data applications
|
|
50
|
+
├── scripts/ # Executable scripts for training and testing
|
|
51
|
+
│ ├── pretrain.py # Main pretraining script
|
|
52
|
+
│ ├── finetuning_*.py # Scripts for downstream finetuning
|
|
53
|
+
│ └── test.py # Script for evaluating checkpoints
|
|
54
|
+
├── src/ # Core source code package
|
|
55
|
+
│ └── sdofmv2/
|
|
56
|
+
│ ├── core/ # Base model architectures and modules
|
|
57
|
+
│ ├── tasks/ # PyTorch Lightning modules (model & data module) for downstream tasks
|
|
58
|
+
│ └── utils/ # Helper functions, physical constants and metrics
|
|
59
|
+
├── pyproject.toml # Project metadata and build dependencies
|
|
60
|
+
└── sdofmv2_environment.yml # Mamba environment definition file
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
---
|
|
64
|
+
|
|
65
|
+
## How to Use
|
|
66
|
+
|
|
67
|
+
*(Note: It is recommended to run all scripts from the root directory of the repository so that file paths to `configs/` and `src/` resolve correctly.)*
|
|
68
|
+
|
|
69
|
+
### 1. Data Preparation
|
|
70
|
+
Before training or running inference, you need to prepare the dataset.
|
|
71
|
+
[Explain where to download the data, or provide a command if you have a script for it.]
|
|
72
|
+
```bash
|
|
73
|
+
python scripts/download_data_cache.py --target_dir ./assets/
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
### 2. Training the Model
|
|
77
|
+
To train the model from scratch, execute the pretraining script and pass the relevant configuration file.
|
|
78
|
+
```bash
|
|
79
|
+
python scripts/pretrain.py --config-name pretrain_mae_AIA.yaml
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
### 3. Inference and Evaluation
|
|
83
|
+
To evaluate a pre-trained checkpoint on the test set:
|
|
84
|
+
```bash
|
|
85
|
+
python scripts/test.py --config-name pretrain_mae_AIA.yaml
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
### 4. Downstream Finetuning
|
|
89
|
+
To finetune the model on a specific downstream task (e.g., solar wind forecasting):
|
|
90
|
+
```bash
|
|
91
|
+
python scripts/finetuning_solarwind.py --config-name finetune_solarwind_config.yaml
|
|
92
|
+
```
|
|
93
|
+
---
|
|
94
|
+
|
|
95
|
+
## Results & Visualizations
|
|
96
|
+
[Include a brief summary of the model's performance. You can add a table of metrics or a sample plot showing predictions vs. ground truth.]
|
|
97
|
+
|
|
98
|
+

|
|
99
|
+
*The first row displays the original ground-truth images. The second and third rows show the model's reconstructed images using masking ratios of 0% and 50%, respectively.*
|
|
100
|
+
|
|
101
|
+
---
|
|
102
|
+
|
|
103
|
+
## Citation
|
|
104
|
+
If you find this repository or model useful in your academic research, please consider citing our work:
|
|
105
|
+
|
|
106
|
+
```bibtex
|
|
107
|
+
@misc{sdofmv2,
|
|
108
|
+
author = {Hong, Jinsu and Martin, Daniela and Gallego, Joseph},
|
|
109
|
+
title = {SDOFMv2: A Multi-Instrument Foundation Model for the Solar Dynamics Observatory with Transferable Downstream Applications},
|
|
110
|
+
year = {2026},
|
|
111
|
+
publisher = {GitHub},
|
|
112
|
+
journal = {GitHub repository},
|
|
113
|
+
howpublished = {\url{[https://github.com/Joaggi/sdofmv2](https://github.com/Joaggi/sdofmv2)}},
|
|
114
|
+
note = {Jinsu Hong, Daniela Martin, and Joseph Gallego contributed equally to this work}
|
|
115
|
+
}
|
|
116
|
+
```
|
|
117
|
+
|
|
118
|
+
## Contributing
|
|
119
|
+
Contributions, bug reports, and feature requests are welcome! Please feel free to check the [issues page](https://github.com/Joaggi/sdofmv2/issues) or submit a pull request.
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "sdofmv2"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Solar phenomena prediction models"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.11"
|
|
7
|
+
license = {text = "MIT"}
|
|
8
|
+
authors = [
|
|
9
|
+
{name = "Joseph Gallego", email = "joaggi@gmail.com"},
|
|
10
|
+
{name = "Daniela Martin", email = "dmartinvega@gmail.com"},
|
|
11
|
+
{name = "Jinsu Hong", email = "jinsuhong.knight@gmail.com"},
|
|
12
|
+
]
|
|
13
|
+
keywords = ["foundation model", "solar physics", "deep learning", "space weather"]
|
|
14
|
+
classifiers = [
|
|
15
|
+
"Programming Language :: Python :: 3",
|
|
16
|
+
"License :: OSI Approved :: MIT License",
|
|
17
|
+
"Operating System :: OS Independent",
|
|
18
|
+
"Intended Audience :: Science/Research",
|
|
19
|
+
"Topic :: Scientific/Engineering :: Astronomy",
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
dependencies = [
|
|
23
|
+
"torch>=2.5.1",
|
|
24
|
+
"lightning>=2.6.0",
|
|
25
|
+
"numpy>=2.3.5",
|
|
26
|
+
"pandas>=2.3.3",
|
|
27
|
+
"transformers>=4.57.3",
|
|
28
|
+
"sunpy>=7.0.4",
|
|
29
|
+
"astropy>=6.0",
|
|
30
|
+
"timm>=1.0.22",
|
|
31
|
+
"einops>=0.8.1",
|
|
32
|
+
"hydra-core>=1.3.2",
|
|
33
|
+
"wandb>=0.23.1",
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
[project.urls]
|
|
37
|
+
Repository = "https://github.com/Joaggi/sdofmv2"
|
|
38
|
+
Issues = "https://github.com/Joaggi/sdofmv2/issues"
|
|
39
|
+
|
|
40
|
+
[build-system]
|
|
41
|
+
requires = ["setuptools>=61.0"]
|
|
42
|
+
build-backend = "setuptools.build_meta"
|
|
43
|
+
|
|
44
|
+
[tool.setuptools]
|
|
45
|
+
package-dir = {"" = "src"}
|
|
46
|
+
|
|
47
|
+
[tool.setuptools.packages.find]
|
|
48
|
+
where = ["src"]
|
sdofmv2-0.1.0/setup.cfg
ADDED
|
File without changes
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from .basemodule import BaseModule
|
|
2
|
+
from .datamodule import (
|
|
3
|
+
SDOMLDataModule,
|
|
4
|
+
SDOMLDataset,
|
|
5
|
+
inverse_log_norm,
|
|
6
|
+
inverse_zscore_norm,
|
|
7
|
+
)
|
|
8
|
+
from .losses import (
|
|
9
|
+
mae_loss,
|
|
10
|
+
vector_aware_loss,
|
|
11
|
+
pixel_weight_loss,
|
|
12
|
+
)
|
|
13
|
+
from .mae3d import MaskedAutoencoderViT3D
|
|
14
|
+
from .mae3d_old import MaskedAutoencoderViT3D_old
|
|
15
|
+
from .mae_module import MAE
|
|
16
|
+
from .mae_module_old import MAE_old
|
|
17
|
+
from .pca_analysis import mapping_dense_to_rgb
|
|
18
|
+
from .attention_map import plot_heads, patch_attn_layers, visualize_head
|
|
@@ -0,0 +1,314 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import types
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
from matplotlib.colors import TwoSlopeNorm
|
|
8
|
+
import sunpy.visualization.colormaps as sunpycm
|
|
9
|
+
from sunpy.visualization.colormaps import color_tables
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
from loguru import logger as lgr_logger
|
|
13
|
+
from omegaconf import OmegaConf
|
|
14
|
+
from timm.layers import maybe_add_mask
|
|
15
|
+
|
|
16
|
+
from sdofmv2.core import MAE, SDOMLDataModule
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# Attention patching function
|
|
20
|
+
def patch_attn_layers(model: MAE) -> List[torch.Tensor]:
|
|
21
|
+
"""
|
|
22
|
+
Monkey-patch the attention layers of a MAE model to store attention maps.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
attn_maps: List[Tensor] with shape [B, num_heads, N, N] per block
|
|
26
|
+
"""
|
|
27
|
+
attn_maps: List[torch.Tensor] = []
|
|
28
|
+
|
|
29
|
+
def patched_forward(self, x, attn_mask=None):
|
|
30
|
+
B, N, C = x.shape
|
|
31
|
+
qkv = (
|
|
32
|
+
self.qkv(x)
|
|
33
|
+
.reshape(B, N, 3, self.num_heads, self.head_dim)
|
|
34
|
+
.permute(2, 0, 3, 1, 4)
|
|
35
|
+
)
|
|
36
|
+
q, k, v = qkv.unbind(0)
|
|
37
|
+
q, k = self.q_norm(q), self.k_norm(k)
|
|
38
|
+
|
|
39
|
+
# Force unfused attention
|
|
40
|
+
if getattr(self, "fused_attn", False):
|
|
41
|
+
self.fused_attn = False
|
|
42
|
+
|
|
43
|
+
# Compute attention
|
|
44
|
+
q = q * self.scale
|
|
45
|
+
attn = q @ k.transpose(-2, -1)
|
|
46
|
+
attn = maybe_add_mask(attn, attn_mask)
|
|
47
|
+
attn = attn.softmax(dim=-1)
|
|
48
|
+
attn_maps.append(attn.detach().cpu())
|
|
49
|
+
attn = self.attn_drop(attn)
|
|
50
|
+
x_out = attn @ v
|
|
51
|
+
|
|
52
|
+
x_out = x_out.transpose(1, 2).reshape(B, N, C)
|
|
53
|
+
x_out = self.norm(x_out)
|
|
54
|
+
x_out = self.proj(x_out)
|
|
55
|
+
x_out = self.proj_drop(x_out)
|
|
56
|
+
return x_out
|
|
57
|
+
|
|
58
|
+
# Patch all encoder blocks
|
|
59
|
+
for blk in model.autoencoder.blocks:
|
|
60
|
+
blk.attn.forward = types.MethodType(patched_forward, blk.attn)
|
|
61
|
+
|
|
62
|
+
return attn_maps
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def patch_id_to_xy(patch_id, patch_size=16, grid=32):
|
|
66
|
+
row = patch_id // grid
|
|
67
|
+
col = patch_id % grid
|
|
68
|
+
x = col * patch_size
|
|
69
|
+
y = row * patch_size
|
|
70
|
+
return x, y
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def attn_to_image(attn_vector, visible_patch_ids, img_size=512, patch_size=16):
|
|
74
|
+
heatmap = np.zeros((img_size, img_size), dtype=np.float32)
|
|
75
|
+
|
|
76
|
+
# Normalize for visualization
|
|
77
|
+
attn_norm = attn_vector / attn_vector.max()
|
|
78
|
+
|
|
79
|
+
for w, patch_id in zip(attn_norm, visible_patch_ids):
|
|
80
|
+
x, y = patch_id_to_xy(patch_id, patch_size)
|
|
81
|
+
heatmap[y : y + patch_size, x : x + patch_size] = w
|
|
82
|
+
|
|
83
|
+
return heatmap
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
# Load MAE weights
|
|
87
|
+
def load_mae_weights(ckpt_path: str, masking_ratio: float = 0.5) -> MAE:
|
|
88
|
+
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
|
89
|
+
hparams = ckpt["hyper_parameters"]
|
|
90
|
+
|
|
91
|
+
# if masking_ratio is not None:
|
|
92
|
+
# hparams["masking_ratio"] = masking_ratio
|
|
93
|
+
# print(f"Overriding masking_ratio to {masking_ratio}")
|
|
94
|
+
|
|
95
|
+
# Clean hyperparameters for model construction
|
|
96
|
+
for key in ["create_embedding_file", "lr", "num_classes"]:
|
|
97
|
+
if key in hparams:
|
|
98
|
+
hparams.pop(key)
|
|
99
|
+
|
|
100
|
+
if "wavelengths" in hparams:
|
|
101
|
+
hparams["chan_types"] = hparams.pop("wavelengths")
|
|
102
|
+
|
|
103
|
+
model = MAE(**hparams)
|
|
104
|
+
model.load_state_dict(ckpt["state_dict"], strict=False)
|
|
105
|
+
return model
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def visualize_head(attn_head, ids_keep, img_size=512, patch=16):
|
|
109
|
+
"""
|
|
110
|
+
attn_head: [num_kept] attention values for ONLY the kept patches
|
|
111
|
+
ids_keep: [num_kept] indices of which patches were kept
|
|
112
|
+
"""
|
|
113
|
+
grid_size = img_size // patch # 32
|
|
114
|
+
heatmap = np.zeros((grid_size, grid_size))
|
|
115
|
+
|
|
116
|
+
# Normalize attention values
|
|
117
|
+
attn_norm = (attn_head - attn_head.min()) / (
|
|
118
|
+
attn_head.max() - attn_head.min() + 1e-8
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# Place attention values at the correct patch positions
|
|
122
|
+
for score, patch_idx in zip(attn_norm, ids_keep):
|
|
123
|
+
row = patch_idx // grid_size
|
|
124
|
+
col = patch_idx % grid_size
|
|
125
|
+
heatmap[row, col] = score
|
|
126
|
+
|
|
127
|
+
# Resize to full image resolution
|
|
128
|
+
heatmap_full = np.repeat(np.repeat(heatmap, patch, axis=0), patch, axis=1)
|
|
129
|
+
|
|
130
|
+
return heatmap_full
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
# def visualize_head(attn_head, ids_keep, img_size=512, patch=16):
|
|
134
|
+
# heatmap = np.zeros((img_size, img_size))
|
|
135
|
+
|
|
136
|
+
# # attn_head: [549] attention from one query token → each key token
|
|
137
|
+
# attn_norm = attn_head / (attn_head.max() + 1e-6)
|
|
138
|
+
|
|
139
|
+
# for score, patch_id in zip(attn_norm, ids_keep):
|
|
140
|
+
# y, x = divmod(patch_id.item(), 32)
|
|
141
|
+
# y *= patch
|
|
142
|
+
# x *= patch
|
|
143
|
+
# heatmap[y:y+patch, x:x+patch] = score
|
|
144
|
+
|
|
145
|
+
# return heatmap
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def plot_heads(
|
|
149
|
+
attn_maps, ids_restore, image, channels=["Bx", "By", "Bz"]
|
|
150
|
+
): # , patch_id=16
|
|
151
|
+
attn = attn_maps[0][0][:, 1:, 1:] # [num_head, num_patch, num_patch]
|
|
152
|
+
num_heads = attn.shape[0]
|
|
153
|
+
full_order = torch.argsort(ids_restore[0]) # Invert argsort
|
|
154
|
+
num_kept = attn.shape[1]
|
|
155
|
+
ids_keep = full_order[:num_kept]
|
|
156
|
+
|
|
157
|
+
if len(channels) == 3:
|
|
158
|
+
cmap = [
|
|
159
|
+
sunpycm.cmlist.get("hmimag"),
|
|
160
|
+
sunpycm.cmlist.get("hmimag"),
|
|
161
|
+
sunpycm.cmlist.get("hmimag"),
|
|
162
|
+
]
|
|
163
|
+
norm = TwoSlopeNorm(vmin=-4000, vcenter=0, vmax=4000)
|
|
164
|
+
|
|
165
|
+
elif len(channels) == 9:
|
|
166
|
+
cmap = [
|
|
167
|
+
sunpycm.cmlist.get("sdoaia131"),
|
|
168
|
+
sunpycm.cmlist.get("sdoaia1600"),
|
|
169
|
+
sunpycm.cmlist.get("sdoaia1700"),
|
|
170
|
+
sunpycm.cmlist.get("sdoaia171"),
|
|
171
|
+
sunpycm.cmlist.get("sdoaia193"),
|
|
172
|
+
sunpycm.cmlist.get("sdoaia211"),
|
|
173
|
+
sunpycm.cmlist.get("sdoaia304"),
|
|
174
|
+
sunpycm.cmlist.get("sdoaia335"),
|
|
175
|
+
sunpycm.cmlist.get("sdoaia94"),
|
|
176
|
+
]
|
|
177
|
+
norm = None
|
|
178
|
+
|
|
179
|
+
else:
|
|
180
|
+
raise ValueError(f"Channel info is wrong")
|
|
181
|
+
|
|
182
|
+
attn_received = attn.mean(axis=1)
|
|
183
|
+
# attn_received = attn[:, 0, 1:] # cls token
|
|
184
|
+
num_images = image.shape[0]
|
|
185
|
+
num_channels = image.shape[1]
|
|
186
|
+
|
|
187
|
+
fig, axs = plt.subplots(
|
|
188
|
+
num_images, num_heads + num_channels, figsize=(25, 4), squeeze=False
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
for i in range(image.shape[0]):
|
|
192
|
+
for i_ch, ch in enumerate(channels):
|
|
193
|
+
axs[i, i_ch].imshow(image[i, i_ch, :, :], cmap=cmap[i_ch], norm=norm)
|
|
194
|
+
axs[i, i_ch].set_title(f"{ch}")
|
|
195
|
+
axs[i, i_ch].axis("off")
|
|
196
|
+
|
|
197
|
+
for h in range(num_heads):
|
|
198
|
+
# head_attn = attn_shuffled[h, patch_id, :] # weights based on patches
|
|
199
|
+
head_attn = attn_received[h, :]
|
|
200
|
+
heatmap = visualize_head(head_attn, ids_keep, 512, 16)
|
|
201
|
+
axs[i, h + num_channels].imshow(heatmap, cmap="jet")
|
|
202
|
+
axs[i, h + num_channels].set_title(f"Head {h}")
|
|
203
|
+
axs[i, h + num_channels].axis("off")
|
|
204
|
+
|
|
205
|
+
plt.tight_layout()
|
|
206
|
+
# plt.savefig("attention_map_no_limb.png", dpi=200)
|
|
207
|
+
return fig, axs
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def plot_heads_no_limb(attn_maps, ids_restore, image, patch_id=16):
|
|
211
|
+
attn = attn_maps[0][0][:, 1:, 1:] # [num_head, num_patch, num_patch]
|
|
212
|
+
num_heads = attn.shape[0]
|
|
213
|
+
ids_keep = ids_restore[0, : attn.shape[1]]
|
|
214
|
+
attn_shuffled = attn[:, ids_restore[0]][:, :, ids_restore[0]]
|
|
215
|
+
attn_received = attn_shuffled.mean(axis=1)
|
|
216
|
+
num_images = image.shape[0]
|
|
217
|
+
num_channels = image.shape[1]
|
|
218
|
+
|
|
219
|
+
fig, axs = plt.subplots(
|
|
220
|
+
num_images, num_heads + num_channels, figsize=(25, 4), squeeze=False
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
for i in range(image.shape[0]):
|
|
224
|
+
for i_ch, ch in enumerate(["Bx", "By", "Bz"]):
|
|
225
|
+
axs[i, i_ch].imshow(image[i, i_ch, :, :].to("cpu").numpy(), cmap="gray")
|
|
226
|
+
axs[i, i_ch].set_title(f"Ch: {ch}")
|
|
227
|
+
axs[i, i_ch].axis("off")
|
|
228
|
+
|
|
229
|
+
for h in range(num_heads):
|
|
230
|
+
# head_attn = attn_shuffled[h, patch_id, :] # weights based on patches
|
|
231
|
+
head_attn = attn_received[h, :] # weights based on mean of patches
|
|
232
|
+
heatmap = visualize_head(head_attn, ids_keep, 512, 16)
|
|
233
|
+
axs[i, h + num_channels].imshow(heatmap, cmap="jet")
|
|
234
|
+
axs[i, h + num_channels].set_title(f"Head {h}")
|
|
235
|
+
axs[i, h + num_channels].axis("off")
|
|
236
|
+
|
|
237
|
+
plt.tight_layout()
|
|
238
|
+
# plt.savefig("attention_map_no_limb.png", dpi=200)
|
|
239
|
+
return fig
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
# Main execution
|
|
243
|
+
if __name__ == "__main__":
|
|
244
|
+
cfg = OmegaConf.load(
|
|
245
|
+
"/home/jhong36/Project/2025-HL-Solar-Wind/solar_phenomena_prediction/configs/pretrain_mae.yaml"
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# Setup dataset
|
|
249
|
+
data_module = SDOMLDataModule(
|
|
250
|
+
hmi_path=(
|
|
251
|
+
os.path.join(
|
|
252
|
+
cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.hmi
|
|
253
|
+
)
|
|
254
|
+
if cfg.data.sdoml.sub_directory.hmi
|
|
255
|
+
else None
|
|
256
|
+
),
|
|
257
|
+
aia_path=(
|
|
258
|
+
os.path.join(
|
|
259
|
+
cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.aia
|
|
260
|
+
)
|
|
261
|
+
if cfg.data.sdoml.sub_directory.aia
|
|
262
|
+
else None
|
|
263
|
+
),
|
|
264
|
+
eve_path=None,
|
|
265
|
+
components=cfg.data.sdoml.components,
|
|
266
|
+
wavelengths=cfg.data.sdoml.wavelengths,
|
|
267
|
+
ions=cfg.data.sdoml.ions,
|
|
268
|
+
frequency=cfg.data.sdoml.frequency,
|
|
269
|
+
batch_size=cfg.model.opt.batch_size,
|
|
270
|
+
num_workers=cfg.data.num_workers,
|
|
271
|
+
val_months=cfg.data.month_splits.val,
|
|
272
|
+
test_months=cfg.data.month_splits.test,
|
|
273
|
+
holdout_months=cfg.data.month_splits.holdout,
|
|
274
|
+
cache_dir=os.path.join(
|
|
275
|
+
cfg.data.sdoml.save_directory, cfg.data.sdoml.sub_directory.cache
|
|
276
|
+
),
|
|
277
|
+
min_date=cfg.data.min_date,
|
|
278
|
+
max_date=cfg.data.max_date,
|
|
279
|
+
num_frames=cfg.model.mae.num_frames,
|
|
280
|
+
drop_frame_dim=cfg.data.drop_frame_dim,
|
|
281
|
+
apply_mask=cfg.data.sdoml.apply_mask,
|
|
282
|
+
precision=cfg.experiment.precision,
|
|
283
|
+
normalization=cfg.data.sdoml.normalization,
|
|
284
|
+
)
|
|
285
|
+
data_module.setup()
|
|
286
|
+
|
|
287
|
+
# Load model
|
|
288
|
+
base_path = "../../../../assets/check_point/backbone/"
|
|
289
|
+
model_hmi = load_mae_weights(
|
|
290
|
+
os.path.join(base_path, "id_xn2c11go_mae_epoch=25-val_loss=0.00.ckpt"),
|
|
291
|
+
# masking_ratio=0 # full image by default
|
|
292
|
+
)
|
|
293
|
+
# model_hmi.autoencoder.ids_limb_mask = None
|
|
294
|
+
model_hmi.eval()
|
|
295
|
+
|
|
296
|
+
# Patch attention layers
|
|
297
|
+
attn_maps = patch_attn_layers(model_hmi)
|
|
298
|
+
|
|
299
|
+
# Forward pass
|
|
300
|
+
id_input = 0
|
|
301
|
+
x = data_module.test_ds[id_input][0].unsqueeze(0)
|
|
302
|
+
lgr_logger.info(f"Input shape: {x.shape}")
|
|
303
|
+
|
|
304
|
+
with torch.no_grad():
|
|
305
|
+
latent, mask, ids_restore = model_hmi.autoencoder.forward_encoder(
|
|
306
|
+
x, mask_ratio=0
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# Example: check first attention map
|
|
310
|
+
if len(attn_maps) > 0:
|
|
311
|
+
lgr_logger.info(f"First attention map shape: {attn_maps[0].shape}")
|
|
312
|
+
|
|
313
|
+
fig = plot_heads(attn_maps, ids_restore, x[:, 0:1, 0, :, :])
|
|
314
|
+
# fig = plot_heads_no_limb(attn_maps, ids_restore, x[:, 0:1, 0, :, :])
|