sdgym 0.11.1.dev0__tar.gz → 0.11.2.dev0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {sdgym-0.11.1.dev0/sdgym.egg-info → sdgym-0.11.2.dev0}/PKG-INFO +3 -3
  2. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/README.md +2 -2
  3. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/pyproject.toml +1 -1
  4. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/__init__.py +11 -6
  5. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/benchmark.py +20 -16
  6. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/cli/__main__.py +2 -2
  7. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/dataset_explorer.py +33 -0
  8. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/datasets.py +0 -15
  9. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/result_explorer/result_handler.py +10 -1
  10. sdgym-0.11.2.dev0/sdgym/synthesizers/__init__.py +32 -0
  11. sdgym-0.11.2.dev0/sdgym/synthesizers/base.py +92 -0
  12. sdgym-0.11.2.dev0/sdgym/synthesizers/generate.py +110 -0
  13. sdgym-0.11.2.dev0/sdgym/synthesizers/sdv.py +107 -0
  14. sdgym-0.11.2.dev0/sdgym/synthesizers/utils.py +51 -0
  15. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/utils.py +15 -6
  16. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0/sdgym.egg-info}/PKG-INFO +3 -3
  17. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym.egg-info/SOURCES.txt +1 -0
  18. sdgym-0.11.1.dev0/sdgym/synthesizers/__init__.py +0 -43
  19. sdgym-0.11.1.dev0/sdgym/synthesizers/base.py +0 -154
  20. sdgym-0.11.1.dev0/sdgym/synthesizers/generate.py +0 -287
  21. sdgym-0.11.1.dev0/sdgym/synthesizers/sdv.py +0 -126
  22. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/LICENSE +0 -0
  23. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/_dataset_utils.py +0 -0
  24. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/cli/__init__.py +0 -0
  25. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/cli/collect.py +0 -0
  26. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/cli/summary.py +0 -0
  27. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/cli/utils.py +0 -0
  28. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/errors.py +0 -0
  29. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/metrics.py +0 -0
  30. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/progress.py +0 -0
  31. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/result_explorer/__init__.py +0 -0
  32. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/result_explorer/result_explorer.py +0 -0
  33. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/result_writer.py +0 -0
  34. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/s3.py +0 -0
  35. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/synthesizers/column.py +0 -0
  36. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/synthesizers/identity.py +0 -0
  37. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/synthesizers/realtabformer.py +0 -0
  38. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym/synthesizers/uniform.py +0 -0
  39. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym.egg-info/dependency_links.txt +0 -0
  40. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym.egg-info/entry_points.txt +0 -0
  41. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym.egg-info/requires.txt +0 -0
  42. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/sdgym.egg-info/top_level.txt +0 -0
  43. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/setup.cfg +0 -0
  44. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/tests/test_scripts.py +0 -0
  45. {sdgym-0.11.1.dev0 → sdgym-0.11.2.dev0}/tests/test_tasks.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sdgym
3
- Version: 0.11.1.dev0
3
+ Version: 0.11.2.dev0
4
4
  Summary: Benchmark tabular synthetic data generators using a variety of datasets
5
5
  Author-email: "DataCebo, Inc." <info@sdv.dev>
6
6
  License: BSL-1.1
@@ -194,10 +194,10 @@ Learn more in the [Custom Synthesizers Guide](https://docs.sdv.dev/sdgym/customi
194
194
  ## Customizing your datasets
195
195
 
196
196
  The SDGym library includes many publicly available datasets that you can include right away.
197
- List these using the ``get_available_datasets`` feature.
197
+ List these using the ``list_datasets`` feature.
198
198
 
199
199
  ```python
200
- sdgym.get_available_datasets()
200
+ sdgym.dataset_explorer.DatasetExplorer().list_datasets()
201
201
  ```
202
202
 
203
203
  ```
@@ -103,10 +103,10 @@ Learn more in the [Custom Synthesizers Guide](https://docs.sdv.dev/sdgym/customi
103
103
  ## Customizing your datasets
104
104
 
105
105
  The SDGym library includes many publicly available datasets that you can include right away.
106
- List these using the ``get_available_datasets`` feature.
106
+ List these using the ``list_datasets`` feature.
107
107
 
108
108
  ```python
109
- sdgym.get_available_datasets()
109
+ sdgym.dataset_explorer.DatasetExplorer().list_datasets()
110
110
  ```
111
111
 
112
112
  ```
@@ -144,7 +144,7 @@ namespaces = false
144
144
  version = {attr = 'sdgym.__version__'}
145
145
 
146
146
  [tool.bumpversion]
147
- current_version = "0.11.1.dev0"
147
+ current_version = "0.11.2.dev0"
148
148
  parse = '(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<candidate>\d+))?'
149
149
  serialize = [
150
150
  '{major}.{minor}.{patch}.{release}{candidate}',
@@ -8,16 +8,20 @@ __author__ = 'DataCebo, Inc.'
8
8
  __copyright__ = 'Copyright (c) 2022 DataCebo, Inc.'
9
9
  __email__ = 'info@sdv.dev'
10
10
  __license__ = 'BSL-1.1'
11
- __version__ = '0.11.1.dev0'
11
+ __version__ = '0.11.2.dev0'
12
12
 
13
13
  import logging
14
14
 
15
- from sdgym.benchmark import benchmark_single_table
15
+ from sdgym.benchmark import benchmark_single_table, benchmark_single_table_aws
16
16
  from sdgym.cli.collect import collect_results
17
17
  from sdgym.cli.summary import make_summary_spreadsheet
18
18
  from sdgym.dataset_explorer import DatasetExplorer
19
- from sdgym.datasets import get_available_datasets, load_dataset
20
- from sdgym.synthesizers import create_sdv_synthesizer_variant, create_single_table_synthesizer
19
+ from sdgym.datasets import load_dataset
20
+ from sdgym.synthesizers import (
21
+ create_synthesizer_variant,
22
+ create_single_table_synthesizer,
23
+ create_multi_table_synthesizer,
24
+ )
21
25
  from sdgym.result_explorer import ResultsExplorer
22
26
 
23
27
  # Clear the logging wrongfully configured by tensorflow/absl
@@ -28,10 +32,11 @@ __all__ = [
28
32
  'DatasetExplorer',
29
33
  'ResultsExplorer',
30
34
  'benchmark_single_table',
35
+ 'benchmark_single_table_aws',
31
36
  'collect_results',
32
- 'create_sdv_synthesizer_variant',
37
+ 'create_synthesizer_variant',
33
38
  'create_single_table_synthesizer',
34
- 'get_available_datasets',
39
+ 'create_multi_table_synthesizer',
35
40
  'load_dataset',
36
41
  'make_summary_spreadsheet',
37
42
  ]
@@ -52,7 +52,7 @@ from sdgym.s3 import (
52
52
  write_csv,
53
53
  write_file,
54
54
  )
55
- from sdgym.synthesizers import CTGANSynthesizer, GaussianCopulaSynthesizer, UniformSynthesizer
55
+ from sdgym.synthesizers import UniformSynthesizer
56
56
  from sdgym.synthesizers.base import BaselineSynthesizer
57
57
  from sdgym.utils import (
58
58
  calculate_score_time,
@@ -67,7 +67,7 @@ from sdgym.utils import (
67
67
  )
68
68
 
69
69
  LOGGER = logging.getLogger(__name__)
70
- DEFAULT_SYNTHESIZERS = [GaussianCopulaSynthesizer, CTGANSynthesizer, UniformSynthesizer]
70
+ DEFAULT_SYNTHESIZERS = ['GaussianCopulaSynthesizer', 'CTGANSynthesizer', 'UniformSynthesizer']
71
71
  DEFAULT_DATASETS = [
72
72
  'adult',
73
73
  'alarm',
@@ -271,7 +271,11 @@ def _generate_job_args_list(
271
271
  if additional_datasets_folder is None
272
272
  else get_dataset_paths(
273
273
  modality='single_table',
274
- bucket=additional_datasets_folder,
274
+ bucket=(
275
+ additional_datasets_folder
276
+ if is_s3_path(additional_datasets_folder)
277
+ else os.path.join(additional_datasets_folder, 'single_table')
278
+ ),
275
279
  aws_access_key_id=aws_access_key_id,
276
280
  aws_secret_access_key=aws_secret_access_key_key,
277
281
  )
@@ -861,6 +865,7 @@ def _directory_exists(bucket_name, s3_file_path):
861
865
 
862
866
 
863
867
  def _check_write_permissions(s3_client, bucket_name):
868
+ s3_client = s3_client or boto3.client('s3')
864
869
  try:
865
870
  s3_client.put_object(Bucket=bucket_name, Key='__test__', Body=b'')
866
871
  write_permission = True
@@ -881,7 +886,7 @@ def _create_sdgym_script(params, output_filepath):
881
886
  bucket_name, key_prefix = parse_s3_path(output_filepath)
882
887
  if not _directory_exists(bucket_name, key_prefix):
883
888
  raise ValueError(f'Directories in {key_prefix} do not exist')
884
- if not _check_write_permissions(bucket_name):
889
+ if not _check_write_permissions(None, bucket_name):
885
890
  raise ValueError('No write permissions allowed for the bucket.')
886
891
 
887
892
  # Add quotes to parameter strings
@@ -893,23 +898,22 @@ def _create_sdgym_script(params, output_filepath):
893
898
  params['output_filepath'] = "'" + params['output_filepath'] + "'"
894
899
 
895
900
  # Generate the output script to run on the e2 instance
896
- synthesizer_string = 'synthesizers=['
897
- for synthesizer in params['synthesizers']:
901
+ synthesizers = params.get('synthesizers', [])
902
+ names = []
903
+ for synthesizer in synthesizers:
898
904
  if isinstance(synthesizer, str):
899
- synthesizer_string += synthesizer + ', '
905
+ names.append(synthesizer)
906
+ elif hasattr(synthesizer, '__name__'):
907
+ names.append(synthesizer.__name__)
900
908
  else:
901
- synthesizer_string += synthesizer.__name__ + ', '
902
- if params['synthesizers']:
903
- synthesizer_string = synthesizer_string[:-2]
904
- synthesizer_string += ']'
909
+ names.append(synthesizer.__class__.__name__)
910
+
911
+ all_names = '", "'.join(names)
912
+ synthesizer_string = f'synthesizers=["{all_names}"]'
905
913
  # The indentation of the string is important for the python script
906
914
  script_content = f"""import boto3
907
915
  from io import StringIO
908
916
  import sdgym
909
- from sdgym.synthesizers.sdv import (CopulaGANSynthesizer, CTGANSynthesizer,
910
- GaussianCopulaSynthesizer, HMASynthesizer, PARSynthesizer, SDVRelationalSynthesizer,
911
- SDVTabularSynthesizer, TVAESynthesizer)
912
- from sdgym.synthesizers import RealTabFormerSynthesizer
913
917
 
914
918
  results = sdgym.benchmark_single_table(
915
919
  {synthesizer_string}, custom_synthesizers={params['custom_synthesizers']},
@@ -1186,7 +1190,7 @@ def benchmark_single_table(
1186
1190
  custom_synthesizers (list[class] or ``None``):
1187
1191
  A list of custom synthesizer classes to use. These can be completely custom or
1188
1192
  they can be synthesizer variants (the output from ``create_single_table_synthesizer``
1189
- or ``create_sdv_synthesizer_variant``). Defaults to ``None``.
1193
+ or ``create_synthesizer_variant``). Defaults to ``None``.
1190
1194
  sdv_datasets (list[str] or ``None``):
1191
1195
  Names of the SDV demo datasets to use for the benchmark. Defaults to
1192
1196
  ``[adult, alarm, census, child, expedia_hotel_logs, insurance, intrusion, news,
@@ -97,7 +97,7 @@ def _download_datasets(args):
97
97
  _env_setup(args.logfile, args.verbose)
98
98
  datasets = args.datasets
99
99
  if not datasets:
100
- datasets = sdgym.datasets.get_available_datasets(
100
+ datasets = sdgym.datasets._get_available_datasets(
101
101
  args.bucket, args.aws_access_key_id, args.aws_secret_access_key
102
102
  )['name']
103
103
 
@@ -118,7 +118,7 @@ def _list_downloaded(args):
118
118
 
119
119
 
120
120
  def _list_available(args):
121
- datasets = sdgym.datasets.get_available_datasets(
121
+ datasets = sdgym.datasets._get_available_datasets(
122
122
  args.bucket, args.aws_access_key_id, args.aws_secret_access_key
123
123
  )
124
124
  _print_table(datasets, args.sort, args.reverse, {'size': humanfriendly.format_size})
@@ -275,3 +275,36 @@ class DatasetExplorer:
275
275
  dataset_summary.to_csv(output_filepath, index=False)
276
276
 
277
277
  return dataset_summary
278
+
279
+ def list_datasets(self, modality, output_filepath=None):
280
+ """List available datasets for a modality using metainfo only.
281
+
282
+ This is a lightweight alternative to ``summarize_datasets`` that does not load
283
+ the actual data. It reads dataset information from the ``metainfo.yaml`` files
284
+ in the bucket and returns a table equivalent to the legacy
285
+ ``get_available_datasets`` output.
286
+
287
+ Args:
288
+ modality (str):
289
+ It must be ``'single_table'``, ``'multi_table'`` or ``'sequential'``.
290
+ output_filepath (str, optional):
291
+ Full path to a ``.csv`` file where the resulting table will be written.
292
+ If not provided, the table is only returned.
293
+
294
+ Returns:
295
+ pd.DataFrame:
296
+ A DataFrame with columns: ``['dataset_name', 'size_MB', 'num_tables']``.
297
+ """
298
+ self._validate_output_filepath(output_filepath)
299
+ _validate_modality(modality)
300
+
301
+ dataframe = _get_available_datasets(
302
+ modality=modality,
303
+ bucket=self._bucket_name,
304
+ aws_access_key_id=self.aws_access_key_id,
305
+ aws_secret_access_key=self.aws_secret_access_key,
306
+ )
307
+ if output_filepath:
308
+ dataframe.to_csv(output_filepath, index=False)
309
+
310
+ return dataframe
@@ -254,21 +254,6 @@ def load_dataset(
254
254
  return data, metadata_dict
255
255
 
256
256
 
257
- def get_available_datasets(modality='single_table'):
258
- """Get available single_table datasets.
259
-
260
- Args:
261
- modality (str):
262
- It must be ``'single_table'``, ``'multi_table'`` or ``'sequential'``.
263
-
264
- Return:
265
- pd.DataFrame:
266
- Table of available datasets and their sizes.
267
- """
268
- _validate_modality(modality)
269
- return _get_available_datasets(modality)
270
-
271
-
272
257
  def get_dataset_paths(
273
258
  modality,
274
259
  datasets=None,
@@ -16,6 +16,7 @@ RESULTS_FOLDER_PREFIX = 'SDGym_results_'
16
16
  metainfo_PREFIX = 'metainfo'
17
17
  RESULTS_FILE_PREFIX = 'results'
18
18
  NUM_DIGITS_DATE = 10
19
+ REGEX_SYNTHESIZER_NAME = r'\s*\(\d+\)\s*$'
19
20
 
20
21
 
21
22
  class ResultsHandler(ABC):
@@ -120,7 +121,15 @@ class ResultsHandler(ABC):
120
121
  def _process_results(self, results):
121
122
  """Process results to ensure they are unique and each dataset has all synthesizers."""
122
123
  aggregated_results = pd.concat(results, ignore_index=True)
123
- aggregated_results = aggregated_results.drop_duplicates(subset=['Dataset', 'Synthesizer'])
124
+ aggregated_results['Synthesizer'] = (
125
+ aggregated_results['Synthesizer']
126
+ .astype(str)
127
+ .str.replace(REGEX_SYNTHESIZER_NAME, '', regex=True)
128
+ .str.strip()
129
+ )
130
+ aggregated_results = aggregated_results.drop_duplicates(
131
+ subset=['Dataset', 'Synthesizer'], keep='first'
132
+ )
124
133
  all_synthesizers = aggregated_results['Synthesizer'].unique()
125
134
  dataset_synth_counts = aggregated_results.groupby('Dataset')['Synthesizer'].nunique()
126
135
  valid_datasets = dataset_synth_counts[dataset_synth_counts == len(all_synthesizers)].index
@@ -0,0 +1,32 @@
1
+ """Synthesizers module."""
2
+
3
+ from sdgym.synthesizers.generate import (
4
+ create_synthesizer_variant,
5
+ create_single_table_synthesizer,
6
+ create_multi_table_synthesizer,
7
+ )
8
+ from sdgym.synthesizers.identity import DataIdentity
9
+ from sdgym.synthesizers.column import ColumnSynthesizer
10
+ from sdgym.synthesizers.realtabformer import RealTabFormerSynthesizer
11
+ from sdgym.synthesizers.uniform import UniformSynthesizer
12
+ from sdgym.synthesizers.utils import (
13
+ get_available_single_table_synthesizers,
14
+ get_available_multi_table_synthesizers,
15
+ )
16
+ from sdgym.synthesizers.sdv import create_sdv_synthesizer_class, _get_all_sdv_synthesizers
17
+
18
+
19
+ __all__ = [
20
+ 'DataIdentity',
21
+ 'ColumnSynthesizer',
22
+ 'UniformSynthesizer',
23
+ 'RealTabFormerSynthesizer',
24
+ 'create_single_table_synthesizer',
25
+ 'create_multi_table_synthesizer',
26
+ 'create_synthesizer_variant',
27
+ 'get_available_single_table_synthesizers',
28
+ 'get_available_multi_table_synthesizers',
29
+ ]
30
+
31
+ for sdv_name in _get_all_sdv_synthesizers():
32
+ create_sdv_synthesizer_class(sdv_name)
@@ -0,0 +1,92 @@
1
+ """Base classes for synthesizers."""
2
+
3
+ import abc
4
+ import logging
5
+ import warnings
6
+
7
+ from sdv.metadata import Metadata
8
+
9
+ LOGGER = logging.getLogger(__name__)
10
+
11
+
12
+ class BaselineSynthesizer(abc.ABC):
13
+ """Base class for all the ``SDGym`` baselines."""
14
+
15
+ _MODEL_KWARGS = {}
16
+ _NATIVELY_SUPPORTED = True
17
+
18
+ @classmethod
19
+ def get_subclasses(cls, include_parents=False):
20
+ """Recursively find subclasses of this Baseline.
21
+
22
+ Args:
23
+ include_parents (bool):
24
+ Whether to include subclasses which are parents to
25
+ other classes. Defaults to ``False``.
26
+ """
27
+ subclasses = {}
28
+ for child in cls.__subclasses__():
29
+ grandchildren = child.get_subclasses(include_parents)
30
+ subclasses.update(grandchildren)
31
+ if include_parents or not grandchildren:
32
+ subclasses[child.__name__] = child
33
+
34
+ return subclasses
35
+
36
+ @classmethod
37
+ def _get_supported_synthesizers(cls):
38
+ """Get the natively supported synthesizer class names."""
39
+ subclasses = cls.get_subclasses(include_parents=True)
40
+ synthesizers = set()
41
+ for name, subclass in subclasses.items():
42
+ if subclass._NATIVELY_SUPPORTED:
43
+ synthesizers.add(name)
44
+
45
+ return sorted(synthesizers)
46
+
47
+ @classmethod
48
+ def get_baselines(cls):
49
+ """Get baseline classes."""
50
+ subclasses = cls.get_subclasses(include_parents=True)
51
+ synthesizers = []
52
+ for _, subclass in subclasses.items():
53
+ if abc.ABC not in subclass.__bases__:
54
+ synthesizers.append(subclass)
55
+
56
+ return synthesizers
57
+
58
+ def get_trained_synthesizer(self, data, metadata):
59
+ """Get a synthesizer that has been trained on the provided data and metadata.
60
+
61
+ Args:
62
+ data (pandas.DataFrame):
63
+ The data to train on.
64
+ metadata (dict):
65
+ The metadata dictionary.
66
+
67
+ Returns:
68
+ obj:
69
+ The synthesizer object.
70
+ """
71
+ metadata_object = Metadata()
72
+ with warnings.catch_warnings():
73
+ warnings.simplefilter('ignore', UserWarning)
74
+ metadata = metadata_object.load_from_dict(metadata)
75
+
76
+ return self._get_trained_synthesizer(data, metadata)
77
+
78
+ def sample_from_synthesizer(self, synthesizer, n_samples):
79
+ """Sample data from the provided synthesizer.
80
+
81
+ Args:
82
+ synthesizer (obj):
83
+ The synthesizer object to sample data from.
84
+ n_samples (int):
85
+ The number of samples to create.
86
+
87
+ Returns:
88
+ pandas.DataFrame or dict:
89
+ The sampled data. If single-table, should be a DataFrame. If multi-table,
90
+ should be a dict mapping table name to DataFrame.
91
+ """
92
+ return self._sample_from_synthesizer(synthesizer, n_samples)
@@ -0,0 +1,110 @@
1
+ """Helpers to create SDGym synthesizer variants."""
2
+
3
+ from sdgym.synthesizers.base import BaselineSynthesizer
4
+ from sdgym.synthesizers.utils import _get_supported_synthesizers
5
+
6
+
7
+ def create_synthesizer_variant(display_name, synthesizer_class, synthesizer_parameters):
8
+ """Create a new synthesizer variant.
9
+
10
+ Args:
11
+ display_name (str):
12
+ Name of this synthesizer, used for display purposes in results.
13
+ synthesizer_class (str):
14
+ Name of the SDV synthesizer class to wrap.
15
+ synthesizer_parameters (dict):
16
+ A dictionary of the parameter names and values that will be used for the synthesizer.
17
+
18
+ Returns:
19
+ class:
20
+ The synthesizer class.
21
+ """
22
+ if synthesizer_class not in _get_supported_synthesizers():
23
+ raise ValueError(f"Synthesizer '{synthesizer_class}' is not a SDGym supported synthesizer.")
24
+
25
+ base_class = BaselineSynthesizer.get_subclasses().get(synthesizer_class)
26
+ NewSynthesizer = type(
27
+ f'Variant:{display_name}',
28
+ (base_class,),
29
+ {
30
+ '__module__': __name__,
31
+ '_MODEL_KWARGS': synthesizer_parameters,
32
+ '_NATIVELY_SUPPORTED': False,
33
+ },
34
+ )
35
+
36
+ return NewSynthesizer
37
+
38
+
39
+ def _create_synthesizer_class(display_name, get_trained_fn, sample_fn, sample_arg_name):
40
+ """Create a synthesizer class.
41
+
42
+ Args:
43
+ display_name(string):
44
+ A string with the name of this synthesizer, used for display purposes only when
45
+ the results are generated
46
+ get_trained_synthesizer_fn (callable):
47
+ A function to generate and train a synthesizer, given the real data and metadata.
48
+ sample_from_synthesizer (callable):
49
+ A function to sample from the given synthesizer.
50
+ sample_arg_name (str):
51
+ The name of the argument used to specify the number of samples to generate.
52
+ Either 'num_samples' for single-table synthesizers, or 'scale' for multi-table
53
+ synthesizers.
54
+
55
+ Returns:
56
+ class:
57
+ The synthesizer class.
58
+ """
59
+ class_name = f'Custom:{display_name}'
60
+
61
+ def get_trained_synthesizer(self, data, metadata):
62
+ return get_trained_fn(data, metadata)
63
+
64
+ if sample_arg_name == 'num_samples':
65
+
66
+ def sample_from_synthesizer(self, synthesizer, num_samples):
67
+ return sample_fn(synthesizer, num_samples)
68
+
69
+ else:
70
+
71
+ def sample_from_synthesizer(self, synthesizer, scale):
72
+ return sample_fn(synthesizer, scale)
73
+
74
+ CustomSynthesizer = type(
75
+ class_name,
76
+ (BaselineSynthesizer,),
77
+ {
78
+ '__module__': __name__,
79
+ '_NATIVELY_SUPPORTED': False,
80
+ 'get_trained_synthesizer': get_trained_synthesizer,
81
+ 'sample_from_synthesizer': sample_from_synthesizer,
82
+ },
83
+ )
84
+
85
+ globals()[class_name] = CustomSynthesizer
86
+ return CustomSynthesizer
87
+
88
+
89
+ def create_single_table_synthesizer(
90
+ display_name, get_trained_synthesizer_fn, sample_from_synthesizer_fn
91
+ ):
92
+ """Create a single-table synthesizer class."""
93
+ return _create_synthesizer_class(
94
+ display_name,
95
+ get_trained_synthesizer_fn,
96
+ sample_from_synthesizer_fn,
97
+ sample_arg_name='num_samples',
98
+ )
99
+
100
+
101
+ def create_multi_table_synthesizer(
102
+ display_name, get_trained_synthesizer_fn, sample_from_synthesizer_fn
103
+ ):
104
+ """Create a multi-table synthesizer class."""
105
+ return _create_synthesizer_class(
106
+ display_name,
107
+ get_trained_synthesizer_fn,
108
+ sample_from_synthesizer_fn,
109
+ sample_arg_name='scale',
110
+ )
@@ -0,0 +1,107 @@
1
+ """SDV synthesizers wrappers for SDGym."""
2
+
3
+ import logging
4
+ import sys
5
+ from importlib import import_module
6
+
7
+ from sdv import multi_table, single_table
8
+
9
+ from sdgym.synthesizers.base import BaselineSynthesizer
10
+
11
+ LOGGER = logging.getLogger(__name__)
12
+ UNSUPPORTED_SDV_SYNTHESIZERS = ['DayZSynthesizer']
13
+ MODALITY_TO_MODULE = {
14
+ 'single_table': single_table,
15
+ 'multi_table': multi_table,
16
+ }
17
+
18
+
19
+ def _validate_modality(modality):
20
+ """Validate that the modality is correct."""
21
+ if modality not in ['single_table', 'multi_table']:
22
+ raise ValueError("`modality` must be one of 'single_table' or 'multi_table'.")
23
+
24
+
25
+ def _get_sdv_synthesizers(modality):
26
+ _validate_modality(modality)
27
+ module = MODALITY_TO_MODULE[modality]
28
+ available_synthesizer = {name for name, cls in module.__dict__.items() if isinstance(cls, type)}
29
+ available_synthesizer = available_synthesizer - set(UNSUPPORTED_SDV_SYNTHESIZERS)
30
+ return sorted(available_synthesizer)
31
+
32
+
33
+ def _get_all_sdv_synthesizers():
34
+ """Get all available SDV synthesizers."""
35
+ synthesizers = set()
36
+ for modality in MODALITY_TO_MODULE.keys():
37
+ synthesizers.update(_get_sdv_synthesizers(modality))
38
+
39
+ return sorted(synthesizers)
40
+
41
+
42
+ def _get_trained_synthesizer(self, data, metadata):
43
+ LOGGER.info('Fitting %s', self.__class__.__name__)
44
+ sdv_class = getattr(import_module(f'sdv.{self.modality}'), self.SDV_NAME)
45
+ synthesizer = sdv_class(metadata=metadata, **self._MODEL_KWARGS)
46
+ synthesizer.fit(data)
47
+ return synthesizer
48
+
49
+
50
+ def _sample_from_synthesizer(self, synthesizer, sample_arg):
51
+ LOGGER.info('Sampling %s', self.__class__.__name__)
52
+ if self.modality == 'multi_table':
53
+ return synthesizer.sample(scale=sample_arg)
54
+
55
+ return synthesizer.sample(num_rows=sample_arg)
56
+
57
+
58
+ def _retrieve_sdv_class(sdv_name):
59
+ current_module = sys.modules[__name__]
60
+ if hasattr(current_module, sdv_name):
61
+ existing_class = getattr(current_module, sdv_name)
62
+ if isinstance(existing_class, type):
63
+ return existing_class
64
+
65
+ return None
66
+
67
+
68
+ def _get_modality(sdv_name):
69
+ """Get the modality of a SDV synthesizer."""
70
+ st_synthesizers = _get_sdv_synthesizers('single_table')
71
+ if sdv_name in st_synthesizers:
72
+ return 'single_table'
73
+
74
+ mt_synthesizers = _get_sdv_synthesizers('multi_table')
75
+ if sdv_name in mt_synthesizers:
76
+ return 'multi_table'
77
+
78
+ raise ValueError(f"Synthesizer '{sdv_name}' is not a SDV synthesizer.")
79
+
80
+
81
+ def _create_sdv_class(sdv_name):
82
+ """Create a SDV synthesizer class dynamically."""
83
+ current_module = sys.modules[__name__]
84
+ modality = _get_modality(sdv_name)
85
+ synthesizer_class = type(
86
+ sdv_name,
87
+ (BaselineSynthesizer,),
88
+ {
89
+ '__module__': __name__,
90
+ 'SDV_NAME': sdv_name,
91
+ 'modality': modality,
92
+ '_MODEL_KWARGS': {},
93
+ '_get_trained_synthesizer': _get_trained_synthesizer,
94
+ '_sample_from_synthesizer': _sample_from_synthesizer,
95
+ },
96
+ )
97
+ setattr(current_module, sdv_name, synthesizer_class)
98
+
99
+ return synthesizer_class
100
+
101
+
102
+ def create_sdv_synthesizer_class(sdv_name):
103
+ """Factory for dynamically creating or retrieving SDV synthesizer classes."""
104
+ if sdv_name not in _get_all_sdv_synthesizers():
105
+ raise ValueError(f"Synthesizer '{sdv_name}' is not a supported SDV synthesizer.")
106
+
107
+ return _retrieve_sdv_class(sdv_name) or _create_sdv_class(sdv_name)
@@ -0,0 +1,51 @@
1
+ """Utility functions for synthesizers in SDGym."""
2
+
3
+ from sdgym.synthesizers.base import BaselineSynthesizer
4
+ from sdgym.synthesizers.sdv import _get_all_sdv_synthesizers, _get_sdv_synthesizers
5
+
6
+
7
+ def _get_sdgym_synthesizers():
8
+ """Get SDGym synthesizers.
9
+
10
+ Returns:
11
+ list:
12
+ A list of available SDGym synthesizer names.
13
+ """
14
+ synthesizers = BaselineSynthesizer._get_supported_synthesizers()
15
+ sdv_synthesizer = _get_all_sdv_synthesizers()
16
+ sdgym_synthesizer = [
17
+ synthesizer for synthesizer in synthesizers if synthesizer not in sdv_synthesizer
18
+ ]
19
+ return sorted(sdgym_synthesizer)
20
+
21
+
22
+ def get_available_single_table_synthesizers():
23
+ """List all available single-table synthesizers in SDGym.
24
+
25
+ Returns:
26
+ list:
27
+ A sorted list of available single-table synthesizer names.
28
+ """
29
+ sdv_synthesizers = _get_sdv_synthesizers('single_table')
30
+ sdgym_synthesizers = _get_sdgym_synthesizers()
31
+ return sorted(sdv_synthesizers + sdgym_synthesizers)
32
+
33
+
34
+ def get_available_multi_table_synthesizers():
35
+ """List all available multi-table synthesizers in SDGym.
36
+
37
+ Returns:
38
+ list:
39
+ A sorted list of available multi-table synthesizer names.
40
+ """
41
+ return sorted(_get_sdv_synthesizers('multi_table'))
42
+
43
+
44
+ def _get_supported_synthesizers():
45
+ """Get SDGym supported synthesizers.
46
+
47
+ Returns:
48
+ list:
49
+ A list of available SDGym supported synthesizer names.
50
+ """
51
+ return BaselineSynthesizer._get_supported_synthesizers()