sdgym 0.9.1.dev0__tar.gz → 0.10.0.dev0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/PKG-INFO +15 -10
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/pyproject.toml +19 -12
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/__init__.py +1 -1
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/benchmark.py +8 -8
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/cli/__main__.py +6 -6
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/synthesizers/__init__.py +2 -0
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/synthesizers/base.py +7 -4
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/synthesizers/column.py +26 -0
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/synthesizers/generate.py +1 -1
- sdgym-0.10.0.dev0/sdgym/synthesizers/realtabformer.py +46 -0
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/synthesizers/uniform.py +3 -2
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym.egg-info/PKG-INFO +15 -10
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym.egg-info/SOURCES.txt +1 -0
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym.egg-info/requires.txt +18 -8
- sdgym-0.10.0.dev0/tests/test_tasks.py +207 -0
- sdgym-0.9.1.dev0/tests/test_tasks.py +0 -39
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/LICENSE +0 -0
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/README.md +0 -0
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/cli/__init__.py +0 -0
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/cli/collect.py +0 -0
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/cli/summary.py +0 -0
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/cli/utils.py +0 -0
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/datasets.py +0 -0
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/errors.py +0 -0
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/metrics.py +0 -0
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/progress.py +0 -0
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/s3.py +0 -0
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/synthesizers/identity.py +0 -0
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/synthesizers/sdv.py +0 -0
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/utils.py +0 -0
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym.egg-info/dependency_links.txt +0 -0
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym.egg-info/entry_points.txt +0 -0
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym.egg-info/top_level.txt +0 -0
- {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
2
|
Name: sdgym
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.10.0.dev0
|
|
4
4
|
Summary: Benchmark tabular synthetic data generators using a variety of datasets
|
|
5
5
|
Author-email: "DataCebo, Inc." <info@sdv.dev>
|
|
6
6
|
License: BSL-1.1
|
|
@@ -30,9 +30,9 @@ Requires-Dist: botocore<2,>=1.31
|
|
|
30
30
|
Requires-Dist: cloudpickle>=2.1.0
|
|
31
31
|
Requires-Dist: compress-pickle>=1.2.0
|
|
32
32
|
Requires-Dist: humanfriendly>=8.2
|
|
33
|
-
Requires-Dist: numpy
|
|
34
|
-
Requires-Dist: numpy
|
|
35
|
-
Requires-Dist: numpy
|
|
33
|
+
Requires-Dist: numpy>=1.21.6; python_version < "3.10"
|
|
34
|
+
Requires-Dist: numpy>=1.23.3; python_version >= "3.10" and python_version < "3.12"
|
|
35
|
+
Requires-Dist: numpy>=1.26.0; python_version >= "3.12"
|
|
36
36
|
Requires-Dist: pandas>=1.4.0; python_version < "3.11"
|
|
37
37
|
Requires-Dist: pandas>=1.5.0; python_version >= "3.11" and python_version < "3.12"
|
|
38
38
|
Requires-Dist: pandas>=2.1.1; python_version >= "3.12"
|
|
@@ -45,18 +45,23 @@ Requires-Dist: scipy>=1.7.3; python_version < "3.10"
|
|
|
45
45
|
Requires-Dist: scipy>=1.9.2; python_version >= "3.10" and python_version < "3.12"
|
|
46
46
|
Requires-Dist: scipy>=1.12.0; python_version >= "3.12"
|
|
47
47
|
Requires-Dist: tabulate<0.9,>=0.8.3
|
|
48
|
-
Requires-Dist: torch>=1.
|
|
48
|
+
Requires-Dist: torch>=1.12.1; python_version < "3.10"
|
|
49
49
|
Requires-Dist: torch>=2.0.0; python_version >= "3.10" and python_version < "3.12"
|
|
50
50
|
Requires-Dist: torch>=2.2.0; python_version >= "3.12"
|
|
51
|
-
Requires-Dist: tqdm>=4.
|
|
51
|
+
Requires-Dist: tqdm>=4.66.3
|
|
52
52
|
Requires-Dist: XlsxWriter>=1.2.8
|
|
53
|
-
Requires-Dist: rdt>=1.
|
|
54
|
-
Requires-Dist: sdmetrics>=0.
|
|
55
|
-
Requires-Dist: sdv>=1.
|
|
53
|
+
Requires-Dist: rdt>=1.13.1
|
|
54
|
+
Requires-Dist: sdmetrics>=0.17.0
|
|
55
|
+
Requires-Dist: sdv>=1.17.2
|
|
56
56
|
Provides-Extra: dask
|
|
57
57
|
Requires-Dist: dask; extra == "dask"
|
|
58
58
|
Requires-Dist: distributed; extra == "dask"
|
|
59
|
+
Provides-Extra: realtabformer
|
|
60
|
+
Requires-Dist: realtabformer>=0.2.2; extra == "realtabformer"
|
|
61
|
+
Requires-Dist: torch>=2.0.0; (python_version >= "3.8" and python_version < "3.12") and extra == "realtabformer"
|
|
62
|
+
Requires-Dist: torch>=2.2.0; python_version >= "3.12" and extra == "realtabformer"
|
|
59
63
|
Provides-Extra: test
|
|
64
|
+
Requires-Dist: sdgym[realtabformer]; extra == "test"
|
|
60
65
|
Requires-Dist: pytest>=6.2.5; extra == "test"
|
|
61
66
|
Requires-Dist: pytest-cov>=2.6.0; extra == "test"
|
|
62
67
|
Requires-Dist: jupyter<2,>=1.0.0; extra == "test"
|
|
@@ -27,9 +27,9 @@ dependencies = [
|
|
|
27
27
|
'cloudpickle>=2.1.0',
|
|
28
28
|
'compress-pickle>=1.2.0',
|
|
29
29
|
'humanfriendly>=8.2',
|
|
30
|
-
"numpy>=1.21.
|
|
31
|
-
"numpy>=1.23.3
|
|
32
|
-
"numpy>=1.26.0
|
|
30
|
+
"numpy>=1.21.6;python_version<'3.10'",
|
|
31
|
+
"numpy>=1.23.3;python_version>='3.10' and python_version<'3.12'",
|
|
32
|
+
"numpy>=1.26.0;python_version>='3.12'",
|
|
33
33
|
"pandas>=1.4.0;python_version<'3.11'",
|
|
34
34
|
"pandas>=1.5.0;python_version>='3.11' and python_version<'3.12'",
|
|
35
35
|
"pandas>=2.1.1;python_version>='3.12'",
|
|
@@ -42,14 +42,14 @@ dependencies = [
|
|
|
42
42
|
"scipy>=1.9.2;python_version>='3.10' and python_version<'3.12'",
|
|
43
43
|
"scipy>=1.12.0;python_version>='3.12'",
|
|
44
44
|
'tabulate>=0.8.3,<0.9',
|
|
45
|
-
"torch>=1.
|
|
45
|
+
"torch>=1.12.1;python_version<'3.10'",
|
|
46
46
|
"torch>=2.0.0;python_version>='3.10' and python_version<'3.12'",
|
|
47
47
|
"torch>=2.2.0;python_version>='3.12'",
|
|
48
|
-
'tqdm>=4.
|
|
48
|
+
'tqdm>=4.66.3',
|
|
49
49
|
'XlsxWriter>=1.2.8',
|
|
50
|
-
'rdt>=1.
|
|
51
|
-
'sdmetrics>=0.
|
|
52
|
-
'sdv>=1.
|
|
50
|
+
'rdt>=1.13.1',
|
|
51
|
+
'sdmetrics>=0.17.0',
|
|
52
|
+
'sdv>=1.17.2',
|
|
53
53
|
]
|
|
54
54
|
|
|
55
55
|
[project.urls]
|
|
@@ -64,7 +64,13 @@ sdgym = { main = 'sdgym.cli.__main__:main' }
|
|
|
64
64
|
|
|
65
65
|
[project.optional-dependencies]
|
|
66
66
|
dask = ['dask', 'distributed']
|
|
67
|
+
realtabformer = [
|
|
68
|
+
'realtabformer>=0.2.2',
|
|
69
|
+
"torch>=2.0.0;python_version>='3.8' and python_version<'3.12'",
|
|
70
|
+
"torch>=2.2.0;python_version>='3.12'",
|
|
71
|
+
]
|
|
67
72
|
test = [
|
|
73
|
+
'sdgym[realtabformer]',
|
|
68
74
|
'pytest>=6.2.5',
|
|
69
75
|
'pytest-cov>=2.6.0',
|
|
70
76
|
'jupyter>=1.0.0,<2',
|
|
@@ -134,7 +140,7 @@ namespaces = false
|
|
|
134
140
|
version = {attr = 'sdgym.__version__'}
|
|
135
141
|
|
|
136
142
|
[tool.bumpversion]
|
|
137
|
-
current_version = "0.
|
|
143
|
+
current_version = "0.10.0.dev0"
|
|
138
144
|
parse = '(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<candidate>\d+))?'
|
|
139
145
|
serialize = [
|
|
140
146
|
'{major}.{minor}.{patch}.{release}{candidate}',
|
|
@@ -198,10 +204,11 @@ select = [
|
|
|
198
204
|
# print statements
|
|
199
205
|
"T201",
|
|
200
206
|
# pandas-vet
|
|
201
|
-
"PD"
|
|
207
|
+
"PD",
|
|
208
|
+
# numpy 2.0
|
|
209
|
+
"NPY201"
|
|
202
210
|
]
|
|
203
211
|
ignore = [
|
|
204
|
-
"E501",
|
|
205
212
|
# pydocstyle
|
|
206
213
|
"D107", # Missing docstring in __init__
|
|
207
214
|
"D417", # Missing argument descriptions in the docstring, this is a bug from pydocstyle: https://github.com/PyCQA/pydocstyle/issues/449
|
|
@@ -230,4 +237,4 @@ convention = "google"
|
|
|
230
237
|
|
|
231
238
|
[tool.ruff.lint.pycodestyle]
|
|
232
239
|
max-doc-length = 100
|
|
233
|
-
max-line-length = 100
|
|
240
|
+
max-line-length = 100
|
|
@@ -66,8 +66,7 @@ N_BYTES_IN_MB = 1000 * 1000
|
|
|
66
66
|
def _validate_inputs(output_filepath, detailed_results_folder, synthesizers, custom_synthesizers):
|
|
67
67
|
if output_filepath and os.path.exists(output_filepath):
|
|
68
68
|
raise ValueError(
|
|
69
|
-
f'{output_filepath} already exists. '
|
|
70
|
-
'Please provide a file that does not already exist.'
|
|
69
|
+
f'{output_filepath} already exists. Please provide a file that does not already exist.'
|
|
71
70
|
)
|
|
72
71
|
|
|
73
72
|
if detailed_results_folder and os.path.exists(detailed_results_folder):
|
|
@@ -149,14 +148,14 @@ def _generate_job_args_list(
|
|
|
149
148
|
def _synthesize(synthesizer_dict, real_data, metadata):
|
|
150
149
|
synthesizer = synthesizer_dict['synthesizer']
|
|
151
150
|
if isinstance(synthesizer, type):
|
|
152
|
-
assert issubclass(
|
|
153
|
-
synthesizer
|
|
154
|
-
)
|
|
151
|
+
assert issubclass(synthesizer, BaselineSynthesizer), (
|
|
152
|
+
'`synthesizer` must be a synthesizer class'
|
|
153
|
+
)
|
|
155
154
|
synthesizer = synthesizer()
|
|
156
155
|
else:
|
|
157
|
-
assert issubclass(
|
|
158
|
-
|
|
159
|
-
)
|
|
156
|
+
assert issubclass(type(synthesizer), BaselineSynthesizer), (
|
|
157
|
+
'`synthesizer` must be an instance of a synthesizer class.'
|
|
158
|
+
)
|
|
160
159
|
|
|
161
160
|
get_synthesizer = synthesizer.get_trained_synthesizer
|
|
162
161
|
sample_from_synthesizer = synthesizer.sample_from_synthesizer
|
|
@@ -747,6 +746,7 @@ def benchmark_single_table(
|
|
|
747
746
|
- ``CTGANSynthesizer``
|
|
748
747
|
- ``CopulaGANSynthesizer``
|
|
749
748
|
- ``TVAESynthesizer``
|
|
749
|
+
- ``RealTabFormerSynthesizer``
|
|
750
750
|
|
|
751
751
|
custom_synthesizers (list[class] or ``None``):
|
|
752
752
|
A list of custom synthesizer classes to use. These can be completely custom or
|
|
@@ -180,9 +180,9 @@ def _get_parser():
|
|
|
180
180
|
)
|
|
181
181
|
run.add_argument('-m', '--metrics', nargs='+', help='Metrics to apply. Accepts multiple names.')
|
|
182
182
|
run.add_argument('-b', '--bucket', help='Bucket from which to download the datasets.')
|
|
183
|
-
run.add_argument('-dp
|
|
183
|
+
run.add_argument('-dp--datasets-path', help='Path where datasets can be found.')
|
|
184
184
|
run.add_argument(
|
|
185
|
-
'-dm
|
|
185
|
+
'-dm--modalities', nargs='+', help='Data Modalities to run. Accepts multiple names.'
|
|
186
186
|
)
|
|
187
187
|
run.add_argument('-i', '--iterations', type=int, default=1, help='Number of iterations.')
|
|
188
188
|
run.add_argument(
|
|
@@ -219,13 +219,13 @@ def _get_parser():
|
|
|
219
219
|
'-g', '--groupby', nargs='+', help='Group scores leaderboard by the given fields.'
|
|
220
220
|
)
|
|
221
221
|
run.add_argument(
|
|
222
|
-
'-ak
|
|
222
|
+
'-ak--aws-key',
|
|
223
223
|
type=str,
|
|
224
224
|
required=False,
|
|
225
225
|
help='Aws access key ID to use when reading datasets.',
|
|
226
226
|
)
|
|
227
227
|
run.add_argument(
|
|
228
|
-
'-as
|
|
228
|
+
'-as--aws-secret',
|
|
229
229
|
type=str,
|
|
230
230
|
required=False,
|
|
231
231
|
help='Aws secret access key to use when reading datasets.',
|
|
@@ -234,10 +234,10 @@ def _get_parser():
|
|
|
234
234
|
'-j', '--jobs', type=str, required=False, help='Serialized list of jobs to run.'
|
|
235
235
|
)
|
|
236
236
|
run.add_argument(
|
|
237
|
-
'-mr
|
|
237
|
+
'-mr--max-rows', type=int, help='Cap the number of rows to model from each dataset.'
|
|
238
238
|
)
|
|
239
239
|
run.add_argument(
|
|
240
|
-
'-mc
|
|
240
|
+
'-mc--max-columns',
|
|
241
241
|
type=int,
|
|
242
242
|
help='Cap the number of columns to model from each dataset.',
|
|
243
243
|
)
|
|
@@ -9,6 +9,7 @@ from sdgym.synthesizers.generate import (
|
|
|
9
9
|
)
|
|
10
10
|
from sdgym.synthesizers.identity import DataIdentity
|
|
11
11
|
from sdgym.synthesizers.column import ColumnSynthesizer
|
|
12
|
+
from sdgym.synthesizers.realtabformer import RealTabFormerSynthesizer
|
|
12
13
|
from sdgym.synthesizers.sdv import (
|
|
13
14
|
CopulaGANSynthesizer,
|
|
14
15
|
CTGANSynthesizer,
|
|
@@ -38,4 +39,5 @@ __all__ = (
|
|
|
38
39
|
'create_sdv_synthesizer_variant',
|
|
39
40
|
'create_sequential_synthesizer',
|
|
40
41
|
'SYNTHESIZER_MAPPING',
|
|
42
|
+
'RealTabFormerSynthesizer',
|
|
41
43
|
)
|
|
@@ -2,9 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
4
|
import logging
|
|
5
|
+
import warnings
|
|
5
6
|
|
|
6
|
-
from sdv.metadata
|
|
7
|
-
from sdv.metadata.single_table import SingleTableMetadata
|
|
7
|
+
from sdv.metadata import Metadata
|
|
8
8
|
|
|
9
9
|
LOGGER = logging.getLogger(__name__)
|
|
10
10
|
|
|
@@ -54,8 +54,11 @@ class BaselineSynthesizer(abc.ABC):
|
|
|
54
54
|
obj:
|
|
55
55
|
The synthesizer object.
|
|
56
56
|
"""
|
|
57
|
-
|
|
58
|
-
|
|
57
|
+
metadata_object = Metadata()
|
|
58
|
+
with warnings.catch_warnings():
|
|
59
|
+
warnings.simplefilter('ignore', UserWarning)
|
|
60
|
+
metadata = metadata_object.load_from_dict(metadata)
|
|
61
|
+
|
|
59
62
|
return self._get_trained_synthesizer(data, metadata)
|
|
60
63
|
|
|
61
64
|
def sample_from_synthesizer(self, synthesizer, n_samples):
|
|
@@ -1,11 +1,16 @@
|
|
|
1
1
|
"""ColumnSynthesizer module."""
|
|
2
2
|
|
|
3
|
+
import logging
|
|
4
|
+
|
|
3
5
|
import pandas as pd
|
|
4
6
|
from rdt.hyper_transformer import HyperTransformer
|
|
7
|
+
from sdv.metadata import Metadata
|
|
5
8
|
from sklearn.mixture import GaussianMixture
|
|
6
9
|
|
|
7
10
|
from sdgym.synthesizers.base import BaselineSynthesizer
|
|
8
11
|
|
|
12
|
+
LOGGER = logging.getLogger(__name__)
|
|
13
|
+
|
|
9
14
|
|
|
10
15
|
class ColumnSynthesizer(BaselineSynthesizer):
|
|
11
16
|
"""Synthesizer that learns each column independently.
|
|
@@ -17,6 +22,27 @@ class ColumnSynthesizer(BaselineSynthesizer):
|
|
|
17
22
|
def _get_trained_synthesizer(self, real_data, metadata):
|
|
18
23
|
hyper_transformer = HyperTransformer()
|
|
19
24
|
hyper_transformer.detect_initial_config(real_data)
|
|
25
|
+
supported_sdtypes = hyper_transformer._get_supported_sdtypes()
|
|
26
|
+
config = {}
|
|
27
|
+
if isinstance(metadata, Metadata):
|
|
28
|
+
table_name = metadata._get_single_table_name()
|
|
29
|
+
columns = metadata.tables[table_name].columns
|
|
30
|
+
else:
|
|
31
|
+
columns = metadata.columns
|
|
32
|
+
|
|
33
|
+
for column_name, column in columns.items():
|
|
34
|
+
sdtype = column['sdtype']
|
|
35
|
+
if sdtype in supported_sdtypes:
|
|
36
|
+
config[column_name] = sdtype
|
|
37
|
+
elif column.get('pii', False):
|
|
38
|
+
config[column_name] = 'pii'
|
|
39
|
+
else:
|
|
40
|
+
LOGGER.info(
|
|
41
|
+
f'Column {column} sdtype: {sdtype} is not supported, '
|
|
42
|
+
f'defaulting to inferred type.'
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
hyper_transformer.update_sdtypes(config)
|
|
20
46
|
|
|
21
47
|
# This is done to match the behavior of the synthesizer for SDGym <= 0.6.0
|
|
22
48
|
columns_to_remove = [
|
|
@@ -49,7 +49,7 @@ def create_sdv_synthesizer_variant(display_name, synthesizer_class, synthesizer_
|
|
|
49
49
|
if synthesizer_class not in SYNTHESIZER_MAPPING.keys():
|
|
50
50
|
raise ValueError(
|
|
51
51
|
f'Synthesizer class {synthesizer_class} is not recognized. '
|
|
52
|
-
f
|
|
52
|
+
f'The supported options are {", ".join(SYNTHESIZER_MAPPING.keys())}'
|
|
53
53
|
)
|
|
54
54
|
|
|
55
55
|
baseclass = SDVTabularSynthesizer
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""REaLTabFormer integration."""
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import logging
|
|
5
|
+
from functools import partialmethod
|
|
6
|
+
|
|
7
|
+
import tqdm
|
|
8
|
+
|
|
9
|
+
from sdgym.synthesizers.base import BaselineSynthesizer
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@contextlib.contextmanager
|
|
13
|
+
def prevent_tqdm_output():
|
|
14
|
+
"""Temporarily disables tqdm m."""
|
|
15
|
+
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
|
|
16
|
+
try:
|
|
17
|
+
yield
|
|
18
|
+
finally:
|
|
19
|
+
tqdm.__init__ = partialmethod(tqdm.__init__, disable=False)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class RealTabFormerSynthesizer(BaselineSynthesizer):
|
|
23
|
+
"""Custom wrapper for the REaLTabFormer synthesizer to make it work with SDGym."""
|
|
24
|
+
|
|
25
|
+
LOGGER = logging.getLogger(__name__)
|
|
26
|
+
_MODEL_KWARGS = None
|
|
27
|
+
|
|
28
|
+
def _get_trained_synthesizer(self, data, metadata):
|
|
29
|
+
try:
|
|
30
|
+
from realtabformer import REaLTabFormer
|
|
31
|
+
except Exception as exception:
|
|
32
|
+
raise ValueError(
|
|
33
|
+
"In order to use 'RealTabFormerSynthesizer' you have to install the extra"
|
|
34
|
+
" dependencies by running pip install sdgym['realtabformer'] "
|
|
35
|
+
) from exception
|
|
36
|
+
|
|
37
|
+
with prevent_tqdm_output():
|
|
38
|
+
model_kwargs = self._MODEL_KWARGS.copy() if self._MODEL_KWARGS else {}
|
|
39
|
+
model = REaLTabFormer(model_type='tabular', **model_kwargs)
|
|
40
|
+
model.fit(data, device='cpu')
|
|
41
|
+
|
|
42
|
+
return model
|
|
43
|
+
|
|
44
|
+
def _sample_from_synthesizer(self, synthesizer, n_sample):
|
|
45
|
+
"""Sample synthetic data with specified sample count."""
|
|
46
|
+
return synthesizer.sample(n_sample, device='cpu')
|
|
@@ -19,7 +19,8 @@ class UniformSynthesizer(BaselineSynthesizer):
|
|
|
19
19
|
hyper_transformer.detect_initial_config(real_data)
|
|
20
20
|
supported_sdtypes = hyper_transformer._get_supported_sdtypes()
|
|
21
21
|
config = {}
|
|
22
|
-
|
|
22
|
+
table = next(iter(metadata.tables.values()))
|
|
23
|
+
for column_name, column in table.columns.items():
|
|
23
24
|
sdtype = column['sdtype']
|
|
24
25
|
if sdtype in supported_sdtypes:
|
|
25
26
|
config[column_name] = sdtype
|
|
@@ -27,7 +28,7 @@ class UniformSynthesizer(BaselineSynthesizer):
|
|
|
27
28
|
config[column_name] = 'pii'
|
|
28
29
|
else:
|
|
29
30
|
LOGGER.info(
|
|
30
|
-
f'Column {
|
|
31
|
+
f'Column {column_name} sdtype: {sdtype} is not supported, '
|
|
31
32
|
f'defaulting to inferred type.'
|
|
32
33
|
)
|
|
33
34
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
2
|
Name: sdgym
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.10.0.dev0
|
|
4
4
|
Summary: Benchmark tabular synthetic data generators using a variety of datasets
|
|
5
5
|
Author-email: "DataCebo, Inc." <info@sdv.dev>
|
|
6
6
|
License: BSL-1.1
|
|
@@ -30,9 +30,9 @@ Requires-Dist: botocore<2,>=1.31
|
|
|
30
30
|
Requires-Dist: cloudpickle>=2.1.0
|
|
31
31
|
Requires-Dist: compress-pickle>=1.2.0
|
|
32
32
|
Requires-Dist: humanfriendly>=8.2
|
|
33
|
-
Requires-Dist: numpy
|
|
34
|
-
Requires-Dist: numpy
|
|
35
|
-
Requires-Dist: numpy
|
|
33
|
+
Requires-Dist: numpy>=1.21.6; python_version < "3.10"
|
|
34
|
+
Requires-Dist: numpy>=1.23.3; python_version >= "3.10" and python_version < "3.12"
|
|
35
|
+
Requires-Dist: numpy>=1.26.0; python_version >= "3.12"
|
|
36
36
|
Requires-Dist: pandas>=1.4.0; python_version < "3.11"
|
|
37
37
|
Requires-Dist: pandas>=1.5.0; python_version >= "3.11" and python_version < "3.12"
|
|
38
38
|
Requires-Dist: pandas>=2.1.1; python_version >= "3.12"
|
|
@@ -45,18 +45,23 @@ Requires-Dist: scipy>=1.7.3; python_version < "3.10"
|
|
|
45
45
|
Requires-Dist: scipy>=1.9.2; python_version >= "3.10" and python_version < "3.12"
|
|
46
46
|
Requires-Dist: scipy>=1.12.0; python_version >= "3.12"
|
|
47
47
|
Requires-Dist: tabulate<0.9,>=0.8.3
|
|
48
|
-
Requires-Dist: torch>=1.
|
|
48
|
+
Requires-Dist: torch>=1.12.1; python_version < "3.10"
|
|
49
49
|
Requires-Dist: torch>=2.0.0; python_version >= "3.10" and python_version < "3.12"
|
|
50
50
|
Requires-Dist: torch>=2.2.0; python_version >= "3.12"
|
|
51
|
-
Requires-Dist: tqdm>=4.
|
|
51
|
+
Requires-Dist: tqdm>=4.66.3
|
|
52
52
|
Requires-Dist: XlsxWriter>=1.2.8
|
|
53
|
-
Requires-Dist: rdt>=1.
|
|
54
|
-
Requires-Dist: sdmetrics>=0.
|
|
55
|
-
Requires-Dist: sdv>=1.
|
|
53
|
+
Requires-Dist: rdt>=1.13.1
|
|
54
|
+
Requires-Dist: sdmetrics>=0.17.0
|
|
55
|
+
Requires-Dist: sdv>=1.17.2
|
|
56
56
|
Provides-Extra: dask
|
|
57
57
|
Requires-Dist: dask; extra == "dask"
|
|
58
58
|
Requires-Dist: distributed; extra == "dask"
|
|
59
|
+
Provides-Extra: realtabformer
|
|
60
|
+
Requires-Dist: realtabformer>=0.2.2; extra == "realtabformer"
|
|
61
|
+
Requires-Dist: torch>=2.0.0; (python_version >= "3.8" and python_version < "3.12") and extra == "realtabformer"
|
|
62
|
+
Requires-Dist: torch>=2.2.0; python_version >= "3.12" and extra == "realtabformer"
|
|
59
63
|
Provides-Extra: test
|
|
64
|
+
Requires-Dist: sdgym[realtabformer]; extra == "test"
|
|
60
65
|
Requires-Dist: pytest>=6.2.5; extra == "test"
|
|
61
66
|
Requires-Dist: pytest-cov>=2.6.0; extra == "test"
|
|
62
67
|
Requires-Dist: jupyter<2,>=1.0.0; extra == "test"
|
|
@@ -6,17 +6,17 @@ compress-pickle>=1.2.0
|
|
|
6
6
|
humanfriendly>=8.2
|
|
7
7
|
psutil>=5.7
|
|
8
8
|
tabulate<0.9,>=0.8.3
|
|
9
|
-
tqdm>=4.
|
|
9
|
+
tqdm>=4.66.3
|
|
10
10
|
XlsxWriter>=1.2.8
|
|
11
|
-
rdt>=1.
|
|
12
|
-
sdmetrics>=0.
|
|
13
|
-
sdv>=1.
|
|
11
|
+
rdt>=1.13.1
|
|
12
|
+
sdmetrics>=0.17.0
|
|
13
|
+
sdv>=1.17.2
|
|
14
14
|
|
|
15
15
|
[:python_version < "3.10"]
|
|
16
|
-
numpy
|
|
16
|
+
numpy>=1.21.6
|
|
17
17
|
scikit-learn>=1.0.2
|
|
18
18
|
scipy>=1.7.3
|
|
19
|
-
torch>=1.
|
|
19
|
+
torch>=1.12.1
|
|
20
20
|
|
|
21
21
|
[:python_version < "3.11"]
|
|
22
22
|
pandas>=1.4.0
|
|
@@ -25,7 +25,7 @@ pandas>=1.4.0
|
|
|
25
25
|
scikit-learn>=1.1.0
|
|
26
26
|
|
|
27
27
|
[:python_version >= "3.10" and python_version < "3.12"]
|
|
28
|
-
numpy
|
|
28
|
+
numpy>=1.23.3
|
|
29
29
|
scipy>=1.9.2
|
|
30
30
|
torch>=2.0.0
|
|
31
31
|
|
|
@@ -34,7 +34,7 @@ pandas>=1.5.0
|
|
|
34
34
|
scikit-learn>=1.1.3
|
|
35
35
|
|
|
36
36
|
[:python_version >= "3.12"]
|
|
37
|
-
numpy
|
|
37
|
+
numpy>=1.26.0
|
|
38
38
|
pandas>=2.1.1
|
|
39
39
|
scikit-learn>=1.3.1
|
|
40
40
|
scipy>=1.12.0
|
|
@@ -61,7 +61,17 @@ tox<5,>=2.9.1
|
|
|
61
61
|
importlib-metadata>=3.6
|
|
62
62
|
invoke
|
|
63
63
|
|
|
64
|
+
[realtabformer]
|
|
65
|
+
realtabformer>=0.2.2
|
|
66
|
+
|
|
67
|
+
[realtabformer:python_version >= "3.12"]
|
|
68
|
+
torch>=2.2.0
|
|
69
|
+
|
|
70
|
+
[realtabformer:python_version >= "3.8" and python_version < "3.12"]
|
|
71
|
+
torch>=2.0.0
|
|
72
|
+
|
|
64
73
|
[test]
|
|
74
|
+
sdgym[realtabformer]
|
|
65
75
|
pytest>=6.2.5
|
|
66
76
|
pytest-cov>=2.6.0
|
|
67
77
|
jupyter<2,>=1.0.0
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
"""Tests for the ``tasks.py`` file."""
|
|
2
|
+
|
|
3
|
+
from tasks import _get_extra_dependencies, _get_minimum_versions, _resolve_version_conflicts
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def test_get_minimum_versions():
|
|
7
|
+
"""Test the ``_get_minimum_versions`` method.
|
|
8
|
+
|
|
9
|
+
The method should return the minimum versions of the dependencies for the given python version.
|
|
10
|
+
If a library is linked to an URL, the minimum version should be the URL.
|
|
11
|
+
"""
|
|
12
|
+
# Setup
|
|
13
|
+
dependencies = [
|
|
14
|
+
"numpy>=1.20.0,<2;python_version<'3.10'",
|
|
15
|
+
"numpy>=1.23.3,<2;python_version>='3.10'",
|
|
16
|
+
"pandas>=1.2.0,<2;python_version<'3.10'",
|
|
17
|
+
"pandas>=1.3.0,<2;python_version>='3.10'",
|
|
18
|
+
'humanfriendly>=8.2,<11',
|
|
19
|
+
'pandas @ git+https://github.com/pandas-dev/pandas.git@master',
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
# Run
|
|
23
|
+
minimum_versions_39 = _get_minimum_versions(dependencies, '3.9')
|
|
24
|
+
minimum_versions_310 = _get_minimum_versions(dependencies, '3.10')
|
|
25
|
+
|
|
26
|
+
# Assert
|
|
27
|
+
expected_versions_39 = {
|
|
28
|
+
'numpy': 'numpy==1.20.0',
|
|
29
|
+
'pandas': 'git+https://github.com/pandas-dev/pandas.git@master#egg=pandas',
|
|
30
|
+
'humanfriendly': 'humanfriendly==8.2',
|
|
31
|
+
}
|
|
32
|
+
expected_versions_310 = {
|
|
33
|
+
'numpy': 'numpy==1.23.3',
|
|
34
|
+
'pandas': 'git+https://github.com/pandas-dev/pandas.git@master#egg=pandas',
|
|
35
|
+
'humanfriendly': 'humanfriendly==8.2',
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
assert minimum_versions_39 == expected_versions_39
|
|
39
|
+
assert minimum_versions_310 == expected_versions_310
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _get_example_pyproject_dict():
|
|
43
|
+
return {
|
|
44
|
+
'build-system': {
|
|
45
|
+
'build-backend': 'setuptools.build_meta',
|
|
46
|
+
'requires': ['setuptools', 'wheel'],
|
|
47
|
+
},
|
|
48
|
+
'project': {
|
|
49
|
+
'authors': [{'email': 'info@sdv.dev', 'name': 'DataCebo, Inc.'}],
|
|
50
|
+
'classifiers': [
|
|
51
|
+
'Intended Audience :: Developers',
|
|
52
|
+
'License :: Free for non-commercial use',
|
|
53
|
+
'Natural Language :: English',
|
|
54
|
+
'Programming Language :: Python :: 3.10',
|
|
55
|
+
'Programming Language :: Python :: 3.11',
|
|
56
|
+
'Programming Language :: Python :: 3.12',
|
|
57
|
+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
|
58
|
+
],
|
|
59
|
+
'dependencies': [
|
|
60
|
+
'appdirs>=1.3',
|
|
61
|
+
'boto3>=1.28,<2',
|
|
62
|
+
'botocore>=1.31,<2',
|
|
63
|
+
'cloudpickle>=2.1.0',
|
|
64
|
+
'compress-pickle>=1.2.0',
|
|
65
|
+
'humanfriendly>=8.2',
|
|
66
|
+
"numpy>=1.21.6;python_version<'3.10'",
|
|
67
|
+
"numpy>=1.23.3;python_version>='3.10' and python_version<'3.12'",
|
|
68
|
+
"numpy>=1.26.0;python_version>='3.12'",
|
|
69
|
+
"pandas>=1.4.0;python_version<'3.11'",
|
|
70
|
+
"pandas>=1.5.0;python_version>='3.11' and python_version<'3.12'",
|
|
71
|
+
"pandas>=2.1.1;python_version>='3.12'",
|
|
72
|
+
'psutil>=5.7',
|
|
73
|
+
"scikit-learn>=1.0.2;python_version<'3.10'",
|
|
74
|
+
"scikit-learn>=1.1.0;python_version>='3.10' and python_version<'3.11'",
|
|
75
|
+
"scikit-learn>=1.1.3;python_version>='3.11' and python_version<'3.12'",
|
|
76
|
+
"scikit-learn>=1.3.1;python_version>='3.12'",
|
|
77
|
+
"scipy>=1.7.3;python_version<'3.10'",
|
|
78
|
+
"scipy>=1.9.2;python_version>='3.10' and python_version<'3.12'",
|
|
79
|
+
"scipy>=1.12.0;python_version>='3.12'",
|
|
80
|
+
'tabulate>=0.8.3,<0.9',
|
|
81
|
+
"torch>=1.12.1;python_version<'3.10'",
|
|
82
|
+
"torch>=2.0.0;python_version>='3.10' and python_version<'3.12'",
|
|
83
|
+
"torch>=2.2.0;python_version>='3.12'",
|
|
84
|
+
'tqdm>=4.66.3',
|
|
85
|
+
'XlsxWriter>=1.2.8',
|
|
86
|
+
'rdt>=1.13.1',
|
|
87
|
+
'sdmetrics>=0.17.0',
|
|
88
|
+
'sdv>=1.17.2',
|
|
89
|
+
],
|
|
90
|
+
'dynamic': ['version'],
|
|
91
|
+
'license': {'text': 'BSL-1.1'},
|
|
92
|
+
'name': 'sdgym',
|
|
93
|
+
'optional-dependencies': {
|
|
94
|
+
'all': ['sdgym[dask, test, dev]'],
|
|
95
|
+
'dask': ['dask', 'distributed'],
|
|
96
|
+
'dev': [
|
|
97
|
+
'sdgym[dask, test]',
|
|
98
|
+
'build>=1.0.0,<2',
|
|
99
|
+
'bump-my-version>=0.18.3,<1',
|
|
100
|
+
'pip>=9.0.1',
|
|
101
|
+
'watchdog>=1.0.1,<5',
|
|
102
|
+
'ruff>=0.4.5,<1',
|
|
103
|
+
'twine>=1.10.0,<6',
|
|
104
|
+
'wheel>=0.30.0',
|
|
105
|
+
'coverage>=4.5.12,<8',
|
|
106
|
+
'tox>=2.9.1,<5',
|
|
107
|
+
'importlib-metadata>=3.6',
|
|
108
|
+
'invoke',
|
|
109
|
+
],
|
|
110
|
+
'realtabformer': ['realtabformer>=0.2.1'],
|
|
111
|
+
'test': [
|
|
112
|
+
'sdgym[realtabformer]',
|
|
113
|
+
'pytest>=6.2.5',
|
|
114
|
+
],
|
|
115
|
+
},
|
|
116
|
+
'readme': 'README.md',
|
|
117
|
+
'requires-python': '>=3.8,<3.13',
|
|
118
|
+
},
|
|
119
|
+
'tool': {
|
|
120
|
+
'bumpversion': {
|
|
121
|
+
'allow_dirty': False,
|
|
122
|
+
'commit': True,
|
|
123
|
+
'commit_args': '',
|
|
124
|
+
},
|
|
125
|
+
'ruff': {
|
|
126
|
+
'exclude': [
|
|
127
|
+
'docs',
|
|
128
|
+
'.tox',
|
|
129
|
+
'.git',
|
|
130
|
+
'__pycache__',
|
|
131
|
+
'.ipynb_checkpoints',
|
|
132
|
+
'tasks.py',
|
|
133
|
+
],
|
|
134
|
+
'indent-width': 4,
|
|
135
|
+
},
|
|
136
|
+
},
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def test__get_extra_dependencies():
|
|
141
|
+
"""Test that the proper dependency strings are extracted from the pyproject dictionary."""
|
|
142
|
+
# Setup
|
|
143
|
+
pyproject_dict = _get_example_pyproject_dict()
|
|
144
|
+
|
|
145
|
+
# Run
|
|
146
|
+
extra_dependencies = _get_extra_dependencies(pyproject_dict)
|
|
147
|
+
|
|
148
|
+
# Assert
|
|
149
|
+
assert extra_dependencies == ['realtabformer>=0.2.1']
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def test__resolve_version_conflicts_conflicting_versions():
|
|
153
|
+
"""Test that any conflicts for the same dependency are resolved to the higher version."""
|
|
154
|
+
# Setup
|
|
155
|
+
deps = {
|
|
156
|
+
'numpy': 'numpy==2.0.1',
|
|
157
|
+
'pandas': 'pandas==2.2.1',
|
|
158
|
+
'sdv': 'sdv==2.1.1',
|
|
159
|
+
'rdt': 'rdt==1.1.2',
|
|
160
|
+
}
|
|
161
|
+
extra_deps = {
|
|
162
|
+
'numpy': 'numpy==2.0.0',
|
|
163
|
+
'pandas': 'pandas==2.3.0',
|
|
164
|
+
'sdv': 'sdv==3.0.0',
|
|
165
|
+
'copulas': 'copulas==0.12.0',
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
# Run
|
|
169
|
+
versions = _resolve_version_conflicts(deps, extra_deps)
|
|
170
|
+
|
|
171
|
+
# Assert
|
|
172
|
+
assert sorted(versions) == sorted([
|
|
173
|
+
'numpy==2.0.1',
|
|
174
|
+
'pandas==2.3.0',
|
|
175
|
+
'sdv==3.0.0',
|
|
176
|
+
'rdt==1.1.2',
|
|
177
|
+
'copulas==0.12.0',
|
|
178
|
+
])
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def test__resolve_version_conflicts_pointing_to_branch():
|
|
182
|
+
"""Test specific branches are always selected over normal version numbers."""
|
|
183
|
+
# Setup
|
|
184
|
+
deps = {
|
|
185
|
+
'numpy': 'git+https://github.com/numpy-dev/numpy.git@master#egg=numpy',
|
|
186
|
+
'pandas': 'pandas==2.2.1',
|
|
187
|
+
'sdv': 'sdv==2.1.1',
|
|
188
|
+
'rdt': 'rdt==1.1.2',
|
|
189
|
+
}
|
|
190
|
+
extra_deps = {
|
|
191
|
+
'numpy': 'numpy==2.0.0',
|
|
192
|
+
'pandas': 'git+https://github.com/pandas-dev/pandas.git@master#egg=pandas',
|
|
193
|
+
'sdv': 'sdv==3.0.0',
|
|
194
|
+
'copulas': 'copulas==0.12.0',
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
# Run
|
|
198
|
+
versions = _resolve_version_conflicts(deps, extra_deps)
|
|
199
|
+
|
|
200
|
+
# Assert
|
|
201
|
+
assert sorted(versions) == sorted([
|
|
202
|
+
'git+https://github.com/numpy-dev/numpy.git@master#egg=numpy',
|
|
203
|
+
'git+https://github.com/pandas-dev/pandas.git@master#egg=pandas',
|
|
204
|
+
'sdv==3.0.0',
|
|
205
|
+
'rdt==1.1.2',
|
|
206
|
+
'copulas==0.12.0',
|
|
207
|
+
])
|
|
@@ -1,39 +0,0 @@
|
|
|
1
|
-
"""Tests for the ``tasks.py`` file."""
|
|
2
|
-
|
|
3
|
-
from tasks import _get_minimum_versions
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def test_get_minimum_versions():
|
|
7
|
-
"""Test the ``_get_minimum_versions`` method.
|
|
8
|
-
|
|
9
|
-
The method should return the minimum versions of the dependencies for the given python version.
|
|
10
|
-
If a library is linked to an URL, the minimum version should be the URL.
|
|
11
|
-
"""
|
|
12
|
-
# Setup
|
|
13
|
-
dependencies = [
|
|
14
|
-
"numpy>=1.20.0,<2;python_version<'3.10'",
|
|
15
|
-
"numpy>=1.23.3,<2;python_version>='3.10'",
|
|
16
|
-
"pandas>=1.2.0,<2;python_version<'3.10'",
|
|
17
|
-
"pandas>=1.3.0,<2;python_version>='3.10'",
|
|
18
|
-
'humanfriendly>=8.2,<11',
|
|
19
|
-
'pandas @ git+https://github.com/pandas-dev/pandas.git@master',
|
|
20
|
-
]
|
|
21
|
-
|
|
22
|
-
# Run
|
|
23
|
-
minimum_versions_39 = _get_minimum_versions(dependencies, '3.9')
|
|
24
|
-
minimum_versions_310 = _get_minimum_versions(dependencies, '3.10')
|
|
25
|
-
|
|
26
|
-
# Assert
|
|
27
|
-
expected_versions_39 = [
|
|
28
|
-
'numpy==1.20.0',
|
|
29
|
-
'git+https://github.com/pandas-dev/pandas.git@master#egg=pandas',
|
|
30
|
-
'humanfriendly==8.2',
|
|
31
|
-
]
|
|
32
|
-
expected_versions_310 = [
|
|
33
|
-
'numpy==1.23.3',
|
|
34
|
-
'git+https://github.com/pandas-dev/pandas.git@master#egg=pandas',
|
|
35
|
-
'humanfriendly==8.2',
|
|
36
|
-
]
|
|
37
|
-
|
|
38
|
-
assert minimum_versions_39 == expected_versions_39
|
|
39
|
-
assert minimum_versions_310 == expected_versions_310
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|