at-gan-core 0.13.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. at_gan_core-0.13.1/LICENSE +21 -0
  2. at_gan_core-0.13.1/PKG-INFO +484 -0
  3. at_gan_core-0.13.1/README.md +443 -0
  4. at_gan_core-0.13.1/pyproject.toml +81 -0
  5. at_gan_core-0.13.1/setup.cfg +4 -0
  6. at_gan_core-0.13.1/src/at_gan/__init__.py +8 -0
  7. at_gan_core-0.13.1/src/at_gan/api.py +183 -0
  8. at_gan_core-0.13.1/src/at_gan/callbacks/__init__.py +0 -0
  9. at_gan_core-0.13.1/src/at_gan/callbacks/console_callback.py +68 -0
  10. at_gan_core-0.13.1/src/at_gan/callbacks/gan_callback.py +354 -0
  11. at_gan_core-0.13.1/src/at_gan/callbacks/wand_callback.py +37 -0
  12. at_gan_core-0.13.1/src/at_gan/cli.py +220 -0
  13. at_gan_core-0.13.1/src/at_gan/data/__init__.py +0 -0
  14. at_gan_core-0.13.1/src/at_gan/data/preprocessor.py +427 -0
  15. at_gan_core-0.13.1/src/at_gan/engine/__init__.py +0 -0
  16. at_gan_core-0.13.1/src/at_gan/engine/core.py +267 -0
  17. at_gan_core-0.13.1/src/at_gan/engine/sweeper.py +132 -0
  18. at_gan_core-0.13.1/src/at_gan/engine/synthesizer.py +59 -0
  19. at_gan_core-0.13.1/src/at_gan/eval/__init__.py +0 -0
  20. at_gan_core-0.13.1/src/at_gan/eval/dcr.py +81 -0
  21. at_gan_core-0.13.1/src/at_gan/eval/plots.py +450 -0
  22. at_gan_core-0.13.1/src/at_gan/eval/sdv.py +62 -0
  23. at_gan_core-0.13.1/src/at_gan/eval/tstr.py +165 -0
  24. at_gan_core-0.13.1/src/at_gan/models/__init__.py +0 -0
  25. at_gan_core-0.13.1/src/at_gan/models/discriminator.py +52 -0
  26. at_gan_core-0.13.1/src/at_gan/models/gan.py +184 -0
  27. at_gan_core-0.13.1/src/at_gan/models/generator.py +101 -0
  28. at_gan_core-0.13.1/src/at_gan/training/__init__.py +0 -0
  29. at_gan_core-0.13.1/src/at_gan/training/trainer.py +198 -0
  30. at_gan_core-0.13.1/src/at_gan/utils/__init__.py +0 -0
  31. at_gan_core-0.13.1/src/at_gan/utils/logger.py +18 -0
  32. at_gan_core-0.13.1/src/at_gan/utils/paths.py +63 -0
  33. at_gan_core-0.13.1/src/at_gan_core.egg-info/PKG-INFO +484 -0
  34. at_gan_core-0.13.1/src/at_gan_core.egg-info/SOURCES.txt +37 -0
  35. at_gan_core-0.13.1/src/at_gan_core.egg-info/dependency_links.txt +1 -0
  36. at_gan_core-0.13.1/src/at_gan_core.egg-info/entry_points.txt +2 -0
  37. at_gan_core-0.13.1/src/at_gan_core.egg-info/requires.txt +19 -0
  38. at_gan_core-0.13.1/src/at_gan_core.egg-info/top_level.txt +1 -0
  39. at_gan_core-0.13.1/tests/test_package.py +57 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Jonas Miesenböck
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,484 @@
1
+ Metadata-Version: 2.4
2
+ Name: at-gan-core
3
+ Version: 0.13.1
4
+ Summary: Training Framework for Arbitrary Tabular Generative Adversarial Networks
5
+ Author-email: Jonas Miesenböck <jonas@miesenboeck.at>
6
+ Project-URL: Homepage, https://github.com/Jns-M/at-gan
7
+ Project-URL: Repository, https://github.com/Jns-M/at-gan
8
+ Project-URL: Issues, https://github.com/Jns-M/at-gan/issues
9
+ Project-URL: Documentation, https://github.com/Jns-M/at-gan#readme
10
+ Keywords: gan,tabular-data,synthetic-data,synthetic-tabular-data,tensorflow,keras,machine-learning,generative-adversarial-network,generative-adversarial-networks
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
15
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.10
18
+ Classifier: Programming Language :: Python :: 3.11
19
+ Classifier: Programming Language :: Python :: 3.12
20
+ Requires-Python: <3.13,>=3.10
21
+ Description-Content-Type: text/markdown
22
+ License-File: LICENSE
23
+ Requires-Dist: pandas>=2.0.0
24
+ Requires-Dist: scikit-learn>=1.3.0
25
+ Requires-Dist: wandb>=0.16.0
26
+ Requires-Dist: pyyaml>=6.0.0
27
+ Requires-Dist: numpy>=1.24.0
28
+ Requires-Dist: joblib>=1.3.0
29
+ Requires-Dist: typer>=0.9.0
30
+ Requires-Dist: seaborn>=0.12.0
31
+ Requires-Dist: matplotlib>=2.0.0
32
+ Requires-Dist: sdmetrics>=0.20.0
33
+ Requires-Dist: pandas-stubs~=2.3.3
34
+ Provides-Extra: dev
35
+ Requires-Dist: build; extra == "dev"
36
+ Requires-Dist: twine; extra == "dev"
37
+ Requires-Dist: pytest; extra == "dev"
38
+ Provides-Extra: tf
39
+ Requires-Dist: tensorflow>=2.16.0; extra == "tf"
40
+ Dynamic: license-file
41
+
42
+ <div align="center">
43
+
44
+ # AT-GAN
45
+
46
+ ### Arbitrary Tabular Generative Adversarial Network
47
+
48
+ *A Tabular GAN framework for generating synthetic tabular data from arbitrary mixed-type tabular datasets.*
49
+
50
+ [![Python](https://img.shields.io/badge/python-3.10--3.12-blue.svg)](https://www.python.org/)
51
+ [![TensorFlow](https://img.shields.io/badge/TensorFlow-2.x-purple.svg)](https://www.tensorflow.org/)
52
+ [![Keras](https://img.shields.io/badge/Keras-3.x-red.svg)](https://keras.io/)
53
+ [![W&B](https://img.shields.io/badge/tracking-Weights%20%26%20Biases-yellow.svg)](https://wandb.ai/)
54
+ [![License](https://img.shields.io/badge/license-MIT-green.svg)](https://github.com/Jns-M/at-gan/blob/main/LICENSE)
55
+ [![PyPI](https://img.shields.io/pypi/v/at-gan)](https://pypi.org/project/at-gan/)
56
+
57
+ </div>
58
+
59
+ ---
60
+
61
+ ## Table of Contents
62
+
63
+ 1. [Overview](#overview)
64
+ 1. [Key Features](#key-features)
65
+ 1. [Installation](#installation)
66
+ 1. [CLI Usage](#cli-usage)
67
+ 1. [API Usage](#api-usage)
68
+ 1. [Configuration Reference](#configuration-reference)
69
+ 1. [In-Training Evaluation Suite](#in-training-evaluation-suite-1)
70
+ 1. [Synthetic Data Evaluation (Post-Training)](#synthetic-data-evaluation-post-training-1)
71
+
72
+ ---
73
+
74
+ ## Overview
75
+
76
+ **at-gan** is a framework for training Generative Adversarial Networks on **arbitrary tabular data**. It is designed to work with *continuous*, *binary*, *discrete count*, and *categorical* features within a single pipeline.
77
+
78
+ The framework combines a **multi-branch generator** (G), a **PacGAN-style discriminator** (D), an integrated **evaluation
79
+ suite**, and **[Weights & Biases](https://wandb.ai/)** (W&B) sweep orchestration, experiment tracking, and training monitoring + visualization.
80
+
81
+
82
+ > **Goal:** Training a GAN that is capable of producing realistic synthetic tabular data from a given dataset with minimal manual tuning and a transparent, observable training process.
83
+
84
+ ---
85
+
86
+ ## Key Features
87
+
88
+ ### Dynamic, Config-Driven Architectures
89
+ - Generator and Discriminator built **entirely from YAML-config**.
90
+ - Configurable amount of `layers` and `units`.
91
+ - Configurable activations: `relu`, `leaky_relu`, `elu`, or any other activation supported in Keras.
92
+ - Configurable `dropout` layers.
93
+ - Optional `Batch Normalization` for G.
94
+
95
+ ### Mixed-Type Data Handling
96
+ - The `TabularPreprocessor` handles **types of input features**:
97
+ - **Continuous** → `MinMaxScaler(-1, 1)` → `tanh` output branch.
98
+ - **Discrete Count** → `MinMaxScaler(0, 1)` → `sigmoid` output branch.
99
+ - **Binary** → 0/1 and optional β-distributed noise application → `sigmoid` output branch.
100
+ - **Categorical** → One-hot encoding and optional label-preserving smoothing → `softmax` output branch.
101
+ - Per-column decimal precision preservation.
102
+ - Scalers and encoders are stored and reused for inference.
103
+
104
+ ### GAN Training and Stabilization Techniques
105
+ | Technique | Controlled by | What it does |
106
+ |-------------------------------------|-------------------------------------|-------------------------------------------------------------------------------------------------|
107
+ | **PacGAN packing** | `discriminator.pack_size` | Concatenates *k* rows into a single D input → fights mode collapse |
108
+ | **One-sided label smoothing** | `discriminator.label_smoothing_min` | Real labels sampled from `[min, 1.0]` instead of hard `1.0` |
109
+ | **Label flipping** | `discriminator.label_flipping` | Random fraction of real labels flipped to `0` to prevent D overconfidence |
110
+ | **TTUR** | `g_lr` / `d_lr` | Different LRs for G and D. Sweeps auto-clamp `d_lr ≤ g_lr` |
111
+ | **G:D update ratio** | `g_updates_per_epoch` | Multiple G steps per D step to balance the training process |
112
+ | **LR Cosine decay + warm restarts** | `lr_cosine_decay` | `CosineDecayRestarts` schedule with configurable `alpha` floor for the learning rate of G and D |
113
+ | **Adam `beta_1` override** | `adam_beta_1` | Typically lowered from default `0.9` for training stability |
114
+ | **Gradient clipping** | *always-on* | `clipnorm=1.0` on both Adam optimizers |
115
+
116
+ ### In-Training Evaluation Suite
117
+ Runs every `eval_frequency` epochs on held-out real samples, logs results to W&B, and saves the **best** checkpoint by error score. See [In-Training Evaluation Suite](#in-training-evaluation-suite-1).
118
+
119
+ ### Experiment Tracking
120
+ [Weights & Biases](https://wandb.ai/) integration:
121
+ - Per-epoch loss/metric logging via a dedicated `WandbCallback`.
122
+ - Training visuals: **correlation heatmaps** + **PCA overlap scatter plots**.
123
+ - Local-only mode when `--no-wandb` is set (uses `run_id="offline_run"`).
124
+
125
+ ### Sweeps & Neural Architecture Search
126
+ - W&B sweeps for **Neural Architecture Search** (NAS) and **Hyperparameter Optimization**.
127
+ - Mechanic to resume existing W&B sweeps (and single runs).
128
+
129
+ ### Synthetic Data Evaluation (Post-Training)
130
+ - **Privacy**: Distance to Closest Record (DCR)
131
+ - **Statistic Fidelity**: [Synthetic Data Vault](https://github.com/sdv-dev/sdv) (SDV)
132
+ - **Utility Retention**: Train on Synthetic, Test on Real (TSTR)
133
+
134
+ ### Usage Modes
135
+ - 🖥️ **CLI**: `train`, `sweep`, `generate`, `evaluate`.
136
+ - 🐍 **Python API** (`at_gan.api`): `train`, `sweep`, `generate`, `evaluate`.
137
+
138
+ ---
139
+
140
+ ## Installation
141
+
142
+ **Requirements:** Python `3.10 – 3.12` and dependencies listed in `pyproject.toml`.
143
+
144
+ ### Option A: Install from PyPI (recommended)
145
+
146
+ ```shell script
147
+ pip install at-gan
148
+ ```
149
+
150
+ ### Option B: Editable install from the [GitHub Repository](https://github.com/Jns-M/at-gan)
151
+
152
+ 1. Clone the [GitHub repository](https://github.com/Jns-M/at-gan)
153
+ 1. Run the following command:
154
+
155
+ ```bash
156
+ pip install -e .
157
+ ```
158
+
159
+
160
+ ### Verify installation
161
+
162
+ ```shell script
163
+ at-gan --help
164
+ python -c "import at_gan; print(at_gan.__version__)"
165
+ ```
166
+
167
+
168
+ ### Weights & Biases Login (one-time)
169
+
170
+ ```shell script
171
+ wandb login
172
+ ```
173
+
174
+
175
+ > 💡 You can use this framework without W&B by passing `--no-wandb` (CLI) or `enable_wandb=False` (API).
176
+
177
+ ---
178
+
179
+ ## CLI Usage
180
+
181
+ ```shell script
182
+ at-gan --help
183
+ ```
184
+
185
+
186
+ ### `train`: Run or resume a single GAN training run
187
+
188
+ | Flag | Short | Default | Description |
189
+ |--------------------------|------------|------------|------------------------------------|
190
+ | `--config` | `-c` | *required* | Path to the YAML experiment config |
191
+ | `--wandb / --no-wandb` | `-w / -nw` | `--wandb` | Toggle W&B tracking |
192
+ | `--export / --no-export` | `-e / -ne` | `--export` | Save `.keras` generator file |
193
+ | `--generate-samples` | `-g` | `1000` | Auto-generate *N* samples post-training |
194
+
195
+ **Examples:**
196
+
197
+ ```shell script
198
+ at-gan train -c configs/config.yaml -w -e -g 5000
199
+ ```
200
+
201
+ Note: A run can be resumed via the `resume_run_id` config key. See [Configuration Reference](#configuration-reference).
202
+
203
+ ---
204
+
205
+ ### `sweep`: Run or resume a W&B sweep
206
+
207
+
208
+ | Flag | Short | Description |
209
+ |---|---|--------------------------------------------------|
210
+ | `--base-config` | `-c` | Baseline experiment config |
211
+ | `--sweep-config` | `-s` | W&B sweep config (required for new sweeps) |
212
+ | `--count` | `-n` | Max runs this agent will execute |
213
+ | `--sweep-id` | `-id` | Resume an existing sweep instead of creating one |
214
+
215
+ ```shell script
216
+ # Launch a new 50-run sweep
217
+ at-gan sweep -c configs/config.yaml -s configs/sweep_config.yaml -n 50
218
+
219
+ # Resume an existing sweep
220
+ at-gan sweep -c configs/config.yaml -id abc123 -n 20
221
+ ```
222
+
223
+ ---
224
+
225
+ ### `generate`: Generate synthetic samples from a trained generator
226
+
227
+ | Flag | Short | Description |
228
+ |---|---|------------------------------------------------|
229
+ | `--config` | `-c` | YAML used during the **original** training run |
230
+ | `--run-id` | `-r` | W&B run ID or `"offline_run"` |
231
+ | `--samples` | `-n` | Number of samples to generate |
232
+ | `--output` | `-o` | Optional override for CSV output path |
233
+
234
+ ```shell script
235
+ at-gan generate -c configs/config.yaml -r a1b2c3 -n 10000 -o synthetic_data.csv
236
+ ```
237
+
238
+ Note: `generate` always loads **`best_generator.keras`**, not the latest.
239
+
240
+ ---
241
+
242
+ ### `evaluate`: Run synthetic data evaluation (post-training)
243
+
244
+ | Flag | Short | Description |
245
+ |---------------|-------|-------------------------------------------------------|
246
+ | `--real` | `-r` | Path to the real data CSV |
247
+ | `--synthetic` | `-s` | Path to the synthetic data CSV |
248
+ | `--target` | `-t` | Discrete target column for (optional) TSTR evaluation |
249
+
250
+ ```shell script
251
+ at-gan evaluate -c real_data.csv -r synthetic_data.csv -t target_column
252
+ ```
253
+
254
+ Note: TSTR evaluation is only performed if a discrete feature (i.e. binary or categorical) is specified as the target.
255
+
256
+ ---
257
+
258
+ ## API Usage
259
+
260
+ The Python API exposes the same primary functions as a CLI, making it easy to integrate into existing projects.
261
+
262
+ See `examples/api_example.py` and `examples/api_example.ipynb` in the [GitHub Repository](https://github.com/Jns-M/at-gan) for a full API usage example.
263
+
264
+ > Note: The `train` entry point also accepts a `dict` instead of a path to a YAML file as input.
265
+
266
+ ---
267
+
268
+ ## Configuration Reference
269
+
270
+ Experiments are driven by **two YAML files**: a base config and a sweep config.
271
+
272
+ See `configs/config.yaml` and `configs/sweep_config.yaml` in the [GitHub Repository](https://github.com/Jns-M/at-gan) for examples and recommended default values for most datasets.
273
+
274
+ ### Base Config Reference
275
+
276
+ ```yaml
277
+ # =============================================================
278
+ # EXPERIMENT META
279
+ # =============================================================
280
+ experiment_name: "test_experiment" # also output directory name
281
+ resume_run_id: null # W&B run id to resume from checkpoint (optional)
282
+ seed: 1130 # seeds Python, NumPy, TensorFlow
283
+
284
+ # =============================================================
285
+ # DATA
286
+ # =============================================================
287
+ data:
288
+ dataset_path: "datasets/example.csv"
289
+ output_path: "experiments/" # run artifacts found in 'output_path/experiment_name/run_id/'
290
+
291
+ # Column routing — every column the GAN should learn MUST be listed here
292
+ continuous_cols: ["age", "heart_rate", "glucose"]
293
+ binary_cols: ["male", "smoker"]
294
+ discrete_count_cols: ["cigs_per_day"]
295
+ categorical_cols: ["education"]
296
+
297
+ # Preprocessing toggles
298
+ treat_bin_as_cat: false # route binary cols through OHE + softmax
299
+ beta_noise: true # Apply Beta-distributed noise on binary cols
300
+ smooth_categorical: true # Apply label-preserving noise on OHE groups
301
+
302
+ # =============================================================
303
+ # MODEL
304
+ # =============================================================
305
+ model:
306
+ latent_dim: 32
307
+
308
+ generator:
309
+ units: [64, 64]
310
+ dropout: 0.0
311
+ activation: "relu" # relu | leaky_relu | elu | ...
312
+ batch_norm: true # BatchNorm after each Dense layer
313
+ # negative_slope: 0.2 # used only when activation == "leaky_relu"
314
+
315
+ discriminator:
316
+ units: [256, 256]
317
+ dropout: 0.2
318
+ activation: "leaky_relu"
319
+ negative_slope: 0.2
320
+ pack_size: 3 # PacGAN packing factor (1 disables packing)
321
+ label_smoothing_min: 0.9 # e.g. real labels ~ [0.9, 1.0]
322
+ label_flipping: 0.05 # e.g. 5% of real labels flipped to 0 each step
323
+
324
+ # =============================================================
325
+ # TRAINING
326
+ # =============================================================
327
+ training:
328
+ device: "cpu" # "cpu" or "gpu"
329
+ epochs: 2000
330
+ batch_size: 512
331
+ g_updates_per_epoch: 2 # G steps per D step
332
+
333
+ # Optimizers
334
+ adam_beta_1: 0.5 # GAN-stable Adam beta_1
335
+ g_lr: 0.0002 # G Learning Rate
336
+ d_lr: 0.0003 # D Learning Rate
337
+
338
+ # LR schedule
339
+ lr_cosine_decay: true
340
+ lr_cosine_decay_restart_epochs: 2000 # restart every N epochs
341
+ g_lr_decay_alpha: 0.1 # minimum G LR fraction (floor)
342
+ d_lr_decay_alpha: 0.1 # minimum D LR fraction (floor)
343
+
344
+ # Evaluation & checkpointing
345
+ checkpoint_frequency: 100 # save "latest" every N epochs
346
+ eval_frequency: 100 # run evaluation suite every N epochs
347
+ test_split_pct: 0.2 # percentage of data to hold out for in-training evaluation
348
+ ```
349
+
350
+
351
+ ### Sweep Config Reference
352
+
353
+ ```yaml
354
+ # =============================================================
355
+ # SWEEP STRATEGY & METRICS
356
+ # =============================================================
357
+ method: bayes
358
+
359
+ metric:
360
+ name: Eval/Total_Error # W&B log key
361
+ goal: minimize
362
+
363
+ early_terminate:
364
+ type: hyperband # Kills unpromising runs early to save compute time
365
+ min_iter: 300 # Don't kill any run before e.g. epoch 300
366
+ eta: 3 # The halving rate for the Hyperband brackets
367
+
368
+ # =============================================================
369
+ # PARAMETERS
370
+ # =============================================================
371
+ parameters:
372
+
373
+ # Sweeps choose from a fixed set of hyperparameter values
374
+ model.latent_dim:
375
+ values: [ 16, 32, 64, 128, 256 ]
376
+
377
+ # -----------------------------------------------------------
378
+ # Generator Architecture
379
+ # -----------------------------------------------------------
380
+ generator.num_hidden_layers:
381
+ values: [ 2, 3, 4 ]
382
+ generator.base_units:
383
+ values: [ 32, 64, 128, 256, 512 ]
384
+ generator.max_units:
385
+ value: 512
386
+ generator.architecture_shape:
387
+ values: [ "block", "ascending", "descending" ]
388
+
389
+ generator.dropout:
390
+ value: 0.0 # e.g. Fixed to 0.0
391
+ generator.activation:
392
+ values: [ 'relu', 'leaky_relu' ]
393
+ generator.batch_norm:
394
+ values: [ true, false ]
395
+
396
+ # -----------------------------------------------------------
397
+ # Discriminator Architecture
398
+ # -----------------------------------------------------------
399
+ discriminator.num_hidden_layers:
400
+ values: [ 2, 3, 4 ]
401
+ discriminator.base_units:
402
+ values: [ 32, 64, 128, 256, 512 ]
403
+ discriminator.max_units:
404
+ value: 512
405
+ discriminator.architecture_shape:
406
+ values: [ "block", "ascending", "descending" ]
407
+
408
+ discriminator.dropout:
409
+ values: [ 0.0, 0.2, 0.3, 0.5 ]
410
+ discriminator.activation:
411
+ values: [ 'relu', 'leaky_relu' ]
412
+ discriminator.negative_slope:
413
+ values: [ 0.1, 0.2, 0.3 ]
414
+ discriminator.pack_size:
415
+ values: [ 1, 3 ]
416
+ discriminator.label_smoothing_min:
417
+ values: [ 0.85, 0.9, 0.95, 1.0 ]
418
+ discriminator.label_flipping:
419
+ values: [ 0.0, 0.05, 0.1 ]
420
+
421
+ # -----------------------------------------------------------
422
+ # Training Loop & Optimizers
423
+ # -----------------------------------------------------------
424
+ training.batch_size:
425
+ values: [ 64, 128, 256, 512 ]
426
+ training.g_updates_per_epoch:
427
+ values: [ 1, 2, 3 ]
428
+ training.adam_beta_1:
429
+ values: [ 0.2, 0.5, 0.7, 0.9 ]
430
+
431
+ # Learning Rates
432
+ training.g_lr:
433
+ distribution: log_uniform_values
434
+ min: 0.00001
435
+ max: 0.001
436
+ training.d_lr:
437
+ distribution: log_uniform_values
438
+ min: 0.000005
439
+ max: 0.0005 # at-gan ensures d_lr <= g_lr
440
+
441
+ # Cosine Decay Warm Restart Parameters
442
+ training.lr_cosine_decay_restart_epochs:
443
+ distribution: int_uniform
444
+ min: 100
445
+ max: 1000
446
+ training.g_lr_decay_alpha:
447
+ distribution: log_uniform_values
448
+ min: 0.01 # Decay to 1% of max LR
449
+ max: 1 # No decay
450
+ training.d_lr_decay_alpha:
451
+ distribution: log_uniform_values
452
+ min: 0.01
453
+ max: 1
454
+ ```
455
+
456
+ ---
457
+
458
+ ## In-Training Evaluation Suite
459
+
460
+ Every `eval_frequency` epochs, `GANCallback` generates synthetic samples and runs an evaluation against the held-out real samples to guide the hyperparameter sweep:
461
+
462
+ | Metric | Computation |
463
+ |-------------------|---------------------------------------------------------------------------------------------------------------------------|
464
+ | PCA Error | First Wasserstein distance between real and synthetic data across the first five PCA components |
465
+ | Adversarial Error | Absolute AUC deviation of a Random Forest classifier trained to distinguish real and synthetic data (`\|AUC - 0.5\| × 2`) |
466
+ | **Total Error** | `sqrt((pca_error² + adv_error²) / 2.0)` |
467
+
468
+ Raw errors are passed through a squashing function (`1 - exp(-x)`) so components ∈ `[0, 1]`.
469
+
470
+ ### Visual artifacts (auto-logged to W&B)
471
+ - **Correlation heatmaps**: real, synthetic, and absolute difference.
472
+ - **PCA scatter overlay**: first two principal components of real vs. synthetic.
473
+
474
+ ## Synthetic Data Evaluation (Post-Training)
475
+
476
+ The `evaluate` command runs a comprehensive benchmark suite that assesses the quality of the synthetic data generated by the GAN:
477
+
478
+ 1. **Privacy (DCR)**: Distance to Closest Record. Measures the minimum Euclidean distance (in standard deviations) between synthetic rows and real training rows. Absence of exact memorization is guaranteed if ``Min. DCR > 0``.
479
+ 2. **Statistic Fidelity (SDV)**: Uses the [Synthetic Data Vault](https://github.com/sdv-dev/SDV) (`sdmetrics` package) to generate a Quality Report, comparing 1D marginal distributions (Column Shapes) and 2D correlations (Column Pair Trends).
480
+ 3. **Utility Retention (TSTR)**: Train on Synthetic, Test on Real.
481
+ * Splits real data into `real_train` (80%) and `real_test` (20%).
482
+ * Trains a **TRTR** baseline (`RandomForest`, `GradientBoosting`, `LogisticRegression`) on `real_train` → baseline F1 on `real_test`.
483
+ * Trains **TSTR** models on the entire synthetic set → F1 on the **same** `real_test`.
484
+ * Reports `TSTR_Mean_F1 / TRTR_Mean_F1 × 100` (F1-Score Retention in %).