sdgym 0.11.1.dev0__tar.gz → 0.11.2.dev0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sdgym-0.11.1.dev0/sdgym.egg-info → sdgym-0.11.2.dev0}/PKG-INFO +3 -3
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/README.md +2 -2
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/pyproject.toml +1 -1
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/__init__.py +11 -6
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/benchmark.py +20 -16
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/cli/__main__.py +2 -2
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/dataset_explorer.py +33 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/datasets.py +0 -15
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/result_explorer/result_handler.py +10 -1
- sdgym-0.11.2.dev0/sdgym/synthesizers/__init__.py +32 -0
- sdgym-0.11.2.dev0/sdgym/synthesizers/base.py +92 -0
- sdgym-0.11.2.dev0/sdgym/synthesizers/generate.py +110 -0
- sdgym-0.11.2.dev0/sdgym/synthesizers/sdv.py +107 -0
- sdgym-0.11.2.dev0/sdgym/synthesizers/utils.py +51 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/utils.py +15 -6
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0/sdgym.egg-info}/PKG-INFO +3 -3
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym.egg-info/SOURCES.txt +1 -0
- sdgym-0.11.1.dev0/sdgym/synthesizers/__init__.py +0 -43
- sdgym-0.11.1.dev0/sdgym/synthesizers/base.py +0 -154
- sdgym-0.11.1.dev0/sdgym/synthesizers/generate.py +0 -287
- sdgym-0.11.1.dev0/sdgym/synthesizers/sdv.py +0 -126
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/LICENSE +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/_dataset_utils.py +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/cli/__init__.py +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/cli/collect.py +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/cli/summary.py +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/cli/utils.py +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/errors.py +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/metrics.py +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/progress.py +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/result_explorer/__init__.py +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/result_explorer/result_explorer.py +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/result_writer.py +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/s3.py +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/synthesizers/column.py +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/synthesizers/identity.py +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/synthesizers/realtabformer.py +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/synthesizers/uniform.py +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym.egg-info/dependency_links.txt +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym.egg-info/entry_points.txt +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym.egg-info/requires.txt +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym.egg-info/top_level.txt +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/setup.cfg +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/tests/test_scripts.py +0 -0
- {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/tests/test_tasks.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sdgym
|
|
3
|
-
Version: 0.11.
|
|
3
|
+
Version: 0.11.2.dev0
|
|
4
4
|
Summary: Benchmark tabular synthetic data generators using a variety of datasets
|
|
5
5
|
Author-email: "DataCebo, Inc." <info@sdv.dev>
|
|
6
6
|
License: BSL-1.1
|
|
@@ -194,10 +194,10 @@ Learn more in the [Custom Synthesizers Guide](https://docs.sdv.dev/sdgym/customi
|
|
|
194
194
|
## Customizing your datasets
|
|
195
195
|
|
|
196
196
|
The SDGym library includes many publicly available datasets that you can include right away.
|
|
197
|
-
List these using the ``
|
|
197
|
+
List these using the ``list_datasets`` feature.
|
|
198
198
|
|
|
199
199
|
```python
|
|
200
|
-
sdgym.
|
|
200
|
+
sdgym.dataset_explorer.DatasetExplorer().list_datasets()
|
|
201
201
|
```
|
|
202
202
|
|
|
203
203
|
```
|
|
@@ -103,10 +103,10 @@ Learn more in the [Custom Synthesizers Guide](https://docs.sdv.dev/sdgym/customi
|
|
|
103
103
|
## Customizing your datasets
|
|
104
104
|
|
|
105
105
|
The SDGym library includes many publicly available datasets that you can include right away.
|
|
106
|
-
List these using the ``
|
|
106
|
+
List these using the ``list_datasets`` feature.
|
|
107
107
|
|
|
108
108
|
```python
|
|
109
|
-
sdgym.
|
|
109
|
+
sdgym.dataset_explorer.DatasetExplorer().list_datasets()
|
|
110
110
|
```
|
|
111
111
|
|
|
112
112
|
```
|
|
@@ -144,7 +144,7 @@ namespaces = false
|
|
|
144
144
|
version = {attr = 'sdgym.__version__'}
|
|
145
145
|
|
|
146
146
|
[tool.bumpversion]
|
|
147
|
-
current_version = "0.11.
|
|
147
|
+
current_version = "0.11.2.dev0"
|
|
148
148
|
parse = '(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<candidate>\d+))?'
|
|
149
149
|
serialize = [
|
|
150
150
|
'{major}.{minor}.{patch}.{release}{candidate}',
|
|
@@ -8,16 +8,20 @@ __author__ = 'DataCebo, Inc.'
|
|
|
8
8
|
__copyright__ = 'Copyright (c) 2022 DataCebo, Inc.'
|
|
9
9
|
__email__ = 'info@sdv.dev'
|
|
10
10
|
__license__ = 'BSL-1.1'
|
|
11
|
-
__version__ = '0.11.
|
|
11
|
+
__version__ = '0.11.2.dev0'
|
|
12
12
|
|
|
13
13
|
import logging
|
|
14
14
|
|
|
15
|
-
from sdgym.benchmark import benchmark_single_table
|
|
15
|
+
from sdgym.benchmark import benchmark_single_table, benchmark_single_table_aws
|
|
16
16
|
from sdgym.cli.collect import collect_results
|
|
17
17
|
from sdgym.cli.summary import make_summary_spreadsheet
|
|
18
18
|
from sdgym.dataset_explorer import DatasetExplorer
|
|
19
|
-
from sdgym.datasets import
|
|
20
|
-
from sdgym.synthesizers import
|
|
19
|
+
from sdgym.datasets import load_dataset
|
|
20
|
+
from sdgym.synthesizers import (
|
|
21
|
+
create_synthesizer_variant,
|
|
22
|
+
create_single_table_synthesizer,
|
|
23
|
+
create_multi_table_synthesizer,
|
|
24
|
+
)
|
|
21
25
|
from sdgym.result_explorer import ResultsExplorer
|
|
22
26
|
|
|
23
27
|
# Clear the logging wrongfully configured by tensorflow/absl
|
|
@@ -28,10 +32,11 @@ __all__ = [
|
|
|
28
32
|
'DatasetExplorer',
|
|
29
33
|
'ResultsExplorer',
|
|
30
34
|
'benchmark_single_table',
|
|
35
|
+
'benchmark_single_table_aws',
|
|
31
36
|
'collect_results',
|
|
32
|
-
'
|
|
37
|
+
'create_synthesizer_variant',
|
|
33
38
|
'create_single_table_synthesizer',
|
|
34
|
-
'
|
|
39
|
+
'create_multi_table_synthesizer',
|
|
35
40
|
'load_dataset',
|
|
36
41
|
'make_summary_spreadsheet',
|
|
37
42
|
]
|
|
@@ -52,7 +52,7 @@ from sdgym.s3 import (
|
|
|
52
52
|
write_csv,
|
|
53
53
|
write_file,
|
|
54
54
|
)
|
|
55
|
-
from sdgym.synthesizers import
|
|
55
|
+
from sdgym.synthesizers import UniformSynthesizer
|
|
56
56
|
from sdgym.synthesizers.base import BaselineSynthesizer
|
|
57
57
|
from sdgym.utils import (
|
|
58
58
|
calculate_score_time,
|
|
@@ -67,7 +67,7 @@ from sdgym.utils import (
|
|
|
67
67
|
)
|
|
68
68
|
|
|
69
69
|
LOGGER = logging.getLogger(__name__)
|
|
70
|
-
DEFAULT_SYNTHESIZERS = [GaussianCopulaSynthesizer, CTGANSynthesizer, UniformSynthesizer]
|
|
70
|
+
DEFAULT_SYNTHESIZERS = ['GaussianCopulaSynthesizer', 'CTGANSynthesizer', 'UniformSynthesizer']
|
|
71
71
|
DEFAULT_DATASETS = [
|
|
72
72
|
'adult',
|
|
73
73
|
'alarm',
|
|
@@ -271,7 +271,11 @@ def _generate_job_args_list(
|
|
|
271
271
|
if additional_datasets_folder is None
|
|
272
272
|
else get_dataset_paths(
|
|
273
273
|
modality='single_table',
|
|
274
|
-
bucket=
|
|
274
|
+
bucket=(
|
|
275
|
+
additional_datasets_folder
|
|
276
|
+
if is_s3_path(additional_datasets_folder)
|
|
277
|
+
else os.path.join(additional_datasets_folder, 'single_table')
|
|
278
|
+
),
|
|
275
279
|
aws_access_key_id=aws_access_key_id,
|
|
276
280
|
aws_secret_access_key=aws_secret_access_key_key,
|
|
277
281
|
)
|
|
@@ -861,6 +865,7 @@ def _directory_exists(bucket_name, s3_file_path):
|
|
|
861
865
|
|
|
862
866
|
|
|
863
867
|
def _check_write_permissions(s3_client, bucket_name):
|
|
868
|
+
s3_client = s3_client or boto3.client('s3')
|
|
864
869
|
try:
|
|
865
870
|
s3_client.put_object(Bucket=bucket_name, Key='__test__', Body=b'')
|
|
866
871
|
write_permission = True
|
|
@@ -881,7 +886,7 @@ def _create_sdgym_script(params, output_filepath):
|
|
|
881
886
|
bucket_name, key_prefix = parse_s3_path(output_filepath)
|
|
882
887
|
if not _directory_exists(bucket_name, key_prefix):
|
|
883
888
|
raise ValueError(f'Directories in {key_prefix} do not exist')
|
|
884
|
-
if not _check_write_permissions(bucket_name):
|
|
889
|
+
if not _check_write_permissions(None, bucket_name):
|
|
885
890
|
raise ValueError('No write permissions allowed for the bucket.')
|
|
886
891
|
|
|
887
892
|
# Add quotes to parameter strings
|
|
@@ -893,23 +898,22 @@ def _create_sdgym_script(params, output_filepath):
|
|
|
893
898
|
params['output_filepath'] = "'" + params['output_filepath'] + "'"
|
|
894
899
|
|
|
895
900
|
# Generate the output script to run on the e2 instance
|
|
896
|
-
|
|
897
|
-
|
|
901
|
+
synthesizers = params.get('synthesizers', [])
|
|
902
|
+
names = []
|
|
903
|
+
for synthesizer in synthesizers:
|
|
898
904
|
if isinstance(synthesizer, str):
|
|
899
|
-
|
|
905
|
+
names.append(synthesizer)
|
|
906
|
+
elif hasattr(synthesizer, '__name__'):
|
|
907
|
+
names.append(synthesizer.__name__)
|
|
900
908
|
else:
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
synthesizer_string
|
|
909
|
+
names.append(synthesizer.__class__.__name__)
|
|
910
|
+
|
|
911
|
+
all_names = '", "'.join(names)
|
|
912
|
+
synthesizer_string = f'synthesizers=["{all_names}"]'
|
|
905
913
|
# The indentation of the string is important for the python script
|
|
906
914
|
script_content = f"""import boto3
|
|
907
915
|
from io import StringIO
|
|
908
916
|
import sdgym
|
|
909
|
-
from sdgym.synthesizers.sdv import (CopulaGANSynthesizer, CTGANSynthesizer,
|
|
910
|
-
GaussianCopulaSynthesizer, HMASynthesizer, PARSynthesizer, SDVRelationalSynthesizer,
|
|
911
|
-
SDVTabularSynthesizer, TVAESynthesizer)
|
|
912
|
-
from sdgym.synthesizers import RealTabFormerSynthesizer
|
|
913
917
|
|
|
914
918
|
results = sdgym.benchmark_single_table(
|
|
915
919
|
{synthesizer_string}, custom_synthesizers={params['custom_synthesizers']},
|
|
@@ -1186,7 +1190,7 @@ def benchmark_single_table(
|
|
|
1186
1190
|
custom_synthesizers (list[class] or ``None``):
|
|
1187
1191
|
A list of custom synthesizer classes to use. These can be completely custom or
|
|
1188
1192
|
they can be synthesizer variants (the output from ``create_single_table_synthesizer``
|
|
1189
|
-
or ``
|
|
1193
|
+
or ``create_synthesizer_variant``). Defaults to ``None``.
|
|
1190
1194
|
sdv_datasets (list[str] or ``None``):
|
|
1191
1195
|
Names of the SDV demo datasets to use for the benchmark. Defaults to
|
|
1192
1196
|
``[adult, alarm, census, child, expedia_hotel_logs, insurance, intrusion, news,
|
|
@@ -97,7 +97,7 @@ def _download_datasets(args):
|
|
|
97
97
|
_env_setup(args.logfile, args.verbose)
|
|
98
98
|
datasets = args.datasets
|
|
99
99
|
if not datasets:
|
|
100
|
-
datasets = sdgym.datasets.
|
|
100
|
+
datasets = sdgym.datasets._get_available_datasets(
|
|
101
101
|
args.bucket, args.aws_access_key_id, args.aws_secret_access_key
|
|
102
102
|
)['name']
|
|
103
103
|
|
|
@@ -118,7 +118,7 @@ def _list_downloaded(args):
|
|
|
118
118
|
|
|
119
119
|
|
|
120
120
|
def _list_available(args):
|
|
121
|
-
datasets = sdgym.datasets.
|
|
121
|
+
datasets = sdgym.datasets._get_available_datasets(
|
|
122
122
|
args.bucket, args.aws_access_key_id, args.aws_secret_access_key
|
|
123
123
|
)
|
|
124
124
|
_print_table(datasets, args.sort, args.reverse, {'size': humanfriendly.format_size})
|
|
@@ -275,3 +275,36 @@ class DatasetExplorer:
|
|
|
275
275
|
dataset_summary.to_csv(output_filepath, index=False)
|
|
276
276
|
|
|
277
277
|
return dataset_summary
|
|
278
|
+
|
|
279
|
+
def list_datasets(self, modality, output_filepath=None):
|
|
280
|
+
"""List available datasets for a modality using metainfo only.
|
|
281
|
+
|
|
282
|
+
This is a lightweight alternative to ``summarize_datasets`` that does not load
|
|
283
|
+
the actual data. It reads dataset information from the ``metainfo.yaml`` files
|
|
284
|
+
in the bucket and returns a table equivalent to the legacy
|
|
285
|
+
``get_available_datasets`` output.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
modality (str):
|
|
289
|
+
It must be ``'single_table'``, ``'multi_table'`` or ``'sequential'``.
|
|
290
|
+
output_filepath (str, optional):
|
|
291
|
+
Full path to a ``.csv`` file where the resulting table will be written.
|
|
292
|
+
If not provided, the table is only returned.
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
pd.DataFrame:
|
|
296
|
+
A DataFrame with columns: ``['dataset_name', 'size_MB', 'num_tables']``.
|
|
297
|
+
"""
|
|
298
|
+
self._validate_output_filepath(output_filepath)
|
|
299
|
+
_validate_modality(modality)
|
|
300
|
+
|
|
301
|
+
dataframe = _get_available_datasets(
|
|
302
|
+
modality=modality,
|
|
303
|
+
bucket=self._bucket_name,
|
|
304
|
+
aws_access_key_id=self.aws_access_key_id,
|
|
305
|
+
aws_secret_access_key=self.aws_secret_access_key,
|
|
306
|
+
)
|
|
307
|
+
if output_filepath:
|
|
308
|
+
dataframe.to_csv(output_filepath, index=False)
|
|
309
|
+
|
|
310
|
+
return dataframe
|
|
@@ -254,21 +254,6 @@ def load_dataset(
|
|
|
254
254
|
return data, metadata_dict
|
|
255
255
|
|
|
256
256
|
|
|
257
|
-
def get_available_datasets(modality='single_table'):
|
|
258
|
-
"""Get available single_table datasets.
|
|
259
|
-
|
|
260
|
-
Args:
|
|
261
|
-
modality (str):
|
|
262
|
-
It must be ``'single_table'``, ``'multi_table'`` or ``'sequential'``.
|
|
263
|
-
|
|
264
|
-
Return:
|
|
265
|
-
pd.DataFrame:
|
|
266
|
-
Table of available datasets and their sizes.
|
|
267
|
-
"""
|
|
268
|
-
_validate_modality(modality)
|
|
269
|
-
return _get_available_datasets(modality)
|
|
270
|
-
|
|
271
|
-
|
|
272
257
|
def get_dataset_paths(
|
|
273
258
|
modality,
|
|
274
259
|
datasets=None,
|
|
@@ -16,6 +16,7 @@ RESULTS_FOLDER_PREFIX = 'SDGym_results_'
|
|
|
16
16
|
metainfo_PREFIX = 'metainfo'
|
|
17
17
|
RESULTS_FILE_PREFIX = 'results'
|
|
18
18
|
NUM_DIGITS_DATE = 10
|
|
19
|
+
REGEX_SYNTHESIZER_NAME = r'\s*\(\d+\)\s*$'
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class ResultsHandler(ABC):
|
|
@@ -120,7 +121,15 @@ class ResultsHandler(ABC):
|
|
|
120
121
|
def _process_results(self, results):
|
|
121
122
|
"""Process results to ensure they are unique and each dataset has all synthesizers."""
|
|
122
123
|
aggregated_results = pd.concat(results, ignore_index=True)
|
|
123
|
-
aggregated_results
|
|
124
|
+
aggregated_results['Synthesizer'] = (
|
|
125
|
+
aggregated_results['Synthesizer']
|
|
126
|
+
.astype(str)
|
|
127
|
+
.str.replace(REGEX_SYNTHESIZER_NAME, '', regex=True)
|
|
128
|
+
.str.strip()
|
|
129
|
+
)
|
|
130
|
+
aggregated_results = aggregated_results.drop_duplicates(
|
|
131
|
+
subset=['Dataset', 'Synthesizer'], keep='first'
|
|
132
|
+
)
|
|
124
133
|
all_synthesizers = aggregated_results['Synthesizer'].unique()
|
|
125
134
|
dataset_synth_counts = aggregated_results.groupby('Dataset')['Synthesizer'].nunique()
|
|
126
135
|
valid_datasets = dataset_synth_counts[dataset_synth_counts == len(all_synthesizers)].index
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""Synthesizers module."""
|
|
2
|
+
|
|
3
|
+
from sdgym.synthesizers.generate import (
|
|
4
|
+
create_synthesizer_variant,
|
|
5
|
+
create_single_table_synthesizer,
|
|
6
|
+
create_multi_table_synthesizer,
|
|
7
|
+
)
|
|
8
|
+
from sdgym.synthesizers.identity import DataIdentity
|
|
9
|
+
from sdgym.synthesizers.column import ColumnSynthesizer
|
|
10
|
+
from sdgym.synthesizers.realtabformer import RealTabFormerSynthesizer
|
|
11
|
+
from sdgym.synthesizers.uniform import UniformSynthesizer
|
|
12
|
+
from sdgym.synthesizers.utils import (
|
|
13
|
+
get_available_single_table_synthesizers,
|
|
14
|
+
get_available_multi_table_synthesizers,
|
|
15
|
+
)
|
|
16
|
+
from sdgym.synthesizers.sdv import create_sdv_synthesizer_class, _get_all_sdv_synthesizers
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
'DataIdentity',
|
|
21
|
+
'ColumnSynthesizer',
|
|
22
|
+
'UniformSynthesizer',
|
|
23
|
+
'RealTabFormerSynthesizer',
|
|
24
|
+
'create_single_table_synthesizer',
|
|
25
|
+
'create_multi_table_synthesizer',
|
|
26
|
+
'create_synthesizer_variant',
|
|
27
|
+
'get_available_single_table_synthesizers',
|
|
28
|
+
'get_available_multi_table_synthesizers',
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
for sdv_name in _get_all_sdv_synthesizers():
|
|
32
|
+
create_sdv_synthesizer_class(sdv_name)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""Base classes for synthesizers."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
import logging
|
|
5
|
+
import warnings
|
|
6
|
+
|
|
7
|
+
from sdv.metadata import Metadata
|
|
8
|
+
|
|
9
|
+
LOGGER = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BaselineSynthesizer(abc.ABC):
|
|
13
|
+
"""Base class for all the ``SDGym`` baselines."""
|
|
14
|
+
|
|
15
|
+
_MODEL_KWARGS = {}
|
|
16
|
+
_NATIVELY_SUPPORTED = True
|
|
17
|
+
|
|
18
|
+
@classmethod
|
|
19
|
+
def get_subclasses(cls, include_parents=False):
|
|
20
|
+
"""Recursively find subclasses of this Baseline.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
include_parents (bool):
|
|
24
|
+
Whether to include subclasses which are parents to
|
|
25
|
+
other classes. Defaults to ``False``.
|
|
26
|
+
"""
|
|
27
|
+
subclasses = {}
|
|
28
|
+
for child in cls.__subclasses__():
|
|
29
|
+
grandchildren = child.get_subclasses(include_parents)
|
|
30
|
+
subclasses.update(grandchildren)
|
|
31
|
+
if include_parents or not grandchildren:
|
|
32
|
+
subclasses[child.__name__] = child
|
|
33
|
+
|
|
34
|
+
return subclasses
|
|
35
|
+
|
|
36
|
+
@classmethod
|
|
37
|
+
def _get_supported_synthesizers(cls):
|
|
38
|
+
"""Get the natively supported synthesizer class names."""
|
|
39
|
+
subclasses = cls.get_subclasses(include_parents=True)
|
|
40
|
+
synthesizers = set()
|
|
41
|
+
for name, subclass in subclasses.items():
|
|
42
|
+
if subclass._NATIVELY_SUPPORTED:
|
|
43
|
+
synthesizers.add(name)
|
|
44
|
+
|
|
45
|
+
return sorted(synthesizers)
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def get_baselines(cls):
|
|
49
|
+
"""Get baseline classes."""
|
|
50
|
+
subclasses = cls.get_subclasses(include_parents=True)
|
|
51
|
+
synthesizers = []
|
|
52
|
+
for _, subclass in subclasses.items():
|
|
53
|
+
if abc.ABC not in subclass.__bases__:
|
|
54
|
+
synthesizers.append(subclass)
|
|
55
|
+
|
|
56
|
+
return synthesizers
|
|
57
|
+
|
|
58
|
+
def get_trained_synthesizer(self, data, metadata):
|
|
59
|
+
"""Get a synthesizer that has been trained on the provided data and metadata.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
data (pandas.DataFrame):
|
|
63
|
+
The data to train on.
|
|
64
|
+
metadata (dict):
|
|
65
|
+
The metadata dictionary.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
obj:
|
|
69
|
+
The synthesizer object.
|
|
70
|
+
"""
|
|
71
|
+
metadata_object = Metadata()
|
|
72
|
+
with warnings.catch_warnings():
|
|
73
|
+
warnings.simplefilter('ignore', UserWarning)
|
|
74
|
+
metadata = metadata_object.load_from_dict(metadata)
|
|
75
|
+
|
|
76
|
+
return self._get_trained_synthesizer(data, metadata)
|
|
77
|
+
|
|
78
|
+
def sample_from_synthesizer(self, synthesizer, n_samples):
|
|
79
|
+
"""Sample data from the provided synthesizer.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
synthesizer (obj):
|
|
83
|
+
The synthesizer object to sample data from.
|
|
84
|
+
n_samples (int):
|
|
85
|
+
The number of samples to create.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
pandas.DataFrame or dict:
|
|
89
|
+
The sampled data. If single-table, should be a DataFrame. If multi-table,
|
|
90
|
+
should be a dict mapping table name to DataFrame.
|
|
91
|
+
"""
|
|
92
|
+
return self._sample_from_synthesizer(synthesizer, n_samples)
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""Helpers to create SDGym synthesizer variants."""
|
|
2
|
+
|
|
3
|
+
from sdgym.synthesizers.base import BaselineSynthesizer
|
|
4
|
+
from sdgym.synthesizers.utils import _get_supported_synthesizers
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def create_synthesizer_variant(display_name, synthesizer_class, synthesizer_parameters):
|
|
8
|
+
"""Create a new synthesizer variant.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
display_name (str):
|
|
12
|
+
Name of this synthesizer, used for display purposes in results.
|
|
13
|
+
synthesizer_class (str):
|
|
14
|
+
Name of the SDV synthesizer class to wrap.
|
|
15
|
+
synthesizer_parameters (dict):
|
|
16
|
+
A dictionary of the parameter names and values that will be used for the synthesizer.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
class:
|
|
20
|
+
The synthesizer class.
|
|
21
|
+
"""
|
|
22
|
+
if synthesizer_class not in _get_supported_synthesizers():
|
|
23
|
+
raise ValueError(f"Synthesizer '{synthesizer_class}' is not a SDGym supported synthesizer.")
|
|
24
|
+
|
|
25
|
+
base_class = BaselineSynthesizer.get_subclasses().get(synthesizer_class)
|
|
26
|
+
NewSynthesizer = type(
|
|
27
|
+
f'Variant:{display_name}',
|
|
28
|
+
(base_class,),
|
|
29
|
+
{
|
|
30
|
+
'__module__': __name__,
|
|
31
|
+
'_MODEL_KWARGS': synthesizer_parameters,
|
|
32
|
+
'_NATIVELY_SUPPORTED': False,
|
|
33
|
+
},
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
return NewSynthesizer
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _create_synthesizer_class(display_name, get_trained_fn, sample_fn, sample_arg_name):
|
|
40
|
+
"""Create a synthesizer class.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
display_name(string):
|
|
44
|
+
A string with the name of this synthesizer, used for display purposes only when
|
|
45
|
+
the results are generated
|
|
46
|
+
get_trained_synthesizer_fn (callable):
|
|
47
|
+
A function to generate and train a synthesizer, given the real data and metadata.
|
|
48
|
+
sample_from_synthesizer (callable):
|
|
49
|
+
A function to sample from the given synthesizer.
|
|
50
|
+
sample_arg_name (str):
|
|
51
|
+
The name of the argument used to specify the number of samples to generate.
|
|
52
|
+
Either 'num_samples' for single-table synthesizers, or 'scale' for multi-table
|
|
53
|
+
synthesizers.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
class:
|
|
57
|
+
The synthesizer class.
|
|
58
|
+
"""
|
|
59
|
+
class_name = f'Custom:{display_name}'
|
|
60
|
+
|
|
61
|
+
def get_trained_synthesizer(self, data, metadata):
|
|
62
|
+
return get_trained_fn(data, metadata)
|
|
63
|
+
|
|
64
|
+
if sample_arg_name == 'num_samples':
|
|
65
|
+
|
|
66
|
+
def sample_from_synthesizer(self, synthesizer, num_samples):
|
|
67
|
+
return sample_fn(synthesizer, num_samples)
|
|
68
|
+
|
|
69
|
+
else:
|
|
70
|
+
|
|
71
|
+
def sample_from_synthesizer(self, synthesizer, scale):
|
|
72
|
+
return sample_fn(synthesizer, scale)
|
|
73
|
+
|
|
74
|
+
CustomSynthesizer = type(
|
|
75
|
+
class_name,
|
|
76
|
+
(BaselineSynthesizer,),
|
|
77
|
+
{
|
|
78
|
+
'__module__': __name__,
|
|
79
|
+
'_NATIVELY_SUPPORTED': False,
|
|
80
|
+
'get_trained_synthesizer': get_trained_synthesizer,
|
|
81
|
+
'sample_from_synthesizer': sample_from_synthesizer,
|
|
82
|
+
},
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
globals()[class_name] = CustomSynthesizer
|
|
86
|
+
return CustomSynthesizer
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def create_single_table_synthesizer(
|
|
90
|
+
display_name, get_trained_synthesizer_fn, sample_from_synthesizer_fn
|
|
91
|
+
):
|
|
92
|
+
"""Create a single-table synthesizer class."""
|
|
93
|
+
return _create_synthesizer_class(
|
|
94
|
+
display_name,
|
|
95
|
+
get_trained_synthesizer_fn,
|
|
96
|
+
sample_from_synthesizer_fn,
|
|
97
|
+
sample_arg_name='num_samples',
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def create_multi_table_synthesizer(
|
|
102
|
+
display_name, get_trained_synthesizer_fn, sample_from_synthesizer_fn
|
|
103
|
+
):
|
|
104
|
+
"""Create a multi-table synthesizer class."""
|
|
105
|
+
return _create_synthesizer_class(
|
|
106
|
+
display_name,
|
|
107
|
+
get_trained_synthesizer_fn,
|
|
108
|
+
sample_from_synthesizer_fn,
|
|
109
|
+
sample_arg_name='scale',
|
|
110
|
+
)
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""SDV synthesizers wrappers for SDGym."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import sys
|
|
5
|
+
from importlib import import_module
|
|
6
|
+
|
|
7
|
+
from sdv import multi_table, single_table
|
|
8
|
+
|
|
9
|
+
from sdgym.synthesizers.base import BaselineSynthesizer
|
|
10
|
+
|
|
11
|
+
LOGGER = logging.getLogger(__name__)
|
|
12
|
+
UNSUPPORTED_SDV_SYNTHESIZERS = ['DayZSynthesizer']
|
|
13
|
+
MODALITY_TO_MODULE = {
|
|
14
|
+
'single_table': single_table,
|
|
15
|
+
'multi_table': multi_table,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _validate_modality(modality):
|
|
20
|
+
"""Validate that the modality is correct."""
|
|
21
|
+
if modality not in ['single_table', 'multi_table']:
|
|
22
|
+
raise ValueError("`modality` must be one of 'single_table' or 'multi_table'.")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _get_sdv_synthesizers(modality):
|
|
26
|
+
_validate_modality(modality)
|
|
27
|
+
module = MODALITY_TO_MODULE[modality]
|
|
28
|
+
available_synthesizer = {name for name, cls in module.__dict__.items() if isinstance(cls, type)}
|
|
29
|
+
available_synthesizer = available_synthesizer - set(UNSUPPORTED_SDV_SYNTHESIZERS)
|
|
30
|
+
return sorted(available_synthesizer)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _get_all_sdv_synthesizers():
|
|
34
|
+
"""Get all available SDV synthesizers."""
|
|
35
|
+
synthesizers = set()
|
|
36
|
+
for modality in MODALITY_TO_MODULE.keys():
|
|
37
|
+
synthesizers.update(_get_sdv_synthesizers(modality))
|
|
38
|
+
|
|
39
|
+
return sorted(synthesizers)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _get_trained_synthesizer(self, data, metadata):
|
|
43
|
+
LOGGER.info('Fitting %s', self.__class__.__name__)
|
|
44
|
+
sdv_class = getattr(import_module(f'sdv.{self.modality}'), self.SDV_NAME)
|
|
45
|
+
synthesizer = sdv_class(metadata=metadata, **self._MODEL_KWARGS)
|
|
46
|
+
synthesizer.fit(data)
|
|
47
|
+
return synthesizer
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _sample_from_synthesizer(self, synthesizer, sample_arg):
|
|
51
|
+
LOGGER.info('Sampling %s', self.__class__.__name__)
|
|
52
|
+
if self.modality == 'multi_table':
|
|
53
|
+
return synthesizer.sample(scale=sample_arg)
|
|
54
|
+
|
|
55
|
+
return synthesizer.sample(num_rows=sample_arg)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _retrieve_sdv_class(sdv_name):
|
|
59
|
+
current_module = sys.modules[__name__]
|
|
60
|
+
if hasattr(current_module, sdv_name):
|
|
61
|
+
existing_class = getattr(current_module, sdv_name)
|
|
62
|
+
if isinstance(existing_class, type):
|
|
63
|
+
return existing_class
|
|
64
|
+
|
|
65
|
+
return None
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _get_modality(sdv_name):
|
|
69
|
+
"""Get the modality of a SDV synthesizer."""
|
|
70
|
+
st_synthesizers = _get_sdv_synthesizers('single_table')
|
|
71
|
+
if sdv_name in st_synthesizers:
|
|
72
|
+
return 'single_table'
|
|
73
|
+
|
|
74
|
+
mt_synthesizers = _get_sdv_synthesizers('multi_table')
|
|
75
|
+
if sdv_name in mt_synthesizers:
|
|
76
|
+
return 'multi_table'
|
|
77
|
+
|
|
78
|
+
raise ValueError(f"Synthesizer '{sdv_name}' is not a SDV synthesizer.")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _create_sdv_class(sdv_name):
|
|
82
|
+
"""Create a SDV synthesizer class dynamically."""
|
|
83
|
+
current_module = sys.modules[__name__]
|
|
84
|
+
modality = _get_modality(sdv_name)
|
|
85
|
+
synthesizer_class = type(
|
|
86
|
+
sdv_name,
|
|
87
|
+
(BaselineSynthesizer,),
|
|
88
|
+
{
|
|
89
|
+
'__module__': __name__,
|
|
90
|
+
'SDV_NAME': sdv_name,
|
|
91
|
+
'modality': modality,
|
|
92
|
+
'_MODEL_KWARGS': {},
|
|
93
|
+
'_get_trained_synthesizer': _get_trained_synthesizer,
|
|
94
|
+
'_sample_from_synthesizer': _sample_from_synthesizer,
|
|
95
|
+
},
|
|
96
|
+
)
|
|
97
|
+
setattr(current_module, sdv_name, synthesizer_class)
|
|
98
|
+
|
|
99
|
+
return synthesizer_class
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def create_sdv_synthesizer_class(sdv_name):
|
|
103
|
+
"""Factory for dynamically creating or retrieving SDV synthesizer classes."""
|
|
104
|
+
if sdv_name not in _get_all_sdv_synthesizers():
|
|
105
|
+
raise ValueError(f"Synthesizer '{sdv_name}' is not a supported SDV synthesizer.")
|
|
106
|
+
|
|
107
|
+
return _retrieve_sdv_class(sdv_name) or _create_sdv_class(sdv_name)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Utility functions for synthesizers in SDGym."""
|
|
2
|
+
|
|
3
|
+
from sdgym.synthesizers.base import BaselineSynthesizer
|
|
4
|
+
from sdgym.synthesizers.sdv import _get_all_sdv_synthesizers, _get_sdv_synthesizers
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _get_sdgym_synthesizers():
|
|
8
|
+
"""Get SDGym synthesizers.
|
|
9
|
+
|
|
10
|
+
Returns:
|
|
11
|
+
list:
|
|
12
|
+
A list of available SDGym synthesizer names.
|
|
13
|
+
"""
|
|
14
|
+
synthesizers = BaselineSynthesizer._get_supported_synthesizers()
|
|
15
|
+
sdv_synthesizer = _get_all_sdv_synthesizers()
|
|
16
|
+
sdgym_synthesizer = [
|
|
17
|
+
synthesizer for synthesizer in synthesizers if synthesizer not in sdv_synthesizer
|
|
18
|
+
]
|
|
19
|
+
return sorted(sdgym_synthesizer)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_available_single_table_synthesizers():
|
|
23
|
+
"""List all available single-table synthesizers in SDGym.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
list:
|
|
27
|
+
A sorted list of available single-table synthesizer names.
|
|
28
|
+
"""
|
|
29
|
+
sdv_synthesizers = _get_sdv_synthesizers('single_table')
|
|
30
|
+
sdgym_synthesizers = _get_sdgym_synthesizers()
|
|
31
|
+
return sorted(sdv_synthesizers + sdgym_synthesizers)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_available_multi_table_synthesizers():
|
|
35
|
+
"""List all available multi-table synthesizers in SDGym.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
list:
|
|
39
|
+
A sorted list of available multi-table synthesizer names.
|
|
40
|
+
"""
|
|
41
|
+
return sorted(_get_sdv_synthesizers('multi_table'))
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _get_supported_synthesizers():
|
|
45
|
+
"""Get SDGym supported synthesizers.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
list:
|
|
49
|
+
A list of available SDGym supported synthesizer names.
|
|
50
|
+
"""
|
|
51
|
+
return BaselineSynthesizer._get_supported_synthesizers()
|