metadentify 0.1.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metadentify/__init__.py +0 -0
- metadentify/args.py +59 -0
- metadentify/baselines.py +460 -0
- metadentify/launch_sweep.py +153 -0
- metadentify/mechanisms.py +167 -0
- metadentify/mixture.py +778 -0
- metadentify/modules.py +926 -0
- metadentify/py.typed +0 -0
- metadentify/queries.py +147 -0
- metadentify/run_experiment.py +94 -0
- metadentify/sbatch.py +52 -0
- metadentify/train.py +276 -0
- metadentify/utils.py +49 -0
- metadentify-0.1.0a0.dist-info/METADATA +220 -0
- metadentify-0.1.0a0.dist-info/RECORD +18 -0
- metadentify-0.1.0a0.dist-info/WHEEL +4 -0
- metadentify-0.1.0a0.dist-info/entry_points.txt +3 -0
- metadentify-0.1.0a0.dist-info/licenses/LICENSE +201 -0
metadentify/__init__.py
ADDED
|
File without changes
|
metadentify/args.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_base_parser(description: str = "") -> argparse.ArgumentParser:
|
|
5
|
+
parser = argparse.ArgumentParser(description=description)
|
|
6
|
+
|
|
7
|
+
parser.add_argument("--experiment_setup_path", type=str, default="")
|
|
8
|
+
parser.add_argument("--config_path", type=str, default="config.py")
|
|
9
|
+
parser.add_argument("--wandb_project", type=str, default="")
|
|
10
|
+
parser.add_argument("--wandb_entity", type=str, default="")
|
|
11
|
+
parser.add_argument("--progress_bar", action="store_true")
|
|
12
|
+
parser.add_argument("--savelast", action="store_true")
|
|
13
|
+
parser.add_argument("--overwrite_data", default=True, action=argparse.BooleanOptionalAction)
|
|
14
|
+
parser.add_argument("--disable_wandb", action="store_true")
|
|
15
|
+
|
|
16
|
+
parser.add_argument("--experiment_name", type=str, required=True)
|
|
17
|
+
parser.add_argument("--query_name", type=str, default="PATE")
|
|
18
|
+
parser.add_argument("--baseline_setting", type=str, default=None)
|
|
19
|
+
|
|
20
|
+
parser.add_argument("--backbone_type", type=str, default="q-cnp")
|
|
21
|
+
parser.add_argument("--embed_dim", type=int, default=64)
|
|
22
|
+
parser.add_argument("--num_heads", type=int, default=4)
|
|
23
|
+
parser.add_argument("--num_layers", type=int, default=2)
|
|
24
|
+
parser.add_argument("--num_tau_samples", type=int, default=8)
|
|
25
|
+
parser.add_argument("--source_embed_dim", type=int, default=16)
|
|
26
|
+
parser.add_argument("--output_dim", type=int, default=1)
|
|
27
|
+
parser.add_argument("--num_inducing_points", type=int, default=32)
|
|
28
|
+
parser.add_argument("--dropout", type=float, default=0.0)
|
|
29
|
+
|
|
30
|
+
parser.add_argument("--standardize", action="store_true")
|
|
31
|
+
parser.add_argument("--num_pytorch_workers", type=int, default=0)
|
|
32
|
+
parser.add_argument("--num_datagen_workers", type=int, default=-2)
|
|
33
|
+
parser.add_argument("--dataset_size", type=int, default=1000)
|
|
34
|
+
parser.add_argument("--num_query_points", type=int, default=10)
|
|
35
|
+
parser.add_argument("--num_scms", type=int, default=None)
|
|
36
|
+
parser.add_argument("--num_train_tasks", type=int, default=10000)
|
|
37
|
+
parser.add_argument("--num_val_tasks", type=int, default=1000)
|
|
38
|
+
parser.add_argument("--num_test_tasks", type=int, default=1000)
|
|
39
|
+
parser.add_argument("--tasks_per_file", type=int, default=500)
|
|
40
|
+
parser.add_argument("--batch_size", type=int, default=64)
|
|
41
|
+
parser.add_argument("--prefetch_factor", type=int, default=2)
|
|
42
|
+
parser.add_argument("--online_data", action="store_true")
|
|
43
|
+
parser.add_argument("--total_steps", type=int, default=None)
|
|
44
|
+
|
|
45
|
+
parser.add_argument(
|
|
46
|
+
"--val_metrics_normalized", default=True, action=argparse.BooleanOptionalAction
|
|
47
|
+
)
|
|
48
|
+
parser.add_argument("--plot_results", action="store_true")
|
|
49
|
+
parser.add_argument("--lr", type=float, default=1e-3)
|
|
50
|
+
parser.add_argument("--lambda_crossing_penalty", type=float, default=1.0)
|
|
51
|
+
parser.add_argument("--weight_decay", type=float, default=0.0)
|
|
52
|
+
parser.add_argument("--patience", type=int, default=30)
|
|
53
|
+
parser.add_argument("--max_epochs", type=int, default=100)
|
|
54
|
+
parser.add_argument("--run_baselines", action="store_true")
|
|
55
|
+
parser.add_argument("--save_dir", type=str, default="data")
|
|
56
|
+
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints")
|
|
57
|
+
parser.add_argument("--plot_diagnostics", action="store_true")
|
|
58
|
+
|
|
59
|
+
return parser
|
metadentify/baselines.py
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from econml.dml import CausalForestDML, LinearDML
|
|
5
|
+
from econml.iv.dml import DMLIV
|
|
6
|
+
from sklearn.linear_model import LinearRegression, Ridge
|
|
7
|
+
from sklearn.neural_network import MLPRegressor
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def confounder_linear_baseline(
|
|
11
|
+
x: np.ndarray,
|
|
12
|
+
t: np.ndarray,
|
|
13
|
+
y: np.ndarray,
|
|
14
|
+
query_x: np.ndarray | None = None,
|
|
15
|
+
x_sources: np.ndarray | None = None,
|
|
16
|
+
) -> float | np.ndarray:
|
|
17
|
+
dataset_x = np.c_[t, x]
|
|
18
|
+
linmod = LinearRegression().fit(X=dataset_x, y=y)
|
|
19
|
+
estimate = linmod.coef_[0]
|
|
20
|
+
|
|
21
|
+
if query_x is not None:
|
|
22
|
+
estimate = np.ones((query_x.shape[0],)) * estimate
|
|
23
|
+
|
|
24
|
+
return estimate
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def confounder_ridge_baseline(
|
|
28
|
+
x: np.ndarray,
|
|
29
|
+
t: np.ndarray,
|
|
30
|
+
y: np.ndarray,
|
|
31
|
+
query_x: np.ndarray | None = None,
|
|
32
|
+
x_sources: np.ndarray | None = None,
|
|
33
|
+
) -> float | np.ndarray:
|
|
34
|
+
dataset_x = np.c_[t, x]
|
|
35
|
+
linmod = Ridge().fit(X=dataset_x, y=y)
|
|
36
|
+
estimate = linmod.coef_[0]
|
|
37
|
+
|
|
38
|
+
if query_x is not None:
|
|
39
|
+
estimate = np.ones((query_x.shape[0],)) * estimate
|
|
40
|
+
|
|
41
|
+
return estimate
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def confounder_mlp_baseline(
|
|
45
|
+
x: np.ndarray,
|
|
46
|
+
t: np.ndarray,
|
|
47
|
+
y: np.ndarray,
|
|
48
|
+
query_x: np.ndarray | None = None,
|
|
49
|
+
x_sources: np.ndarray | None = None,
|
|
50
|
+
) -> float | np.ndarray:
|
|
51
|
+
dataset_x = np.c_[t, x]
|
|
52
|
+
mlp = MLPRegressor(max_iter=3000).fit(X=dataset_x, y=y)
|
|
53
|
+
if np.array_equal(np.unique(t), [0, 1]):
|
|
54
|
+
t_sample = t * 0
|
|
55
|
+
else:
|
|
56
|
+
t_sample = np.random.uniform(low=t.min(), high=t.max(), size=x.shape[0])
|
|
57
|
+
control_x = np.c_[t_sample, x]
|
|
58
|
+
treatment_x = np.c_[t_sample + 1, x]
|
|
59
|
+
estimate = np.mean(mlp.predict(X=treatment_x) - mlp.predict(X=control_x))
|
|
60
|
+
|
|
61
|
+
if query_x is not None:
|
|
62
|
+
estimate = np.ones((query_x.shape[0],)) * estimate
|
|
63
|
+
|
|
64
|
+
return estimate
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def treatment_only_linear_baseline(
|
|
68
|
+
x: np.ndarray,
|
|
69
|
+
t: np.ndarray,
|
|
70
|
+
y: np.ndarray,
|
|
71
|
+
query_x: np.ndarray | None = None,
|
|
72
|
+
x_sources: np.ndarray | None = None,
|
|
73
|
+
) -> float | np.ndarray:
|
|
74
|
+
linmod = LinearRegression().fit(X=t.reshape(-1, 1), y=y)
|
|
75
|
+
estimate = linmod.coef_[0]
|
|
76
|
+
|
|
77
|
+
if query_x is not None:
|
|
78
|
+
estimate = np.ones((query_x.shape[0],)) * estimate
|
|
79
|
+
|
|
80
|
+
return estimate
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def treatment_only_ridge_baseline(
|
|
84
|
+
x: np.ndarray,
|
|
85
|
+
t: np.ndarray,
|
|
86
|
+
y: np.ndarray,
|
|
87
|
+
query_x: np.ndarray | None = None,
|
|
88
|
+
x_sources: np.ndarray | None = None,
|
|
89
|
+
) -> float | np.ndarray:
|
|
90
|
+
linmod = Ridge().fit(X=t.reshape(-1, 1), y=y)
|
|
91
|
+
estimate = linmod.coef_[0]
|
|
92
|
+
|
|
93
|
+
if query_x is not None:
|
|
94
|
+
estimate = np.ones((query_x.shape[0],)) * estimate
|
|
95
|
+
|
|
96
|
+
return estimate
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def tsls_linear_baseline(
|
|
100
|
+
x: np.ndarray,
|
|
101
|
+
t: np.ndarray,
|
|
102
|
+
y: np.ndarray,
|
|
103
|
+
query_x: np.ndarray | None = None,
|
|
104
|
+
x_sources: np.ndarray | None = None,
|
|
105
|
+
) -> float | np.ndarray:
|
|
106
|
+
stage1linmod = LinearRegression().fit(X=x, y=t)
|
|
107
|
+
predicted_t = stage1linmod.predict(X=x)
|
|
108
|
+
stagt2linmod = LinearRegression().fit(X=predicted_t.reshape(-1, 1), y=y)
|
|
109
|
+
estimate = stagt2linmod.coef_[0]
|
|
110
|
+
return estimate
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def tsls_ridge_baseline(
|
|
114
|
+
x: np.ndarray,
|
|
115
|
+
t: np.ndarray,
|
|
116
|
+
y: np.ndarray,
|
|
117
|
+
query_x: np.ndarray | None = None,
|
|
118
|
+
x_sources: np.ndarray | None = None,
|
|
119
|
+
) -> float | np.ndarray:
|
|
120
|
+
stage1linmod = Ridge().fit(X=x, y=t)
|
|
121
|
+
predicted_t = stage1linmod.predict(X=x)
|
|
122
|
+
stagt2linmod = Ridge().fit(X=predicted_t.reshape(-1, 1), y=y)
|
|
123
|
+
estimate = stagt2linmod.coef_[0]
|
|
124
|
+
return estimate
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def tsls_mlp_baseline(
|
|
128
|
+
x: np.ndarray,
|
|
129
|
+
t: np.ndarray,
|
|
130
|
+
y: np.ndarray,
|
|
131
|
+
query_x: np.ndarray | None = None,
|
|
132
|
+
x_sources: np.ndarray | None = None,
|
|
133
|
+
) -> float | np.ndarray:
|
|
134
|
+
stage1model = MLPRegressor().fit(X=x, y=t.ravel())
|
|
135
|
+
predicted_t = stage1model.predict(X=x)
|
|
136
|
+
predicted_t = np.expand_dims(predicted_t, -1)
|
|
137
|
+
stage2model = Ridge().fit(X=predicted_t, y=y)
|
|
138
|
+
estimate = stage2model.coef_[0]
|
|
139
|
+
return estimate
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def proxy_linear_baseline(
|
|
143
|
+
x: np.ndarray,
|
|
144
|
+
t: np.ndarray,
|
|
145
|
+
y: np.ndarray,
|
|
146
|
+
query_x: np.ndarray | None = None,
|
|
147
|
+
x_sources: np.ndarray | None = None,
|
|
148
|
+
) -> float | np.ndarray:
|
|
149
|
+
proxy_1 = x[:, [0]]
|
|
150
|
+
proxy_2 = x[:, [1]]
|
|
151
|
+
treatment_and_proxy_2 = np.c_[proxy_2, t]
|
|
152
|
+
stage1model = LinearRegression().fit(X=treatment_and_proxy_2, y=proxy_1)
|
|
153
|
+
predicted_proxy_1 = stage1model.predict(X=treatment_and_proxy_2)
|
|
154
|
+
treatment_and_predicted_proxy = np.c_[t, predicted_proxy_1]
|
|
155
|
+
stage2model = LinearRegression().fit(X=treatment_and_predicted_proxy, y=y)
|
|
156
|
+
estimate = stage2model.coef_[0]
|
|
157
|
+
return estimate
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def proxy_ridge_baseline(
|
|
161
|
+
x: np.ndarray,
|
|
162
|
+
t: np.ndarray,
|
|
163
|
+
y: np.ndarray,
|
|
164
|
+
query_x: np.ndarray | None = None,
|
|
165
|
+
x_sources: np.ndarray | None = None,
|
|
166
|
+
) -> float | np.ndarray:
|
|
167
|
+
proxy_1 = x[:, [0]]
|
|
168
|
+
proxy_2 = x[:, [1]]
|
|
169
|
+
treatment_and_proxy_2 = np.c_[proxy_2, t]
|
|
170
|
+
stage1model = Ridge().fit(X=treatment_and_proxy_2, y=proxy_1)
|
|
171
|
+
predicted_proxy_1 = stage1model.predict(X=treatment_and_proxy_2)
|
|
172
|
+
treatment_and_predicted_proxy = np.c_[t, predicted_proxy_1]
|
|
173
|
+
stage2model = Ridge().fit(X=treatment_and_predicted_proxy, y=y)
|
|
174
|
+
estimate = stage2model.coef_[0]
|
|
175
|
+
return estimate
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def proxy_mlp_baseline(
|
|
179
|
+
x: np.ndarray,
|
|
180
|
+
t: np.ndarray,
|
|
181
|
+
y: np.ndarray,
|
|
182
|
+
query_x: np.ndarray | None = None,
|
|
183
|
+
x_sources: np.ndarray | None = None,
|
|
184
|
+
) -> float | np.ndarray:
|
|
185
|
+
proxy_1 = x[:, [0]]
|
|
186
|
+
proxy_2 = x[:, [1]]
|
|
187
|
+
treatment_and_proxy_2 = np.c_[proxy_2, t]
|
|
188
|
+
stage1model = MLPRegressor().fit(X=treatment_and_proxy_2, y=proxy_1)
|
|
189
|
+
predicted_proxy_1 = stage1model.predict(X=treatment_and_proxy_2)
|
|
190
|
+
treatment_and_predicted_proxy = np.c_[t, predicted_proxy_1]
|
|
191
|
+
stage2model = Ridge().fit(X=treatment_and_predicted_proxy, y=y)
|
|
192
|
+
estimate = stage2model.coef_[0]
|
|
193
|
+
return estimate
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def confounder_dml_baseline(
|
|
197
|
+
x: np.ndarray,
|
|
198
|
+
t: np.ndarray,
|
|
199
|
+
y: np.ndarray,
|
|
200
|
+
query_x: np.ndarray | None = None,
|
|
201
|
+
x_sources: np.ndarray | None = None,
|
|
202
|
+
type: str = "forest",
|
|
203
|
+
binary_t: bool = False,
|
|
204
|
+
) -> float | np.ndarray:
|
|
205
|
+
if type == "forest":
|
|
206
|
+
model = CausalForestDML(
|
|
207
|
+
discrete_treatment=binary_t,
|
|
208
|
+
discrete_outcome=False,
|
|
209
|
+
n_jobs=None,
|
|
210
|
+
)
|
|
211
|
+
elif type == "linear":
|
|
212
|
+
model = LinearDML(
|
|
213
|
+
discrete_treatment=binary_t,
|
|
214
|
+
discrete_outcome=False,
|
|
215
|
+
)
|
|
216
|
+
model.fit(Y=y, T=t, X=x, W=None)
|
|
217
|
+
estimate = model.ate(X=x)
|
|
218
|
+
return estimate
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def iv_dml_baseline(
|
|
222
|
+
x: np.ndarray,
|
|
223
|
+
t: np.ndarray,
|
|
224
|
+
y: np.ndarray,
|
|
225
|
+
query_x: np.ndarray | None = None,
|
|
226
|
+
x_sources: np.ndarray | None = None,
|
|
227
|
+
type: str = "forest",
|
|
228
|
+
binary_t: bool = False,
|
|
229
|
+
) -> float | np.ndarray:
|
|
230
|
+
if type == "forest":
|
|
231
|
+
model = DMLIV(
|
|
232
|
+
discrete_treatment=binary_t,
|
|
233
|
+
discrete_instrument=False,
|
|
234
|
+
model_y_xw="forest",
|
|
235
|
+
model_t_xw="forest",
|
|
236
|
+
model_t_xwz="forest",
|
|
237
|
+
)
|
|
238
|
+
elif type == "linear":
|
|
239
|
+
model = DMLIV(
|
|
240
|
+
discrete_treatment=binary_t,
|
|
241
|
+
discrete_instrument=False,
|
|
242
|
+
model_y_xw="linear",
|
|
243
|
+
model_t_xw="linear",
|
|
244
|
+
model_t_xwz="linear",
|
|
245
|
+
)
|
|
246
|
+
model.fit(Y=y, T=t, Z=x, X=None, W=None)
|
|
247
|
+
estimate = model.ate()
|
|
248
|
+
return estimate
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def cate_confounder_linear_baseline(
|
|
252
|
+
x: np.ndarray,
|
|
253
|
+
t: np.ndarray,
|
|
254
|
+
y: np.ndarray,
|
|
255
|
+
query_x: np.ndarray,
|
|
256
|
+
x_sources: np.ndarray | None = None,
|
|
257
|
+
) -> float | np.ndarray:
|
|
258
|
+
dataset_x = np.c_[x * t, t]
|
|
259
|
+
linmod = LinearRegression().fit(X=dataset_x, y=y)
|
|
260
|
+
estimate = linmod.coef_[0] * query_x + linmod.coef_[1]
|
|
261
|
+
return estimate
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def cate_confounder_ridge_baseline(
|
|
265
|
+
x: np.ndarray,
|
|
266
|
+
t: np.ndarray,
|
|
267
|
+
y: np.ndarray,
|
|
268
|
+
query_x: np.ndarray,
|
|
269
|
+
x_sources: np.ndarray | None = None,
|
|
270
|
+
) -> float | np.ndarray:
|
|
271
|
+
dataset_x = np.c_[x * t, t]
|
|
272
|
+
linmod = Ridge().fit(X=dataset_x, y=y)
|
|
273
|
+
estimate = query_x * linmod.coef_[0] + linmod.coef_[1]
|
|
274
|
+
return estimate
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def cate_confounder_mlp_baseline(
|
|
278
|
+
x: np.ndarray,
|
|
279
|
+
t: np.ndarray,
|
|
280
|
+
y: np.ndarray,
|
|
281
|
+
query_x: np.ndarray,
|
|
282
|
+
num_t_samples: int = 1000,
|
|
283
|
+
x_sources: np.ndarray | None = None,
|
|
284
|
+
) -> list[float]:
|
|
285
|
+
dataset_x = np.c_[t, x]
|
|
286
|
+
mlp = MLPRegressor(max_iter=3000).fit(X=dataset_x, y=y)
|
|
287
|
+
if np.array_equal(np.unique(t), [0, 1]):
|
|
288
|
+
t_sample = np.zeros((num_t_samples, 1))
|
|
289
|
+
else:
|
|
290
|
+
t_sample = np.random.uniform(low=t.min(), high=t.max(), size=(num_t_samples, 1))
|
|
291
|
+
|
|
292
|
+
estimates = []
|
|
293
|
+
for i in range(query_x.shape[0]):
|
|
294
|
+
query_row = np.array(query_x[i]).reshape(1, -1)
|
|
295
|
+
query_point = np.repeat(query_row, repeats=num_t_samples, axis=0)
|
|
296
|
+
control_x = np.c_[t_sample, query_point]
|
|
297
|
+
treatment_x = np.c_[t_sample + 1, query_point]
|
|
298
|
+
estimate = np.mean(mlp.predict(X=treatment_x) - mlp.predict(X=control_x))
|
|
299
|
+
estimates.append(estimate)
|
|
300
|
+
|
|
301
|
+
return estimates
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def get_baselines(baseline_setting: str) -> list[dict[str, Any]]:
|
|
305
|
+
t_only_dict = {
|
|
306
|
+
"model": treatment_only_ridge_baseline,
|
|
307
|
+
"name": "t_only_ridge",
|
|
308
|
+
"alias": "T-Only-Ridge",
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
baselines = [t_only_dict]
|
|
312
|
+
|
|
313
|
+
if baseline_setting in [
|
|
314
|
+
"t-only",
|
|
315
|
+
"t-only-double-invalid-ocp",
|
|
316
|
+
"t-only-iv-invalid-tcp",
|
|
317
|
+
]:
|
|
318
|
+
return baselines
|
|
319
|
+
|
|
320
|
+
elif baseline_setting == "confounder":
|
|
321
|
+
baseline_1_model = confounder_ridge_baseline
|
|
322
|
+
baseline_1_name = "ridge_regression"
|
|
323
|
+
baseline_1_alias = "Reg-Ridge"
|
|
324
|
+
|
|
325
|
+
baseline_2_model = confounder_mlp_baseline
|
|
326
|
+
baseline_2_name = "mlp_regression"
|
|
327
|
+
baseline_2_alias = "Reg-MLP"
|
|
328
|
+
|
|
329
|
+
elif baseline_setting == "iv":
|
|
330
|
+
baseline_1_model = tsls_ridge_baseline
|
|
331
|
+
baseline_1_name = "tsls_ridge"
|
|
332
|
+
baseline_1_alias = "TSLS-Ridge"
|
|
333
|
+
|
|
334
|
+
baseline_2_model = tsls_mlp_baseline
|
|
335
|
+
baseline_2_name = "mlp_tsls"
|
|
336
|
+
baseline_2_alias = "TSLS-MLP"
|
|
337
|
+
|
|
338
|
+
elif baseline_setting in ["proxy", "double-invalid-ocp"]:
|
|
339
|
+
baseline_1_model = proxy_ridge_baseline
|
|
340
|
+
baseline_1_name = "proxy_tsls_ridge"
|
|
341
|
+
baseline_1_alias = "PrTSLS-Ridge"
|
|
342
|
+
|
|
343
|
+
baseline_2_model = proxy_mlp_baseline
|
|
344
|
+
baseline_2_name = "mlp_proxy_tsls"
|
|
345
|
+
baseline_2_alias = "PrTSLS-MLP"
|
|
346
|
+
|
|
347
|
+
elif baseline_setting == "confounder-linreg":
|
|
348
|
+
baseline_1_model = confounder_linear_baseline
|
|
349
|
+
baseline_1_name = "linear_regression"
|
|
350
|
+
baseline_1_alias = "Reg-Lin"
|
|
351
|
+
|
|
352
|
+
baseline_2_model = confounder_mlp_baseline
|
|
353
|
+
baseline_2_name = "mlp_regression"
|
|
354
|
+
baseline_2_alias = "Reg-MLP"
|
|
355
|
+
|
|
356
|
+
elif baseline_setting == "iv-linreg":
|
|
357
|
+
baseline_1_model = tsls_linear_baseline
|
|
358
|
+
baseline_1_name = "tsls_linear"
|
|
359
|
+
baseline_1_alias = "TSLS-Lin"
|
|
360
|
+
|
|
361
|
+
baseline_2_model = tsls_mlp_baseline
|
|
362
|
+
baseline_2_name = "mlp_tsls"
|
|
363
|
+
baseline_2_alias = "TSLS-MLP"
|
|
364
|
+
|
|
365
|
+
elif baseline_setting in ["proxy-linreg", "double-invalid-ocp-linreg"]:
|
|
366
|
+
baseline_1_model = proxy_linear_baseline
|
|
367
|
+
baseline_1_name = "proxy_tsls_linear"
|
|
368
|
+
baseline_1_alias = "PrTSLS-Lin"
|
|
369
|
+
|
|
370
|
+
baseline_2_model = proxy_mlp_baseline
|
|
371
|
+
baseline_2_name = "mlp_proxy_tsls"
|
|
372
|
+
baseline_2_alias = "PrTSLS-MLP"
|
|
373
|
+
|
|
374
|
+
elif baseline_setting == "confounder-and-iv":
|
|
375
|
+
pass
|
|
376
|
+
|
|
377
|
+
elif baseline_setting == "iv-invalid-tcp":
|
|
378
|
+
pass
|
|
379
|
+
|
|
380
|
+
elif baseline_setting == "confounder-binary-t-dml":
|
|
381
|
+
|
|
382
|
+
def baseline_1_model(x, t, y, query_x):
|
|
383
|
+
return confounder_dml_baseline(x, t, y, query_x, type="linear", binary_t=True)
|
|
384
|
+
|
|
385
|
+
baseline_1_name = "linear_dml"
|
|
386
|
+
baseline_1_alias = "DML-Lin"
|
|
387
|
+
|
|
388
|
+
def baseline_2_model(x, t, y, query_x):
|
|
389
|
+
return confounder_dml_baseline(x, t, y, query_x, type="forest", binary_t=True)
|
|
390
|
+
|
|
391
|
+
baseline_2_name = "forest_dml"
|
|
392
|
+
baseline_2_alias = "DML-RF"
|
|
393
|
+
|
|
394
|
+
elif baseline_setting == "confounder-continuous-t-dml":
|
|
395
|
+
|
|
396
|
+
def baseline_1_model(x, t, y, query_x):
|
|
397
|
+
return confounder_dml_baseline(x, t, y, query_x, type="linear", binary_t=False)
|
|
398
|
+
|
|
399
|
+
baseline_1_name = "linear_dml"
|
|
400
|
+
baseline_1_alias = "DML-Lin"
|
|
401
|
+
|
|
402
|
+
def baseline_2_model(x, t, y, query_x):
|
|
403
|
+
return confounder_dml_baseline(x, t, y, query_x, type="forest", binary_t=False)
|
|
404
|
+
|
|
405
|
+
baseline_2_name = "forest_dml"
|
|
406
|
+
baseline_2_alias = "DML-RF"
|
|
407
|
+
|
|
408
|
+
elif baseline_setting == "iv-binary-t-dml":
|
|
409
|
+
|
|
410
|
+
def baseline_1_model(x, t, y, query_x):
|
|
411
|
+
return iv_dml_baseline(x, t, y, query_x, type="linear", binary_t=True)
|
|
412
|
+
|
|
413
|
+
baseline_1_name = "linear_dmliv"
|
|
414
|
+
baseline_1_alias = "DMLIV-Lin"
|
|
415
|
+
|
|
416
|
+
def baseline_2_model(x, t, y, query_x):
|
|
417
|
+
return iv_dml_baseline(x, t, y, query_x, type="forest", binary_t=True)
|
|
418
|
+
|
|
419
|
+
baseline_2_name = "forest_dmliv"
|
|
420
|
+
baseline_2_alias = "DMLIV-RF"
|
|
421
|
+
|
|
422
|
+
elif baseline_setting == "iv-continuous-t-dml":
|
|
423
|
+
|
|
424
|
+
def baseline_1_model(x, t, y, query_x):
|
|
425
|
+
return iv_dml_baseline(x, t, y, query_x, type="linear", binary_t=False)
|
|
426
|
+
|
|
427
|
+
baseline_1_name = "linear_dmliv"
|
|
428
|
+
baseline_1_alias = "DMLIV-Lin"
|
|
429
|
+
|
|
430
|
+
def baseline_2_model(x, t, y, query_x):
|
|
431
|
+
return iv_dml_baseline(x, t, y, query_x, type="forest", binary_t=False)
|
|
432
|
+
|
|
433
|
+
baseline_2_name = "forest_dmliv"
|
|
434
|
+
baseline_2_alias = "DMLIV-RF"
|
|
435
|
+
|
|
436
|
+
elif baseline_setting == "cate-confounder-linear":
|
|
437
|
+
baseline_1_model = cate_confounder_ridge_baseline
|
|
438
|
+
baseline_1_name = "cate_ridge"
|
|
439
|
+
baseline_1_alias = "CATE-Ridge"
|
|
440
|
+
|
|
441
|
+
baseline_2_model = cate_confounder_mlp_baseline
|
|
442
|
+
baseline_2_name = "cate_mlp"
|
|
443
|
+
baseline_2_alias = "CATE-MLP"
|
|
444
|
+
|
|
445
|
+
elif baseline_setting == "cate-confounder-nonlinear":
|
|
446
|
+
baseline_1_model = cate_confounder_ridge_baseline
|
|
447
|
+
baseline_1_name = "cate_ridge"
|
|
448
|
+
baseline_1_alias = "CATE-Ridge"
|
|
449
|
+
|
|
450
|
+
baseline_2_model = cate_confounder_mlp_baseline
|
|
451
|
+
baseline_2_name = "cate_mlp"
|
|
452
|
+
baseline_2_alias = "CATE-MLP"
|
|
453
|
+
|
|
454
|
+
else:
|
|
455
|
+
raise ValueError(f"Unsupported baseline_setting: {baseline_setting}")
|
|
456
|
+
|
|
457
|
+
baselines.append(dict(model=baseline_1_model, name=baseline_1_name, alias=baseline_1_alias))
|
|
458
|
+
baselines.append(dict(model=baseline_2_model, name=baseline_2_name, alias=baseline_2_alias))
|
|
459
|
+
|
|
460
|
+
return baselines
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
import subprocess
|
|
5
|
+
|
|
6
|
+
from dotenv import load_dotenv
|
|
7
|
+
|
|
8
|
+
from metadentify.sbatch import format_sbatch_script
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def initialize_wandb_sweep(yaml_config_path):
|
|
12
|
+
print(f"Initializing W&B sweep from {yaml_config_path} inside Singularity...")
|
|
13
|
+
|
|
14
|
+
singularity_env_path = os.environ.get("SINGULARITY_ENV_PATH")
|
|
15
|
+
singularity_container_path = os.environ.get("SINGULARITY_CONTAINER_PATH")
|
|
16
|
+
|
|
17
|
+
inner_command = f"source /ext3/env.sh; wandb sweep {yaml_config_path}"
|
|
18
|
+
|
|
19
|
+
singularity_cmd = [
|
|
20
|
+
"singularity",
|
|
21
|
+
"exec",
|
|
22
|
+
"--nv",
|
|
23
|
+
"--overlay",
|
|
24
|
+
singularity_env_path,
|
|
25
|
+
singularity_container_path,
|
|
26
|
+
"/bin/bash",
|
|
27
|
+
"-c",
|
|
28
|
+
inner_command,
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
result = subprocess.run(singularity_cmd, capture_output=True, text=True)
|
|
32
|
+
output = result.stdout + result.stderr
|
|
33
|
+
|
|
34
|
+
if result.returncode != 0:
|
|
35
|
+
print(f"Error initializing sweep:\n{output}")
|
|
36
|
+
raise RuntimeError("Failed to create W&B sweep.")
|
|
37
|
+
|
|
38
|
+
match = re.search(r"wandb agent\s+([^\s]+)", output)
|
|
39
|
+
if not match:
|
|
40
|
+
print(f"Could not parse sweep ID from W&B output:\n{output}")
|
|
41
|
+
raise ValueError("Sweep ID regex match failed.")
|
|
42
|
+
|
|
43
|
+
sweep_id = match.group(1)
|
|
44
|
+
print(f"Successfully created sweep: {sweep_id}")
|
|
45
|
+
return sweep_id
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def launch_sweep_agents(
|
|
49
|
+
sweep_id,
|
|
50
|
+
num_agents,
|
|
51
|
+
num_nodes,
|
|
52
|
+
tasks_per_node,
|
|
53
|
+
cpus_per_task,
|
|
54
|
+
num_hours,
|
|
55
|
+
gigs_memory,
|
|
56
|
+
num_gpus,
|
|
57
|
+
slurm_directory="slurm_scripts",
|
|
58
|
+
):
|
|
59
|
+
os.makedirs(slurm_directory, exist_ok=True)
|
|
60
|
+
base_sweep_id = sweep_id.split("/")[-1]
|
|
61
|
+
|
|
62
|
+
sweep_out_dir = os.path.join("out", base_sweep_id)
|
|
63
|
+
os.makedirs(sweep_out_dir, exist_ok=True)
|
|
64
|
+
|
|
65
|
+
wandb_entity = os.environ.get("WANDB_ENTITY")
|
|
66
|
+
wandb_project = os.environ.get("WANDB_PROJECT")
|
|
67
|
+
|
|
68
|
+
local_run_command = f"wandb agent {sweep_id} --entity {wandb_entity} --project {wandb_project}"
|
|
69
|
+
|
|
70
|
+
for i in range(1, num_agents + 1):
|
|
71
|
+
task_name = f"swp_{base_sweep_id}_{i}"
|
|
72
|
+
job_name = f"swp_{base_sweep_id[:6]}_{i}"
|
|
73
|
+
|
|
74
|
+
sbatch_script = format_sbatch_script(
|
|
75
|
+
local_run_command=local_run_command,
|
|
76
|
+
out_dir=sweep_out_dir,
|
|
77
|
+
task_name=task_name,
|
|
78
|
+
job_name=job_name,
|
|
79
|
+
num_nodes=num_nodes,
|
|
80
|
+
tasks_per_node=tasks_per_node,
|
|
81
|
+
cpus_per_task=cpus_per_task,
|
|
82
|
+
num_hours=num_hours,
|
|
83
|
+
gigs_memory=gigs_memory,
|
|
84
|
+
num_gpus=num_gpus,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
task_path = os.path.join(slurm_directory, f"{task_name}.sbatch")
|
|
88
|
+
with open(task_path, "w") as out_file:
|
|
89
|
+
out_file.write(sbatch_script)
|
|
90
|
+
|
|
91
|
+
print(f"[{i}/{num_agents}] Submitting SLURM task: {task_name}")
|
|
92
|
+
subprocess.run(f"sbatch {task_path}", shell=True)
|
|
93
|
+
|
|
94
|
+
print(f"All {num_agents} agents submitted to queue")
|
|
95
|
+
print(f"SLURM logs will be saved to: {sweep_out_dir}/")
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def main():
|
|
99
|
+
load_dotenv()
|
|
100
|
+
required_keys = [
|
|
101
|
+
"SINGULARITY_ENV_PATH",
|
|
102
|
+
"SINGULARITY_CONTAINER_PATH",
|
|
103
|
+
"WANDB_ENTITY",
|
|
104
|
+
"WANDB_PROJECT",
|
|
105
|
+
]
|
|
106
|
+
for key in required_keys:
|
|
107
|
+
if not os.environ.get(key):
|
|
108
|
+
raise ValueError(f"Missing {key} in .env file.")
|
|
109
|
+
|
|
110
|
+
parser = argparse.ArgumentParser(description="W&B Sweep Launcher for SLURM")
|
|
111
|
+
|
|
112
|
+
group = parser.add_mutually_exclusive_group(required=True)
|
|
113
|
+
group.add_argument(
|
|
114
|
+
"--config", type=str, help="Path to your sweep_config.yaml to start a NEW sweep"
|
|
115
|
+
)
|
|
116
|
+
group.add_argument(
|
|
117
|
+
"--sweep_id",
|
|
118
|
+
type=str,
|
|
119
|
+
help="Existing W&B sweep ID (e.g., entity/project/abcd1234) to RESUME an old sweep",
|
|
120
|
+
)
|
|
121
|
+
parser.add_argument(
|
|
122
|
+
"--num_jobs", type=int, default=4, help="Number of parallel SLURM jobs/agents to launch"
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
parser.add_argument("--nodes", type=int, default=1, help="Number of nodes per agent")
|
|
126
|
+
parser.add_argument("--tasks_per_node", type=int, default=1, help="Tasks per node")
|
|
127
|
+
parser.add_argument("--cpus", type=int, default=14, help="CPUs per task")
|
|
128
|
+
parser.add_argument("--hours", type=int, default=5, help="Wall time in hours")
|
|
129
|
+
parser.add_argument("--memory", type=int, default=32, help="Memory in GB")
|
|
130
|
+
parser.add_argument("--gpus", type=int, default=1, help="GPUs per agent")
|
|
131
|
+
|
|
132
|
+
args = parser.parse_args()
|
|
133
|
+
|
|
134
|
+
if args.config:
|
|
135
|
+
target_sweep_id = initialize_wandb_sweep(args.config)
|
|
136
|
+
else:
|
|
137
|
+
target_sweep_id = args.sweep_id
|
|
138
|
+
print(f"Attaching {args.num_jobs} new agents to existing sweep: {target_sweep_id}")
|
|
139
|
+
|
|
140
|
+
launch_sweep_agents(
|
|
141
|
+
sweep_id=target_sweep_id,
|
|
142
|
+
num_agents=args.num_jobs,
|
|
143
|
+
num_nodes=args.nodes,
|
|
144
|
+
tasks_per_node=args.tasks_per_node,
|
|
145
|
+
cpus_per_task=args.cpus,
|
|
146
|
+
num_hours=args.hours,
|
|
147
|
+
gigs_memory=args.memory,
|
|
148
|
+
num_gpus=args.gpus,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
if __name__ == "__main__":
|
|
153
|
+
main()
|