pg-sui 0.2.3__py3-none-any.whl → 1.6.16a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. pg_sui-1.6.16a3.dist-info/METADATA +292 -0
  2. pg_sui-1.6.16a3.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.16a3.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.16a3.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +922 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1436 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1121 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1361 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1666 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1660 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1316 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.3.dist-info/METADATA +0 -322
  83. pg_sui-0.2.3.dist-info/RECORD +0 -75
  84. pg_sui-0.2.3.dist-info/top_level.txt +0 -3
  85. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  88. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  89. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  90. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  93. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  94. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  95. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  96. pgsui/example_data/trees/test.iqtree +0 -376
  97. pgsui/example_data/trees/test.qmat +0 -5
  98. pgsui/example_data/trees/test.rate +0 -2033
  99. pgsui/example_data/trees/test.tre +0 -1
  100. pgsui/example_data/trees/test_n10.rate +0 -19
  101. pgsui/example_data/trees/test_n100.rate +0 -109
  102. pgsui/example_data/trees/test_n500.rate +0 -509
  103. pgsui/example_data/trees/test_siterates.txt +0 -2024
  104. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  105. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  106. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  107. pgsui/example_data/vcf_files/test.vcf +0 -244
  108. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  109. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  110. pgsui/impute/estimators.py +0 -1268
  111. pgsui/impute/impute.py +0 -1463
  112. pgsui/impute/simple_imputers.py +0 -1431
  113. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  114. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  115. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  116. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  117. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  118. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  119. pgsui/pg_sui.py +0 -261
  120. pgsui/utils/sequence_tools.py +0 -407
  121. simulation/sim_benchmarks.py +0 -333
  122. simulation/sim_treeparams.py +0 -475
  123. test/__init__.py +0 -0
  124. test/pg_sui_simtest.py +0 -215
  125. test/pg_sui_testing.py +0 -523
  126. test/test.py +0 -151
  127. test/test_pgsui.py +0 -374
  128. test/test_tkc.py +0 -185
@@ -1,486 +0,0 @@
1
- import logging
2
- import os
3
- import sys
4
- import warnings
5
- import math
6
-
7
- # Import tensorflow with reduced warnings.
8
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
9
- logging.getLogger("tensorflow").disabled = True
10
- warnings.filterwarnings("ignore", category=UserWarning)
11
-
12
- import numpy as np
13
- import pandas as pd
14
- import tensorflow as tf
15
-
16
- # Disable can't find cuda .dll errors. Also turns of GPU support.
17
- tf.config.set_visible_devices([], "GPU")
18
-
19
- from tensorflow.python.util import deprecation
20
-
21
- # Disable warnings and info logs.
22
- tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
23
- tf.get_logger().setLevel(logging.ERROR)
24
-
25
-
26
- # Monkey patching deprecation utils to supress warnings.
27
- # noinspection PyUnusedLocal
28
- def deprecated(
29
- date, instructions, warn_once=True
30
- ): # pylint: disable=unused-argument
31
- def deprecated_wrapper(func):
32
- return func
33
-
34
- return deprecated_wrapper
35
-
36
-
37
- deprecation.deprecated = deprecated
38
-
39
- from tensorflow.keras.layers import (
40
- Dropout,
41
- Dense,
42
- Reshape,
43
- Activation,
44
- Flatten,
45
- BatchNormalization,
46
- LeakyReLU,
47
- PReLU,
48
- )
49
-
50
- from tensorflow.keras.regularizers import l1_l2
51
-
52
- # Custom Modules
53
- try:
54
- from ...neural_network_methods import NeuralNetworkMethods
55
- except (ModuleNotFoundError, ValueError, ImportError):
56
- from impute.unsupervised.neural_network_methods import NeuralNetworkMethods
57
-
58
-
59
- class SoftOrdering1DCNN(tf.keras.Model):
60
- def __init__(
61
- self,
62
- y=None,
63
- output_shape=None,
64
- weights_initializer="glorot_normal",
65
- hidden_layer_sizes="midpoint",
66
- num_hidden_layers=1,
67
- hidden_activation="elu",
68
- l1_penalty=1e-6,
69
- l2_penalty=1e-6,
70
- dropout_rate=0.2,
71
- num_classes=4,
72
- sample_weight=None,
73
- batch_size=32,
74
- missing_mask=None,
75
- activation=None,
76
- channel_increase_rate=2,
77
- initial_hidden_size=2048,
78
- num_groups=256,
79
- ):
80
- super(SoftOrdering1DCNN, self).__init__()
81
-
82
- self._y = y
83
- self._missing_mask = missing_mask
84
- self._sample_weight = sample_weight
85
- self._batch_idx = 0
86
- self._batch_size = batch_size
87
- self.output_activation = activation
88
- self.sample_weight = sample_weight
89
-
90
- self.nn_ = NeuralNetworkMethods()
91
- self.binary_accuracy = self.nn_.make_masked_binary_accuracy(
92
- is_vae=True
93
- )
94
-
95
- self.total_loss_tracker = tf.keras.metrics.Mean(name="loss")
96
- self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
97
- name="reconstruction_loss"
98
- )
99
- self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
100
- self.accuracy_tracker = tf.keras.metrics.Mean(name="accuracy")
101
-
102
- # y_train[1] dimension.
103
- self.n_features = output_shape * num_classes
104
-
105
- self.weights_initializer = weights_initializer
106
- self.hidden_layer_sizes = hidden_layer_sizes
107
- self.num_hidden_layers = num_hidden_layers
108
- self.hidden_activation = hidden_activation
109
- self.l1_penalty = l1_penalty
110
- self.l2_penalty = l2_penalty
111
- self.dropout_rate = dropout_rate
112
- self.num_classes = num_classes
113
- self.channel_increase_rate = channel_increase_rate
114
- self.initial_hidden_size = initial_hidden_size
115
- self.channel_size1 = num_groups
116
- self.channel_size2 = num_groups * 2
117
- self.channel_size3 = num_groups * 2
118
-
119
- nn = NeuralNetworkMethods()
120
-
121
- # hidden_layer_sizes = nn.validate_hidden_layers(
122
- # self.hidden_layer_sizes, self.num_hidden_layers
123
- # )
124
-
125
- # hidden_layer_sizes = nn.get_hidden_layer_sizes(
126
- # self.n_features, self.n_components, hidden_layer_sizes, vae=True
127
- # )
128
-
129
- # hidden_layer_sizes = [h * self.num_classes for h in hidden_layer_sizes]
130
-
131
- if self.l1_penalty == 0.0 and self.l2_penalty == 0.0:
132
- kernel_regularizer = None
133
- else:
134
- kernel_regularizer = l1_l2(self.l1_penalty, self.l2_penalty)
135
-
136
- kernel_initializer = self.weights_initializer
137
-
138
- if self.hidden_activation.lower() == "leaky_relu":
139
- activation = LeakyReLU(alpha=0.01)
140
-
141
- elif self.hidden_activation.lower() == "prelu":
142
- activation = PReLU()
143
-
144
- elif self.hidden_activation.lower() == "selu":
145
- activation = "selu"
146
- kernel_initializer = "lecun_normal"
147
-
148
- else:
149
- activation = self.hidden_activation
150
-
151
- if num_hidden_layers > 5:
152
- raise ValueError(
153
- f"The maximum number of hidden layers is 5, but got "
154
- f"{num_hidden_layers}"
155
- )
156
-
157
- hidden_size = initial_hidden_size
158
- if self.n_features >= hidden_size:
159
- scaling_factor = int(math.ceil(self.n_features / hidden_size)) * 2
160
- hidden_size *= num_groups * int(
161
- math.ceil((scaling_factor / num_groups))
162
- )
163
- else:
164
- # If hidden_size is close in number to n_features
165
- if abs(hidden_size - self.n_features) <= (hidden_size // 2):
166
- hidden_size *= 2
167
-
168
- # Model adapted from: https://medium.com/spikelab/convolutional-neural-networks-on-tabular-datasets-part-1-4abdd67795b6
169
-
170
- signal_size1 = hidden_size // num_groups
171
- signal_size2 = signal_size1 // 2
172
- signal_size3 = signal_size1 // 4 * self.channel_size3
173
-
174
- self.signal_size1 = signal_size1
175
- self.signal_size2 = signal_size2
176
- self.signal_size3 = signal_size3
177
-
178
- self.batch_norm1 = BatchNormalization()
179
- self.dropout1 = Dropout(self.dropout_rate)
180
- self.dense1 = Dense(
181
- hidden_size,
182
- input_shape=(self.n_features,),
183
- activation=hidden_activation,
184
- kernel_initializer=kernel_initializer,
185
- )
186
-
187
- self.rshp = Reshape((num_groups, signal_size1))
188
-
189
- self.batch_norm_c1 = BatchNormalization()
190
- self.conv1 = tf.keras.layers.Conv1D(
191
- self.channel_size1 * self.channel_increase_rate,
192
- kernel_size=5,
193
- stride=1,
194
- padding=2,
195
- groups=signal_size1,
196
- kernel_initializer=kernel_initializer,
197
- activation=hidden_activation,
198
- )
199
-
200
- self.avg_po_c1 = tf.keras.layers.AveragePooling1D(
201
- pool_size=4, padding="valid"
202
- )
203
-
204
- self.batch_norm_c2 = BatchNormalization()
205
- self.dropout_c2 = Dropout(self.dropout_rate)
206
- self.conv2 = tf.keras.layers.Conv1D(
207
- self.channel_size2,
208
- kernel_size=3,
209
- stride=1,
210
- padding=1,
211
- kernel_initializer=kernel_initializer,
212
- activation=hidden_activation,
213
- )
214
-
215
- self.batch_norm_c3 = BatchNormalization()
216
- self.dropout_c3 = Dropout(self.dropout_rate)
217
- self.conv3 = tf.keras.layers.Conv1D(
218
- self.channel_size2,
219
- kernel_size=3,
220
- stride=1,
221
- padding=1,
222
- kernel_initializer=kernel_initializer,
223
- activation=hidden_activation,
224
- )
225
-
226
- self.batch_norm_c4 = BatchNormalization()
227
- self.dropout_c4 = Dropout(self.dropout_rate)
228
- self.conv4 = tf.keras.layers.Conv1D(
229
- self.channel_size2,
230
- kernel_size=5,
231
- stride=1,
232
- padding=2,
233
- groups=signal_size1,
234
- kernel_initializer=kernel_initializer,
235
- activation=None,
236
- )
237
-
238
- self.act_c4 = Activation(hidden_activation)
239
-
240
- self.max_po_c4 = tf.keras.layers.MaxPooling1D(
241
- pool_size=4, stride=2, padding=1
242
- )
243
-
244
- self.flatten = Flatten()
245
-
246
- self.batch_norm2 = BatchNormalization()
247
- self.dropout2 = Dropout(self.dropout_rate)
248
- self.dense2 = Dense(
249
- self.n_features, kernel_initializer=kernel_initializer
250
- )
251
- self.rshp2 = Reshape((output_shape, num_classes))
252
- self.act2 = Activation(activation)
253
-
254
- def call(self, inputs, training=None):
255
- """Call the model on a particular input.
256
-
257
- Args:
258
- input (tf.Tensor): Input tensor. Must be one-hot encoded.
259
-
260
- Returns:
261
- tf.Tensor: Output predictions. Will be one-hot encoded.
262
- """
263
- x = self.dense1(inputs)
264
- x = self.batch_norm1(x, training=training)
265
- x = self.dropout1(x, training=training)
266
- x = self.rshp(x)
267
- x = self.conv1(x)
268
- x = self.batch_norm_c1(x, training=training)
269
- x = self.avg_po_c1(x)
270
- x = self.conv2(x)
271
- x = self.batch_norm(x, training=training)
272
- x = self.dropout_c2(x, training=training)
273
- x_s = x
274
- x = self.conv3(x)
275
- x = self.batch_norm_c3(x, training=training)
276
- x = self.dropout(x, training=training)
277
- x = self.conv4(x)
278
- x = self.batch_norm_c4(x, training=training)
279
- x += x_s
280
- x = self.act_c4(x)
281
- x = self.max_po_c4(x)
282
- x = self.dropout1(x)
283
- x = self.rshp(x)
284
- x = self.batch_norm_c1(x)
285
- x = self.conv1(x)
286
- x = self.avg_po_c1(x)
287
- x = self.flatten(x)
288
- x = self.dense2(x)
289
- x = self.batch_norm2(x, training=training)
290
- x = self.dropout2(x, training=training)
291
- x = self.rshp2(x)
292
- return self.act2(x)
293
-
294
- def model(self):
295
- """Here so that mymodel.model().summary() can be called for debugging."""
296
- x = tf.keras.Input(shape=(self.n_features * self.num_classes,))
297
- return tf.keras.Model(inputs=[x], outputs=self.call(x))
298
-
299
- def set_model_outputs(self):
300
- x = tf.keras.Input(shape=(self.n_features * self.num_classes,))
301
- model = tf.keras.Model(inputs=[x], outputs=self.call(x))
302
- self.outputs = model.outputs
303
-
304
- @property
305
- def metrics(self):
306
- return [
307
- self.total_loss_tracker,
308
- self.reconstruction_loss_tracker,
309
- self.kl_loss_tracker,
310
- self.accuracy_tracker,
311
- ]
312
-
313
- @tf.function
314
- def train_step(self, data):
315
- # if isinstance(data, tuple):
316
- # if len(data) == 2:
317
- # x, y = data
318
- # sample_weight = None
319
- # else:
320
- # x, y, sample_weight = data
321
- # else:
322
- # raise TypeError("Target y must be supplied to fit for this model.")
323
-
324
- # Set in the UBPCallbacks() callback.
325
- y = self._y
326
-
327
- (
328
- y,
329
- y_true,
330
- sample_weight,
331
- missing_mask,
332
- batch_start,
333
- batch_end,
334
- ) = self.nn_.prepare_training_batches(
335
- y,
336
- y,
337
- self._batch_size,
338
- self._batch_idx,
339
- True,
340
- self.n_components,
341
- self._sample_weight,
342
- self._missing_mask,
343
- ubp=False,
344
- )
345
-
346
- if sample_weight is not None:
347
- sample_weight_masked = tf.convert_to_tensor(
348
- sample_weight[~missing_mask], dtype=tf.float32
349
- )
350
- else:
351
- sample_weight_masked = None
352
-
353
- y_true_masked = tf.boolean_mask(
354
- tf.convert_to_tensor(y_true, dtype=tf.float32),
355
- tf.reduce_any(tf.not_equal(y_true, -1), axis=2),
356
- )
357
-
358
- with tf.GradientTape() as tape:
359
- reconstruction = self(tf.convert_to_tensor(y), training=True)
360
-
361
- y_pred_masked = tf.boolean_mask(
362
- reconstruction, tf.reduce_any(tf.not_equal(y_true, -1), axis=2)
363
- )
364
-
365
- # Returns binary crossentropy loss.
366
- loss = self.compiled_loss(
367
- y_true_masked,
368
- y_pred_masked,
369
- sample_weight=sample_weight_masked,
370
- )
371
-
372
- grads = tape.gradient(loss, self.trainable_weights)
373
- self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
374
- self.total_loss_tracker.update_state(loss)
375
-
376
- ### NOTE: If you get the error, "'tuple' object has no attribute
377
- ### 'rank', then convert y_true to a tensor object."
378
- # self.compiled_metrics.update_state(
379
- self.accuracy_tracker.update_state(
380
- self.binary_accuracy(
381
- y_true_masked,
382
- y_pred_masked,
383
- sample_weight=sample_weight_masked,
384
- )
385
- )
386
-
387
- return {
388
- "loss": self.total_loss_tracker.result(),
389
- "accuracy": self.accuracy_tracker.result(),
390
- }
391
-
392
- @tf.function
393
- def test_step(self, data):
394
- if isinstance(data, tuple):
395
- if len(data) == 2:
396
- x, y = data
397
- sample_weight = None
398
- else:
399
- x, y, sample_weight = data
400
- else:
401
- raise TypeError("Target y must be supplied to fit in this model.")
402
-
403
- if sample_weight is not None:
404
- sample_weight_masked = tf.boolean_mask(
405
- tf.convert_to_tensor(sample_weight),
406
- tf.reduce_any(tf.not_equal(y, -1), axis=2),
407
- )
408
- else:
409
- sample_weight_masked = None
410
-
411
- reconstruction, z_mean, z_log_var, z = self(x, training=False)
412
- reconstruction_loss = self.compiled_loss(
413
- y,
414
- reconstruction,
415
- sample_weight=sample_weight_masked,
416
- )
417
-
418
- # Includes KL Divergence Loss.
419
- regularization_loss = sum(self.losses)
420
-
421
- total_loss = reconstruction_loss + regularization_loss
422
-
423
- self.accuracy_tracker.update_state(
424
- self.cateogrical_accuracy(
425
- y,
426
- reconstruction,
427
- sample_weight=sample_weight_masked,
428
- )
429
- )
430
-
431
- self.total_loss_tracker.update_state(total_loss)
432
- self.reconstruction_loss_tracker.update_state(reconstruction_loss)
433
- self.kl_loss_tracker.update_state(regularization_loss)
434
-
435
- return {
436
- "loss": self.total_loss_tracker.result(),
437
- "reconstruction_loss": self.reconstruction_loss_tracker.result(),
438
- "kl_loss": self.kl_loss_tracker.result(),
439
- "accuracy": self.accuracy_tracker.result(),
440
- }
441
-
442
- @property
443
- def batch_size(self):
444
- """Batch (=step) size per epoch."""
445
- return self._batch_size
446
-
447
- @property
448
- def batch_idx(self):
449
- """Current batch (=step) index."""
450
- return self._batch_idx
451
-
452
- @property
453
- def y(self):
454
- return self._y
455
-
456
- @property
457
- def missing_mask(self):
458
- return self._missing_mask
459
-
460
- @property
461
- def sample_weight(self):
462
- return self._sample_weight
463
-
464
- @batch_size.setter
465
- def batch_size(self, value):
466
- """Set batch_size parameter."""
467
- self._batch_size = int(value)
468
-
469
- @batch_idx.setter
470
- def batch_idx(self, value):
471
- """Set current batch (=step) index."""
472
- self._batch_idx = int(value)
473
-
474
- @y.setter
475
- def y(self, value):
476
- """Set y after each epoch."""
477
- self._y = value
478
-
479
- @missing_mask.setter
480
- def missing_mask(self, value):
481
- """Set y after each epoch."""
482
- self._missing_mask = value
483
-
484
- @sample_weight.setter
485
- def sample_weight(self, value):
486
- self._sample_weight = value