tf-keras-nightly 2.17.0.dev2024031909__py3-none-any.whl → 2.19.0.dev2025011410__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tf_keras/__init__.py +1 -1
- tf_keras/src/__init__.py +1 -1
- tf_keras/src/backend.py +1 -1
- tf_keras/src/callbacks.py +24 -7
- tf_keras/src/datasets/boston_housing.py +14 -5
- tf_keras/src/datasets/cifar10.py +9 -1
- tf_keras/src/datasets/cifar100.py +7 -1
- tf_keras/src/datasets/fashion_mnist.py +16 -4
- tf_keras/src/datasets/imdb.py +8 -0
- tf_keras/src/datasets/mnist.py +9 -3
- tf_keras/src/datasets/reuters.py +8 -0
- tf_keras/src/engine/base_layer.py +10 -4
- tf_keras/src/engine/base_layer_v1.py +10 -4
- tf_keras/src/engine/node.py +8 -3
- tf_keras/src/layers/activation/prelu.py +1 -1
- tf_keras/src/layers/attention/base_dense_attention.py +2 -1
- tf_keras/src/layers/convolutional/base_conv.py +1 -1
- tf_keras/src/layers/convolutional/base_depthwise_conv.py +3 -1
- tf_keras/src/layers/convolutional/base_separable_conv.py +3 -1
- tf_keras/src/layers/convolutional/conv1d_transpose.py +3 -1
- tf_keras/src/layers/convolutional/conv2d_transpose.py +3 -1
- tf_keras/src/layers/convolutional/conv3d_transpose.py +3 -1
- tf_keras/src/layers/core/dense.py +1 -1
- tf_keras/src/layers/core/embedding.py +1 -1
- tf_keras/src/layers/locally_connected/locally_connected1d.py +1 -1
- tf_keras/src/layers/locally_connected/locally_connected2d.py +1 -1
- tf_keras/src/layers/normalization/batch_normalization.py +1 -1
- tf_keras/src/layers/normalization/layer_normalization.py +1 -1
- tf_keras/src/layers/normalization/unit_normalization.py +2 -1
- tf_keras/src/layers/rnn/abstract_rnn_cell.py +1 -1
- tf_keras/src/layers/rnn/base_conv_lstm.py +0 -1
- tf_keras/src/layers/rnn/base_conv_rnn.py +3 -1
- tf_keras/src/layers/rnn/base_rnn.py +1 -1
- tf_keras/src/layers/rnn/base_wrapper.py +1 -1
- tf_keras/src/layers/rnn/bidirectional.py +2 -1
- tf_keras/src/layers/rnn/cell_wrappers.py +3 -3
- tf_keras/src/layers/rnn/cudnn_gru.py +6 -3
- tf_keras/src/layers/rnn/cudnn_lstm.py +6 -3
- tf_keras/src/layers/rnn/gru.py +35 -47
- tf_keras/src/layers/rnn/legacy_cell_wrappers.py +3 -3
- tf_keras/src/layers/rnn/legacy_cells.py +20 -25
- tf_keras/src/layers/rnn/lstm.py +35 -50
- tf_keras/src/layers/rnn/simple_rnn.py +0 -1
- tf_keras/src/layers/rnn/stacked_rnn_cells.py +1 -1
- tf_keras/src/layers/rnn/time_distributed.py +0 -1
- tf_keras/src/mixed_precision/autocast_variable.py +12 -6
- tf_keras/src/mixed_precision/test_util.py +6 -5
- tf_keras/src/optimizers/legacy/optimizer_v2.py +9 -2
- tf_keras/src/optimizers/optimizer.py +18 -9
- tf_keras/src/premade_models/linear.py +2 -1
- tf_keras/src/saving/legacy/saved_model/json_utils.py +1 -1
- tf_keras/src/saving/saving_api.py +165 -127
- tf_keras/src/saving/saving_lib.py +1 -11
- tf_keras/src/saving/serialization_lib.py +1 -10
- tf_keras/src/utils/data_utils.py +1 -1
- tf_keras/src/utils/steps_per_execution_tuning.py +1 -1
- tf_keras/src/utils/tf_utils.py +2 -2
- tf_keras/src/utils/timeseries_dataset.py +13 -5
- {tf_keras_nightly-2.17.0.dev2024031909.dist-info → tf_keras_nightly-2.19.0.dev2025011410.dist-info}/METADATA +14 -3
- {tf_keras_nightly-2.17.0.dev2024031909.dist-info → tf_keras_nightly-2.19.0.dev2025011410.dist-info}/RECORD +62 -62
- {tf_keras_nightly-2.17.0.dev2024031909.dist-info → tf_keras_nightly-2.19.0.dev2025011410.dist-info}/WHEEL +1 -1
- {tf_keras_nightly-2.17.0.dev2024031909.dist-info → tf_keras_nightly-2.19.0.dev2025011410.dist-info}/top_level.txt +0 -0
@@ -308,7 +308,7 @@ class LocallyConnected2D(Layer):
|
|
308
308
|
self.input_spec = InputSpec(ndim=4, axes={1: input_filter})
|
309
309
|
else:
|
310
310
|
self.input_spec = InputSpec(ndim=4, axes={-1: input_filter})
|
311
|
-
|
311
|
+
super().build(input_shape)
|
312
312
|
|
313
313
|
@tf_utils.shape_type_conversion
|
314
314
|
def compute_output_shape(self, input_shape):
|
@@ -542,7 +542,7 @@ class BatchNormalizationBase(Layer):
|
|
542
542
|
finally:
|
543
543
|
if partitioner:
|
544
544
|
self._scope.set_partitioner(partitioner)
|
545
|
-
|
545
|
+
super().build(input_shape)
|
546
546
|
|
547
547
|
def call(self, inputs, training=None, mask=None):
|
548
548
|
inputs = tf.cast(inputs, self.compute_dtype)
|
@@ -60,7 +60,8 @@ class UnitNormalization(base_layer.Layer):
|
|
60
60
|
self.supports_masking = True
|
61
61
|
|
62
62
|
def build(self, input_shape):
|
63
|
-
|
63
|
+
tf_utils.validate_axis(self.axis, input_shape)
|
64
|
+
super().build(input_shape)
|
64
65
|
|
65
66
|
def call(self, inputs):
|
66
67
|
inputs = tf.cast(inputs, self.compute_dtype)
|
@@ -20,6 +20,7 @@ import tensorflow.compat.v2 as tf
|
|
20
20
|
|
21
21
|
from tf_keras.src import backend
|
22
22
|
from tf_keras.src.engine import base_layer
|
23
|
+
from tf_keras.src.engine.base_layer import Layer
|
23
24
|
from tf_keras.src.engine.input_spec import InputSpec
|
24
25
|
from tf_keras.src.layers.rnn.base_rnn import RNN
|
25
26
|
from tf_keras.src.utils import conv_utils
|
@@ -207,6 +208,8 @@ class ConvRNN(RNN):
|
|
207
208
|
|
208
209
|
@tf_utils.shape_type_conversion
|
209
210
|
def build(self, input_shape):
|
211
|
+
# Call Layer.build() to skip RNN.build() which we override here.
|
212
|
+
Layer.build(self, input_shape)
|
210
213
|
# Note input_shape will be list of shapes of initial states and
|
211
214
|
# constants if these are passed in __call__.
|
212
215
|
if self._num_constants is not None:
|
@@ -263,7 +266,6 @@ class ConvRNN(RNN):
|
|
263
266
|
]
|
264
267
|
if self.stateful:
|
265
268
|
self.reset_states()
|
266
|
-
self.built = True
|
267
269
|
|
268
270
|
def get_initial_state(self, inputs):
|
269
271
|
# (samples, timesteps, img_dims..., filters)
|
@@ -470,7 +470,8 @@ class Bidirectional(Wrapper):
|
|
470
470
|
self.forward_layer.build(input_shape)
|
471
471
|
with backend.name_scope(self.backward_layer.name):
|
472
472
|
self.backward_layer.build(input_shape)
|
473
|
-
|
473
|
+
# Call Layer.build() to skip Wrapper.build() which we override here.
|
474
|
+
Layer.build(self, input_shape)
|
474
475
|
|
475
476
|
def compute_mask(self, inputs, mask):
|
476
477
|
if isinstance(mask, list):
|
@@ -102,10 +102,10 @@ class _RNNCellWrapper(AbstractRNNCell):
|
|
102
102
|
inputs, state, cell_call_fn=self.cell.call, **kwargs
|
103
103
|
)
|
104
104
|
|
105
|
-
def build(self,
|
105
|
+
def build(self, input_shape):
|
106
106
|
"""Builds the wrapped cell."""
|
107
|
-
self.cell.build(
|
108
|
-
|
107
|
+
self.cell.build(input_shape)
|
108
|
+
super().build(input_shape)
|
109
109
|
|
110
110
|
@property
|
111
111
|
def wrapped_cell(self):
|
@@ -144,8 +144,6 @@ class CuDNNGRU(_CuDNNRNN):
|
|
144
144
|
constraint=self.bias_constraint,
|
145
145
|
)
|
146
146
|
|
147
|
-
self.built = True
|
148
|
-
|
149
147
|
def _process_batch(self, inputs, initial_state):
|
150
148
|
if not self.time_major:
|
151
149
|
inputs = tf.transpose(inputs, perm=(1, 0, 2))
|
@@ -172,6 +170,10 @@ class CuDNNGRU(_CuDNNRNN):
|
|
172
170
|
shape=self._vector_shape,
|
173
171
|
)
|
174
172
|
|
173
|
+
batch_dim = tf.shape(inputs)[1]
|
174
|
+
max_sequence_length = tf.shape(inputs)[0]
|
175
|
+
sequence_lengths = tf.fill([batch_dim], max_sequence_length)
|
176
|
+
|
175
177
|
args = {
|
176
178
|
"input": inputs,
|
177
179
|
"input_h": input_h,
|
@@ -179,9 +181,10 @@ class CuDNNGRU(_CuDNNRNN):
|
|
179
181
|
"params": params,
|
180
182
|
"is_training": True,
|
181
183
|
"rnn_mode": "gru",
|
184
|
+
"sequence_lengths": sequence_lengths,
|
182
185
|
}
|
183
186
|
|
184
|
-
outputs, h, _, _, _ = tf.raw_ops.
|
187
|
+
outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(**args)
|
185
188
|
|
186
189
|
if self.stateful or self.return_state:
|
187
190
|
h = h[0]
|
@@ -170,8 +170,6 @@ class CuDNNLSTM(_CuDNNRNN):
|
|
170
170
|
constraint=self.bias_constraint,
|
171
171
|
)
|
172
172
|
|
173
|
-
self.built = True
|
174
|
-
|
175
173
|
def _process_batch(self, inputs, initial_state):
|
176
174
|
if not self.time_major:
|
177
175
|
inputs = tf.transpose(inputs, perm=(1, 0, 2))
|
@@ -204,15 +202,20 @@ class CuDNNLSTM(_CuDNNRNN):
|
|
204
202
|
shape=self._vector_shape,
|
205
203
|
)
|
206
204
|
|
205
|
+
batch_dim = tf.shape(inputs)[1]
|
206
|
+
max_sequence_length = tf.shape(inputs)[0]
|
207
|
+
sequence_lengths = tf.fill([batch_dim], max_sequence_length)
|
208
|
+
|
207
209
|
args = {
|
208
210
|
"input": inputs,
|
209
211
|
"input_h": input_h,
|
210
212
|
"input_c": input_c,
|
211
213
|
"params": params,
|
212
214
|
"is_training": True,
|
215
|
+
"sequence_lengths": sequence_lengths,
|
213
216
|
}
|
214
217
|
|
215
|
-
outputs, h, c, _, _ = tf.raw_ops.
|
218
|
+
outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3(**args)
|
216
219
|
|
217
220
|
if self.stateful or self.return_state:
|
218
221
|
h = h[0]
|
tf_keras/src/layers/rnn/gru.py
CHANGED
@@ -222,7 +222,6 @@ class GRUCell(DropoutRNNCellMixin, base_layer.BaseRandomLayer):
|
|
222
222
|
)
|
223
223
|
else:
|
224
224
|
self.bias = None
|
225
|
-
self.built = True
|
226
225
|
|
227
226
|
def call(self, inputs, states, training=None):
|
228
227
|
h_tm1 = (
|
@@ -1034,11 +1033,13 @@ def gpu_gru(
|
|
1034
1033
|
mask, time_major
|
1035
1034
|
)
|
1036
1035
|
|
1037
|
-
if
|
1038
|
-
|
1039
|
-
|
1040
|
-
|
1041
|
-
|
1036
|
+
seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
|
1037
|
+
|
1038
|
+
if sequence_lengths is None:
|
1039
|
+
max_sequence_length = tf.shape(inputs)[seq_axis]
|
1040
|
+
batch_size = tf.shape(inputs)[batch_axis]
|
1041
|
+
sequence_lengths = tf.fill([batch_size], max_sequence_length)
|
1042
|
+
|
1042
1043
|
# For init_h, cuDNN expects one more dim of num_layers before or after batch
|
1043
1044
|
# dim for time major or batch major inputs respectively
|
1044
1045
|
init_h = tf.expand_dims(init_h, axis=seq_axis)
|
@@ -1069,49 +1070,36 @@ def gpu_gru(
|
|
1069
1070
|
transpose_weights=True,
|
1070
1071
|
)
|
1071
1072
|
|
1072
|
-
if
|
1073
|
-
|
1074
|
-
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1078
|
-
|
1079
|
-
inputs
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
batch_axis=batch_axis,
|
1084
|
-
)
|
1085
|
-
outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(
|
1086
|
-
input=inputs,
|
1087
|
-
input_h=init_h,
|
1088
|
-
input_c=0,
|
1089
|
-
params=params,
|
1090
|
-
is_training=True,
|
1091
|
-
rnn_mode="gru",
|
1092
|
-
sequence_lengths=sequence_lengths,
|
1093
|
-
time_major=time_major,
|
1073
|
+
if go_backwards:
|
1074
|
+
# Three reversals are required. E.g.,
|
1075
|
+
# normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
|
1076
|
+
# reversed_input_to_cudnn = [3, 2, 1, 0, 0]
|
1077
|
+
# output_from_cudnn = [6, 5, 4, 0, 0]
|
1078
|
+
# expected_output = [0, 0, 6, 5 ,4]
|
1079
|
+
inputs = tf.reverse_sequence(
|
1080
|
+
inputs,
|
1081
|
+
sequence_lengths,
|
1082
|
+
seq_axis=seq_axis,
|
1083
|
+
batch_axis=batch_axis,
|
1094
1084
|
)
|
1095
|
-
|
1096
|
-
|
1097
|
-
|
1098
|
-
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
1104
|
-
|
1105
|
-
|
1106
|
-
|
1107
|
-
|
1108
|
-
|
1109
|
-
|
1110
|
-
|
1111
|
-
params=params,
|
1112
|
-
is_training=True,
|
1113
|
-
rnn_mode="gru",
|
1085
|
+
outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(
|
1086
|
+
input=inputs,
|
1087
|
+
input_h=init_h,
|
1088
|
+
input_c=0,
|
1089
|
+
params=params,
|
1090
|
+
is_training=True,
|
1091
|
+
rnn_mode="gru",
|
1092
|
+
sequence_lengths=sequence_lengths,
|
1093
|
+
time_major=time_major,
|
1094
|
+
)
|
1095
|
+
if go_backwards:
|
1096
|
+
outputs = tf.reverse_sequence(
|
1097
|
+
outputs,
|
1098
|
+
sequence_lengths,
|
1099
|
+
seq_axis=seq_axis,
|
1100
|
+
batch_axis=batch_axis,
|
1114
1101
|
)
|
1102
|
+
outputs = tf.reverse(outputs, axis=[seq_axis])
|
1115
1103
|
|
1116
1104
|
last_output = outputs[-1]
|
1117
1105
|
if not time_major and sequence_lengths is None and return_sequences:
|
@@ -368,9 +368,9 @@ class DropoutWrapper(_RNNCellWrapperV1):
|
|
368
368
|
def wrapped_cell(self):
|
369
369
|
return self.cell
|
370
370
|
|
371
|
-
def build(self,
|
372
|
-
self.cell.build(
|
373
|
-
|
371
|
+
def build(self, input_shape):
|
372
|
+
self.cell.build(input_shape)
|
373
|
+
super().build(input_shape)
|
374
374
|
|
375
375
|
def _variational_recurrent_dropout_value(
|
376
376
|
self, unused_index, value, noise, keep_prob
|
@@ -246,11 +246,6 @@ class RNNCell(base_layer.Layer):
|
|
246
246
|
"""Integer or TensorShape: size of outputs produced by this cell."""
|
247
247
|
raise NotImplementedError("Abstract method")
|
248
248
|
|
249
|
-
def build(self, _):
|
250
|
-
# This tells the parent Layer object that it's OK to call
|
251
|
-
# self.add_weight() inside the call() method.
|
252
|
-
pass
|
253
|
-
|
254
249
|
def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
|
255
250
|
if inputs is not None:
|
256
251
|
# Validate the given batch_size and dtype against inputs if
|
@@ -445,15 +440,15 @@ class BasicRNNCell(LayerRNNCell):
|
|
445
440
|
return self._num_units
|
446
441
|
|
447
442
|
@tf_utils.shape_type_conversion
|
448
|
-
def build(self,
|
449
|
-
if
|
443
|
+
def build(self, input_shape):
|
444
|
+
if input_shape[-1] is None:
|
450
445
|
raise ValueError(
|
451
446
|
"Expected inputs.shape[-1] to be known, "
|
452
|
-
f"received shape: {
|
447
|
+
f"received shape: {input_shape}"
|
453
448
|
)
|
454
449
|
_check_supported_dtypes(self.dtype)
|
455
450
|
|
456
|
-
input_depth =
|
451
|
+
input_depth = input_shape[-1]
|
457
452
|
self._kernel = self.add_weight(
|
458
453
|
_WEIGHTS_VARIABLE_NAME,
|
459
454
|
shape=[input_depth + self._num_units, self._num_units],
|
@@ -464,7 +459,7 @@ class BasicRNNCell(LayerRNNCell):
|
|
464
459
|
initializer=tf.compat.v1.zeros_initializer(dtype=self.dtype),
|
465
460
|
)
|
466
461
|
|
467
|
-
|
462
|
+
super().build(input_shape)
|
468
463
|
|
469
464
|
def call(self, inputs, state):
|
470
465
|
"""Most basic RNN: output = new_state = act(W * input + U * state +
|
@@ -563,14 +558,14 @@ class GRUCell(LayerRNNCell):
|
|
563
558
|
return self._num_units
|
564
559
|
|
565
560
|
@tf_utils.shape_type_conversion
|
566
|
-
def build(self,
|
567
|
-
if
|
561
|
+
def build(self, input_shape):
|
562
|
+
if input_shape[-1] is None:
|
568
563
|
raise ValueError(
|
569
564
|
"Expected inputs.shape[-1] to be known, "
|
570
|
-
f"received shape: {
|
565
|
+
f"received shape: {input_shape}"
|
571
566
|
)
|
572
567
|
_check_supported_dtypes(self.dtype)
|
573
|
-
input_depth =
|
568
|
+
input_depth = input_shape[-1]
|
574
569
|
self._gate_kernel = self.add_weight(
|
575
570
|
f"gates/{_WEIGHTS_VARIABLE_NAME}",
|
576
571
|
shape=[input_depth + self._num_units, 2 * self._num_units],
|
@@ -600,7 +595,7 @@ class GRUCell(LayerRNNCell):
|
|
600
595
|
),
|
601
596
|
)
|
602
597
|
|
603
|
-
|
598
|
+
super().build(input_shape)
|
604
599
|
|
605
600
|
def call(self, inputs, state):
|
606
601
|
"""Gated recurrent unit (GRU) with nunits cells."""
|
@@ -774,14 +769,14 @@ class BasicLSTMCell(LayerRNNCell):
|
|
774
769
|
return self._num_units
|
775
770
|
|
776
771
|
@tf_utils.shape_type_conversion
|
777
|
-
def build(self,
|
778
|
-
if
|
772
|
+
def build(self, input_shape):
|
773
|
+
if input_shape[-1] is None:
|
779
774
|
raise ValueError(
|
780
775
|
"Expected inputs.shape[-1] to be known, "
|
781
|
-
f"received shape: {
|
776
|
+
f"received shape: {input_shape}"
|
782
777
|
)
|
783
778
|
_check_supported_dtypes(self.dtype)
|
784
|
-
input_depth =
|
779
|
+
input_depth = input_shape[-1]
|
785
780
|
h_depth = self._num_units
|
786
781
|
self._kernel = self.add_weight(
|
787
782
|
_WEIGHTS_VARIABLE_NAME,
|
@@ -793,7 +788,7 @@ class BasicLSTMCell(LayerRNNCell):
|
|
793
788
|
initializer=tf.compat.v1.zeros_initializer(dtype=self.dtype),
|
794
789
|
)
|
795
790
|
|
796
|
-
|
791
|
+
super().build(input_shape)
|
797
792
|
|
798
793
|
def call(self, inputs, state):
|
799
794
|
"""Long short-term memory cell (LSTM).
|
@@ -1017,14 +1012,14 @@ class LSTMCell(LayerRNNCell):
|
|
1017
1012
|
return self._output_size
|
1018
1013
|
|
1019
1014
|
@tf_utils.shape_type_conversion
|
1020
|
-
def build(self,
|
1021
|
-
if
|
1015
|
+
def build(self, input_shape):
|
1016
|
+
if input_shape[-1] is None:
|
1022
1017
|
raise ValueError(
|
1023
1018
|
"Expected inputs.shape[-1] to be known, "
|
1024
|
-
f"received shape: {
|
1019
|
+
f"received shape: {input_shape}"
|
1025
1020
|
)
|
1026
1021
|
_check_supported_dtypes(self.dtype)
|
1027
|
-
input_depth =
|
1022
|
+
input_depth = input_shape[-1]
|
1028
1023
|
h_depth = self._num_units if self._num_proj is None else self._num_proj
|
1029
1024
|
maybe_partitioner = (
|
1030
1025
|
tf.compat.v1.fixed_size_partitioner(self._num_unit_shards)
|
@@ -1076,7 +1071,7 @@ class LSTMCell(LayerRNNCell):
|
|
1076
1071
|
partitioner=maybe_proj_partitioner,
|
1077
1072
|
)
|
1078
1073
|
|
1079
|
-
|
1074
|
+
super().build(input_shape)
|
1080
1075
|
|
1081
1076
|
def call(self, inputs, state):
|
1082
1077
|
"""Run one step of LSTM.
|
tf_keras/src/layers/rnn/lstm.py
CHANGED
@@ -236,7 +236,6 @@ class LSTMCell(DropoutRNNCellMixin, base_layer.BaseRandomLayer):
|
|
236
236
|
)
|
237
237
|
else:
|
238
238
|
self.bias = None
|
239
|
-
self.built = True
|
240
239
|
|
241
240
|
def _compute_carry_and_output(self, x, h_tm1, c_tm1):
|
242
241
|
"""Computes carry and output using split kernels."""
|
@@ -1063,11 +1062,13 @@ def gpu_lstm(
|
|
1063
1062
|
mask, time_major
|
1064
1063
|
)
|
1065
1064
|
|
1066
|
-
if
|
1067
|
-
|
1068
|
-
|
1069
|
-
|
1070
|
-
|
1065
|
+
seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
|
1066
|
+
|
1067
|
+
if sequence_lengths is None:
|
1068
|
+
max_sequence_length = tf.shape(inputs)[seq_axis]
|
1069
|
+
batch_size = tf.shape(inputs)[batch_axis]
|
1070
|
+
sequence_lengths = tf.fill([batch_size], max_sequence_length)
|
1071
|
+
|
1071
1072
|
# For init_h and init_c, cuDNN expects one more dim of num_layers before or
|
1072
1073
|
# after batch dim for time major or batch major inputs respectively
|
1073
1074
|
init_h = tf.expand_dims(init_h, axis=seq_axis)
|
@@ -1099,52 +1100,36 @@ def gpu_lstm(
|
|
1099
1100
|
transpose_weights=True,
|
1100
1101
|
)
|
1101
1102
|
|
1102
|
-
if
|
1103
|
-
|
1104
|
-
|
1105
|
-
|
1106
|
-
|
1107
|
-
|
1108
|
-
|
1109
|
-
inputs
|
1110
|
-
|
1111
|
-
|
1112
|
-
|
1113
|
-
batch_axis=batch_axis,
|
1114
|
-
)
|
1115
|
-
outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3(
|
1116
|
-
input=inputs,
|
1117
|
-
input_h=init_h,
|
1118
|
-
input_c=init_c,
|
1119
|
-
params=params,
|
1120
|
-
is_training=True,
|
1121
|
-
rnn_mode="lstm",
|
1122
|
-
sequence_lengths=sequence_lengths,
|
1123
|
-
time_major=time_major,
|
1103
|
+
if go_backwards:
|
1104
|
+
# Three reversals are required. E.g.,
|
1105
|
+
# normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
|
1106
|
+
# reversed_input_to_cudnn = [3, 2, 1, 0, 0]
|
1107
|
+
# output_from_cudnn = [6, 5, 4, 0, 0]
|
1108
|
+
# expected_output = [0, 0, 6, 5 ,4]
|
1109
|
+
inputs = tf.reverse_sequence(
|
1110
|
+
inputs,
|
1111
|
+
sequence_lengths,
|
1112
|
+
seq_axis=seq_axis,
|
1113
|
+
batch_axis=batch_axis,
|
1124
1114
|
)
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1129
|
-
|
1130
|
-
|
1131
|
-
|
1132
|
-
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
1136
|
-
|
1137
|
-
|
1138
|
-
|
1139
|
-
|
1140
|
-
|
1141
|
-
input=inputs,
|
1142
|
-
input_h=init_h,
|
1143
|
-
input_c=init_c,
|
1144
|
-
params=params,
|
1145
|
-
is_training=True,
|
1146
|
-
rnn_mode="lstm",
|
1115
|
+
outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3(
|
1116
|
+
input=inputs,
|
1117
|
+
input_h=init_h,
|
1118
|
+
input_c=init_c,
|
1119
|
+
params=params,
|
1120
|
+
is_training=True,
|
1121
|
+
rnn_mode="lstm",
|
1122
|
+
sequence_lengths=sequence_lengths,
|
1123
|
+
time_major=time_major,
|
1124
|
+
)
|
1125
|
+
if go_backwards:
|
1126
|
+
outputs = tf.reverse_sequence(
|
1127
|
+
outputs,
|
1128
|
+
sequence_lengths,
|
1129
|
+
seq_axis=seq_axis,
|
1130
|
+
batch_axis=batch_axis,
|
1147
1131
|
)
|
1132
|
+
outputs = tf.reverse(outputs, axis=[seq_axis])
|
1148
1133
|
|
1149
1134
|
last_output = outputs[-1]
|
1150
1135
|
if not time_major and sequence_lengths is None and return_sequences:
|
@@ -166,6 +166,7 @@ class StackedRNNCells(base_layer.Layer):
|
|
166
166
|
|
167
167
|
@tf_utils.shape_type_conversion
|
168
168
|
def build(self, input_shape):
|
169
|
+
super().build(input_shape)
|
169
170
|
if isinstance(input_shape, list):
|
170
171
|
input_shape = input_shape[0]
|
171
172
|
|
@@ -195,7 +196,6 @@ class StackedRNNCells(base_layer.Layer):
|
|
195
196
|
input_shape = tuple(
|
196
197
|
[batch_size] + tf.TensorShape(output_dim).as_list()
|
197
198
|
)
|
198
|
-
self.built = True
|
199
199
|
|
200
200
|
def get_config(self):
|
201
201
|
cells = []
|
@@ -135,7 +135,6 @@ class TimeDistributed(Wrapper):
|
|
135
135
|
)
|
136
136
|
child_input_shape = tf_utils.convert_shapes(child_input_shape)
|
137
137
|
super().build(tuple(child_input_shape))
|
138
|
-
self.built = True
|
139
138
|
|
140
139
|
def compute_output_shape(self, input_shape):
|
141
140
|
input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
|
@@ -124,20 +124,21 @@ class AutoCastVariable(tf.Variable, tf.__internal__.types.Tensor):
|
|
124
124
|
def _should_cast(self):
|
125
125
|
"""Returns True if this variable should be casted when accessed."""
|
126
126
|
autocast_dtype = getattr(_autocast_dtype, "dtype", None)
|
127
|
-
return autocast_dtype is not None and self.
|
127
|
+
return autocast_dtype is not None and self.true_dtype != autocast_dtype
|
128
128
|
|
129
129
|
@property
|
130
130
|
def dtype(self):
|
131
|
-
"""The dtype
|
132
|
-
return self.
|
131
|
+
"""The dtype when the value is accessed, that is after casting."""
|
132
|
+
return self._cast_dtype
|
133
133
|
|
134
134
|
@property
|
135
135
|
def true_dtype(self):
|
136
|
-
"""
|
136
|
+
"""The dtype of the underlying variable, before any casts are done."""
|
137
137
|
return self._variable.dtype
|
138
138
|
|
139
139
|
@property
|
140
140
|
def _cast_dtype(self):
|
141
|
+
"""The dtype after casting."""
|
141
142
|
dtype = getattr(_autocast_dtype, "dtype", None)
|
142
143
|
return dtype or self._variable.dtype
|
143
144
|
|
@@ -202,7 +203,8 @@ class AutoCastVariable(tf.Variable, tf.__internal__.types.Tensor):
|
|
202
203
|
if tf.executing_eagerly() and not self._in_graph_mode:
|
203
204
|
repr_str = (
|
204
205
|
"<AutoCastVariable '{v.name}' shape={v.shape} "
|
205
|
-
"dtype={v.
|
206
|
+
"dtype={v.true_dtype.name} "
|
207
|
+
"dtype_to_cast_to={v._cast_dtype.name}, "
|
206
208
|
"numpy={np_repr}>"
|
207
209
|
)
|
208
210
|
return repr_str.format(
|
@@ -211,7 +213,8 @@ class AutoCastVariable(tf.Variable, tf.__internal__.types.Tensor):
|
|
211
213
|
else:
|
212
214
|
repr_str = (
|
213
215
|
"<AutoCastVariable '{v.name}' shape={v.shape} "
|
214
|
-
"dtype={v.
|
216
|
+
"dtype={v.true_dtype.name} "
|
217
|
+
"dtype_to_cast_to={v._cast_dtype.name}>"
|
215
218
|
)
|
216
219
|
return repr_str.format(v=self)
|
217
220
|
|
@@ -261,6 +264,9 @@ class AutoCastVariable(tf.Variable, tf.__internal__.types.Tensor):
|
|
261
264
|
def _apply_assign_update(
|
262
265
|
self, update_fn, value, use_locking=None, name=None, read_value=True
|
263
266
|
):
|
267
|
+
# In auto cast scope, we cast back to the actual variable dtype.
|
268
|
+
if self._should_cast():
|
269
|
+
value = tf.cast(value, self.true_dtype)
|
264
270
|
# TODO(b/146181571): This logic can be simplified once
|
265
271
|
# DistributedVariable.assign returns a DistributedVariable. Currently
|
266
272
|
# for MirroredStrategy, it returns a Mirrored value.
|