deepgeodemo 0.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepgeodemo-0.1.4/LICENSE +21 -0
- deepgeodemo-0.1.4/MANIFEST.in +0 -0
- deepgeodemo-0.1.4/PKG-INFO +133 -0
- deepgeodemo-0.1.4/README.md +99 -0
- deepgeodemo-0.1.4/pyproject.toml +51 -0
- deepgeodemo-0.1.4/setup.cfg +4 -0
- deepgeodemo-0.1.4/src/deepgeodemo/__init__.py +2 -0
- deepgeodemo-0.1.4/src/deepgeodemo/activation.py +153 -0
- deepgeodemo-0.1.4/src/deepgeodemo/autoencoder_search.py +97 -0
- deepgeodemo-0.1.4/src/deepgeodemo/autoencoder_train_latent.py +380 -0
- deepgeodemo-0.1.4/src/deepgeodemo/cli.py +72 -0
- deepgeodemo-0.1.4/src/deepgeodemo/kmeans_rapids.py +119 -0
- deepgeodemo-0.1.4/src/deepgeodemo/kmeans_rapids_search.py +146 -0
- deepgeodemo-0.1.4/src/deepgeodemo/kmeans_sklearn.py +113 -0
- deepgeodemo-0.1.4/src/deepgeodemo/kmeans_sklearn_search.py +140 -0
- deepgeodemo-0.1.4/src/deepgeodemo/loggers.py +34 -0
- deepgeodemo-0.1.4/src/deepgeodemo/loss.py +93 -0
- deepgeodemo-0.1.4/src/deepgeodemo/models.py +330 -0
- deepgeodemo-0.1.4/src/deepgeodemo.egg-info/PKG-INFO +133 -0
- deepgeodemo-0.1.4/src/deepgeodemo.egg-info/SOURCES.txt +27 -0
- deepgeodemo-0.1.4/src/deepgeodemo.egg-info/dependency_links.txt +1 -0
- deepgeodemo-0.1.4/src/deepgeodemo.egg-info/entry_points.txt +2 -0
- deepgeodemo-0.1.4/src/deepgeodemo.egg-info/requires.txt +17 -0
- deepgeodemo-0.1.4/src/deepgeodemo.egg-info/top_level.txt +1 -0
- deepgeodemo-0.1.4/tests/test_activation.py +83 -0
- deepgeodemo-0.1.4/tests/test_autoencoder_search.py +92 -0
- deepgeodemo-0.1.4/tests/test_autoencoder_train_latent.py +33 -0
- deepgeodemo-0.1.4/tests/test_loss.py +105 -0
- deepgeodemo-0.1.4/tests/test_models.py +95 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Stef De Sabbata
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
File without changes
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: deepgeodemo
|
|
3
|
+
Version: 0.1.4
|
|
4
|
+
Summary: A library and CLI tool for geodemographic classification using deep autoencoders.
|
|
5
|
+
Author-email: Stef De Sabbata <stef@stefdesabbata.io>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://stefdesabbata.github.io/deepgeodemo/
|
|
8
|
+
Project-URL: Repository, https://github.com/stefdesabbata/deepgeodemo
|
|
9
|
+
Keywords: geodemographic,geodemographics,autoencoders
|
|
10
|
+
Classifier: Development Status :: 3 - Alpha
|
|
11
|
+
Classifier: Intended Audience :: Science/Research
|
|
12
|
+
Classifier: Topic :: Scientific/Engineering :: GIS
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
14
|
+
Requires-Python: >=3.12
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
License-File: LICENSE
|
|
17
|
+
Requires-Dist: pyyaml>=6.0
|
|
18
|
+
Requires-Dist: numpy>=2.0
|
|
19
|
+
Requires-Dist: pandas>=2.2
|
|
20
|
+
Requires-Dist: scikit-learn>=1.5
|
|
21
|
+
Requires-Dist: torch>=2.4
|
|
22
|
+
Requires-Dist: torchvision
|
|
23
|
+
Requires-Dist: torchaudio
|
|
24
|
+
Requires-Dist: tensorboard>=2.18
|
|
25
|
+
Requires-Dist: lightning>=2.4
|
|
26
|
+
Requires-Dist: matplotlib>=3.9
|
|
27
|
+
Requires-Dist: seaborn>=0.13
|
|
28
|
+
Requires-Dist: clustergram>=0.8
|
|
29
|
+
Requires-Dist: pytest>=7.0
|
|
30
|
+
Provides-Extra: gpu
|
|
31
|
+
Requires-Dist: cudf-cu12>=24.10; extra == "gpu"
|
|
32
|
+
Requires-Dist: cuml-cu12>=24.10; extra == "gpu"
|
|
33
|
+
Dynamic: license-file
|
|
34
|
+
|
|
35
|
+
# DeepGeoDemo: Deep Embedding Geodemographics Made Simple
|
|
36
|
+
|
|
37
|
+
## Overview
|
|
38
|
+
|
|
39
|
+
DeepGeoDemo makes it easy to build geodemographic classifications powered by deep embeddings. Configure your autoencoder in a simple YAML file and run the full pipeline from the command line, or import DeepGeoDemo as a Python library when you need more control.
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
## Installation
|
|
44
|
+
|
|
45
|
+
The package is currently under development and can be installed from the repository. The package requires Python 3.12. The package can be installed using the following commands. This will install the package with the CPU dependencies --- i.e., the cpu version of [PyTorch](https://pytorch.org/) and [scikit-learn](https://scikit-learn.org/stable/index.html) for clustering.
|
|
46
|
+
|
|
47
|
+
Create a virtual environment, e.g. using conda.
|
|
48
|
+
|
|
49
|
+
```bash
|
|
50
|
+
conda create -n deepgeodemo python=3.12
|
|
51
|
+
conda activate deepgeodemo
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
Install the package via `pip`.
|
|
55
|
+
|
|
56
|
+
```bash
|
|
57
|
+
pip install deepgeodemo
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
Alternatively, the package can be installed with the GPU dependencies if available and [RAPIDS](https://docs.rapids.ai/) for the clustering.
|
|
61
|
+
|
|
62
|
+
```bash
|
|
63
|
+
pip install --extra-index-url=https://pypi.nvidia.com deepgeodemo[gpu]
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
## Usage
|
|
68
|
+
|
|
69
|
+
The commands below illustrate how to use the Command-Line Interface (CLI) to run the tool using the example configuration file `example/example.yml` and data (eight random blobs in sixteen dimensions). The `-t` flag is used to train the autoencoder and `l` to create the latent representation. The `-s` flag is used to search for the best k. The `-c` flag is used to run clustering using k-means. If available, the cuml backend is used for clustering. The `-v` flag is optional and is used to display the progress of the process.
|
|
70
|
+
|
|
71
|
+
Train the autoencoder.
|
|
72
|
+
|
|
73
|
+
```bash
|
|
74
|
+
deepgeodemo -tv example/example.yml
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
Create latent representation using on the previously trained autoencoder. Note that this will load a model from disk, and thus it will rais a warning message, as that can result in **arbitrary code execution**. Do it only if you got the file from a **trusted** source -- e.g., a model file you trained yourself, using the command above.
|
|
78
|
+
|
|
79
|
+
```bash
|
|
80
|
+
deepgeodemo -lv example/example.yml
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
Alternatively, you can train the autoencoder and create the latent representation in one go. In this case, the autoencoder will still be saved, but the latent representation will be created directly with the model in memory (rather than loading from the disk).
|
|
84
|
+
|
|
85
|
+
```bash
|
|
86
|
+
deepgeodemo -tlv example/example.yml
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
Run clustering in test mode to search for best k. Add `-r` for RAPIDS backend, default is scikit-learn.
|
|
90
|
+
|
|
91
|
+
```bash
|
|
92
|
+
deepgeodemo -sv example/example.yml
|
|
93
|
+
```
|
|
94
|
+
|
|
95
|
+
Run clustering using k-means. Add `-r` for RAPIDS backend, default is scikit-learn.
|
|
96
|
+
|
|
97
|
+
```bash
|
|
98
|
+
deepgeodemo -cv example/example.yml
|
|
99
|
+
```
|
|
100
|
+
|
|
101
|
+
Alternatively, you can run everything in one go as well.
|
|
102
|
+
|
|
103
|
+
```bash
|
|
104
|
+
deepgeodemo -tlscv example/example.yml
|
|
105
|
+
```
|
|
106
|
+
|
|
107
|
+
For a more concrete example, you can test the tool using the 2021 OAC data available from [Jakub Wyszomierski's repo](https://github.com/jakubwyszomierski/OAC2021-2). Download the [Clean data](https://liveuclac-my.sharepoint.com/:f:/g/personal/zcfajwy_ucl_ac_uk/Eqd1EV2WgOFJmZ7kLx-oDYMBdxqNe9IJmli6M8S-e91F0g?e=M9wh5j) used to create the [2021 OAC](https://data.cdrc.ac.uk/dataset/output-area-classification-2021), unzip the file and set the value of `data: source` to the location of one of the file datasets on your computer. It is advisable to normalise the data before training the autoencoder, e.g., using min-max scaling. Please increase the number of epochs and the number of clustering iteration to get meaninful results.
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
## Unit Tests
|
|
112
|
+
|
|
113
|
+
If you want to run the unit tests, you can install the package in editable mode.
|
|
114
|
+
|
|
115
|
+
```bash
|
|
116
|
+
# Clone the repository
|
|
117
|
+
gh repo clone stefdesabbata/deepgeodemo
|
|
118
|
+
cd deepgeodemo
|
|
119
|
+
|
|
120
|
+
# Install the package
|
|
121
|
+
pip install -e .
|
|
122
|
+
```
|
|
123
|
+
|
|
124
|
+
Then run the tests.
|
|
125
|
+
|
|
126
|
+
```bash
|
|
127
|
+
python -m pytest tests/ -v
|
|
128
|
+
```
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
## Acknowledgement
|
|
132
|
+
|
|
133
|
+
Many thanks to [Owen Goodwin](https://github.com/ogoodwin505), [Pengyuan Liu](https://github.com/PengyuanLiu1993) and [Alex Singleton](https://github.com/alexsingleton) for their collaboration on this project and for testing the pre-alpha versions of the tool.
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
# DeepGeoDemo: Deep Embedding Geodemographics Made Simple
|
|
2
|
+
|
|
3
|
+
## Overview
|
|
4
|
+
|
|
5
|
+
DeepGeoDemo makes it easy to build geodemographic classifications powered by deep embeddings. Configure your autoencoder in a simple YAML file and run the full pipeline from the command line, or import DeepGeoDemo as a Python library when you need more control.
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
## Installation
|
|
10
|
+
|
|
11
|
+
The package is currently under development and can be installed from the repository. The package requires Python 3.12. The package can be installed using the following commands. This will install the package with the CPU dependencies --- i.e., the cpu version of [PyTorch](https://pytorch.org/) and [scikit-learn](https://scikit-learn.org/stable/index.html) for clustering.
|
|
12
|
+
|
|
13
|
+
Create a virtual environment, e.g. using conda.
|
|
14
|
+
|
|
15
|
+
```bash
|
|
16
|
+
conda create -n deepgeodemo python=3.12
|
|
17
|
+
conda activate deepgeodemo
|
|
18
|
+
```
|
|
19
|
+
|
|
20
|
+
Install the package via `pip`.
|
|
21
|
+
|
|
22
|
+
```bash
|
|
23
|
+
pip install deepgeodemo
|
|
24
|
+
```
|
|
25
|
+
|
|
26
|
+
Alternatively, the package can be installed with the GPU dependencies if available and [RAPIDS](https://docs.rapids.ai/) for the clustering.
|
|
27
|
+
|
|
28
|
+
```bash
|
|
29
|
+
pip install --extra-index-url=https://pypi.nvidia.com deepgeodemo[gpu]
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
## Usage
|
|
34
|
+
|
|
35
|
+
The commands below illustrate how to use the Command-Line Interface (CLI) to run the tool using the example configuration file `example/example.yml` and data (eight random blobs in sixteen dimensions). The `-t` flag is used to train the autoencoder and `l` to create the latent representation. The `-s` flag is used to search for the best k. The `-c` flag is used to run clustering using k-means. If available, the cuml backend is used for clustering. The `-v` flag is optional and is used to display the progress of the process.
|
|
36
|
+
|
|
37
|
+
Train the autoencoder.
|
|
38
|
+
|
|
39
|
+
```bash
|
|
40
|
+
deepgeodemo -tv example/example.yml
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
Create latent representation using on the previously trained autoencoder. Note that this will load a model from disk, and thus it will rais a warning message, as that can result in **arbitrary code execution**. Do it only if you got the file from a **trusted** source -- e.g., a model file you trained yourself, using the command above.
|
|
44
|
+
|
|
45
|
+
```bash
|
|
46
|
+
deepgeodemo -lv example/example.yml
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
Alternatively, you can train the autoencoder and create the latent representation in one go. In this case, the autoencoder will still be saved, but the latent representation will be created directly with the model in memory (rather than loading from the disk).
|
|
50
|
+
|
|
51
|
+
```bash
|
|
52
|
+
deepgeodemo -tlv example/example.yml
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
Run clustering in test mode to search for best k. Add `-r` for RAPIDS backend, default is scikit-learn.
|
|
56
|
+
|
|
57
|
+
```bash
|
|
58
|
+
deepgeodemo -sv example/example.yml
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
Run clustering using k-means. Add `-r` for RAPIDS backend, default is scikit-learn.
|
|
62
|
+
|
|
63
|
+
```bash
|
|
64
|
+
deepgeodemo -cv example/example.yml
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
Alternatively, you can run everything in one go as well.
|
|
68
|
+
|
|
69
|
+
```bash
|
|
70
|
+
deepgeodemo -tlscv example/example.yml
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
For a more concrete example, you can test the tool using the 2021 OAC data available from [Jakub Wyszomierski's repo](https://github.com/jakubwyszomierski/OAC2021-2). Download the [Clean data](https://liveuclac-my.sharepoint.com/:f:/g/personal/zcfajwy_ucl_ac_uk/Eqd1EV2WgOFJmZ7kLx-oDYMBdxqNe9IJmli6M8S-e91F0g?e=M9wh5j) used to create the [2021 OAC](https://data.cdrc.ac.uk/dataset/output-area-classification-2021), unzip the file and set the value of `data: source` to the location of one of the file datasets on your computer. It is advisable to normalise the data before training the autoencoder, e.g., using min-max scaling. Please increase the number of epochs and the number of clustering iteration to get meaninful results.
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
## Unit Tests
|
|
78
|
+
|
|
79
|
+
If you want to run the unit tests, you can install the package in editable mode.
|
|
80
|
+
|
|
81
|
+
```bash
|
|
82
|
+
# Clone the repository
|
|
83
|
+
gh repo clone stefdesabbata/deepgeodemo
|
|
84
|
+
cd deepgeodemo
|
|
85
|
+
|
|
86
|
+
# Install the package
|
|
87
|
+
pip install -e .
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
Then run the tests.
|
|
91
|
+
|
|
92
|
+
```bash
|
|
93
|
+
python -m pytest tests/ -v
|
|
94
|
+
```
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
## Acknowledgement
|
|
98
|
+
|
|
99
|
+
Many thanks to [Owen Goodwin](https://github.com/ogoodwin505), [Pengyuan Liu](https://github.com/PengyuanLiu1993) and [Alex Singleton](https://github.com/alexsingleton) for their collaboration on this project and for testing the pre-alpha versions of the tool.
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=75", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "deepgeodemo"
|
|
7
|
+
version = "0.1.4"
|
|
8
|
+
description = "A library and CLI tool for geodemographic classification using deep autoencoders."
|
|
9
|
+
keywords = ["geodemographic", "geodemographics", "autoencoders"]
|
|
10
|
+
readme = "README.md"
|
|
11
|
+
requires-python = ">=3.12"
|
|
12
|
+
license = { text = "MIT" }
|
|
13
|
+
authors = [
|
|
14
|
+
{name = "Stef De Sabbata", email = "stef@stefdesabbata.io"}
|
|
15
|
+
]
|
|
16
|
+
classifiers = [
|
|
17
|
+
"Development Status :: 3 - Alpha",
|
|
18
|
+
"Intended Audience :: Science/Research",
|
|
19
|
+
"Topic :: Scientific/Engineering :: GIS",
|
|
20
|
+
"Programming Language :: Python :: 3.12"
|
|
21
|
+
]
|
|
22
|
+
dependencies = [
|
|
23
|
+
"pyyaml>=6.0",
|
|
24
|
+
"numpy>=2.0",
|
|
25
|
+
"pandas>=2.2",
|
|
26
|
+
"scikit-learn>=1.5",
|
|
27
|
+
"torch>=2.4",
|
|
28
|
+
"torchvision",
|
|
29
|
+
"torchaudio",
|
|
30
|
+
"tensorboard>=2.18",
|
|
31
|
+
"lightning>=2.4",
|
|
32
|
+
"matplotlib>=3.9",
|
|
33
|
+
"seaborn>=0.13",
|
|
34
|
+
"clustergram>=0.8",
|
|
35
|
+
"pytest>=7.0"
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
[project.optional-dependencies]
|
|
39
|
+
# Optional RAPIDS libraries
|
|
40
|
+
gpu = [
|
|
41
|
+
"cudf-cu12>=24.10",
|
|
42
|
+
"cuml-cu12>=24.10"
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
[project.urls]
|
|
46
|
+
Homepage = "https://stefdesabbata.github.io/deepgeodemo/"
|
|
47
|
+
Repository = "https://github.com/stefdesabbata/deepgeodemo"
|
|
48
|
+
|
|
49
|
+
[project.scripts]
|
|
50
|
+
deepgeodemo = "deepgeodemo.cli:main"
|
|
51
|
+
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from typing import Any, Literal
|
|
2
|
+
import torch
|
|
3
|
+
from torch import nn
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
# Based on TopK activation function
|
|
7
|
+
# by Gao et al (2024)
|
|
8
|
+
# https://arxiv.org/abs/2406.04093
|
|
9
|
+
#
|
|
10
|
+
# Based on
|
|
11
|
+
# https://github.com/openai/sparse_autoencoder
|
|
12
|
+
# MIT license
|
|
13
|
+
|
|
14
|
+
class TopK(nn.Module):
|
|
15
|
+
"""TopK activation."""
|
|
16
|
+
|
|
17
|
+
def __init__(self,
|
|
18
|
+
postact_fn: nn.Module,
|
|
19
|
+
k: int
|
|
20
|
+
) -> None:
|
|
21
|
+
super().__init__()
|
|
22
|
+
self.postact_fn = postact_fn
|
|
23
|
+
self.k = k
|
|
24
|
+
|
|
25
|
+
# Forward pass
|
|
26
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
27
|
+
# If zeroing is disabled, just apply the post-activation function
|
|
28
|
+
output = self.postact_fn(x)
|
|
29
|
+
# Select top-k activations
|
|
30
|
+
topk = torch.topk(output, k=self.k, dim=-1)
|
|
31
|
+
# make all other values 0
|
|
32
|
+
output = torch.zeros_like(output)
|
|
33
|
+
output.scatter_(-1, topk.indices, topk.values)
|
|
34
|
+
return output
|
|
35
|
+
|
|
36
|
+
# Based on Jumping Ahead's JumpReLU activation function
|
|
37
|
+
# by Rajamanoharan et al (2024)
|
|
38
|
+
# https://arxiv.org/abs/2407.14435
|
|
39
|
+
#
|
|
40
|
+
# Based on
|
|
41
|
+
# https://github.com/jbloomAus/SAELens/blob/abcf9a603acf9344d249f0a595e89be45b77b7cf/sae_lens/training/training_sae.py#L64
|
|
42
|
+
# MIT license
|
|
43
|
+
#
|
|
44
|
+
# Note: suggested hyperparameters for JumpReLU
|
|
45
|
+
# jumprelu_bandwidth=0.001
|
|
46
|
+
# jumprelu_init_threshold=0.001
|
|
47
|
+
|
|
48
|
+
def _rectangle(x: torch.Tensor) -> torch.Tensor:
|
|
49
|
+
return ((x > -0.5) & (x < 0.5)).to(x.dtype)
|
|
50
|
+
|
|
51
|
+
class _JumpReLUFunction(torch.autograd.Function):
|
|
52
|
+
"""
|
|
53
|
+
Implements the JumpReLU activation function from Appendix J of https://arxiv.org/abs/2407.14435
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def forward(
|
|
58
|
+
x: torch.Tensor,
|
|
59
|
+
threshold: torch.Tensor,
|
|
60
|
+
bandwidth: float,
|
|
61
|
+
) -> torch.Tensor:
|
|
62
|
+
# Validate threshold tensor
|
|
63
|
+
if not (threshold > 0).all():
|
|
64
|
+
raise ValueError("All values in the threshold tensor must be positive.")
|
|
65
|
+
# Return the JumpReLU activation
|
|
66
|
+
return (x * (x > threshold)).to(x.dtype)
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def setup_context(
|
|
70
|
+
ctx: Any,
|
|
71
|
+
inputs: tuple[torch.Tensor, torch.Tensor, float],
|
|
72
|
+
output: torch.Tensor
|
|
73
|
+
) -> None:
|
|
74
|
+
# Save the input tensors and bandwidth for backward pass
|
|
75
|
+
x, threshold, bandwidth = inputs
|
|
76
|
+
del output
|
|
77
|
+
ctx.save_for_backward(x, threshold)
|
|
78
|
+
ctx.bandwidth = bandwidth
|
|
79
|
+
|
|
80
|
+
@staticmethod
|
|
81
|
+
def backward(
|
|
82
|
+
ctx: Any,
|
|
83
|
+
grad_output: torch.Tensor
|
|
84
|
+
) -> tuple[torch.Tensor, torch.Tensor, None]:
|
|
85
|
+
# Retrieve saved tensors and bandwidth
|
|
86
|
+
x, threshold = ctx.saved_tensors
|
|
87
|
+
bandwidth = ctx.bandwidth
|
|
88
|
+
x_grad = (x > threshold).to(x.dtype) * grad_output
|
|
89
|
+
# Pseudo-derivative for the threshold using the STE
|
|
90
|
+
threshold_grad = torch.sum(
|
|
91
|
+
-(threshold / bandwidth)
|
|
92
|
+
* _rectangle((x - threshold) / bandwidth)
|
|
93
|
+
* grad_output,
|
|
94
|
+
dim=0,
|
|
95
|
+
)
|
|
96
|
+
return x_grad, threshold_grad, None
|
|
97
|
+
|
|
98
|
+
def jump_relu(
|
|
99
|
+
x: torch.Tensor,
|
|
100
|
+
threshold: torch.Tensor,
|
|
101
|
+
bandwidth: float = 0.001
|
|
102
|
+
) -> torch.Tensor:
|
|
103
|
+
"""
|
|
104
|
+
Functional wrapper for the JumpReLU activation function.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
x (torch.Tensor): The input tensor (pre-activations).
|
|
108
|
+
threshold (torch.Tensor): The trainable threshold parameter (θ).
|
|
109
|
+
bandwidth (float): The kernel bandwidth hyperparameter (ε).
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
torch.Tensor: The output of the JumpReLU activation.
|
|
113
|
+
"""
|
|
114
|
+
return _JumpReLUFunction.apply(x, threshold, bandwidth)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class JumpReLU(nn.Module):
|
|
118
|
+
"""
|
|
119
|
+
A PyTorch nn.Module for the JumpReLU activation function.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
num_features (int): The number of features in the input tensor (e.g., M).
|
|
123
|
+
initial_threshold (float): The initial value for the threshold θ.
|
|
124
|
+
bandwidth (float): The kernel bandwidth hyperparameter ε.
|
|
125
|
+
"""
|
|
126
|
+
def __init__(self,
|
|
127
|
+
num_features: int,
|
|
128
|
+
initial_threshold: float = 0.001,
|
|
129
|
+
bandwidth: float = 0.001
|
|
130
|
+
) -> None:
|
|
131
|
+
super().__init__()
|
|
132
|
+
# Bandwidth
|
|
133
|
+
self.bandwidth = bandwidth
|
|
134
|
+
# Threshold
|
|
135
|
+
# To ensure the threshold remains positive, train its logarithm.
|
|
136
|
+
initial_log_threshold = torch.log(torch.tensor(initial_threshold))
|
|
137
|
+
self.log_threshold = nn.Parameter(torch.full((num_features,), initial_log_threshold))
|
|
138
|
+
|
|
139
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
140
|
+
"""
|
|
141
|
+
Applies the JumpReLU activation to the input tensor.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
x (torch.Tensor): The input tensor of pre-activations.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
The output tensor.
|
|
148
|
+
"""
|
|
149
|
+
# Exponentiate to get the positive threshold value for the forward pass.
|
|
150
|
+
threshold = torch.exp(self.log_threshold)
|
|
151
|
+
# Calculate pre-activations (see paper's SAE implementation https://arxiv.org/abs/2407.14435 )
|
|
152
|
+
x = torch.relu(x)
|
|
153
|
+
return jump_relu(x, threshold, self.bandwidth)
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import collections.abc
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import torch
|
|
5
|
+
from .autoencoder_train_latent import train_latent
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# Utitlities for generating configurations to search ----------------------
|
|
9
|
+
|
|
10
|
+
# Split keys in the format 'key: value' into a list
|
|
11
|
+
def key_splitter(key):
|
|
12
|
+
return key.split(': ') if ':' in key else [key]
|
|
13
|
+
|
|
14
|
+
# Generate dictionaries of options from product
|
|
15
|
+
# By Seth Johnson via stackoverflow
|
|
16
|
+
# https://stackoverflow.com/questions/5228158/cartesian-product-of-a-dictionary-of-lists
|
|
17
|
+
def product_dict(**kwargs):
|
|
18
|
+
keys = kwargs.keys()
|
|
19
|
+
for instance in itertools.product(*kwargs.values()):
|
|
20
|
+
yield dict(zip(keys, instance))
|
|
21
|
+
|
|
22
|
+
# Generate configurations based on a base configuration and options
|
|
23
|
+
def generate_configs(base, options, state_subversion_of=None):
|
|
24
|
+
configs = []
|
|
25
|
+
subv = 0
|
|
26
|
+
# For each combination of options, create a new configuration
|
|
27
|
+
for option in product_dict(**options):
|
|
28
|
+
description = {**base, **option}
|
|
29
|
+
# Add a version number to the configuration if specified
|
|
30
|
+
if state_subversion_of is not None:
|
|
31
|
+
subv += 1
|
|
32
|
+
description['autoencoder: version'] = str(state_subversion_of) + '-' + str(subv)
|
|
33
|
+
# Create config
|
|
34
|
+
config = {}
|
|
35
|
+
# For each item of the description
|
|
36
|
+
for descr_key, descr_value in description.items():
|
|
37
|
+
descr_key = key_splitter(descr_key)
|
|
38
|
+
config_pointer = config
|
|
39
|
+
# Traverse the config structure to the right position
|
|
40
|
+
for k in descr_key[:-1]:
|
|
41
|
+
if not k in config_pointer.keys():
|
|
42
|
+
config_pointer[k] = {}
|
|
43
|
+
config_pointer = config_pointer[k]
|
|
44
|
+
# Set the value at the right position
|
|
45
|
+
config_pointer[descr_key[-1]] = descr_value
|
|
46
|
+
# Add the configuration to the list
|
|
47
|
+
configs.append(config)
|
|
48
|
+
# Clean up variables
|
|
49
|
+
del config, description, descr_key, descr_value, config_pointer
|
|
50
|
+
# Return all configs
|
|
51
|
+
return configs
|
|
52
|
+
|
|
53
|
+
# Flatten dictionaries for reporting
|
|
54
|
+
def flatten_dict(d, parent_key = ''):
|
|
55
|
+
items = []
|
|
56
|
+
for k, v in d.items():
|
|
57
|
+
new_key = parent_key + '_' + k if parent_key else k
|
|
58
|
+
if isinstance(v, collections.abc.MutableMapping) and v:
|
|
59
|
+
items.extend(flatten_dict(v, new_key).items())
|
|
60
|
+
else:
|
|
61
|
+
# Handle tensors
|
|
62
|
+
if isinstance(v, torch.Tensor):
|
|
63
|
+
if v.numel() == 1:
|
|
64
|
+
v = v.item()
|
|
65
|
+
else:
|
|
66
|
+
v = v.detach().cpu().tolist()
|
|
67
|
+
v = v if isinstance(v, (int, float)) else str(v)
|
|
68
|
+
items.append((new_key, v))
|
|
69
|
+
return dict(items)
|
|
70
|
+
|
|
71
|
+
# Generate final report from configurations and reports
|
|
72
|
+
def generate_final_report(configs, reports):
|
|
73
|
+
flat_configs = [flatten_dict(config) for config in configs]
|
|
74
|
+
flat_reports = [flatten_dict(report) for report in reports]
|
|
75
|
+
final_reports = []
|
|
76
|
+
for config, report in zip(flat_configs, flat_reports):
|
|
77
|
+
final_reports.append({**config, **report})
|
|
78
|
+
final_report_df = pd.DataFrame(final_reports)
|
|
79
|
+
return final_report_df
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# Main method searching across configurations -----------------------------
|
|
83
|
+
|
|
84
|
+
def explore_configs(ae_base, ae_options, state_subversion_of=0, create_latent=False, verbose=True):
|
|
85
|
+
# Generate configurations based on the base and options
|
|
86
|
+
ae_configs = generate_configs(ae_base, ae_options, state_subversion_of)
|
|
87
|
+
ae_reports = []
|
|
88
|
+
# Run the training for each configuration
|
|
89
|
+
for i, ae_config in enumerate(ae_configs):
|
|
90
|
+
print(f'\n\n{'-'*76}')
|
|
91
|
+
print(f'Exploring configuration {i+1} of {len(ae_configs)}')
|
|
92
|
+
print(f'{'-'*76}\n')
|
|
93
|
+
ae_report = train_latent(ae_config, create_latent=create_latent, verbose=verbose)
|
|
94
|
+
ae_reports.append(ae_report)
|
|
95
|
+
del ae_report
|
|
96
|
+
ae_final_report = generate_final_report(ae_configs, ae_reports)
|
|
97
|
+
return ae_final_report
|