metadentify 0.1.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
metadentify/args.py ADDED
@@ -0,0 +1,59 @@
1
+ import argparse
2
+
3
+
4
+ def get_base_parser(description: str = "") -> argparse.ArgumentParser:
5
+ parser = argparse.ArgumentParser(description=description)
6
+
7
+ parser.add_argument("--experiment_setup_path", type=str, default="")
8
+ parser.add_argument("--config_path", type=str, default="config.py")
9
+ parser.add_argument("--wandb_project", type=str, default="")
10
+ parser.add_argument("--wandb_entity", type=str, default="")
11
+ parser.add_argument("--progress_bar", action="store_true")
12
+ parser.add_argument("--savelast", action="store_true")
13
+ parser.add_argument("--overwrite_data", default=True, action=argparse.BooleanOptionalAction)
14
+ parser.add_argument("--disable_wandb", action="store_true")
15
+
16
+ parser.add_argument("--experiment_name", type=str, required=True)
17
+ parser.add_argument("--query_name", type=str, default="PATE")
18
+ parser.add_argument("--baseline_setting", type=str, default=None)
19
+
20
+ parser.add_argument("--backbone_type", type=str, default="q-cnp")
21
+ parser.add_argument("--embed_dim", type=int, default=64)
22
+ parser.add_argument("--num_heads", type=int, default=4)
23
+ parser.add_argument("--num_layers", type=int, default=2)
24
+ parser.add_argument("--num_tau_samples", type=int, default=8)
25
+ parser.add_argument("--source_embed_dim", type=int, default=16)
26
+ parser.add_argument("--output_dim", type=int, default=1)
27
+ parser.add_argument("--num_inducing_points", type=int, default=32)
28
+ parser.add_argument("--dropout", type=float, default=0.0)
29
+
30
+ parser.add_argument("--standardize", action="store_true")
31
+ parser.add_argument("--num_pytorch_workers", type=int, default=0)
32
+ parser.add_argument("--num_datagen_workers", type=int, default=-2)
33
+ parser.add_argument("--dataset_size", type=int, default=1000)
34
+ parser.add_argument("--num_query_points", type=int, default=10)
35
+ parser.add_argument("--num_scms", type=int, default=None)
36
+ parser.add_argument("--num_train_tasks", type=int, default=10000)
37
+ parser.add_argument("--num_val_tasks", type=int, default=1000)
38
+ parser.add_argument("--num_test_tasks", type=int, default=1000)
39
+ parser.add_argument("--tasks_per_file", type=int, default=500)
40
+ parser.add_argument("--batch_size", type=int, default=64)
41
+ parser.add_argument("--prefetch_factor", type=int, default=2)
42
+ parser.add_argument("--online_data", action="store_true")
43
+ parser.add_argument("--total_steps", type=int, default=None)
44
+
45
+ parser.add_argument(
46
+ "--val_metrics_normalized", default=True, action=argparse.BooleanOptionalAction
47
+ )
48
+ parser.add_argument("--plot_results", action="store_true")
49
+ parser.add_argument("--lr", type=float, default=1e-3)
50
+ parser.add_argument("--lambda_crossing_penalty", type=float, default=1.0)
51
+ parser.add_argument("--weight_decay", type=float, default=0.0)
52
+ parser.add_argument("--patience", type=int, default=30)
53
+ parser.add_argument("--max_epochs", type=int, default=100)
54
+ parser.add_argument("--run_baselines", action="store_true")
55
+ parser.add_argument("--save_dir", type=str, default="data")
56
+ parser.add_argument("--checkpoint_dir", type=str, default="checkpoints")
57
+ parser.add_argument("--plot_diagnostics", action="store_true")
58
+
59
+ return parser
@@ -0,0 +1,460 @@
1
+ from typing import Any
2
+
3
+ import numpy as np
4
+ from econml.dml import CausalForestDML, LinearDML
5
+ from econml.iv.dml import DMLIV
6
+ from sklearn.linear_model import LinearRegression, Ridge
7
+ from sklearn.neural_network import MLPRegressor
8
+
9
+
10
+ def confounder_linear_baseline(
11
+ x: np.ndarray,
12
+ t: np.ndarray,
13
+ y: np.ndarray,
14
+ query_x: np.ndarray | None = None,
15
+ x_sources: np.ndarray | None = None,
16
+ ) -> float | np.ndarray:
17
+ dataset_x = np.c_[t, x]
18
+ linmod = LinearRegression().fit(X=dataset_x, y=y)
19
+ estimate = linmod.coef_[0]
20
+
21
+ if query_x is not None:
22
+ estimate = np.ones((query_x.shape[0],)) * estimate
23
+
24
+ return estimate
25
+
26
+
27
+ def confounder_ridge_baseline(
28
+ x: np.ndarray,
29
+ t: np.ndarray,
30
+ y: np.ndarray,
31
+ query_x: np.ndarray | None = None,
32
+ x_sources: np.ndarray | None = None,
33
+ ) -> float | np.ndarray:
34
+ dataset_x = np.c_[t, x]
35
+ linmod = Ridge().fit(X=dataset_x, y=y)
36
+ estimate = linmod.coef_[0]
37
+
38
+ if query_x is not None:
39
+ estimate = np.ones((query_x.shape[0],)) * estimate
40
+
41
+ return estimate
42
+
43
+
44
+ def confounder_mlp_baseline(
45
+ x: np.ndarray,
46
+ t: np.ndarray,
47
+ y: np.ndarray,
48
+ query_x: np.ndarray | None = None,
49
+ x_sources: np.ndarray | None = None,
50
+ ) -> float | np.ndarray:
51
+ dataset_x = np.c_[t, x]
52
+ mlp = MLPRegressor(max_iter=3000).fit(X=dataset_x, y=y)
53
+ if np.array_equal(np.unique(t), [0, 1]):
54
+ t_sample = t * 0
55
+ else:
56
+ t_sample = np.random.uniform(low=t.min(), high=t.max(), size=x.shape[0])
57
+ control_x = np.c_[t_sample, x]
58
+ treatment_x = np.c_[t_sample + 1, x]
59
+ estimate = np.mean(mlp.predict(X=treatment_x) - mlp.predict(X=control_x))
60
+
61
+ if query_x is not None:
62
+ estimate = np.ones((query_x.shape[0],)) * estimate
63
+
64
+ return estimate
65
+
66
+
67
+ def treatment_only_linear_baseline(
68
+ x: np.ndarray,
69
+ t: np.ndarray,
70
+ y: np.ndarray,
71
+ query_x: np.ndarray | None = None,
72
+ x_sources: np.ndarray | None = None,
73
+ ) -> float | np.ndarray:
74
+ linmod = LinearRegression().fit(X=t.reshape(-1, 1), y=y)
75
+ estimate = linmod.coef_[0]
76
+
77
+ if query_x is not None:
78
+ estimate = np.ones((query_x.shape[0],)) * estimate
79
+
80
+ return estimate
81
+
82
+
83
+ def treatment_only_ridge_baseline(
84
+ x: np.ndarray,
85
+ t: np.ndarray,
86
+ y: np.ndarray,
87
+ query_x: np.ndarray | None = None,
88
+ x_sources: np.ndarray | None = None,
89
+ ) -> float | np.ndarray:
90
+ linmod = Ridge().fit(X=t.reshape(-1, 1), y=y)
91
+ estimate = linmod.coef_[0]
92
+
93
+ if query_x is not None:
94
+ estimate = np.ones((query_x.shape[0],)) * estimate
95
+
96
+ return estimate
97
+
98
+
99
+ def tsls_linear_baseline(
100
+ x: np.ndarray,
101
+ t: np.ndarray,
102
+ y: np.ndarray,
103
+ query_x: np.ndarray | None = None,
104
+ x_sources: np.ndarray | None = None,
105
+ ) -> float | np.ndarray:
106
+ stage1linmod = LinearRegression().fit(X=x, y=t)
107
+ predicted_t = stage1linmod.predict(X=x)
108
+ stagt2linmod = LinearRegression().fit(X=predicted_t.reshape(-1, 1), y=y)
109
+ estimate = stagt2linmod.coef_[0]
110
+ return estimate
111
+
112
+
113
+ def tsls_ridge_baseline(
114
+ x: np.ndarray,
115
+ t: np.ndarray,
116
+ y: np.ndarray,
117
+ query_x: np.ndarray | None = None,
118
+ x_sources: np.ndarray | None = None,
119
+ ) -> float | np.ndarray:
120
+ stage1linmod = Ridge().fit(X=x, y=t)
121
+ predicted_t = stage1linmod.predict(X=x)
122
+ stagt2linmod = Ridge().fit(X=predicted_t.reshape(-1, 1), y=y)
123
+ estimate = stagt2linmod.coef_[0]
124
+ return estimate
125
+
126
+
127
+ def tsls_mlp_baseline(
128
+ x: np.ndarray,
129
+ t: np.ndarray,
130
+ y: np.ndarray,
131
+ query_x: np.ndarray | None = None,
132
+ x_sources: np.ndarray | None = None,
133
+ ) -> float | np.ndarray:
134
+ stage1model = MLPRegressor().fit(X=x, y=t.ravel())
135
+ predicted_t = stage1model.predict(X=x)
136
+ predicted_t = np.expand_dims(predicted_t, -1)
137
+ stage2model = Ridge().fit(X=predicted_t, y=y)
138
+ estimate = stage2model.coef_[0]
139
+ return estimate
140
+
141
+
142
+ def proxy_linear_baseline(
143
+ x: np.ndarray,
144
+ t: np.ndarray,
145
+ y: np.ndarray,
146
+ query_x: np.ndarray | None = None,
147
+ x_sources: np.ndarray | None = None,
148
+ ) -> float | np.ndarray:
149
+ proxy_1 = x[:, [0]]
150
+ proxy_2 = x[:, [1]]
151
+ treatment_and_proxy_2 = np.c_[proxy_2, t]
152
+ stage1model = LinearRegression().fit(X=treatment_and_proxy_2, y=proxy_1)
153
+ predicted_proxy_1 = stage1model.predict(X=treatment_and_proxy_2)
154
+ treatment_and_predicted_proxy = np.c_[t, predicted_proxy_1]
155
+ stage2model = LinearRegression().fit(X=treatment_and_predicted_proxy, y=y)
156
+ estimate = stage2model.coef_[0]
157
+ return estimate
158
+
159
+
160
+ def proxy_ridge_baseline(
161
+ x: np.ndarray,
162
+ t: np.ndarray,
163
+ y: np.ndarray,
164
+ query_x: np.ndarray | None = None,
165
+ x_sources: np.ndarray | None = None,
166
+ ) -> float | np.ndarray:
167
+ proxy_1 = x[:, [0]]
168
+ proxy_2 = x[:, [1]]
169
+ treatment_and_proxy_2 = np.c_[proxy_2, t]
170
+ stage1model = Ridge().fit(X=treatment_and_proxy_2, y=proxy_1)
171
+ predicted_proxy_1 = stage1model.predict(X=treatment_and_proxy_2)
172
+ treatment_and_predicted_proxy = np.c_[t, predicted_proxy_1]
173
+ stage2model = Ridge().fit(X=treatment_and_predicted_proxy, y=y)
174
+ estimate = stage2model.coef_[0]
175
+ return estimate
176
+
177
+
178
+ def proxy_mlp_baseline(
179
+ x: np.ndarray,
180
+ t: np.ndarray,
181
+ y: np.ndarray,
182
+ query_x: np.ndarray | None = None,
183
+ x_sources: np.ndarray | None = None,
184
+ ) -> float | np.ndarray:
185
+ proxy_1 = x[:, [0]]
186
+ proxy_2 = x[:, [1]]
187
+ treatment_and_proxy_2 = np.c_[proxy_2, t]
188
+ stage1model = MLPRegressor().fit(X=treatment_and_proxy_2, y=proxy_1)
189
+ predicted_proxy_1 = stage1model.predict(X=treatment_and_proxy_2)
190
+ treatment_and_predicted_proxy = np.c_[t, predicted_proxy_1]
191
+ stage2model = Ridge().fit(X=treatment_and_predicted_proxy, y=y)
192
+ estimate = stage2model.coef_[0]
193
+ return estimate
194
+
195
+
196
+ def confounder_dml_baseline(
197
+ x: np.ndarray,
198
+ t: np.ndarray,
199
+ y: np.ndarray,
200
+ query_x: np.ndarray | None = None,
201
+ x_sources: np.ndarray | None = None,
202
+ type: str = "forest",
203
+ binary_t: bool = False,
204
+ ) -> float | np.ndarray:
205
+ if type == "forest":
206
+ model = CausalForestDML(
207
+ discrete_treatment=binary_t,
208
+ discrete_outcome=False,
209
+ n_jobs=None,
210
+ )
211
+ elif type == "linear":
212
+ model = LinearDML(
213
+ discrete_treatment=binary_t,
214
+ discrete_outcome=False,
215
+ )
216
+ model.fit(Y=y, T=t, X=x, W=None)
217
+ estimate = model.ate(X=x)
218
+ return estimate
219
+
220
+
221
+ def iv_dml_baseline(
222
+ x: np.ndarray,
223
+ t: np.ndarray,
224
+ y: np.ndarray,
225
+ query_x: np.ndarray | None = None,
226
+ x_sources: np.ndarray | None = None,
227
+ type: str = "forest",
228
+ binary_t: bool = False,
229
+ ) -> float | np.ndarray:
230
+ if type == "forest":
231
+ model = DMLIV(
232
+ discrete_treatment=binary_t,
233
+ discrete_instrument=False,
234
+ model_y_xw="forest",
235
+ model_t_xw="forest",
236
+ model_t_xwz="forest",
237
+ )
238
+ elif type == "linear":
239
+ model = DMLIV(
240
+ discrete_treatment=binary_t,
241
+ discrete_instrument=False,
242
+ model_y_xw="linear",
243
+ model_t_xw="linear",
244
+ model_t_xwz="linear",
245
+ )
246
+ model.fit(Y=y, T=t, Z=x, X=None, W=None)
247
+ estimate = model.ate()
248
+ return estimate
249
+
250
+
251
+ def cate_confounder_linear_baseline(
252
+ x: np.ndarray,
253
+ t: np.ndarray,
254
+ y: np.ndarray,
255
+ query_x: np.ndarray,
256
+ x_sources: np.ndarray | None = None,
257
+ ) -> float | np.ndarray:
258
+ dataset_x = np.c_[x * t, t]
259
+ linmod = LinearRegression().fit(X=dataset_x, y=y)
260
+ estimate = linmod.coef_[0] * query_x + linmod.coef_[1]
261
+ return estimate
262
+
263
+
264
+ def cate_confounder_ridge_baseline(
265
+ x: np.ndarray,
266
+ t: np.ndarray,
267
+ y: np.ndarray,
268
+ query_x: np.ndarray,
269
+ x_sources: np.ndarray | None = None,
270
+ ) -> float | np.ndarray:
271
+ dataset_x = np.c_[x * t, t]
272
+ linmod = Ridge().fit(X=dataset_x, y=y)
273
+ estimate = query_x * linmod.coef_[0] + linmod.coef_[1]
274
+ return estimate
275
+
276
+
277
+ def cate_confounder_mlp_baseline(
278
+ x: np.ndarray,
279
+ t: np.ndarray,
280
+ y: np.ndarray,
281
+ query_x: np.ndarray,
282
+ num_t_samples: int = 1000,
283
+ x_sources: np.ndarray | None = None,
284
+ ) -> list[float]:
285
+ dataset_x = np.c_[t, x]
286
+ mlp = MLPRegressor(max_iter=3000).fit(X=dataset_x, y=y)
287
+ if np.array_equal(np.unique(t), [0, 1]):
288
+ t_sample = np.zeros((num_t_samples, 1))
289
+ else:
290
+ t_sample = np.random.uniform(low=t.min(), high=t.max(), size=(num_t_samples, 1))
291
+
292
+ estimates = []
293
+ for i in range(query_x.shape[0]):
294
+ query_row = np.array(query_x[i]).reshape(1, -1)
295
+ query_point = np.repeat(query_row, repeats=num_t_samples, axis=0)
296
+ control_x = np.c_[t_sample, query_point]
297
+ treatment_x = np.c_[t_sample + 1, query_point]
298
+ estimate = np.mean(mlp.predict(X=treatment_x) - mlp.predict(X=control_x))
299
+ estimates.append(estimate)
300
+
301
+ return estimates
302
+
303
+
304
+ def get_baselines(baseline_setting: str) -> list[dict[str, Any]]:
305
+ t_only_dict = {
306
+ "model": treatment_only_ridge_baseline,
307
+ "name": "t_only_ridge",
308
+ "alias": "T-Only-Ridge",
309
+ }
310
+
311
+ baselines = [t_only_dict]
312
+
313
+ if baseline_setting in [
314
+ "t-only",
315
+ "t-only-double-invalid-ocp",
316
+ "t-only-iv-invalid-tcp",
317
+ ]:
318
+ return baselines
319
+
320
+ elif baseline_setting == "confounder":
321
+ baseline_1_model = confounder_ridge_baseline
322
+ baseline_1_name = "ridge_regression"
323
+ baseline_1_alias = "Reg-Ridge"
324
+
325
+ baseline_2_model = confounder_mlp_baseline
326
+ baseline_2_name = "mlp_regression"
327
+ baseline_2_alias = "Reg-MLP"
328
+
329
+ elif baseline_setting == "iv":
330
+ baseline_1_model = tsls_ridge_baseline
331
+ baseline_1_name = "tsls_ridge"
332
+ baseline_1_alias = "TSLS-Ridge"
333
+
334
+ baseline_2_model = tsls_mlp_baseline
335
+ baseline_2_name = "mlp_tsls"
336
+ baseline_2_alias = "TSLS-MLP"
337
+
338
+ elif baseline_setting in ["proxy", "double-invalid-ocp"]:
339
+ baseline_1_model = proxy_ridge_baseline
340
+ baseline_1_name = "proxy_tsls_ridge"
341
+ baseline_1_alias = "PrTSLS-Ridge"
342
+
343
+ baseline_2_model = proxy_mlp_baseline
344
+ baseline_2_name = "mlp_proxy_tsls"
345
+ baseline_2_alias = "PrTSLS-MLP"
346
+
347
+ elif baseline_setting == "confounder-linreg":
348
+ baseline_1_model = confounder_linear_baseline
349
+ baseline_1_name = "linear_regression"
350
+ baseline_1_alias = "Reg-Lin"
351
+
352
+ baseline_2_model = confounder_mlp_baseline
353
+ baseline_2_name = "mlp_regression"
354
+ baseline_2_alias = "Reg-MLP"
355
+
356
+ elif baseline_setting == "iv-linreg":
357
+ baseline_1_model = tsls_linear_baseline
358
+ baseline_1_name = "tsls_linear"
359
+ baseline_1_alias = "TSLS-Lin"
360
+
361
+ baseline_2_model = tsls_mlp_baseline
362
+ baseline_2_name = "mlp_tsls"
363
+ baseline_2_alias = "TSLS-MLP"
364
+
365
+ elif baseline_setting in ["proxy-linreg", "double-invalid-ocp-linreg"]:
366
+ baseline_1_model = proxy_linear_baseline
367
+ baseline_1_name = "proxy_tsls_linear"
368
+ baseline_1_alias = "PrTSLS-Lin"
369
+
370
+ baseline_2_model = proxy_mlp_baseline
371
+ baseline_2_name = "mlp_proxy_tsls"
372
+ baseline_2_alias = "PrTSLS-MLP"
373
+
374
+ elif baseline_setting == "confounder-and-iv":
375
+ pass
376
+
377
+ elif baseline_setting == "iv-invalid-tcp":
378
+ pass
379
+
380
+ elif baseline_setting == "confounder-binary-t-dml":
381
+
382
+ def baseline_1_model(x, t, y, query_x):
383
+ return confounder_dml_baseline(x, t, y, query_x, type="linear", binary_t=True)
384
+
385
+ baseline_1_name = "linear_dml"
386
+ baseline_1_alias = "DML-Lin"
387
+
388
+ def baseline_2_model(x, t, y, query_x):
389
+ return confounder_dml_baseline(x, t, y, query_x, type="forest", binary_t=True)
390
+
391
+ baseline_2_name = "forest_dml"
392
+ baseline_2_alias = "DML-RF"
393
+
394
+ elif baseline_setting == "confounder-continuous-t-dml":
395
+
396
+ def baseline_1_model(x, t, y, query_x):
397
+ return confounder_dml_baseline(x, t, y, query_x, type="linear", binary_t=False)
398
+
399
+ baseline_1_name = "linear_dml"
400
+ baseline_1_alias = "DML-Lin"
401
+
402
+ def baseline_2_model(x, t, y, query_x):
403
+ return confounder_dml_baseline(x, t, y, query_x, type="forest", binary_t=False)
404
+
405
+ baseline_2_name = "forest_dml"
406
+ baseline_2_alias = "DML-RF"
407
+
408
+ elif baseline_setting == "iv-binary-t-dml":
409
+
410
+ def baseline_1_model(x, t, y, query_x):
411
+ return iv_dml_baseline(x, t, y, query_x, type="linear", binary_t=True)
412
+
413
+ baseline_1_name = "linear_dmliv"
414
+ baseline_1_alias = "DMLIV-Lin"
415
+
416
+ def baseline_2_model(x, t, y, query_x):
417
+ return iv_dml_baseline(x, t, y, query_x, type="forest", binary_t=True)
418
+
419
+ baseline_2_name = "forest_dmliv"
420
+ baseline_2_alias = "DMLIV-RF"
421
+
422
+ elif baseline_setting == "iv-continuous-t-dml":
423
+
424
+ def baseline_1_model(x, t, y, query_x):
425
+ return iv_dml_baseline(x, t, y, query_x, type="linear", binary_t=False)
426
+
427
+ baseline_1_name = "linear_dmliv"
428
+ baseline_1_alias = "DMLIV-Lin"
429
+
430
+ def baseline_2_model(x, t, y, query_x):
431
+ return iv_dml_baseline(x, t, y, query_x, type="forest", binary_t=False)
432
+
433
+ baseline_2_name = "forest_dmliv"
434
+ baseline_2_alias = "DMLIV-RF"
435
+
436
+ elif baseline_setting == "cate-confounder-linear":
437
+ baseline_1_model = cate_confounder_ridge_baseline
438
+ baseline_1_name = "cate_ridge"
439
+ baseline_1_alias = "CATE-Ridge"
440
+
441
+ baseline_2_model = cate_confounder_mlp_baseline
442
+ baseline_2_name = "cate_mlp"
443
+ baseline_2_alias = "CATE-MLP"
444
+
445
+ elif baseline_setting == "cate-confounder-nonlinear":
446
+ baseline_1_model = cate_confounder_ridge_baseline
447
+ baseline_1_name = "cate_ridge"
448
+ baseline_1_alias = "CATE-Ridge"
449
+
450
+ baseline_2_model = cate_confounder_mlp_baseline
451
+ baseline_2_name = "cate_mlp"
452
+ baseline_2_alias = "CATE-MLP"
453
+
454
+ else:
455
+ raise ValueError(f"Unsupported baseline_setting: {baseline_setting}")
456
+
457
+ baselines.append(dict(model=baseline_1_model, name=baseline_1_name, alias=baseline_1_alias))
458
+ baselines.append(dict(model=baseline_2_model, name=baseline_2_name, alias=baseline_2_alias))
459
+
460
+ return baselines
@@ -0,0 +1,153 @@
1
+ import argparse
2
+ import os
3
+ import re
4
+ import subprocess
5
+
6
+ from dotenv import load_dotenv
7
+
8
+ from metadentify.sbatch import format_sbatch_script
9
+
10
+
11
+ def initialize_wandb_sweep(yaml_config_path):
12
+ print(f"Initializing W&B sweep from {yaml_config_path} inside Singularity...")
13
+
14
+ singularity_env_path = os.environ.get("SINGULARITY_ENV_PATH")
15
+ singularity_container_path = os.environ.get("SINGULARITY_CONTAINER_PATH")
16
+
17
+ inner_command = f"source /ext3/env.sh; wandb sweep {yaml_config_path}"
18
+
19
+ singularity_cmd = [
20
+ "singularity",
21
+ "exec",
22
+ "--nv",
23
+ "--overlay",
24
+ singularity_env_path,
25
+ singularity_container_path,
26
+ "/bin/bash",
27
+ "-c",
28
+ inner_command,
29
+ ]
30
+
31
+ result = subprocess.run(singularity_cmd, capture_output=True, text=True)
32
+ output = result.stdout + result.stderr
33
+
34
+ if result.returncode != 0:
35
+ print(f"Error initializing sweep:\n{output}")
36
+ raise RuntimeError("Failed to create W&B sweep.")
37
+
38
+ match = re.search(r"wandb agent\s+([^\s]+)", output)
39
+ if not match:
40
+ print(f"Could not parse sweep ID from W&B output:\n{output}")
41
+ raise ValueError("Sweep ID regex match failed.")
42
+
43
+ sweep_id = match.group(1)
44
+ print(f"Successfully created sweep: {sweep_id}")
45
+ return sweep_id
46
+
47
+
48
+ def launch_sweep_agents(
49
+ sweep_id,
50
+ num_agents,
51
+ num_nodes,
52
+ tasks_per_node,
53
+ cpus_per_task,
54
+ num_hours,
55
+ gigs_memory,
56
+ num_gpus,
57
+ slurm_directory="slurm_scripts",
58
+ ):
59
+ os.makedirs(slurm_directory, exist_ok=True)
60
+ base_sweep_id = sweep_id.split("/")[-1]
61
+
62
+ sweep_out_dir = os.path.join("out", base_sweep_id)
63
+ os.makedirs(sweep_out_dir, exist_ok=True)
64
+
65
+ wandb_entity = os.environ.get("WANDB_ENTITY")
66
+ wandb_project = os.environ.get("WANDB_PROJECT")
67
+
68
+ local_run_command = f"wandb agent {sweep_id} --entity {wandb_entity} --project {wandb_project}"
69
+
70
+ for i in range(1, num_agents + 1):
71
+ task_name = f"swp_{base_sweep_id}_{i}"
72
+ job_name = f"swp_{base_sweep_id[:6]}_{i}"
73
+
74
+ sbatch_script = format_sbatch_script(
75
+ local_run_command=local_run_command,
76
+ out_dir=sweep_out_dir,
77
+ task_name=task_name,
78
+ job_name=job_name,
79
+ num_nodes=num_nodes,
80
+ tasks_per_node=tasks_per_node,
81
+ cpus_per_task=cpus_per_task,
82
+ num_hours=num_hours,
83
+ gigs_memory=gigs_memory,
84
+ num_gpus=num_gpus,
85
+ )
86
+
87
+ task_path = os.path.join(slurm_directory, f"{task_name}.sbatch")
88
+ with open(task_path, "w") as out_file:
89
+ out_file.write(sbatch_script)
90
+
91
+ print(f"[{i}/{num_agents}] Submitting SLURM task: {task_name}")
92
+ subprocess.run(f"sbatch {task_path}", shell=True)
93
+
94
+ print(f"All {num_agents} agents submitted to queue")
95
+ print(f"SLURM logs will be saved to: {sweep_out_dir}/")
96
+
97
+
98
+ def main():
99
+ load_dotenv()
100
+ required_keys = [
101
+ "SINGULARITY_ENV_PATH",
102
+ "SINGULARITY_CONTAINER_PATH",
103
+ "WANDB_ENTITY",
104
+ "WANDB_PROJECT",
105
+ ]
106
+ for key in required_keys:
107
+ if not os.environ.get(key):
108
+ raise ValueError(f"Missing {key} in .env file.")
109
+
110
+ parser = argparse.ArgumentParser(description="W&B Sweep Launcher for SLURM")
111
+
112
+ group = parser.add_mutually_exclusive_group(required=True)
113
+ group.add_argument(
114
+ "--config", type=str, help="Path to your sweep_config.yaml to start a NEW sweep"
115
+ )
116
+ group.add_argument(
117
+ "--sweep_id",
118
+ type=str,
119
+ help="Existing W&B sweep ID (e.g., entity/project/abcd1234) to RESUME an old sweep",
120
+ )
121
+ parser.add_argument(
122
+ "--num_jobs", type=int, default=4, help="Number of parallel SLURM jobs/agents to launch"
123
+ )
124
+
125
+ parser.add_argument("--nodes", type=int, default=1, help="Number of nodes per agent")
126
+ parser.add_argument("--tasks_per_node", type=int, default=1, help="Tasks per node")
127
+ parser.add_argument("--cpus", type=int, default=14, help="CPUs per task")
128
+ parser.add_argument("--hours", type=int, default=5, help="Wall time in hours")
129
+ parser.add_argument("--memory", type=int, default=32, help="Memory in GB")
130
+ parser.add_argument("--gpus", type=int, default=1, help="GPUs per agent")
131
+
132
+ args = parser.parse_args()
133
+
134
+ if args.config:
135
+ target_sweep_id = initialize_wandb_sweep(args.config)
136
+ else:
137
+ target_sweep_id = args.sweep_id
138
+ print(f"Attaching {args.num_jobs} new agents to existing sweep: {target_sweep_id}")
139
+
140
+ launch_sweep_agents(
141
+ sweep_id=target_sweep_id,
142
+ num_agents=args.num_jobs,
143
+ num_nodes=args.nodes,
144
+ tasks_per_node=args.tasks_per_node,
145
+ cpus_per_task=args.cpus,
146
+ num_hours=args.hours,
147
+ gigs_memory=args.memory,
148
+ num_gpus=args.gpus,
149
+ )
150
+
151
+
152
+ if __name__ == "__main__":
153
+ main()