at-gan 0.11.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. at_gan-0.11.6/LICENSE +21 -0
  2. at_gan-0.11.6/MANIFEST.in +3 -0
  3. at_gan-0.11.6/PKG-INFO +481 -0
  4. at_gan-0.11.6/README.md +441 -0
  5. at_gan-0.11.6/pyproject.toml +70 -0
  6. at_gan-0.11.6/setup.cfg +4 -0
  7. at_gan-0.11.6/src/at_gan/__init__.py +8 -0
  8. at_gan-0.11.6/src/at_gan/api.py +179 -0
  9. at_gan-0.11.6/src/at_gan/callbacks/__init__.py +0 -0
  10. at_gan-0.11.6/src/at_gan/callbacks/console_callback.py +68 -0
  11. at_gan-0.11.6/src/at_gan/callbacks/gan_callback.py +344 -0
  12. at_gan-0.11.6/src/at_gan/callbacks/wand_callback.py +37 -0
  13. at_gan-0.11.6/src/at_gan/cli.py +220 -0
  14. at_gan-0.11.6/src/at_gan/data/__init__.py +0 -0
  15. at_gan-0.11.6/src/at_gan/data/preprocessor.py +427 -0
  16. at_gan-0.11.6/src/at_gan/engine/__init__.py +0 -0
  17. at_gan-0.11.6/src/at_gan/engine/core.py +267 -0
  18. at_gan-0.11.6/src/at_gan/engine/sweeper.py +132 -0
  19. at_gan-0.11.6/src/at_gan/engine/synthesizer.py +59 -0
  20. at_gan-0.11.6/src/at_gan/eval/__init__.py +0 -0
  21. at_gan-0.11.6/src/at_gan/eval/dcr.py +79 -0
  22. at_gan-0.11.6/src/at_gan/eval/sdv.py +62 -0
  23. at_gan-0.11.6/src/at_gan/eval/tstr.py +165 -0
  24. at_gan-0.11.6/src/at_gan/models/__init__.py +0 -0
  25. at_gan-0.11.6/src/at_gan/models/discriminator.py +52 -0
  26. at_gan-0.11.6/src/at_gan/models/gan.py +184 -0
  27. at_gan-0.11.6/src/at_gan/models/generator.py +101 -0
  28. at_gan-0.11.6/src/at_gan/training/__init__.py +0 -0
  29. at_gan-0.11.6/src/at_gan/training/trainer.py +198 -0
  30. at_gan-0.11.6/src/at_gan/utils/__init__.py +0 -0
  31. at_gan-0.11.6/src/at_gan/utils/logger.py +18 -0
  32. at_gan-0.11.6/src/at_gan/utils/paths.py +63 -0
  33. at_gan-0.11.6/src/at_gan.egg-info/PKG-INFO +481 -0
  34. at_gan-0.11.6/src/at_gan.egg-info/SOURCES.txt +36 -0
  35. at_gan-0.11.6/src/at_gan.egg-info/dependency_links.txt +1 -0
  36. at_gan-0.11.6/src/at_gan.egg-info/entry_points.txt +2 -0
  37. at_gan-0.11.6/src/at_gan.egg-info/requires.txt +17 -0
  38. at_gan-0.11.6/src/at_gan.egg-info/top_level.txt +1 -0
at_gan-0.11.6/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Jonas Miesenböck
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,3 @@
1
+ exclude tests/*
2
+ exclude tests/**
3
+ prune tests
at_gan-0.11.6/PKG-INFO ADDED
@@ -0,0 +1,481 @@
1
+ Metadata-Version: 2.4
2
+ Name: at-gan
3
+ Version: 0.11.6
4
+ Summary: Training Framework for Arbitrary Tabular Generative Adversarial Networks
5
+ Author-email: Jonas Miesenböck <jonas@miesenboeck.at>
6
+ Project-URL: Homepage, https://github.com/Jns-M/at-gan
7
+ Project-URL: Repository, https://github.com/Jns-M/at-gan
8
+ Project-URL: Issues, https://github.com/Jns-M/at-gan/issues
9
+ Project-URL: Documentation, https://github.com/Jns-M/at-gan#readme
10
+ Keywords: gan,tabular-data,synthetic-data,synthetic-tabular-data,tensorflow,keras,machine-learning,generative-adversarial-network,generative-adversarial-networks
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
15
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.10
18
+ Classifier: Programming Language :: Python :: 3.11
19
+ Classifier: Programming Language :: Python :: 3.12
20
+ Requires-Python: <3.13,>=3.10
21
+ Description-Content-Type: text/markdown
22
+ License-File: LICENSE
23
+ Requires-Dist: tensorflow>=2.16.0
24
+ Requires-Dist: pandas>=2.0.0
25
+ Requires-Dist: scikit-learn>=1.3.0
26
+ Requires-Dist: wandb>=0.16.0
27
+ Requires-Dist: pyyaml>=6.0.0
28
+ Requires-Dist: numpy>=1.24.0
29
+ Requires-Dist: joblib>=1.3.0
30
+ Requires-Dist: typer>=0.9.0
31
+ Requires-Dist: seaborn>=0.12.0
32
+ Requires-Dist: matplotlib>=2.0.0
33
+ Requires-Dist: sdmetrics>=0.20.0
34
+ Requires-Dist: pandas-stubs~=2.3.3
35
+ Provides-Extra: dev
36
+ Requires-Dist: build; extra == "dev"
37
+ Requires-Dist: twine; extra == "dev"
38
+ Requires-Dist: pytest; extra == "dev"
39
+ Dynamic: license-file
40
+
41
+ <div align="center">
42
+
43
+ # AT-GAN
44
+
45
+ ### Arbitrary Tabular Generative Adversarial Network
46
+
47
+ *A Tabular GAN framework for generating synthetic tabular data from arbitrary mixed-type tabular datasets.*
48
+
49
+ [![Python](https://img.shields.io/badge/python-3.10--3.12-blue.svg)](https://www.python.org/)
50
+ [![TensorFlow](https://img.shields.io/badge/TensorFlow-2.x-orange.svg)](https://www.tensorflow.org/)
51
+ [![Keras](https://img.shields.io/badge/Keras-3.x-red.svg)](https://keras.io/)
52
+ [![W&B](https://img.shields.io/badge/tracking-Weights%20%26%20Biases-yellow.svg)](https://wandb.ai/)
53
+ [![Typer](https://img.shields.io/badge/CLI-Typer-green.svg)](https://typer.tiangolo.com/)
54
+ [![Status](https://img.shields.io/badge/status-research-purple.svg)]()
55
+
56
+ </div>
57
+
58
+ ---
59
+
60
+ ## Table of Contents
61
+
62
+ 1. [Overview](#overview)
63
+ 1. [Key Features](#key-features)
64
+ 1. [Installation](#installation)
65
+ 1. [CLI Usage](#cli-usage)
66
+ 1. [API Usage](#api-usage)
67
+ 1. [Configuration Reference](#configuration-reference)
68
+ 1. [In-Training Evaluation Suite](#in-training-evaluation-suite-1)
69
+ 1. [Synthetic Data Evaluation (Post-Training)](#synthetic-data-evaluation-post-training-1)
70
+
71
+ ---
72
+
73
+ ## Overview
74
+
75
+ **at-gan** is a framework for training Generative Adversarial Networks on **arbitrary tabular data**. It is designed to work with *continuous*, *binary*, *discrete count*, and *categorical* features within a single pipeline.
76
+
77
+ The framework combines a **multi-branch generator** (G), a **PacGAN-style discriminator** (D), an integrated **evaluation
78
+ suite**, and **[Weights & Biases](https://wandb.ai/)** (W&B) sweep orchestration, experiment tracking, and training monitoring + visualization.
79
+
80
+
81
+ > **Goal:** Training a GAN that is capable of producing realistic synthetic tabular data from a given dataset with minimal manual tuning and a transparent, observable training process.
82
+
83
+ ---
84
+
85
+ ## Key Features
86
+
87
+ ### Dynamic, Config-Driven Architectures
88
+ - Generator and Discriminator built **entirely from YAML-config**.
89
+ - Configurable amount of `layers` and `units`.
90
+ - Configurable activations: `relu`, `leaky_relu`, `elu`, or any other activation supported in Keras.
91
+ - Configurable `dropout` layers.
92
+ - Optional `Batch Normalization` for G.
93
+
94
+ ### Mixed-Type Data Handling
95
+ - The `TabularPreprocessor` handles **types of input features**:
96
+ - **Continuous** → `MinMaxScaler(-1, 1)` → `tanh` output branch.
97
+ - **Discrete Count** → `MinMaxScaler(0, 1)` → `sigmoid` output branch.
98
+ - **Binary** → 0/1 and optional β-distributed noise application → `sigmoid` output branch.
99
+ - **Categorical** → One-hot encoding and optional label-preserving smoothing → `softmax` output branch.
100
+ - Per-column decimal precision preservation.
101
+ - Scalers and encoders are stored and reused for inference.
102
+
103
+ ### GAN Training and Stabilization Techniques
104
+ | Technique | Controlled by | What it does |
105
+ |-------------------------------------|-------------------------------------|-------------------------------------------------------------------------------------------------|
106
+ | **PacGAN packing** | `discriminator.pack_size` | Concatenates *k* rows into a single D input → fights mode collapse |
107
+ | **One-sided label smoothing** | `discriminator.label_smoothing_min` | Real labels sampled from `[min, 1.0]` instead of hard `1.0` |
108
+ | **Label flipping** | `discriminator.label_flipping` | Random fraction of real labels flipped to `0` to prevent D overconfidence |
109
+ | **TTUR** | `g_lr` / `d_lr` | Different LRs for G and D. Sweeps auto-clamp `d_lr ≤ g_lr` |
110
+ | **G:D update ratio** | `g_updates_per_epoch` | Multiple G steps per D step to balance the training process |
111
+ | **LR Cosine decay + warm restarts** | `lr_cosine_decay` | `CosineDecayRestarts` schedule with configurable `alpha` floor for the learning rate of G and D |
112
+ | **Adam `beta_1` override** | `adam_beta_1` | Typically lowered from default `0.9` for training stability |
113
+ | **Gradient clipping** | *always-on* | `clipnorm=1.0` on both Adam optimizers |
114
+
115
+ ### In-Training Evaluation Suite
116
+ Runs every `eval_frequency` epochs on held-out real samples, logs results to W&B, and saves the **best** checkpoint by error score. See [In-Training Evaluation Suite](#in-training-evaluation-suite-1).
117
+
118
+ ### Experiment Tracking
119
+ [Weights & Biases](https://wandb.ai/) integration:
120
+ - Per-epoch loss/metric logging via a dedicated `WandbCallback`.
121
+ - Training visuals: **correlation heatmaps** + **PCA overlap scatter plots**.
122
+ - Local-only mode when `--no-wandb` is set (uses `run_id="offline_run"`).
123
+
124
+ ### Sweeps & Neural Architecture Search
125
+ - W&B sweeps for **Neural Architecture Search** (NAS) and **Hyperparameter Optimization**.
126
+ - Mechanic to resume existing W&B sweeps (and single runs).
127
+
128
+ ### Synthetic Data Evaluation (Post-Training)
129
+ - **Privacy**: Distance to Closest Record (DCR)
130
+ - **Statistic Fidelity**: [Synthetic Data Vault](https://github.com/sdv-dev/sdv) (SDV)
131
+ - **Utility Retention**: Train on Synthetic, Test on Real (TSTR)
132
+
133
+ ### Usage Modes
134
+ - 🖥️ **CLI**: `train`, `sweep`, `generate`, `evaluate`.
135
+ - 🐍 **Python API** (`at_gan.api`): `train`, `sweep`, `generate`, `evaluate`.
136
+
137
+ ---
138
+
139
+ ## Installation
140
+
141
+ **Requirements:** Python `3.10 – 3.12` and dependencies listed in `pyproject.toml`.
142
+
143
+ ### Option A: Install from PyPI (recommended)
144
+
145
+ ```shell script
146
+ pip install at-gan
147
+ ```
148
+
149
+ ### Option B: Editable install
150
+
151
+ 1. Clone this repository
152
+ 1. Run the following command:
153
+
154
+ ```bash
155
+ pip install -e .
156
+ ```
157
+
158
+
159
+ ### Verify installation
160
+
161
+ ```shell script
162
+ at-gan --help
163
+ python -c "import at_gan; print(at_gan.__version__)"
164
+ ```
165
+
166
+
167
+ ### Weights & Biases Login (one-time)
168
+
169
+ ```shell script
170
+ wandb login
171
+ ```
172
+
173
+
174
+ > 💡 You can use this framework without W&B by passing `--no-wandb` (CLI) or `enable_wandb=False` (API).
175
+
176
+ ---
177
+
178
+ ## CLI Usage
179
+
180
+ ```shell script
181
+ at-gan --help
182
+ ```
183
+
184
+
185
+ ### `train`: Run or resume a single GAN training run
186
+
187
+ | Flag | Short | Default | Description |
188
+ |--------------------------|------------|------------|------------------------------------|
189
+ | `--config` | `-c` | *required* | Path to the YAML experiment config |
190
+ | `--wandb / --no-wandb` | `-w / -nw` | `--wandb` | Toggle W&B tracking |
191
+ | `--export / --no-export` | `-e / -ne` | `--export` | Save `.keras` generator file |
192
+ | `--generate-samples` | `-g` | `1000` | Auto-generate *N* samples post-training |
193
+
194
+ **Examples:**
195
+
196
+ ```shell script
197
+ at-gan train -c configs/config.yaml -w -e -g 5000
198
+ ```
199
+
200
+ Note: A run can be resumed via the `resume_run_id` config key. See [Configuration Reference](#configuration-reference).
201
+
202
+ ---
203
+
204
+ ### `sweep`: Run or resume a W&B sweep
205
+
206
+
207
+ | Flag | Short | Description |
208
+ |---|---|--------------------------------------------------|
209
+ | `--base-config` | `-c` | Baseline experiment config |
210
+ | `--sweep-config` | `-s` | W&B sweep config (required for new sweeps) |
211
+ | `--count` | `-n` | Max runs this agent will execute |
212
+ | `--sweep-id` | `-id` | Resume an existing sweep instead of creating one |
213
+
214
+ ```shell script
215
+ # Launch a new 50-run sweep
216
+ at-gan sweep -c configs/config.yaml -s configs/sweep_config.yaml -n 50
217
+
218
+ # Resume an existing sweep
219
+ at-gan sweep -c configs/config.yaml -id abc123 -n 20
220
+ ```
221
+
222
+ ---
223
+
224
+ ### `generate`: Generate synthetic samples from a trained generator
225
+
226
+ | Flag | Short | Description |
227
+ |---|---|------------------------------------------------|
228
+ | `--config` | `-c` | YAML used during the **original** training run |
229
+ | `--run-id` | `-r` | W&B run ID or `"offline_run"` |
230
+ | `--samples` | `-n` | Number of samples to generate |
231
+ | `--output` | `-o` | Optional override for CSV output path |
232
+
233
+ ```shell script
234
+ at-gan generate -c configs/config.yaml -r a1b2c3 -n 10000 -o synthetic_data.csv
235
+ ```
236
+
237
+ Note: `generate` always loads **`best_generator.keras`**, not the latest.
238
+
239
+ ---
240
+
241
+ ### `evaluate`: Run synthetic data evaluation (post-training)
242
+
243
+ | Flag | Short | Description |
244
+ |---------------|-------|--------------------------------|
245
+ | `--real` | `-r` | Path to the real data CSV |
246
+ | `--synthetic` | `-s` | Path to the synthetic data CSV |
247
+ | `--target` | `-t` | Optional target column for TSTR evaluation |
248
+
249
+ ```shell script
250
+ at-gan evaluate -c real_data.csv -r synthetic_data.csv -t target_column
251
+ ```
252
+
253
+ ---
254
+
255
+ ## API Usage
256
+
257
+ The Python API exposes the same primary functions as a CLI, making it easy to integrate into existing projects.
258
+
259
+ See `examples/api_example.py` and `examples/api_example.ipynb` for a full API usage example.
260
+
261
+ > Note: The `train` entry point also accepts a `dict` instead of a path to a YAML file as input.
262
+
263
+ ---
264
+
265
+ ## Configuration Reference
266
+
267
+ Experiments are driven by **two YAML files**: a base config and a sweep config.
268
+
269
+ See `configs/config.yaml` and `configs/sweep_config.yaml` for examples and recommended default values for most datasets.
270
+
271
+ ### Base Config Reference
272
+
273
+ ```yaml
274
+ # =============================================================
275
+ # EXPERIMENT META
276
+ # =============================================================
277
+ experiment_name: "test_experiment" # also output directory name
278
+ resume_run_id: null # W&B run id to resume from checkpoint (optional)
279
+ seed: 1130 # seeds Python, NumPy, TensorFlow
280
+
281
+ # =============================================================
282
+ # DATA
283
+ # =============================================================
284
+ data:
285
+ dataset_path: "datasets/example.csv"
286
+ output_path: "experiments/" # run artifacts found in 'output_path/experiment_name/run_id/'
287
+
288
+ # Column routing — every column the GAN should learn MUST be listed here
289
+ continuous_cols: ["age", "heart_rate", "glucose"]
290
+ binary_cols: ["male", "smoker"]
291
+ discrete_count_cols: ["cigs_per_day"]
292
+ categorical_cols: ["education"]
293
+
294
+ # Preprocessing toggles
295
+ treat_bin_as_cat: false # route binary cols through OHE + softmax
296
+ beta_noise: true # Apply Beta-distributed noise on binary cols
297
+ smooth_categorical: true # Apply label-preserving noise on OHE groups
298
+
299
+ # =============================================================
300
+ # MODEL
301
+ # =============================================================
302
+ model:
303
+ latent_dim: 32
304
+
305
+ generator:
306
+ units: [64, 64]
307
+ dropout: 0.0
308
+ activation: "relu" # relu | leaky_relu | elu | ...
309
+ batch_norm: true # BatchNorm after each Dense layer
310
+ # negative_slope: 0.2 # used only when activation == "leaky_relu"
311
+
312
+ discriminator:
313
+ units: [256, 256]
314
+ dropout: 0.2
315
+ activation: "leaky_relu"
316
+ negative_slope: 0.2
317
+ pack_size: 3 # PacGAN packing factor (1 disables packing)
318
+ label_smoothing_min: 0.9 # e.g. real labels ~ [0.9, 1.0]
319
+ label_flipping: 0.05 # e.g. 5% of real labels flipped to 0 each step
320
+
321
+ # =============================================================
322
+ # TRAINING
323
+ # =============================================================
324
+ training:
325
+ device: "cpu" # "cpu" or "gpu"
326
+ epochs: 2000
327
+ batch_size: 512
328
+ g_updates_per_epoch: 2 # G steps per D step
329
+
330
+ # Optimizers
331
+ adam_beta_1: 0.5 # GAN-stable Adam beta_1
332
+ g_lr: 0.0002 # G Learning Rate
333
+ d_lr: 0.0003 # D Learning Rate
334
+
335
+ # LR schedule
336
+ lr_cosine_decay: true
337
+ lr_cosine_decay_restart_epochs: 2000 # restart every N epochs
338
+ g_lr_decay_alpha: 0.1 # minimum G LR fraction (floor)
339
+ d_lr_decay_alpha: 0.1 # minimum D LR fraction (floor)
340
+
341
+ # Evaluation & checkpointing
342
+ checkpoint_frequency: 100 # save "latest" every N epochs
343
+ eval_frequency: 100 # run evaluation suite every N epochs
344
+ test_split_pct: 0.2 # percentage of data to hold out for in-training evaluation
345
+ ```
346
+
347
+
348
+ ### Sweep Config Reference
349
+
350
+ ```yaml
351
+ # =============================================================
352
+ # SWEEP STRATEGY & METRICS
353
+ # =============================================================
354
+ method: bayes
355
+
356
+ metric:
357
+ name: Eval/Total_Error # W&B log key
358
+ goal: minimize
359
+
360
+ early_terminate:
361
+ type: hyperband # Kills unpromising runs early to save compute time
362
+ min_iter: 300 # Don't kill any run before e.g. epoch 300
363
+ eta: 3 # The halving rate for the Hyperband brackets
364
+
365
+ # =============================================================
366
+ # PARAMETERS
367
+ # =============================================================
368
+ parameters:
369
+
370
+ # Sweeps choose from a fixed set of hyperparameter values
371
+ model.latent_dim:
372
+ values: [ 16, 32, 64, 128, 256 ]
373
+
374
+ # -----------------------------------------------------------
375
+ # Generator Architecture
376
+ # -----------------------------------------------------------
377
+ generator.num_hidden_layers:
378
+ values: [ 2, 3, 4 ]
379
+ generator.base_units:
380
+ values: [ 32, 64, 128, 256, 512 ]
381
+ generator.max_units:
382
+ value: 512
383
+ generator.architecture_shape:
384
+ values: [ "block", "ascending", "descending" ]
385
+
386
+ generator.dropout:
387
+ value: 0.0 # e.g. Fixed to 0.0
388
+ generator.activation:
389
+ values: [ 'relu', 'leaky_relu' ]
390
+ generator.batch_norm:
391
+ values: [ true, false ]
392
+
393
+ # -----------------------------------------------------------
394
+ # Discriminator Architecture
395
+ # -----------------------------------------------------------
396
+ discriminator.num_hidden_layers:
397
+ values: [ 2, 3, 4 ]
398
+ discriminator.base_units:
399
+ values: [ 32, 64, 128, 256, 512 ]
400
+ discriminator.max_units:
401
+ value: 512
402
+ discriminator.architecture_shape:
403
+ values: [ "block", "ascending", "descending" ]
404
+
405
+ discriminator.dropout:
406
+ values: [ 0.0, 0.2, 0.3, 0.5 ]
407
+ discriminator.activation:
408
+ values: [ 'relu', 'leaky_relu' ]
409
+ discriminator.negative_slope:
410
+ values: [ 0.1, 0.2, 0.3 ]
411
+ discriminator.pack_size:
412
+ values: [ 1, 3 ]
413
+ discriminator.label_smoothing_min:
414
+ values: [ 0.85, 0.9, 0.95, 1.0 ]
415
+ discriminator.label_flipping:
416
+ values: [ 0.0, 0.05, 0.1 ]
417
+
418
+ # -----------------------------------------------------------
419
+ # Training Loop & Optimizers
420
+ # -----------------------------------------------------------
421
+ training.batch_size:
422
+ values: [ 64, 128, 256, 512 ]
423
+ training.g_updates_per_epoch:
424
+ values: [ 1, 2, 3 ]
425
+ training.adam_beta_1:
426
+ values: [ 0.2, 0.5, 0.7, 0.9 ]
427
+
428
+ # Learning Rates
429
+ training.g_lr:
430
+ distribution: log_uniform_values
431
+ min: 0.00001
432
+ max: 0.001
433
+ training.d_lr:
434
+ distribution: log_uniform_values
435
+ min: 0.000005
436
+ max: 0.0005 # at-gan ensures d_lr <= g_lr
437
+
438
+ # Cosine Decay Warm Restart Parameters
439
+ training.lr_cosine_decay_restart_epochs:
440
+ distribution: int_uniform
441
+ min: 100
442
+ max: 1000
443
+ training.g_lr_decay_alpha:
444
+ distribution: log_uniform_values
445
+ min: 0.01 # Decay to 1% of max LR
446
+ max: 1 # No decay
447
+ training.d_lr_decay_alpha:
448
+ distribution: log_uniform_values
449
+ min: 0.01
450
+ max: 1
451
+ ```
452
+
453
+ ---
454
+
455
+ ## In-Training Evaluation Suite
456
+
457
+ Every `eval_frequency` epochs, `GANCallback` generates synthetic samples and runs an evaluation against the held-out real samples to guide the hyperparameter sweep:
458
+
459
+ | Metric | Computation |
460
+ |-------------------|---------------------------------------------------------------------------------------------------------------------------|
461
+ | PCA Error | First Wasserstein distance between real and synthetic data across the first five PCA components |
462
+ | Adversarial Error | Absolute AUC deviation of a Random Forest classifier trained to distinguish real and synthetic data (`\|AUC - 0.5\| × 2`) |
463
+ | **Total Error** | `sqrt((pca_error² + adv_error²) / 2.0)` |
464
+
465
+ Raw errors are passed through a squashing function (`1 - exp(-x)`) so components ∈ `[0, 1]`.
466
+
467
+ ### Visual artifacts (auto-logged to W&B)
468
+ - **Correlation heatmaps**: real, synthetic, and absolute difference.
469
+ - **PCA scatter overlay**: first two principal components of real vs. synthetic.
470
+
471
+ ## Synthetic Data Evaluation (Post-Training)
472
+
473
+ The `evaluate` command runs a comprehensive benchmark suite that assesses the quality of the synthetic data generated by the GAN:
474
+
475
+ 1. **Privacy (DCR)**: Distance to Closest Record. Measures the minimum Euclidean distance (in standard deviations) between synthetic rows and real training rows. Absence of exact memorization is guaranteed if ``Min. DCR > 0``.
476
+ 2. **Statistic Fidelity (SDV)**: Uses the [Synthetic Data Vault](https://github.com/sdv-dev/SDV) (`sdmetrics` package) to generate a Quality Report, comparing 1D marginal distributions (Column Shapes) and 2D correlations (Column Pair Trends).
477
+ 3. **Utility Retention (TSTR)**: Train on Synthetic, Test on Real.
478
+ * Splits real data into `real_train` (80%) and `real_test` (20%).
479
+ * Trains a **TRTR** baseline (`RandomForest`, `GradientBoosting`, `LogisticRegression`) on `real_train` → baseline F1 on `real_test`.
480
+ * Trains **TSTR** models on the entire synthetic set → F1 on the **same** `real_test`.
481
+ * Reports `TSTR_Mean_F1 / TRTR_Mean_F1 × 100` (F1-Score Retention in %).