sdgym 0.12.1.dev0__tar.gz → 0.12.2.dev0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sdgym-0.12.1.dev0/sdgym.egg-info → sdgym-0.12.2.dev0}/PKG-INFO +8 -6
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/pyproject.toml +12 -7
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/__init__.py +11 -4
- sdgym-0.12.2.dev0/sdgym/_benchmark/__init__.py +1 -0
- sdgym-0.12.2.dev0/sdgym/_benchmark/benchmark.py +532 -0
- sdgym-0.12.2.dev0/sdgym/_benchmark/config_utils.py +123 -0
- sdgym-0.12.2.dev0/sdgym/_benchmark/credentials_utils.py +104 -0
- sdgym-0.12.2.dev0/sdgym/_dataset_utils.py +156 -0
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/benchmark.py +503 -415
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/cli/__main__.py +3 -15
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/dataset_explorer.py +20 -14
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/datasets.py +106 -36
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/metrics.py +7 -7
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/result_explorer/result_explorer.py +55 -22
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/result_explorer/result_handler.py +35 -18
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/result_writer.py +25 -3
- sdgym-0.12.2.dev0/sdgym/run_benchmark/__init__.py +1 -0
- sdgym-0.12.2.dev0/sdgym/run_benchmark/run_benchmark.py +152 -0
- sdgym-0.12.2.dev0/sdgym/run_benchmark/upload_benchmark_results.py +601 -0
- sdgym-0.12.2.dev0/sdgym/run_benchmark/utils.py +205 -0
- sdgym-0.12.2.dev0/sdgym/synthesizer_descriptions.yaml +105 -0
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/synthesizers/__init__.py +2 -1
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/synthesizers/base.py +73 -8
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/synthesizers/column.py +16 -10
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/synthesizers/generate.py +21 -14
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/synthesizers/identity.py +6 -9
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/synthesizers/realtabformer.py +5 -3
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/synthesizers/sdv.py +18 -16
- sdgym-0.12.2.dev0/sdgym/synthesizers/uniform.py +140 -0
- sdgym-0.12.2.dev0/sdgym/synthesizers/utils.py +37 -0
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/utils.py +1 -0
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0/sdgym.egg-info}/PKG-INFO +8 -6
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym.egg-info/SOURCES.txt +9 -0
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym.egg-info/requires.txt +7 -5
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/tests/test_tasks.py +5 -1
- sdgym-0.12.1.dev0/sdgym/_dataset_utils.py +0 -107
- sdgym-0.12.1.dev0/sdgym/synthesizers/uniform.py +0 -72
- sdgym-0.12.1.dev0/sdgym/synthesizers/utils.py +0 -51
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/LICENSE +0 -0
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/README.md +0 -0
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/cli/__init__.py +0 -0
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/cli/collect.py +0 -0
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/cli/summary.py +0 -0
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/cli/utils.py +0 -0
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/errors.py +0 -0
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/progress.py +0 -0
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/result_explorer/__init__.py +0 -0
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/s3.py +0 -0
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym.egg-info/dependency_links.txt +0 -0
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym.egg-info/entry_points.txt +0 -0
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym.egg-info/top_level.txt +0 -0
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/setup.cfg +0 -0
- {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/tests/test_scripts.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sdgym
|
|
3
|
-
Version: 0.12.
|
|
3
|
+
Version: 0.12.2.dev0
|
|
4
4
|
Summary: Benchmark tabular synthetic data generators using a variety of datasets
|
|
5
5
|
Author-email: "DataCebo, Inc." <info@sdv.dev>
|
|
6
6
|
License: BSL-1.1
|
|
@@ -29,16 +29,18 @@ Requires-Dist: boto3<2,>=1.28
|
|
|
29
29
|
Requires-Dist: botocore<2,>=1.31
|
|
30
30
|
Requires-Dist: cloudpickle>=2.1.0
|
|
31
31
|
Requires-Dist: compress-pickle>=1.2.0
|
|
32
|
+
Requires-Dist: google-cloud-compute>=1.0.0
|
|
33
|
+
Requires-Dist: google-auth>=2.0.0
|
|
32
34
|
Requires-Dist: humanfriendly>=10.0
|
|
33
35
|
Requires-Dist: numpy>=1.22.2; python_version < "3.10"
|
|
34
36
|
Requires-Dist: numpy>=1.24.0; python_version >= "3.10" and python_version < "3.12"
|
|
35
37
|
Requires-Dist: numpy>=1.26.0; python_version >= "3.12" and python_version < "3.13"
|
|
36
38
|
Requires-Dist: numpy>=2.1.0; python_version >= "3.13"
|
|
37
39
|
Requires-Dist: openpyxl>=3.1.2
|
|
38
|
-
Requires-Dist: pandas
|
|
39
|
-
Requires-Dist: pandas
|
|
40
|
-
Requires-Dist: pandas
|
|
41
|
-
Requires-Dist: pandas
|
|
40
|
+
Requires-Dist: pandas<3.0.0,>=1.4.0; python_version < "3.11"
|
|
41
|
+
Requires-Dist: pandas<3.0.0,>=1.5.0; python_version >= "3.11" and python_version < "3.12"
|
|
42
|
+
Requires-Dist: pandas<3.0.0,>=2.1.1; python_version >= "3.12" and python_version < "3.13"
|
|
43
|
+
Requires-Dist: pandas<3.0.0,>=2.2.3; python_version >= "3.13"
|
|
42
44
|
Requires-Dist: psutil>=5.7
|
|
43
45
|
Requires-Dist: scikit-learn>=1.0.2; python_version < "3.10"
|
|
44
46
|
Requires-Dist: scikit-learn>=1.1.0; python_version >= "3.10" and python_version < "3.11"
|
|
@@ -60,7 +62,7 @@ Provides-Extra: dask
|
|
|
60
62
|
Requires-Dist: dask; extra == "dask"
|
|
61
63
|
Requires-Dist: distributed; extra == "dask"
|
|
62
64
|
Provides-Extra: realtabformer
|
|
63
|
-
Requires-Dist: realtabformer
|
|
65
|
+
Requires-Dist: realtabformer!=0.2.4,>=0.2.3; extra == "realtabformer"
|
|
64
66
|
Requires-Dist: torch>=2.6.0; extra == "realtabformer"
|
|
65
67
|
Requires-Dist: transformers<4.51; extra == "realtabformer"
|
|
66
68
|
Provides-Extra: test
|
|
@@ -26,16 +26,18 @@ dependencies = [
|
|
|
26
26
|
'botocore>=1.31,<2',
|
|
27
27
|
'cloudpickle>=2.1.0',
|
|
28
28
|
'compress-pickle>=1.2.0',
|
|
29
|
+
'google-cloud-compute>=1.0.0',
|
|
30
|
+
'google-auth>=2.0.0',
|
|
29
31
|
'humanfriendly>=10.0',
|
|
30
32
|
"numpy>=1.22.2;python_version<'3.10'",
|
|
31
33
|
"numpy>=1.24.0;python_version>='3.10' and python_version<'3.12'",
|
|
32
34
|
"numpy>=1.26.0;python_version>='3.12' and python_version<'3.13'",
|
|
33
35
|
"numpy>=2.1.0;python_version>='3.13'",
|
|
34
36
|
'openpyxl>=3.1.2',
|
|
35
|
-
"pandas>=1.4.0;python_version<'3.11'",
|
|
36
|
-
"pandas>=1.5.0;python_version>='3.11' and python_version<'3.12'",
|
|
37
|
-
"pandas>=2.1.1;python_version>='3.12' and python_version<'3.13'",
|
|
38
|
-
"pandas>=2.2.3;python_version>='3.13'",
|
|
37
|
+
"pandas>=1.4.0,<3.0.0;python_version<'3.11'",
|
|
38
|
+
"pandas>=1.5.0,<3.0.0;python_version>='3.11' and python_version<'3.12'",
|
|
39
|
+
"pandas>=2.1.1,<3.0.0;python_version>='3.12' and python_version<'3.13'",
|
|
40
|
+
"pandas>=2.2.3,<3.0.0;python_version>='3.13'",
|
|
39
41
|
'psutil>=5.7',
|
|
40
42
|
"scikit-learn>=1.0.2;python_version<'3.10'",
|
|
41
43
|
"scikit-learn>=1.1.0;python_version>='3.10' and python_version<'3.11'",
|
|
@@ -68,7 +70,7 @@ sdgym = { main = 'sdgym.cli.__main__:main' }
|
|
|
68
70
|
[project.optional-dependencies]
|
|
69
71
|
dask = ['dask', 'distributed']
|
|
70
72
|
realtabformer = [
|
|
71
|
-
'realtabformer>=0.2.3',
|
|
73
|
+
'realtabformer>=0.2.3,!=0.2.4',
|
|
72
74
|
"torch>=2.6.0",
|
|
73
75
|
'transformers<4.51',
|
|
74
76
|
]
|
|
@@ -131,7 +133,10 @@ namespaces = false
|
|
|
131
133
|
'*.png',
|
|
132
134
|
'*.gif'
|
|
133
135
|
]
|
|
134
|
-
'sdgym' = [
|
|
136
|
+
'sdgym' = [
|
|
137
|
+
'leaderboard.csv',
|
|
138
|
+
'synthesizer_descriptions.yaml'
|
|
139
|
+
]
|
|
135
140
|
|
|
136
141
|
[tool.setuptools.exclude-package-data]
|
|
137
142
|
'*' = [
|
|
@@ -144,7 +149,7 @@ namespaces = false
|
|
|
144
149
|
version = {attr = 'sdgym.__version__'}
|
|
145
150
|
|
|
146
151
|
[tool.bumpversion]
|
|
147
|
-
current_version = "0.12.
|
|
152
|
+
current_version = "0.12.2.dev0"
|
|
148
153
|
parse = '(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<candidate>\d+))?'
|
|
149
154
|
serialize = [
|
|
150
155
|
'{major}.{minor}.{patch}.{release}{candidate}',
|
|
@@ -8,11 +8,16 @@ __author__ = 'DataCebo, Inc.'
|
|
|
8
8
|
__copyright__ = 'Copyright (c) 2022 DataCebo, Inc.'
|
|
9
9
|
__email__ = 'info@sdv.dev'
|
|
10
10
|
__license__ = 'BSL-1.1'
|
|
11
|
-
__version__ = '0.12.
|
|
11
|
+
__version__ = '0.12.2.dev0'
|
|
12
12
|
|
|
13
13
|
import logging
|
|
14
14
|
|
|
15
|
-
from sdgym.benchmark import
|
|
15
|
+
from sdgym.benchmark import (
|
|
16
|
+
benchmark_multi_table,
|
|
17
|
+
benchmark_single_table,
|
|
18
|
+
benchmark_single_table_aws,
|
|
19
|
+
benchmark_multi_table_aws,
|
|
20
|
+
)
|
|
16
21
|
from sdgym.cli.collect import collect_results
|
|
17
22
|
from sdgym.cli.summary import make_summary_spreadsheet
|
|
18
23
|
from sdgym.dataset_explorer import DatasetExplorer
|
|
@@ -31,12 +36,14 @@ list(map(logging.root.removeFilter, logging.root.filters))
|
|
|
31
36
|
__all__ = [
|
|
32
37
|
'DatasetExplorer',
|
|
33
38
|
'ResultsExplorer',
|
|
39
|
+
'benchmark_multi_table',
|
|
40
|
+
'benchmark_multi_table_aws',
|
|
34
41
|
'benchmark_single_table',
|
|
35
42
|
'benchmark_single_table_aws',
|
|
36
43
|
'collect_results',
|
|
37
|
-
'create_synthesizer_variant',
|
|
38
|
-
'create_single_table_synthesizer',
|
|
39
44
|
'create_multi_table_synthesizer',
|
|
45
|
+
'create_single_table_synthesizer',
|
|
46
|
+
'create_synthesizer_variant',
|
|
40
47
|
'load_dataset',
|
|
41
48
|
'make_summary_spreadsheet',
|
|
42
49
|
]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Benchmark on GCP compute instances."""
|
|
@@ -0,0 +1,532 @@
|
|
|
1
|
+
import textwrap
|
|
2
|
+
from urllib.parse import urlparse
|
|
3
|
+
|
|
4
|
+
from google.cloud import compute_v1
|
|
5
|
+
from google.oauth2 import service_account
|
|
6
|
+
|
|
7
|
+
from sdgym._benchmark.config_utils import (
|
|
8
|
+
_make_instance_name,
|
|
9
|
+
resolve_compute_config,
|
|
10
|
+
validate_compute_config,
|
|
11
|
+
)
|
|
12
|
+
from sdgym._benchmark.credentials_utils import get_credentials, sdv_install_cmd
|
|
13
|
+
from sdgym.benchmark import (
|
|
14
|
+
DEFAULT_MULTI_TABLE_DATASETS,
|
|
15
|
+
DEFAULT_MULTI_TABLE_SYNTHESIZERS,
|
|
16
|
+
DEFAULT_SINGLE_TABLE_DATASETS,
|
|
17
|
+
DEFAULT_SINGLE_TABLE_SYNTHESIZERS,
|
|
18
|
+
S3_REGION,
|
|
19
|
+
_ensure_uniform_included,
|
|
20
|
+
_generate_job_args_list,
|
|
21
|
+
_get_empty_dataframe,
|
|
22
|
+
_get_s3_script_content,
|
|
23
|
+
_import_and_validate_synthesizers,
|
|
24
|
+
_store_job_args_in_s3,
|
|
25
|
+
_validate_output_destination,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _get_logs_s3_uri(output_destination, instance_name):
|
|
30
|
+
"""Store logs next to output destination prefix.
|
|
31
|
+
|
|
32
|
+
Example:
|
|
33
|
+
output_destination='s3://bucket/prefix'
|
|
34
|
+
-> s3://bucket/prefix/logs/<instance>-user-data.log
|
|
35
|
+
"""
|
|
36
|
+
if not output_destination.startswith('s3://'):
|
|
37
|
+
return ''
|
|
38
|
+
|
|
39
|
+
parsed = urlparse(output_destination)
|
|
40
|
+
bucket = parsed.netloc
|
|
41
|
+
prefix = parsed.path.lstrip('/').rstrip('/')
|
|
42
|
+
prefix = f'{prefix}/logs' if prefix else 'logs'
|
|
43
|
+
|
|
44
|
+
return f's3://{bucket}/{prefix}/{instance_name}-user-data.log'
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _prepare_script_content(
|
|
48
|
+
output_destination,
|
|
49
|
+
synthesizers,
|
|
50
|
+
s3_client,
|
|
51
|
+
job_args_list,
|
|
52
|
+
credentials,
|
|
53
|
+
):
|
|
54
|
+
bucket_name, job_args_key = _store_job_args_in_s3(
|
|
55
|
+
output_destination,
|
|
56
|
+
job_args_list,
|
|
57
|
+
s3_client,
|
|
58
|
+
)
|
|
59
|
+
synthesizer_names = [{'name': s['name']} for s in synthesizers]
|
|
60
|
+
return _get_s3_script_content(
|
|
61
|
+
credentials['aws']['aws_access_key_id'],
|
|
62
|
+
credentials['aws']['aws_secret_access_key'],
|
|
63
|
+
S3_REGION,
|
|
64
|
+
bucket_name,
|
|
65
|
+
job_args_key,
|
|
66
|
+
synthesizer_names,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _terminate_instance(compute_service):
|
|
71
|
+
if compute_service not in ('aws', 'gcp'):
|
|
72
|
+
raise ValueError(f'Unsupported compute service: {compute_service}')
|
|
73
|
+
|
|
74
|
+
if compute_service == 'aws':
|
|
75
|
+
return textwrap.dedent(
|
|
76
|
+
"""\
|
|
77
|
+
cleanup() {
|
|
78
|
+
log "======== Kernel messages (OOM info) =========="
|
|
79
|
+
dmesg | tail -50 || true
|
|
80
|
+
upload_logs || true
|
|
81
|
+
INSTANCE_ID=$(curl -sf http://169.254.169.254/latest/meta-data/instance-id || true)
|
|
82
|
+
if [ -n "$INSTANCE_ID" ]; then
|
|
83
|
+
log "Terminating EC2 instance: $INSTANCE_ID"
|
|
84
|
+
aws ec2 terminate-instances --instance-ids "$INSTANCE_ID" >/dev/null 2>&1 || true
|
|
85
|
+
fi
|
|
86
|
+
}
|
|
87
|
+
"""
|
|
88
|
+
).strip()
|
|
89
|
+
|
|
90
|
+
# GCP
|
|
91
|
+
return textwrap.dedent(
|
|
92
|
+
"""\
|
|
93
|
+
cleanup() {
|
|
94
|
+
log "======== Kernel messages (OOM info) =========="
|
|
95
|
+
dmesg | tail -50 || true
|
|
96
|
+
upload_logs || true
|
|
97
|
+
log "Shutting down GCE instance"
|
|
98
|
+
shutdown -h now || true
|
|
99
|
+
}
|
|
100
|
+
"""
|
|
101
|
+
).strip()
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _gpu_wait_block():
|
|
105
|
+
return textwrap.dedent(
|
|
106
|
+
"""\
|
|
107
|
+
log "======== Waiting for GPU =========="
|
|
108
|
+
for i in {1..60}; do
|
|
109
|
+
if command -v nvidia-smi >/dev/null && nvidia-smi >/dev/null; then
|
|
110
|
+
nvidia-smi
|
|
111
|
+
break
|
|
112
|
+
fi
|
|
113
|
+
sleep 10
|
|
114
|
+
done
|
|
115
|
+
"""
|
|
116
|
+
).strip()
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _upload_logs(log_uri):
|
|
120
|
+
if not log_uri:
|
|
121
|
+
return 'upload_logs() { :; }'
|
|
122
|
+
|
|
123
|
+
return textwrap.dedent(
|
|
124
|
+
f"""\
|
|
125
|
+
upload_logs() {{
|
|
126
|
+
log "======== Uploading logs =========="
|
|
127
|
+
aws s3 cp /var/log/user-data.log "{log_uri}" >/dev/null 2>&1 || true
|
|
128
|
+
}}
|
|
129
|
+
"""
|
|
130
|
+
).strip()
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _get_user_data_script(
|
|
134
|
+
credentials,
|
|
135
|
+
script_content,
|
|
136
|
+
config,
|
|
137
|
+
instance_name,
|
|
138
|
+
output_destination,
|
|
139
|
+
):
|
|
140
|
+
compute_service = config['service']
|
|
141
|
+
swap_gb = int(config.get('swap_gb', 32))
|
|
142
|
+
gpu = (
|
|
143
|
+
bool(config.get('gpu'))
|
|
144
|
+
or int(config.get('gpu_count', 0)) > 0
|
|
145
|
+
or bool(config.get('gpu_type'))
|
|
146
|
+
)
|
|
147
|
+
upload_logs = bool(config.get('upload_logs', True))
|
|
148
|
+
|
|
149
|
+
aws_key = credentials['aws']['aws_access_key_id']
|
|
150
|
+
aws_secret = credentials['aws']['aws_secret_access_key']
|
|
151
|
+
|
|
152
|
+
log_uri = _get_logs_s3_uri(output_destination, instance_name) if upload_logs else ''
|
|
153
|
+
|
|
154
|
+
sdv_install = sdv_install_cmd(credentials).rstrip()
|
|
155
|
+
sdv_install = textwrap.indent(sdv_install, ' ') if sdv_install else ''
|
|
156
|
+
terminate_fn = _terminate_instance(compute_service)
|
|
157
|
+
upload_logs_fn = _upload_logs(log_uri)
|
|
158
|
+
gpu_block = _gpu_wait_block() if gpu else ''
|
|
159
|
+
|
|
160
|
+
return textwrap.dedent(
|
|
161
|
+
f"""\
|
|
162
|
+
#!/bin/bash
|
|
163
|
+
set -e
|
|
164
|
+
|
|
165
|
+
LOG_FILE=/var/log/user-data.log
|
|
166
|
+
exec >> "$LOG_FILE" 2>&1
|
|
167
|
+
|
|
168
|
+
log() {{
|
|
169
|
+
echo "$@"
|
|
170
|
+
}}
|
|
171
|
+
|
|
172
|
+
{upload_logs_fn}
|
|
173
|
+
{terminate_fn}
|
|
174
|
+
|
|
175
|
+
# Always cleanup on exit
|
|
176
|
+
trap cleanup EXIT
|
|
177
|
+
|
|
178
|
+
log "======== Instance: {instance_name} =========="
|
|
179
|
+
|
|
180
|
+
log "======== Configure kernel OOM behavior =========="
|
|
181
|
+
sudo sysctl -w vm.panic_on_oom=1
|
|
182
|
+
sudo sysctl -w kernel.panic=0
|
|
183
|
+
|
|
184
|
+
log "======== Update and Install Dependencies =========="
|
|
185
|
+
sudo apt update -y
|
|
186
|
+
sudo apt install -y python3-pip python3-venv awscli git jq
|
|
187
|
+
|
|
188
|
+
log "======== Setting up swap ({swap_gb}G) =========="
|
|
189
|
+
sudo fallocate -l {swap_gb}G /swapfile || \
|
|
190
|
+
sudo dd if=/dev/zero of=/swapfile bs=1M count=$(({swap_gb}*1024))
|
|
191
|
+
sudo chmod 600 /swapfile
|
|
192
|
+
sudo mkswap /swapfile
|
|
193
|
+
sudo swapon /swapfile
|
|
194
|
+
|
|
195
|
+
log "======== Configure AWS CLI =========="
|
|
196
|
+
aws configure set aws_access_key_id '{aws_key}'
|
|
197
|
+
aws configure set aws_secret_access_key '{aws_secret}'
|
|
198
|
+
aws configure set default.region '{S3_REGION}'
|
|
199
|
+
|
|
200
|
+
log "======== Create Virtual Environment =========="
|
|
201
|
+
python3 -m venv ~/env
|
|
202
|
+
source ~/env/bin/activate
|
|
203
|
+
|
|
204
|
+
log "======== Install Dependencies =========="
|
|
205
|
+
pip install --upgrade pip
|
|
206
|
+
{sdv_install}
|
|
207
|
+
pip install "sdgym[all]"
|
|
208
|
+
|
|
209
|
+
{gpu_block}
|
|
210
|
+
|
|
211
|
+
log "======== Write Script =========="
|
|
212
|
+
cat << 'EOF' > ~/sdgym_script.py
|
|
213
|
+
{script_content}
|
|
214
|
+
EOF
|
|
215
|
+
|
|
216
|
+
log "======== Run Script =========="
|
|
217
|
+
python -u ~/sdgym_script.py | tee -a /var/log/sdgym.log
|
|
218
|
+
|
|
219
|
+
log "======== Complete =========="
|
|
220
|
+
"""
|
|
221
|
+
).strip()
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def _run_on_gcp(
|
|
225
|
+
output_destination, synthesizers, s3_client, job_args_list, credentials, compute_config
|
|
226
|
+
):
|
|
227
|
+
"""Launch a GCP Compute Engine instance to run a benchmark.
|
|
228
|
+
|
|
229
|
+
This method creates and configures a VM using the provided compute settings,
|
|
230
|
+
prepares a startup script with the benchmark configuration, and starts execution
|
|
231
|
+
automatically when the instance boots. It waits for the instance to be created
|
|
232
|
+
and raises an error if provisioning fails.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
output_destination (str):
|
|
236
|
+
The S3 URI where results will be stored.
|
|
237
|
+
synthesizers (list of dict):
|
|
238
|
+
The synthesizers to use in the benchmark.
|
|
239
|
+
s3_client (boto3.client):
|
|
240
|
+
The S3 client to use for storing job arguments.
|
|
241
|
+
job_args_list (list of dict):
|
|
242
|
+
The list of job arguments for each dataset.
|
|
243
|
+
credentials (dict):
|
|
244
|
+
The credentials for AWS and GCP.
|
|
245
|
+
compute_config (dict):
|
|
246
|
+
The compute configuration for the GCP instance.
|
|
247
|
+
"""
|
|
248
|
+
script_content = _prepare_script_content(
|
|
249
|
+
output_destination,
|
|
250
|
+
synthesizers,
|
|
251
|
+
s3_client,
|
|
252
|
+
job_args_list,
|
|
253
|
+
credentials,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
gcp_zone = credentials['gcp']['gcp_zone']
|
|
257
|
+
gcp_project = credentials['gcp']['gcp_project']
|
|
258
|
+
gcp_creds = service_account.Credentials.from_service_account_info(
|
|
259
|
+
credentials['gcp'],
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
instance_name = _make_instance_name(compute_config['name_prefix'])
|
|
263
|
+
print( # noqa: T201
|
|
264
|
+
f'Launching instance: {instance_name} (service=gcp project={gcp_project} zone={gcp_zone})'
|
|
265
|
+
)
|
|
266
|
+
startup_script = _get_user_data_script(
|
|
267
|
+
credentials,
|
|
268
|
+
script_content,
|
|
269
|
+
compute_config,
|
|
270
|
+
instance_name,
|
|
271
|
+
output_destination,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
machine_type = f'zones/{gcp_zone}/machineTypes/{compute_config["machine_type"]}'
|
|
275
|
+
source_disk_image = compute_config['source_image']
|
|
276
|
+
gpu = compute_v1.AcceleratorConfig(
|
|
277
|
+
accelerator_type=(f'zones/{gcp_zone}/acceleratorTypes/{compute_config["gpu_type"]}'),
|
|
278
|
+
accelerator_count=int(compute_config['gpu_count']),
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
boot_disk = compute_v1.AttachedDisk(
|
|
282
|
+
auto_delete=True,
|
|
283
|
+
boot=True,
|
|
284
|
+
initialize_params=compute_v1.AttachedDiskInitializeParams(
|
|
285
|
+
source_image=source_disk_image,
|
|
286
|
+
disk_size_gb=int(compute_config['disk_size_gb']),
|
|
287
|
+
),
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
nic = compute_v1.NetworkInterface()
|
|
291
|
+
nic.network = 'global/networks/default'
|
|
292
|
+
nic.access_configs = [
|
|
293
|
+
compute_v1.AccessConfig(
|
|
294
|
+
name='External NAT',
|
|
295
|
+
type_='ONE_TO_ONE_NAT',
|
|
296
|
+
)
|
|
297
|
+
]
|
|
298
|
+
|
|
299
|
+
items = [compute_v1.Items(key='startup-script', value=startup_script)]
|
|
300
|
+
if compute_config.get('install_nvidia_driver', True):
|
|
301
|
+
items.append(
|
|
302
|
+
compute_v1.Items(key='install-nvidia-driver', value='true'),
|
|
303
|
+
)
|
|
304
|
+
metadata = compute_v1.Metadata(items=items)
|
|
305
|
+
|
|
306
|
+
scheduling = compute_v1.Scheduling(
|
|
307
|
+
on_host_maintenance='TERMINATE',
|
|
308
|
+
automatic_restart=False,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
instance = compute_v1.Instance(
|
|
312
|
+
name=instance_name,
|
|
313
|
+
machine_type=machine_type,
|
|
314
|
+
disks=[boot_disk],
|
|
315
|
+
network_interfaces=[nic],
|
|
316
|
+
metadata=metadata,
|
|
317
|
+
guest_accelerators=[gpu],
|
|
318
|
+
scheduling=scheduling,
|
|
319
|
+
service_accounts=[
|
|
320
|
+
compute_v1.ServiceAccount(
|
|
321
|
+
email='default',
|
|
322
|
+
scopes=['https://www.googleapis.com/auth/cloud-platform'],
|
|
323
|
+
)
|
|
324
|
+
],
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
instance_client = compute_v1.InstancesClient(credentials=gcp_creds)
|
|
328
|
+
operation = instance_client.insert(
|
|
329
|
+
project=gcp_project,
|
|
330
|
+
zone=gcp_zone,
|
|
331
|
+
instance_resource=instance,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
op_client = compute_v1.ZoneOperationsClient(credentials=gcp_creds)
|
|
335
|
+
operation = op_client.wait(
|
|
336
|
+
project=gcp_project,
|
|
337
|
+
zone=gcp_zone,
|
|
338
|
+
operation=operation.name,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
if operation.error and operation.error.errors:
|
|
342
|
+
messages = [e.message for e in operation.error.errors if e.message]
|
|
343
|
+
joined = '; '.join(messages) if messages else str(operation.error)
|
|
344
|
+
raise RuntimeError(f'GCP instance creation failed: {joined}')
|
|
345
|
+
|
|
346
|
+
print(f'Instance created: {instance_name}') # noqa: T201
|
|
347
|
+
return instance_name
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def _benchmark_compute_gcp(
|
|
351
|
+
output_destination,
|
|
352
|
+
credential_filepath,
|
|
353
|
+
compute_config,
|
|
354
|
+
synthesizers,
|
|
355
|
+
sdv_datasets,
|
|
356
|
+
additional_datasets_folder,
|
|
357
|
+
limit_dataset_size,
|
|
358
|
+
compute_quality_score,
|
|
359
|
+
compute_diagnostic_score,
|
|
360
|
+
compute_privacy_score,
|
|
361
|
+
sdmetrics,
|
|
362
|
+
timeout,
|
|
363
|
+
modality,
|
|
364
|
+
):
|
|
365
|
+
"""Run the SDGym benchmark on datasets for the given modality."""
|
|
366
|
+
compute_config = resolve_compute_config('gcp', compute_config)
|
|
367
|
+
credentials = get_credentials(credential_filepath)
|
|
368
|
+
validate_compute_config(compute_config)
|
|
369
|
+
|
|
370
|
+
s3_client = _validate_output_destination(
|
|
371
|
+
output_destination,
|
|
372
|
+
aws_keys={
|
|
373
|
+
'aws_access_key_id': credentials['aws']['aws_access_key_id'],
|
|
374
|
+
'aws_secret_access_key': credentials['aws']['aws_secret_access_key'],
|
|
375
|
+
},
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
if not synthesizers:
|
|
379
|
+
synthesizers = []
|
|
380
|
+
|
|
381
|
+
_ensure_uniform_included(synthesizers, modality)
|
|
382
|
+
synthesizers = _import_and_validate_synthesizers(
|
|
383
|
+
synthesizers=synthesizers,
|
|
384
|
+
custom_synthesizers=None,
|
|
385
|
+
modality=modality,
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
job_args_list = _generate_job_args_list(
|
|
389
|
+
limit_dataset_size=limit_dataset_size,
|
|
390
|
+
sdv_datasets=sdv_datasets,
|
|
391
|
+
additional_datasets_folder=additional_datasets_folder,
|
|
392
|
+
sdmetrics=sdmetrics,
|
|
393
|
+
timeout=timeout,
|
|
394
|
+
output_destination=output_destination,
|
|
395
|
+
compute_quality_score=compute_quality_score,
|
|
396
|
+
compute_diagnostic_score=compute_diagnostic_score,
|
|
397
|
+
compute_privacy_score=compute_privacy_score,
|
|
398
|
+
synthesizers=synthesizers,
|
|
399
|
+
s3_client=s3_client,
|
|
400
|
+
modality=modality,
|
|
401
|
+
)
|
|
402
|
+
if not job_args_list:
|
|
403
|
+
return _get_empty_dataframe(
|
|
404
|
+
compute_diagnostic_score=compute_diagnostic_score,
|
|
405
|
+
compute_quality_score=compute_quality_score,
|
|
406
|
+
compute_privacy_score=compute_privacy_score,
|
|
407
|
+
sdmetrics=sdmetrics,
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
_run_on_gcp(
|
|
411
|
+
output_destination=output_destination,
|
|
412
|
+
synthesizers=synthesizers,
|
|
413
|
+
s3_client=s3_client,
|
|
414
|
+
job_args_list=job_args_list,
|
|
415
|
+
credentials=credentials,
|
|
416
|
+
compute_config=compute_config,
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def _benchmark_single_table_compute_gcp(
|
|
421
|
+
output_destination,
|
|
422
|
+
credential_filepath,
|
|
423
|
+
compute_config=None,
|
|
424
|
+
synthesizers=DEFAULT_SINGLE_TABLE_SYNTHESIZERS,
|
|
425
|
+
sdv_datasets=DEFAULT_SINGLE_TABLE_DATASETS,
|
|
426
|
+
additional_datasets_folder=None,
|
|
427
|
+
limit_dataset_size=False,
|
|
428
|
+
compute_quality_score=True,
|
|
429
|
+
compute_diagnostic_score=True,
|
|
430
|
+
compute_privacy_score=False,
|
|
431
|
+
sdmetrics=None,
|
|
432
|
+
timeout=None,
|
|
433
|
+
):
|
|
434
|
+
"""Run the SDGym benchmark on GCP with the single-table modality.
|
|
435
|
+
|
|
436
|
+
Args:
|
|
437
|
+
output_destination (str):
|
|
438
|
+
The S3 URI where results will be stored.
|
|
439
|
+
credential_filepath (str or Path):
|
|
440
|
+
Path to the credentials file for AWS, GCP and SDV-Enterprise.
|
|
441
|
+
compute_config (dict, optional):
|
|
442
|
+
The compute configuration for the GCP instance. If None, default settings will be used.
|
|
443
|
+
synthesizers (list of dict, optional):
|
|
444
|
+
The synthesizers to use in the benchmark. Defaults to DEFAULT_SINGLE_TABLE_SYNTHESIZERS.
|
|
445
|
+
sdv_datasets (list of str, optional):
|
|
446
|
+
The SDV datasets to use in the benchmark. Defaults to DEFAULT_SINGLE_TABLE_DATASETS.
|
|
447
|
+
additional_datasets_folder (str or Path, optional):
|
|
448
|
+
Path to a folder containing additional datasets to include in the benchmark.
|
|
449
|
+
limit_dataset_size (bool, optional):
|
|
450
|
+
Whether to limit the size of datasets for faster benchmarking. Defaults to False
|
|
451
|
+
compute_quality_score (bool, optional):
|
|
452
|
+
Whether to compute the quality score. Defaults to True.
|
|
453
|
+
compute_diagnostic_score (bool, optional):
|
|
454
|
+
Whether to compute the diagnostic score. Defaults to True.
|
|
455
|
+
compute_privacy_score (bool, optional):
|
|
456
|
+
Whether to compute the privacy score. Defaults to True.
|
|
457
|
+
sdmetrics (list of str, optional):
|
|
458
|
+
The sdmetrics to use for evaluation. If None, default metrics will be used.
|
|
459
|
+
timeout (int, optional):
|
|
460
|
+
Timeout in seconds for each synthesizer-dataset run. If None, no timeout is applied
|
|
461
|
+
"""
|
|
462
|
+
return _benchmark_compute_gcp(
|
|
463
|
+
output_destination=output_destination,
|
|
464
|
+
credential_filepath=credential_filepath,
|
|
465
|
+
compute_config=compute_config,
|
|
466
|
+
synthesizers=synthesizers,
|
|
467
|
+
sdv_datasets=sdv_datasets,
|
|
468
|
+
additional_datasets_folder=additional_datasets_folder,
|
|
469
|
+
limit_dataset_size=limit_dataset_size,
|
|
470
|
+
compute_quality_score=compute_quality_score,
|
|
471
|
+
compute_diagnostic_score=compute_diagnostic_score,
|
|
472
|
+
compute_privacy_score=compute_privacy_score,
|
|
473
|
+
sdmetrics=sdmetrics,
|
|
474
|
+
timeout=timeout,
|
|
475
|
+
modality='single_table',
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
def _benchmark_multi_table_compute_gcp(
|
|
480
|
+
output_destination,
|
|
481
|
+
credential_filepath,
|
|
482
|
+
compute_config=None,
|
|
483
|
+
synthesizers=DEFAULT_MULTI_TABLE_SYNTHESIZERS,
|
|
484
|
+
sdv_datasets=DEFAULT_MULTI_TABLE_DATASETS,
|
|
485
|
+
additional_datasets_folder=None,
|
|
486
|
+
limit_dataset_size=False,
|
|
487
|
+
compute_quality_score=True,
|
|
488
|
+
compute_diagnostic_score=True,
|
|
489
|
+
sdmetrics=None,
|
|
490
|
+
timeout=None,
|
|
491
|
+
):
|
|
492
|
+
"""Run the SDGym benchmark on GCP with the multi-table modality.
|
|
493
|
+
|
|
494
|
+
Args:
|
|
495
|
+
output_destination (str):
|
|
496
|
+
The S3 URI where results will be stored.
|
|
497
|
+
credential_filepath (str or Path):
|
|
498
|
+
Path to the credentials file for AWS, GCP and SDV-Enterprise.
|
|
499
|
+
compute_config (dict, optional):
|
|
500
|
+
The compute configuration for the GCP instance. If None, default settings will be used.
|
|
501
|
+
synthesizers (list of dict, optional):
|
|
502
|
+
The synthesizers to use in the benchmark. Defaults to DEFAULT_MULTI_TABLE_SYNTHESIZERS.
|
|
503
|
+
sdv_datasets (list of str, optional):
|
|
504
|
+
The SDV datasets to use in the benchmark. Defaults to DEFAULT_MULTI_TABLE_DATASETS.
|
|
505
|
+
additional_datasets_folder (str or Path, optional):
|
|
506
|
+
Path to a folder containing additional datasets to include in the benchmark.
|
|
507
|
+
limit_dataset_size (bool, optional):
|
|
508
|
+
Whether to limit the size of datasets for faster benchmarking. Defaults to False
|
|
509
|
+
compute_quality_score (bool, optional):
|
|
510
|
+
Whether to compute the quality score. Defaults to True.
|
|
511
|
+
compute_diagnostic_score (bool, optional):
|
|
512
|
+
Whether to compute the diagnostic score. Defaults to True.
|
|
513
|
+
sdmetrics (list of str, optional):
|
|
514
|
+
The sdmetrics to use for evaluation. If None, default metrics will be used.
|
|
515
|
+
timeout (int, optional):
|
|
516
|
+
Timeout in seconds for each synthesizer-dataset run. If None, no timeout is applied.
|
|
517
|
+
"""
|
|
518
|
+
return _benchmark_compute_gcp(
|
|
519
|
+
output_destination=output_destination,
|
|
520
|
+
credential_filepath=credential_filepath,
|
|
521
|
+
compute_config=compute_config,
|
|
522
|
+
synthesizers=synthesizers,
|
|
523
|
+
sdv_datasets=sdv_datasets,
|
|
524
|
+
additional_datasets_folder=additional_datasets_folder,
|
|
525
|
+
limit_dataset_size=limit_dataset_size,
|
|
526
|
+
compute_quality_score=compute_quality_score,
|
|
527
|
+
compute_diagnostic_score=compute_diagnostic_score,
|
|
528
|
+
compute_privacy_score=False,
|
|
529
|
+
sdmetrics=sdmetrics,
|
|
530
|
+
timeout=timeout,
|
|
531
|
+
modality='multi_table',
|
|
532
|
+
)
|