pmf-dark 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,233 @@
1
+ Metadata-Version: 2.4
2
+ Name: pmf-dark
3
+ Version: 0.1.0
4
+ Summary: Modelling species dark diversity using bayesian Probablistic Matrix Factorisation
5
+ License: Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
6
+ Author: davidyshen
7
+ Author-email: contact@davidyshen.com
8
+ Requires-Python: >=3.13,<3.15
9
+ Classifier: License :: Other/Proprietary License
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Programming Language :: Python :: 3.13
12
+ Classifier: Programming Language :: Python :: 3.14
13
+ Requires-Dist: matplotlib (>=3.10.9,<4.0.0)
14
+ Requires-Dist: numpy (>=2.4.6,<3.0.0)
15
+ Requires-Dist: pandas (>=3.0.3,<4.0.0)
16
+ Requires-Dist: pyro-ppl (>=1.9.0,<2.0.0)
17
+ Description-Content-Type: text/markdown
18
+
19
+ # PMF-dark: Using matrix factorisation for dark diversity estimation
20
+
21
+ ## Overview
22
+
23
+ This repository implements a **PMF-dark** using Bayesian Probabilistic Matrix Factorisation to estimate **dark diversity** - the set of species absent from a site despite having suitable environmental conditions. The method uses **counterfactual predictions** to reconstruct the potential species pool by separating environmental effects from unmeasured drivers of absence (e.g., land-use degradation, dispersal limitation, biotic interactions).
24
+
25
+ ## The Problem: What is Dark Diversity?
26
+
27
+ Traditional biodiversity assessments only count observed species (alpha diversity). However, many species are absent from sites where they *could* thrive based on environmental conditions. This "**dark diversity**" represents:
28
+
29
+ - Species lost due to historical or ongoing land-use degradation
30
+ - Species unable to reach suitable sites due to dispersal limitation
31
+ - Species suppressed by biotic interactions
32
+
33
+ Quantifying dark diversity is crucial for:
34
+ - Conservation planning and restoration potential assessment
35
+ - Understanding true biodiversity patterns
36
+ - Identifying areas with highest restoration value
37
+
38
+ ## Methodology
39
+
40
+ ### Core Model
41
+
42
+ The framework decomposes species occurrence probabilities into **three additive components**:
43
+
44
+ $$\text{logit}(p_{ij}) = \underbrace{\alpha_j}_{\text{Intercept}} + \underbrace{f_j(\mathbf{x}_i)}_{\text{Environmental Effects}} + \underbrace{\mathbf{w}_i^\top \mathbf{z}_j}_{\text{Latent Factors}}$$
45
+
46
+ Where:
47
+ - **$\alpha_j$**: Species-specific baseline prevalence
48
+ - **$f_j(\mathbf{x}_i)$**: Environmental response function to measured abiotic variables (temperature, pH, elevation, etc.), which can be modelled as linear, Gaussian niche, or non-linear (e.g. Bayesian neural network)
49
+ - **$\mathbf{w}_i^\top \mathbf{z}_j$**: Latent factors capturing unmeasured drivers of absence
50
+
51
+ ### Key Innovation: Counterfactual Predictions
52
+
53
+ 1. **Full Predictions**: Include all components (environment + latent factors)
54
+ - Represents observed diversity with all drivers active
55
+
56
+ 2. **Environment-Only Predictions**: Exclude latent factors
57
+ - Represents potential diversity (setting $\mathbf{w}_i^\top \mathbf{z}_j = 0$)
58
+
59
+ 3. **Dark Diversity Proxy**: Difference between full and environment-only predictions
60
+ - Quantifies species lost to unmeasured stressors
61
+
62
+ ### Inference: Stochastic Variational Inference (SVI)
63
+
64
+ The model is fit using **Pyro-based SVI**, which:
65
+ - Handles high-dimensional ecological matrices efficiently
66
+ - Treats inference as an optimisation problem (ELBO maximisation)
67
+ - Scales to thousands of sites and species
68
+ - Requires minimal computational resources
69
+
70
+ ## Repository Structure
71
+
72
+ ```
73
+ PMF_dark/
74
+ ├── README.md # This file
75
+ ├── mat_fact_dark_div.ipynb # Main analysis notebook
76
+ ├── data/
77
+ │ ├── survey.csv # Species presence/absence matrix (sites × species)
78
+ │ ├── env.csv # Environmental predictors (sites × covariates)
79
+ │ └── truth.csv # Ground truth data (if available)
80
+ └── output/
81
+ ├── mat_fact_predicted_probabilities_full.csv # Full model predictions
82
+ ├── mat_fact_predicted_probabilities_env_only.csv # Environment-only predictions
83
+ └── mat_fact_dark_diversity_proxy.csv # Dark diversity estimates
84
+ ```
85
+
86
+ ## Installation
87
+
88
+ ### Requirements
89
+
90
+ - Python 3.13 or 3.14
91
+ - PyTorch (with CUDA support if using GPU)
92
+ - Pyro (pyro-ppl)
93
+ - Pandas
94
+ - NumPy
95
+ - scikit-learn
96
+ - scipy
97
+
98
+ ### Setup
99
+
100
+ To ensure PyTorch is installed with the correct CUDA version for your system, it is recommended to install PyTorch manually **first** before installing the package or its other dependencies.
101
+
102
+ #### 1. Setup Virtual Environment
103
+ ```bash
104
+ # Clone the repository
105
+ git clone https://github.com/davidyshen/PMF_dark.git
106
+ cd PMF_dark
107
+
108
+ # Create virtual environment (optional but recommended)
109
+ python -m venv .venv
110
+ source .venv/bin/activate # On Windows: .venv\Scripts\activate
111
+ ```
112
+
113
+ #### 2. Install PyTorch with CUDA
114
+ Visit the [PyTorch Getting Started guide](https://pytorch.org/get-started/locally/) to select the correct command for your CUDA version and OS. For example, to install PyTorch with CUDA 12.4 support on Windows/Linux:
115
+ ```bash
116
+ pip install torch --index-url https://download.pytorch.org/whl/cu124
117
+ ```
118
+
119
+ If not using CUDA, simply install the CPU version:
120
+ ```bash
121
+ pip install torch
122
+ ```
123
+
124
+
125
+ #### 3. Install remaining dependencies
126
+ If installing via pip:
127
+ ```bash
128
+ pip install pyro-ppl pandas numpy scikit-learn scipy jupyter
129
+ ```
130
+
131
+ If using Poetry:
132
+ ```bash
133
+ # This will install the package and its remaining dependencies into your environment
134
+ poetry install
135
+ ```
136
+
137
+ ## Usage
138
+
139
+ ### Running the Full Analysis
140
+
141
+ 1. Prepare your data in `data/` directory:
142
+ - `survey.csv`: Species presence/absence (rows = sites, columns = species, values = 0/1)
143
+ - `env.csv`: Environmental predictors (rows = sites, columns = variables)
144
+
145
+ 2. Open and run the Jupyter notebook:
146
+ ```bash
147
+ jupyter notebook mat_fact_dark_div.ipynb
148
+ ```
149
+
150
+ 3. The notebook will:
151
+ - Load and standardise data
152
+ - Fit the matrix factorisation model (2,500 iterations)
153
+ - Generate predictions and save CSV outputs
154
+
155
+ ### Customisation
156
+
157
+ Key parameters in the notebook:
158
+
159
+ ```python
160
+ # Model parameters
161
+ num_factors = 5 # Number of latent factors (adjust based on data complexity)
162
+ num_iterations = 2500 # Model iterations
163
+
164
+ # Learning rate
165
+ Adam({"lr": 0.01}) # Adjust if convergence is slow
166
+ ```
167
+
168
+ ## Output Files
169
+
170
+ - **mat_fact_predicted_probabilities_full.csv**: Predicted species occurrence probabilities including all effects
171
+ - **mat_fact_predicted_probabilities_env_only.csv**: Predicted probabilities using only environmental effects
172
+ - **mat_fact_dark_diversity_proxy.csv**: Dark diversity estimates (full - env_only)
173
+
174
+ ## Data Format
175
+
176
+ ### survey.csv
177
+ ```
178
+ site_id,species_1,species_2,...,species_n,ID,x,y
179
+ site_1,0,1,0,...,1,id_1,100.5,200.3
180
+ site_2,1,0,1,...,0,id_2,101.2,201.5
181
+ ...
182
+ ```
183
+ - Rows: Sites/locations
184
+ - Columns: Species (0/1 presence/absence) + ID + spatial coordinates
185
+ - **Note**: ID and spatial coordinates are automatically extracted/dropped
186
+
187
+ ### env.csv
188
+ ```
189
+ site_id,temp,pH,elevation,...,ID,landuse
190
+ site_1,15.2,7.1,500,...,id_1,degraded
191
+ site_2,14.8,6.9,520,...,id_2,pristine
192
+ ...
193
+ ```
194
+ - Rows: Sites matching survey.csv
195
+ - Columns: Environmental predictors + ID + land-use
196
+ - **Note**: ID and land-use columns are dropped; only abiotic predictors are used
197
+
198
+ ## Interpretation of Results
199
+
200
+ ### Dark Diversity Proxy Values
201
+ - **High values (close to 1)**: Species should be present based on environment but are absent—candidate for restoration
202
+ - **Low values (close to 0)**: Species absence explained by environmental conditions
203
+ - **Negative values**: Model predicts species should be absent (rare, indicates environmental unsuitability)
204
+
205
+ ### Key Metrics
206
+ - **AUC (Area Under ROC Curve)**: Overall model discrimination (0.5 = random, 1.0 = perfect)
207
+ - **Brier Score**: Prediction calibration error (lower is better)
208
+ - **F1 Score**: Balance between precision and recall
209
+
210
+ ## Advantages of This Approach
211
+
212
+ ✓ **No subjective benchmarking**: Automated separation of environmental vs. unmeasured effects
213
+ ✓ **Mathematically principled**: Latent factors naturally absorb degradation signals
214
+ ✓ **Scalable**: SVI handles thousands of species and sites
215
+ ✓ **Species-specific**: Each species can have unique environmental responses
216
+ ✓ **Reproducible**: Fully probabilistic framework with clear assumptions
217
+
218
+ ## Limitations
219
+
220
+ - Assumes species responses are log-linear (logit link)
221
+ - Requires sufficient environmental variation to estimate effects reliably
222
+ - May overestimate dark diversity if detection is imperfect
223
+ - Computational cost increases with number of species and sites
224
+ - Requires careful tuning of number of latent factors
225
+
226
+ ## References & Theoretical Background
227
+
228
+ ### Key Concepts
229
+ - **Joint Species Distribution Models (JSDMs)**: Latent variable models for multivariate species data
230
+ - **Matrix Factorisation**: Low-rank decomposition of high-dimensional species matrices
231
+ - **Stochastic Variational Inference**: Scalable Bayesian inference for probabilistic models
232
+ - **Counterfactual Predictions**: Causal inference approach to estimate potential outcomes
233
+
@@ -0,0 +1,214 @@
1
+ # PMF-dark: Using matrix factorisation for dark diversity estimation
2
+
3
+ ## Overview
4
+
5
+ This repository implements a **PMF-dark** using Bayesian Probabilistic Matrix Factorisation to estimate **dark diversity** - the set of species absent from a site despite having suitable environmental conditions. The method uses **counterfactual predictions** to reconstruct the potential species pool by separating environmental effects from unmeasured drivers of absence (e.g., land-use degradation, dispersal limitation, biotic interactions).
6
+
7
+ ## The Problem: What is Dark Diversity?
8
+
9
+ Traditional biodiversity assessments only count observed species (alpha diversity). However, many species are absent from sites where they *could* thrive based on environmental conditions. This "**dark diversity**" represents:
10
+
11
+ - Species lost due to historical or ongoing land-use degradation
12
+ - Species unable to reach suitable sites due to dispersal limitation
13
+ - Species suppressed by biotic interactions
14
+
15
+ Quantifying dark diversity is crucial for:
16
+ - Conservation planning and restoration potential assessment
17
+ - Understanding true biodiversity patterns
18
+ - Identifying areas with highest restoration value
19
+
20
+ ## Methodology
21
+
22
+ ### Core Model
23
+
24
+ The framework decomposes species occurrence probabilities into **three additive components**:
25
+
26
+ $$\text{logit}(p_{ij}) = \underbrace{\alpha_j}_{\text{Intercept}} + \underbrace{f_j(\mathbf{x}_i)}_{\text{Environmental Effects}} + \underbrace{\mathbf{w}_i^\top \mathbf{z}_j}_{\text{Latent Factors}}$$
27
+
28
+ Where:
29
+ - **$\alpha_j$**: Species-specific baseline prevalence
30
+ - **$f_j(\mathbf{x}_i)$**: Environmental response function to measured abiotic variables (temperature, pH, elevation, etc.), which can be modelled as linear, Gaussian niche, or non-linear (e.g. Bayesian neural network)
31
+ - **$\mathbf{w}_i^\top \mathbf{z}_j$**: Latent factors capturing unmeasured drivers of absence
32
+
33
+ ### Key Innovation: Counterfactual Predictions
34
+
35
+ 1. **Full Predictions**: Include all components (environment + latent factors)
36
+ - Represents observed diversity with all drivers active
37
+
38
+ 2. **Environment-Only Predictions**: Exclude latent factors
39
+ - Represents potential diversity (setting $\mathbf{w}_i^\top \mathbf{z}_j = 0$)
40
+
41
+ 3. **Dark Diversity Proxy**: Difference between full and environment-only predictions
42
+ - Quantifies species lost to unmeasured stressors
43
+
44
+ ### Inference: Stochastic Variational Inference (SVI)
45
+
46
+ The model is fit using **Pyro-based SVI**, which:
47
+ - Handles high-dimensional ecological matrices efficiently
48
+ - Treats inference as an optimisation problem (ELBO maximisation)
49
+ - Scales to thousands of sites and species
50
+ - Requires minimal computational resources
51
+
52
+ ## Repository Structure
53
+
54
+ ```
55
+ PMF_dark/
56
+ ├── README.md # This file
57
+ ├── mat_fact_dark_div.ipynb # Main analysis notebook
58
+ ├── data/
59
+ │ ├── survey.csv # Species presence/absence matrix (sites × species)
60
+ │ ├── env.csv # Environmental predictors (sites × covariates)
61
+ │ └── truth.csv # Ground truth data (if available)
62
+ └── output/
63
+ ├── mat_fact_predicted_probabilities_full.csv # Full model predictions
64
+ ├── mat_fact_predicted_probabilities_env_only.csv # Environment-only predictions
65
+ └── mat_fact_dark_diversity_proxy.csv # Dark diversity estimates
66
+ ```
67
+
68
+ ## Installation
69
+
70
+ ### Requirements
71
+
72
+ - Python 3.13 or 3.14
73
+ - PyTorch (with CUDA support if using GPU)
74
+ - Pyro (pyro-ppl)
75
+ - Pandas
76
+ - NumPy
77
+ - scikit-learn
78
+ - scipy
79
+
80
+ ### Setup
81
+
82
+ To ensure PyTorch is installed with the correct CUDA version for your system, it is recommended to install PyTorch manually **first** before installing the package or its other dependencies.
83
+
84
+ #### 1. Setup Virtual Environment
85
+ ```bash
86
+ # Clone the repository
87
+ git clone https://github.com/davidyshen/PMF_dark.git
88
+ cd PMF_dark
89
+
90
+ # Create virtual environment (optional but recommended)
91
+ python -m venv .venv
92
+ source .venv/bin/activate # On Windows: .venv\Scripts\activate
93
+ ```
94
+
95
+ #### 2. Install PyTorch with CUDA
96
+ Visit the [PyTorch Getting Started guide](https://pytorch.org/get-started/locally/) to select the correct command for your CUDA version and OS. For example, to install PyTorch with CUDA 12.4 support on Windows/Linux:
97
+ ```bash
98
+ pip install torch --index-url https://download.pytorch.org/whl/cu124
99
+ ```
100
+
101
+ If not using CUDA, simply install the CPU version:
102
+ ```bash
103
+ pip install torch
104
+ ```
105
+
106
+
107
+ #### 3. Install remaining dependencies
108
+ If installing via pip:
109
+ ```bash
110
+ pip install pyro-ppl pandas numpy scikit-learn scipy jupyter
111
+ ```
112
+
113
+ If using Poetry:
114
+ ```bash
115
+ # This will install the package and its remaining dependencies into your environment
116
+ poetry install
117
+ ```
118
+
119
+ ## Usage
120
+
121
+ ### Running the Full Analysis
122
+
123
+ 1. Prepare your data in `data/` directory:
124
+ - `survey.csv`: Species presence/absence (rows = sites, columns = species, values = 0/1)
125
+ - `env.csv`: Environmental predictors (rows = sites, columns = variables)
126
+
127
+ 2. Open and run the Jupyter notebook:
128
+ ```bash
129
+ jupyter notebook mat_fact_dark_div.ipynb
130
+ ```
131
+
132
+ 3. The notebook will:
133
+ - Load and standardise data
134
+ - Fit the matrix factorisation model (2,500 iterations)
135
+ - Generate predictions and save CSV outputs
136
+
137
+ ### Customisation
138
+
139
+ Key parameters in the notebook:
140
+
141
+ ```python
142
+ # Model parameters
143
+ num_factors = 5 # Number of latent factors (adjust based on data complexity)
144
+ num_iterations = 2500 # Model iterations
145
+
146
+ # Learning rate
147
+ Adam({"lr": 0.01}) # Adjust if convergence is slow
148
+ ```
149
+
150
+ ## Output Files
151
+
152
+ - **mat_fact_predicted_probabilities_full.csv**: Predicted species occurrence probabilities including all effects
153
+ - **mat_fact_predicted_probabilities_env_only.csv**: Predicted probabilities using only environmental effects
154
+ - **mat_fact_dark_diversity_proxy.csv**: Dark diversity estimates (full - env_only)
155
+
156
+ ## Data Format
157
+
158
+ ### survey.csv
159
+ ```
160
+ site_id,species_1,species_2,...,species_n,ID,x,y
161
+ site_1,0,1,0,...,1,id_1,100.5,200.3
162
+ site_2,1,0,1,...,0,id_2,101.2,201.5
163
+ ...
164
+ ```
165
+ - Rows: Sites/locations
166
+ - Columns: Species (0/1 presence/absence) + ID + spatial coordinates
167
+ - **Note**: ID and spatial coordinates are automatically extracted/dropped
168
+
169
+ ### env.csv
170
+ ```
171
+ site_id,temp,pH,elevation,...,ID,landuse
172
+ site_1,15.2,7.1,500,...,id_1,degraded
173
+ site_2,14.8,6.9,520,...,id_2,pristine
174
+ ...
175
+ ```
176
+ - Rows: Sites matching survey.csv
177
+ - Columns: Environmental predictors + ID + land-use
178
+ - **Note**: ID and land-use columns are dropped; only abiotic predictors are used
179
+
180
+ ## Interpretation of Results
181
+
182
+ ### Dark Diversity Proxy Values
183
+ - **High values (close to 1)**: Species should be present based on environment but are absent—candidate for restoration
184
+ - **Low values (close to 0)**: Species absence explained by environmental conditions
185
+ - **Negative values**: Model predicts species should be absent (rare, indicates environmental unsuitability)
186
+
187
+ ### Key Metrics
188
+ - **AUC (Area Under ROC Curve)**: Overall model discrimination (0.5 = random, 1.0 = perfect)
189
+ - **Brier Score**: Prediction calibration error (lower is better)
190
+ - **F1 Score**: Balance between precision and recall
191
+
192
+ ## Advantages of This Approach
193
+
194
+ ✓ **No subjective benchmarking**: Automated separation of environmental vs. unmeasured effects
195
+ ✓ **Mathematically principled**: Latent factors naturally absorb degradation signals
196
+ ✓ **Scalable**: SVI handles thousands of species and sites
197
+ ✓ **Species-specific**: Each species can have unique environmental responses
198
+ ✓ **Reproducible**: Fully probabilistic framework with clear assumptions
199
+
200
+ ## Limitations
201
+
202
+ - Assumes species responses are log-linear (logit link)
203
+ - Requires sufficient environmental variation to estimate effects reliably
204
+ - May overestimate dark diversity if detection is imperfect
205
+ - Computational cost increases with number of species and sites
206
+ - Requires careful tuning of number of latent factors
207
+
208
+ ## References & Theoretical Background
209
+
210
+ ### Key Concepts
211
+ - **Joint Species Distribution Models (JSDMs)**: Latent variable models for multivariate species data
212
+ - **Matrix Factorisation**: Low-rank decomposition of high-dimensional species matrices
213
+ - **Stochastic Variational Inference**: Scalable Bayesian inference for probabilistic models
214
+ - **Counterfactual Predictions**: Causal inference approach to estimate potential outcomes
@@ -0,0 +1,26 @@
1
+ [project]
2
+ name = "pmf-dark"
3
+ version = "0.1.0"
4
+ description = "Modelling species dark diversity using bayesian Probablistic Matrix Factorisation"
5
+ authors = [
6
+ {name = "davidyshen",email = "contact@davidyshen.com"}
7
+ ]
8
+ license = {text = "Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE."}
9
+ readme = "README.md"
10
+ requires-python = ">=3.13,<3.15"
11
+ dependencies = [
12
+ "numpy (>=2.4.6,<3.0.0)",
13
+ "pandas (>=3.0.3,<4.0.0)",
14
+ "matplotlib (>=3.10.9,<4.0.0)",
15
+ "pyro-ppl (>=1.9.0,<2.0.0)"
16
+ ]
17
+
18
+ [tool.poetry]
19
+ packages = [{ include = "pmf_dark", from = "src" }]
20
+
21
+ [tool.poetry.group.dev.dependencies]
22
+ ipykernel = "^7.2.0"
23
+
24
+ [build-system]
25
+ requires = ["poetry-core>=2.0.0,<3.0.0"]
26
+ build-backend = "poetry.core.masonry.api"
@@ -0,0 +1,10 @@
1
+ import torch
2
+ from .darkdiv import compute_dark_diversity
3
+
4
+ # Print CUDA availability on import
5
+ if torch.cuda.is_available():
6
+ print(f"pmf-dark: CUDA is available. GPU: {torch.cuda.get_device_name(0)}")
7
+ else:
8
+ print("pmf-dark: CUDA is not available. Using CPU.")
9
+
10
+ __all__ = ["compute_dark_diversity"]
@@ -0,0 +1,264 @@
1
+ import torch
2
+ import pyro
3
+ import matplotlib.pyplot as plt
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+
9
+ def infer_y_type(y):
10
+
11
+ # torch tensor
12
+ if isinstance(y, torch.Tensor):
13
+
14
+ values = y.detach().cpu().numpy()
15
+
16
+ # pandas dataframe
17
+ elif hasattr(y, "to_numpy"):
18
+
19
+ values = y.to_numpy()
20
+
21
+ # numpy or other
22
+ else:
23
+
24
+ values = np.asarray(y)
25
+
26
+ # Check missing values
27
+ if np.isnan(values).any():
28
+ raise ValueError("y contains missing values.")
29
+
30
+ # Binary data
31
+ if np.isin(values, [0, 1]).all():
32
+ return "presence_absence"
33
+
34
+ # Count data
35
+ elif (values >= 0).all():
36
+ return "count"
37
+
38
+ else:
39
+ raise ValueError("y must contain either binary or count data.")
40
+
41
+
42
+ def prepare_data(x, y, cuda=False):
43
+
44
+ # Keep names
45
+ x_columns = x.columns
46
+ y_columns = y.columns
47
+ site_index = y.index
48
+
49
+ # Standardize x
50
+ x = (x - x.mean()) / x.std()
51
+
52
+ # Convert to tensors
53
+ x_tensor = torch.tensor(
54
+ x.to_numpy(),
55
+ dtype=torch.float32,
56
+ )
57
+
58
+ y_tensor = torch.tensor(
59
+ y.to_numpy(),
60
+ dtype=torch.float32,
61
+ )
62
+
63
+ # Device
64
+ device = torch.device("cuda" if cuda and torch.cuda.is_available() else "cpu")
65
+
66
+ x_tensor = x_tensor.to(device)
67
+ y_tensor = y_tensor.to(device)
68
+
69
+ return {
70
+ "x": x_tensor,
71
+ "y": y_tensor,
72
+ "x_columns": x_columns,
73
+ "y_columns": y_columns,
74
+ "site_index": site_index,
75
+ }
76
+
77
+
78
+ def compute_predictions(
79
+ samples, x, model_type="gaussian", include_latent=True, y_type="presence_absence"
80
+ ):
81
+
82
+ if model_type == "linear":
83
+ alpha = samples["alpha"].squeeze(1)
84
+ beta = samples["beta"].squeeze(1)
85
+ eta = alpha[:, None, :] + torch.einsum("ij,sjk->sik", x, beta)
86
+
87
+ elif model_type == "gaussian":
88
+ alpha = samples["alpha"].squeeze(1)
89
+ mu = samples["mu"].squeeze(1)
90
+ gamma = samples["gamma"].squeeze(1)
91
+
92
+ # x: [sites, env]
93
+ # mu: [samples, env, species]
94
+ # gamma: [samples, env, species]
95
+ diff = x[None, :, :, None] - mu[:, None, :, :]
96
+ env_effect = -torch.sum(gamma[:, None, :, :] * diff**2, dim=2)
97
+
98
+ eta = alpha[:, None, :] + env_effect
99
+
100
+ elif model_type == "bnn":
101
+ w1 = samples["w1"].squeeze(1)
102
+ b1 = samples["b1"].squeeze(1)
103
+ w2 = samples["w2"].squeeze(1)
104
+ b2 = samples["b2"].squeeze(1)
105
+
106
+ hidden = torch.tanh(torch.einsum("ij,sjh->sih", x, w1) + b1[:, None, :])
107
+
108
+ eta = torch.einsum("sih,shk->sik", hidden, w2) + b2[:, None, :]
109
+
110
+ else:
111
+ raise ValueError(f"Unknown model_type: {model_type}")
112
+
113
+ if include_latent:
114
+ W = samples["W"].squeeze(1)
115
+ Z = samples["Z"].squeeze(1)
116
+ eta = eta + torch.einsum("sik,sjk->sij", W, Z)
117
+
118
+ if y_type == "presence_absence":
119
+ return torch.sigmoid(eta)
120
+
121
+ elif y_type == "count":
122
+ return torch.exp(eta)
123
+
124
+ else:
125
+ raise ValueError("y_type must be 'presence_absence' or 'count'")
126
+
127
+
128
+ def compute_dark_diversity(
129
+ y,
130
+ x,
131
+ model_type="linear",
132
+ num_factors=1,
133
+ method="svi",
134
+ cuda=False,
135
+ include_latent=True,
136
+ return_means=True,
137
+ batch_size=None,
138
+ pred_batch_size=None,
139
+ **kwargs,
140
+ ):
141
+
142
+ if cuda and not torch.cuda.is_available():
143
+ import warnings
144
+
145
+ warnings.warn(
146
+ "CUDA was requested (cuda=True), but PyTorch cannot detect a CUDA-enabled GPU. "
147
+ "Please check your NVIDIA drivers and reinstall PyTorch with CUDA support. "
148
+ "Falling back to CPU."
149
+ )
150
+
151
+ # For bnn only svi is supported
152
+ if method == "mcmc" and model_type == "bnn":
153
+ raise ValueError(
154
+ "MCMC is currently not supported for the Bayesian neural network model. "
155
+ "Please use method='svi' instead."
156
+ )
157
+
158
+ # Check if y is presence/absence or count data
159
+ y_type = infer_y_type(y)
160
+ print(y_type)
161
+
162
+ data = prepare_data(x, y, cuda=cuda)
163
+
164
+ x = data["x"]
165
+ y = data["y"]
166
+
167
+ # Load the model
168
+ if model_type == "linear":
169
+ from .models import linear_model
170
+
171
+ model = linear_model
172
+ elif model_type == "gaussian":
173
+ from .models import gaussian
174
+
175
+ model = gaussian
176
+ elif model_type == "bnn":
177
+ from .models import bnn_model
178
+
179
+ model = bnn_model
180
+
181
+ # Inference
182
+ if method == "svi":
183
+ from .inference import fit_svi
184
+
185
+ fit = fit_svi(
186
+ model,
187
+ y,
188
+ x,
189
+ num_factors,
190
+ y_type=y_type,
191
+ cuda=cuda,
192
+ batch_size=batch_size,
193
+ **kwargs,
194
+ )
195
+ elif method == "mcmc":
196
+ from .inference import fit_mcmc
197
+
198
+ fit = fit_mcmc(
199
+ model, y, x, num_factors, y_type=y_type, batch_size=batch_size, **kwargs
200
+ )
201
+
202
+ # Compute probabilities
203
+ if pred_batch_size is not None:
204
+ n_sites = x.shape[0]
205
+ pred_chunks = []
206
+ for i in range(0, n_sites, pred_batch_size):
207
+ x_chunk = x[i : i + pred_batch_size]
208
+ samples_chunk = fit["samples"].copy()
209
+ if "W" in fit["samples"]:
210
+ w_tensor = fit["samples"]["W"]
211
+ if w_tensor.dim() == 4:
212
+ samples_chunk["W"] = w_tensor[:, :, i : i + pred_batch_size, :]
213
+ else:
214
+ samples_chunk["W"] = w_tensor[:, i : i + pred_batch_size, :]
215
+
216
+ pred_chunk = compute_predictions(
217
+ samples_chunk,
218
+ x_chunk,
219
+ model_type=model_type,
220
+ include_latent=include_latent,
221
+ y_type=y_type,
222
+ )
223
+
224
+ if return_means:
225
+ pred_chunk_processed = pred_chunk.mean(dim=0).detach().cpu().numpy()
226
+ else:
227
+ pred_chunk_processed = pred_chunk.detach().cpu().numpy()
228
+
229
+ pred_chunks.append(pred_chunk_processed)
230
+
231
+ if return_means:
232
+ pred_np = np.concatenate(pred_chunks, axis=0)
233
+ pred = pd.DataFrame(
234
+ pred_np,
235
+ index=data["site_index"],
236
+ columns=data["y_columns"],
237
+ )
238
+ else:
239
+ # If returning full samples, shape of each chunk is (num_samples, batch_size, n_species)
240
+ # We concatenate along the site dimension (axis 1)
241
+ pred = np.concatenate(pred_chunks, axis=1)
242
+
243
+ else:
244
+ pred = compute_predictions(
245
+ fit["samples"],
246
+ x,
247
+ model_type=model_type,
248
+ include_latent=include_latent,
249
+ y_type=y_type,
250
+ )
251
+ if return_means:
252
+ pred = pred.mean(dim=0)
253
+ pred = pred.detach().cpu().numpy()
254
+
255
+ pred = pd.DataFrame(
256
+ pred,
257
+ index=data["site_index"],
258
+ columns=data["y_columns"],
259
+ )
260
+
261
+ else:
262
+ pred = pred.detach().cpu().numpy()
263
+
264
+ return pred
@@ -0,0 +1,107 @@
1
+ import numpy as np
2
+ import pyro
3
+ from pyro.infer import SVI, Trace_ELBO, MCMC, NUTS
4
+ from pyro.infer.autoguide import AutoNormal
5
+ from pyro.optim import Adam
6
+ from torch import device
7
+ import torch
8
+ from pyro.infer import Predictive
9
+
10
+
11
+ def fit_svi(
12
+ model,
13
+ y,
14
+ x,
15
+ num_factors,
16
+ y_type,
17
+ lr=0.01,
18
+ num_iterations=2500,
19
+ cuda=False,
20
+ batch_size=None,
21
+ **kwargs,
22
+ ):
23
+
24
+ device = torch.device("cuda" if cuda and torch.cuda.is_available() else "cpu")
25
+ print(f"Using device: {device}")
26
+
27
+ # Define create_plates helper for autoguide subsampling
28
+ def create_plates(*args, **kwargs):
29
+ # args[1] is the species presence/absence matrix Y
30
+ y_data = args[1]
31
+ n_sites = y_data.shape[0]
32
+ batch_size_val = kwargs.get("batch_size", None)
33
+ return pyro.plate("sites", n_sites, subsample_size=batch_size_val)
34
+
35
+ # Setup Inference
36
+ pyro.clear_param_store()
37
+ guide = AutoNormal(model, create_plates=create_plates)
38
+ optimizer = Adam({"lr": lr})
39
+ svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
40
+
41
+ # Training Loop
42
+ losses = []
43
+ for i in range(num_iterations):
44
+ loss = svi.step(x, y, num_factors, y_type, batch_size=batch_size, **kwargs)
45
+ losses.append(loss)
46
+ if i % 500 == 0:
47
+ print(f"Iteration {i} - Loss: {loss:.2f}")
48
+
49
+ # Simple convergence check
50
+ first = np.mean(losses[-200:-100])
51
+ last = np.mean(losses[-100:])
52
+ relative_change = abs(last - first) / abs(first)
53
+ if relative_change < 0.01:
54
+ print("SVI converged successfully.")
55
+ else:
56
+ print("SVI may not have converged.")
57
+
58
+ # return samples
59
+ predictive = Predictive(
60
+ model,
61
+ guide=guide,
62
+ num_samples=1000,
63
+ )
64
+
65
+ samples = predictive(x, y, num_factors, y_type, batch_size=None, **kwargs)
66
+
67
+ return {
68
+ "method": "svi",
69
+ "guide": guide,
70
+ "losses": losses,
71
+ "samples": samples,
72
+ }
73
+
74
+
75
+ def fit_mcmc(
76
+ model,
77
+ y,
78
+ x,
79
+ num_factors,
80
+ y_type,
81
+ num_samples=1000,
82
+ warmup_steps=500,
83
+ num_chains=1,
84
+ batch_size=None,
85
+ ):
86
+ if batch_size is not None:
87
+ raise ValueError(
88
+ "MCMC does not support mini-batching. Please use SVI (method='svi') or set batch_size=None."
89
+ )
90
+ pyro.clear_param_store()
91
+
92
+ kernel = NUTS(model)
93
+
94
+ mcmc = MCMC(
95
+ kernel,
96
+ num_samples=num_samples,
97
+ warmup_steps=warmup_steps,
98
+ num_chains=num_chains,
99
+ )
100
+
101
+ mcmc.run(x, y, num_factors, y_type)
102
+
103
+ return {
104
+ "method": "mcmc",
105
+ "mcmc": mcmc,
106
+ "samples": mcmc.get_samples(),
107
+ }
@@ -0,0 +1,229 @@
1
+ # 2. Define the Model
2
+ import pyro
3
+ import pyro.distributions as dist
4
+ import torch
5
+
6
+
7
+ def observation_dist(eta, y_type):
8
+ if y_type == "presence_absence":
9
+ return dist.Bernoulli(logits=eta)
10
+
11
+ elif y_type == "count":
12
+ return dist.Poisson(rate=torch.exp(torch.clamp(eta, -20, 20)))
13
+ # elif y_type == "count":
14
+ #
15
+ # mu = torch.exp(torch.clamp(eta, -20, 20))
16
+ #
17
+ # dispersion = pyro.sample(
18
+ # "dispersion",
19
+ # dist.LogNormal(
20
+ # torch.zeros(eta.shape[-1], device=eta.device),
21
+ # torch.ones(eta.shape[-1], device=eta.device),
22
+ # ).to_event(1),
23
+ # )
24
+ #
25
+ # return dist.NegativeBinomial(
26
+ # total_count=dispersion,
27
+ # logits=torch.log(mu) - torch.log(dispersion),
28
+ # )
29
+
30
+ else:
31
+ raise ValueError(
32
+ "y_type must be 'presence_absence', 'count', or 'zero_inflated_count'"
33
+ )
34
+
35
+
36
+ def linear_model(X, Y, num_factors, y_type="presence_absence", batch_size=None):
37
+ device = X.device
38
+
39
+ n_sites, n_species = Y.shape
40
+ n_env = X.shape[1]
41
+
42
+ alpha = pyro.sample(
43
+ "alpha",
44
+ dist.Normal(
45
+ torch.zeros(n_species, device=device),
46
+ torch.ones(n_species, device=device),
47
+ ).to_event(1),
48
+ )
49
+
50
+ beta = pyro.sample(
51
+ "beta",
52
+ dist.Normal(
53
+ torch.zeros(n_env, n_species, device=device),
54
+ torch.ones(n_env, n_species, device=device),
55
+ ).to_event(2),
56
+ )
57
+
58
+ Z = pyro.sample(
59
+ "Z",
60
+ dist.Normal(
61
+ torch.zeros(n_species, num_factors, device=device),
62
+ torch.ones(n_species, num_factors, device=device),
63
+ ).to_event(2),
64
+ )
65
+
66
+ with pyro.plate("sites", n_sites, subsample_size=batch_size) as ind:
67
+ X_batch = X[ind]
68
+ Y_batch = Y[ind]
69
+
70
+ W_batch = pyro.sample(
71
+ "W",
72
+ dist.Normal(
73
+ torch.zeros(num_factors, device=device),
74
+ torch.ones(num_factors, device=device),
75
+ ).to_event(1),
76
+ )
77
+
78
+ eta_batch = (
79
+ alpha
80
+ + torch.matmul(X_batch, beta)
81
+ + torch.matmul(W_batch, Z.transpose(-1, -2))
82
+ )
83
+
84
+ pyro.sample(
85
+ "obs",
86
+ observation_dist(eta_batch, y_type).to_event(1),
87
+ obs=Y_batch,
88
+ )
89
+
90
+
91
+ def gaussian(
92
+ X, Y, num_factors, y_type="presence_absence", batch_size=None
93
+ ):
94
+ device = X.device
95
+
96
+ n_sites, n_species = Y.shape
97
+ n_env = X.shape[1]
98
+
99
+ alpha = pyro.sample(
100
+ "alpha",
101
+ dist.Normal(
102
+ torch.zeros(n_species, device=device),
103
+ torch.ones(n_species, device=device),
104
+ ).to_event(1),
105
+ )
106
+
107
+ # species-specific environmental optimum
108
+ mu = pyro.sample(
109
+ "mu",
110
+ dist.Normal(
111
+ torch.zeros(n_env, n_species, device=device),
112
+ torch.ones(n_env, n_species, device=device),
113
+ ).to_event(2),
114
+ )
115
+
116
+ # positive niche strength / inverse width
117
+ gamma = pyro.sample(
118
+ "gamma",
119
+ dist.LogNormal(
120
+ torch.zeros(n_env, n_species, device=device),
121
+ torch.ones(n_env, n_species, device=device),
122
+ ).to_event(2),
123
+ )
124
+
125
+ Z = pyro.sample(
126
+ "Z",
127
+ dist.Normal(
128
+ torch.zeros(n_species, num_factors, device=device),
129
+ torch.ones(n_species, num_factors, device=device),
130
+ ).to_event(2),
131
+ )
132
+
133
+ with pyro.plate("sites", n_sites, subsample_size=batch_size) as ind:
134
+ X_batch = X[ind]
135
+ Y_batch = Y[ind]
136
+
137
+ W_batch = pyro.sample(
138
+ "W",
139
+ dist.Normal(
140
+ torch.zeros(num_factors, device=device),
141
+ torch.ones(num_factors, device=device),
142
+ ).to_event(1),
143
+ )
144
+
145
+ # quadratic environmental response
146
+ diff = X_batch[..., None] - mu[..., None, :, :]
147
+ env_effect = -torch.sum(gamma[..., None, :, :] * diff**2, dim=-2)
148
+
149
+ eta_batch = alpha + env_effect + torch.matmul(W_batch, Z.transpose(-1, -2))
150
+
151
+ pyro.sample(
152
+ "obs",
153
+ observation_dist(eta_batch, y_type).to_event(1),
154
+ obs=Y_batch,
155
+ )
156
+
157
+
158
+ def bnn_model(
159
+ X, Y, num_factors=1, y_type="presence_absence", hidden_size=10, batch_size=None
160
+ ):
161
+ device = X.device
162
+
163
+ n_sites, n_species = Y.shape
164
+ n_env = X.shape[1]
165
+
166
+ w1 = pyro.sample(
167
+ "w1",
168
+ dist.Normal(
169
+ torch.zeros(n_env, hidden_size, device=device),
170
+ torch.ones(n_env, hidden_size, device=device),
171
+ ).to_event(2),
172
+ )
173
+
174
+ b1 = pyro.sample(
175
+ "b1",
176
+ dist.Normal(
177
+ torch.zeros(hidden_size, device=device),
178
+ torch.ones(hidden_size, device=device),
179
+ ).to_event(1),
180
+ )
181
+
182
+ w2 = pyro.sample(
183
+ "w2",
184
+ dist.Normal(
185
+ torch.zeros(hidden_size, n_species, device=device),
186
+ torch.ones(hidden_size, n_species, device=device),
187
+ ).to_event(2),
188
+ )
189
+
190
+ b2 = pyro.sample(
191
+ "b2",
192
+ dist.Normal(
193
+ torch.zeros(n_species, device=device),
194
+ torch.ones(n_species, device=device),
195
+ ).to_event(1),
196
+ )
197
+
198
+ Z = None
199
+ if num_factors > 0:
200
+ Z = pyro.sample(
201
+ "Z",
202
+ dist.Normal(
203
+ torch.zeros(n_species, num_factors, device=device),
204
+ torch.ones(n_species, num_factors, device=device),
205
+ ).to_event(2),
206
+ )
207
+
208
+ with pyro.plate("sites", n_sites, subsample_size=batch_size) as ind:
209
+ X_batch = X[ind]
210
+ Y_batch = Y[ind]
211
+
212
+ hidden = torch.tanh(X_batch @ w1 + b1)
213
+ eta_batch = hidden @ w2 + b2
214
+
215
+ if num_factors > 0:
216
+ W_batch = pyro.sample(
217
+ "W",
218
+ dist.Normal(
219
+ torch.zeros(num_factors, device=device),
220
+ torch.ones(num_factors, device=device),
221
+ ).to_event(1),
222
+ )
223
+ eta_batch = eta_batch + W_batch @ Z.transpose(-1, -2)
224
+
225
+ pyro.sample(
226
+ "obs",
227
+ observation_dist(eta_batch, y_type).to_event(1),
228
+ obs=Y_batch,
229
+ )