pmf-dark 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pmf_dark-0.1.0/PKG-INFO +233 -0
- pmf_dark-0.1.0/README.md +214 -0
- pmf_dark-0.1.0/pyproject.toml +26 -0
- pmf_dark-0.1.0/src/pmf_dark/__init__.py +10 -0
- pmf_dark-0.1.0/src/pmf_dark/darkdiv.py +264 -0
- pmf_dark-0.1.0/src/pmf_dark/inference.py +107 -0
- pmf_dark-0.1.0/src/pmf_dark/models.py +229 -0
pmf_dark-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: pmf-dark
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Modelling species dark diversity using bayesian Probablistic Matrix Factorisation
|
|
5
|
+
License: Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
6
|
+
Author: davidyshen
|
|
7
|
+
Author-email: contact@davidyshen.com
|
|
8
|
+
Requires-Python: >=3.13,<3.15
|
|
9
|
+
Classifier: License :: Other/Proprietary License
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
13
|
+
Requires-Dist: matplotlib (>=3.10.9,<4.0.0)
|
|
14
|
+
Requires-Dist: numpy (>=2.4.6,<3.0.0)
|
|
15
|
+
Requires-Dist: pandas (>=3.0.3,<4.0.0)
|
|
16
|
+
Requires-Dist: pyro-ppl (>=1.9.0,<2.0.0)
|
|
17
|
+
Description-Content-Type: text/markdown
|
|
18
|
+
|
|
19
|
+
# PMF-dark: Using matrix factorisation for dark diversity estimation
|
|
20
|
+
|
|
21
|
+
## Overview
|
|
22
|
+
|
|
23
|
+
This repository implements a **PMF-dark** using Bayesian Probabilistic Matrix Factorisation to estimate **dark diversity** - the set of species absent from a site despite having suitable environmental conditions. The method uses **counterfactual predictions** to reconstruct the potential species pool by separating environmental effects from unmeasured drivers of absence (e.g., land-use degradation, dispersal limitation, biotic interactions).
|
|
24
|
+
|
|
25
|
+
## The Problem: What is Dark Diversity?
|
|
26
|
+
|
|
27
|
+
Traditional biodiversity assessments only count observed species (alpha diversity). However, many species are absent from sites where they *could* thrive based on environmental conditions. This "**dark diversity**" represents:
|
|
28
|
+
|
|
29
|
+
- Species lost due to historical or ongoing land-use degradation
|
|
30
|
+
- Species unable to reach suitable sites due to dispersal limitation
|
|
31
|
+
- Species suppressed by biotic interactions
|
|
32
|
+
|
|
33
|
+
Quantifying dark diversity is crucial for:
|
|
34
|
+
- Conservation planning and restoration potential assessment
|
|
35
|
+
- Understanding true biodiversity patterns
|
|
36
|
+
- Identifying areas with highest restoration value
|
|
37
|
+
|
|
38
|
+
## Methodology
|
|
39
|
+
|
|
40
|
+
### Core Model
|
|
41
|
+
|
|
42
|
+
The framework decomposes species occurrence probabilities into **three additive components**:
|
|
43
|
+
|
|
44
|
+
$$\text{logit}(p_{ij}) = \underbrace{\alpha_j}_{\text{Intercept}} + \underbrace{f_j(\mathbf{x}_i)}_{\text{Environmental Effects}} + \underbrace{\mathbf{w}_i^\top \mathbf{z}_j}_{\text{Latent Factors}}$$
|
|
45
|
+
|
|
46
|
+
Where:
|
|
47
|
+
- **$\alpha_j$**: Species-specific baseline prevalence
|
|
48
|
+
- **$f_j(\mathbf{x}_i)$**: Environmental response function to measured abiotic variables (temperature, pH, elevation, etc.), which can be modelled as linear, Gaussian niche, or non-linear (e.g. Bayesian neural network)
|
|
49
|
+
- **$\mathbf{w}_i^\top \mathbf{z}_j$**: Latent factors capturing unmeasured drivers of absence
|
|
50
|
+
|
|
51
|
+
### Key Innovation: Counterfactual Predictions
|
|
52
|
+
|
|
53
|
+
1. **Full Predictions**: Include all components (environment + latent factors)
|
|
54
|
+
- Represents observed diversity with all drivers active
|
|
55
|
+
|
|
56
|
+
2. **Environment-Only Predictions**: Exclude latent factors
|
|
57
|
+
- Represents potential diversity (setting $\mathbf{w}_i^\top \mathbf{z}_j = 0$)
|
|
58
|
+
|
|
59
|
+
3. **Dark Diversity Proxy**: Difference between full and environment-only predictions
|
|
60
|
+
- Quantifies species lost to unmeasured stressors
|
|
61
|
+
|
|
62
|
+
### Inference: Stochastic Variational Inference (SVI)
|
|
63
|
+
|
|
64
|
+
The model is fit using **Pyro-based SVI**, which:
|
|
65
|
+
- Handles high-dimensional ecological matrices efficiently
|
|
66
|
+
- Treats inference as an optimisation problem (ELBO maximisation)
|
|
67
|
+
- Scales to thousands of sites and species
|
|
68
|
+
- Requires minimal computational resources
|
|
69
|
+
|
|
70
|
+
## Repository Structure
|
|
71
|
+
|
|
72
|
+
```
|
|
73
|
+
PMF_dark/
|
|
74
|
+
├── README.md # This file
|
|
75
|
+
├── mat_fact_dark_div.ipynb # Main analysis notebook
|
|
76
|
+
├── data/
|
|
77
|
+
│ ├── survey.csv # Species presence/absence matrix (sites × species)
|
|
78
|
+
│ ├── env.csv # Environmental predictors (sites × covariates)
|
|
79
|
+
│ └── truth.csv # Ground truth data (if available)
|
|
80
|
+
└── output/
|
|
81
|
+
├── mat_fact_predicted_probabilities_full.csv # Full model predictions
|
|
82
|
+
├── mat_fact_predicted_probabilities_env_only.csv # Environment-only predictions
|
|
83
|
+
└── mat_fact_dark_diversity_proxy.csv # Dark diversity estimates
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
## Installation
|
|
87
|
+
|
|
88
|
+
### Requirements
|
|
89
|
+
|
|
90
|
+
- Python 3.13 or 3.14
|
|
91
|
+
- PyTorch (with CUDA support if using GPU)
|
|
92
|
+
- Pyro (pyro-ppl)
|
|
93
|
+
- Pandas
|
|
94
|
+
- NumPy
|
|
95
|
+
- scikit-learn
|
|
96
|
+
- scipy
|
|
97
|
+
|
|
98
|
+
### Setup
|
|
99
|
+
|
|
100
|
+
To ensure PyTorch is installed with the correct CUDA version for your system, it is recommended to install PyTorch manually **first** before installing the package or its other dependencies.
|
|
101
|
+
|
|
102
|
+
#### 1. Setup Virtual Environment
|
|
103
|
+
```bash
|
|
104
|
+
# Clone the repository
|
|
105
|
+
git clone https://github.com/davidyshen/PMF_dark.git
|
|
106
|
+
cd PMF_dark
|
|
107
|
+
|
|
108
|
+
# Create virtual environment (optional but recommended)
|
|
109
|
+
python -m venv .venv
|
|
110
|
+
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
#### 2. Install PyTorch with CUDA
|
|
114
|
+
Visit the [PyTorch Getting Started guide](https://pytorch.org/get-started/locally/) to select the correct command for your CUDA version and OS. For example, to install PyTorch with CUDA 12.4 support on Windows/Linux:
|
|
115
|
+
```bash
|
|
116
|
+
pip install torch --index-url https://download.pytorch.org/whl/cu124
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
If not using CUDA, simply install the CPU version:
|
|
120
|
+
```bash
|
|
121
|
+
pip install torch
|
|
122
|
+
```
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
#### 3. Install remaining dependencies
|
|
126
|
+
If installing via pip:
|
|
127
|
+
```bash
|
|
128
|
+
pip install pyro-ppl pandas numpy scikit-learn scipy jupyter
|
|
129
|
+
```
|
|
130
|
+
|
|
131
|
+
If using Poetry:
|
|
132
|
+
```bash
|
|
133
|
+
# This will install the package and its remaining dependencies into your environment
|
|
134
|
+
poetry install
|
|
135
|
+
```
|
|
136
|
+
|
|
137
|
+
## Usage
|
|
138
|
+
|
|
139
|
+
### Running the Full Analysis
|
|
140
|
+
|
|
141
|
+
1. Prepare your data in `data/` directory:
|
|
142
|
+
- `survey.csv`: Species presence/absence (rows = sites, columns = species, values = 0/1)
|
|
143
|
+
- `env.csv`: Environmental predictors (rows = sites, columns = variables)
|
|
144
|
+
|
|
145
|
+
2. Open and run the Jupyter notebook:
|
|
146
|
+
```bash
|
|
147
|
+
jupyter notebook mat_fact_dark_div.ipynb
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
3. The notebook will:
|
|
151
|
+
- Load and standardise data
|
|
152
|
+
- Fit the matrix factorisation model (2,500 iterations)
|
|
153
|
+
- Generate predictions and save CSV outputs
|
|
154
|
+
|
|
155
|
+
### Customisation
|
|
156
|
+
|
|
157
|
+
Key parameters in the notebook:
|
|
158
|
+
|
|
159
|
+
```python
|
|
160
|
+
# Model parameters
|
|
161
|
+
num_factors = 5 # Number of latent factors (adjust based on data complexity)
|
|
162
|
+
num_iterations = 2500 # Model iterations
|
|
163
|
+
|
|
164
|
+
# Learning rate
|
|
165
|
+
Adam({"lr": 0.01}) # Adjust if convergence is slow
|
|
166
|
+
```
|
|
167
|
+
|
|
168
|
+
## Output Files
|
|
169
|
+
|
|
170
|
+
- **mat_fact_predicted_probabilities_full.csv**: Predicted species occurrence probabilities including all effects
|
|
171
|
+
- **mat_fact_predicted_probabilities_env_only.csv**: Predicted probabilities using only environmental effects
|
|
172
|
+
- **mat_fact_dark_diversity_proxy.csv**: Dark diversity estimates (full - env_only)
|
|
173
|
+
|
|
174
|
+
## Data Format
|
|
175
|
+
|
|
176
|
+
### survey.csv
|
|
177
|
+
```
|
|
178
|
+
site_id,species_1,species_2,...,species_n,ID,x,y
|
|
179
|
+
site_1,0,1,0,...,1,id_1,100.5,200.3
|
|
180
|
+
site_2,1,0,1,...,0,id_2,101.2,201.5
|
|
181
|
+
...
|
|
182
|
+
```
|
|
183
|
+
- Rows: Sites/locations
|
|
184
|
+
- Columns: Species (0/1 presence/absence) + ID + spatial coordinates
|
|
185
|
+
- **Note**: ID and spatial coordinates are automatically extracted/dropped
|
|
186
|
+
|
|
187
|
+
### env.csv
|
|
188
|
+
```
|
|
189
|
+
site_id,temp,pH,elevation,...,ID,landuse
|
|
190
|
+
site_1,15.2,7.1,500,...,id_1,degraded
|
|
191
|
+
site_2,14.8,6.9,520,...,id_2,pristine
|
|
192
|
+
...
|
|
193
|
+
```
|
|
194
|
+
- Rows: Sites matching survey.csv
|
|
195
|
+
- Columns: Environmental predictors + ID + land-use
|
|
196
|
+
- **Note**: ID and land-use columns are dropped; only abiotic predictors are used
|
|
197
|
+
|
|
198
|
+
## Interpretation of Results
|
|
199
|
+
|
|
200
|
+
### Dark Diversity Proxy Values
|
|
201
|
+
- **High values (close to 1)**: Species should be present based on environment but are absent—candidate for restoration
|
|
202
|
+
- **Low values (close to 0)**: Species absence explained by environmental conditions
|
|
203
|
+
- **Negative values**: Model predicts species should be absent (rare, indicates environmental unsuitability)
|
|
204
|
+
|
|
205
|
+
### Key Metrics
|
|
206
|
+
- **AUC (Area Under ROC Curve)**: Overall model discrimination (0.5 = random, 1.0 = perfect)
|
|
207
|
+
- **Brier Score**: Prediction calibration error (lower is better)
|
|
208
|
+
- **F1 Score**: Balance between precision and recall
|
|
209
|
+
|
|
210
|
+
## Advantages of This Approach
|
|
211
|
+
|
|
212
|
+
✓ **No subjective benchmarking**: Automated separation of environmental vs. unmeasured effects
|
|
213
|
+
✓ **Mathematically principled**: Latent factors naturally absorb degradation signals
|
|
214
|
+
✓ **Scalable**: SVI handles thousands of species and sites
|
|
215
|
+
✓ **Species-specific**: Each species can have unique environmental responses
|
|
216
|
+
✓ **Reproducible**: Fully probabilistic framework with clear assumptions
|
|
217
|
+
|
|
218
|
+
## Limitations
|
|
219
|
+
|
|
220
|
+
- Assumes species responses are log-linear (logit link)
|
|
221
|
+
- Requires sufficient environmental variation to estimate effects reliably
|
|
222
|
+
- May overestimate dark diversity if detection is imperfect
|
|
223
|
+
- Computational cost increases with number of species and sites
|
|
224
|
+
- Requires careful tuning of number of latent factors
|
|
225
|
+
|
|
226
|
+
## References & Theoretical Background
|
|
227
|
+
|
|
228
|
+
### Key Concepts
|
|
229
|
+
- **Joint Species Distribution Models (JSDMs)**: Latent variable models for multivariate species data
|
|
230
|
+
- **Matrix Factorisation**: Low-rank decomposition of high-dimensional species matrices
|
|
231
|
+
- **Stochastic Variational Inference**: Scalable Bayesian inference for probabilistic models
|
|
232
|
+
- **Counterfactual Predictions**: Causal inference approach to estimate potential outcomes
|
|
233
|
+
|
pmf_dark-0.1.0/README.md
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
# PMF-dark: Using matrix factorisation for dark diversity estimation
|
|
2
|
+
|
|
3
|
+
## Overview
|
|
4
|
+
|
|
5
|
+
This repository implements a **PMF-dark** using Bayesian Probabilistic Matrix Factorisation to estimate **dark diversity** - the set of species absent from a site despite having suitable environmental conditions. The method uses **counterfactual predictions** to reconstruct the potential species pool by separating environmental effects from unmeasured drivers of absence (e.g., land-use degradation, dispersal limitation, biotic interactions).
|
|
6
|
+
|
|
7
|
+
## The Problem: What is Dark Diversity?
|
|
8
|
+
|
|
9
|
+
Traditional biodiversity assessments only count observed species (alpha diversity). However, many species are absent from sites where they *could* thrive based on environmental conditions. This "**dark diversity**" represents:
|
|
10
|
+
|
|
11
|
+
- Species lost due to historical or ongoing land-use degradation
|
|
12
|
+
- Species unable to reach suitable sites due to dispersal limitation
|
|
13
|
+
- Species suppressed by biotic interactions
|
|
14
|
+
|
|
15
|
+
Quantifying dark diversity is crucial for:
|
|
16
|
+
- Conservation planning and restoration potential assessment
|
|
17
|
+
- Understanding true biodiversity patterns
|
|
18
|
+
- Identifying areas with highest restoration value
|
|
19
|
+
|
|
20
|
+
## Methodology
|
|
21
|
+
|
|
22
|
+
### Core Model
|
|
23
|
+
|
|
24
|
+
The framework decomposes species occurrence probabilities into **three additive components**:
|
|
25
|
+
|
|
26
|
+
$$\text{logit}(p_{ij}) = \underbrace{\alpha_j}_{\text{Intercept}} + \underbrace{f_j(\mathbf{x}_i)}_{\text{Environmental Effects}} + \underbrace{\mathbf{w}_i^\top \mathbf{z}_j}_{\text{Latent Factors}}$$
|
|
27
|
+
|
|
28
|
+
Where:
|
|
29
|
+
- **$\alpha_j$**: Species-specific baseline prevalence
|
|
30
|
+
- **$f_j(\mathbf{x}_i)$**: Environmental response function to measured abiotic variables (temperature, pH, elevation, etc.), which can be modelled as linear, Gaussian niche, or non-linear (e.g. Bayesian neural network)
|
|
31
|
+
- **$\mathbf{w}_i^\top \mathbf{z}_j$**: Latent factors capturing unmeasured drivers of absence
|
|
32
|
+
|
|
33
|
+
### Key Innovation: Counterfactual Predictions
|
|
34
|
+
|
|
35
|
+
1. **Full Predictions**: Include all components (environment + latent factors)
|
|
36
|
+
- Represents observed diversity with all drivers active
|
|
37
|
+
|
|
38
|
+
2. **Environment-Only Predictions**: Exclude latent factors
|
|
39
|
+
- Represents potential diversity (setting $\mathbf{w}_i^\top \mathbf{z}_j = 0$)
|
|
40
|
+
|
|
41
|
+
3. **Dark Diversity Proxy**: Difference between full and environment-only predictions
|
|
42
|
+
- Quantifies species lost to unmeasured stressors
|
|
43
|
+
|
|
44
|
+
### Inference: Stochastic Variational Inference (SVI)
|
|
45
|
+
|
|
46
|
+
The model is fit using **Pyro-based SVI**, which:
|
|
47
|
+
- Handles high-dimensional ecological matrices efficiently
|
|
48
|
+
- Treats inference as an optimisation problem (ELBO maximisation)
|
|
49
|
+
- Scales to thousands of sites and species
|
|
50
|
+
- Requires minimal computational resources
|
|
51
|
+
|
|
52
|
+
## Repository Structure
|
|
53
|
+
|
|
54
|
+
```
|
|
55
|
+
PMF_dark/
|
|
56
|
+
├── README.md # This file
|
|
57
|
+
├── mat_fact_dark_div.ipynb # Main analysis notebook
|
|
58
|
+
├── data/
|
|
59
|
+
│ ├── survey.csv # Species presence/absence matrix (sites × species)
|
|
60
|
+
│ ├── env.csv # Environmental predictors (sites × covariates)
|
|
61
|
+
│ └── truth.csv # Ground truth data (if available)
|
|
62
|
+
└── output/
|
|
63
|
+
├── mat_fact_predicted_probabilities_full.csv # Full model predictions
|
|
64
|
+
├── mat_fact_predicted_probabilities_env_only.csv # Environment-only predictions
|
|
65
|
+
└── mat_fact_dark_diversity_proxy.csv # Dark diversity estimates
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
## Installation
|
|
69
|
+
|
|
70
|
+
### Requirements
|
|
71
|
+
|
|
72
|
+
- Python 3.13 or 3.14
|
|
73
|
+
- PyTorch (with CUDA support if using GPU)
|
|
74
|
+
- Pyro (pyro-ppl)
|
|
75
|
+
- Pandas
|
|
76
|
+
- NumPy
|
|
77
|
+
- scikit-learn
|
|
78
|
+
- scipy
|
|
79
|
+
|
|
80
|
+
### Setup
|
|
81
|
+
|
|
82
|
+
To ensure PyTorch is installed with the correct CUDA version for your system, it is recommended to install PyTorch manually **first** before installing the package or its other dependencies.
|
|
83
|
+
|
|
84
|
+
#### 1. Setup Virtual Environment
|
|
85
|
+
```bash
|
|
86
|
+
# Clone the repository
|
|
87
|
+
git clone https://github.com/davidyshen/PMF_dark.git
|
|
88
|
+
cd PMF_dark
|
|
89
|
+
|
|
90
|
+
# Create virtual environment (optional but recommended)
|
|
91
|
+
python -m venv .venv
|
|
92
|
+
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
|
93
|
+
```
|
|
94
|
+
|
|
95
|
+
#### 2. Install PyTorch with CUDA
|
|
96
|
+
Visit the [PyTorch Getting Started guide](https://pytorch.org/get-started/locally/) to select the correct command for your CUDA version and OS. For example, to install PyTorch with CUDA 12.4 support on Windows/Linux:
|
|
97
|
+
```bash
|
|
98
|
+
pip install torch --index-url https://download.pytorch.org/whl/cu124
|
|
99
|
+
```
|
|
100
|
+
|
|
101
|
+
If not using CUDA, simply install the CPU version:
|
|
102
|
+
```bash
|
|
103
|
+
pip install torch
|
|
104
|
+
```
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
#### 3. Install remaining dependencies
|
|
108
|
+
If installing via pip:
|
|
109
|
+
```bash
|
|
110
|
+
pip install pyro-ppl pandas numpy scikit-learn scipy jupyter
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
If using Poetry:
|
|
114
|
+
```bash
|
|
115
|
+
# This will install the package and its remaining dependencies into your environment
|
|
116
|
+
poetry install
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
## Usage
|
|
120
|
+
|
|
121
|
+
### Running the Full Analysis
|
|
122
|
+
|
|
123
|
+
1. Prepare your data in `data/` directory:
|
|
124
|
+
- `survey.csv`: Species presence/absence (rows = sites, columns = species, values = 0/1)
|
|
125
|
+
- `env.csv`: Environmental predictors (rows = sites, columns = variables)
|
|
126
|
+
|
|
127
|
+
2. Open and run the Jupyter notebook:
|
|
128
|
+
```bash
|
|
129
|
+
jupyter notebook mat_fact_dark_div.ipynb
|
|
130
|
+
```
|
|
131
|
+
|
|
132
|
+
3. The notebook will:
|
|
133
|
+
- Load and standardise data
|
|
134
|
+
- Fit the matrix factorisation model (2,500 iterations)
|
|
135
|
+
- Generate predictions and save CSV outputs
|
|
136
|
+
|
|
137
|
+
### Customisation
|
|
138
|
+
|
|
139
|
+
Key parameters in the notebook:
|
|
140
|
+
|
|
141
|
+
```python
|
|
142
|
+
# Model parameters
|
|
143
|
+
num_factors = 5 # Number of latent factors (adjust based on data complexity)
|
|
144
|
+
num_iterations = 2500 # Model iterations
|
|
145
|
+
|
|
146
|
+
# Learning rate
|
|
147
|
+
Adam({"lr": 0.01}) # Adjust if convergence is slow
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
## Output Files
|
|
151
|
+
|
|
152
|
+
- **mat_fact_predicted_probabilities_full.csv**: Predicted species occurrence probabilities including all effects
|
|
153
|
+
- **mat_fact_predicted_probabilities_env_only.csv**: Predicted probabilities using only environmental effects
|
|
154
|
+
- **mat_fact_dark_diversity_proxy.csv**: Dark diversity estimates (full - env_only)
|
|
155
|
+
|
|
156
|
+
## Data Format
|
|
157
|
+
|
|
158
|
+
### survey.csv
|
|
159
|
+
```
|
|
160
|
+
site_id,species_1,species_2,...,species_n,ID,x,y
|
|
161
|
+
site_1,0,1,0,...,1,id_1,100.5,200.3
|
|
162
|
+
site_2,1,0,1,...,0,id_2,101.2,201.5
|
|
163
|
+
...
|
|
164
|
+
```
|
|
165
|
+
- Rows: Sites/locations
|
|
166
|
+
- Columns: Species (0/1 presence/absence) + ID + spatial coordinates
|
|
167
|
+
- **Note**: ID and spatial coordinates are automatically extracted/dropped
|
|
168
|
+
|
|
169
|
+
### env.csv
|
|
170
|
+
```
|
|
171
|
+
site_id,temp,pH,elevation,...,ID,landuse
|
|
172
|
+
site_1,15.2,7.1,500,...,id_1,degraded
|
|
173
|
+
site_2,14.8,6.9,520,...,id_2,pristine
|
|
174
|
+
...
|
|
175
|
+
```
|
|
176
|
+
- Rows: Sites matching survey.csv
|
|
177
|
+
- Columns: Environmental predictors + ID + land-use
|
|
178
|
+
- **Note**: ID and land-use columns are dropped; only abiotic predictors are used
|
|
179
|
+
|
|
180
|
+
## Interpretation of Results
|
|
181
|
+
|
|
182
|
+
### Dark Diversity Proxy Values
|
|
183
|
+
- **High values (close to 1)**: Species should be present based on environment but are absent—candidate for restoration
|
|
184
|
+
- **Low values (close to 0)**: Species absence explained by environmental conditions
|
|
185
|
+
- **Negative values**: Model predicts species should be absent (rare, indicates environmental unsuitability)
|
|
186
|
+
|
|
187
|
+
### Key Metrics
|
|
188
|
+
- **AUC (Area Under ROC Curve)**: Overall model discrimination (0.5 = random, 1.0 = perfect)
|
|
189
|
+
- **Brier Score**: Prediction calibration error (lower is better)
|
|
190
|
+
- **F1 Score**: Balance between precision and recall
|
|
191
|
+
|
|
192
|
+
## Advantages of This Approach
|
|
193
|
+
|
|
194
|
+
✓ **No subjective benchmarking**: Automated separation of environmental vs. unmeasured effects
|
|
195
|
+
✓ **Mathematically principled**: Latent factors naturally absorb degradation signals
|
|
196
|
+
✓ **Scalable**: SVI handles thousands of species and sites
|
|
197
|
+
✓ **Species-specific**: Each species can have unique environmental responses
|
|
198
|
+
✓ **Reproducible**: Fully probabilistic framework with clear assumptions
|
|
199
|
+
|
|
200
|
+
## Limitations
|
|
201
|
+
|
|
202
|
+
- Assumes species responses are log-linear (logit link)
|
|
203
|
+
- Requires sufficient environmental variation to estimate effects reliably
|
|
204
|
+
- May overestimate dark diversity if detection is imperfect
|
|
205
|
+
- Computational cost increases with number of species and sites
|
|
206
|
+
- Requires careful tuning of number of latent factors
|
|
207
|
+
|
|
208
|
+
## References & Theoretical Background
|
|
209
|
+
|
|
210
|
+
### Key Concepts
|
|
211
|
+
- **Joint Species Distribution Models (JSDMs)**: Latent variable models for multivariate species data
|
|
212
|
+
- **Matrix Factorisation**: Low-rank decomposition of high-dimensional species matrices
|
|
213
|
+
- **Stochastic Variational Inference**: Scalable Bayesian inference for probabilistic models
|
|
214
|
+
- **Counterfactual Predictions**: Causal inference approach to estimate potential outcomes
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "pmf-dark"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Modelling species dark diversity using bayesian Probablistic Matrix Factorisation"
|
|
5
|
+
authors = [
|
|
6
|
+
{name = "davidyshen",email = "contact@davidyshen.com"}
|
|
7
|
+
]
|
|
8
|
+
license = {text = "Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE."}
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.13,<3.15"
|
|
11
|
+
dependencies = [
|
|
12
|
+
"numpy (>=2.4.6,<3.0.0)",
|
|
13
|
+
"pandas (>=3.0.3,<4.0.0)",
|
|
14
|
+
"matplotlib (>=3.10.9,<4.0.0)",
|
|
15
|
+
"pyro-ppl (>=1.9.0,<2.0.0)"
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
[tool.poetry]
|
|
19
|
+
packages = [{ include = "pmf_dark", from = "src" }]
|
|
20
|
+
|
|
21
|
+
[tool.poetry.group.dev.dependencies]
|
|
22
|
+
ipykernel = "^7.2.0"
|
|
23
|
+
|
|
24
|
+
[build-system]
|
|
25
|
+
requires = ["poetry-core>=2.0.0,<3.0.0"]
|
|
26
|
+
build-backend = "poetry.core.masonry.api"
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .darkdiv import compute_dark_diversity
|
|
3
|
+
|
|
4
|
+
# Print CUDA availability on import
|
|
5
|
+
if torch.cuda.is_available():
|
|
6
|
+
print(f"pmf-dark: CUDA is available. GPU: {torch.cuda.get_device_name(0)}")
|
|
7
|
+
else:
|
|
8
|
+
print("pmf-dark: CUDA is not available. Using CPU.")
|
|
9
|
+
|
|
10
|
+
__all__ = ["compute_dark_diversity"]
|
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import pyro
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def infer_y_type(y):
|
|
10
|
+
|
|
11
|
+
# torch tensor
|
|
12
|
+
if isinstance(y, torch.Tensor):
|
|
13
|
+
|
|
14
|
+
values = y.detach().cpu().numpy()
|
|
15
|
+
|
|
16
|
+
# pandas dataframe
|
|
17
|
+
elif hasattr(y, "to_numpy"):
|
|
18
|
+
|
|
19
|
+
values = y.to_numpy()
|
|
20
|
+
|
|
21
|
+
# numpy or other
|
|
22
|
+
else:
|
|
23
|
+
|
|
24
|
+
values = np.asarray(y)
|
|
25
|
+
|
|
26
|
+
# Check missing values
|
|
27
|
+
if np.isnan(values).any():
|
|
28
|
+
raise ValueError("y contains missing values.")
|
|
29
|
+
|
|
30
|
+
# Binary data
|
|
31
|
+
if np.isin(values, [0, 1]).all():
|
|
32
|
+
return "presence_absence"
|
|
33
|
+
|
|
34
|
+
# Count data
|
|
35
|
+
elif (values >= 0).all():
|
|
36
|
+
return "count"
|
|
37
|
+
|
|
38
|
+
else:
|
|
39
|
+
raise ValueError("y must contain either binary or count data.")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def prepare_data(x, y, cuda=False):
|
|
43
|
+
|
|
44
|
+
# Keep names
|
|
45
|
+
x_columns = x.columns
|
|
46
|
+
y_columns = y.columns
|
|
47
|
+
site_index = y.index
|
|
48
|
+
|
|
49
|
+
# Standardize x
|
|
50
|
+
x = (x - x.mean()) / x.std()
|
|
51
|
+
|
|
52
|
+
# Convert to tensors
|
|
53
|
+
x_tensor = torch.tensor(
|
|
54
|
+
x.to_numpy(),
|
|
55
|
+
dtype=torch.float32,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
y_tensor = torch.tensor(
|
|
59
|
+
y.to_numpy(),
|
|
60
|
+
dtype=torch.float32,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# Device
|
|
64
|
+
device = torch.device("cuda" if cuda and torch.cuda.is_available() else "cpu")
|
|
65
|
+
|
|
66
|
+
x_tensor = x_tensor.to(device)
|
|
67
|
+
y_tensor = y_tensor.to(device)
|
|
68
|
+
|
|
69
|
+
return {
|
|
70
|
+
"x": x_tensor,
|
|
71
|
+
"y": y_tensor,
|
|
72
|
+
"x_columns": x_columns,
|
|
73
|
+
"y_columns": y_columns,
|
|
74
|
+
"site_index": site_index,
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def compute_predictions(
|
|
79
|
+
samples, x, model_type="gaussian", include_latent=True, y_type="presence_absence"
|
|
80
|
+
):
|
|
81
|
+
|
|
82
|
+
if model_type == "linear":
|
|
83
|
+
alpha = samples["alpha"].squeeze(1)
|
|
84
|
+
beta = samples["beta"].squeeze(1)
|
|
85
|
+
eta = alpha[:, None, :] + torch.einsum("ij,sjk->sik", x, beta)
|
|
86
|
+
|
|
87
|
+
elif model_type == "gaussian":
|
|
88
|
+
alpha = samples["alpha"].squeeze(1)
|
|
89
|
+
mu = samples["mu"].squeeze(1)
|
|
90
|
+
gamma = samples["gamma"].squeeze(1)
|
|
91
|
+
|
|
92
|
+
# x: [sites, env]
|
|
93
|
+
# mu: [samples, env, species]
|
|
94
|
+
# gamma: [samples, env, species]
|
|
95
|
+
diff = x[None, :, :, None] - mu[:, None, :, :]
|
|
96
|
+
env_effect = -torch.sum(gamma[:, None, :, :] * diff**2, dim=2)
|
|
97
|
+
|
|
98
|
+
eta = alpha[:, None, :] + env_effect
|
|
99
|
+
|
|
100
|
+
elif model_type == "bnn":
|
|
101
|
+
w1 = samples["w1"].squeeze(1)
|
|
102
|
+
b1 = samples["b1"].squeeze(1)
|
|
103
|
+
w2 = samples["w2"].squeeze(1)
|
|
104
|
+
b2 = samples["b2"].squeeze(1)
|
|
105
|
+
|
|
106
|
+
hidden = torch.tanh(torch.einsum("ij,sjh->sih", x, w1) + b1[:, None, :])
|
|
107
|
+
|
|
108
|
+
eta = torch.einsum("sih,shk->sik", hidden, w2) + b2[:, None, :]
|
|
109
|
+
|
|
110
|
+
else:
|
|
111
|
+
raise ValueError(f"Unknown model_type: {model_type}")
|
|
112
|
+
|
|
113
|
+
if include_latent:
|
|
114
|
+
W = samples["W"].squeeze(1)
|
|
115
|
+
Z = samples["Z"].squeeze(1)
|
|
116
|
+
eta = eta + torch.einsum("sik,sjk->sij", W, Z)
|
|
117
|
+
|
|
118
|
+
if y_type == "presence_absence":
|
|
119
|
+
return torch.sigmoid(eta)
|
|
120
|
+
|
|
121
|
+
elif y_type == "count":
|
|
122
|
+
return torch.exp(eta)
|
|
123
|
+
|
|
124
|
+
else:
|
|
125
|
+
raise ValueError("y_type must be 'presence_absence' or 'count'")
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def compute_dark_diversity(
|
|
129
|
+
y,
|
|
130
|
+
x,
|
|
131
|
+
model_type="linear",
|
|
132
|
+
num_factors=1,
|
|
133
|
+
method="svi",
|
|
134
|
+
cuda=False,
|
|
135
|
+
include_latent=True,
|
|
136
|
+
return_means=True,
|
|
137
|
+
batch_size=None,
|
|
138
|
+
pred_batch_size=None,
|
|
139
|
+
**kwargs,
|
|
140
|
+
):
|
|
141
|
+
|
|
142
|
+
if cuda and not torch.cuda.is_available():
|
|
143
|
+
import warnings
|
|
144
|
+
|
|
145
|
+
warnings.warn(
|
|
146
|
+
"CUDA was requested (cuda=True), but PyTorch cannot detect a CUDA-enabled GPU. "
|
|
147
|
+
"Please check your NVIDIA drivers and reinstall PyTorch with CUDA support. "
|
|
148
|
+
"Falling back to CPU."
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# For bnn only svi is supported
|
|
152
|
+
if method == "mcmc" and model_type == "bnn":
|
|
153
|
+
raise ValueError(
|
|
154
|
+
"MCMC is currently not supported for the Bayesian neural network model. "
|
|
155
|
+
"Please use method='svi' instead."
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Check if y is presence/absence or count data
|
|
159
|
+
y_type = infer_y_type(y)
|
|
160
|
+
print(y_type)
|
|
161
|
+
|
|
162
|
+
data = prepare_data(x, y, cuda=cuda)
|
|
163
|
+
|
|
164
|
+
x = data["x"]
|
|
165
|
+
y = data["y"]
|
|
166
|
+
|
|
167
|
+
# Load the model
|
|
168
|
+
if model_type == "linear":
|
|
169
|
+
from .models import linear_model
|
|
170
|
+
|
|
171
|
+
model = linear_model
|
|
172
|
+
elif model_type == "gaussian":
|
|
173
|
+
from .models import gaussian
|
|
174
|
+
|
|
175
|
+
model = gaussian
|
|
176
|
+
elif model_type == "bnn":
|
|
177
|
+
from .models import bnn_model
|
|
178
|
+
|
|
179
|
+
model = bnn_model
|
|
180
|
+
|
|
181
|
+
# Inference
|
|
182
|
+
if method == "svi":
|
|
183
|
+
from .inference import fit_svi
|
|
184
|
+
|
|
185
|
+
fit = fit_svi(
|
|
186
|
+
model,
|
|
187
|
+
y,
|
|
188
|
+
x,
|
|
189
|
+
num_factors,
|
|
190
|
+
y_type=y_type,
|
|
191
|
+
cuda=cuda,
|
|
192
|
+
batch_size=batch_size,
|
|
193
|
+
**kwargs,
|
|
194
|
+
)
|
|
195
|
+
elif method == "mcmc":
|
|
196
|
+
from .inference import fit_mcmc
|
|
197
|
+
|
|
198
|
+
fit = fit_mcmc(
|
|
199
|
+
model, y, x, num_factors, y_type=y_type, batch_size=batch_size, **kwargs
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Compute probabilities
|
|
203
|
+
if pred_batch_size is not None:
|
|
204
|
+
n_sites = x.shape[0]
|
|
205
|
+
pred_chunks = []
|
|
206
|
+
for i in range(0, n_sites, pred_batch_size):
|
|
207
|
+
x_chunk = x[i : i + pred_batch_size]
|
|
208
|
+
samples_chunk = fit["samples"].copy()
|
|
209
|
+
if "W" in fit["samples"]:
|
|
210
|
+
w_tensor = fit["samples"]["W"]
|
|
211
|
+
if w_tensor.dim() == 4:
|
|
212
|
+
samples_chunk["W"] = w_tensor[:, :, i : i + pred_batch_size, :]
|
|
213
|
+
else:
|
|
214
|
+
samples_chunk["W"] = w_tensor[:, i : i + pred_batch_size, :]
|
|
215
|
+
|
|
216
|
+
pred_chunk = compute_predictions(
|
|
217
|
+
samples_chunk,
|
|
218
|
+
x_chunk,
|
|
219
|
+
model_type=model_type,
|
|
220
|
+
include_latent=include_latent,
|
|
221
|
+
y_type=y_type,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
if return_means:
|
|
225
|
+
pred_chunk_processed = pred_chunk.mean(dim=0).detach().cpu().numpy()
|
|
226
|
+
else:
|
|
227
|
+
pred_chunk_processed = pred_chunk.detach().cpu().numpy()
|
|
228
|
+
|
|
229
|
+
pred_chunks.append(pred_chunk_processed)
|
|
230
|
+
|
|
231
|
+
if return_means:
|
|
232
|
+
pred_np = np.concatenate(pred_chunks, axis=0)
|
|
233
|
+
pred = pd.DataFrame(
|
|
234
|
+
pred_np,
|
|
235
|
+
index=data["site_index"],
|
|
236
|
+
columns=data["y_columns"],
|
|
237
|
+
)
|
|
238
|
+
else:
|
|
239
|
+
# If returning full samples, shape of each chunk is (num_samples, batch_size, n_species)
|
|
240
|
+
# We concatenate along the site dimension (axis 1)
|
|
241
|
+
pred = np.concatenate(pred_chunks, axis=1)
|
|
242
|
+
|
|
243
|
+
else:
|
|
244
|
+
pred = compute_predictions(
|
|
245
|
+
fit["samples"],
|
|
246
|
+
x,
|
|
247
|
+
model_type=model_type,
|
|
248
|
+
include_latent=include_latent,
|
|
249
|
+
y_type=y_type,
|
|
250
|
+
)
|
|
251
|
+
if return_means:
|
|
252
|
+
pred = pred.mean(dim=0)
|
|
253
|
+
pred = pred.detach().cpu().numpy()
|
|
254
|
+
|
|
255
|
+
pred = pd.DataFrame(
|
|
256
|
+
pred,
|
|
257
|
+
index=data["site_index"],
|
|
258
|
+
columns=data["y_columns"],
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
else:
|
|
262
|
+
pred = pred.detach().cpu().numpy()
|
|
263
|
+
|
|
264
|
+
return pred
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pyro
|
|
3
|
+
from pyro.infer import SVI, Trace_ELBO, MCMC, NUTS
|
|
4
|
+
from pyro.infer.autoguide import AutoNormal
|
|
5
|
+
from pyro.optim import Adam
|
|
6
|
+
from torch import device
|
|
7
|
+
import torch
|
|
8
|
+
from pyro.infer import Predictive
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def fit_svi(
|
|
12
|
+
model,
|
|
13
|
+
y,
|
|
14
|
+
x,
|
|
15
|
+
num_factors,
|
|
16
|
+
y_type,
|
|
17
|
+
lr=0.01,
|
|
18
|
+
num_iterations=2500,
|
|
19
|
+
cuda=False,
|
|
20
|
+
batch_size=None,
|
|
21
|
+
**kwargs,
|
|
22
|
+
):
|
|
23
|
+
|
|
24
|
+
device = torch.device("cuda" if cuda and torch.cuda.is_available() else "cpu")
|
|
25
|
+
print(f"Using device: {device}")
|
|
26
|
+
|
|
27
|
+
# Define create_plates helper for autoguide subsampling
|
|
28
|
+
def create_plates(*args, **kwargs):
|
|
29
|
+
# args[1] is the species presence/absence matrix Y
|
|
30
|
+
y_data = args[1]
|
|
31
|
+
n_sites = y_data.shape[0]
|
|
32
|
+
batch_size_val = kwargs.get("batch_size", None)
|
|
33
|
+
return pyro.plate("sites", n_sites, subsample_size=batch_size_val)
|
|
34
|
+
|
|
35
|
+
# Setup Inference
|
|
36
|
+
pyro.clear_param_store()
|
|
37
|
+
guide = AutoNormal(model, create_plates=create_plates)
|
|
38
|
+
optimizer = Adam({"lr": lr})
|
|
39
|
+
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
|
|
40
|
+
|
|
41
|
+
# Training Loop
|
|
42
|
+
losses = []
|
|
43
|
+
for i in range(num_iterations):
|
|
44
|
+
loss = svi.step(x, y, num_factors, y_type, batch_size=batch_size, **kwargs)
|
|
45
|
+
losses.append(loss)
|
|
46
|
+
if i % 500 == 0:
|
|
47
|
+
print(f"Iteration {i} - Loss: {loss:.2f}")
|
|
48
|
+
|
|
49
|
+
# Simple convergence check
|
|
50
|
+
first = np.mean(losses[-200:-100])
|
|
51
|
+
last = np.mean(losses[-100:])
|
|
52
|
+
relative_change = abs(last - first) / abs(first)
|
|
53
|
+
if relative_change < 0.01:
|
|
54
|
+
print("SVI converged successfully.")
|
|
55
|
+
else:
|
|
56
|
+
print("SVI may not have converged.")
|
|
57
|
+
|
|
58
|
+
# return samples
|
|
59
|
+
predictive = Predictive(
|
|
60
|
+
model,
|
|
61
|
+
guide=guide,
|
|
62
|
+
num_samples=1000,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
samples = predictive(x, y, num_factors, y_type, batch_size=None, **kwargs)
|
|
66
|
+
|
|
67
|
+
return {
|
|
68
|
+
"method": "svi",
|
|
69
|
+
"guide": guide,
|
|
70
|
+
"losses": losses,
|
|
71
|
+
"samples": samples,
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def fit_mcmc(
|
|
76
|
+
model,
|
|
77
|
+
y,
|
|
78
|
+
x,
|
|
79
|
+
num_factors,
|
|
80
|
+
y_type,
|
|
81
|
+
num_samples=1000,
|
|
82
|
+
warmup_steps=500,
|
|
83
|
+
num_chains=1,
|
|
84
|
+
batch_size=None,
|
|
85
|
+
):
|
|
86
|
+
if batch_size is not None:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
"MCMC does not support mini-batching. Please use SVI (method='svi') or set batch_size=None."
|
|
89
|
+
)
|
|
90
|
+
pyro.clear_param_store()
|
|
91
|
+
|
|
92
|
+
kernel = NUTS(model)
|
|
93
|
+
|
|
94
|
+
mcmc = MCMC(
|
|
95
|
+
kernel,
|
|
96
|
+
num_samples=num_samples,
|
|
97
|
+
warmup_steps=warmup_steps,
|
|
98
|
+
num_chains=num_chains,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
mcmc.run(x, y, num_factors, y_type)
|
|
102
|
+
|
|
103
|
+
return {
|
|
104
|
+
"method": "mcmc",
|
|
105
|
+
"mcmc": mcmc,
|
|
106
|
+
"samples": mcmc.get_samples(),
|
|
107
|
+
}
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
# 2. Define the Model
|
|
2
|
+
import pyro
|
|
3
|
+
import pyro.distributions as dist
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def observation_dist(eta, y_type):
|
|
8
|
+
if y_type == "presence_absence":
|
|
9
|
+
return dist.Bernoulli(logits=eta)
|
|
10
|
+
|
|
11
|
+
elif y_type == "count":
|
|
12
|
+
return dist.Poisson(rate=torch.exp(torch.clamp(eta, -20, 20)))
|
|
13
|
+
# elif y_type == "count":
|
|
14
|
+
#
|
|
15
|
+
# mu = torch.exp(torch.clamp(eta, -20, 20))
|
|
16
|
+
#
|
|
17
|
+
# dispersion = pyro.sample(
|
|
18
|
+
# "dispersion",
|
|
19
|
+
# dist.LogNormal(
|
|
20
|
+
# torch.zeros(eta.shape[-1], device=eta.device),
|
|
21
|
+
# torch.ones(eta.shape[-1], device=eta.device),
|
|
22
|
+
# ).to_event(1),
|
|
23
|
+
# )
|
|
24
|
+
#
|
|
25
|
+
# return dist.NegativeBinomial(
|
|
26
|
+
# total_count=dispersion,
|
|
27
|
+
# logits=torch.log(mu) - torch.log(dispersion),
|
|
28
|
+
# )
|
|
29
|
+
|
|
30
|
+
else:
|
|
31
|
+
raise ValueError(
|
|
32
|
+
"y_type must be 'presence_absence', 'count', or 'zero_inflated_count'"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def linear_model(X, Y, num_factors, y_type="presence_absence", batch_size=None):
|
|
37
|
+
device = X.device
|
|
38
|
+
|
|
39
|
+
n_sites, n_species = Y.shape
|
|
40
|
+
n_env = X.shape[1]
|
|
41
|
+
|
|
42
|
+
alpha = pyro.sample(
|
|
43
|
+
"alpha",
|
|
44
|
+
dist.Normal(
|
|
45
|
+
torch.zeros(n_species, device=device),
|
|
46
|
+
torch.ones(n_species, device=device),
|
|
47
|
+
).to_event(1),
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
beta = pyro.sample(
|
|
51
|
+
"beta",
|
|
52
|
+
dist.Normal(
|
|
53
|
+
torch.zeros(n_env, n_species, device=device),
|
|
54
|
+
torch.ones(n_env, n_species, device=device),
|
|
55
|
+
).to_event(2),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
Z = pyro.sample(
|
|
59
|
+
"Z",
|
|
60
|
+
dist.Normal(
|
|
61
|
+
torch.zeros(n_species, num_factors, device=device),
|
|
62
|
+
torch.ones(n_species, num_factors, device=device),
|
|
63
|
+
).to_event(2),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
with pyro.plate("sites", n_sites, subsample_size=batch_size) as ind:
|
|
67
|
+
X_batch = X[ind]
|
|
68
|
+
Y_batch = Y[ind]
|
|
69
|
+
|
|
70
|
+
W_batch = pyro.sample(
|
|
71
|
+
"W",
|
|
72
|
+
dist.Normal(
|
|
73
|
+
torch.zeros(num_factors, device=device),
|
|
74
|
+
torch.ones(num_factors, device=device),
|
|
75
|
+
).to_event(1),
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
eta_batch = (
|
|
79
|
+
alpha
|
|
80
|
+
+ torch.matmul(X_batch, beta)
|
|
81
|
+
+ torch.matmul(W_batch, Z.transpose(-1, -2))
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
pyro.sample(
|
|
85
|
+
"obs",
|
|
86
|
+
observation_dist(eta_batch, y_type).to_event(1),
|
|
87
|
+
obs=Y_batch,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def gaussian(
|
|
92
|
+
X, Y, num_factors, y_type="presence_absence", batch_size=None
|
|
93
|
+
):
|
|
94
|
+
device = X.device
|
|
95
|
+
|
|
96
|
+
n_sites, n_species = Y.shape
|
|
97
|
+
n_env = X.shape[1]
|
|
98
|
+
|
|
99
|
+
alpha = pyro.sample(
|
|
100
|
+
"alpha",
|
|
101
|
+
dist.Normal(
|
|
102
|
+
torch.zeros(n_species, device=device),
|
|
103
|
+
torch.ones(n_species, device=device),
|
|
104
|
+
).to_event(1),
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# species-specific environmental optimum
|
|
108
|
+
mu = pyro.sample(
|
|
109
|
+
"mu",
|
|
110
|
+
dist.Normal(
|
|
111
|
+
torch.zeros(n_env, n_species, device=device),
|
|
112
|
+
torch.ones(n_env, n_species, device=device),
|
|
113
|
+
).to_event(2),
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# positive niche strength / inverse width
|
|
117
|
+
gamma = pyro.sample(
|
|
118
|
+
"gamma",
|
|
119
|
+
dist.LogNormal(
|
|
120
|
+
torch.zeros(n_env, n_species, device=device),
|
|
121
|
+
torch.ones(n_env, n_species, device=device),
|
|
122
|
+
).to_event(2),
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
Z = pyro.sample(
|
|
126
|
+
"Z",
|
|
127
|
+
dist.Normal(
|
|
128
|
+
torch.zeros(n_species, num_factors, device=device),
|
|
129
|
+
torch.ones(n_species, num_factors, device=device),
|
|
130
|
+
).to_event(2),
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
with pyro.plate("sites", n_sites, subsample_size=batch_size) as ind:
|
|
134
|
+
X_batch = X[ind]
|
|
135
|
+
Y_batch = Y[ind]
|
|
136
|
+
|
|
137
|
+
W_batch = pyro.sample(
|
|
138
|
+
"W",
|
|
139
|
+
dist.Normal(
|
|
140
|
+
torch.zeros(num_factors, device=device),
|
|
141
|
+
torch.ones(num_factors, device=device),
|
|
142
|
+
).to_event(1),
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# quadratic environmental response
|
|
146
|
+
diff = X_batch[..., None] - mu[..., None, :, :]
|
|
147
|
+
env_effect = -torch.sum(gamma[..., None, :, :] * diff**2, dim=-2)
|
|
148
|
+
|
|
149
|
+
eta_batch = alpha + env_effect + torch.matmul(W_batch, Z.transpose(-1, -2))
|
|
150
|
+
|
|
151
|
+
pyro.sample(
|
|
152
|
+
"obs",
|
|
153
|
+
observation_dist(eta_batch, y_type).to_event(1),
|
|
154
|
+
obs=Y_batch,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def bnn_model(
|
|
159
|
+
X, Y, num_factors=1, y_type="presence_absence", hidden_size=10, batch_size=None
|
|
160
|
+
):
|
|
161
|
+
device = X.device
|
|
162
|
+
|
|
163
|
+
n_sites, n_species = Y.shape
|
|
164
|
+
n_env = X.shape[1]
|
|
165
|
+
|
|
166
|
+
w1 = pyro.sample(
|
|
167
|
+
"w1",
|
|
168
|
+
dist.Normal(
|
|
169
|
+
torch.zeros(n_env, hidden_size, device=device),
|
|
170
|
+
torch.ones(n_env, hidden_size, device=device),
|
|
171
|
+
).to_event(2),
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
b1 = pyro.sample(
|
|
175
|
+
"b1",
|
|
176
|
+
dist.Normal(
|
|
177
|
+
torch.zeros(hidden_size, device=device),
|
|
178
|
+
torch.ones(hidden_size, device=device),
|
|
179
|
+
).to_event(1),
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
w2 = pyro.sample(
|
|
183
|
+
"w2",
|
|
184
|
+
dist.Normal(
|
|
185
|
+
torch.zeros(hidden_size, n_species, device=device),
|
|
186
|
+
torch.ones(hidden_size, n_species, device=device),
|
|
187
|
+
).to_event(2),
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
b2 = pyro.sample(
|
|
191
|
+
"b2",
|
|
192
|
+
dist.Normal(
|
|
193
|
+
torch.zeros(n_species, device=device),
|
|
194
|
+
torch.ones(n_species, device=device),
|
|
195
|
+
).to_event(1),
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
Z = None
|
|
199
|
+
if num_factors > 0:
|
|
200
|
+
Z = pyro.sample(
|
|
201
|
+
"Z",
|
|
202
|
+
dist.Normal(
|
|
203
|
+
torch.zeros(n_species, num_factors, device=device),
|
|
204
|
+
torch.ones(n_species, num_factors, device=device),
|
|
205
|
+
).to_event(2),
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
with pyro.plate("sites", n_sites, subsample_size=batch_size) as ind:
|
|
209
|
+
X_batch = X[ind]
|
|
210
|
+
Y_batch = Y[ind]
|
|
211
|
+
|
|
212
|
+
hidden = torch.tanh(X_batch @ w1 + b1)
|
|
213
|
+
eta_batch = hidden @ w2 + b2
|
|
214
|
+
|
|
215
|
+
if num_factors > 0:
|
|
216
|
+
W_batch = pyro.sample(
|
|
217
|
+
"W",
|
|
218
|
+
dist.Normal(
|
|
219
|
+
torch.zeros(num_factors, device=device),
|
|
220
|
+
torch.ones(num_factors, device=device),
|
|
221
|
+
).to_event(1),
|
|
222
|
+
)
|
|
223
|
+
eta_batch = eta_batch + W_batch @ Z.transpose(-1, -2)
|
|
224
|
+
|
|
225
|
+
pyro.sample(
|
|
226
|
+
"obs",
|
|
227
|
+
observation_dist(eta_batch, y_type).to_event(1),
|
|
228
|
+
obs=Y_batch,
|
|
229
|
+
)
|