sdv 1.36.3.dev0__tar.gz → 1.36.4.dev0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. {sdv-1.36.3.dev0/sdv.egg-info → sdv-1.36.4.dev0}/PKG-INFO +1 -1
  2. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/pyproject.toml +1 -1
  3. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/__init__.py +1 -1
  4. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/cag/_utils.py +42 -0
  5. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/cag/base.py +31 -0
  6. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/cag/programmable_constraint.py +36 -0
  7. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/datasets/demo.py +82 -72
  8. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/multi_table/base.py +67 -2
  9. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/single_table/base.py +68 -2
  10. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0/sdv.egg-info}/PKG-INFO +1 -1
  11. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/LICENSE +0 -0
  12. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/README.md +0 -0
  13. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/_utils.py +0 -0
  14. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/cag/__init__.py +0 -0
  15. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/cag/_errors.py +0 -0
  16. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/cag/fixed_combinations.py +0 -0
  17. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/cag/fixed_increments.py +0 -0
  18. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/cag/inequality.py +0 -0
  19. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/cag/one_hot_encoding.py +0 -0
  20. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/cag/range.py +0 -0
  21. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/constraints/__init__.py +0 -0
  22. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/constraints/base.py +0 -0
  23. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/constraints/errors.py +0 -0
  24. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/constraints/tabular.py +0 -0
  25. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/constraints/utils.py +0 -0
  26. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/data_processing/__init__.py +0 -0
  27. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/data_processing/data_processor.py +0 -0
  28. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/data_processing/datetime_formatter.py +0 -0
  29. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/data_processing/errors.py +0 -0
  30. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/data_processing/numerical_formatter.py +0 -0
  31. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/data_processing/utils.py +0 -0
  32. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/datasets/__init__.py +0 -0
  33. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/datasets/local.py +0 -0
  34. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/errors.py +0 -0
  35. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/evaluation/__init__.py +0 -0
  36. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/evaluation/_utils.py +0 -0
  37. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/evaluation/multi_table.py +0 -0
  38. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/evaluation/single_table.py +0 -0
  39. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/io/__init__.py +0 -0
  40. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/io/local/__init__.py +0 -0
  41. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/io/local/local.py +0 -0
  42. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/lite/__init__.py +0 -0
  43. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/lite/single_table.py +0 -0
  44. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/logging/__init__.py +0 -0
  45. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/logging/logger.py +0 -0
  46. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/logging/sdv_logger_config.yml +0 -0
  47. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/logging/utils.py +0 -0
  48. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/metadata/__init__.py +0 -0
  49. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/metadata/errors.py +0 -0
  50. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/metadata/metadata.py +0 -0
  51. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/metadata/metadata_upgrader.py +0 -0
  52. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/metadata/multi_table.py +0 -0
  53. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/metadata/single_table.py +0 -0
  54. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/metadata/utils.py +0 -0
  55. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/metadata/visualization.py +0 -0
  56. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/metrics/__init__.py +0 -0
  57. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/metrics/demos.py +0 -0
  58. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/metrics/relational.py +0 -0
  59. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/metrics/tabular.py +0 -0
  60. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/metrics/timeseries.py +0 -0
  61. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/multi_table/__init__.py +0 -0
  62. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/multi_table/dayz.py +0 -0
  63. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/multi_table/hma.py +0 -0
  64. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/multi_table/utils.py +0 -0
  65. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/sampling/__init__.py +0 -0
  66. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/sampling/hierarchical_sampler.py +0 -0
  67. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/sampling/independent_sampler.py +0 -0
  68. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/sampling/tabular.py +0 -0
  69. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/sequential/__init__.py +0 -0
  70. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/sequential/par.py +0 -0
  71. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/single_table/__init__.py +0 -0
  72. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/single_table/copulagan.py +0 -0
  73. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/single_table/copulas.py +0 -0
  74. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/single_table/ctgan.py +0 -0
  75. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/single_table/dayz.py +0 -0
  76. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/single_table/utils.py +0 -0
  77. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/utils/__init__.py +0 -0
  78. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/utils/mixins.py +0 -0
  79. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/utils/poc.py +0 -0
  80. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/utils/utils.py +0 -0
  81. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv/version/__init__.py +0 -0
  82. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv.egg-info/SOURCES.txt +0 -0
  83. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv.egg-info/dependency_links.txt +0 -0
  84. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv.egg-info/entry_points.txt +0 -0
  85. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv.egg-info/requires.txt +0 -0
  86. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/sdv.egg-info/top_level.txt +0 -0
  87. {sdv-1.36.3.dev0 → sdv-1.36.4.dev0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sdv
3
- Version: 1.36.3.dev0
3
+ Version: 1.36.4.dev0
4
4
  Summary: Generate synthetic data for single table, multi table and sequential data
5
5
  Author-email: "DataCebo, Inc." <info@sdv.dev>
6
6
  License-Expression: BUSL-1.1
@@ -149,7 +149,7 @@ namespaces = false
149
149
  version = {attr = 'sdv.__version__'}
150
150
 
151
151
  [tool.bumpversion]
152
- current_version = "1.36.3.dev0"
152
+ current_version = "1.36.4.dev0"
153
153
  parse = '(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<candidate>\d+))?'
154
154
  serialize = [
155
155
  '{major}.{minor}.{patch}.{release}{candidate}',
@@ -6,7 +6,7 @@
6
6
 
7
7
  __author__ = 'DataCebo, Inc.'
8
8
  __email__ = 'info@sdv.dev'
9
- __version__ = '1.36.3.dev0'
9
+ __version__ = '1.36.4.dev0'
10
10
 
11
11
 
12
12
  import sys
@@ -1,3 +1,4 @@
1
+ import importlib
1
2
  import re
2
3
  import warnings
3
4
 
@@ -227,3 +228,44 @@ def _validate_constraints_single_table(constraints, synthesizer_fitted):
227
228
  )
228
229
 
229
230
  return constraints
231
+
232
+
233
+ def load_constraint_from_dict(constraint_dict):
234
+ """Load a constraint from a constraint dictionary.
235
+
236
+ Args:
237
+ constraint_dict (dict):
238
+ A constraint dictionary containing the following keys:
239
+ - `class_name` (str): The constraint class name.
240
+ - `parameters` (dict): Dictionary of the parameters used to instantiate the constraint.
241
+
242
+ Returns:
243
+ Instance of `class_name` constraint instantiated with the given `parameters`.
244
+ """
245
+ expected_keys = {'class_name', 'parameters'}
246
+ if not isinstance(constraint_dict, dict) or set(constraint_dict.keys()) != expected_keys:
247
+ raise ValueError(
248
+ 'Invalid `constraint_dict`. Expected dictionary with keys `class_name` and '
249
+ f' `parameters`, got {constraint_dict}.'
250
+ )
251
+
252
+ class_name = constraint_dict['class_name']
253
+ parameters = constraint_dict['parameters']
254
+ if not isinstance(class_name, str):
255
+ raise ValueError('`class_name` must be a string.')
256
+
257
+ if not isinstance(parameters, dict):
258
+ raise ValueError('`parameters` must be a dict.')
259
+
260
+ cag_module = importlib.import_module('sdv.cag')
261
+ try:
262
+ sandbox_module = importlib.import_module('sdv.cag.sandbox')
263
+ sandbox_constraint = getattr(sandbox_module, class_name, None)
264
+ except ModuleNotFoundError:
265
+ sandbox_constraint = None
266
+
267
+ constraint_class = getattr(cag_module, class_name, sandbox_constraint)
268
+ if constraint_class is None:
269
+ raise ValueError(f"Unknown `constraint_class` '{class_name}'.")
270
+
271
+ return constraint_class.load_constraint_from_dict(parameters=parameters)
@@ -59,6 +59,37 @@ class BaseConstraint:
59
59
  args_string = ', '.join(custom_args)
60
60
  return f'{class_name}({args_string})'
61
61
 
62
+ def get_constraint_dict(self):
63
+ """Return the constraint as a serialiazable dict.
64
+
65
+ Returns:
66
+ dict:
67
+ A dictionary with the following keys:
68
+ - `class_name` [str]: The name of the constraint class.
69
+ - `parameters` [dict]: A dictionary of the init parameters used to
70
+ create this constraint instance.
71
+ """
72
+ args = inspect.getfullargspec(self.__init__)
73
+ keys = args.args[1:]
74
+ instanced = {}
75
+ for key in keys:
76
+ if hasattr(self, key) or hasattr(self, f'_{key}'):
77
+ instanced[key] = getattr(self, key, getattr(self, f'_{key}', None))
78
+ missing_attrs = list(set(keys) - set(instanced.keys()))
79
+ if missing_attrs:
80
+ missing_attrs = sorted(missing_attrs)
81
+ raise AttributeError(
82
+ 'Cannot convert constraint to dictionary because required parameters '
83
+ f'{missing_attrs} are not saved as attributes on the constraint.'
84
+ )
85
+
86
+ return {'class_name': self.__class__.__name__, 'parameters': instanced}
87
+
88
+ @classmethod
89
+ def load_constraint_from_dict(cls, parameters):
90
+ """Uses the given parameters to recreate an instance of the constraint."""
91
+ return cls(**parameters)
92
+
62
93
  def __init__(self):
63
94
  self.metadata = None
64
95
  self._fitted = False
@@ -1,5 +1,6 @@
1
1
  """Programmable constraints base classes."""
2
2
 
3
+ import inspect
3
4
  from copy import deepcopy
4
5
 
5
6
  from sdv.cag.base import BaseConstraint
@@ -10,6 +11,11 @@ class ProgrammableConstraint:
10
11
 
11
12
  _is_single_table = True
12
13
 
14
+ @classmethod
15
+ def load_constraint_from_dict(cls, parameters):
16
+ """Uses the given parameters to recreate an instance of the constraint."""
17
+ return cls(**parameters)
18
+
13
19
  def validate(self, metadata):
14
20
  """Validates that the metadata is compatible with the constraint and its parameters.
15
21
 
@@ -133,6 +139,36 @@ class ProgrammableConstraintHarness(BaseConstraint):
133
139
  self.table_name = getattr(self.programmable_constraint, 'table_name', None)
134
140
  self._is_single_table = self.programmable_constraint._is_single_table
135
141
 
142
+ def get_constraint_dict(self):
143
+ """Return the constraint as a serialiazable dict.
144
+
145
+ Returns:
146
+ dict:
147
+ A dictionary with the following keys:
148
+ - `class_name` [str]: The name of the constraint class.
149
+ - `parameters` [dict]: A dictionary of the init parameters used to
150
+ create this constraint instance.
151
+ """
152
+ args = inspect.getfullargspec(self.programmable_constraint.__init__)
153
+ keys = args.args[1:]
154
+ instanced = {}
155
+ constraint = self.programmable_constraint
156
+ for key in keys:
157
+ if hasattr(constraint, key) or hasattr(constraint, f'_{key}'):
158
+ instanced[key] = getattr(constraint, key, getattr(constraint, f'_{key}', None))
159
+ missing_attrs = list(set(keys) - set(instanced.keys()))
160
+ if missing_attrs:
161
+ missing_attrs = sorted(missing_attrs)
162
+ raise AttributeError(
163
+ 'Cannot convert constraint to dictionary because required parameters '
164
+ f'{missing_attrs} are not saved as attributes on the constraint.'
165
+ )
166
+
167
+ return {
168
+ 'class_name': self.programmable_constraint.__class__.__name__,
169
+ 'parameters': instanced,
170
+ }
171
+
136
172
  def _validate_constraint_with_metadata(self, metadata):
137
173
  self.programmable_constraint.validate(metadata)
138
174
 
@@ -55,6 +55,20 @@ def _create_s3_client(bucket, credentials=None):
55
55
 
56
56
 
57
57
  def _get_data_from_bucket(object_key, bucket, client):
58
+ """Get a file from an S3 bucket as a bytes object.
59
+
60
+ Args:
61
+ object_key (str):
62
+ The key of the object to get.
63
+ bucket (str):
64
+ The name of the bucket to get the object from.
65
+ client (botocore.client.S3):
66
+ S3 client.
67
+
68
+ Returns:
69
+ bytes:
70
+ The file data from the S3 object as a bytes object.
71
+ """
58
72
  response = client.get_object(Bucket=bucket, Key=object_key)
59
73
  return response['Body'].read()
60
74
 
@@ -271,11 +285,31 @@ def handle_aws_client_errors(error_message_builder):
271
285
 
272
286
 
273
287
  def _download(modality, dataset_name, bucket, credentials=None):
274
- """Download dataset resources from a bucket.
288
+ """Download dataset resources from S3 bucket and return the bytes.
289
+
290
+ Args:
291
+ modality (str):
292
+ The modality of the dataset: ``'single_table'``, ``'multi_table'``,
293
+ ``'sequential'``.
294
+ dataset_name (str):
295
+ The name of the dataset.
296
+ bucket (str):
297
+ The name of the bucket to download from.
298
+ credentials (dict or None):
299
+ Dictionary containing DataCebo license key and username. It takes the form:
300
+ { 'username': 'example@datacebo.com', 'license_key': '<MY_LICENSE_KEY>' }
301
+ If None, the function will use the default credentials.
275
302
 
276
303
  Returns:
277
- tuple:
278
- (BytesIO(zip_bytes), metadata_bytes)
304
+ tuple[BytesIO, bytes]:
305
+ (data_bytes, metadata_bytes)
306
+ The data is bytes of the ``data.zip`` and
307
+ ``metadata_bytes`` is the raw bytes of the metadata JSON.
308
+
309
+ Raises:
310
+ DemoResourceNotFoundError:
311
+ If the dataset prefix is missing in the bucket, if ``data.zip`` is
312
+ missing, or if no V1 metadata file is present.
279
313
  """
280
314
  client = _create_s3_client(bucket=bucket, credentials=credentials)
281
315
  dataset_prefix = f'{modality}/{dataset_name}/'
@@ -286,86 +320,64 @@ def _download(modality, dataset_name, bucket, credentials=None):
286
320
  )
287
321
  contents = _list_objects(dataset_prefix, bucket=bucket, client=client)
288
322
  zip_key = _find_data_zip_key(contents, dataset_prefix, bucket)
289
- zip_bytes = _get_data_from_bucket(zip_key, bucket=bucket, client=client)
323
+ data_bytes = io.BytesIO(_get_data_from_bucket(zip_key, bucket=bucket, client=client))
290
324
  metadata_bytes = _get_first_v1_metadata_bytes(
291
325
  contents, dataset_prefix, bucket=bucket, client=client
292
326
  )
293
327
 
294
- return io.BytesIO(zip_bytes), metadata_bytes
328
+ return data_bytes, metadata_bytes
295
329
 
296
330
 
297
- def _extract_data(bytes_io, output_folder_name):
298
- with ZipFile(bytes_io) as zf:
299
- if output_folder_name:
300
- os.makedirs(output_folder_name, exist_ok=True)
301
- zf.extractall(output_folder_name)
302
-
303
- else:
304
- in_memory_directory = {}
305
- for name in zf.namelist():
306
- in_memory_directory[name] = zf.read(name)
331
+ def _load_data_from_zip(zip_bytes, bucket, dataset_name, output_folder_name=None):
332
+ """Load CSV tables from in-memory zip bytes into a dict of DataFrames.
307
333
 
308
- return in_memory_directory
334
+ This function iterates over the zip bytes and parses each CSV entry with
335
+ ``pandas.read_csv``. Non-CSV entries are recorded as skipped. When
336
+ ``output_folder_name`` is provided, the archive is also extracted to disk
337
+ as a side effect so the caller keeps a local copy.
309
338
 
339
+ Args:
340
+ zip_bytes (io.BytesIO):
341
+ File-like object containing the bytes of ``data.zip``.
342
+ bucket (str):
343
+ The name of the bucket the zip was downloaded from. Used only for
344
+ error messages.
345
+ dataset_name (str):
346
+ The name of the dataset. Used only for error messages.
347
+ output_folder_name (str or None):
348
+ Optional folder path where the zip will also be extracted to disk so
349
+ the user keeps a local copy. The returned data dict is always built
350
+ in-memory. If ``None``, no folder is created.
310
351
 
311
- def _get_data_with_output_folder(output_folder_name):
312
- """Load CSV tables from an extracted folder on disk.
352
+ Returns:
353
+ dict[str, pandas.DataFrame]:
354
+ Mapping of table name to DataFrame.
313
355
 
314
- Returns a tuple of (data_dict, skipped_files).
315
- Non-CSV files are ignored.
356
+ Raises:
357
+ DemoResourceNotFoundError:
358
+ If the zip contains no valid CSV entries.
316
359
  """
317
360
  data = {}
318
361
  skipped_files = []
319
- for root, _dirs, files in os.walk(output_folder_name):
320
- for filename in files:
362
+ with ZipFile(zip_bytes, 'r') as z:
363
+ if output_folder_name:
364
+ os.makedirs(output_folder_name, exist_ok=True)
365
+ z.extractall(output_folder_name)
366
+
367
+ for filename in z.namelist():
321
368
  if not filename.lower().endswith('.csv'):
322
369
  skipped_files.append(filename)
323
370
  continue
324
371
 
325
372
  table_name = Path(filename).stem
326
- data_path = os.path.join(root, filename)
327
373
  try:
328
- data[table_name] = pd.read_csv(data_path)
374
+ with z.open(filename) as f:
375
+ data[table_name] = pd.read_csv(f, low_memory=False)
329
376
  except UnicodeDecodeError:
330
- data[table_name] = pd.read_csv(data_path, encoding=FALLBACK_ENCODING)
377
+ with z.open(filename) as f:
378
+ data[table_name] = pd.read_csv(f, low_memory=False, encoding=FALLBACK_ENCODING)
331
379
  except Exception as e:
332
- rel = os.path.relpath(data_path, output_folder_name)
333
- skipped_files.append(f'{rel}: {e}')
334
-
335
- return data, skipped_files
336
-
337
-
338
- def _get_data_without_output_folder(in_memory_directory):
339
- """Load CSV tables directly from in-memory zip contents.
340
-
341
- Returns a tuple of (data_dict, skipped_files).
342
- Non-CSV entries are ignored.
343
- """
344
- data = {}
345
- skipped_files = []
346
- for filename, file_ in in_memory_directory.items():
347
- if not filename.lower().endswith('.csv'):
348
- skipped_files.append(filename)
349
- continue
350
-
351
- table_name = Path(filename).stem
352
- try:
353
- data[table_name] = pd.read_csv(io.BytesIO(file_), low_memory=False)
354
- except UnicodeDecodeError:
355
- data[table_name] = pd.read_csv(
356
- io.BytesIO(file_), low_memory=False, encoding=FALLBACK_ENCODING
357
- )
358
- except Exception as e:
359
- skipped_files.append(f'{filename}: {e}')
360
-
361
- return data, skipped_files
362
-
363
-
364
- def _get_data(modality, output_folder_name, in_memory_directory, bucket, dataset_name):
365
- if output_folder_name:
366
- data, skipped_files = _get_data_with_output_folder(output_folder_name)
367
- else:
368
- data, skipped_files = _get_data_without_output_folder(in_memory_directory)
380
+ skipped_files.append(f'{filename}: {e}')
369
381
 
370
382
  if skipped_files:
371
383
  warnings.warn('Skipped files: ' + ', '.join(sorted(skipped_files)))
@@ -376,9 +388,6 @@ def _get_data(modality, output_folder_name, in_memory_directory, bucket, dataset
376
388
  'The dataset is missing `csv` file/s.'
377
389
  )
378
390
 
379
- if modality != 'multi_table':
380
- data = data.popitem()[1]
381
-
382
391
  return data
383
392
 
384
393
 
@@ -464,17 +473,18 @@ def download_demo(
464
473
  _validate_modalities(modality)
465
474
  _validate_output_folder(output_folder_name)
466
475
 
467
- data_io, metadata_bytes = _download(modality, dataset_name, s3_bucket_name, credentials)
468
- in_memory_directory = _extract_data(data_io, output_folder_name)
469
- data = _get_data(
476
+ data, metadata_bytes = _download(
470
477
  modality,
471
- output_folder_name,
472
- in_memory_directory,
473
- s3_bucket_name,
474
478
  dataset_name,
479
+ s3_bucket_name,
480
+ credentials,
475
481
  )
476
- metadata = _get_metadata(metadata_bytes, dataset_name, output_folder_name)
482
+ data = _load_data_from_zip(data, s3_bucket_name, dataset_name, output_folder_name)
477
483
 
484
+ if modality != 'multi_table':
485
+ data = data.popitem()[1]
486
+
487
+ metadata = _get_metadata(metadata_bytes, dataset_name, output_folder_name)
478
488
  return data, metadata
479
489
 
480
490
 
@@ -3,10 +3,13 @@
3
3
  import contextlib
4
4
  import datetime
5
5
  import inspect
6
+ import json
6
7
  import operator
8
+ import traceback
7
9
  import warnings
8
10
  from collections import defaultdict
9
11
  from copy import deepcopy
12
+ from pathlib import Path
10
13
 
11
14
  import cloudpickle
12
15
  import numpy as np
@@ -25,6 +28,7 @@ from sdv.cag._utils import (
25
28
  _convert_to_snake_case,
26
29
  _get_invalid_rows,
27
30
  _validate_constraints,
31
+ load_constraint_from_dict,
28
32
  )
29
33
  from sdv.cag.programmable_constraint import ProgrammableConstraint, ProgrammableConstraintHarness
30
34
  from sdv.errors import (
@@ -269,8 +273,7 @@ class BaseMultiTableSynthesizer:
269
273
 
270
274
  self._single_table_constraints += single_table_constraints
271
275
 
272
- def get_constraints(self):
273
- """Get a copy of the list of constraints applied to the synthesizer."""
276
+ def _get_all_constraints_list(self):
274
277
  if not hasattr(self, 'constraints'):
275
278
  return []
276
279
  constraints = []
@@ -285,6 +288,68 @@ class BaseMultiTableSynthesizer:
285
288
 
286
289
  return constraints
287
290
 
291
+ def get_constraints(self, filepath=None):
292
+ """Get a list of constraint-augmented generation constraints applied to the synthesizer.
293
+
294
+ If `filepath` is provided, will save the constraints in JSON format to `filepath`.
295
+ Otherwise, will return a list of the constraints applied to the synthesizer.
296
+
297
+ Args:
298
+ filepath (str, optional):
299
+ Path where the constraints applied to the synthesizer will be serialized. If `None`,
300
+ will return a list of the applied constraints instead. Defaults to `None`.
301
+ """
302
+ constraints = self._get_all_constraints_list()
303
+ if filepath is None:
304
+ return constraints
305
+
306
+ path = Path(filepath)
307
+ if path.exists():
308
+ raise ValueError(
309
+ f"Cannot save constraints to file because '{filepath}' already exists."
310
+ )
311
+
312
+ constraints_dict_list = [constraint.get_constraint_dict() for constraint in constraints]
313
+ with open(path, 'w') as file:
314
+ json.dump(constraints_dict_list, file, indent=4)
315
+
316
+ def set_constraints(self, filepath):
317
+ """Add all the constraints in the file to the synthesizer.
318
+
319
+ If any constraints have been added to the synthesizer, they will be removed before
320
+ the constraints from the file are set.
321
+
322
+ Args:
323
+ filepath (str):
324
+ The string path to the file containing the constraints to set on the synthesizer.
325
+ """
326
+ if self.get_constraints():
327
+ raise SynthesizerInputError(
328
+ 'Cannot `set_constraints` since constraints have already been applied.'
329
+ )
330
+
331
+ with open(filepath, 'r') as f:
332
+ constraints_json = json.load(f)
333
+
334
+ constraint_list = []
335
+ for constraint_dict in constraints_json:
336
+ try:
337
+ constraint_list.append(load_constraint_from_dict(constraint_dict))
338
+ except Exception as e:
339
+ warnings.warn(
340
+ f'Could not load constraint ({constraint_dict}):\n'
341
+ f' {traceback.format_exception_only(type(e), e)[0]}'
342
+ )
343
+
344
+ for constraint in constraint_list:
345
+ try:
346
+ self.add_constraints([constraint])
347
+ except Exception as e:
348
+ warnings.warn(
349
+ f'Could not add constraint ({constraint}):\n'
350
+ f' {traceback.format_exception_only(type(e), e)[0]}'
351
+ )
352
+
288
353
  def validate_constraints(self, synthetic_data):
289
354
  """Validate synthetic_data against the constraints.
290
355
 
@@ -3,14 +3,17 @@
3
3
  import datetime
4
4
  import functools
5
5
  import inspect
6
+ import json
6
7
  import logging
7
8
  import math
8
9
  import operator
9
10
  import os
11
+ import traceback
10
12
  import uuid
11
13
  import warnings
12
14
  from collections import defaultdict
13
15
  from copy import deepcopy
16
+ from pathlib import Path
14
17
 
15
18
  import cloudpickle
16
19
  import copulas
@@ -35,6 +38,7 @@ from sdv.cag._utils import (
35
38
  _convert_to_snake_case,
36
39
  _get_invalid_rows,
37
40
  _validate_constraints_single_table,
41
+ load_constraint_from_dict,
38
42
  )
39
43
  from sdv.cag.programmable_constraint import ProgrammableConstraint, ProgrammableConstraintHarness
40
44
  from sdv.data_processing.data_processor import DataProcessor
@@ -497,8 +501,7 @@ class BaseSynthesizer:
497
501
  locales=self.locales,
498
502
  )
499
503
 
500
- def get_constraints(self):
501
- """Get a list of constraint-augmented generation constraints applied to the synthesizer."""
504
+ def _get_all_constraints_list(self):
502
505
  constraints = []
503
506
  for constraint in self._chained_constraints + self._reject_sampling_constraints:
504
507
  if isinstance(constraint, ProgrammableConstraintHarness):
@@ -508,6 +511,69 @@ class BaseSynthesizer:
508
511
 
509
512
  return constraints
510
513
 
514
+ def get_constraints(self, filepath=None):
515
+ """Get a list of constraint-augmented generation constraints applied to the synthesizer.
516
+
517
+ If `filepath` is provided, will save the constraints in JSON format to `filepath`.
518
+ Otherwise, will return a list of the constraints applied to the synthesizer.
519
+
520
+ Args:
521
+ filepath (str, optional):
522
+ Path where the constraints applied to the synthesizer will be serialized. If `None`,
523
+ will return a list of the applied constraints instead. Defaults to `None`.
524
+ """
525
+ constraints = self._get_all_constraints_list()
526
+ if filepath is None:
527
+ return constraints
528
+
529
+ path = Path(filepath)
530
+ if path.exists():
531
+ raise ValueError(
532
+ f"Cannot save constraints to file because '{filepath}' already exists."
533
+ )
534
+
535
+ constraints_dict_list = [constraint.get_constraint_dict() for constraint in constraints]
536
+
537
+ with open(path, 'w') as file:
538
+ json.dump(constraints_dict_list, file, indent=4)
539
+
540
+ def set_constraints(self, filepath):
541
+ """Add all the constraints in the file to the synthesizer.
542
+
543
+ If any constraints have been added to the synthesizer, they will be removed before
544
+ the constraints from the file are set.
545
+
546
+ Args:
547
+ filepath (str):
548
+ The string path to the file containing the constraints to set on the synthesizer.
549
+ """
550
+ if self.get_constraints():
551
+ raise SynthesizerInputError(
552
+ 'Cannot `set_constraints` since constraints have already been applied.'
553
+ )
554
+
555
+ with open(filepath, 'r') as f:
556
+ constraints_json = json.load(f)
557
+
558
+ constraint_list = []
559
+ for constraint_dict in constraints_json:
560
+ try:
561
+ constraint_list.append(load_constraint_from_dict(constraint_dict))
562
+ except Exception as e:
563
+ warnings.warn(
564
+ f'Could not load constraint ({constraint_dict}):\n'
565
+ f' {traceback.format_exception_only(type(e), e)[0]}'
566
+ )
567
+
568
+ for constraint in constraint_list:
569
+ try:
570
+ self.add_constraints([constraint])
571
+ except Exception as e:
572
+ warnings.warn(
573
+ f'Could not add constraint ({constraint}):\n'
574
+ f' {traceback.format_exception_only(type(e), e)[0]}'
575
+ )
576
+
511
577
  def validate_constraints(self, synthetic_data):
512
578
  """Validate synthetic_data against the constraints.
513
579
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sdv
3
- Version: 1.36.3.dev0
3
+ Version: 1.36.4.dev0
4
4
  Summary: Generate synthetic data for single table, multi table and sequential data
5
5
  Author-email: "DataCebo, Inc." <info@sdv.dev>
6
6
  License-Expression: BUSL-1.1
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes