hjxdl 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. hdl/__init__.py +0 -0
  2. hdl/_version.py +16 -0
  3. hdl/args/__init__.py +0 -0
  4. hdl/args/loss_args.py +5 -0
  5. hdl/controllers/__init__.py +0 -0
  6. hdl/controllers/al/__init__.py +0 -0
  7. hdl/controllers/al/al.py +0 -0
  8. hdl/controllers/al/dispatcher.py +0 -0
  9. hdl/controllers/al/feedback.py +0 -0
  10. hdl/controllers/explain/__init__.py +0 -0
  11. hdl/controllers/explain/shapley.py +293 -0
  12. hdl/controllers/explain/subgraphx.py +865 -0
  13. hdl/controllers/train/__init__.py +0 -0
  14. hdl/controllers/train/rxn_train.py +219 -0
  15. hdl/controllers/train/train.py +50 -0
  16. hdl/controllers/train/train_ginet.py +316 -0
  17. hdl/controllers/train/trainer_base.py +155 -0
  18. hdl/controllers/train/trainer_iterative.py +389 -0
  19. hdl/data/__init__.py +0 -0
  20. hdl/data/dataset/__init__.py +0 -0
  21. hdl/data/dataset/base_dataset.py +98 -0
  22. hdl/data/dataset/fp/__init__.py +0 -0
  23. hdl/data/dataset/fp/fp_dataset.py +122 -0
  24. hdl/data/dataset/graph/__init__.py +0 -0
  25. hdl/data/dataset/graph/chiral.py +62 -0
  26. hdl/data/dataset/graph/gin.py +255 -0
  27. hdl/data/dataset/graph/molnet.py +362 -0
  28. hdl/data/dataset/loaders/__init__.py +0 -0
  29. hdl/data/dataset/loaders/chiral_graph.py +71 -0
  30. hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
  31. hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
  32. hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
  33. hdl/data/dataset/loaders/general.py +23 -0
  34. hdl/data/dataset/loaders/spliter.py +86 -0
  35. hdl/data/dataset/samplers/__init__.py +0 -0
  36. hdl/data/dataset/samplers/chiral.py +19 -0
  37. hdl/data/dataset/seq/__init__.py +0 -0
  38. hdl/data/dataset/seq/rxn_dataset.py +61 -0
  39. hdl/data/dataset/utils.py +31 -0
  40. hdl/data/to_mols.py +0 -0
  41. hdl/features/__init__.py +0 -0
  42. hdl/features/fp/__init__.py +0 -0
  43. hdl/features/fp/features_generators.py +235 -0
  44. hdl/features/graph/__init__.py +0 -0
  45. hdl/features/graph/featurization.py +297 -0
  46. hdl/features/utils/__init__.py +0 -0
  47. hdl/features/utils/utils.py +111 -0
  48. hdl/layers/__init__.py +0 -0
  49. hdl/layers/general/__init__.py +0 -0
  50. hdl/layers/general/gp.py +14 -0
  51. hdl/layers/general/linear.py +641 -0
  52. hdl/layers/graph/__init__.py +0 -0
  53. hdl/layers/graph/chiral_graph.py +230 -0
  54. hdl/layers/graph/gcn.py +16 -0
  55. hdl/layers/graph/gin.py +45 -0
  56. hdl/layers/graph/tetra.py +158 -0
  57. hdl/layers/graph/transformer.py +188 -0
  58. hdl/layers/sequential/__init__.py +0 -0
  59. hdl/metric_loss/__init__.py +0 -0
  60. hdl/metric_loss/loss.py +79 -0
  61. hdl/metric_loss/metric.py +178 -0
  62. hdl/metric_loss/multi_label.py +42 -0
  63. hdl/metric_loss/nt_xent.py +65 -0
  64. hdl/models/__init__.py +0 -0
  65. hdl/models/chiral_gnn.py +176 -0
  66. hdl/models/fast_transformer.py +234 -0
  67. hdl/models/ginet.py +189 -0
  68. hdl/models/linear.py +137 -0
  69. hdl/models/model_dict.py +18 -0
  70. hdl/models/norm_flows.py +33 -0
  71. hdl/models/optim_dict.py +16 -0
  72. hdl/models/rxn.py +63 -0
  73. hdl/models/utils.py +83 -0
  74. hdl/ops/__init__.py +0 -0
  75. hdl/ops/utils.py +42 -0
  76. hdl/optims/__init__.py +0 -0
  77. hdl/optims/nadam.py +86 -0
  78. hdl/utils/__init__.py +0 -0
  79. hdl/utils/chemical_tools/__init__.py +2 -0
  80. hdl/utils/chemical_tools/query_info.py +149 -0
  81. hdl/utils/chemical_tools/sdf.py +20 -0
  82. hdl/utils/database_tools/__init__.py +0 -0
  83. hdl/utils/database_tools/connect.py +28 -0
  84. hdl/utils/general/__init__.py +0 -0
  85. hdl/utils/general/glob.py +21 -0
  86. hdl/utils/schedulers/__init__.py +0 -0
  87. hdl/utils/schedulers/norm_lr.py +108 -0
  88. hjxdl-0.0.1.dist-info/METADATA +19 -0
  89. hjxdl-0.0.1.dist-info/RECORD +91 -0
  90. hjxdl-0.0.1.dist-info/WHEEL +5 -0
  91. hjxdl-0.0.1.dist-info/top_level.txt +1 -0
File without changes
@@ -0,0 +1,219 @@
1
+ from os import path as osp
2
+ import typing as t
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from hdl.models.rxn import build_rxn_mu
8
+ from hdl.models.utils import load_model, save_model
9
+ from hdl.data.dataset.seq.rxn_dataset import RXNCSVDataset
10
+ from hdl.data.dataset.loaders.rxn_loader import RXNLoader
11
+ from hdl.metric_loss.loss import mtmc_loss
12
+ from jupyfuncs.pbar import tnrange, tqdm
13
+ from jupyfuncs.glob import makedirs
14
+ # from hdl.optims.nadam import Nadam
15
+ from torch.optim import Adam
16
+ # from .trainer_base import TorchTrainer
17
+
18
+
19
+ def train_a_batch(
20
+ model,
21
+ batch_data,
22
+ loss_func,
23
+ optimizer,
24
+ device,
25
+ individual,
26
+ **kwargs
27
+ ):
28
+ optimizer.zero_grad()
29
+
30
+ X = [x.to(device) for x in batch_data[0]]
31
+ y = batch_data[1].T.to(device)
32
+
33
+ y_preds = model(X)
34
+ loss = mtmc_loss(
35
+ y_preds,
36
+ y,
37
+ loss_func,
38
+ individual=individual, **kwargs
39
+ )
40
+
41
+ if not individual:
42
+ final_loss = loss
43
+ individual_losses = []
44
+ else:
45
+ final_loss = loss[0]
46
+ individual_losses = loss[1]
47
+
48
+ final_loss.backward()
49
+ optimizer.step()
50
+
51
+ return final_loss, individual_losses
52
+
53
+
54
+ def train_an_epoch(
55
+ base_dir: str,
56
+ model,
57
+ data_loader,
58
+ epoch_id: int,
59
+ loss_func,
60
+ optimizer,
61
+ device,
62
+ num_warm_epochs: int = 0,
63
+ individual: bool = True,
64
+ **kwargs
65
+ ):
66
+ if epoch_id < num_warm_epochs:
67
+ model.freeze_encoder = True
68
+ else:
69
+ model.freeze_encoder = False
70
+
71
+ for batch in tqdm(data_loader):
72
+ loss, individual_losses = train_a_batch(
73
+ model=model,
74
+ batch_data=batch,
75
+ loss_func=loss_func,
76
+ optimizer=optimizer,
77
+ device=device,
78
+ individual=individual,
79
+ **kwargs
80
+ )
81
+ with open(
82
+ osp.join(base_dir, 'loss.log'),
83
+ 'a'
84
+ ) as f:
85
+ f.write(str(loss.item()))
86
+ f.write('\t')
87
+ for individual_loss in individual_losses:
88
+ f.write(str(individual_loss))
89
+ f.write('\t')
90
+ f.write('\n')
91
+
92
+ ckpt_file = osp.join(
93
+ base_dir,
94
+ f'model.{epoch_id}.ckpt'
95
+ )
96
+ save_model(
97
+ model=model,
98
+ save_dir=ckpt_file,
99
+ epoch=epoch_id,
100
+ optimizer=optimizer,
101
+ loss=loss,
102
+ )
103
+
104
+
105
+ def train_rxn(
106
+ base_dir,
107
+ model,
108
+ num_epochs,
109
+ loss_func,
110
+ data_loader,
111
+ optimizer,
112
+ device,
113
+ num_warm_epochs: int = 10,
114
+ ckpt_file: str = None,
115
+ individual: bool = True,
116
+ **kwargs
117
+ ):
118
+
119
+ epoch = 0
120
+ if ckpt_file is not None:
121
+
122
+ model, optimizer, epoch, _ = load_model(
123
+ ckpt_file,
124
+ model=model,
125
+ optimizer=optimizer,
126
+ train=True,
127
+ device=device,
128
+ )
129
+
130
+ for epoch_id in tnrange(num_epochs):
131
+
132
+ train_an_epoch(
133
+ base_dir=base_dir,
134
+ model=model,
135
+ data_loader=data_loader,
136
+ epoch_id=epoch + epoch_id,
137
+ loss_func=loss_func,
138
+ optimizer=optimizer,
139
+ num_warm_epochs=num_warm_epochs,
140
+ device=device,
141
+ individual=individual,
142
+ **kwargs
143
+ )
144
+
145
+
146
+ def rxn_engine(
147
+ base_dir: str,
148
+ csv_file: str,
149
+ splitter: str,
150
+ smiles_col: str,
151
+ hard: bool = False,
152
+ num_epochs: int = 20,
153
+ target_cols: t.List = [],
154
+ nums_classes: t.List = [],
155
+ loss_func: str = 'ce',
156
+ num_warm_epochs: int = 10,
157
+ batch_size: int = 128,
158
+ hidden_size: int = 128,
159
+ lr: float = 0.01,
160
+ num_hidden_layers: int = 10,
161
+ shuffle: bool = True,
162
+ num_workers: int = 12,
163
+ dim=-1,
164
+ out_act='softmax',
165
+ device_id: int = 0,
166
+ individual: bool = True,
167
+ **kwargs
168
+ ):
169
+
170
+ base_dir = osp.abspath(base_dir)
171
+ makedirs(base_dir)
172
+ model, device = build_rxn_mu(
173
+ nums_classes=nums_classes,
174
+ hard=hard,
175
+ hidden_size=hidden_size,
176
+ nums_hidden_layers=num_hidden_layers,
177
+ dim=dim,
178
+ out_act=out_act,
179
+ device_id=device_id
180
+ )
181
+ if torch.cuda.device_count() > 1:
182
+ model = nn.DataParallel(model)
183
+ model.train()
184
+
185
+ params = [{
186
+ 'params': model.parameters(),
187
+ 'lr': lr,
188
+ 'weight_decay': 0
189
+ }]
190
+ optimizer = Adam(params)
191
+
192
+ dataset = RXNCSVDataset(
193
+ csv_file=csv_file,
194
+ splitter=splitter,
195
+ smiles_col=smiles_col,
196
+ target_cols=target_cols,
197
+ )
198
+ data_loader = RXNLoader(
199
+ dataset=dataset,
200
+ batch_size=batch_size,
201
+ shuffle=shuffle,
202
+ num_workers=num_workers
203
+ )
204
+
205
+ train_rxn(
206
+ base_dir=base_dir,
207
+ model=model,
208
+ num_epochs=num_epochs,
209
+ loss_func=loss_func,
210
+ data_loader=data_loader,
211
+ optimizer=optimizer,
212
+ device=device,
213
+ num_warm_epochs=num_warm_epochs,
214
+ ckpt_file=None,
215
+ individual=individual,
216
+ **kwargs
217
+ )
218
+
219
+
@@ -0,0 +1,50 @@
1
+ from torch_geometric.data import DataLoader
2
+ import warnings
3
+ warnings.filterwarnings("ignore")
4
+
5
+ # Root mean squared error
6
+ loss_fn = torch.nn.MSELoss()
7
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.0007)
8
+
9
+ # Use GPU for training
10
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
11
+ model = model.to(device)
12
+
13
+ # Wrap data in a data loader
14
+ data_size = len(data)
15
+ NUM_GRAPHS_PER_BATCH = 64
16
+ loader = DataLoader(data[:int(data_size * 0.8)],
17
+ batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
18
+ test_loader = DataLoader(data[int(data_size * 0.8):],
19
+ batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
20
+
21
+ def train(data):
22
+ # Enumerate over the data
23
+ for batch in loader:
24
+ # Use GPU
25
+ batch.to(device)
26
+ # Reset gradients
27
+ optimizer.zero_grad()
28
+ # Passing the node features and the connection info
29
+ pred, embedding = model(batch.x.float(), batch.edge_index, batch.batch)
30
+ # Calculating the loss and gradients
31
+ loss = torch.sqrt(loss_fn(pred, batch.y))
32
+ loss.backward()
33
+ # Update using the gradients
34
+ optimizer.step()
35
+ return loss, embedding
36
+
37
+ print("Starting training...")
38
+ losses = []
39
+ for epoch in range(2000):
40
+ loss, h = train(data)
41
+ losses.append(loss)
42
+ if epoch % 100 == 0:
43
+ print(f"Epoch {epoch} | Train Loss {loss}")
44
+ # Visualize learning (training loss)
45
+ import seaborn as sns
46
+ losses_float = [float(loss.cpu().detach().numpy()) for loss in losses]
47
+ loss_indices = [i for i,l in enumerate(losses_float)]
48
+ plt = sns.lineplot(loss_indices, losses_float)
49
+ plt
50
+ As result we get something like this:
@@ -0,0 +1,316 @@
1
+ import typing as t
2
+ from os import path as osp
3
+ # from os import path as osp
4
+ from itertools import cycle
5
+ # import datetime
6
+
7
+ import torch
8
+ import numpy as np
9
+ import pandas as pd
10
+
11
+
12
+ # from jupyfuncs.glob import makedirs
13
+ from jupyfuncs.pbar import tnrange, tqdm
14
+ # from hdl.data.dataset.graph.gin import MoleculeDataset
15
+ from hdl.data.dataset.graph.gin import MoleculeDatasetWrapper
16
+ # from hdl.metric_loss.loss import get_lossfunc
17
+ # from hdl.models.utils import save_model
18
+ from .trainer_base import TorchTrainer
19
+
20
+
21
+ class GINTrainer(TorchTrainer):
22
+ def __init__(
23
+ self,
24
+ base_dir,
25
+ data_loader,
26
+ test_loader,
27
+ metrics: t.List[str] = ['rsquared', 'rmse', 'mae'],
28
+ loss_func: str = 'mse',
29
+ model=None,
30
+ model_name=None,
31
+ model_init_args=None,
32
+ ckpt_file=None,
33
+ model_ckpt=None,
34
+ fix_emb=True,
35
+ optimizer=None,
36
+ optimizer_name=None,
37
+ optimizer_kwargs=None,
38
+ device=torch.device('cpu'),
39
+ # logger=None
40
+ ) -> None:
41
+ super().__init__(
42
+ base_dir=base_dir,
43
+ data_loader=data_loader,
44
+ test_loader=test_loader,
45
+ metrics=metrics,
46
+ loss_func=loss_func,
47
+ model=model,
48
+ model_name=model_name,
49
+ model_init_args=model_init_args,
50
+ ckpt_file=ckpt_file,
51
+ model_ckpt=model_ckpt,
52
+ optimizer=optimizer,
53
+ optimizer_name=optimizer_name,
54
+ optimizer_kwargs=optimizer_kwargs,
55
+ device=device,
56
+ )
57
+ # self.loss_func = get_lossfunc(self.loss_func)
58
+ # self.metrics = [get_metric(metric) for metric in metrics]
59
+ if fix_emb:
60
+ for gin in self.model.gins:
61
+ for param in gin.parameters():
62
+ param.requires_grad = False
63
+
64
+ def train_a_batch(self, data):
65
+ self.optimizer.zero_grad()
66
+ for i in data[: -1]:
67
+ for j in i:
68
+ j.to(self.device)
69
+ y = data[-1].to(self.device)
70
+ y = y / 100
71
+
72
+ y_pred = self.model(data).flatten()
73
+
74
+ loss = self.loss_func(y_pred, y)
75
+
76
+ loss.backward()
77
+ self.optimizer.step()
78
+
79
+ return loss
80
+
81
+ def load_ckpt(self):
82
+ self.model.load_ckpt()
83
+
84
+ def train_an_epoch(
85
+ self,
86
+ ):
87
+ for i, (data, test_data) in enumerate(
88
+ zip(
89
+ self.data_loader,
90
+ cycle(self.test_loader)
91
+ )
92
+ ):
93
+ loss = self.train_a_batch(data)
94
+ self.losses.append(loss.item())
95
+ self.n_iter += 1
96
+ self.logger.add_scalar(
97
+ 'train_loss',
98
+ loss.item(),
99
+ global_step=self.n_iter
100
+ )
101
+
102
+ if self.n_iter % 10 == 0:
103
+ for i in test_data[: -1]:
104
+ for j in i:
105
+ j.to(self.device)
106
+ y = test_data[-1].to(self.device)
107
+ y = y / 100
108
+
109
+ y_pred = self.model(test_data).flatten()
110
+ valid_loss = self.loss_func(y_pred, y)
111
+
112
+ y_pred = y_pred.cpu().detach().numpy()
113
+ y = y.cpu().detach().numpy()
114
+
115
+ self.logger.add_scalar(
116
+ 'valid_loss',
117
+ valid_loss.item(),
118
+ global_step=self.n_iter
119
+ )
120
+
121
+ for metric_name, metric in zip(
122
+ self.metric_names,
123
+ self.metrics
124
+ ):
125
+ self.logger.add_scalar(
126
+ metric_name,
127
+ metric(y_pred, y),
128
+ global_step=self.n_iter
129
+ )
130
+
131
+ self.save()
132
+ self.epoch_id += 1
133
+
134
+ def train(self, num_epochs):
135
+ # dir_name = datetime.now().strftime('%b%d_%H-%M-%S')
136
+ # makedirs(osp.join(self.base_dir, dir_name))
137
+
138
+ for _ in tnrange(num_epochs):
139
+ self.train_an_epoch()
140
+
141
+ def predict(self, data_loader):
142
+ result_list = []
143
+ for data in tqdm(data_loader):
144
+ for i in data[: -1]:
145
+ for j in i:
146
+ j.to(self.device)
147
+ # print(data[0][0].x.device)
148
+ # for param in self.model.parameters():
149
+ # print(param.device)
150
+ # break
151
+ y_pred = self.model(data).flatten()
152
+ result_list.append(y_pred.cpu().detach().numpy())
153
+ results = np.hstack(result_list)
154
+ return results
155
+
156
+
157
+ def engine(
158
+ base_dir,
159
+ data_path,
160
+ test_data_path,
161
+ batch_size=128,
162
+ num_workers=64,
163
+ model_name='GINMLPR',
164
+ num_layers=5,
165
+ emb_dim=300,
166
+ feat_dim=512,
167
+ out_dim=1,
168
+ drop_ratio=0.0,
169
+ pool='mean',
170
+ ckpt_file=None,
171
+ fix_emb: bool = False,
172
+ device='cuda:1',
173
+ num_epochs=300,
174
+ optimizer_name='adam',
175
+ lr=0.001,
176
+ file_type: str = 'csv',
177
+ smiles_col_names: t.List = [],
178
+ y_col_name: str = None, # "yield (%)",
179
+ loss_func: str = 'mse',
180
+ metrics: t.List[str] = ['rsquared', 'rmse', 'mae'],
181
+ ):
182
+ model_init_args = {
183
+ "num_layer": num_layers,
184
+ "emb_dim": emb_dim,
185
+ "feat_dim": feat_dim,
186
+ "out_dim": out_dim,
187
+ "drop_ratio": drop_ratio,
188
+ "pool": pool,
189
+ "ckpt_file": ckpt_file,
190
+ "num_smiles": len(smiles_col_names),
191
+ }
192
+ wrapper = MoleculeDatasetWrapper(
193
+ batch_size=batch_size,
194
+ num_workers=num_workers,
195
+ valid_size=0,
196
+ data_path=data_path,
197
+ file_type=file_type,
198
+ smi_col_names=smiles_col_names,
199
+ y_col_name=y_col_name
200
+ )
201
+ test_wrapper = MoleculeDatasetWrapper(
202
+ batch_size=batch_size,
203
+ num_workers=num_workers,
204
+ valid_size=0,
205
+ data_path=test_data_path,
206
+ file_type=file_type,
207
+ smi_col_names=smiles_col_names,
208
+ y_col_name=y_col_name
209
+ )
210
+
211
+ data_loader = wrapper.get_test_loader(
212
+ shuffle=True
213
+ )
214
+ test_loader = test_wrapper.get_test_loader(
215
+ shuffle=False
216
+ )
217
+
218
+ trainer = GINTrainer(
219
+ base_dir=base_dir,
220
+ model_name=model_name,
221
+ model_init_args=model_init_args,
222
+ optimizer_name=optimizer_name,
223
+ ckpt_file=ckpt_file,
224
+ fix_emb=fix_emb,
225
+ optimizer_kwargs={"lr": lr},
226
+ data_loader=data_loader,
227
+ test_loader=test_loader,
228
+ metrics=metrics,
229
+ loss_func=loss_func,
230
+ device=device
231
+ )
232
+
233
+ trainer.train(num_epochs=num_epochs)
234
+
235
+
236
+ def predict(
237
+ base_dir,
238
+ data_path,
239
+ batch_size=128,
240
+ num_workers=64,
241
+ model_name='GINMLPR',
242
+ num_layers=5,
243
+ emb_dim=300,
244
+ feat_dim=512,
245
+ out_dim=1,
246
+ drop_ratio=0.0,
247
+ pool='mean',
248
+ ckpt_file=None,
249
+ model_ckpt=None,
250
+ device='cuda:1',
251
+ file_type: str = 'csv',
252
+ smiles_col_names: t.List = [],
253
+ y_col_name: str = None, # "yield (%)",
254
+ metrics: t.List[str] = ['rsquared', 'rmse', 'mae'],
255
+ ):
256
+ model_init_args = {
257
+ "num_layer": num_layers,
258
+ "emb_dim": emb_dim,
259
+ "feat_dim": feat_dim,
260
+ "out_dim": out_dim,
261
+ "drop_ratio": drop_ratio,
262
+ "pool": pool,
263
+ "ckpt_file": ckpt_file,
264
+ "num_smiles": len(smiles_col_names),
265
+ }
266
+ wrapper = MoleculeDatasetWrapper(
267
+ batch_size=batch_size,
268
+ num_workers=num_workers,
269
+ valid_size=0,
270
+ data_path=data_path,
271
+ file_type=file_type,
272
+ smi_col_names=smiles_col_names,
273
+ y_col_name=y_col_name
274
+ )
275
+ data_loader = wrapper.get_test_loader(
276
+ shuffle=False
277
+ )
278
+ trainer = GINTrainer(
279
+ base_dir=base_dir,
280
+ model_name=model_name,
281
+ model_init_args=model_init_args,
282
+ model_ckpt=model_ckpt,
283
+ data_loader=data_loader,
284
+ test_loader=None,
285
+ metrics=metrics,
286
+ device=device
287
+ )
288
+ metric_list = trainer.metrics
289
+ trainer.load(ckpt_file=model_ckpt)
290
+ trainer.model.eval()
291
+ results = trainer.predict(data_loader)
292
+
293
+ df = pd.read_csv(data_path)
294
+ df['pred'] = results
295
+ df.to_csv(
296
+ osp.join(base_dir, 'pred.csv'),
297
+ index=False
298
+ )
299
+
300
+ if y_col_name is not None:
301
+ metrics_df = pd.DataFrame()
302
+ y = df[y_col_name].array / 100
303
+ for metric_name, metric in zip(
304
+ metrics,
305
+ metric_list
306
+ ):
307
+ metrics_df[metric_name] = np.array([metric(
308
+ y, results
309
+ )])
310
+ metrics_df.to_csv(
311
+ osp.join(
312
+ base_dir, 'metrics.csv'
313
+ ),
314
+ index=False
315
+ )
316
+