deepgeodemo 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. deepgeodemo-0.1.4/LICENSE +21 -0
  2. deepgeodemo-0.1.4/MANIFEST.in +0 -0
  3. deepgeodemo-0.1.4/PKG-INFO +133 -0
  4. deepgeodemo-0.1.4/README.md +99 -0
  5. deepgeodemo-0.1.4/pyproject.toml +51 -0
  6. deepgeodemo-0.1.4/setup.cfg +4 -0
  7. deepgeodemo-0.1.4/src/deepgeodemo/__init__.py +2 -0
  8. deepgeodemo-0.1.4/src/deepgeodemo/activation.py +153 -0
  9. deepgeodemo-0.1.4/src/deepgeodemo/autoencoder_search.py +97 -0
  10. deepgeodemo-0.1.4/src/deepgeodemo/autoencoder_train_latent.py +380 -0
  11. deepgeodemo-0.1.4/src/deepgeodemo/cli.py +72 -0
  12. deepgeodemo-0.1.4/src/deepgeodemo/kmeans_rapids.py +119 -0
  13. deepgeodemo-0.1.4/src/deepgeodemo/kmeans_rapids_search.py +146 -0
  14. deepgeodemo-0.1.4/src/deepgeodemo/kmeans_sklearn.py +113 -0
  15. deepgeodemo-0.1.4/src/deepgeodemo/kmeans_sklearn_search.py +140 -0
  16. deepgeodemo-0.1.4/src/deepgeodemo/loggers.py +34 -0
  17. deepgeodemo-0.1.4/src/deepgeodemo/loss.py +93 -0
  18. deepgeodemo-0.1.4/src/deepgeodemo/models.py +330 -0
  19. deepgeodemo-0.1.4/src/deepgeodemo.egg-info/PKG-INFO +133 -0
  20. deepgeodemo-0.1.4/src/deepgeodemo.egg-info/SOURCES.txt +27 -0
  21. deepgeodemo-0.1.4/src/deepgeodemo.egg-info/dependency_links.txt +1 -0
  22. deepgeodemo-0.1.4/src/deepgeodemo.egg-info/entry_points.txt +2 -0
  23. deepgeodemo-0.1.4/src/deepgeodemo.egg-info/requires.txt +17 -0
  24. deepgeodemo-0.1.4/src/deepgeodemo.egg-info/top_level.txt +1 -0
  25. deepgeodemo-0.1.4/tests/test_activation.py +83 -0
  26. deepgeodemo-0.1.4/tests/test_autoencoder_search.py +92 -0
  27. deepgeodemo-0.1.4/tests/test_autoencoder_train_latent.py +33 -0
  28. deepgeodemo-0.1.4/tests/test_loss.py +105 -0
  29. deepgeodemo-0.1.4/tests/test_models.py +95 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Stef De Sabbata
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
File without changes
@@ -0,0 +1,133 @@
1
+ Metadata-Version: 2.4
2
+ Name: deepgeodemo
3
+ Version: 0.1.4
4
+ Summary: A library and CLI tool for geodemographic classification using deep autoencoders.
5
+ Author-email: Stef De Sabbata <stef@stefdesabbata.io>
6
+ License: MIT
7
+ Project-URL: Homepage, https://stefdesabbata.github.io/deepgeodemo/
8
+ Project-URL: Repository, https://github.com/stefdesabbata/deepgeodemo
9
+ Keywords: geodemographic,geodemographics,autoencoders
10
+ Classifier: Development Status :: 3 - Alpha
11
+ Classifier: Intended Audience :: Science/Research
12
+ Classifier: Topic :: Scientific/Engineering :: GIS
13
+ Classifier: Programming Language :: Python :: 3.12
14
+ Requires-Python: >=3.12
15
+ Description-Content-Type: text/markdown
16
+ License-File: LICENSE
17
+ Requires-Dist: pyyaml>=6.0
18
+ Requires-Dist: numpy>=2.0
19
+ Requires-Dist: pandas>=2.2
20
+ Requires-Dist: scikit-learn>=1.5
21
+ Requires-Dist: torch>=2.4
22
+ Requires-Dist: torchvision
23
+ Requires-Dist: torchaudio
24
+ Requires-Dist: tensorboard>=2.18
25
+ Requires-Dist: lightning>=2.4
26
+ Requires-Dist: matplotlib>=3.9
27
+ Requires-Dist: seaborn>=0.13
28
+ Requires-Dist: clustergram>=0.8
29
+ Requires-Dist: pytest>=7.0
30
+ Provides-Extra: gpu
31
+ Requires-Dist: cudf-cu12>=24.10; extra == "gpu"
32
+ Requires-Dist: cuml-cu12>=24.10; extra == "gpu"
33
+ Dynamic: license-file
34
+
35
+ # DeepGeoDemo: Deep Embedding Geodemographics Made Simple
36
+
37
+ ## Overview
38
+
39
+ DeepGeoDemo makes it easy to build geodemographic classifications powered by deep embeddings. Configure your autoencoder in a simple YAML file and run the full pipeline from the command line, or import DeepGeoDemo as a Python library when you need more control.
40
+
41
+
42
+
43
+ ## Installation
44
+
45
+ The package is currently under development and can be installed from the repository. The package requires Python 3.12. The package can be installed using the following commands. This will install the package with the CPU dependencies --- i.e., the cpu version of [PyTorch](https://pytorch.org/) and [scikit-learn](https://scikit-learn.org/stable/index.html) for clustering.
46
+
47
+ Create a virtual environment, e.g. using conda.
48
+
49
+ ```bash
50
+ conda create -n deepgeodemo python=3.12
51
+ conda activate deepgeodemo
52
+ ```
53
+
54
+ Install the package via `pip`.
55
+
56
+ ```bash
57
+ pip install deepgeodemo
58
+ ```
59
+
60
+ Alternatively, the package can be installed with the GPU dependencies if available and [RAPIDS](https://docs.rapids.ai/) for the clustering.
61
+
62
+ ```bash
63
+ pip install --extra-index-url=https://pypi.nvidia.com deepgeodemo[gpu]
64
+ ```
65
+
66
+
67
+ ## Usage
68
+
69
+ The commands below illustrate how to use the Command-Line Interface (CLI) to run the tool using the example configuration file `example/example.yml` and data (eight random blobs in sixteen dimensions). The `-t` flag is used to train the autoencoder and `l` to create the latent representation. The `-s` flag is used to search for the best k. The `-c` flag is used to run clustering using k-means. If available, the cuml backend is used for clustering. The `-v` flag is optional and is used to display the progress of the process.
70
+
71
+ Train the autoencoder.
72
+
73
+ ```bash
74
+ deepgeodemo -tv example/example.yml
75
+ ```
76
+
77
+ Create latent representation using on the previously trained autoencoder. Note that this will load a model from disk, and thus it will rais a warning message, as that can result in **arbitrary code execution**. Do it only if you got the file from a **trusted** source -- e.g., a model file you trained yourself, using the command above.
78
+
79
+ ```bash
80
+ deepgeodemo -lv example/example.yml
81
+ ```
82
+
83
+ Alternatively, you can train the autoencoder and create the latent representation in one go. In this case, the autoencoder will still be saved, but the latent representation will be created directly with the model in memory (rather than loading from the disk).
84
+
85
+ ```bash
86
+ deepgeodemo -tlv example/example.yml
87
+ ```
88
+
89
+ Run clustering in test mode to search for best k. Add `-r` for RAPIDS backend, default is scikit-learn.
90
+
91
+ ```bash
92
+ deepgeodemo -sv example/example.yml
93
+ ```
94
+
95
+ Run clustering using k-means. Add `-r` for RAPIDS backend, default is scikit-learn.
96
+
97
+ ```bash
98
+ deepgeodemo -cv example/example.yml
99
+ ```
100
+
101
+ Alternatively, you can run everything in one go as well.
102
+
103
+ ```bash
104
+ deepgeodemo -tlscv example/example.yml
105
+ ```
106
+
107
+ For a more concrete example, you can test the tool using the 2021 OAC data available from [Jakub Wyszomierski's repo](https://github.com/jakubwyszomierski/OAC2021-2). Download the [Clean data](https://liveuclac-my.sharepoint.com/:f:/g/personal/zcfajwy_ucl_ac_uk/Eqd1EV2WgOFJmZ7kLx-oDYMBdxqNe9IJmli6M8S-e91F0g?e=M9wh5j) used to create the [2021 OAC](https://data.cdrc.ac.uk/dataset/output-area-classification-2021), unzip the file and set the value of `data: source` to the location of one of the file datasets on your computer. It is advisable to normalise the data before training the autoencoder, e.g., using min-max scaling. Please increase the number of epochs and the number of clustering iteration to get meaninful results.
108
+
109
+
110
+
111
+ ## Unit Tests
112
+
113
+ If you want to run the unit tests, you can install the package in editable mode.
114
+
115
+ ```bash
116
+ # Clone the repository
117
+ gh repo clone stefdesabbata/deepgeodemo
118
+ cd deepgeodemo
119
+
120
+ # Install the package
121
+ pip install -e .
122
+ ```
123
+
124
+ Then run the tests.
125
+
126
+ ```bash
127
+ python -m pytest tests/ -v
128
+ ```
129
+
130
+
131
+ ## Acknowledgement
132
+
133
+ Many thanks to [Owen Goodwin](https://github.com/ogoodwin505), [Pengyuan Liu](https://github.com/PengyuanLiu1993) and [Alex Singleton](https://github.com/alexsingleton) for their collaboration on this project and for testing the pre-alpha versions of the tool.
@@ -0,0 +1,99 @@
1
+ # DeepGeoDemo: Deep Embedding Geodemographics Made Simple
2
+
3
+ ## Overview
4
+
5
+ DeepGeoDemo makes it easy to build geodemographic classifications powered by deep embeddings. Configure your autoencoder in a simple YAML file and run the full pipeline from the command line, or import DeepGeoDemo as a Python library when you need more control.
6
+
7
+
8
+
9
+ ## Installation
10
+
11
+ The package is currently under development and can be installed from the repository. The package requires Python 3.12. The package can be installed using the following commands. This will install the package with the CPU dependencies --- i.e., the cpu version of [PyTorch](https://pytorch.org/) and [scikit-learn](https://scikit-learn.org/stable/index.html) for clustering.
12
+
13
+ Create a virtual environment, e.g. using conda.
14
+
15
+ ```bash
16
+ conda create -n deepgeodemo python=3.12
17
+ conda activate deepgeodemo
18
+ ```
19
+
20
+ Install the package via `pip`.
21
+
22
+ ```bash
23
+ pip install deepgeodemo
24
+ ```
25
+
26
+ Alternatively, the package can be installed with the GPU dependencies if available and [RAPIDS](https://docs.rapids.ai/) for the clustering.
27
+
28
+ ```bash
29
+ pip install --extra-index-url=https://pypi.nvidia.com deepgeodemo[gpu]
30
+ ```
31
+
32
+
33
+ ## Usage
34
+
35
+ The commands below illustrate how to use the Command-Line Interface (CLI) to run the tool using the example configuration file `example/example.yml` and data (eight random blobs in sixteen dimensions). The `-t` flag is used to train the autoencoder and `l` to create the latent representation. The `-s` flag is used to search for the best k. The `-c` flag is used to run clustering using k-means. If available, the cuml backend is used for clustering. The `-v` flag is optional and is used to display the progress of the process.
36
+
37
+ Train the autoencoder.
38
+
39
+ ```bash
40
+ deepgeodemo -tv example/example.yml
41
+ ```
42
+
43
+ Create latent representation using on the previously trained autoencoder. Note that this will load a model from disk, and thus it will rais a warning message, as that can result in **arbitrary code execution**. Do it only if you got the file from a **trusted** source -- e.g., a model file you trained yourself, using the command above.
44
+
45
+ ```bash
46
+ deepgeodemo -lv example/example.yml
47
+ ```
48
+
49
+ Alternatively, you can train the autoencoder and create the latent representation in one go. In this case, the autoencoder will still be saved, but the latent representation will be created directly with the model in memory (rather than loading from the disk).
50
+
51
+ ```bash
52
+ deepgeodemo -tlv example/example.yml
53
+ ```
54
+
55
+ Run clustering in test mode to search for best k. Add `-r` for RAPIDS backend, default is scikit-learn.
56
+
57
+ ```bash
58
+ deepgeodemo -sv example/example.yml
59
+ ```
60
+
61
+ Run clustering using k-means. Add `-r` for RAPIDS backend, default is scikit-learn.
62
+
63
+ ```bash
64
+ deepgeodemo -cv example/example.yml
65
+ ```
66
+
67
+ Alternatively, you can run everything in one go as well.
68
+
69
+ ```bash
70
+ deepgeodemo -tlscv example/example.yml
71
+ ```
72
+
73
+ For a more concrete example, you can test the tool using the 2021 OAC data available from [Jakub Wyszomierski's repo](https://github.com/jakubwyszomierski/OAC2021-2). Download the [Clean data](https://liveuclac-my.sharepoint.com/:f:/g/personal/zcfajwy_ucl_ac_uk/Eqd1EV2WgOFJmZ7kLx-oDYMBdxqNe9IJmli6M8S-e91F0g?e=M9wh5j) used to create the [2021 OAC](https://data.cdrc.ac.uk/dataset/output-area-classification-2021), unzip the file and set the value of `data: source` to the location of one of the file datasets on your computer. It is advisable to normalise the data before training the autoencoder, e.g., using min-max scaling. Please increase the number of epochs and the number of clustering iteration to get meaninful results.
74
+
75
+
76
+
77
+ ## Unit Tests
78
+
79
+ If you want to run the unit tests, you can install the package in editable mode.
80
+
81
+ ```bash
82
+ # Clone the repository
83
+ gh repo clone stefdesabbata/deepgeodemo
84
+ cd deepgeodemo
85
+
86
+ # Install the package
87
+ pip install -e .
88
+ ```
89
+
90
+ Then run the tests.
91
+
92
+ ```bash
93
+ python -m pytest tests/ -v
94
+ ```
95
+
96
+
97
+ ## Acknowledgement
98
+
99
+ Many thanks to [Owen Goodwin](https://github.com/ogoodwin505), [Pengyuan Liu](https://github.com/PengyuanLiu1993) and [Alex Singleton](https://github.com/alexsingleton) for their collaboration on this project and for testing the pre-alpha versions of the tool.
@@ -0,0 +1,51 @@
1
+ [build-system]
2
+ requires = ["setuptools>=75", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "deepgeodemo"
7
+ version = "0.1.4"
8
+ description = "A library and CLI tool for geodemographic classification using deep autoencoders."
9
+ keywords = ["geodemographic", "geodemographics", "autoencoders"]
10
+ readme = "README.md"
11
+ requires-python = ">=3.12"
12
+ license = { text = "MIT" }
13
+ authors = [
14
+ {name = "Stef De Sabbata", email = "stef@stefdesabbata.io"}
15
+ ]
16
+ classifiers = [
17
+ "Development Status :: 3 - Alpha",
18
+ "Intended Audience :: Science/Research",
19
+ "Topic :: Scientific/Engineering :: GIS",
20
+ "Programming Language :: Python :: 3.12"
21
+ ]
22
+ dependencies = [
23
+ "pyyaml>=6.0",
24
+ "numpy>=2.0",
25
+ "pandas>=2.2",
26
+ "scikit-learn>=1.5",
27
+ "torch>=2.4",
28
+ "torchvision",
29
+ "torchaudio",
30
+ "tensorboard>=2.18",
31
+ "lightning>=2.4",
32
+ "matplotlib>=3.9",
33
+ "seaborn>=0.13",
34
+ "clustergram>=0.8",
35
+ "pytest>=7.0"
36
+ ]
37
+
38
+ [project.optional-dependencies]
39
+ # Optional RAPIDS libraries
40
+ gpu = [
41
+ "cudf-cu12>=24.10",
42
+ "cuml-cu12>=24.10"
43
+ ]
44
+
45
+ [project.urls]
46
+ Homepage = "https://stefdesabbata.github.io/deepgeodemo/"
47
+ Repository = "https://github.com/stefdesabbata/deepgeodemo"
48
+
49
+ [project.scripts]
50
+ deepgeodemo = "deepgeodemo.cli:main"
51
+
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,2 @@
1
+ from .models import AutoEncoder
2
+ from .autoencoder_train_latent import train_latent
@@ -0,0 +1,153 @@
1
+ from typing import Any, Literal
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ # Based on TopK activation function
7
+ # by Gao et al (2024)
8
+ # https://arxiv.org/abs/2406.04093
9
+ #
10
+ # Based on
11
+ # https://github.com/openai/sparse_autoencoder
12
+ # MIT license
13
+
14
+ class TopK(nn.Module):
15
+ """TopK activation."""
16
+
17
+ def __init__(self,
18
+ postact_fn: nn.Module,
19
+ k: int
20
+ ) -> None:
21
+ super().__init__()
22
+ self.postact_fn = postact_fn
23
+ self.k = k
24
+
25
+ # Forward pass
26
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
27
+ # If zeroing is disabled, just apply the post-activation function
28
+ output = self.postact_fn(x)
29
+ # Select top-k activations
30
+ topk = torch.topk(output, k=self.k, dim=-1)
31
+ # make all other values 0
32
+ output = torch.zeros_like(output)
33
+ output.scatter_(-1, topk.indices, topk.values)
34
+ return output
35
+
36
+ # Based on Jumping Ahead's JumpReLU activation function
37
+ # by Rajamanoharan et al (2024)
38
+ # https://arxiv.org/abs/2407.14435
39
+ #
40
+ # Based on
41
+ # https://github.com/jbloomAus/SAELens/blob/abcf9a603acf9344d249f0a595e89be45b77b7cf/sae_lens/training/training_sae.py#L64
42
+ # MIT license
43
+ #
44
+ # Note: suggested hyperparameters for JumpReLU
45
+ # jumprelu_bandwidth=0.001
46
+ # jumprelu_init_threshold=0.001
47
+
48
+ def _rectangle(x: torch.Tensor) -> torch.Tensor:
49
+ return ((x > -0.5) & (x < 0.5)).to(x.dtype)
50
+
51
+ class _JumpReLUFunction(torch.autograd.Function):
52
+ """
53
+ Implements the JumpReLU activation function from Appendix J of https://arxiv.org/abs/2407.14435
54
+ """
55
+
56
+ @staticmethod
57
+ def forward(
58
+ x: torch.Tensor,
59
+ threshold: torch.Tensor,
60
+ bandwidth: float,
61
+ ) -> torch.Tensor:
62
+ # Validate threshold tensor
63
+ if not (threshold > 0).all():
64
+ raise ValueError("All values in the threshold tensor must be positive.")
65
+ # Return the JumpReLU activation
66
+ return (x * (x > threshold)).to(x.dtype)
67
+
68
+ @staticmethod
69
+ def setup_context(
70
+ ctx: Any,
71
+ inputs: tuple[torch.Tensor, torch.Tensor, float],
72
+ output: torch.Tensor
73
+ ) -> None:
74
+ # Save the input tensors and bandwidth for backward pass
75
+ x, threshold, bandwidth = inputs
76
+ del output
77
+ ctx.save_for_backward(x, threshold)
78
+ ctx.bandwidth = bandwidth
79
+
80
+ @staticmethod
81
+ def backward(
82
+ ctx: Any,
83
+ grad_output: torch.Tensor
84
+ ) -> tuple[torch.Tensor, torch.Tensor, None]:
85
+ # Retrieve saved tensors and bandwidth
86
+ x, threshold = ctx.saved_tensors
87
+ bandwidth = ctx.bandwidth
88
+ x_grad = (x > threshold).to(x.dtype) * grad_output
89
+ # Pseudo-derivative for the threshold using the STE
90
+ threshold_grad = torch.sum(
91
+ -(threshold / bandwidth)
92
+ * _rectangle((x - threshold) / bandwidth)
93
+ * grad_output,
94
+ dim=0,
95
+ )
96
+ return x_grad, threshold_grad, None
97
+
98
+ def jump_relu(
99
+ x: torch.Tensor,
100
+ threshold: torch.Tensor,
101
+ bandwidth: float = 0.001
102
+ ) -> torch.Tensor:
103
+ """
104
+ Functional wrapper for the JumpReLU activation function.
105
+
106
+ Args:
107
+ x (torch.Tensor): The input tensor (pre-activations).
108
+ threshold (torch.Tensor): The trainable threshold parameter (θ).
109
+ bandwidth (float): The kernel bandwidth hyperparameter (ε).
110
+
111
+ Returns:
112
+ torch.Tensor: The output of the JumpReLU activation.
113
+ """
114
+ return _JumpReLUFunction.apply(x, threshold, bandwidth)
115
+
116
+
117
+ class JumpReLU(nn.Module):
118
+ """
119
+ A PyTorch nn.Module for the JumpReLU activation function.
120
+
121
+ Args:
122
+ num_features (int): The number of features in the input tensor (e.g., M).
123
+ initial_threshold (float): The initial value for the threshold θ.
124
+ bandwidth (float): The kernel bandwidth hyperparameter ε.
125
+ """
126
+ def __init__(self,
127
+ num_features: int,
128
+ initial_threshold: float = 0.001,
129
+ bandwidth: float = 0.001
130
+ ) -> None:
131
+ super().__init__()
132
+ # Bandwidth
133
+ self.bandwidth = bandwidth
134
+ # Threshold
135
+ # To ensure the threshold remains positive, train its logarithm.
136
+ initial_log_threshold = torch.log(torch.tensor(initial_threshold))
137
+ self.log_threshold = nn.Parameter(torch.full((num_features,), initial_log_threshold))
138
+
139
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
140
+ """
141
+ Applies the JumpReLU activation to the input tensor.
142
+
143
+ Args:
144
+ x (torch.Tensor): The input tensor of pre-activations.
145
+
146
+ Returns:
147
+ The output tensor.
148
+ """
149
+ # Exponentiate to get the positive threshold value for the forward pass.
150
+ threshold = torch.exp(self.log_threshold)
151
+ # Calculate pre-activations (see paper's SAE implementation https://arxiv.org/abs/2407.14435 )
152
+ x = torch.relu(x)
153
+ return jump_relu(x, threshold, self.bandwidth)
@@ -0,0 +1,97 @@
1
+ import itertools
2
+ import collections.abc
3
+ import pandas as pd
4
+ import torch
5
+ from .autoencoder_train_latent import train_latent
6
+
7
+
8
+ # Utitlities for generating configurations to search ----------------------
9
+
10
+ # Split keys in the format 'key: value' into a list
11
+ def key_splitter(key):
12
+ return key.split(': ') if ':' in key else [key]
13
+
14
+ # Generate dictionaries of options from product
15
+ # By Seth Johnson via stackoverflow
16
+ # https://stackoverflow.com/questions/5228158/cartesian-product-of-a-dictionary-of-lists
17
+ def product_dict(**kwargs):
18
+ keys = kwargs.keys()
19
+ for instance in itertools.product(*kwargs.values()):
20
+ yield dict(zip(keys, instance))
21
+
22
+ # Generate configurations based on a base configuration and options
23
+ def generate_configs(base, options, state_subversion_of=None):
24
+ configs = []
25
+ subv = 0
26
+ # For each combination of options, create a new configuration
27
+ for option in product_dict(**options):
28
+ description = {**base, **option}
29
+ # Add a version number to the configuration if specified
30
+ if state_subversion_of is not None:
31
+ subv += 1
32
+ description['autoencoder: version'] = str(state_subversion_of) + '-' + str(subv)
33
+ # Create config
34
+ config = {}
35
+ # For each item of the description
36
+ for descr_key, descr_value in description.items():
37
+ descr_key = key_splitter(descr_key)
38
+ config_pointer = config
39
+ # Traverse the config structure to the right position
40
+ for k in descr_key[:-1]:
41
+ if not k in config_pointer.keys():
42
+ config_pointer[k] = {}
43
+ config_pointer = config_pointer[k]
44
+ # Set the value at the right position
45
+ config_pointer[descr_key[-1]] = descr_value
46
+ # Add the configuration to the list
47
+ configs.append(config)
48
+ # Clean up variables
49
+ del config, description, descr_key, descr_value, config_pointer
50
+ # Return all configs
51
+ return configs
52
+
53
+ # Flatten dictionaries for reporting
54
+ def flatten_dict(d, parent_key = ''):
55
+ items = []
56
+ for k, v in d.items():
57
+ new_key = parent_key + '_' + k if parent_key else k
58
+ if isinstance(v, collections.abc.MutableMapping) and v:
59
+ items.extend(flatten_dict(v, new_key).items())
60
+ else:
61
+ # Handle tensors
62
+ if isinstance(v, torch.Tensor):
63
+ if v.numel() == 1:
64
+ v = v.item()
65
+ else:
66
+ v = v.detach().cpu().tolist()
67
+ v = v if isinstance(v, (int, float)) else str(v)
68
+ items.append((new_key, v))
69
+ return dict(items)
70
+
71
+ # Generate final report from configurations and reports
72
+ def generate_final_report(configs, reports):
73
+ flat_configs = [flatten_dict(config) for config in configs]
74
+ flat_reports = [flatten_dict(report) for report in reports]
75
+ final_reports = []
76
+ for config, report in zip(flat_configs, flat_reports):
77
+ final_reports.append({**config, **report})
78
+ final_report_df = pd.DataFrame(final_reports)
79
+ return final_report_df
80
+
81
+
82
+ # Main method searching across configurations -----------------------------
83
+
84
+ def explore_configs(ae_base, ae_options, state_subversion_of=0, create_latent=False, verbose=True):
85
+ # Generate configurations based on the base and options
86
+ ae_configs = generate_configs(ae_base, ae_options, state_subversion_of)
87
+ ae_reports = []
88
+ # Run the training for each configuration
89
+ for i, ae_config in enumerate(ae_configs):
90
+ print(f'\n\n{'-'*76}')
91
+ print(f'Exploring configuration {i+1} of {len(ae_configs)}')
92
+ print(f'{'-'*76}\n')
93
+ ae_report = train_latent(ae_config, create_latent=create_latent, verbose=verbose)
94
+ ae_reports.append(ae_report)
95
+ del ae_report
96
+ ae_final_report = generate_final_report(ae_configs, ae_reports)
97
+ return ae_final_report