sdgym 0.14.1.dev0__tar.gz → 0.14.2.dev0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {sdgym-0.14.1.dev0/sdgym.egg-info → sdgym-0.14.2.dev0}/PKG-INFO +1 -1
  2. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/pyproject.toml +5 -3
  3. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/__init__.py +1 -1
  4. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/_benchmark/benchmark.py +13 -12
  5. sdgym-0.14.2.dev0/sdgym/_benchmark/credentials_utils.py +17 -0
  6. sdgym-0.14.2.dev0/sdgym/_benchmark_launcher/__init__.py +6 -0
  7. sdgym-0.14.2.dev0/sdgym/_benchmark_launcher/_instance_manager.py +115 -0
  8. sdgym-0.14.2.dev0/sdgym/_benchmark_launcher/_storage_manager.py +64 -0
  9. sdgym-0.14.2.dev0/sdgym/_benchmark_launcher/_validation.py +211 -0
  10. sdgym-0.14.2.dev0/sdgym/_benchmark_launcher/benchmark_base.yaml +9 -0
  11. sdgym-0.14.2.dev0/sdgym/_benchmark_launcher/benchmark_config.py +118 -0
  12. sdgym-0.14.2.dev0/sdgym/_benchmark_launcher/benchmark_launcher.py +421 -0
  13. sdgym-0.14.2.dev0/sdgym/_benchmark_launcher/benchmark_multi_table.yaml +180 -0
  14. sdgym-0.14.2.dev0/sdgym/_benchmark_launcher/benchmark_single_table.yaml +131 -0
  15. sdgym-0.14.2.dev0/sdgym/_benchmark_launcher/script.py +280 -0
  16. sdgym-0.14.2.dev0/sdgym/_benchmark_launcher/utils.py +241 -0
  17. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/errors.py +4 -0
  18. sdgym-0.14.2.dev0/sdgym/run_benchmark/run_benchmark.py +87 -0
  19. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/run_benchmark/utils.py +1 -0
  20. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0/sdgym.egg-info}/PKG-INFO +1 -1
  21. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym.egg-info/SOURCES.txt +11 -0
  22. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/tests/test_tasks.py +45 -0
  23. sdgym-0.14.1.dev0/sdgym/_benchmark/credentials_utils.py +0 -104
  24. sdgym-0.14.1.dev0/sdgym/run_benchmark/run_benchmark.py +0 -207
  25. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/LICENSE +0 -0
  26. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/README.md +0 -0
  27. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/_benchmark/__init__.py +0 -0
  28. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/_benchmark/config_utils.py +0 -0
  29. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/_dataset_utils.py +0 -0
  30. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/benchmark.py +0 -0
  31. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/cli/__init__.py +0 -0
  32. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/cli/__main__.py +0 -0
  33. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/cli/collect.py +0 -0
  34. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/cli/summary.py +0 -0
  35. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/cli/utils.py +0 -0
  36. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/dataset_explorer.py +0 -0
  37. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/datasets.py +0 -0
  38. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/metrics.py +0 -0
  39. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/progress.py +0 -0
  40. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/result_explorer/__init__.py +0 -0
  41. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/result_explorer/result_explorer.py +0 -0
  42. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/result_explorer/result_handler.py +0 -0
  43. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/result_writer.py +0 -0
  44. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/run_benchmark/__init__.py +0 -0
  45. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/run_benchmark/upload_benchmark_results.py +0 -0
  46. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/s3.py +0 -0
  47. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/synthesizer_descriptions.yaml +0 -0
  48. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/synthesizers/__init__.py +0 -0
  49. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/synthesizers/base.py +0 -0
  50. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/synthesizers/column.py +0 -0
  51. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/synthesizers/generate.py +0 -0
  52. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/synthesizers/identity.py +0 -0
  53. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/synthesizers/realtabformer.py +0 -0
  54. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/synthesizers/sdv.py +0 -0
  55. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/synthesizers/uniform.py +0 -0
  56. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/synthesizers/utils.py +0 -0
  57. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/utils.py +0 -0
  58. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym.egg-info/dependency_links.txt +0 -0
  59. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym.egg-info/entry_points.txt +0 -0
  60. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym.egg-info/requires.txt +0 -0
  61. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym.egg-info/top_level.txt +0 -0
  62. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/setup.cfg +0 -0
  63. {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/tests/test_scripts.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sdgym
3
- Version: 0.14.1.dev0
3
+ Version: 0.14.2.dev0
4
4
  Summary: Benchmark tabular synthetic data generators using a variety of datasets
5
5
  Author-email: "DataCebo, Inc." <info@sdv.dev>
6
6
  License-Expression: BUSL-1.1
@@ -143,11 +143,13 @@ namespaces = false
143
143
  'make.bat',
144
144
  '*.jpg',
145
145
  '*.png',
146
- '*.gif'
146
+ '*.gif',
147
+ '*.yaml'
147
148
  ]
148
149
  'sdgym' = [
149
150
  'leaderboard.csv',
150
- 'synthesizer_descriptions.yaml'
151
+ 'synthesizer_descriptions.yaml',
152
+ '_benchmark_launcher/*.yaml',
151
153
  ]
152
154
 
153
155
  [tool.setuptools.exclude-package-data]
@@ -161,7 +163,7 @@ namespaces = false
161
163
  version = {attr = 'sdgym.__version__'}
162
164
 
163
165
  [tool.bumpversion]
164
- current_version = "0.14.1.dev0"
166
+ current_version = "0.14.2.dev0"
165
167
  parse = '(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<candidate>\d+))?'
166
168
  serialize = [
167
169
  '{major}.{minor}.{patch}.{release}{candidate}',
@@ -8,7 +8,7 @@ __author__ = 'DataCebo, Inc.'
8
8
  __copyright__ = 'Copyright (c) 2022 DataCebo, Inc.'
9
9
  __email__ = 'info@sdv.dev'
10
10
  __license__ = 'BSL-1.1'
11
- __version__ = '0.14.1.dev0'
11
+ __version__ = '0.14.2.dev0'
12
12
 
13
13
  import logging
14
14
 
@@ -9,7 +9,7 @@ from sdgym._benchmark.config_utils import (
9
9
  resolve_compute_config,
10
10
  validate_compute_config,
11
11
  )
12
- from sdgym._benchmark.credentials_utils import get_credentials, sdv_install_cmd
12
+ from sdgym._benchmark.credentials_utils import sdv_install_cmd
13
13
  from sdgym.benchmark import (
14
14
  DEFAULT_MULTI_TABLE_DATASETS,
15
15
  DEFAULT_MULTI_TABLE_SYNTHESIZERS,
@@ -349,7 +349,7 @@ def _run_on_gcp(
349
349
 
350
350
  def _benchmark_compute_gcp(
351
351
  output_destination,
352
- credential_filepath,
352
+ credentials,
353
353
  compute_config,
354
354
  synthesizers,
355
355
  sdv_datasets,
@@ -364,7 +364,6 @@ def _benchmark_compute_gcp(
364
364
  ):
365
365
  """Run the SDGym benchmark on datasets for the given modality."""
366
366
  compute_config = resolve_compute_config('gcp', compute_config)
367
- credentials = get_credentials(credential_filepath)
368
367
  validate_compute_config(compute_config)
369
368
 
370
369
  s3_client = _validate_output_destination(
@@ -407,7 +406,7 @@ def _benchmark_compute_gcp(
407
406
  sdmetrics=sdmetrics,
408
407
  )
409
408
 
410
- _run_on_gcp(
409
+ instance_name = _run_on_gcp(
411
410
  output_destination=output_destination,
412
411
  synthesizers=synthesizers,
413
412
  s3_client=s3_client,
@@ -416,10 +415,12 @@ def _benchmark_compute_gcp(
416
415
  compute_config=compute_config,
417
416
  )
418
417
 
418
+ return instance_name
419
+
419
420
 
420
421
  def _benchmark_single_table_compute_gcp(
421
422
  output_destination,
422
- credential_filepath,
423
+ credentials,
423
424
  compute_config=None,
424
425
  synthesizers=DEFAULT_SINGLE_TABLE_SYNTHESIZERS,
425
426
  sdv_datasets=DEFAULT_SINGLE_TABLE_DATASETS,
@@ -436,8 +437,8 @@ def _benchmark_single_table_compute_gcp(
436
437
  Args:
437
438
  output_destination (str):
438
439
  The S3 URI where results will be stored.
439
- credential_filepath (str or Path):
440
- Path to the credentials file for AWS, GCP and SDV-Enterprise.
440
+ credentials (dict):
441
+ The credentials for AWS, GCP and SDV-Enterprise.
441
442
  compute_config (dict, optional):
442
443
  The compute configuration for the GCP instance. If None, default settings will be used.
443
444
  synthesizers (list of dict, optional):
@@ -461,7 +462,7 @@ def _benchmark_single_table_compute_gcp(
461
462
  """
462
463
  return _benchmark_compute_gcp(
463
464
  output_destination=output_destination,
464
- credential_filepath=credential_filepath,
465
+ credentials=credentials,
465
466
  compute_config=compute_config,
466
467
  synthesizers=synthesizers,
467
468
  sdv_datasets=sdv_datasets,
@@ -478,7 +479,7 @@ def _benchmark_single_table_compute_gcp(
478
479
 
479
480
  def _benchmark_multi_table_compute_gcp(
480
481
  output_destination,
481
- credential_filepath,
482
+ credentials,
482
483
  compute_config=None,
483
484
  synthesizers=DEFAULT_MULTI_TABLE_SYNTHESIZERS,
484
485
  sdv_datasets=DEFAULT_MULTI_TABLE_DATASETS,
@@ -494,8 +495,8 @@ def _benchmark_multi_table_compute_gcp(
494
495
  Args:
495
496
  output_destination (str):
496
497
  The S3 URI where results will be stored.
497
- credential_filepath (str or Path):
498
- Path to the credentials file for AWS, GCP and SDV-Enterprise.
498
+ credentials (dict):
499
+ The credentials for AWS, GCP and SDV-Enterprise.
499
500
  compute_config (dict, optional):
500
501
  The compute configuration for the GCP instance. If None, default settings will be used.
501
502
  synthesizers (list of dict, optional):
@@ -517,7 +518,7 @@ def _benchmark_multi_table_compute_gcp(
517
518
  """
518
519
  return _benchmark_compute_gcp(
519
520
  output_destination=output_destination,
520
- credential_filepath=credential_filepath,
521
+ credentials=credentials,
521
522
  compute_config=compute_config,
522
523
  synthesizers=synthesizers,
523
524
  sdv_datasets=sdv_datasets,
@@ -0,0 +1,17 @@
1
+ import textwrap
2
+
3
+
4
+ def sdv_install_cmd(credentials):
5
+ """Return the shell command to install sdv-enterprise using sdv-installer."""
6
+ sdv_creds = credentials.get('sdv_enterprise') or {}
7
+ username = sdv_creds.get('sdv_enterprise_username')
8
+ license_key = sdv_creds.get('sdv_enterprise_license_key')
9
+ if not (username and license_key):
10
+ return ''
11
+
12
+ return textwrap.dedent(f"""\
13
+ pip install sdv-installer
14
+
15
+ python -c "from sdv_installer.installation.installer import install_packages; \\
16
+ install_packages(username='{username}', license_key='{license_key}')"
17
+ """)
@@ -0,0 +1,6 @@
1
+ """Benchmark Launcher Module."""
2
+
3
+ from sdgym._benchmark_launcher.benchmark_config import BenchmarkConfig
4
+ from sdgym._benchmark_launcher.benchmark_launcher import BenchmarkLauncher
5
+
6
+ __all__ = ('BenchmarkConfig', 'BenchmarkLauncher')
@@ -0,0 +1,115 @@
1
+ import logging
2
+
3
+ from google.cloud import compute_v1
4
+ from google.oauth2 import service_account
5
+
6
+ from sdgym._benchmark_launcher._validation import _validate_gcp_credentials
7
+ from sdgym._benchmark_launcher.utils import resolve_credentials
8
+
9
+ LOGGER = logging.getLogger(__name__)
10
+
11
+
12
+ class BaseInstanceManager:
13
+ """Base class for compute-service-specific instance managers."""
14
+
15
+ def list_instances(self):
16
+ """Return non-terminated instances."""
17
+ raise NotImplementedError
18
+
19
+ def update_instance_statuses(self, instance_names, instance_name_to_status):
20
+ """Update launcher-tracked instance statuses in place."""
21
+ raise NotImplementedError
22
+
23
+ def terminate_instances(self, instance_names, verbose):
24
+ """Terminate instances and return deleted instance names."""
25
+ raise NotImplementedError
26
+
27
+
28
+ class GCPInstanceManager(BaseInstanceManager):
29
+ """Manage GCP benchmark instances."""
30
+
31
+ def __init__(self, credentials_filepath):
32
+ self.credentials_filepath = credentials_filepath
33
+
34
+ def _get_client(self):
35
+ """Build and return the GCP client and project id."""
36
+ credentials = resolve_credentials(self.credentials_filepath)
37
+ errors = _validate_gcp_credentials(credentials)
38
+ if errors:
39
+ error_message = '\n'.join(errors)
40
+ raise ValueError(f'Invalid GCP credentials:\n{error_message}')
41
+
42
+ project_id = credentials['gcp']['project_id']
43
+ gcp_credentials = service_account.Credentials.from_service_account_info(credentials['gcp'])
44
+ client = compute_v1.InstancesClient(credentials=gcp_credentials)
45
+
46
+ return client, project_id
47
+
48
+ def list_instances(self):
49
+ """List all non-terminated GCP instances."""
50
+ client, project_id = self._get_client()
51
+ instances = []
52
+ response = client.aggregated_list(project=project_id)
53
+ for _, scoped_list in response:
54
+ scoped_instances = getattr(scoped_list, 'instances', None)
55
+ if not scoped_instances:
56
+ continue
57
+
58
+ for instance in scoped_instances:
59
+ if instance.status == 'TERMINATED':
60
+ continue
61
+
62
+ instances.append({
63
+ 'id': str(instance.id),
64
+ 'name': instance.name,
65
+ 'zone': instance.zone.split('/')[-1],
66
+ 'status': instance.status,
67
+ })
68
+
69
+ return instances
70
+
71
+ def update_instance_statuses(self, instance_names, instance_name_to_status):
72
+ """Update launcher-tracked instance statuses in place."""
73
+ running_instances = self.list_instances()
74
+ running_instance_names = {instance['name'] for instance in running_instances}
75
+ for instance_name in instance_names:
76
+ if instance_name in running_instance_names:
77
+ instance_name_to_status[instance_name] = 'running'
78
+ elif instance_name_to_status.get(instance_name) == 'running':
79
+ instance_name_to_status[instance_name] = 'completed'
80
+
81
+ def terminate_instances(self, instance_names, verbose):
82
+ """Terminate GCP instances by name."""
83
+ client, project_id = self._get_client()
84
+ running_instances = self.list_instances()
85
+ running_instances_by_name = {instance['name']: instance for instance in running_instances}
86
+ instances_to_delete = [
87
+ running_instances_by_name[name]
88
+ for name in instance_names
89
+ if name in running_instances_by_name
90
+ ]
91
+
92
+ not_running = sorted(set(instance_names) - set(running_instances_by_name))
93
+ if not_running:
94
+ not_running_str = "', '".join(not_running)
95
+ LOGGER.info(
96
+ f"Some provided instance names are not currently running: '{not_running_str}'."
97
+ )
98
+
99
+ deleted_instances = []
100
+ for instance in instances_to_delete:
101
+ if verbose:
102
+ print( # noqa: T201
103
+ f"Terminating GCP instance '{instance['name']}' "
104
+ f'(id={instance["id"]}, zone={instance["zone"]})...'
105
+ )
106
+
107
+ operation = client.delete(
108
+ project=project_id,
109
+ zone=instance['zone'],
110
+ instance=instance['name'],
111
+ )
112
+ operation.result()
113
+ deleted_instances.append(instance['name'])
114
+
115
+ return deleted_instances
@@ -0,0 +1,64 @@
1
+ from sdgym._benchmark_launcher.utils import resolve_credentials
2
+ from sdgym.s3 import _list_s3_bucket_contents, get_s3_client, is_s3_path, parse_s3_path
3
+
4
+
5
+ def _validate_s3_output_destinations(instance_jobs):
6
+ """Validate that all output destinations are S3 paths."""
7
+ for instance_job in instance_jobs:
8
+ output_destination = instance_job['output_destination']
9
+ if not is_s3_path(output_destination):
10
+ raise ValueError(
11
+ f'Only S3 storage is currently supported. Found: {output_destination!r}.'
12
+ )
13
+
14
+
15
+ class BaseStorageManager:
16
+ """Base class for storage-specific managers."""
17
+
18
+ def handles_destination(self, output_destination):
19
+ """Return whether this manager supports the given destination."""
20
+ raise NotImplementedError
21
+
22
+ def list_files(self, output_destination):
23
+ """Return the files currently stored under the given destination."""
24
+ raise NotImplementedError
25
+
26
+ def get_existing_filenames(self, output_destination):
27
+ """Return the existing filenames for the given destination."""
28
+ raise NotImplementedError
29
+
30
+
31
+ class S3StorageManager(BaseStorageManager):
32
+ """Manage benchmark artifacts stored in S3."""
33
+
34
+ def __init__(self, credentials_filepath, instance_jobs):
35
+ _validate_s3_output_destinations(instance_jobs)
36
+ self.credentials_filepath = credentials_filepath
37
+
38
+ def handles_destination(self, output_destination):
39
+ """Return whether the destination is an S3 path."""
40
+ return is_s3_path(output_destination)
41
+
42
+ def _get_client(self):
43
+ """Build and return the S3 client."""
44
+ credentials = resolve_credentials(self.credentials_filepath)
45
+ aws_credentials = credentials.get('aws', {})
46
+ return get_s3_client(
47
+ aws_access_key_id=aws_credentials.get('aws_access_key_id'),
48
+ aws_secret_access_key=aws_credentials.get('aws_secret_access_key'),
49
+ )
50
+
51
+ def list_files(self, output_destination):
52
+ """List files under the provided S3 output destination."""
53
+ if not self.handles_destination(output_destination):
54
+ raise ValueError(
55
+ f'S3StorageManager only supports S3 paths. Found: {output_destination!r}.'
56
+ )
57
+
58
+ s3_client = self._get_client()
59
+ bucket_name, key_prefix = parse_s3_path(output_destination)
60
+ return _list_s3_bucket_contents(s3_client, bucket_name, key_prefix)
61
+
62
+ def get_existing_filenames(self, output_destination):
63
+ """Return the existing filenames for the given destination."""
64
+ return {obj['Key'] for obj in self.list_files(output_destination)}
@@ -0,0 +1,211 @@
1
+ from sdgym._benchmark_launcher.utils import (
2
+ _AWS_CREDENTIAL_KEYS,
3
+ _GCP_SERVICE_ACCOUNT_REQUIRED_KEYS,
4
+ _is_unique_string_list,
5
+ resolve_credentials,
6
+ )
7
+
8
+ _INJECTED_PARAMS = {
9
+ 'credentials',
10
+ 'synthesizers',
11
+ 'sdv_datasets',
12
+ 'compute_config',
13
+ 'output_destination',
14
+ }
15
+
16
+
17
+ def _as_errors(value):
18
+ if value is None:
19
+ return []
20
+ if isinstance(value, list):
21
+ return [str(v) for v in value if v]
22
+
23
+ return [str(value)]
24
+
25
+
26
+ def _format_sectioned_errors(section_errors):
27
+ parts = ['BenchmarkConfig validation failed:\n']
28
+ for section, raw in section_errors.items():
29
+ errs = _as_errors(raw)
30
+ if not errs:
31
+ continue
32
+ parts.append(f'[{section}]')
33
+ parts.extend([f'- {e}' for e in errs])
34
+ parts.append('')
35
+
36
+ return '\n'.join(parts).rstrip()
37
+
38
+
39
+ def _validate_structure(config):
40
+ errors = []
41
+ if config.modality not in ('single_table', 'multi_table'):
42
+ errors.append(
43
+ f"modality: must be 'single_table' or 'multi_table'. Found: {config.modality!r}"
44
+ )
45
+
46
+ if config.credentials_filepath is not None and not isinstance(config.credentials_filepath, str):
47
+ errors.append('credentials_filepath must be a string or None.')
48
+
49
+ expected_types = {
50
+ 'method_params': dict,
51
+ 'compute': dict,
52
+ 'instance_jobs': list,
53
+ }
54
+ for key, expected_type in expected_types.items():
55
+ value = getattr(config, key, None)
56
+ if value is None:
57
+ errors.append(f'{key}: is a required section but missing.')
58
+ elif not isinstance(value, expected_type):
59
+ errors.append(f'{key}: must be a {expected_type.__name__}. Found: {type(value)}')
60
+
61
+ compute = getattr(config, 'compute', None)
62
+ if isinstance(compute, dict):
63
+ service = compute.get('service')
64
+ if service not in ('gcp',):
65
+ errors.append(f"compute.service: must be 'gcp'. Found: {service!r}")
66
+
67
+ return sorted(errors)
68
+
69
+
70
+ def _validate_method_params(method_params, method_to_run):
71
+ errors = []
72
+ timeout = method_params.get('timeout')
73
+ if timeout is not None:
74
+ if not isinstance(timeout, int):
75
+ errors.append(
76
+ f'method_params.timeout: must be int seconds. Found: {timeout!r} ({type(timeout)})'
77
+ )
78
+ elif timeout <= 0:
79
+ errors.append('method_params.timeout: must be > 0.')
80
+
81
+ for key in ('compute_quality_score', 'compute_diagnostic_score', 'compute_privacy_score'):
82
+ value = method_params.get(key)
83
+ if value is not None and not isinstance(value, bool):
84
+ errors.append(f'method_params.{key}: must be bool. Found: {value!r} ({type(value)})')
85
+
86
+ illegal = _INJECTED_PARAMS & set(method_params)
87
+ if illegal:
88
+ errors.append(
89
+ f'method_params: must not define injected parameters {sorted(illegal)} '
90
+ f'(resolved from credentials/instance_jobs).'
91
+ )
92
+
93
+ return errors
94
+
95
+
96
+ def _validate_instance_jobs(instance_jobs):
97
+ error_message = (
98
+ "Each job in 'instance_jobs' must be a dict with an 'output_destination' (string), "
99
+ "'synthesizers' (list of unique strings), and 'datasets' (list of unique strings or "
100
+ "dict with 'include' and optional 'exclude')."
101
+ )
102
+ invalid_jobs = []
103
+ for job in instance_jobs:
104
+ if not isinstance(job, dict):
105
+ invalid_jobs.append(job)
106
+ continue
107
+
108
+ if 'datasets' not in job or 'synthesizers' not in job or 'output_destination' not in job:
109
+ invalid_jobs.append(job)
110
+ continue
111
+
112
+ synthesizers = job['synthesizers']
113
+ if not _is_unique_string_list(synthesizers):
114
+ invalid_jobs.append(job)
115
+ continue
116
+
117
+ output_destination = job['output_destination']
118
+ if not isinstance(output_destination, str) or not output_destination:
119
+ invalid_jobs.append(job)
120
+ continue
121
+
122
+ datasets = job['datasets']
123
+ if isinstance(datasets, list):
124
+ if not _is_unique_string_list(datasets):
125
+ invalid_jobs.append(job)
126
+ continue
127
+
128
+ if isinstance(datasets, dict):
129
+ include = datasets.get('include')
130
+ exclude = datasets.get('exclude')
131
+ if not _is_unique_string_list(include):
132
+ invalid_jobs.append(job)
133
+ continue
134
+
135
+ if exclude is not None and not _is_unique_string_list(exclude):
136
+ invalid_jobs.append(job)
137
+ continue
138
+
139
+ invalid_jobs.append(job)
140
+
141
+ if not invalid_jobs:
142
+ return []
143
+
144
+ invalid_jobs_str = '\n'.join(str(job) for job in invalid_jobs)
145
+
146
+ return [f'{error_message}\nInvalid jobs:\n{invalid_jobs_str}']
147
+
148
+
149
+ def _validate_aws_credentials(credentials):
150
+ errors = []
151
+ aws = credentials.get('aws', {})
152
+ if not isinstance(aws, dict):
153
+ errors.append("credentials['aws'] must be a dict.")
154
+ else:
155
+ if any(aws.values()):
156
+ for key in _AWS_CREDENTIAL_KEYS:
157
+ key = key.lower()
158
+ if aws.get(key) in (None, ''):
159
+ errors.append(f"credentials['aws']['{key}'] is missing or empty.")
160
+
161
+ return sorted(errors)
162
+
163
+
164
+ def _validate_sdv_enterprise_credentials(credentials):
165
+ errors = []
166
+ sdv = credentials.get('sdv_enterprise', {})
167
+ if not isinstance(sdv, dict):
168
+ errors.append("credentials['sdv_enterprise'] must be a dict.")
169
+ else:
170
+ username = sdv.get('sdv_enterprise_username')
171
+ license_key = sdv.get('sdv_enterprise_license_key')
172
+ message = (
173
+ "credentials['sdv_enterprise'] require both 'sdv_enterprise_username' and "
174
+ "'sdv_enterprise_license_key' to be provided and non-empty if any SDV Enterprise"
175
+ ' credential is provided.'
176
+ )
177
+ if bool(username) != bool(license_key):
178
+ errors.append(message)
179
+
180
+ return sorted(errors)
181
+
182
+
183
+ def _validate_gcp_credentials(credentials):
184
+ errors = []
185
+ gcp = credentials.get('gcp', {})
186
+ if not isinstance(gcp, dict):
187
+ errors.append("credentials['gcp'] must be a dict.")
188
+ else:
189
+ if gcp:
190
+ for key in _GCP_SERVICE_ACCOUNT_REQUIRED_KEYS:
191
+ if gcp.get(key) in (None, ''):
192
+ errors.append(f"credentials['gcp']['{key}'] is missing or empty.")
193
+
194
+ return sorted(errors)
195
+
196
+
197
+ def _validate_resolved_credentials(credentials):
198
+ errors = []
199
+ errors.extend(_validate_aws_credentials(credentials))
200
+ errors.extend(_validate_sdv_enterprise_credentials(credentials))
201
+ errors.extend(_validate_gcp_credentials(credentials))
202
+
203
+ return sorted(errors)
204
+
205
+
206
+ def _validate_credentials(credentials_filepath):
207
+ if credentials_filepath is not None and not isinstance(credentials_filepath, str):
208
+ return ['credentials_filepath: must be a string path to the credentials file or None.']
209
+
210
+ credentials = resolve_credentials(credentials_filepath)
211
+ return _validate_resolved_credentials(credentials)
@@ -0,0 +1,9 @@
1
+ method_params:
2
+ timeout: 345600
3
+ compute_quality_score: true
4
+ compute_diagnostic_score: true
5
+
6
+ compute:
7
+ service: 'gcp'
8
+
9
+ credentials_filepath: null
@@ -0,0 +1,118 @@
1
+ """Define the BenchmarkConfig class, which represents the configuration for a benchmark."""
2
+
3
+ import json
4
+ from copy import deepcopy
5
+
6
+ import yaml
7
+
8
+ from sdgym._benchmark_launcher._validation import (
9
+ _format_sectioned_errors,
10
+ _validate_credentials,
11
+ _validate_instance_jobs,
12
+ _validate_method_params,
13
+ _validate_structure,
14
+ )
15
+ from sdgym._benchmark_launcher.utils import _METHODS, CONFIG_KEYS
16
+ from sdgym.errors import BenchmarkConfigError
17
+
18
+
19
+ class BenchmarkConfig:
20
+ """BenchmarkConfig class.
21
+
22
+ This class represents the configuration for a benchmark. It can be loaded from a YAML file
23
+ or a dictionary and provides methods for validation and conversion to different formats.
24
+ The expected structure of the config is as follows:
25
+ {
26
+ 'modality': 'single_table' or 'multi_table',
27
+ 'method_params': dict of parameters to pass to the benchmark method (e.g. timeout),
28
+ 'credentials_filepath':
29
+ string specifying the path to the credentials file, if None,
30
+ credentials will be resolved from environment variables.
31
+ 'compute': dict specifying the compute configuration (e.g. service: 'gcp'),
32
+ 'instance_jobs': list of dicts, each specifying a combination of synthesizers
33
+ and datasets and output destination to run a benchmark job on. Each dict should
34
+ have the following structure:
35
+ [
36
+ {
37
+ 'synthesizers': ['synthesizer1', 'synthesizer2'],
38
+ 'datasets': ['dataset1', 'dataset2'] or {'include': [...], 'exclude': [...]},
39
+ 'output_destination': 's3://bucket/path'
40
+ },
41
+ ...
42
+ ]
43
+ }
44
+ """
45
+
46
+ def __init__(self):
47
+ self.modality = None
48
+ self.method_params = None
49
+ self.credentials_filepath = None
50
+ self.compute = {'service': None}
51
+ self.instance_jobs = []
52
+ self._is_validated = False
53
+
54
+ def to_dict(self):
55
+ """Return a python ``dict`` representation of the ``BenchmarkConfig``."""
56
+ config = {}
57
+ for key in CONFIG_KEYS:
58
+ value = getattr(self, f'{key}', None)
59
+ if value is not None:
60
+ config[key] = value
61
+
62
+ return deepcopy(config)
63
+
64
+ def __str__(self):
65
+ """Pretty print the ``BenchmarkConfig``."""
66
+ printed = json.dumps(self.to_dict(), indent=4)
67
+ return printed
68
+
69
+ def validate(self):
70
+ method_to_run = _METHODS[(self.modality, self.compute.get('service'))]
71
+ errors = _validate_structure(self)
72
+ if errors:
73
+ raise BenchmarkConfigError(_format_sectioned_errors({'structure': errors}))
74
+
75
+ section_errors = {
76
+ 'method_params': _validate_method_params(self.method_params, method_to_run),
77
+ 'credentials_filepath': _validate_credentials(self.credentials_filepath),
78
+ 'instance_jobs': _validate_instance_jobs(self.instance_jobs),
79
+ }
80
+ if any(section_errors.values()):
81
+ raise BenchmarkConfigError(_format_sectioned_errors(section_errors))
82
+
83
+ self._is_validated = True
84
+
85
+ def _validate_no_extra_keys(self, config_dict):
86
+ """Validate that the config dictionary does not contain extra keys."""
87
+ extra_keys = set(config_dict.keys()).difference(CONFIG_KEYS)
88
+ if extra_keys:
89
+ extra_keys = "', '".join(sorted(extra_keys))
90
+ valid_keys = "', '".join(sorted(CONFIG_KEYS))
91
+ raise ValueError(
92
+ f"The config dictionary contains extra keys: '{extra_keys}'. "
93
+ f"Valid keys are: '{valid_keys}'."
94
+ )
95
+
96
+ @classmethod
97
+ def load_from_dict(cls, config_dict):
98
+ """Load the BenchmarkConfig from a dict."""
99
+ instance = cls()
100
+ instance._validate_no_extra_keys(config_dict)
101
+ for attribute_name, attribute_value in config_dict.items():
102
+ setattr(instance, attribute_name, attribute_value)
103
+
104
+ return instance
105
+
106
+ @classmethod
107
+ def load_from_yaml(cls, filepath):
108
+ """Load a config from a YAML file."""
109
+ with open(filepath, 'r') as f:
110
+ config_dict = yaml.safe_load(f)
111
+
112
+ return cls.load_from_dict(config_dict)
113
+
114
+ def save_to_yaml(self, filepath):
115
+ """Save the BenchmarkConfig in a YAML file."""
116
+ config_dict = self.to_dict()
117
+ with open(filepath, 'w') as file:
118
+ yaml.dump(config_dict, file)