sdgym 0.14.1.dev0__tar.gz → 0.14.2.dev0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sdgym-0.14.1.dev0/sdgym.egg-info → sdgym-0.14.2.dev0}/PKG-INFO +1 -1
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/pyproject.toml +5 -3
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/__init__.py +1 -1
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/_benchmark/benchmark.py +13 -12
- sdgym-0.14.2.dev0/sdgym/_benchmark/credentials_utils.py +17 -0
- sdgym-0.14.2.dev0/sdgym/_benchmark_launcher/__init__.py +6 -0
- sdgym-0.14.2.dev0/sdgym/_benchmark_launcher/_instance_manager.py +115 -0
- sdgym-0.14.2.dev0/sdgym/_benchmark_launcher/_storage_manager.py +64 -0
- sdgym-0.14.2.dev0/sdgym/_benchmark_launcher/_validation.py +211 -0
- sdgym-0.14.2.dev0/sdgym/_benchmark_launcher/benchmark_base.yaml +9 -0
- sdgym-0.14.2.dev0/sdgym/_benchmark_launcher/benchmark_config.py +118 -0
- sdgym-0.14.2.dev0/sdgym/_benchmark_launcher/benchmark_launcher.py +421 -0
- sdgym-0.14.2.dev0/sdgym/_benchmark_launcher/benchmark_multi_table.yaml +180 -0
- sdgym-0.14.2.dev0/sdgym/_benchmark_launcher/benchmark_single_table.yaml +131 -0
- sdgym-0.14.2.dev0/sdgym/_benchmark_launcher/script.py +280 -0
- sdgym-0.14.2.dev0/sdgym/_benchmark_launcher/utils.py +241 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/errors.py +4 -0
- sdgym-0.14.2.dev0/sdgym/run_benchmark/run_benchmark.py +87 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/run_benchmark/utils.py +1 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0/sdgym.egg-info}/PKG-INFO +1 -1
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym.egg-info/SOURCES.txt +11 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/tests/test_tasks.py +45 -0
- sdgym-0.14.1.dev0/sdgym/_benchmark/credentials_utils.py +0 -104
- sdgym-0.14.1.dev0/sdgym/run_benchmark/run_benchmark.py +0 -207
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/LICENSE +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/README.md +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/_benchmark/__init__.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/_benchmark/config_utils.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/_dataset_utils.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/benchmark.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/cli/__init__.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/cli/__main__.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/cli/collect.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/cli/summary.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/cli/utils.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/dataset_explorer.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/datasets.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/metrics.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/progress.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/result_explorer/__init__.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/result_explorer/result_explorer.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/result_explorer/result_handler.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/result_writer.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/run_benchmark/__init__.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/run_benchmark/upload_benchmark_results.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/s3.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/synthesizer_descriptions.yaml +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/synthesizers/__init__.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/synthesizers/base.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/synthesizers/column.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/synthesizers/generate.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/synthesizers/identity.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/synthesizers/realtabformer.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/synthesizers/sdv.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/synthesizers/uniform.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/synthesizers/utils.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym/utils.py +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym.egg-info/dependency_links.txt +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym.egg-info/entry_points.txt +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym.egg-info/requires.txt +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/sdgym.egg-info/top_level.txt +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/setup.cfg +0 -0
- {sdgym-0.14.1.dev0 → sdgym-0.14.2.dev0}/tests/test_scripts.py +0 -0
|
@@ -143,11 +143,13 @@ namespaces = false
|
|
|
143
143
|
'make.bat',
|
|
144
144
|
'*.jpg',
|
|
145
145
|
'*.png',
|
|
146
|
-
'*.gif'
|
|
146
|
+
'*.gif',
|
|
147
|
+
'*.yaml'
|
|
147
148
|
]
|
|
148
149
|
'sdgym' = [
|
|
149
150
|
'leaderboard.csv',
|
|
150
|
-
'synthesizer_descriptions.yaml'
|
|
151
|
+
'synthesizer_descriptions.yaml',
|
|
152
|
+
'_benchmark_launcher/*.yaml',
|
|
151
153
|
]
|
|
152
154
|
|
|
153
155
|
[tool.setuptools.exclude-package-data]
|
|
@@ -161,7 +163,7 @@ namespaces = false
|
|
|
161
163
|
version = {attr = 'sdgym.__version__'}
|
|
162
164
|
|
|
163
165
|
[tool.bumpversion]
|
|
164
|
-
current_version = "0.14.
|
|
166
|
+
current_version = "0.14.2.dev0"
|
|
165
167
|
parse = '(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<candidate>\d+))?'
|
|
166
168
|
serialize = [
|
|
167
169
|
'{major}.{minor}.{patch}.{release}{candidate}',
|
|
@@ -9,7 +9,7 @@ from sdgym._benchmark.config_utils import (
|
|
|
9
9
|
resolve_compute_config,
|
|
10
10
|
validate_compute_config,
|
|
11
11
|
)
|
|
12
|
-
from sdgym._benchmark.credentials_utils import
|
|
12
|
+
from sdgym._benchmark.credentials_utils import sdv_install_cmd
|
|
13
13
|
from sdgym.benchmark import (
|
|
14
14
|
DEFAULT_MULTI_TABLE_DATASETS,
|
|
15
15
|
DEFAULT_MULTI_TABLE_SYNTHESIZERS,
|
|
@@ -349,7 +349,7 @@ def _run_on_gcp(
|
|
|
349
349
|
|
|
350
350
|
def _benchmark_compute_gcp(
|
|
351
351
|
output_destination,
|
|
352
|
-
|
|
352
|
+
credentials,
|
|
353
353
|
compute_config,
|
|
354
354
|
synthesizers,
|
|
355
355
|
sdv_datasets,
|
|
@@ -364,7 +364,6 @@ def _benchmark_compute_gcp(
|
|
|
364
364
|
):
|
|
365
365
|
"""Run the SDGym benchmark on datasets for the given modality."""
|
|
366
366
|
compute_config = resolve_compute_config('gcp', compute_config)
|
|
367
|
-
credentials = get_credentials(credential_filepath)
|
|
368
367
|
validate_compute_config(compute_config)
|
|
369
368
|
|
|
370
369
|
s3_client = _validate_output_destination(
|
|
@@ -407,7 +406,7 @@ def _benchmark_compute_gcp(
|
|
|
407
406
|
sdmetrics=sdmetrics,
|
|
408
407
|
)
|
|
409
408
|
|
|
410
|
-
_run_on_gcp(
|
|
409
|
+
instance_name = _run_on_gcp(
|
|
411
410
|
output_destination=output_destination,
|
|
412
411
|
synthesizers=synthesizers,
|
|
413
412
|
s3_client=s3_client,
|
|
@@ -416,10 +415,12 @@ def _benchmark_compute_gcp(
|
|
|
416
415
|
compute_config=compute_config,
|
|
417
416
|
)
|
|
418
417
|
|
|
418
|
+
return instance_name
|
|
419
|
+
|
|
419
420
|
|
|
420
421
|
def _benchmark_single_table_compute_gcp(
|
|
421
422
|
output_destination,
|
|
422
|
-
|
|
423
|
+
credentials,
|
|
423
424
|
compute_config=None,
|
|
424
425
|
synthesizers=DEFAULT_SINGLE_TABLE_SYNTHESIZERS,
|
|
425
426
|
sdv_datasets=DEFAULT_SINGLE_TABLE_DATASETS,
|
|
@@ -436,8 +437,8 @@ def _benchmark_single_table_compute_gcp(
|
|
|
436
437
|
Args:
|
|
437
438
|
output_destination (str):
|
|
438
439
|
The S3 URI where results will be stored.
|
|
439
|
-
|
|
440
|
-
|
|
440
|
+
credentials (dict):
|
|
441
|
+
The credentials for AWS, GCP and SDV-Enterprise.
|
|
441
442
|
compute_config (dict, optional):
|
|
442
443
|
The compute configuration for the GCP instance. If None, default settings will be used.
|
|
443
444
|
synthesizers (list of dict, optional):
|
|
@@ -461,7 +462,7 @@ def _benchmark_single_table_compute_gcp(
|
|
|
461
462
|
"""
|
|
462
463
|
return _benchmark_compute_gcp(
|
|
463
464
|
output_destination=output_destination,
|
|
464
|
-
|
|
465
|
+
credentials=credentials,
|
|
465
466
|
compute_config=compute_config,
|
|
466
467
|
synthesizers=synthesizers,
|
|
467
468
|
sdv_datasets=sdv_datasets,
|
|
@@ -478,7 +479,7 @@ def _benchmark_single_table_compute_gcp(
|
|
|
478
479
|
|
|
479
480
|
def _benchmark_multi_table_compute_gcp(
|
|
480
481
|
output_destination,
|
|
481
|
-
|
|
482
|
+
credentials,
|
|
482
483
|
compute_config=None,
|
|
483
484
|
synthesizers=DEFAULT_MULTI_TABLE_SYNTHESIZERS,
|
|
484
485
|
sdv_datasets=DEFAULT_MULTI_TABLE_DATASETS,
|
|
@@ -494,8 +495,8 @@ def _benchmark_multi_table_compute_gcp(
|
|
|
494
495
|
Args:
|
|
495
496
|
output_destination (str):
|
|
496
497
|
The S3 URI where results will be stored.
|
|
497
|
-
|
|
498
|
-
|
|
498
|
+
credentials (dict):
|
|
499
|
+
The credentials for AWS, GCP and SDV-Enterprise.
|
|
499
500
|
compute_config (dict, optional):
|
|
500
501
|
The compute configuration for the GCP instance. If None, default settings will be used.
|
|
501
502
|
synthesizers (list of dict, optional):
|
|
@@ -517,7 +518,7 @@ def _benchmark_multi_table_compute_gcp(
|
|
|
517
518
|
"""
|
|
518
519
|
return _benchmark_compute_gcp(
|
|
519
520
|
output_destination=output_destination,
|
|
520
|
-
|
|
521
|
+
credentials=credentials,
|
|
521
522
|
compute_config=compute_config,
|
|
522
523
|
synthesizers=synthesizers,
|
|
523
524
|
sdv_datasets=sdv_datasets,
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import textwrap
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def sdv_install_cmd(credentials):
|
|
5
|
+
"""Return the shell command to install sdv-enterprise using sdv-installer."""
|
|
6
|
+
sdv_creds = credentials.get('sdv_enterprise') or {}
|
|
7
|
+
username = sdv_creds.get('sdv_enterprise_username')
|
|
8
|
+
license_key = sdv_creds.get('sdv_enterprise_license_key')
|
|
9
|
+
if not (username and license_key):
|
|
10
|
+
return ''
|
|
11
|
+
|
|
12
|
+
return textwrap.dedent(f"""\
|
|
13
|
+
pip install sdv-installer
|
|
14
|
+
|
|
15
|
+
python -c "from sdv_installer.installation.installer import install_packages; \\
|
|
16
|
+
install_packages(username='{username}', license_key='{license_key}')"
|
|
17
|
+
""")
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from google.cloud import compute_v1
|
|
4
|
+
from google.oauth2 import service_account
|
|
5
|
+
|
|
6
|
+
from sdgym._benchmark_launcher._validation import _validate_gcp_credentials
|
|
7
|
+
from sdgym._benchmark_launcher.utils import resolve_credentials
|
|
8
|
+
|
|
9
|
+
LOGGER = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BaseInstanceManager:
|
|
13
|
+
"""Base class for compute-service-specific instance managers."""
|
|
14
|
+
|
|
15
|
+
def list_instances(self):
|
|
16
|
+
"""Return non-terminated instances."""
|
|
17
|
+
raise NotImplementedError
|
|
18
|
+
|
|
19
|
+
def update_instance_statuses(self, instance_names, instance_name_to_status):
|
|
20
|
+
"""Update launcher-tracked instance statuses in place."""
|
|
21
|
+
raise NotImplementedError
|
|
22
|
+
|
|
23
|
+
def terminate_instances(self, instance_names, verbose):
|
|
24
|
+
"""Terminate instances and return deleted instance names."""
|
|
25
|
+
raise NotImplementedError
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class GCPInstanceManager(BaseInstanceManager):
|
|
29
|
+
"""Manage GCP benchmark instances."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, credentials_filepath):
|
|
32
|
+
self.credentials_filepath = credentials_filepath
|
|
33
|
+
|
|
34
|
+
def _get_client(self):
|
|
35
|
+
"""Build and return the GCP client and project id."""
|
|
36
|
+
credentials = resolve_credentials(self.credentials_filepath)
|
|
37
|
+
errors = _validate_gcp_credentials(credentials)
|
|
38
|
+
if errors:
|
|
39
|
+
error_message = '\n'.join(errors)
|
|
40
|
+
raise ValueError(f'Invalid GCP credentials:\n{error_message}')
|
|
41
|
+
|
|
42
|
+
project_id = credentials['gcp']['project_id']
|
|
43
|
+
gcp_credentials = service_account.Credentials.from_service_account_info(credentials['gcp'])
|
|
44
|
+
client = compute_v1.InstancesClient(credentials=gcp_credentials)
|
|
45
|
+
|
|
46
|
+
return client, project_id
|
|
47
|
+
|
|
48
|
+
def list_instances(self):
|
|
49
|
+
"""List all non-terminated GCP instances."""
|
|
50
|
+
client, project_id = self._get_client()
|
|
51
|
+
instances = []
|
|
52
|
+
response = client.aggregated_list(project=project_id)
|
|
53
|
+
for _, scoped_list in response:
|
|
54
|
+
scoped_instances = getattr(scoped_list, 'instances', None)
|
|
55
|
+
if not scoped_instances:
|
|
56
|
+
continue
|
|
57
|
+
|
|
58
|
+
for instance in scoped_instances:
|
|
59
|
+
if instance.status == 'TERMINATED':
|
|
60
|
+
continue
|
|
61
|
+
|
|
62
|
+
instances.append({
|
|
63
|
+
'id': str(instance.id),
|
|
64
|
+
'name': instance.name,
|
|
65
|
+
'zone': instance.zone.split('/')[-1],
|
|
66
|
+
'status': instance.status,
|
|
67
|
+
})
|
|
68
|
+
|
|
69
|
+
return instances
|
|
70
|
+
|
|
71
|
+
def update_instance_statuses(self, instance_names, instance_name_to_status):
|
|
72
|
+
"""Update launcher-tracked instance statuses in place."""
|
|
73
|
+
running_instances = self.list_instances()
|
|
74
|
+
running_instance_names = {instance['name'] for instance in running_instances}
|
|
75
|
+
for instance_name in instance_names:
|
|
76
|
+
if instance_name in running_instance_names:
|
|
77
|
+
instance_name_to_status[instance_name] = 'running'
|
|
78
|
+
elif instance_name_to_status.get(instance_name) == 'running':
|
|
79
|
+
instance_name_to_status[instance_name] = 'completed'
|
|
80
|
+
|
|
81
|
+
def terminate_instances(self, instance_names, verbose):
|
|
82
|
+
"""Terminate GCP instances by name."""
|
|
83
|
+
client, project_id = self._get_client()
|
|
84
|
+
running_instances = self.list_instances()
|
|
85
|
+
running_instances_by_name = {instance['name']: instance for instance in running_instances}
|
|
86
|
+
instances_to_delete = [
|
|
87
|
+
running_instances_by_name[name]
|
|
88
|
+
for name in instance_names
|
|
89
|
+
if name in running_instances_by_name
|
|
90
|
+
]
|
|
91
|
+
|
|
92
|
+
not_running = sorted(set(instance_names) - set(running_instances_by_name))
|
|
93
|
+
if not_running:
|
|
94
|
+
not_running_str = "', '".join(not_running)
|
|
95
|
+
LOGGER.info(
|
|
96
|
+
f"Some provided instance names are not currently running: '{not_running_str}'."
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
deleted_instances = []
|
|
100
|
+
for instance in instances_to_delete:
|
|
101
|
+
if verbose:
|
|
102
|
+
print( # noqa: T201
|
|
103
|
+
f"Terminating GCP instance '{instance['name']}' "
|
|
104
|
+
f'(id={instance["id"]}, zone={instance["zone"]})...'
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
operation = client.delete(
|
|
108
|
+
project=project_id,
|
|
109
|
+
zone=instance['zone'],
|
|
110
|
+
instance=instance['name'],
|
|
111
|
+
)
|
|
112
|
+
operation.result()
|
|
113
|
+
deleted_instances.append(instance['name'])
|
|
114
|
+
|
|
115
|
+
return deleted_instances
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from sdgym._benchmark_launcher.utils import resolve_credentials
|
|
2
|
+
from sdgym.s3 import _list_s3_bucket_contents, get_s3_client, is_s3_path, parse_s3_path
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def _validate_s3_output_destinations(instance_jobs):
|
|
6
|
+
"""Validate that all output destinations are S3 paths."""
|
|
7
|
+
for instance_job in instance_jobs:
|
|
8
|
+
output_destination = instance_job['output_destination']
|
|
9
|
+
if not is_s3_path(output_destination):
|
|
10
|
+
raise ValueError(
|
|
11
|
+
f'Only S3 storage is currently supported. Found: {output_destination!r}.'
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BaseStorageManager:
|
|
16
|
+
"""Base class for storage-specific managers."""
|
|
17
|
+
|
|
18
|
+
def handles_destination(self, output_destination):
|
|
19
|
+
"""Return whether this manager supports the given destination."""
|
|
20
|
+
raise NotImplementedError
|
|
21
|
+
|
|
22
|
+
def list_files(self, output_destination):
|
|
23
|
+
"""Return the files currently stored under the given destination."""
|
|
24
|
+
raise NotImplementedError
|
|
25
|
+
|
|
26
|
+
def get_existing_filenames(self, output_destination):
|
|
27
|
+
"""Return the existing filenames for the given destination."""
|
|
28
|
+
raise NotImplementedError
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class S3StorageManager(BaseStorageManager):
|
|
32
|
+
"""Manage benchmark artifacts stored in S3."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, credentials_filepath, instance_jobs):
|
|
35
|
+
_validate_s3_output_destinations(instance_jobs)
|
|
36
|
+
self.credentials_filepath = credentials_filepath
|
|
37
|
+
|
|
38
|
+
def handles_destination(self, output_destination):
|
|
39
|
+
"""Return whether the destination is an S3 path."""
|
|
40
|
+
return is_s3_path(output_destination)
|
|
41
|
+
|
|
42
|
+
def _get_client(self):
|
|
43
|
+
"""Build and return the S3 client."""
|
|
44
|
+
credentials = resolve_credentials(self.credentials_filepath)
|
|
45
|
+
aws_credentials = credentials.get('aws', {})
|
|
46
|
+
return get_s3_client(
|
|
47
|
+
aws_access_key_id=aws_credentials.get('aws_access_key_id'),
|
|
48
|
+
aws_secret_access_key=aws_credentials.get('aws_secret_access_key'),
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
def list_files(self, output_destination):
|
|
52
|
+
"""List files under the provided S3 output destination."""
|
|
53
|
+
if not self.handles_destination(output_destination):
|
|
54
|
+
raise ValueError(
|
|
55
|
+
f'S3StorageManager only supports S3 paths. Found: {output_destination!r}.'
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
s3_client = self._get_client()
|
|
59
|
+
bucket_name, key_prefix = parse_s3_path(output_destination)
|
|
60
|
+
return _list_s3_bucket_contents(s3_client, bucket_name, key_prefix)
|
|
61
|
+
|
|
62
|
+
def get_existing_filenames(self, output_destination):
|
|
63
|
+
"""Return the existing filenames for the given destination."""
|
|
64
|
+
return {obj['Key'] for obj in self.list_files(output_destination)}
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
from sdgym._benchmark_launcher.utils import (
|
|
2
|
+
_AWS_CREDENTIAL_KEYS,
|
|
3
|
+
_GCP_SERVICE_ACCOUNT_REQUIRED_KEYS,
|
|
4
|
+
_is_unique_string_list,
|
|
5
|
+
resolve_credentials,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
_INJECTED_PARAMS = {
|
|
9
|
+
'credentials',
|
|
10
|
+
'synthesizers',
|
|
11
|
+
'sdv_datasets',
|
|
12
|
+
'compute_config',
|
|
13
|
+
'output_destination',
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _as_errors(value):
|
|
18
|
+
if value is None:
|
|
19
|
+
return []
|
|
20
|
+
if isinstance(value, list):
|
|
21
|
+
return [str(v) for v in value if v]
|
|
22
|
+
|
|
23
|
+
return [str(value)]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _format_sectioned_errors(section_errors):
|
|
27
|
+
parts = ['BenchmarkConfig validation failed:\n']
|
|
28
|
+
for section, raw in section_errors.items():
|
|
29
|
+
errs = _as_errors(raw)
|
|
30
|
+
if not errs:
|
|
31
|
+
continue
|
|
32
|
+
parts.append(f'[{section}]')
|
|
33
|
+
parts.extend([f'- {e}' for e in errs])
|
|
34
|
+
parts.append('')
|
|
35
|
+
|
|
36
|
+
return '\n'.join(parts).rstrip()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _validate_structure(config):
|
|
40
|
+
errors = []
|
|
41
|
+
if config.modality not in ('single_table', 'multi_table'):
|
|
42
|
+
errors.append(
|
|
43
|
+
f"modality: must be 'single_table' or 'multi_table'. Found: {config.modality!r}"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
if config.credentials_filepath is not None and not isinstance(config.credentials_filepath, str):
|
|
47
|
+
errors.append('credentials_filepath must be a string or None.')
|
|
48
|
+
|
|
49
|
+
expected_types = {
|
|
50
|
+
'method_params': dict,
|
|
51
|
+
'compute': dict,
|
|
52
|
+
'instance_jobs': list,
|
|
53
|
+
}
|
|
54
|
+
for key, expected_type in expected_types.items():
|
|
55
|
+
value = getattr(config, key, None)
|
|
56
|
+
if value is None:
|
|
57
|
+
errors.append(f'{key}: is a required section but missing.')
|
|
58
|
+
elif not isinstance(value, expected_type):
|
|
59
|
+
errors.append(f'{key}: must be a {expected_type.__name__}. Found: {type(value)}')
|
|
60
|
+
|
|
61
|
+
compute = getattr(config, 'compute', None)
|
|
62
|
+
if isinstance(compute, dict):
|
|
63
|
+
service = compute.get('service')
|
|
64
|
+
if service not in ('gcp',):
|
|
65
|
+
errors.append(f"compute.service: must be 'gcp'. Found: {service!r}")
|
|
66
|
+
|
|
67
|
+
return sorted(errors)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _validate_method_params(method_params, method_to_run):
|
|
71
|
+
errors = []
|
|
72
|
+
timeout = method_params.get('timeout')
|
|
73
|
+
if timeout is not None:
|
|
74
|
+
if not isinstance(timeout, int):
|
|
75
|
+
errors.append(
|
|
76
|
+
f'method_params.timeout: must be int seconds. Found: {timeout!r} ({type(timeout)})'
|
|
77
|
+
)
|
|
78
|
+
elif timeout <= 0:
|
|
79
|
+
errors.append('method_params.timeout: must be > 0.')
|
|
80
|
+
|
|
81
|
+
for key in ('compute_quality_score', 'compute_diagnostic_score', 'compute_privacy_score'):
|
|
82
|
+
value = method_params.get(key)
|
|
83
|
+
if value is not None and not isinstance(value, bool):
|
|
84
|
+
errors.append(f'method_params.{key}: must be bool. Found: {value!r} ({type(value)})')
|
|
85
|
+
|
|
86
|
+
illegal = _INJECTED_PARAMS & set(method_params)
|
|
87
|
+
if illegal:
|
|
88
|
+
errors.append(
|
|
89
|
+
f'method_params: must not define injected parameters {sorted(illegal)} '
|
|
90
|
+
f'(resolved from credentials/instance_jobs).'
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
return errors
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _validate_instance_jobs(instance_jobs):
|
|
97
|
+
error_message = (
|
|
98
|
+
"Each job in 'instance_jobs' must be a dict with an 'output_destination' (string), "
|
|
99
|
+
"'synthesizers' (list of unique strings), and 'datasets' (list of unique strings or "
|
|
100
|
+
"dict with 'include' and optional 'exclude')."
|
|
101
|
+
)
|
|
102
|
+
invalid_jobs = []
|
|
103
|
+
for job in instance_jobs:
|
|
104
|
+
if not isinstance(job, dict):
|
|
105
|
+
invalid_jobs.append(job)
|
|
106
|
+
continue
|
|
107
|
+
|
|
108
|
+
if 'datasets' not in job or 'synthesizers' not in job or 'output_destination' not in job:
|
|
109
|
+
invalid_jobs.append(job)
|
|
110
|
+
continue
|
|
111
|
+
|
|
112
|
+
synthesizers = job['synthesizers']
|
|
113
|
+
if not _is_unique_string_list(synthesizers):
|
|
114
|
+
invalid_jobs.append(job)
|
|
115
|
+
continue
|
|
116
|
+
|
|
117
|
+
output_destination = job['output_destination']
|
|
118
|
+
if not isinstance(output_destination, str) or not output_destination:
|
|
119
|
+
invalid_jobs.append(job)
|
|
120
|
+
continue
|
|
121
|
+
|
|
122
|
+
datasets = job['datasets']
|
|
123
|
+
if isinstance(datasets, list):
|
|
124
|
+
if not _is_unique_string_list(datasets):
|
|
125
|
+
invalid_jobs.append(job)
|
|
126
|
+
continue
|
|
127
|
+
|
|
128
|
+
if isinstance(datasets, dict):
|
|
129
|
+
include = datasets.get('include')
|
|
130
|
+
exclude = datasets.get('exclude')
|
|
131
|
+
if not _is_unique_string_list(include):
|
|
132
|
+
invalid_jobs.append(job)
|
|
133
|
+
continue
|
|
134
|
+
|
|
135
|
+
if exclude is not None and not _is_unique_string_list(exclude):
|
|
136
|
+
invalid_jobs.append(job)
|
|
137
|
+
continue
|
|
138
|
+
|
|
139
|
+
invalid_jobs.append(job)
|
|
140
|
+
|
|
141
|
+
if not invalid_jobs:
|
|
142
|
+
return []
|
|
143
|
+
|
|
144
|
+
invalid_jobs_str = '\n'.join(str(job) for job in invalid_jobs)
|
|
145
|
+
|
|
146
|
+
return [f'{error_message}\nInvalid jobs:\n{invalid_jobs_str}']
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _validate_aws_credentials(credentials):
|
|
150
|
+
errors = []
|
|
151
|
+
aws = credentials.get('aws', {})
|
|
152
|
+
if not isinstance(aws, dict):
|
|
153
|
+
errors.append("credentials['aws'] must be a dict.")
|
|
154
|
+
else:
|
|
155
|
+
if any(aws.values()):
|
|
156
|
+
for key in _AWS_CREDENTIAL_KEYS:
|
|
157
|
+
key = key.lower()
|
|
158
|
+
if aws.get(key) in (None, ''):
|
|
159
|
+
errors.append(f"credentials['aws']['{key}'] is missing or empty.")
|
|
160
|
+
|
|
161
|
+
return sorted(errors)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _validate_sdv_enterprise_credentials(credentials):
|
|
165
|
+
errors = []
|
|
166
|
+
sdv = credentials.get('sdv_enterprise', {})
|
|
167
|
+
if not isinstance(sdv, dict):
|
|
168
|
+
errors.append("credentials['sdv_enterprise'] must be a dict.")
|
|
169
|
+
else:
|
|
170
|
+
username = sdv.get('sdv_enterprise_username')
|
|
171
|
+
license_key = sdv.get('sdv_enterprise_license_key')
|
|
172
|
+
message = (
|
|
173
|
+
"credentials['sdv_enterprise'] require both 'sdv_enterprise_username' and "
|
|
174
|
+
"'sdv_enterprise_license_key' to be provided and non-empty if any SDV Enterprise"
|
|
175
|
+
' credential is provided.'
|
|
176
|
+
)
|
|
177
|
+
if bool(username) != bool(license_key):
|
|
178
|
+
errors.append(message)
|
|
179
|
+
|
|
180
|
+
return sorted(errors)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _validate_gcp_credentials(credentials):
|
|
184
|
+
errors = []
|
|
185
|
+
gcp = credentials.get('gcp', {})
|
|
186
|
+
if not isinstance(gcp, dict):
|
|
187
|
+
errors.append("credentials['gcp'] must be a dict.")
|
|
188
|
+
else:
|
|
189
|
+
if gcp:
|
|
190
|
+
for key in _GCP_SERVICE_ACCOUNT_REQUIRED_KEYS:
|
|
191
|
+
if gcp.get(key) in (None, ''):
|
|
192
|
+
errors.append(f"credentials['gcp']['{key}'] is missing or empty.")
|
|
193
|
+
|
|
194
|
+
return sorted(errors)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _validate_resolved_credentials(credentials):
|
|
198
|
+
errors = []
|
|
199
|
+
errors.extend(_validate_aws_credentials(credentials))
|
|
200
|
+
errors.extend(_validate_sdv_enterprise_credentials(credentials))
|
|
201
|
+
errors.extend(_validate_gcp_credentials(credentials))
|
|
202
|
+
|
|
203
|
+
return sorted(errors)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _validate_credentials(credentials_filepath):
|
|
207
|
+
if credentials_filepath is not None and not isinstance(credentials_filepath, str):
|
|
208
|
+
return ['credentials_filepath: must be a string path to the credentials file or None.']
|
|
209
|
+
|
|
210
|
+
credentials = resolve_credentials(credentials_filepath)
|
|
211
|
+
return _validate_resolved_credentials(credentials)
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""Define the BenchmarkConfig class, which represents the configuration for a benchmark."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
|
|
6
|
+
import yaml
|
|
7
|
+
|
|
8
|
+
from sdgym._benchmark_launcher._validation import (
|
|
9
|
+
_format_sectioned_errors,
|
|
10
|
+
_validate_credentials,
|
|
11
|
+
_validate_instance_jobs,
|
|
12
|
+
_validate_method_params,
|
|
13
|
+
_validate_structure,
|
|
14
|
+
)
|
|
15
|
+
from sdgym._benchmark_launcher.utils import _METHODS, CONFIG_KEYS
|
|
16
|
+
from sdgym.errors import BenchmarkConfigError
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BenchmarkConfig:
|
|
20
|
+
"""BenchmarkConfig class.
|
|
21
|
+
|
|
22
|
+
This class represents the configuration for a benchmark. It can be loaded from a YAML file
|
|
23
|
+
or a dictionary and provides methods for validation and conversion to different formats.
|
|
24
|
+
The expected structure of the config is as follows:
|
|
25
|
+
{
|
|
26
|
+
'modality': 'single_table' or 'multi_table',
|
|
27
|
+
'method_params': dict of parameters to pass to the benchmark method (e.g. timeout),
|
|
28
|
+
'credentials_filepath':
|
|
29
|
+
string specifying the path to the credentials file, if None,
|
|
30
|
+
credentials will be resolved from environment variables.
|
|
31
|
+
'compute': dict specifying the compute configuration (e.g. service: 'gcp'),
|
|
32
|
+
'instance_jobs': list of dicts, each specifying a combination of synthesizers
|
|
33
|
+
and datasets and output destination to run a benchmark job on. Each dict should
|
|
34
|
+
have the following structure:
|
|
35
|
+
[
|
|
36
|
+
{
|
|
37
|
+
'synthesizers': ['synthesizer1', 'synthesizer2'],
|
|
38
|
+
'datasets': ['dataset1', 'dataset2'] or {'include': [...], 'exclude': [...]},
|
|
39
|
+
'output_destination': 's3://bucket/path'
|
|
40
|
+
},
|
|
41
|
+
...
|
|
42
|
+
]
|
|
43
|
+
}
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self):
|
|
47
|
+
self.modality = None
|
|
48
|
+
self.method_params = None
|
|
49
|
+
self.credentials_filepath = None
|
|
50
|
+
self.compute = {'service': None}
|
|
51
|
+
self.instance_jobs = []
|
|
52
|
+
self._is_validated = False
|
|
53
|
+
|
|
54
|
+
def to_dict(self):
|
|
55
|
+
"""Return a python ``dict`` representation of the ``BenchmarkConfig``."""
|
|
56
|
+
config = {}
|
|
57
|
+
for key in CONFIG_KEYS:
|
|
58
|
+
value = getattr(self, f'{key}', None)
|
|
59
|
+
if value is not None:
|
|
60
|
+
config[key] = value
|
|
61
|
+
|
|
62
|
+
return deepcopy(config)
|
|
63
|
+
|
|
64
|
+
def __str__(self):
|
|
65
|
+
"""Pretty print the ``BenchmarkConfig``."""
|
|
66
|
+
printed = json.dumps(self.to_dict(), indent=4)
|
|
67
|
+
return printed
|
|
68
|
+
|
|
69
|
+
def validate(self):
|
|
70
|
+
method_to_run = _METHODS[(self.modality, self.compute.get('service'))]
|
|
71
|
+
errors = _validate_structure(self)
|
|
72
|
+
if errors:
|
|
73
|
+
raise BenchmarkConfigError(_format_sectioned_errors({'structure': errors}))
|
|
74
|
+
|
|
75
|
+
section_errors = {
|
|
76
|
+
'method_params': _validate_method_params(self.method_params, method_to_run),
|
|
77
|
+
'credentials_filepath': _validate_credentials(self.credentials_filepath),
|
|
78
|
+
'instance_jobs': _validate_instance_jobs(self.instance_jobs),
|
|
79
|
+
}
|
|
80
|
+
if any(section_errors.values()):
|
|
81
|
+
raise BenchmarkConfigError(_format_sectioned_errors(section_errors))
|
|
82
|
+
|
|
83
|
+
self._is_validated = True
|
|
84
|
+
|
|
85
|
+
def _validate_no_extra_keys(self, config_dict):
|
|
86
|
+
"""Validate that the config dictionary does not contain extra keys."""
|
|
87
|
+
extra_keys = set(config_dict.keys()).difference(CONFIG_KEYS)
|
|
88
|
+
if extra_keys:
|
|
89
|
+
extra_keys = "', '".join(sorted(extra_keys))
|
|
90
|
+
valid_keys = "', '".join(sorted(CONFIG_KEYS))
|
|
91
|
+
raise ValueError(
|
|
92
|
+
f"The config dictionary contains extra keys: '{extra_keys}'. "
|
|
93
|
+
f"Valid keys are: '{valid_keys}'."
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
@classmethod
|
|
97
|
+
def load_from_dict(cls, config_dict):
|
|
98
|
+
"""Load the BenchmarkConfig from a dict."""
|
|
99
|
+
instance = cls()
|
|
100
|
+
instance._validate_no_extra_keys(config_dict)
|
|
101
|
+
for attribute_name, attribute_value in config_dict.items():
|
|
102
|
+
setattr(instance, attribute_name, attribute_value)
|
|
103
|
+
|
|
104
|
+
return instance
|
|
105
|
+
|
|
106
|
+
@classmethod
|
|
107
|
+
def load_from_yaml(cls, filepath):
|
|
108
|
+
"""Load a config from a YAML file."""
|
|
109
|
+
with open(filepath, 'r') as f:
|
|
110
|
+
config_dict = yaml.safe_load(f)
|
|
111
|
+
|
|
112
|
+
return cls.load_from_dict(config_dict)
|
|
113
|
+
|
|
114
|
+
def save_to_yaml(self, filepath):
|
|
115
|
+
"""Save the BenchmarkConfig in a YAML file."""
|
|
116
|
+
config_dict = self.to_dict()
|
|
117
|
+
with open(filepath, 'w') as file:
|
|
118
|
+
yaml.dump(config_dict, file)
|