at-gan 0.11.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- at_gan-0.11.6/LICENSE +21 -0
- at_gan-0.11.6/MANIFEST.in +3 -0
- at_gan-0.11.6/PKG-INFO +481 -0
- at_gan-0.11.6/README.md +441 -0
- at_gan-0.11.6/pyproject.toml +70 -0
- at_gan-0.11.6/setup.cfg +4 -0
- at_gan-0.11.6/src/at_gan/__init__.py +8 -0
- at_gan-0.11.6/src/at_gan/api.py +179 -0
- at_gan-0.11.6/src/at_gan/callbacks/__init__.py +0 -0
- at_gan-0.11.6/src/at_gan/callbacks/console_callback.py +68 -0
- at_gan-0.11.6/src/at_gan/callbacks/gan_callback.py +344 -0
- at_gan-0.11.6/src/at_gan/callbacks/wand_callback.py +37 -0
- at_gan-0.11.6/src/at_gan/cli.py +220 -0
- at_gan-0.11.6/src/at_gan/data/__init__.py +0 -0
- at_gan-0.11.6/src/at_gan/data/preprocessor.py +427 -0
- at_gan-0.11.6/src/at_gan/engine/__init__.py +0 -0
- at_gan-0.11.6/src/at_gan/engine/core.py +267 -0
- at_gan-0.11.6/src/at_gan/engine/sweeper.py +132 -0
- at_gan-0.11.6/src/at_gan/engine/synthesizer.py +59 -0
- at_gan-0.11.6/src/at_gan/eval/__init__.py +0 -0
- at_gan-0.11.6/src/at_gan/eval/dcr.py +79 -0
- at_gan-0.11.6/src/at_gan/eval/sdv.py +62 -0
- at_gan-0.11.6/src/at_gan/eval/tstr.py +165 -0
- at_gan-0.11.6/src/at_gan/models/__init__.py +0 -0
- at_gan-0.11.6/src/at_gan/models/discriminator.py +52 -0
- at_gan-0.11.6/src/at_gan/models/gan.py +184 -0
- at_gan-0.11.6/src/at_gan/models/generator.py +101 -0
- at_gan-0.11.6/src/at_gan/training/__init__.py +0 -0
- at_gan-0.11.6/src/at_gan/training/trainer.py +198 -0
- at_gan-0.11.6/src/at_gan/utils/__init__.py +0 -0
- at_gan-0.11.6/src/at_gan/utils/logger.py +18 -0
- at_gan-0.11.6/src/at_gan/utils/paths.py +63 -0
- at_gan-0.11.6/src/at_gan.egg-info/PKG-INFO +481 -0
- at_gan-0.11.6/src/at_gan.egg-info/SOURCES.txt +36 -0
- at_gan-0.11.6/src/at_gan.egg-info/dependency_links.txt +1 -0
- at_gan-0.11.6/src/at_gan.egg-info/entry_points.txt +2 -0
- at_gan-0.11.6/src/at_gan.egg-info/requires.txt +17 -0
- at_gan-0.11.6/src/at_gan.egg-info/top_level.txt +1 -0
at_gan-0.11.6/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Jonas Miesenböck
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
at_gan-0.11.6/PKG-INFO
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: at-gan
|
|
3
|
+
Version: 0.11.6
|
|
4
|
+
Summary: Training Framework for Arbitrary Tabular Generative Adversarial Networks
|
|
5
|
+
Author-email: Jonas Miesenböck <jonas@miesenboeck.at>
|
|
6
|
+
Project-URL: Homepage, https://github.com/Jns-M/at-gan
|
|
7
|
+
Project-URL: Repository, https://github.com/Jns-M/at-gan
|
|
8
|
+
Project-URL: Issues, https://github.com/Jns-M/at-gan/issues
|
|
9
|
+
Project-URL: Documentation, https://github.com/Jns-M/at-gan#readme
|
|
10
|
+
Keywords: gan,tabular-data,synthetic-data,synthetic-tabular-data,tensorflow,keras,machine-learning,generative-adversarial-network,generative-adversarial-networks
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
15
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
16
|
+
Classifier: Programming Language :: Python :: 3
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
20
|
+
Requires-Python: <3.13,>=3.10
|
|
21
|
+
Description-Content-Type: text/markdown
|
|
22
|
+
License-File: LICENSE
|
|
23
|
+
Requires-Dist: tensorflow>=2.16.0
|
|
24
|
+
Requires-Dist: pandas>=2.0.0
|
|
25
|
+
Requires-Dist: scikit-learn>=1.3.0
|
|
26
|
+
Requires-Dist: wandb>=0.16.0
|
|
27
|
+
Requires-Dist: pyyaml>=6.0.0
|
|
28
|
+
Requires-Dist: numpy>=1.24.0
|
|
29
|
+
Requires-Dist: joblib>=1.3.0
|
|
30
|
+
Requires-Dist: typer>=0.9.0
|
|
31
|
+
Requires-Dist: seaborn>=0.12.0
|
|
32
|
+
Requires-Dist: matplotlib>=2.0.0
|
|
33
|
+
Requires-Dist: sdmetrics>=0.20.0
|
|
34
|
+
Requires-Dist: pandas-stubs~=2.3.3
|
|
35
|
+
Provides-Extra: dev
|
|
36
|
+
Requires-Dist: build; extra == "dev"
|
|
37
|
+
Requires-Dist: twine; extra == "dev"
|
|
38
|
+
Requires-Dist: pytest; extra == "dev"
|
|
39
|
+
Dynamic: license-file
|
|
40
|
+
|
|
41
|
+
<div align="center">
|
|
42
|
+
|
|
43
|
+
# AT-GAN
|
|
44
|
+
|
|
45
|
+
### Arbitrary Tabular Generative Adversarial Network
|
|
46
|
+
|
|
47
|
+
*A Tabular GAN framework for generating synthetic tabular data from arbitrary mixed-type tabular datasets.*
|
|
48
|
+
|
|
49
|
+
[](https://www.python.org/)
|
|
50
|
+
[](https://www.tensorflow.org/)
|
|
51
|
+
[](https://keras.io/)
|
|
52
|
+
[](https://wandb.ai/)
|
|
53
|
+
[](https://typer.tiangolo.com/)
|
|
54
|
+
[]()
|
|
55
|
+
|
|
56
|
+
</div>
|
|
57
|
+
|
|
58
|
+
---
|
|
59
|
+
|
|
60
|
+
## Table of Contents
|
|
61
|
+
|
|
62
|
+
1. [Overview](#overview)
|
|
63
|
+
1. [Key Features](#key-features)
|
|
64
|
+
1. [Installation](#installation)
|
|
65
|
+
1. [CLI Usage](#cli-usage)
|
|
66
|
+
1. [API Usage](#api-usage)
|
|
67
|
+
1. [Configuration Reference](#configuration-reference)
|
|
68
|
+
1. [In-Training Evaluation Suite](#in-training-evaluation-suite-1)
|
|
69
|
+
1. [Synthetic Data Evaluation (Post-Training)](#synthetic-data-evaluation-post-training-1)
|
|
70
|
+
|
|
71
|
+
---
|
|
72
|
+
|
|
73
|
+
## Overview
|
|
74
|
+
|
|
75
|
+
**at-gan** is a framework for training Generative Adversarial Networks on **arbitrary tabular data**. It is designed to work with *continuous*, *binary*, *discrete count*, and *categorical* features within a single pipeline.
|
|
76
|
+
|
|
77
|
+
The framework combines a **multi-branch generator** (G), a **PacGAN-style discriminator** (D), an integrated **evaluation
|
|
78
|
+
suite**, and **[Weights & Biases](https://wandb.ai/)** (W&B) sweep orchestration, experiment tracking, and training monitoring + visualization.
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
> **Goal:** Training a GAN that is capable of producing realistic synthetic tabular data from a given dataset with minimal manual tuning and a transparent, observable training process.
|
|
82
|
+
|
|
83
|
+
---
|
|
84
|
+
|
|
85
|
+
## Key Features
|
|
86
|
+
|
|
87
|
+
### Dynamic, Config-Driven Architectures
|
|
88
|
+
- Generator and Discriminator built **entirely from YAML-config**.
|
|
89
|
+
- Configurable amount of `layers` and `units`.
|
|
90
|
+
- Configurable activations: `relu`, `leaky_relu`, `elu`, or any other activation supported in Keras.
|
|
91
|
+
- Configurable `dropout` layers.
|
|
92
|
+
- Optional `Batch Normalization` for G.
|
|
93
|
+
|
|
94
|
+
### Mixed-Type Data Handling
|
|
95
|
+
- The `TabularPreprocessor` handles **types of input features**:
|
|
96
|
+
- **Continuous** → `MinMaxScaler(-1, 1)` → `tanh` output branch.
|
|
97
|
+
- **Discrete Count** → `MinMaxScaler(0, 1)` → `sigmoid` output branch.
|
|
98
|
+
- **Binary** → 0/1 and optional β-distributed noise application → `sigmoid` output branch.
|
|
99
|
+
- **Categorical** → One-hot encoding and optional label-preserving smoothing → `softmax` output branch.
|
|
100
|
+
- Per-column decimal precision preservation.
|
|
101
|
+
- Scalers and encoders are stored and reused for inference.
|
|
102
|
+
|
|
103
|
+
### GAN Training and Stabilization Techniques
|
|
104
|
+
| Technique | Controlled by | What it does |
|
|
105
|
+
|-------------------------------------|-------------------------------------|-------------------------------------------------------------------------------------------------|
|
|
106
|
+
| **PacGAN packing** | `discriminator.pack_size` | Concatenates *k* rows into a single D input → fights mode collapse |
|
|
107
|
+
| **One-sided label smoothing** | `discriminator.label_smoothing_min` | Real labels sampled from `[min, 1.0]` instead of hard `1.0` |
|
|
108
|
+
| **Label flipping** | `discriminator.label_flipping` | Random fraction of real labels flipped to `0` to prevent D overconfidence |
|
|
109
|
+
| **TTUR** | `g_lr` / `d_lr` | Different LRs for G and D. Sweeps auto-clamp `d_lr ≤ g_lr` |
|
|
110
|
+
| **G:D update ratio** | `g_updates_per_epoch` | Multiple G steps per D step to balance the training process |
|
|
111
|
+
| **LR Cosine decay + warm restarts** | `lr_cosine_decay` | `CosineDecayRestarts` schedule with configurable `alpha` floor for the learning rate of G and D |
|
|
112
|
+
| **Adam `beta_1` override** | `adam_beta_1` | Typically lowered from default `0.9` for training stability |
|
|
113
|
+
| **Gradient clipping** | *always-on* | `clipnorm=1.0` on both Adam optimizers |
|
|
114
|
+
|
|
115
|
+
### In-Training Evaluation Suite
|
|
116
|
+
Runs every `eval_frequency` epochs on held-out real samples, logs results to W&B, and saves the **best** checkpoint by error score. See [In-Training Evaluation Suite](#in-training-evaluation-suite-1).
|
|
117
|
+
|
|
118
|
+
### Experiment Tracking
|
|
119
|
+
[Weights & Biases](https://wandb.ai/) integration:
|
|
120
|
+
- Per-epoch loss/metric logging via a dedicated `WandbCallback`.
|
|
121
|
+
- Training visuals: **correlation heatmaps** + **PCA overlap scatter plots**.
|
|
122
|
+
- Local-only mode when `--no-wandb` is set (uses `run_id="offline_run"`).
|
|
123
|
+
|
|
124
|
+
### Sweeps & Neural Architecture Search
|
|
125
|
+
- W&B sweeps for **Neural Architecture Search** (NAS) and **Hyperparameter Optimization**.
|
|
126
|
+
- Mechanic to resume existing W&B sweeps (and single runs).
|
|
127
|
+
|
|
128
|
+
### Synthetic Data Evaluation (Post-Training)
|
|
129
|
+
- **Privacy**: Distance to Closest Record (DCR)
|
|
130
|
+
- **Statistic Fidelity**: [Synthetic Data Vault](https://github.com/sdv-dev/sdv) (SDV)
|
|
131
|
+
- **Utility Retention**: Train on Synthetic, Test on Real (TSTR)
|
|
132
|
+
|
|
133
|
+
### Usage Modes
|
|
134
|
+
- 🖥️ **CLI**: `train`, `sweep`, `generate`, `evaluate`.
|
|
135
|
+
- 🐍 **Python API** (`at_gan.api`): `train`, `sweep`, `generate`, `evaluate`.
|
|
136
|
+
|
|
137
|
+
---
|
|
138
|
+
|
|
139
|
+
## Installation
|
|
140
|
+
|
|
141
|
+
**Requirements:** Python `3.10 – 3.12` and dependencies listed in `pyproject.toml`.
|
|
142
|
+
|
|
143
|
+
### Option A: Install from PyPI (recommended)
|
|
144
|
+
|
|
145
|
+
```shell script
|
|
146
|
+
pip install at-gan
|
|
147
|
+
```
|
|
148
|
+
|
|
149
|
+
### Option B: Editable install
|
|
150
|
+
|
|
151
|
+
1. Clone this repository
|
|
152
|
+
1. Run the following command:
|
|
153
|
+
|
|
154
|
+
```bash
|
|
155
|
+
pip install -e .
|
|
156
|
+
```
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
### Verify installation
|
|
160
|
+
|
|
161
|
+
```shell script
|
|
162
|
+
at-gan --help
|
|
163
|
+
python -c "import at_gan; print(at_gan.__version__)"
|
|
164
|
+
```
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
### Weights & Biases Login (one-time)
|
|
168
|
+
|
|
169
|
+
```shell script
|
|
170
|
+
wandb login
|
|
171
|
+
```
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
> 💡 You can use this framework without W&B by passing `--no-wandb` (CLI) or `enable_wandb=False` (API).
|
|
175
|
+
|
|
176
|
+
---
|
|
177
|
+
|
|
178
|
+
## CLI Usage
|
|
179
|
+
|
|
180
|
+
```shell script
|
|
181
|
+
at-gan --help
|
|
182
|
+
```
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
### `train`: Run or resume a single GAN training run
|
|
186
|
+
|
|
187
|
+
| Flag | Short | Default | Description |
|
|
188
|
+
|--------------------------|------------|------------|------------------------------------|
|
|
189
|
+
| `--config` | `-c` | *required* | Path to the YAML experiment config |
|
|
190
|
+
| `--wandb / --no-wandb` | `-w / -nw` | `--wandb` | Toggle W&B tracking |
|
|
191
|
+
| `--export / --no-export` | `-e / -ne` | `--export` | Save `.keras` generator file |
|
|
192
|
+
| `--generate-samples` | `-g` | `1000` | Auto-generate *N* samples post-training |
|
|
193
|
+
|
|
194
|
+
**Examples:**
|
|
195
|
+
|
|
196
|
+
```shell script
|
|
197
|
+
at-gan train -c configs/config.yaml -w -e -g 5000
|
|
198
|
+
```
|
|
199
|
+
|
|
200
|
+
Note: A run can be resumed via the `resume_run_id` config key. See [Configuration Reference](#configuration-reference).
|
|
201
|
+
|
|
202
|
+
---
|
|
203
|
+
|
|
204
|
+
### `sweep`: Run or resume a W&B sweep
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
| Flag | Short | Description |
|
|
208
|
+
|---|---|--------------------------------------------------|
|
|
209
|
+
| `--base-config` | `-c` | Baseline experiment config |
|
|
210
|
+
| `--sweep-config` | `-s` | W&B sweep config (required for new sweeps) |
|
|
211
|
+
| `--count` | `-n` | Max runs this agent will execute |
|
|
212
|
+
| `--sweep-id` | `-id` | Resume an existing sweep instead of creating one |
|
|
213
|
+
|
|
214
|
+
```shell script
|
|
215
|
+
# Launch a new 50-run sweep
|
|
216
|
+
at-gan sweep -c configs/config.yaml -s configs/sweep_config.yaml -n 50
|
|
217
|
+
|
|
218
|
+
# Resume an existing sweep
|
|
219
|
+
at-gan sweep -c configs/config.yaml -id abc123 -n 20
|
|
220
|
+
```
|
|
221
|
+
|
|
222
|
+
---
|
|
223
|
+
|
|
224
|
+
### `generate`: Generate synthetic samples from a trained generator
|
|
225
|
+
|
|
226
|
+
| Flag | Short | Description |
|
|
227
|
+
|---|---|------------------------------------------------|
|
|
228
|
+
| `--config` | `-c` | YAML used during the **original** training run |
|
|
229
|
+
| `--run-id` | `-r` | W&B run ID or `"offline_run"` |
|
|
230
|
+
| `--samples` | `-n` | Number of samples to generate |
|
|
231
|
+
| `--output` | `-o` | Optional override for CSV output path |
|
|
232
|
+
|
|
233
|
+
```shell script
|
|
234
|
+
at-gan generate -c configs/config.yaml -r a1b2c3 -n 10000 -o synthetic_data.csv
|
|
235
|
+
```
|
|
236
|
+
|
|
237
|
+
Note: `generate` always loads **`best_generator.keras`**, not the latest.
|
|
238
|
+
|
|
239
|
+
---
|
|
240
|
+
|
|
241
|
+
### `evaluate`: Run synthetic data evaluation (post-training)
|
|
242
|
+
|
|
243
|
+
| Flag | Short | Description |
|
|
244
|
+
|---------------|-------|--------------------------------|
|
|
245
|
+
| `--real` | `-r` | Path to the real data CSV |
|
|
246
|
+
| `--synthetic` | `-s` | Path to the synthetic data CSV |
|
|
247
|
+
| `--target` | `-t` | Optional target column for TSTR evaluation |
|
|
248
|
+
|
|
249
|
+
```shell script
|
|
250
|
+
at-gan evaluate -c real_data.csv -r synthetic_data.csv -t target_column
|
|
251
|
+
```
|
|
252
|
+
|
|
253
|
+
---
|
|
254
|
+
|
|
255
|
+
## API Usage
|
|
256
|
+
|
|
257
|
+
The Python API exposes the same primary functions as a CLI, making it easy to integrate into existing projects.
|
|
258
|
+
|
|
259
|
+
See `examples/api_example.py` and `examples/api_example.ipynb` for a full API usage example.
|
|
260
|
+
|
|
261
|
+
> Note: The `train` entry point also accepts a `dict` instead of a path to a YAML file as input.
|
|
262
|
+
|
|
263
|
+
---
|
|
264
|
+
|
|
265
|
+
## Configuration Reference
|
|
266
|
+
|
|
267
|
+
Experiments are driven by **two YAML files**: a base config and a sweep config.
|
|
268
|
+
|
|
269
|
+
See `configs/config.yaml` and `configs/sweep_config.yaml` for examples and recommended default values for most datasets.
|
|
270
|
+
|
|
271
|
+
### Base Config Reference
|
|
272
|
+
|
|
273
|
+
```yaml
|
|
274
|
+
# =============================================================
|
|
275
|
+
# EXPERIMENT META
|
|
276
|
+
# =============================================================
|
|
277
|
+
experiment_name: "test_experiment" # also output directory name
|
|
278
|
+
resume_run_id: null # W&B run id to resume from checkpoint (optional)
|
|
279
|
+
seed: 1130 # seeds Python, NumPy, TensorFlow
|
|
280
|
+
|
|
281
|
+
# =============================================================
|
|
282
|
+
# DATA
|
|
283
|
+
# =============================================================
|
|
284
|
+
data:
|
|
285
|
+
dataset_path: "datasets/example.csv"
|
|
286
|
+
output_path: "experiments/" # run artifacts found in 'output_path/experiment_name/run_id/'
|
|
287
|
+
|
|
288
|
+
# Column routing — every column the GAN should learn MUST be listed here
|
|
289
|
+
continuous_cols: ["age", "heart_rate", "glucose"]
|
|
290
|
+
binary_cols: ["male", "smoker"]
|
|
291
|
+
discrete_count_cols: ["cigs_per_day"]
|
|
292
|
+
categorical_cols: ["education"]
|
|
293
|
+
|
|
294
|
+
# Preprocessing toggles
|
|
295
|
+
treat_bin_as_cat: false # route binary cols through OHE + softmax
|
|
296
|
+
beta_noise: true # Apply Beta-distributed noise on binary cols
|
|
297
|
+
smooth_categorical: true # Apply label-preserving noise on OHE groups
|
|
298
|
+
|
|
299
|
+
# =============================================================
|
|
300
|
+
# MODEL
|
|
301
|
+
# =============================================================
|
|
302
|
+
model:
|
|
303
|
+
latent_dim: 32
|
|
304
|
+
|
|
305
|
+
generator:
|
|
306
|
+
units: [64, 64]
|
|
307
|
+
dropout: 0.0
|
|
308
|
+
activation: "relu" # relu | leaky_relu | elu | ...
|
|
309
|
+
batch_norm: true # BatchNorm after each Dense layer
|
|
310
|
+
# negative_slope: 0.2 # used only when activation == "leaky_relu"
|
|
311
|
+
|
|
312
|
+
discriminator:
|
|
313
|
+
units: [256, 256]
|
|
314
|
+
dropout: 0.2
|
|
315
|
+
activation: "leaky_relu"
|
|
316
|
+
negative_slope: 0.2
|
|
317
|
+
pack_size: 3 # PacGAN packing factor (1 disables packing)
|
|
318
|
+
label_smoothing_min: 0.9 # e.g. real labels ~ [0.9, 1.0]
|
|
319
|
+
label_flipping: 0.05 # e.g. 5% of real labels flipped to 0 each step
|
|
320
|
+
|
|
321
|
+
# =============================================================
|
|
322
|
+
# TRAINING
|
|
323
|
+
# =============================================================
|
|
324
|
+
training:
|
|
325
|
+
device: "cpu" # "cpu" or "gpu"
|
|
326
|
+
epochs: 2000
|
|
327
|
+
batch_size: 512
|
|
328
|
+
g_updates_per_epoch: 2 # G steps per D step
|
|
329
|
+
|
|
330
|
+
# Optimizers
|
|
331
|
+
adam_beta_1: 0.5 # GAN-stable Adam beta_1
|
|
332
|
+
g_lr: 0.0002 # G Learning Rate
|
|
333
|
+
d_lr: 0.0003 # D Learning Rate
|
|
334
|
+
|
|
335
|
+
# LR schedule
|
|
336
|
+
lr_cosine_decay: true
|
|
337
|
+
lr_cosine_decay_restart_epochs: 2000 # restart every N epochs
|
|
338
|
+
g_lr_decay_alpha: 0.1 # minimum G LR fraction (floor)
|
|
339
|
+
d_lr_decay_alpha: 0.1 # minimum D LR fraction (floor)
|
|
340
|
+
|
|
341
|
+
# Evaluation & checkpointing
|
|
342
|
+
checkpoint_frequency: 100 # save "latest" every N epochs
|
|
343
|
+
eval_frequency: 100 # run evaluation suite every N epochs
|
|
344
|
+
test_split_pct: 0.2 # percentage of data to hold out for in-training evaluation
|
|
345
|
+
```
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
### Sweep Config Reference
|
|
349
|
+
|
|
350
|
+
```yaml
|
|
351
|
+
# =============================================================
|
|
352
|
+
# SWEEP STRATEGY & METRICS
|
|
353
|
+
# =============================================================
|
|
354
|
+
method: bayes
|
|
355
|
+
|
|
356
|
+
metric:
|
|
357
|
+
name: Eval/Total_Error # W&B log key
|
|
358
|
+
goal: minimize
|
|
359
|
+
|
|
360
|
+
early_terminate:
|
|
361
|
+
type: hyperband # Kills unpromising runs early to save compute time
|
|
362
|
+
min_iter: 300 # Don't kill any run before e.g. epoch 300
|
|
363
|
+
eta: 3 # The halving rate for the Hyperband brackets
|
|
364
|
+
|
|
365
|
+
# =============================================================
|
|
366
|
+
# PARAMETERS
|
|
367
|
+
# =============================================================
|
|
368
|
+
parameters:
|
|
369
|
+
|
|
370
|
+
# Sweeps choose from a fixed set of hyperparameter values
|
|
371
|
+
model.latent_dim:
|
|
372
|
+
values: [ 16, 32, 64, 128, 256 ]
|
|
373
|
+
|
|
374
|
+
# -----------------------------------------------------------
|
|
375
|
+
# Generator Architecture
|
|
376
|
+
# -----------------------------------------------------------
|
|
377
|
+
generator.num_hidden_layers:
|
|
378
|
+
values: [ 2, 3, 4 ]
|
|
379
|
+
generator.base_units:
|
|
380
|
+
values: [ 32, 64, 128, 256, 512 ]
|
|
381
|
+
generator.max_units:
|
|
382
|
+
value: 512
|
|
383
|
+
generator.architecture_shape:
|
|
384
|
+
values: [ "block", "ascending", "descending" ]
|
|
385
|
+
|
|
386
|
+
generator.dropout:
|
|
387
|
+
value: 0.0 # e.g. Fixed to 0.0
|
|
388
|
+
generator.activation:
|
|
389
|
+
values: [ 'relu', 'leaky_relu' ]
|
|
390
|
+
generator.batch_norm:
|
|
391
|
+
values: [ true, false ]
|
|
392
|
+
|
|
393
|
+
# -----------------------------------------------------------
|
|
394
|
+
# Discriminator Architecture
|
|
395
|
+
# -----------------------------------------------------------
|
|
396
|
+
discriminator.num_hidden_layers:
|
|
397
|
+
values: [ 2, 3, 4 ]
|
|
398
|
+
discriminator.base_units:
|
|
399
|
+
values: [ 32, 64, 128, 256, 512 ]
|
|
400
|
+
discriminator.max_units:
|
|
401
|
+
value: 512
|
|
402
|
+
discriminator.architecture_shape:
|
|
403
|
+
values: [ "block", "ascending", "descending" ]
|
|
404
|
+
|
|
405
|
+
discriminator.dropout:
|
|
406
|
+
values: [ 0.0, 0.2, 0.3, 0.5 ]
|
|
407
|
+
discriminator.activation:
|
|
408
|
+
values: [ 'relu', 'leaky_relu' ]
|
|
409
|
+
discriminator.negative_slope:
|
|
410
|
+
values: [ 0.1, 0.2, 0.3 ]
|
|
411
|
+
discriminator.pack_size:
|
|
412
|
+
values: [ 1, 3 ]
|
|
413
|
+
discriminator.label_smoothing_min:
|
|
414
|
+
values: [ 0.85, 0.9, 0.95, 1.0 ]
|
|
415
|
+
discriminator.label_flipping:
|
|
416
|
+
values: [ 0.0, 0.05, 0.1 ]
|
|
417
|
+
|
|
418
|
+
# -----------------------------------------------------------
|
|
419
|
+
# Training Loop & Optimizers
|
|
420
|
+
# -----------------------------------------------------------
|
|
421
|
+
training.batch_size:
|
|
422
|
+
values: [ 64, 128, 256, 512 ]
|
|
423
|
+
training.g_updates_per_epoch:
|
|
424
|
+
values: [ 1, 2, 3 ]
|
|
425
|
+
training.adam_beta_1:
|
|
426
|
+
values: [ 0.2, 0.5, 0.7, 0.9 ]
|
|
427
|
+
|
|
428
|
+
# Learning Rates
|
|
429
|
+
training.g_lr:
|
|
430
|
+
distribution: log_uniform_values
|
|
431
|
+
min: 0.00001
|
|
432
|
+
max: 0.001
|
|
433
|
+
training.d_lr:
|
|
434
|
+
distribution: log_uniform_values
|
|
435
|
+
min: 0.000005
|
|
436
|
+
max: 0.0005 # at-gan ensures d_lr <= g_lr
|
|
437
|
+
|
|
438
|
+
# Cosine Decay Warm Restart Parameters
|
|
439
|
+
training.lr_cosine_decay_restart_epochs:
|
|
440
|
+
distribution: int_uniform
|
|
441
|
+
min: 100
|
|
442
|
+
max: 1000
|
|
443
|
+
training.g_lr_decay_alpha:
|
|
444
|
+
distribution: log_uniform_values
|
|
445
|
+
min: 0.01 # Decay to 1% of max LR
|
|
446
|
+
max: 1 # No decay
|
|
447
|
+
training.d_lr_decay_alpha:
|
|
448
|
+
distribution: log_uniform_values
|
|
449
|
+
min: 0.01
|
|
450
|
+
max: 1
|
|
451
|
+
```
|
|
452
|
+
|
|
453
|
+
---
|
|
454
|
+
|
|
455
|
+
## In-Training Evaluation Suite
|
|
456
|
+
|
|
457
|
+
Every `eval_frequency` epochs, `GANCallback` generates synthetic samples and runs an evaluation against the held-out real samples to guide the hyperparameter sweep:
|
|
458
|
+
|
|
459
|
+
| Metric | Computation |
|
|
460
|
+
|-------------------|---------------------------------------------------------------------------------------------------------------------------|
|
|
461
|
+
| PCA Error | First Wasserstein distance between real and synthetic data across the first five PCA components |
|
|
462
|
+
| Adversarial Error | Absolute AUC deviation of a Random Forest classifier trained to distinguish real and synthetic data (`\|AUC - 0.5\| × 2`) |
|
|
463
|
+
| **Total Error** | `sqrt((pca_error² + adv_error²) / 2.0)` |
|
|
464
|
+
|
|
465
|
+
Raw errors are passed through a squashing function (`1 - exp(-x)`) so components ∈ `[0, 1]`.
|
|
466
|
+
|
|
467
|
+
### Visual artifacts (auto-logged to W&B)
|
|
468
|
+
- **Correlation heatmaps**: real, synthetic, and absolute difference.
|
|
469
|
+
- **PCA scatter overlay**: first two principal components of real vs. synthetic.
|
|
470
|
+
|
|
471
|
+
## Synthetic Data Evaluation (Post-Training)
|
|
472
|
+
|
|
473
|
+
The `evaluate` command runs a comprehensive benchmark suite that assesses the quality of the synthetic data generated by the GAN:
|
|
474
|
+
|
|
475
|
+
1. **Privacy (DCR)**: Distance to Closest Record. Measures the minimum Euclidean distance (in standard deviations) between synthetic rows and real training rows. Absence of exact memorization is guaranteed if ``Min. DCR > 0``.
|
|
476
|
+
2. **Statistic Fidelity (SDV)**: Uses the [Synthetic Data Vault](https://github.com/sdv-dev/SDV) (`sdmetrics` package) to generate a Quality Report, comparing 1D marginal distributions (Column Shapes) and 2D correlations (Column Pair Trends).
|
|
477
|
+
3. **Utility Retention (TSTR)**: Train on Synthetic, Test on Real.
|
|
478
|
+
* Splits real data into `real_train` (80%) and `real_test` (20%).
|
|
479
|
+
* Trains a **TRTR** baseline (`RandomForest`, `GradientBoosting`, `LogisticRegression`) on `real_train` → baseline F1 on `real_test`.
|
|
480
|
+
* Trains **TSTR** models on the entire synthetic set → F1 on the **same** `real_test`.
|
|
481
|
+
* Reports `TSTR_Mean_F1 / TRTR_Mean_F1 × 100` (F1-Score Retention in %).
|