sdgym 0.9.1.dev0__tar.gz → 0.10.0.dev0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/PKG-INFO +15 -10
  2. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/pyproject.toml +19 -12
  3. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/__init__.py +1 -1
  4. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/benchmark.py +8 -8
  5. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/cli/__main__.py +6 -6
  6. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/synthesizers/__init__.py +2 -0
  7. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/synthesizers/base.py +7 -4
  8. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/synthesizers/column.py +26 -0
  9. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/synthesizers/generate.py +1 -1
  10. sdgym-0.10.0.dev0/sdgym/synthesizers/realtabformer.py +46 -0
  11. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/synthesizers/uniform.py +3 -2
  12. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym.egg-info/PKG-INFO +15 -10
  13. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym.egg-info/SOURCES.txt +1 -0
  14. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym.egg-info/requires.txt +18 -8
  15. sdgym-0.10.0.dev0/tests/test_tasks.py +207 -0
  16. sdgym-0.9.1.dev0/tests/test_tasks.py +0 -39
  17. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/LICENSE +0 -0
  18. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/README.md +0 -0
  19. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/cli/__init__.py +0 -0
  20. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/cli/collect.py +0 -0
  21. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/cli/summary.py +0 -0
  22. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/cli/utils.py +0 -0
  23. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/datasets.py +0 -0
  24. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/errors.py +0 -0
  25. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/metrics.py +0 -0
  26. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/progress.py +0 -0
  27. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/s3.py +0 -0
  28. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/synthesizers/identity.py +0 -0
  29. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/synthesizers/sdv.py +0 -0
  30. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym/utils.py +0 -0
  31. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym.egg-info/dependency_links.txt +0 -0
  32. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym.egg-info/entry_points.txt +0 -0
  33. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/sdgym.egg-info/top_level.txt +0 -0
  34. {sdgym-0.9.1.dev0 → sdgym-0.10.0.dev0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: sdgym
3
- Version: 0.9.1.dev0
3
+ Version: 0.10.0.dev0
4
4
  Summary: Benchmark tabular synthetic data generators using a variety of datasets
5
5
  Author-email: "DataCebo, Inc." <info@sdv.dev>
6
6
  License: BSL-1.1
@@ -30,9 +30,9 @@ Requires-Dist: botocore<2,>=1.31
30
30
  Requires-Dist: cloudpickle>=2.1.0
31
31
  Requires-Dist: compress-pickle>=1.2.0
32
32
  Requires-Dist: humanfriendly>=8.2
33
- Requires-Dist: numpy<2.0.0,>=1.21.0; python_version < "3.10"
34
- Requires-Dist: numpy<2.0.0,>=1.23.3; python_version >= "3.10" and python_version < "3.12"
35
- Requires-Dist: numpy<2.0.0,>=1.26.0; python_version >= "3.12"
33
+ Requires-Dist: numpy>=1.21.6; python_version < "3.10"
34
+ Requires-Dist: numpy>=1.23.3; python_version >= "3.10" and python_version < "3.12"
35
+ Requires-Dist: numpy>=1.26.0; python_version >= "3.12"
36
36
  Requires-Dist: pandas>=1.4.0; python_version < "3.11"
37
37
  Requires-Dist: pandas>=1.5.0; python_version >= "3.11" and python_version < "3.12"
38
38
  Requires-Dist: pandas>=2.1.1; python_version >= "3.12"
@@ -45,18 +45,23 @@ Requires-Dist: scipy>=1.7.3; python_version < "3.10"
45
45
  Requires-Dist: scipy>=1.9.2; python_version >= "3.10" and python_version < "3.12"
46
46
  Requires-Dist: scipy>=1.12.0; python_version >= "3.12"
47
47
  Requires-Dist: tabulate<0.9,>=0.8.3
48
- Requires-Dist: torch>=1.9.0; python_version < "3.10"
48
+ Requires-Dist: torch>=1.12.1; python_version < "3.10"
49
49
  Requires-Dist: torch>=2.0.0; python_version >= "3.10" and python_version < "3.12"
50
50
  Requires-Dist: torch>=2.2.0; python_version >= "3.12"
51
- Requires-Dist: tqdm>=4.29
51
+ Requires-Dist: tqdm>=4.66.3
52
52
  Requires-Dist: XlsxWriter>=1.2.8
53
- Requires-Dist: rdt>=1.12.1
54
- Requires-Dist: sdmetrics>=0.14.1
55
- Requires-Dist: sdv>=1.13.1
53
+ Requires-Dist: rdt>=1.13.1
54
+ Requires-Dist: sdmetrics>=0.17.0
55
+ Requires-Dist: sdv>=1.17.2
56
56
  Provides-Extra: dask
57
57
  Requires-Dist: dask; extra == "dask"
58
58
  Requires-Dist: distributed; extra == "dask"
59
+ Provides-Extra: realtabformer
60
+ Requires-Dist: realtabformer>=0.2.2; extra == "realtabformer"
61
+ Requires-Dist: torch>=2.0.0; (python_version >= "3.8" and python_version < "3.12") and extra == "realtabformer"
62
+ Requires-Dist: torch>=2.2.0; python_version >= "3.12" and extra == "realtabformer"
59
63
  Provides-Extra: test
64
+ Requires-Dist: sdgym[realtabformer]; extra == "test"
60
65
  Requires-Dist: pytest>=6.2.5; extra == "test"
61
66
  Requires-Dist: pytest-cov>=2.6.0; extra == "test"
62
67
  Requires-Dist: jupyter<2,>=1.0.0; extra == "test"
@@ -27,9 +27,9 @@ dependencies = [
27
27
  'cloudpickle>=2.1.0',
28
28
  'compress-pickle>=1.2.0',
29
29
  'humanfriendly>=8.2',
30
- "numpy>=1.21.0,<2.0.0;python_version<'3.10'",
31
- "numpy>=1.23.3,<2.0.0;python_version>='3.10' and python_version<'3.12'",
32
- "numpy>=1.26.0,<2.0.0;python_version>='3.12'",
30
+ "numpy>=1.21.6;python_version<'3.10'",
31
+ "numpy>=1.23.3;python_version>='3.10' and python_version<'3.12'",
32
+ "numpy>=1.26.0;python_version>='3.12'",
33
33
  "pandas>=1.4.0;python_version<'3.11'",
34
34
  "pandas>=1.5.0;python_version>='3.11' and python_version<'3.12'",
35
35
  "pandas>=2.1.1;python_version>='3.12'",
@@ -42,14 +42,14 @@ dependencies = [
42
42
  "scipy>=1.9.2;python_version>='3.10' and python_version<'3.12'",
43
43
  "scipy>=1.12.0;python_version>='3.12'",
44
44
  'tabulate>=0.8.3,<0.9',
45
- "torch>=1.9.0;python_version<'3.10'",
45
+ "torch>=1.12.1;python_version<'3.10'",
46
46
  "torch>=2.0.0;python_version>='3.10' and python_version<'3.12'",
47
47
  "torch>=2.2.0;python_version>='3.12'",
48
- 'tqdm>=4.29',
48
+ 'tqdm>=4.66.3',
49
49
  'XlsxWriter>=1.2.8',
50
- 'rdt>=1.12.1',
51
- 'sdmetrics>=0.14.1',
52
- 'sdv>=1.13.1',
50
+ 'rdt>=1.13.1',
51
+ 'sdmetrics>=0.17.0',
52
+ 'sdv>=1.17.2',
53
53
  ]
54
54
 
55
55
  [project.urls]
@@ -64,7 +64,13 @@ sdgym = { main = 'sdgym.cli.__main__:main' }
64
64
 
65
65
  [project.optional-dependencies]
66
66
  dask = ['dask', 'distributed']
67
+ realtabformer = [
68
+ 'realtabformer>=0.2.2',
69
+ "torch>=2.0.0;python_version>='3.8' and python_version<'3.12'",
70
+ "torch>=2.2.0;python_version>='3.12'",
71
+ ]
67
72
  test = [
73
+ 'sdgym[realtabformer]',
68
74
  'pytest>=6.2.5',
69
75
  'pytest-cov>=2.6.0',
70
76
  'jupyter>=1.0.0,<2',
@@ -134,7 +140,7 @@ namespaces = false
134
140
  version = {attr = 'sdgym.__version__'}
135
141
 
136
142
  [tool.bumpversion]
137
- current_version = "0.9.1.dev0"
143
+ current_version = "0.10.0.dev0"
138
144
  parse = '(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<candidate>\d+))?'
139
145
  serialize = [
140
146
  '{major}.{minor}.{patch}.{release}{candidate}',
@@ -198,10 +204,11 @@ select = [
198
204
  # print statements
199
205
  "T201",
200
206
  # pandas-vet
201
- "PD"
207
+ "PD",
208
+ # numpy 2.0
209
+ "NPY201"
202
210
  ]
203
211
  ignore = [
204
- "E501",
205
212
  # pydocstyle
206
213
  "D107", # Missing docstring in __init__
207
214
  "D417", # Missing argument descriptions in the docstring, this is a bug from pydocstyle: https://github.com/PyCQA/pydocstyle/issues/449
@@ -230,4 +237,4 @@ convention = "google"
230
237
 
231
238
  [tool.ruff.lint.pycodestyle]
232
239
  max-doc-length = 100
233
- max-line-length = 100
240
+ max-line-length = 100
@@ -8,7 +8,7 @@ __author__ = 'DataCebo, Inc.'
8
8
  __copyright__ = 'Copyright (c) 2022 DataCebo, Inc.'
9
9
  __email__ = 'info@sdv.dev'
10
10
  __license__ = 'BSL-1.1'
11
- __version__ = '0.9.1.dev0'
11
+ __version__ = '0.10.0.dev0'
12
12
 
13
13
  import logging
14
14
 
@@ -66,8 +66,7 @@ N_BYTES_IN_MB = 1000 * 1000
66
66
  def _validate_inputs(output_filepath, detailed_results_folder, synthesizers, custom_synthesizers):
67
67
  if output_filepath and os.path.exists(output_filepath):
68
68
  raise ValueError(
69
- f'{output_filepath} already exists. '
70
- 'Please provide a file that does not already exist.'
69
+ f'{output_filepath} already exists. Please provide a file that does not already exist.'
71
70
  )
72
71
 
73
72
  if detailed_results_folder and os.path.exists(detailed_results_folder):
@@ -149,14 +148,14 @@ def _generate_job_args_list(
149
148
  def _synthesize(synthesizer_dict, real_data, metadata):
150
149
  synthesizer = synthesizer_dict['synthesizer']
151
150
  if isinstance(synthesizer, type):
152
- assert issubclass(
153
- synthesizer, BaselineSynthesizer
154
- ), '`synthesizer` must be a synthesizer class'
151
+ assert issubclass(synthesizer, BaselineSynthesizer), (
152
+ '`synthesizer` must be a synthesizer class'
153
+ )
155
154
  synthesizer = synthesizer()
156
155
  else:
157
- assert issubclass(
158
- type(synthesizer), BaselineSynthesizer
159
- ), '`synthesizer` must be an instance of a synthesizer class.'
156
+ assert issubclass(type(synthesizer), BaselineSynthesizer), (
157
+ '`synthesizer` must be an instance of a synthesizer class.'
158
+ )
160
159
 
161
160
  get_synthesizer = synthesizer.get_trained_synthesizer
162
161
  sample_from_synthesizer = synthesizer.sample_from_synthesizer
@@ -747,6 +746,7 @@ def benchmark_single_table(
747
746
  - ``CTGANSynthesizer``
748
747
  - ``CopulaGANSynthesizer``
749
748
  - ``TVAESynthesizer``
749
+ - ``RealTabFormerSynthesizer``
750
750
 
751
751
  custom_synthesizers (list[class] or ``None``):
752
752
  A list of custom synthesizer classes to use. These can be completely custom or
@@ -180,9 +180,9 @@ def _get_parser():
180
180
  )
181
181
  run.add_argument('-m', '--metrics', nargs='+', help='Metrics to apply. Accepts multiple names.')
182
182
  run.add_argument('-b', '--bucket', help='Bucket from which to download the datasets.')
183
- run.add_argument('-dp' '--datasets-path', help='Path where datasets can be found.')
183
+ run.add_argument('-dp--datasets-path', help='Path where datasets can be found.')
184
184
  run.add_argument(
185
- '-dm' '--modalities', nargs='+', help='Data Modalities to run. Accepts multiple names.'
185
+ '-dm--modalities', nargs='+', help='Data Modalities to run. Accepts multiple names.'
186
186
  )
187
187
  run.add_argument('-i', '--iterations', type=int, default=1, help='Number of iterations.')
188
188
  run.add_argument(
@@ -219,13 +219,13 @@ def _get_parser():
219
219
  '-g', '--groupby', nargs='+', help='Group scores leaderboard by the given fields.'
220
220
  )
221
221
  run.add_argument(
222
- '-ak' '--aws-key',
222
+ '-ak--aws-key',
223
223
  type=str,
224
224
  required=False,
225
225
  help='Aws access key ID to use when reading datasets.',
226
226
  )
227
227
  run.add_argument(
228
- '-as' '--aws-secret',
228
+ '-as--aws-secret',
229
229
  type=str,
230
230
  required=False,
231
231
  help='Aws secret access key to use when reading datasets.',
@@ -234,10 +234,10 @@ def _get_parser():
234
234
  '-j', '--jobs', type=str, required=False, help='Serialized list of jobs to run.'
235
235
  )
236
236
  run.add_argument(
237
- '-mr' '--max-rows', type=int, help='Cap the number of rows to model from each dataset.'
237
+ '-mr--max-rows', type=int, help='Cap the number of rows to model from each dataset.'
238
238
  )
239
239
  run.add_argument(
240
- '-mc' '--max-columns',
240
+ '-mc--max-columns',
241
241
  type=int,
242
242
  help='Cap the number of columns to model from each dataset.',
243
243
  )
@@ -9,6 +9,7 @@ from sdgym.synthesizers.generate import (
9
9
  )
10
10
  from sdgym.synthesizers.identity import DataIdentity
11
11
  from sdgym.synthesizers.column import ColumnSynthesizer
12
+ from sdgym.synthesizers.realtabformer import RealTabFormerSynthesizer
12
13
  from sdgym.synthesizers.sdv import (
13
14
  CopulaGANSynthesizer,
14
15
  CTGANSynthesizer,
@@ -38,4 +39,5 @@ __all__ = (
38
39
  'create_sdv_synthesizer_variant',
39
40
  'create_sequential_synthesizer',
40
41
  'SYNTHESIZER_MAPPING',
42
+ 'RealTabFormerSynthesizer',
41
43
  )
@@ -2,9 +2,9 @@
2
2
 
3
3
  import abc
4
4
  import logging
5
+ import warnings
5
6
 
6
- from sdv.metadata.multi_table import MultiTableMetadata
7
- from sdv.metadata.single_table import SingleTableMetadata
7
+ from sdv.metadata import Metadata
8
8
 
9
9
  LOGGER = logging.getLogger(__name__)
10
10
 
@@ -54,8 +54,11 @@ class BaselineSynthesizer(abc.ABC):
54
54
  obj:
55
55
  The synthesizer object.
56
56
  """
57
- metadata_class = MultiTableMetadata() if 'tables' in metadata else SingleTableMetadata()
58
- metadata = metadata_class.load_from_dict(metadata)
57
+ metadata_object = Metadata()
58
+ with warnings.catch_warnings():
59
+ warnings.simplefilter('ignore', UserWarning)
60
+ metadata = metadata_object.load_from_dict(metadata)
61
+
59
62
  return self._get_trained_synthesizer(data, metadata)
60
63
 
61
64
  def sample_from_synthesizer(self, synthesizer, n_samples):
@@ -1,11 +1,16 @@
1
1
  """ColumnSynthesizer module."""
2
2
 
3
+ import logging
4
+
3
5
  import pandas as pd
4
6
  from rdt.hyper_transformer import HyperTransformer
7
+ from sdv.metadata import Metadata
5
8
  from sklearn.mixture import GaussianMixture
6
9
 
7
10
  from sdgym.synthesizers.base import BaselineSynthesizer
8
11
 
12
+ LOGGER = logging.getLogger(__name__)
13
+
9
14
 
10
15
  class ColumnSynthesizer(BaselineSynthesizer):
11
16
  """Synthesizer that learns each column independently.
@@ -17,6 +22,27 @@ class ColumnSynthesizer(BaselineSynthesizer):
17
22
  def _get_trained_synthesizer(self, real_data, metadata):
18
23
  hyper_transformer = HyperTransformer()
19
24
  hyper_transformer.detect_initial_config(real_data)
25
+ supported_sdtypes = hyper_transformer._get_supported_sdtypes()
26
+ config = {}
27
+ if isinstance(metadata, Metadata):
28
+ table_name = metadata._get_single_table_name()
29
+ columns = metadata.tables[table_name].columns
30
+ else:
31
+ columns = metadata.columns
32
+
33
+ for column_name, column in columns.items():
34
+ sdtype = column['sdtype']
35
+ if sdtype in supported_sdtypes:
36
+ config[column_name] = sdtype
37
+ elif column.get('pii', False):
38
+ config[column_name] = 'pii'
39
+ else:
40
+ LOGGER.info(
41
+ f'Column {column} sdtype: {sdtype} is not supported, '
42
+ f'defaulting to inferred type.'
43
+ )
44
+
45
+ hyper_transformer.update_sdtypes(config)
20
46
 
21
47
  # This is done to match the behavior of the synthesizer for SDGym <= 0.6.0
22
48
  columns_to_remove = [
@@ -49,7 +49,7 @@ def create_sdv_synthesizer_variant(display_name, synthesizer_class, synthesizer_
49
49
  if synthesizer_class not in SYNTHESIZER_MAPPING.keys():
50
50
  raise ValueError(
51
51
  f'Synthesizer class {synthesizer_class} is not recognized. '
52
- f"The supported options are {', '.join(SYNTHESIZER_MAPPING.keys())}"
52
+ f'The supported options are {", ".join(SYNTHESIZER_MAPPING.keys())}'
53
53
  )
54
54
 
55
55
  baseclass = SDVTabularSynthesizer
@@ -0,0 +1,46 @@
1
+ """REaLTabFormer integration."""
2
+
3
+ import contextlib
4
+ import logging
5
+ from functools import partialmethod
6
+
7
+ import tqdm
8
+
9
+ from sdgym.synthesizers.base import BaselineSynthesizer
10
+
11
+
12
+ @contextlib.contextmanager
13
+ def prevent_tqdm_output():
14
+ """Temporarily disables tqdm m."""
15
+ tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
16
+ try:
17
+ yield
18
+ finally:
19
+ tqdm.__init__ = partialmethod(tqdm.__init__, disable=False)
20
+
21
+
22
+ class RealTabFormerSynthesizer(BaselineSynthesizer):
23
+ """Custom wrapper for the REaLTabFormer synthesizer to make it work with SDGym."""
24
+
25
+ LOGGER = logging.getLogger(__name__)
26
+ _MODEL_KWARGS = None
27
+
28
+ def _get_trained_synthesizer(self, data, metadata):
29
+ try:
30
+ from realtabformer import REaLTabFormer
31
+ except Exception as exception:
32
+ raise ValueError(
33
+ "In order to use 'RealTabFormerSynthesizer' you have to install the extra"
34
+ " dependencies by running pip install sdgym['realtabformer'] "
35
+ ) from exception
36
+
37
+ with prevent_tqdm_output():
38
+ model_kwargs = self._MODEL_KWARGS.copy() if self._MODEL_KWARGS else {}
39
+ model = REaLTabFormer(model_type='tabular', **model_kwargs)
40
+ model.fit(data, device='cpu')
41
+
42
+ return model
43
+
44
+ def _sample_from_synthesizer(self, synthesizer, n_sample):
45
+ """Sample synthetic data with specified sample count."""
46
+ return synthesizer.sample(n_sample, device='cpu')
@@ -19,7 +19,8 @@ class UniformSynthesizer(BaselineSynthesizer):
19
19
  hyper_transformer.detect_initial_config(real_data)
20
20
  supported_sdtypes = hyper_transformer._get_supported_sdtypes()
21
21
  config = {}
22
- for column_name, column in metadata.columns.items():
22
+ table = next(iter(metadata.tables.values()))
23
+ for column_name, column in table.columns.items():
23
24
  sdtype = column['sdtype']
24
25
  if sdtype in supported_sdtypes:
25
26
  config[column_name] = sdtype
@@ -27,7 +28,7 @@ class UniformSynthesizer(BaselineSynthesizer):
27
28
  config[column_name] = 'pii'
28
29
  else:
29
30
  LOGGER.info(
30
- f'Column {column} sdtype: {sdtype} is not supported, '
31
+ f'Column {column_name} sdtype: {sdtype} is not supported, '
31
32
  f'defaulting to inferred type.'
32
33
  )
33
34
 
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: sdgym
3
- Version: 0.9.1.dev0
3
+ Version: 0.10.0.dev0
4
4
  Summary: Benchmark tabular synthetic data generators using a variety of datasets
5
5
  Author-email: "DataCebo, Inc." <info@sdv.dev>
6
6
  License: BSL-1.1
@@ -30,9 +30,9 @@ Requires-Dist: botocore<2,>=1.31
30
30
  Requires-Dist: cloudpickle>=2.1.0
31
31
  Requires-Dist: compress-pickle>=1.2.0
32
32
  Requires-Dist: humanfriendly>=8.2
33
- Requires-Dist: numpy<2.0.0,>=1.21.0; python_version < "3.10"
34
- Requires-Dist: numpy<2.0.0,>=1.23.3; python_version >= "3.10" and python_version < "3.12"
35
- Requires-Dist: numpy<2.0.0,>=1.26.0; python_version >= "3.12"
33
+ Requires-Dist: numpy>=1.21.6; python_version < "3.10"
34
+ Requires-Dist: numpy>=1.23.3; python_version >= "3.10" and python_version < "3.12"
35
+ Requires-Dist: numpy>=1.26.0; python_version >= "3.12"
36
36
  Requires-Dist: pandas>=1.4.0; python_version < "3.11"
37
37
  Requires-Dist: pandas>=1.5.0; python_version >= "3.11" and python_version < "3.12"
38
38
  Requires-Dist: pandas>=2.1.1; python_version >= "3.12"
@@ -45,18 +45,23 @@ Requires-Dist: scipy>=1.7.3; python_version < "3.10"
45
45
  Requires-Dist: scipy>=1.9.2; python_version >= "3.10" and python_version < "3.12"
46
46
  Requires-Dist: scipy>=1.12.0; python_version >= "3.12"
47
47
  Requires-Dist: tabulate<0.9,>=0.8.3
48
- Requires-Dist: torch>=1.9.0; python_version < "3.10"
48
+ Requires-Dist: torch>=1.12.1; python_version < "3.10"
49
49
  Requires-Dist: torch>=2.0.0; python_version >= "3.10" and python_version < "3.12"
50
50
  Requires-Dist: torch>=2.2.0; python_version >= "3.12"
51
- Requires-Dist: tqdm>=4.29
51
+ Requires-Dist: tqdm>=4.66.3
52
52
  Requires-Dist: XlsxWriter>=1.2.8
53
- Requires-Dist: rdt>=1.12.1
54
- Requires-Dist: sdmetrics>=0.14.1
55
- Requires-Dist: sdv>=1.13.1
53
+ Requires-Dist: rdt>=1.13.1
54
+ Requires-Dist: sdmetrics>=0.17.0
55
+ Requires-Dist: sdv>=1.17.2
56
56
  Provides-Extra: dask
57
57
  Requires-Dist: dask; extra == "dask"
58
58
  Requires-Dist: distributed; extra == "dask"
59
+ Provides-Extra: realtabformer
60
+ Requires-Dist: realtabformer>=0.2.2; extra == "realtabformer"
61
+ Requires-Dist: torch>=2.0.0; (python_version >= "3.8" and python_version < "3.12") and extra == "realtabformer"
62
+ Requires-Dist: torch>=2.2.0; python_version >= "3.12" and extra == "realtabformer"
59
63
  Provides-Extra: test
64
+ Requires-Dist: sdgym[realtabformer]; extra == "test"
60
65
  Requires-Dist: pytest>=6.2.5; extra == "test"
61
66
  Requires-Dist: pytest-cov>=2.6.0; extra == "test"
62
67
  Requires-Dist: jupyter<2,>=1.0.0; extra == "test"
@@ -25,6 +25,7 @@ sdgym/synthesizers/base.py
25
25
  sdgym/synthesizers/column.py
26
26
  sdgym/synthesizers/generate.py
27
27
  sdgym/synthesizers/identity.py
28
+ sdgym/synthesizers/realtabformer.py
28
29
  sdgym/synthesizers/sdv.py
29
30
  sdgym/synthesizers/uniform.py
30
31
  tests/test_tasks.py
@@ -6,17 +6,17 @@ compress-pickle>=1.2.0
6
6
  humanfriendly>=8.2
7
7
  psutil>=5.7
8
8
  tabulate<0.9,>=0.8.3
9
- tqdm>=4.29
9
+ tqdm>=4.66.3
10
10
  XlsxWriter>=1.2.8
11
- rdt>=1.12.1
12
- sdmetrics>=0.14.1
13
- sdv>=1.13.1
11
+ rdt>=1.13.1
12
+ sdmetrics>=0.17.0
13
+ sdv>=1.17.2
14
14
 
15
15
  [:python_version < "3.10"]
16
- numpy<2.0.0,>=1.21.0
16
+ numpy>=1.21.6
17
17
  scikit-learn>=1.0.2
18
18
  scipy>=1.7.3
19
- torch>=1.9.0
19
+ torch>=1.12.1
20
20
 
21
21
  [:python_version < "3.11"]
22
22
  pandas>=1.4.0
@@ -25,7 +25,7 @@ pandas>=1.4.0
25
25
  scikit-learn>=1.1.0
26
26
 
27
27
  [:python_version >= "3.10" and python_version < "3.12"]
28
- numpy<2.0.0,>=1.23.3
28
+ numpy>=1.23.3
29
29
  scipy>=1.9.2
30
30
  torch>=2.0.0
31
31
 
@@ -34,7 +34,7 @@ pandas>=1.5.0
34
34
  scikit-learn>=1.1.3
35
35
 
36
36
  [:python_version >= "3.12"]
37
- numpy<2.0.0,>=1.26.0
37
+ numpy>=1.26.0
38
38
  pandas>=2.1.1
39
39
  scikit-learn>=1.3.1
40
40
  scipy>=1.12.0
@@ -61,7 +61,17 @@ tox<5,>=2.9.1
61
61
  importlib-metadata>=3.6
62
62
  invoke
63
63
 
64
+ [realtabformer]
65
+ realtabformer>=0.2.2
66
+
67
+ [realtabformer:python_version >= "3.12"]
68
+ torch>=2.2.0
69
+
70
+ [realtabformer:python_version >= "3.8" and python_version < "3.12"]
71
+ torch>=2.0.0
72
+
64
73
  [test]
74
+ sdgym[realtabformer]
65
75
  pytest>=6.2.5
66
76
  pytest-cov>=2.6.0
67
77
  jupyter<2,>=1.0.0
@@ -0,0 +1,207 @@
1
+ """Tests for the ``tasks.py`` file."""
2
+
3
+ from tasks import _get_extra_dependencies, _get_minimum_versions, _resolve_version_conflicts
4
+
5
+
6
+ def test_get_minimum_versions():
7
+ """Test the ``_get_minimum_versions`` method.
8
+
9
+ The method should return the minimum versions of the dependencies for the given python version.
10
+ If a library is linked to an URL, the minimum version should be the URL.
11
+ """
12
+ # Setup
13
+ dependencies = [
14
+ "numpy>=1.20.0,<2;python_version<'3.10'",
15
+ "numpy>=1.23.3,<2;python_version>='3.10'",
16
+ "pandas>=1.2.0,<2;python_version<'3.10'",
17
+ "pandas>=1.3.0,<2;python_version>='3.10'",
18
+ 'humanfriendly>=8.2,<11',
19
+ 'pandas @ git+https://github.com/pandas-dev/pandas.git@master',
20
+ ]
21
+
22
+ # Run
23
+ minimum_versions_39 = _get_minimum_versions(dependencies, '3.9')
24
+ minimum_versions_310 = _get_minimum_versions(dependencies, '3.10')
25
+
26
+ # Assert
27
+ expected_versions_39 = {
28
+ 'numpy': 'numpy==1.20.0',
29
+ 'pandas': 'git+https://github.com/pandas-dev/pandas.git@master#egg=pandas',
30
+ 'humanfriendly': 'humanfriendly==8.2',
31
+ }
32
+ expected_versions_310 = {
33
+ 'numpy': 'numpy==1.23.3',
34
+ 'pandas': 'git+https://github.com/pandas-dev/pandas.git@master#egg=pandas',
35
+ 'humanfriendly': 'humanfriendly==8.2',
36
+ }
37
+
38
+ assert minimum_versions_39 == expected_versions_39
39
+ assert minimum_versions_310 == expected_versions_310
40
+
41
+
42
+ def _get_example_pyproject_dict():
43
+ return {
44
+ 'build-system': {
45
+ 'build-backend': 'setuptools.build_meta',
46
+ 'requires': ['setuptools', 'wheel'],
47
+ },
48
+ 'project': {
49
+ 'authors': [{'email': 'info@sdv.dev', 'name': 'DataCebo, Inc.'}],
50
+ 'classifiers': [
51
+ 'Intended Audience :: Developers',
52
+ 'License :: Free for non-commercial use',
53
+ 'Natural Language :: English',
54
+ 'Programming Language :: Python :: 3.10',
55
+ 'Programming Language :: Python :: 3.11',
56
+ 'Programming Language :: Python :: 3.12',
57
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
58
+ ],
59
+ 'dependencies': [
60
+ 'appdirs>=1.3',
61
+ 'boto3>=1.28,<2',
62
+ 'botocore>=1.31,<2',
63
+ 'cloudpickle>=2.1.0',
64
+ 'compress-pickle>=1.2.0',
65
+ 'humanfriendly>=8.2',
66
+ "numpy>=1.21.6;python_version<'3.10'",
67
+ "numpy>=1.23.3;python_version>='3.10' and python_version<'3.12'",
68
+ "numpy>=1.26.0;python_version>='3.12'",
69
+ "pandas>=1.4.0;python_version<'3.11'",
70
+ "pandas>=1.5.0;python_version>='3.11' and python_version<'3.12'",
71
+ "pandas>=2.1.1;python_version>='3.12'",
72
+ 'psutil>=5.7',
73
+ "scikit-learn>=1.0.2;python_version<'3.10'",
74
+ "scikit-learn>=1.1.0;python_version>='3.10' and python_version<'3.11'",
75
+ "scikit-learn>=1.1.3;python_version>='3.11' and python_version<'3.12'",
76
+ "scikit-learn>=1.3.1;python_version>='3.12'",
77
+ "scipy>=1.7.3;python_version<'3.10'",
78
+ "scipy>=1.9.2;python_version>='3.10' and python_version<'3.12'",
79
+ "scipy>=1.12.0;python_version>='3.12'",
80
+ 'tabulate>=0.8.3,<0.9',
81
+ "torch>=1.12.1;python_version<'3.10'",
82
+ "torch>=2.0.0;python_version>='3.10' and python_version<'3.12'",
83
+ "torch>=2.2.0;python_version>='3.12'",
84
+ 'tqdm>=4.66.3',
85
+ 'XlsxWriter>=1.2.8',
86
+ 'rdt>=1.13.1',
87
+ 'sdmetrics>=0.17.0',
88
+ 'sdv>=1.17.2',
89
+ ],
90
+ 'dynamic': ['version'],
91
+ 'license': {'text': 'BSL-1.1'},
92
+ 'name': 'sdgym',
93
+ 'optional-dependencies': {
94
+ 'all': ['sdgym[dask, test, dev]'],
95
+ 'dask': ['dask', 'distributed'],
96
+ 'dev': [
97
+ 'sdgym[dask, test]',
98
+ 'build>=1.0.0,<2',
99
+ 'bump-my-version>=0.18.3,<1',
100
+ 'pip>=9.0.1',
101
+ 'watchdog>=1.0.1,<5',
102
+ 'ruff>=0.4.5,<1',
103
+ 'twine>=1.10.0,<6',
104
+ 'wheel>=0.30.0',
105
+ 'coverage>=4.5.12,<8',
106
+ 'tox>=2.9.1,<5',
107
+ 'importlib-metadata>=3.6',
108
+ 'invoke',
109
+ ],
110
+ 'realtabformer': ['realtabformer>=0.2.1'],
111
+ 'test': [
112
+ 'sdgym[realtabformer]',
113
+ 'pytest>=6.2.5',
114
+ ],
115
+ },
116
+ 'readme': 'README.md',
117
+ 'requires-python': '>=3.8,<3.13',
118
+ },
119
+ 'tool': {
120
+ 'bumpversion': {
121
+ 'allow_dirty': False,
122
+ 'commit': True,
123
+ 'commit_args': '',
124
+ },
125
+ 'ruff': {
126
+ 'exclude': [
127
+ 'docs',
128
+ '.tox',
129
+ '.git',
130
+ '__pycache__',
131
+ '.ipynb_checkpoints',
132
+ 'tasks.py',
133
+ ],
134
+ 'indent-width': 4,
135
+ },
136
+ },
137
+ }
138
+
139
+
140
+ def test__get_extra_dependencies():
141
+ """Test that the proper dependency strings are extracted from the pyproject dictionary."""
142
+ # Setup
143
+ pyproject_dict = _get_example_pyproject_dict()
144
+
145
+ # Run
146
+ extra_dependencies = _get_extra_dependencies(pyproject_dict)
147
+
148
+ # Assert
149
+ assert extra_dependencies == ['realtabformer>=0.2.1']
150
+
151
+
152
+ def test__resolve_version_conflicts_conflicting_versions():
153
+ """Test that any conflicts for the same dependency are resolved to the higher version."""
154
+ # Setup
155
+ deps = {
156
+ 'numpy': 'numpy==2.0.1',
157
+ 'pandas': 'pandas==2.2.1',
158
+ 'sdv': 'sdv==2.1.1',
159
+ 'rdt': 'rdt==1.1.2',
160
+ }
161
+ extra_deps = {
162
+ 'numpy': 'numpy==2.0.0',
163
+ 'pandas': 'pandas==2.3.0',
164
+ 'sdv': 'sdv==3.0.0',
165
+ 'copulas': 'copulas==0.12.0',
166
+ }
167
+
168
+ # Run
169
+ versions = _resolve_version_conflicts(deps, extra_deps)
170
+
171
+ # Assert
172
+ assert sorted(versions) == sorted([
173
+ 'numpy==2.0.1',
174
+ 'pandas==2.3.0',
175
+ 'sdv==3.0.0',
176
+ 'rdt==1.1.2',
177
+ 'copulas==0.12.0',
178
+ ])
179
+
180
+
181
+ def test__resolve_version_conflicts_pointing_to_branch():
182
+ """Test specific branches are always selected over normal version numbers."""
183
+ # Setup
184
+ deps = {
185
+ 'numpy': 'git+https://github.com/numpy-dev/numpy.git@master#egg=numpy',
186
+ 'pandas': 'pandas==2.2.1',
187
+ 'sdv': 'sdv==2.1.1',
188
+ 'rdt': 'rdt==1.1.2',
189
+ }
190
+ extra_deps = {
191
+ 'numpy': 'numpy==2.0.0',
192
+ 'pandas': 'git+https://github.com/pandas-dev/pandas.git@master#egg=pandas',
193
+ 'sdv': 'sdv==3.0.0',
194
+ 'copulas': 'copulas==0.12.0',
195
+ }
196
+
197
+ # Run
198
+ versions = _resolve_version_conflicts(deps, extra_deps)
199
+
200
+ # Assert
201
+ assert sorted(versions) == sorted([
202
+ 'git+https://github.com/numpy-dev/numpy.git@master#egg=numpy',
203
+ 'git+https://github.com/pandas-dev/pandas.git@master#egg=pandas',
204
+ 'sdv==3.0.0',
205
+ 'rdt==1.1.2',
206
+ 'copulas==0.12.0',
207
+ ])
@@ -1,39 +0,0 @@
1
- """Tests for the ``tasks.py`` file."""
2
-
3
- from tasks import _get_minimum_versions
4
-
5
-
6
- def test_get_minimum_versions():
7
- """Test the ``_get_minimum_versions`` method.
8
-
9
- The method should return the minimum versions of the dependencies for the given python version.
10
- If a library is linked to an URL, the minimum version should be the URL.
11
- """
12
- # Setup
13
- dependencies = [
14
- "numpy>=1.20.0,<2;python_version<'3.10'",
15
- "numpy>=1.23.3,<2;python_version>='3.10'",
16
- "pandas>=1.2.0,<2;python_version<'3.10'",
17
- "pandas>=1.3.0,<2;python_version>='3.10'",
18
- 'humanfriendly>=8.2,<11',
19
- 'pandas @ git+https://github.com/pandas-dev/pandas.git@master',
20
- ]
21
-
22
- # Run
23
- minimum_versions_39 = _get_minimum_versions(dependencies, '3.9')
24
- minimum_versions_310 = _get_minimum_versions(dependencies, '3.10')
25
-
26
- # Assert
27
- expected_versions_39 = [
28
- 'numpy==1.20.0',
29
- 'git+https://github.com/pandas-dev/pandas.git@master#egg=pandas',
30
- 'humanfriendly==8.2',
31
- ]
32
- expected_versions_310 = [
33
- 'numpy==1.23.3',
34
- 'git+https://github.com/pandas-dev/pandas.git@master#egg=pandas',
35
- 'humanfriendly==8.2',
36
- ]
37
-
38
- assert minimum_versions_39 == expected_versions_39
39
- assert minimum_versions_310 == expected_versions_310
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes