ddi-fw 0.0.149__py3-none-any.whl → 0.0.150__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,83 +1,186 @@
1
- import mlflow
2
1
  import torch
3
- from ddi_fw.ml.evaluation_helper import evaluate
2
+ import torch.nn as nn
3
+ from torch.utils.data import DataLoader, TensorDataset
4
+ import mlflow
5
+ from typing import Any, Dict, Tuple
6
+ from ddi_fw.ml.evaluation_helper import Metrics, evaluate
4
7
  from ddi_fw.ml.model_wrapper import ModelWrapper
5
-
8
+ import ddi_fw.utils as utils
6
9
 
7
10
  class PTModelWrapper(ModelWrapper):
8
- def __init__(self, date, descriptor, model_func, batch_size=128, epochs=100, **kwargs):
9
- super().__init__(date, descriptor, model_func, batch_size, epochs)
11
+ def __init__(self, date, descriptor, model_func, **kwargs):
12
+ super().__init__(date, descriptor, model_func, **kwargs)
13
+ self.batch_size = kwargs.get('batch_size',128)
14
+ self.epochs = kwargs.get('epochs',100)
10
15
  self.optimizer = kwargs['optimizer']
11
16
  self.criterion = kwargs['criterion']
12
17
 
13
- def _create_dataloader(self, data, labels):
14
- dataset = torch.utils.data.TensorDataset(data, labels)
15
- return torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
16
-
17
- def predict(self):
18
- print(self.train_data.shape)
19
-
20
- with mlflow.start_run(run_name=self.descriptor, description="***", nested=True) as run:
21
- models = {}
22
- # models_val_acc = {}
23
-
24
- for i, (train_idx, val_idx) in enumerate(zip(self.train_idx_arr, self.val_idx_arr)):
25
- print(f"Validation {i}")
26
-
27
- with mlflow.start_run(run_name=f'Validation {i}', description='CV models', nested=True) as cv_fit:
28
- model = self.model_func(self.train_data.shape[1])
29
- models[f'validation_{i}'] = model
30
-
31
- # Create DataLoaders
32
- X_train_cv = torch.tensor(self.train_data[train_idx], dtype=torch.float16)
33
- y_train_cv = torch.tensor(self.train_label[train_idx], dtype=torch.float16)
34
- X_valid_cv = torch.tensor(self.train_data[val_idx], dtype=torch.float16)
35
- y_valid_cv = torch.tensor(self.train_label[val_idx], dtype=torch.float16)
36
-
37
- train_loader = self._create_dataloader(X_train_cv, y_train_cv)
38
- valid_loader = self._create_dataloader(X_valid_cv, y_valid_cv)
39
-
40
- optimizer = self.optimizer
41
- criterion = self.criterion
42
- best_val_loss = float('inf')
43
-
44
- for epoch in range(self.epochs):
45
- model.train()
46
- for batch_X, batch_y in train_loader:
47
- optimizer.zero_grad()
48
- output = model(batch_X)
49
- loss = criterion(output, batch_y)
50
- loss.backward()
51
- optimizer.step()
52
-
53
- model.eval()
54
- with torch.no_grad():
55
- val_loss = self._validate(model, valid_loader)
56
-
57
- # Callbacks after each epoch
58
- for callback in self.callbacks:
59
- callback.on_epoch_end(epoch, logs={'loss': loss.item(), 'val_loss': val_loss.item()})
60
-
61
- if val_loss < best_val_loss:
62
- best_val_loss = val_loss
63
- best_model = model
64
-
65
- # Evaluate on test data
66
- with torch.no_grad():
67
- pred = best_model(torch.tensor(self.test_data, dtype=torch.float16))
68
- logs, metrics = evaluate(
69
- actual=self.test_label, pred=pred.numpy(), info=self.descriptor)
70
- mlflow.log_metrics(logs)
71
-
72
- return logs, metrics, pred.numpy()
18
+ def fit_model(self, X_train, y_train, X_valid, y_valid):
19
+ self.model = self.model_func(self.train_data.shape[1])
20
+ train_dataset = TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.float32))
21
+ valid_dataset = TensorDataset(torch.tensor(X_valid, dtype=torch.float32), torch.tensor(y_valid, dtype=torch.float32))
22
+ train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
23
+ valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)
24
+
25
+ best_loss = float('inf')
26
+ best_model = None
27
+
28
+ for epoch in range(self.epochs):
29
+ self.model.train()
30
+ for batch_X, batch_y in train_loader:
31
+ self.optimizer.zero_grad()
32
+ output = self.model(batch_X)
33
+ loss = self.criterion(output, batch_y)
34
+ loss.backward()
35
+ self.optimizer.step()
36
+
37
+ valid_loss = self._validate(self.model, valid_loader)
38
+ if valid_loss < best_loss:
39
+ best_loss = valid_loss
40
+ best_model = self.model.state_dict()
41
+
42
+ self.model.load_state_dict(best_model)
43
+ return self.model, best_loss
73
44
 
74
45
  def _validate(self, model, valid_loader):
46
+ model.eval()
75
47
  total_loss = 0
76
- criterion = self.criterion
77
-
78
- for batch_X, batch_y in valid_loader:
79
- output = model(batch_X)
80
- loss = criterion(output, batch_y)
81
- total_loss += loss.item()
48
+ with torch.no_grad():
49
+ for batch_X, batch_y in valid_loader:
50
+ output = model(batch_X)
51
+ loss = self.criterion(output, batch_y)
52
+ total_loss += loss.item()
53
+ return total_loss / len(valid_loader)
54
+
55
+ def fit(self):
56
+ models = {}
57
+ models_val_acc = {}
58
+ for i, (train_idx, val_idx) in enumerate(zip(self.train_idx_arr, self.val_idx_arr)):
59
+ print(f"Validation {i}")
60
+ with mlflow.start_run(run_name=f'Validation {i}', description='CV models', nested=True) as cv_fit:
61
+ X_train_cv = self.train_data[train_idx]
62
+ y_train_cv = self.train_label[train_idx]
63
+ X_valid_cv = self.train_data[val_idx]
64
+ y_valid_cv = self.train_label[val_idx]
65
+ model, best_loss = self.fit_model(X_train_cv, y_train_cv, X_valid_cv, y_valid_cv)
66
+ models[f'{self.descriptor}_validation_{i}'] = model
67
+ models_val_acc[f'{self.descriptor}_validation_{i}'] = best_loss
68
+
69
+ best_model_key = min(models_val_acc, key=lambda k: models_val_acc[k])
70
+ best_model = models[best_model_key]
71
+ return best_model, best_model_key
82
72
 
83
- return total_loss / len(valid_loader)
73
+ def predict(self):
74
+ test_dataset = TensorDataset(torch.tensor(self.test_data, dtype=torch.float32), torch.tensor(self.test_label, dtype=torch.float32))
75
+ test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
76
+ self.model.eval()
77
+ preds = []
78
+ with torch.no_grad():
79
+ for batch_X, _ in test_loader:
80
+ output = self.model(batch_X)
81
+ preds.append(output)
82
+ return torch.cat(preds, dim=0).numpy()
83
+
84
+ def fit_and_evaluate(self) -> Tuple[Dict[str, Any], Metrics, Any]:
85
+ with mlflow.start_run(run_name=self.descriptor, description="***", nested=True) as run:
86
+ print(run.info.artifact_uri)
87
+ best_model, best_model_key = self.fit()
88
+ print(best_model_key)
89
+ self.best_model = best_model
90
+ pred = self.predict()
91
+ logs, metrics = evaluate(actual=self.test_label, pred=pred, info=self.descriptor)
92
+ metrics.format_float()
93
+ mlflow.log_metrics(logs)
94
+ mlflow.log_param('best_cv', best_model_key)
95
+ utils.compress_and_save_data(metrics.__dict__, run.info.artifact_uri, f'{self.date}_metrics.gzip')
96
+ mlflow.log_artifact(f'{run.info.artifact_uri}/{self.date}_metrics.gzip')
97
+
98
+ return logs, metrics, pred
99
+
100
+ # from typing import Any
101
+ # import mlflow
102
+ # import torch
103
+ # from ddi_fw.ml.evaluation_helper import Metrics, evaluate
104
+ # from ddi_fw.ml.model_wrapper import ModelWrapper
105
+
106
+
107
+ # class PTModelWrapper(ModelWrapper):
108
+ # def __init__(self, date, descriptor, model_func, batch_size=128, epochs=100, **kwargs):
109
+ # super().__init__(date, descriptor, model_func, batch_size, epochs)
110
+ # self.optimizer = kwargs['optimizer']
111
+ # self.criterion = kwargs['criterion']
112
+
113
+ # def _create_dataloader(self, data, labels):
114
+ # dataset = torch.utils.data.TensorDataset(data, labels)
115
+ # return torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
116
+
117
+ # def predict(self):
118
+ # print(self.train_data.shape)
119
+
120
+ # with mlflow.start_run(run_name=self.descriptor, description="***", nested=True) as run:
121
+ # models = {}
122
+ # # models_val_acc = {}
123
+
124
+ # for i, (train_idx, val_idx) in enumerate(zip(self.train_idx_arr, self.val_idx_arr)):
125
+ # print(f"Validation {i}")
126
+
127
+ # with mlflow.start_run(run_name=f'Validation {i}', description='CV models', nested=True) as cv_fit:
128
+ # model = self.model_func(self.train_data.shape[1])
129
+ # models[f'validation_{i}'] = model
130
+
131
+ # # Create DataLoaders
132
+ # X_train_cv = torch.tensor(self.train_data[train_idx], dtype=torch.float16)
133
+ # y_train_cv = torch.tensor(self.train_label[train_idx], dtype=torch.float16)
134
+ # X_valid_cv = torch.tensor(self.train_data[val_idx], dtype=torch.float16)
135
+ # y_valid_cv = torch.tensor(self.train_label[val_idx], dtype=torch.float16)
136
+
137
+ # train_loader = self._create_dataloader(X_train_cv, y_train_cv)
138
+ # valid_loader = self._create_dataloader(X_valid_cv, y_valid_cv)
139
+
140
+ # optimizer = self.optimizer
141
+ # criterion = self.criterion
142
+ # best_val_loss = float('inf')
143
+
144
+ # for epoch in range(self.epochs):
145
+ # model.train()
146
+ # for batch_X, batch_y in train_loader:
147
+ # optimizer.zero_grad()
148
+ # output = model(batch_X)
149
+ # loss = criterion(output, batch_y)
150
+ # loss.backward()
151
+ # optimizer.step()
152
+
153
+ # model.eval()
154
+ # with torch.no_grad():
155
+ # val_loss = self._validate(model, valid_loader)
156
+
157
+ # # Callbacks after each epoch
158
+ # for callback in self.callbacks:
159
+ # callback.on_epoch_end(epoch, logs={'loss': loss.item(), 'val_loss': val_loss.item()})
160
+
161
+ # if val_loss < best_val_loss:
162
+ # best_val_loss = val_loss
163
+ # best_model = model
164
+
165
+ # # Evaluate on test data
166
+ # with torch.no_grad():
167
+ # pred = best_model(torch.tensor(self.test_data, dtype=torch.float16))
168
+ # logs, metrics = evaluate(
169
+ # actual=self.test_label, pred=pred.numpy(), info=self.descriptor)
170
+ # mlflow.log_metrics(logs)
171
+
172
+ # return logs, metrics, pred.numpy()
173
+
174
+ # def _validate(self, model, valid_loader):
175
+ # total_loss = 0
176
+ # criterion = self.criterion
177
+
178
+ # for batch_X, batch_y in valid_loader:
179
+ # output = model(batch_X)
180
+ # loss = criterion(output, batch_y)
181
+ # total_loss += loss.item()
182
+
183
+ # return total_loss / len(valid_loader)
184
+
185
+ # def fit_and_evaluate(self) -> tuple[dict[str, Any], Metrics, Any]:
186
+ # return None,None,None
@@ -1,3 +1,4 @@
1
+ from typing import Any, Callable
1
2
  from ddi_fw.ml.model_wrapper import ModelWrapper
2
3
  import tensorflow as tf
3
4
  from tensorflow import keras
@@ -5,7 +6,7 @@ from tensorflow import keras
5
6
  from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, Callback
6
7
  from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
7
8
  import numpy as np
8
-
9
+ from tensorflow.keras import Model
9
10
  import mlflow
10
11
  from mlflow.utils.autologging_utils import batch_metrics_logger
11
12
 
@@ -21,10 +22,11 @@ import os
21
22
 
22
23
  class TFModelWrapper(ModelWrapper):
23
24
 
24
- def __init__(self, date, descriptor, model_func, **kwargs):
25
+ def __init__(self, date, descriptor, model_func, use_mlflow=True, **kwargs):
25
26
  super().__init__(date, descriptor, model_func, **kwargs)
26
- self.batch_size = kwargs.get('batch_size',128)
27
- self.epochs = kwargs.get('epochs',100)
27
+ self.batch_size = kwargs.get('batch_size', 128)
28
+ self.epochs = kwargs.get('epochs', 100)
29
+ self.use_mlflow = use_mlflow
28
30
 
29
31
  def fit_model(self, X_train, y_train, X_valid, y_valid):
30
32
  self.kwargs['input_shape'] = self.train_data.shape
@@ -39,18 +41,24 @@ class TFModelWrapper(ModelWrapper):
39
41
  )
40
42
  early_stopping = EarlyStopping(
41
43
  monitor='val_loss', patience=10, mode='auto')
42
- custom_callback = CustomCallback()
44
+ custom_callback = CustomCallback(self.use_mlflow)
43
45
  train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
44
- val_dataset = tf.data.Dataset.from_tensor_slices((X_valid, y_valid))
45
46
  train_dataset = train_dataset.batch(batch_size=self.batch_size)
46
- val_dataset = val_dataset.batch(batch_size=self.batch_size)
47
+
48
+ if X_valid is not None and y_valid is not None:
49
+ val_dataset = tf.data.Dataset.from_tensor_slices(
50
+ (X_valid, y_valid))
51
+ val_dataset = val_dataset.batch(batch_size=self.batch_size)
52
+ else:
53
+ val_dataset = None
54
+
47
55
  history = model.fit(
48
56
  train_dataset,
49
57
  epochs=self.epochs,
50
- validation_data=val_dataset,
58
+ # validation_data=val_dataset,
51
59
  callbacks=[early_stopping, checkpoint, custom_callback]
52
60
  )
53
- # ex
61
+ # ex
54
62
  # history = model.fit(
55
63
  # X_train, y_train,
56
64
  # batch_size=self.batch_size,
@@ -68,101 +76,185 @@ class TFModelWrapper(ModelWrapper):
68
76
  print(self.train_data.shape)
69
77
  models = {}
70
78
  models_val_acc = {}
71
- for i, (train_idx, val_idx) in enumerate(zip(self.train_idx_arr, self.val_idx_arr)):
72
- print(f"Validation {i}")
73
- with mlflow.start_run(run_name=f'Validation {i}', description='CV models', nested=True) as cv_fit:
74
- X_train_cv = self.train_data[train_idx]
75
- y_train_cv = self.train_label[train_idx]
76
- X_valid_cv = self.train_data[val_idx]
77
- y_valid_cv = self.train_label[val_idx]
79
+ if self.train_idx_arr is not None and self.val_idx_arr is not None:
80
+ for i, (train_idx, val_idx) in enumerate(zip(self.train_idx_arr, self.val_idx_arr)):
81
+ print(f"Validation {i}")
82
+
83
+ if self.use_mlflow:
84
+ with mlflow.start_run(run_name=f'Validation {i}', description='CV models', nested=True) as cv_fit:
85
+ X_train_cv = self.train_data[train_idx]
86
+ y_train_cv = self.train_label[train_idx]
87
+ X_valid_cv = self.train_data[val_idx]
88
+ y_valid_cv = self.train_label[val_idx]
89
+ model, checkpoint = self.fit_model(
90
+ X_train_cv, y_train_cv, X_valid_cv, y_valid_cv)
91
+ models[f'{self.descriptor}_validation_{i}'] = model
92
+ models_val_acc[f'{self.descriptor}_validation_{i}'] = checkpoint.best
93
+ else:
94
+ X_train_cv = self.train_data[train_idx]
95
+ y_train_cv = self.train_label[train_idx]
96
+ X_valid_cv = self.train_data[val_idx]
97
+ y_valid_cv = self.train_label[val_idx]
98
+ model, checkpoint = self.fit_model(
99
+ X_train_cv, y_train_cv, X_valid_cv, y_valid_cv)
100
+ models[f'{self.descriptor}_validation_{i}'] = model
101
+ models_val_acc[f'{self.descriptor}_validation_{i}'] = checkpoint.best
102
+ else:
103
+ if self.use_mlflow:
104
+ with mlflow.start_run(run_name=f'Training', description='Training', nested=True) as cv_fit:
105
+ model, checkpoint = self.fit_model(
106
+ self.train_data, self.train_label, None, None)
107
+ models[self.descriptor] = model
108
+ models_val_acc[self.descriptor] = checkpoint.best
109
+ else:
78
110
  model, checkpoint = self.fit_model(
79
- X_train_cv, y_train_cv, X_valid_cv, y_valid_cv)
80
- models[f'{self.descriptor}_validation_{i}'] = model
81
- models_val_acc[f'{self.descriptor}_validation_{i}'] = checkpoint.best
111
+ self.train_data, self.train_label, None, None)
112
+ models[self.descriptor] = model
113
+ models_val_acc[self.descriptor] = checkpoint.best
82
114
 
83
- best_model_key = max(models_val_acc, key=models_val_acc.get)
115
+ best_model_key = max(models_val_acc, key=lambda k: models_val_acc[k])
116
+ # best_model_key = max(models_val_acc, key=models_val_acc.get)
84
117
  best_model = models[best_model_key]
85
118
  return best_model, best_model_key
86
119
 
87
120
  # https://github.com/mlflow/mlflow/blob/master/examples/tensorflow/train.py
88
121
 
89
122
  def predict(self):
90
- test_dataset = tf.data.Dataset.from_tensor_slices((self.test_data, self.test_label))
123
+ test_dataset = tf.data.Dataset.from_tensor_slices(
124
+ (self.test_data, self.test_label))
91
125
  test_dataset = test_dataset.batch(batch_size=1)
92
126
  # pred = self.best_model.predict(self.test_data)
93
127
  pred = self.best_model.predict(test_dataset)
94
128
  return pred
95
129
 
96
- def fit_and_evaluate(self):
97
-
98
- with mlflow.start_run(run_name=self.descriptor, description="***", nested=True) as run:
99
- print(run.info.artifact_uri)
100
- best_model, best_model_key =self.fit()
130
+ def fit_and_evaluate(self, print_detail=False) -> tuple[dict[str, Any], Metrics, Any]:
131
+ if self.use_mlflow:
132
+ with mlflow.start_run(run_name=self.descriptor, description="***", nested=True) as run:
133
+ print(run.info.artifact_uri)
134
+ best_model, best_model_key = self.fit()
135
+ print(best_model_key)
136
+ self.best_model: Model = best_model
137
+ pred = self.predict()
138
+ logs, metrics = evaluate(
139
+ actual=self.test_label, pred=pred, info=self.descriptor, print_detail=print_detail)
140
+ metrics.format_float()
141
+ mlflow.log_metrics(logs)
142
+ mlflow.log_param('best_cv', best_model_key)
143
+ utils.compress_and_save_data(
144
+ metrics.__dict__, run.info.artifact_uri, f'{self.date}_metrics.gzip')
145
+ mlflow.log_artifact(
146
+ f'{run.info.artifact_uri}/{self.date}_metrics.gzip')
147
+
148
+ return logs, metrics, pred
149
+ else:
150
+ best_model, best_model_key = self.fit()
101
151
  print(best_model_key)
102
152
  self.best_model = best_model
103
153
  pred = self.predict()
104
154
  logs, metrics = evaluate(
105
155
  actual=self.test_label, pred=pred, info=self.descriptor)
106
156
  metrics.format_float()
107
- mlflow.log_metrics(logs)
108
- mlflow.log_param('best_cv', best_model_key)
109
- utils.compress_and_save_data(
110
- metrics.__dict__, run.info.artifact_uri, f'{self.date}_metrics.gzip')
111
- mlflow.log_artifact(f'{run.info.artifact_uri}/{self.date}_metrics.gzip')
112
-
113
157
  return logs, metrics, pred
114
158
 
159
+
160
+ """
161
+ Custom Keras callback for logging training metrics and model summary to MLflow.
162
+ """
163
+
164
+
115
165
  class CustomCallback(Callback):
166
+ def __init__(self, use_mlflow: bool = True):
167
+ super().__init__()
168
+ self.use_mlflow = use_mlflow
169
+
170
+ def _mlflow_log(self, func: Callable):
171
+ if self.use_mlflow:
172
+ func()
173
+
116
174
  def on_train_begin(self, logs=None):
175
+ if logs is None:
176
+ logs = {}
177
+ if not isinstance(self.model, Model):
178
+ raise TypeError("self.model must be an instance of Model")
179
+
117
180
  keys = list(logs.keys())
118
- mlflow.log_param("train_begin_keys", keys)
119
- config = self.model.optimizer.get_config()
181
+ self._mlflow_log(lambda: mlflow.log_param("train_begin_keys", keys))
182
+ # config = self.model.optimizer.get_config()
183
+ config = self.model.get_config()
120
184
  for attribute in config:
121
- mlflow.log_param("opt_" + attribute, config[attribute])
185
+ self._mlflow_log(lambda: mlflow.log_param(
186
+ "opt_" + attribute, config[attribute]))
122
187
 
123
188
  sum_list = []
124
189
  self.model.summary(print_fn=sum_list.append)
125
190
  summary = "\n".join(sum_list)
126
- mlflow.log_text(summary, artifact_file="model_summary.txt")
191
+ self._mlflow_log(lambda: mlflow.log_text(
192
+ summary, artifact_file="model_summary.txt"))
127
193
 
128
194
  def on_train_end(self, logs=None):
195
+ if logs is None:
196
+ logs = {}
129
197
  print(logs)
130
- mlflow.log_metrics(logs)
198
+ self._mlflow_log(lambda: mlflow.log_metrics(logs))
131
199
 
132
200
  def on_epoch_begin(self, epoch, logs=None):
201
+ if logs is None:
202
+ logs = {}
133
203
  keys = list(logs.keys())
134
204
 
135
205
  def on_epoch_end(self, epoch, logs=None):
206
+ if logs is None:
207
+ logs = {}
136
208
  keys = list(logs.keys())
137
209
 
138
210
  def on_test_begin(self, logs=None):
211
+ if logs is None:
212
+ logs = {}
139
213
  keys = list(logs.keys())
140
214
 
141
215
  def on_test_end(self, logs=None):
142
- mlflow.log_metrics(logs)
216
+ if logs is None:
217
+ logs = {}
218
+ self._mlflow_log(lambda: mlflow.log_metrics(logs))
143
219
  print(logs)
144
220
 
145
221
  def on_predict_begin(self, logs=None):
222
+ if logs is None:
223
+ logs = {}
146
224
  keys = list(logs.keys())
147
225
 
148
226
  def on_predict_end(self, logs=None):
227
+ if logs is None:
228
+ logs = {}
149
229
  keys = list(logs.keys())
150
- mlflow.log_metrics(logs)
230
+ self._mlflow_log(lambda: mlflow.log_metrics(logs))
151
231
 
152
232
  def on_train_batch_begin(self, batch, logs=None):
233
+ if logs is None:
234
+ logs = {}
153
235
  keys = list(logs.keys())
154
236
 
155
237
  def on_train_batch_end(self, batch, logs=None):
238
+ if logs is None:
239
+ logs = {}
156
240
  keys = list(logs.keys())
157
241
 
158
242
  def on_test_batch_begin(self, batch, logs=None):
243
+ if logs is None:
244
+ logs = {}
159
245
  keys = list(logs.keys())
160
246
 
161
247
  def on_test_batch_end(self, batch, logs=None):
248
+ if logs is None:
249
+ logs = {}
162
250
  keys = list(logs.keys())
163
251
 
164
252
  def on_predict_batch_begin(self, batch, logs=None):
253
+ if logs is None:
254
+ logs = {}
165
255
  keys = list(logs.keys())
166
256
 
167
257
  def on_predict_batch_end(self, batch, logs=None):
258
+ if logs is None:
259
+ logs = {}
168
260
  keys = list(logs.keys())