sdgym 0.12.1.dev0__tar.gz → 0.12.2.dev0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. {sdgym-0.12.1.dev0/sdgym.egg-info → sdgym-0.12.2.dev0}/PKG-INFO +8 -6
  2. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/pyproject.toml +12 -7
  3. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/__init__.py +11 -4
  4. sdgym-0.12.2.dev0/sdgym/_benchmark/__init__.py +1 -0
  5. sdgym-0.12.2.dev0/sdgym/_benchmark/benchmark.py +532 -0
  6. sdgym-0.12.2.dev0/sdgym/_benchmark/config_utils.py +123 -0
  7. sdgym-0.12.2.dev0/sdgym/_benchmark/credentials_utils.py +104 -0
  8. sdgym-0.12.2.dev0/sdgym/_dataset_utils.py +156 -0
  9. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/benchmark.py +503 -415
  10. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/cli/__main__.py +3 -15
  11. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/dataset_explorer.py +20 -14
  12. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/datasets.py +106 -36
  13. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/metrics.py +7 -7
  14. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/result_explorer/result_explorer.py +55 -22
  15. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/result_explorer/result_handler.py +35 -18
  16. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/result_writer.py +25 -3
  17. sdgym-0.12.2.dev0/sdgym/run_benchmark/__init__.py +1 -0
  18. sdgym-0.12.2.dev0/sdgym/run_benchmark/run_benchmark.py +152 -0
  19. sdgym-0.12.2.dev0/sdgym/run_benchmark/upload_benchmark_results.py +601 -0
  20. sdgym-0.12.2.dev0/sdgym/run_benchmark/utils.py +205 -0
  21. sdgym-0.12.2.dev0/sdgym/synthesizer_descriptions.yaml +105 -0
  22. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/synthesizers/__init__.py +2 -1
  23. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/synthesizers/base.py +73 -8
  24. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/synthesizers/column.py +16 -10
  25. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/synthesizers/generate.py +21 -14
  26. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/synthesizers/identity.py +6 -9
  27. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/synthesizers/realtabformer.py +5 -3
  28. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/synthesizers/sdv.py +18 -16
  29. sdgym-0.12.2.dev0/sdgym/synthesizers/uniform.py +140 -0
  30. sdgym-0.12.2.dev0/sdgym/synthesizers/utils.py +37 -0
  31. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/utils.py +1 -0
  32. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0/sdgym.egg-info}/PKG-INFO +8 -6
  33. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym.egg-info/SOURCES.txt +9 -0
  34. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym.egg-info/requires.txt +7 -5
  35. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/tests/test_tasks.py +5 -1
  36. sdgym-0.12.1.dev0/sdgym/_dataset_utils.py +0 -107
  37. sdgym-0.12.1.dev0/sdgym/synthesizers/uniform.py +0 -72
  38. sdgym-0.12.1.dev0/sdgym/synthesizers/utils.py +0 -51
  39. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/LICENSE +0 -0
  40. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/README.md +0 -0
  41. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/cli/__init__.py +0 -0
  42. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/cli/collect.py +0 -0
  43. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/cli/summary.py +0 -0
  44. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/cli/utils.py +0 -0
  45. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/errors.py +0 -0
  46. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/progress.py +0 -0
  47. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/result_explorer/__init__.py +0 -0
  48. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym/s3.py +0 -0
  49. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym.egg-info/dependency_links.txt +0 -0
  50. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym.egg-info/entry_points.txt +0 -0
  51. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/sdgym.egg-info/top_level.txt +0 -0
  52. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/setup.cfg +0 -0
  53. {sdgym-0.12.1.dev0 → sdgym-0.12.2.dev0}/tests/test_scripts.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sdgym
3
- Version: 0.12.1.dev0
3
+ Version: 0.12.2.dev0
4
4
  Summary: Benchmark tabular synthetic data generators using a variety of datasets
5
5
  Author-email: "DataCebo, Inc." <info@sdv.dev>
6
6
  License: BSL-1.1
@@ -29,16 +29,18 @@ Requires-Dist: boto3<2,>=1.28
29
29
  Requires-Dist: botocore<2,>=1.31
30
30
  Requires-Dist: cloudpickle>=2.1.0
31
31
  Requires-Dist: compress-pickle>=1.2.0
32
+ Requires-Dist: google-cloud-compute>=1.0.0
33
+ Requires-Dist: google-auth>=2.0.0
32
34
  Requires-Dist: humanfriendly>=10.0
33
35
  Requires-Dist: numpy>=1.22.2; python_version < "3.10"
34
36
  Requires-Dist: numpy>=1.24.0; python_version >= "3.10" and python_version < "3.12"
35
37
  Requires-Dist: numpy>=1.26.0; python_version >= "3.12" and python_version < "3.13"
36
38
  Requires-Dist: numpy>=2.1.0; python_version >= "3.13"
37
39
  Requires-Dist: openpyxl>=3.1.2
38
- Requires-Dist: pandas>=1.4.0; python_version < "3.11"
39
- Requires-Dist: pandas>=1.5.0; python_version >= "3.11" and python_version < "3.12"
40
- Requires-Dist: pandas>=2.1.1; python_version >= "3.12" and python_version < "3.13"
41
- Requires-Dist: pandas>=2.2.3; python_version >= "3.13"
40
+ Requires-Dist: pandas<3.0.0,>=1.4.0; python_version < "3.11"
41
+ Requires-Dist: pandas<3.0.0,>=1.5.0; python_version >= "3.11" and python_version < "3.12"
42
+ Requires-Dist: pandas<3.0.0,>=2.1.1; python_version >= "3.12" and python_version < "3.13"
43
+ Requires-Dist: pandas<3.0.0,>=2.2.3; python_version >= "3.13"
42
44
  Requires-Dist: psutil>=5.7
43
45
  Requires-Dist: scikit-learn>=1.0.2; python_version < "3.10"
44
46
  Requires-Dist: scikit-learn>=1.1.0; python_version >= "3.10" and python_version < "3.11"
@@ -60,7 +62,7 @@ Provides-Extra: dask
60
62
  Requires-Dist: dask; extra == "dask"
61
63
  Requires-Dist: distributed; extra == "dask"
62
64
  Provides-Extra: realtabformer
63
- Requires-Dist: realtabformer>=0.2.3; extra == "realtabformer"
65
+ Requires-Dist: realtabformer!=0.2.4,>=0.2.3; extra == "realtabformer"
64
66
  Requires-Dist: torch>=2.6.0; extra == "realtabformer"
65
67
  Requires-Dist: transformers<4.51; extra == "realtabformer"
66
68
  Provides-Extra: test
@@ -26,16 +26,18 @@ dependencies = [
26
26
  'botocore>=1.31,<2',
27
27
  'cloudpickle>=2.1.0',
28
28
  'compress-pickle>=1.2.0',
29
+ 'google-cloud-compute>=1.0.0',
30
+ 'google-auth>=2.0.0',
29
31
  'humanfriendly>=10.0',
30
32
  "numpy>=1.22.2;python_version<'3.10'",
31
33
  "numpy>=1.24.0;python_version>='3.10' and python_version<'3.12'",
32
34
  "numpy>=1.26.0;python_version>='3.12' and python_version<'3.13'",
33
35
  "numpy>=2.1.0;python_version>='3.13'",
34
36
  'openpyxl>=3.1.2',
35
- "pandas>=1.4.0;python_version<'3.11'",
36
- "pandas>=1.5.0;python_version>='3.11' and python_version<'3.12'",
37
- "pandas>=2.1.1;python_version>='3.12' and python_version<'3.13'",
38
- "pandas>=2.2.3;python_version>='3.13'",
37
+ "pandas>=1.4.0,<3.0.0;python_version<'3.11'",
38
+ "pandas>=1.5.0,<3.0.0;python_version>='3.11' and python_version<'3.12'",
39
+ "pandas>=2.1.1,<3.0.0;python_version>='3.12' and python_version<'3.13'",
40
+ "pandas>=2.2.3,<3.0.0;python_version>='3.13'",
39
41
  'psutil>=5.7',
40
42
  "scikit-learn>=1.0.2;python_version<'3.10'",
41
43
  "scikit-learn>=1.1.0;python_version>='3.10' and python_version<'3.11'",
@@ -68,7 +70,7 @@ sdgym = { main = 'sdgym.cli.__main__:main' }
68
70
  [project.optional-dependencies]
69
71
  dask = ['dask', 'distributed']
70
72
  realtabformer = [
71
- 'realtabformer>=0.2.3',
73
+ 'realtabformer>=0.2.3,!=0.2.4',
72
74
  "torch>=2.6.0",
73
75
  'transformers<4.51',
74
76
  ]
@@ -131,7 +133,10 @@ namespaces = false
131
133
  '*.png',
132
134
  '*.gif'
133
135
  ]
134
- 'sdgym' = ['leaderboard.csv']
136
+ 'sdgym' = [
137
+ 'leaderboard.csv',
138
+ 'synthesizer_descriptions.yaml'
139
+ ]
135
140
 
136
141
  [tool.setuptools.exclude-package-data]
137
142
  '*' = [
@@ -144,7 +149,7 @@ namespaces = false
144
149
  version = {attr = 'sdgym.__version__'}
145
150
 
146
151
  [tool.bumpversion]
147
- current_version = "0.12.1.dev0"
152
+ current_version = "0.12.2.dev0"
148
153
  parse = '(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<candidate>\d+))?'
149
154
  serialize = [
150
155
  '{major}.{minor}.{patch}.{release}{candidate}',
@@ -8,11 +8,16 @@ __author__ = 'DataCebo, Inc.'
8
8
  __copyright__ = 'Copyright (c) 2022 DataCebo, Inc.'
9
9
  __email__ = 'info@sdv.dev'
10
10
  __license__ = 'BSL-1.1'
11
- __version__ = '0.12.1.dev0'
11
+ __version__ = '0.12.2.dev0'
12
12
 
13
13
  import logging
14
14
 
15
- from sdgym.benchmark import benchmark_single_table, benchmark_single_table_aws
15
+ from sdgym.benchmark import (
16
+ benchmark_multi_table,
17
+ benchmark_single_table,
18
+ benchmark_single_table_aws,
19
+ benchmark_multi_table_aws,
20
+ )
16
21
  from sdgym.cli.collect import collect_results
17
22
  from sdgym.cli.summary import make_summary_spreadsheet
18
23
  from sdgym.dataset_explorer import DatasetExplorer
@@ -31,12 +36,14 @@ list(map(logging.root.removeFilter, logging.root.filters))
31
36
  __all__ = [
32
37
  'DatasetExplorer',
33
38
  'ResultsExplorer',
39
+ 'benchmark_multi_table',
40
+ 'benchmark_multi_table_aws',
34
41
  'benchmark_single_table',
35
42
  'benchmark_single_table_aws',
36
43
  'collect_results',
37
- 'create_synthesizer_variant',
38
- 'create_single_table_synthesizer',
39
44
  'create_multi_table_synthesizer',
45
+ 'create_single_table_synthesizer',
46
+ 'create_synthesizer_variant',
40
47
  'load_dataset',
41
48
  'make_summary_spreadsheet',
42
49
  ]
@@ -0,0 +1 @@
1
+ """Benchmark on GCP compute instances."""
@@ -0,0 +1,532 @@
1
+ import textwrap
2
+ from urllib.parse import urlparse
3
+
4
+ from google.cloud import compute_v1
5
+ from google.oauth2 import service_account
6
+
7
+ from sdgym._benchmark.config_utils import (
8
+ _make_instance_name,
9
+ resolve_compute_config,
10
+ validate_compute_config,
11
+ )
12
+ from sdgym._benchmark.credentials_utils import get_credentials, sdv_install_cmd
13
+ from sdgym.benchmark import (
14
+ DEFAULT_MULTI_TABLE_DATASETS,
15
+ DEFAULT_MULTI_TABLE_SYNTHESIZERS,
16
+ DEFAULT_SINGLE_TABLE_DATASETS,
17
+ DEFAULT_SINGLE_TABLE_SYNTHESIZERS,
18
+ S3_REGION,
19
+ _ensure_uniform_included,
20
+ _generate_job_args_list,
21
+ _get_empty_dataframe,
22
+ _get_s3_script_content,
23
+ _import_and_validate_synthesizers,
24
+ _store_job_args_in_s3,
25
+ _validate_output_destination,
26
+ )
27
+
28
+
29
+ def _get_logs_s3_uri(output_destination, instance_name):
30
+ """Store logs next to output destination prefix.
31
+
32
+ Example:
33
+ output_destination='s3://bucket/prefix'
34
+ -> s3://bucket/prefix/logs/<instance>-user-data.log
35
+ """
36
+ if not output_destination.startswith('s3://'):
37
+ return ''
38
+
39
+ parsed = urlparse(output_destination)
40
+ bucket = parsed.netloc
41
+ prefix = parsed.path.lstrip('/').rstrip('/')
42
+ prefix = f'{prefix}/logs' if prefix else 'logs'
43
+
44
+ return f's3://{bucket}/{prefix}/{instance_name}-user-data.log'
45
+
46
+
47
+ def _prepare_script_content(
48
+ output_destination,
49
+ synthesizers,
50
+ s3_client,
51
+ job_args_list,
52
+ credentials,
53
+ ):
54
+ bucket_name, job_args_key = _store_job_args_in_s3(
55
+ output_destination,
56
+ job_args_list,
57
+ s3_client,
58
+ )
59
+ synthesizer_names = [{'name': s['name']} for s in synthesizers]
60
+ return _get_s3_script_content(
61
+ credentials['aws']['aws_access_key_id'],
62
+ credentials['aws']['aws_secret_access_key'],
63
+ S3_REGION,
64
+ bucket_name,
65
+ job_args_key,
66
+ synthesizer_names,
67
+ )
68
+
69
+
70
+ def _terminate_instance(compute_service):
71
+ if compute_service not in ('aws', 'gcp'):
72
+ raise ValueError(f'Unsupported compute service: {compute_service}')
73
+
74
+ if compute_service == 'aws':
75
+ return textwrap.dedent(
76
+ """\
77
+ cleanup() {
78
+ log "======== Kernel messages (OOM info) =========="
79
+ dmesg | tail -50 || true
80
+ upload_logs || true
81
+ INSTANCE_ID=$(curl -sf http://169.254.169.254/latest/meta-data/instance-id || true)
82
+ if [ -n "$INSTANCE_ID" ]; then
83
+ log "Terminating EC2 instance: $INSTANCE_ID"
84
+ aws ec2 terminate-instances --instance-ids "$INSTANCE_ID" >/dev/null 2>&1 || true
85
+ fi
86
+ }
87
+ """
88
+ ).strip()
89
+
90
+ # GCP
91
+ return textwrap.dedent(
92
+ """\
93
+ cleanup() {
94
+ log "======== Kernel messages (OOM info) =========="
95
+ dmesg | tail -50 || true
96
+ upload_logs || true
97
+ log "Shutting down GCE instance"
98
+ shutdown -h now || true
99
+ }
100
+ """
101
+ ).strip()
102
+
103
+
104
+ def _gpu_wait_block():
105
+ return textwrap.dedent(
106
+ """\
107
+ log "======== Waiting for GPU =========="
108
+ for i in {1..60}; do
109
+ if command -v nvidia-smi >/dev/null && nvidia-smi >/dev/null; then
110
+ nvidia-smi
111
+ break
112
+ fi
113
+ sleep 10
114
+ done
115
+ """
116
+ ).strip()
117
+
118
+
119
+ def _upload_logs(log_uri):
120
+ if not log_uri:
121
+ return 'upload_logs() { :; }'
122
+
123
+ return textwrap.dedent(
124
+ f"""\
125
+ upload_logs() {{
126
+ log "======== Uploading logs =========="
127
+ aws s3 cp /var/log/user-data.log "{log_uri}" >/dev/null 2>&1 || true
128
+ }}
129
+ """
130
+ ).strip()
131
+
132
+
133
+ def _get_user_data_script(
134
+ credentials,
135
+ script_content,
136
+ config,
137
+ instance_name,
138
+ output_destination,
139
+ ):
140
+ compute_service = config['service']
141
+ swap_gb = int(config.get('swap_gb', 32))
142
+ gpu = (
143
+ bool(config.get('gpu'))
144
+ or int(config.get('gpu_count', 0)) > 0
145
+ or bool(config.get('gpu_type'))
146
+ )
147
+ upload_logs = bool(config.get('upload_logs', True))
148
+
149
+ aws_key = credentials['aws']['aws_access_key_id']
150
+ aws_secret = credentials['aws']['aws_secret_access_key']
151
+
152
+ log_uri = _get_logs_s3_uri(output_destination, instance_name) if upload_logs else ''
153
+
154
+ sdv_install = sdv_install_cmd(credentials).rstrip()
155
+ sdv_install = textwrap.indent(sdv_install, ' ') if sdv_install else ''
156
+ terminate_fn = _terminate_instance(compute_service)
157
+ upload_logs_fn = _upload_logs(log_uri)
158
+ gpu_block = _gpu_wait_block() if gpu else ''
159
+
160
+ return textwrap.dedent(
161
+ f"""\
162
+ #!/bin/bash
163
+ set -e
164
+
165
+ LOG_FILE=/var/log/user-data.log
166
+ exec >> "$LOG_FILE" 2>&1
167
+
168
+ log() {{
169
+ echo "$@"
170
+ }}
171
+
172
+ {upload_logs_fn}
173
+ {terminate_fn}
174
+
175
+ # Always cleanup on exit
176
+ trap cleanup EXIT
177
+
178
+ log "======== Instance: {instance_name} =========="
179
+
180
+ log "======== Configure kernel OOM behavior =========="
181
+ sudo sysctl -w vm.panic_on_oom=1
182
+ sudo sysctl -w kernel.panic=0
183
+
184
+ log "======== Update and Install Dependencies =========="
185
+ sudo apt update -y
186
+ sudo apt install -y python3-pip python3-venv awscli git jq
187
+
188
+ log "======== Setting up swap ({swap_gb}G) =========="
189
+ sudo fallocate -l {swap_gb}G /swapfile || \
190
+ sudo dd if=/dev/zero of=/swapfile bs=1M count=$(({swap_gb}*1024))
191
+ sudo chmod 600 /swapfile
192
+ sudo mkswap /swapfile
193
+ sudo swapon /swapfile
194
+
195
+ log "======== Configure AWS CLI =========="
196
+ aws configure set aws_access_key_id '{aws_key}'
197
+ aws configure set aws_secret_access_key '{aws_secret}'
198
+ aws configure set default.region '{S3_REGION}'
199
+
200
+ log "======== Create Virtual Environment =========="
201
+ python3 -m venv ~/env
202
+ source ~/env/bin/activate
203
+
204
+ log "======== Install Dependencies =========="
205
+ pip install --upgrade pip
206
+ {sdv_install}
207
+ pip install "sdgym[all]"
208
+
209
+ {gpu_block}
210
+
211
+ log "======== Write Script =========="
212
+ cat << 'EOF' > ~/sdgym_script.py
213
+ {script_content}
214
+ EOF
215
+
216
+ log "======== Run Script =========="
217
+ python -u ~/sdgym_script.py | tee -a /var/log/sdgym.log
218
+
219
+ log "======== Complete =========="
220
+ """
221
+ ).strip()
222
+
223
+
224
+ def _run_on_gcp(
225
+ output_destination, synthesizers, s3_client, job_args_list, credentials, compute_config
226
+ ):
227
+ """Launch a GCP Compute Engine instance to run a benchmark.
228
+
229
+ This method creates and configures a VM using the provided compute settings,
230
+ prepares a startup script with the benchmark configuration, and starts execution
231
+ automatically when the instance boots. It waits for the instance to be created
232
+ and raises an error if provisioning fails.
233
+
234
+ Args:
235
+ output_destination (str):
236
+ The S3 URI where results will be stored.
237
+ synthesizers (list of dict):
238
+ The synthesizers to use in the benchmark.
239
+ s3_client (boto3.client):
240
+ The S3 client to use for storing job arguments.
241
+ job_args_list (list of dict):
242
+ The list of job arguments for each dataset.
243
+ credentials (dict):
244
+ The credentials for AWS and GCP.
245
+ compute_config (dict):
246
+ The compute configuration for the GCP instance.
247
+ """
248
+ script_content = _prepare_script_content(
249
+ output_destination,
250
+ synthesizers,
251
+ s3_client,
252
+ job_args_list,
253
+ credentials,
254
+ )
255
+
256
+ gcp_zone = credentials['gcp']['gcp_zone']
257
+ gcp_project = credentials['gcp']['gcp_project']
258
+ gcp_creds = service_account.Credentials.from_service_account_info(
259
+ credentials['gcp'],
260
+ )
261
+
262
+ instance_name = _make_instance_name(compute_config['name_prefix'])
263
+ print( # noqa: T201
264
+ f'Launching instance: {instance_name} (service=gcp project={gcp_project} zone={gcp_zone})'
265
+ )
266
+ startup_script = _get_user_data_script(
267
+ credentials,
268
+ script_content,
269
+ compute_config,
270
+ instance_name,
271
+ output_destination,
272
+ )
273
+
274
+ machine_type = f'zones/{gcp_zone}/machineTypes/{compute_config["machine_type"]}'
275
+ source_disk_image = compute_config['source_image']
276
+ gpu = compute_v1.AcceleratorConfig(
277
+ accelerator_type=(f'zones/{gcp_zone}/acceleratorTypes/{compute_config["gpu_type"]}'),
278
+ accelerator_count=int(compute_config['gpu_count']),
279
+ )
280
+
281
+ boot_disk = compute_v1.AttachedDisk(
282
+ auto_delete=True,
283
+ boot=True,
284
+ initialize_params=compute_v1.AttachedDiskInitializeParams(
285
+ source_image=source_disk_image,
286
+ disk_size_gb=int(compute_config['disk_size_gb']),
287
+ ),
288
+ )
289
+
290
+ nic = compute_v1.NetworkInterface()
291
+ nic.network = 'global/networks/default'
292
+ nic.access_configs = [
293
+ compute_v1.AccessConfig(
294
+ name='External NAT',
295
+ type_='ONE_TO_ONE_NAT',
296
+ )
297
+ ]
298
+
299
+ items = [compute_v1.Items(key='startup-script', value=startup_script)]
300
+ if compute_config.get('install_nvidia_driver', True):
301
+ items.append(
302
+ compute_v1.Items(key='install-nvidia-driver', value='true'),
303
+ )
304
+ metadata = compute_v1.Metadata(items=items)
305
+
306
+ scheduling = compute_v1.Scheduling(
307
+ on_host_maintenance='TERMINATE',
308
+ automatic_restart=False,
309
+ )
310
+
311
+ instance = compute_v1.Instance(
312
+ name=instance_name,
313
+ machine_type=machine_type,
314
+ disks=[boot_disk],
315
+ network_interfaces=[nic],
316
+ metadata=metadata,
317
+ guest_accelerators=[gpu],
318
+ scheduling=scheduling,
319
+ service_accounts=[
320
+ compute_v1.ServiceAccount(
321
+ email='default',
322
+ scopes=['https://www.googleapis.com/auth/cloud-platform'],
323
+ )
324
+ ],
325
+ )
326
+
327
+ instance_client = compute_v1.InstancesClient(credentials=gcp_creds)
328
+ operation = instance_client.insert(
329
+ project=gcp_project,
330
+ zone=gcp_zone,
331
+ instance_resource=instance,
332
+ )
333
+
334
+ op_client = compute_v1.ZoneOperationsClient(credentials=gcp_creds)
335
+ operation = op_client.wait(
336
+ project=gcp_project,
337
+ zone=gcp_zone,
338
+ operation=operation.name,
339
+ )
340
+
341
+ if operation.error and operation.error.errors:
342
+ messages = [e.message for e in operation.error.errors if e.message]
343
+ joined = '; '.join(messages) if messages else str(operation.error)
344
+ raise RuntimeError(f'GCP instance creation failed: {joined}')
345
+
346
+ print(f'Instance created: {instance_name}') # noqa: T201
347
+ return instance_name
348
+
349
+
350
+ def _benchmark_compute_gcp(
351
+ output_destination,
352
+ credential_filepath,
353
+ compute_config,
354
+ synthesizers,
355
+ sdv_datasets,
356
+ additional_datasets_folder,
357
+ limit_dataset_size,
358
+ compute_quality_score,
359
+ compute_diagnostic_score,
360
+ compute_privacy_score,
361
+ sdmetrics,
362
+ timeout,
363
+ modality,
364
+ ):
365
+ """Run the SDGym benchmark on datasets for the given modality."""
366
+ compute_config = resolve_compute_config('gcp', compute_config)
367
+ credentials = get_credentials(credential_filepath)
368
+ validate_compute_config(compute_config)
369
+
370
+ s3_client = _validate_output_destination(
371
+ output_destination,
372
+ aws_keys={
373
+ 'aws_access_key_id': credentials['aws']['aws_access_key_id'],
374
+ 'aws_secret_access_key': credentials['aws']['aws_secret_access_key'],
375
+ },
376
+ )
377
+
378
+ if not synthesizers:
379
+ synthesizers = []
380
+
381
+ _ensure_uniform_included(synthesizers, modality)
382
+ synthesizers = _import_and_validate_synthesizers(
383
+ synthesizers=synthesizers,
384
+ custom_synthesizers=None,
385
+ modality=modality,
386
+ )
387
+
388
+ job_args_list = _generate_job_args_list(
389
+ limit_dataset_size=limit_dataset_size,
390
+ sdv_datasets=sdv_datasets,
391
+ additional_datasets_folder=additional_datasets_folder,
392
+ sdmetrics=sdmetrics,
393
+ timeout=timeout,
394
+ output_destination=output_destination,
395
+ compute_quality_score=compute_quality_score,
396
+ compute_diagnostic_score=compute_diagnostic_score,
397
+ compute_privacy_score=compute_privacy_score,
398
+ synthesizers=synthesizers,
399
+ s3_client=s3_client,
400
+ modality=modality,
401
+ )
402
+ if not job_args_list:
403
+ return _get_empty_dataframe(
404
+ compute_diagnostic_score=compute_diagnostic_score,
405
+ compute_quality_score=compute_quality_score,
406
+ compute_privacy_score=compute_privacy_score,
407
+ sdmetrics=sdmetrics,
408
+ )
409
+
410
+ _run_on_gcp(
411
+ output_destination=output_destination,
412
+ synthesizers=synthesizers,
413
+ s3_client=s3_client,
414
+ job_args_list=job_args_list,
415
+ credentials=credentials,
416
+ compute_config=compute_config,
417
+ )
418
+
419
+
420
+ def _benchmark_single_table_compute_gcp(
421
+ output_destination,
422
+ credential_filepath,
423
+ compute_config=None,
424
+ synthesizers=DEFAULT_SINGLE_TABLE_SYNTHESIZERS,
425
+ sdv_datasets=DEFAULT_SINGLE_TABLE_DATASETS,
426
+ additional_datasets_folder=None,
427
+ limit_dataset_size=False,
428
+ compute_quality_score=True,
429
+ compute_diagnostic_score=True,
430
+ compute_privacy_score=False,
431
+ sdmetrics=None,
432
+ timeout=None,
433
+ ):
434
+ """Run the SDGym benchmark on GCP with the single-table modality.
435
+
436
+ Args:
437
+ output_destination (str):
438
+ The S3 URI where results will be stored.
439
+ credential_filepath (str or Path):
440
+ Path to the credentials file for AWS, GCP and SDV-Enterprise.
441
+ compute_config (dict, optional):
442
+ The compute configuration for the GCP instance. If None, default settings will be used.
443
+ synthesizers (list of dict, optional):
444
+ The synthesizers to use in the benchmark. Defaults to DEFAULT_SINGLE_TABLE_SYNTHESIZERS.
445
+ sdv_datasets (list of str, optional):
446
+ The SDV datasets to use in the benchmark. Defaults to DEFAULT_SINGLE_TABLE_DATASETS.
447
+ additional_datasets_folder (str or Path, optional):
448
+ Path to a folder containing additional datasets to include in the benchmark.
449
+ limit_dataset_size (bool, optional):
450
+ Whether to limit the size of datasets for faster benchmarking. Defaults to False
451
+ compute_quality_score (bool, optional):
452
+ Whether to compute the quality score. Defaults to True.
453
+ compute_diagnostic_score (bool, optional):
454
+ Whether to compute the diagnostic score. Defaults to True.
455
+ compute_privacy_score (bool, optional):
456
+ Whether to compute the privacy score. Defaults to True.
457
+ sdmetrics (list of str, optional):
458
+ The sdmetrics to use for evaluation. If None, default metrics will be used.
459
+ timeout (int, optional):
460
+ Timeout in seconds for each synthesizer-dataset run. If None, no timeout is applied
461
+ """
462
+ return _benchmark_compute_gcp(
463
+ output_destination=output_destination,
464
+ credential_filepath=credential_filepath,
465
+ compute_config=compute_config,
466
+ synthesizers=synthesizers,
467
+ sdv_datasets=sdv_datasets,
468
+ additional_datasets_folder=additional_datasets_folder,
469
+ limit_dataset_size=limit_dataset_size,
470
+ compute_quality_score=compute_quality_score,
471
+ compute_diagnostic_score=compute_diagnostic_score,
472
+ compute_privacy_score=compute_privacy_score,
473
+ sdmetrics=sdmetrics,
474
+ timeout=timeout,
475
+ modality='single_table',
476
+ )
477
+
478
+
479
+ def _benchmark_multi_table_compute_gcp(
480
+ output_destination,
481
+ credential_filepath,
482
+ compute_config=None,
483
+ synthesizers=DEFAULT_MULTI_TABLE_SYNTHESIZERS,
484
+ sdv_datasets=DEFAULT_MULTI_TABLE_DATASETS,
485
+ additional_datasets_folder=None,
486
+ limit_dataset_size=False,
487
+ compute_quality_score=True,
488
+ compute_diagnostic_score=True,
489
+ sdmetrics=None,
490
+ timeout=None,
491
+ ):
492
+ """Run the SDGym benchmark on GCP with the multi-table modality.
493
+
494
+ Args:
495
+ output_destination (str):
496
+ The S3 URI where results will be stored.
497
+ credential_filepath (str or Path):
498
+ Path to the credentials file for AWS, GCP and SDV-Enterprise.
499
+ compute_config (dict, optional):
500
+ The compute configuration for the GCP instance. If None, default settings will be used.
501
+ synthesizers (list of dict, optional):
502
+ The synthesizers to use in the benchmark. Defaults to DEFAULT_MULTI_TABLE_SYNTHESIZERS.
503
+ sdv_datasets (list of str, optional):
504
+ The SDV datasets to use in the benchmark. Defaults to DEFAULT_MULTI_TABLE_DATASETS.
505
+ additional_datasets_folder (str or Path, optional):
506
+ Path to a folder containing additional datasets to include in the benchmark.
507
+ limit_dataset_size (bool, optional):
508
+ Whether to limit the size of datasets for faster benchmarking. Defaults to False
509
+ compute_quality_score (bool, optional):
510
+ Whether to compute the quality score. Defaults to True.
511
+ compute_diagnostic_score (bool, optional):
512
+ Whether to compute the diagnostic score. Defaults to True.
513
+ sdmetrics (list of str, optional):
514
+ The sdmetrics to use for evaluation. If None, default metrics will be used.
515
+ timeout (int, optional):
516
+ Timeout in seconds for each synthesizer-dataset run. If None, no timeout is applied.
517
+ """
518
+ return _benchmark_compute_gcp(
519
+ output_destination=output_destination,
520
+ credential_filepath=credential_filepath,
521
+ compute_config=compute_config,
522
+ synthesizers=synthesizers,
523
+ sdv_datasets=sdv_datasets,
524
+ additional_datasets_folder=additional_datasets_folder,
525
+ limit_dataset_size=limit_dataset_size,
526
+ compute_quality_score=compute_quality_score,
527
+ compute_diagnostic_score=compute_diagnostic_score,
528
+ compute_privacy_score=False,
529
+ sdmetrics=sdmetrics,
530
+ timeout=timeout,
531
+ modality='multi_table',
532
+ )