tf-keras-nightly 2.17.0.dev2024031909__py3-none-any.whl → 2.19.0.dev2025011410__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tf_keras/__init__.py +1 -1
  2. tf_keras/src/__init__.py +1 -1
  3. tf_keras/src/backend.py +1 -1
  4. tf_keras/src/callbacks.py +24 -7
  5. tf_keras/src/datasets/boston_housing.py +14 -5
  6. tf_keras/src/datasets/cifar10.py +9 -1
  7. tf_keras/src/datasets/cifar100.py +7 -1
  8. tf_keras/src/datasets/fashion_mnist.py +16 -4
  9. tf_keras/src/datasets/imdb.py +8 -0
  10. tf_keras/src/datasets/mnist.py +9 -3
  11. tf_keras/src/datasets/reuters.py +8 -0
  12. tf_keras/src/engine/base_layer.py +10 -4
  13. tf_keras/src/engine/base_layer_v1.py +10 -4
  14. tf_keras/src/engine/node.py +8 -3
  15. tf_keras/src/layers/activation/prelu.py +1 -1
  16. tf_keras/src/layers/attention/base_dense_attention.py +2 -1
  17. tf_keras/src/layers/convolutional/base_conv.py +1 -1
  18. tf_keras/src/layers/convolutional/base_depthwise_conv.py +3 -1
  19. tf_keras/src/layers/convolutional/base_separable_conv.py +3 -1
  20. tf_keras/src/layers/convolutional/conv1d_transpose.py +3 -1
  21. tf_keras/src/layers/convolutional/conv2d_transpose.py +3 -1
  22. tf_keras/src/layers/convolutional/conv3d_transpose.py +3 -1
  23. tf_keras/src/layers/core/dense.py +1 -1
  24. tf_keras/src/layers/core/embedding.py +1 -1
  25. tf_keras/src/layers/locally_connected/locally_connected1d.py +1 -1
  26. tf_keras/src/layers/locally_connected/locally_connected2d.py +1 -1
  27. tf_keras/src/layers/normalization/batch_normalization.py +1 -1
  28. tf_keras/src/layers/normalization/layer_normalization.py +1 -1
  29. tf_keras/src/layers/normalization/unit_normalization.py +2 -1
  30. tf_keras/src/layers/rnn/abstract_rnn_cell.py +1 -1
  31. tf_keras/src/layers/rnn/base_conv_lstm.py +0 -1
  32. tf_keras/src/layers/rnn/base_conv_rnn.py +3 -1
  33. tf_keras/src/layers/rnn/base_rnn.py +1 -1
  34. tf_keras/src/layers/rnn/base_wrapper.py +1 -1
  35. tf_keras/src/layers/rnn/bidirectional.py +2 -1
  36. tf_keras/src/layers/rnn/cell_wrappers.py +3 -3
  37. tf_keras/src/layers/rnn/cudnn_gru.py +6 -3
  38. tf_keras/src/layers/rnn/cudnn_lstm.py +6 -3
  39. tf_keras/src/layers/rnn/gru.py +35 -47
  40. tf_keras/src/layers/rnn/legacy_cell_wrappers.py +3 -3
  41. tf_keras/src/layers/rnn/legacy_cells.py +20 -25
  42. tf_keras/src/layers/rnn/lstm.py +35 -50
  43. tf_keras/src/layers/rnn/simple_rnn.py +0 -1
  44. tf_keras/src/layers/rnn/stacked_rnn_cells.py +1 -1
  45. tf_keras/src/layers/rnn/time_distributed.py +0 -1
  46. tf_keras/src/mixed_precision/autocast_variable.py +12 -6
  47. tf_keras/src/mixed_precision/test_util.py +6 -5
  48. tf_keras/src/optimizers/legacy/optimizer_v2.py +9 -2
  49. tf_keras/src/optimizers/optimizer.py +18 -9
  50. tf_keras/src/premade_models/linear.py +2 -1
  51. tf_keras/src/saving/legacy/saved_model/json_utils.py +1 -1
  52. tf_keras/src/saving/saving_api.py +165 -127
  53. tf_keras/src/saving/saving_lib.py +1 -11
  54. tf_keras/src/saving/serialization_lib.py +1 -10
  55. tf_keras/src/utils/data_utils.py +1 -1
  56. tf_keras/src/utils/steps_per_execution_tuning.py +1 -1
  57. tf_keras/src/utils/tf_utils.py +2 -2
  58. tf_keras/src/utils/timeseries_dataset.py +13 -5
  59. {tf_keras_nightly-2.17.0.dev2024031909.dist-info → tf_keras_nightly-2.19.0.dev2025011410.dist-info}/METADATA +14 -3
  60. {tf_keras_nightly-2.17.0.dev2024031909.dist-info → tf_keras_nightly-2.19.0.dev2025011410.dist-info}/RECORD +62 -62
  61. {tf_keras_nightly-2.17.0.dev2024031909.dist-info → tf_keras_nightly-2.19.0.dev2025011410.dist-info}/WHEEL +1 -1
  62. {tf_keras_nightly-2.17.0.dev2024031909.dist-info → tf_keras_nightly-2.19.0.dev2025011410.dist-info}/top_level.txt +0 -0
@@ -308,7 +308,7 @@ class LocallyConnected2D(Layer):
308
308
  self.input_spec = InputSpec(ndim=4, axes={1: input_filter})
309
309
  else:
310
310
  self.input_spec = InputSpec(ndim=4, axes={-1: input_filter})
311
- self.built = True
311
+ super().build(input_shape)
312
312
 
313
313
  @tf_utils.shape_type_conversion
314
314
  def compute_output_shape(self, input_shape):
@@ -542,7 +542,7 @@ class BatchNormalizationBase(Layer):
542
542
  finally:
543
543
  if partitioner:
544
544
  self._scope.set_partitioner(partitioner)
545
- self.built = True
545
+ super().build(input_shape)
546
546
 
547
547
  def call(self, inputs, training=None, mask=None):
548
548
  inputs = tf.cast(inputs, self.compute_dtype)
@@ -249,7 +249,7 @@ class LayerNormalization(Layer):
249
249
  self.beta = None
250
250
 
251
251
  self._fused = self._fused_can_be_used(rank)
252
- self.built = True
252
+ super().build(input_shape)
253
253
 
254
254
  def call(self, inputs):
255
255
  # TODO(b/229545225): Remove the RaggedTensor check.
@@ -60,7 +60,8 @@ class UnitNormalization(base_layer.Layer):
60
60
  self.supports_masking = True
61
61
 
62
62
  def build(self, input_shape):
63
- self.axis = tf_utils.validate_axis(self.axis, input_shape)
63
+ tf_utils.validate_axis(self.axis, input_shape)
64
+ super().build(input_shape)
64
65
 
65
66
  def call(self, inputs):
66
67
  inputs = tf.cast(inputs, self.compute_dtype)
@@ -56,7 +56,7 @@ class AbstractRNNCell(base_layer.Layer):
56
56
  shape=(self.units, self.units),
57
57
  initializer='uniform',
58
58
  name='recurrent_kernel')
59
- self.built = True
59
+ super().build(input_shape)
60
60
 
61
61
  def call(self, inputs, states):
62
62
  prev_output = states[0]
@@ -218,7 +218,6 @@ class ConvLSTMCell(DropoutRNNCellMixin, base_layer.BaseRandomLayer):
218
218
  )
219
219
  else:
220
220
  self.bias = None
221
- self.built = True
222
221
 
223
222
  def call(self, inputs, states, training=None):
224
223
  h_tm1 = states[0] # previous memory state
@@ -20,6 +20,7 @@ import tensorflow.compat.v2 as tf
20
20
 
21
21
  from tf_keras.src import backend
22
22
  from tf_keras.src.engine import base_layer
23
+ from tf_keras.src.engine.base_layer import Layer
23
24
  from tf_keras.src.engine.input_spec import InputSpec
24
25
  from tf_keras.src.layers.rnn.base_rnn import RNN
25
26
  from tf_keras.src.utils import conv_utils
@@ -207,6 +208,8 @@ class ConvRNN(RNN):
207
208
 
208
209
  @tf_utils.shape_type_conversion
209
210
  def build(self, input_shape):
211
+ # Call Layer.build() to skip RNN.build() which we override here.
212
+ Layer.build(self, input_shape)
210
213
  # Note input_shape will be list of shapes of initial states and
211
214
  # constants if these are passed in __call__.
212
215
  if self._num_constants is not None:
@@ -263,7 +266,6 @@ class ConvRNN(RNN):
263
266
  ]
264
267
  if self.stateful:
265
268
  self.reset_states()
266
- self.built = True
267
269
 
268
270
  def get_initial_state(self, inputs):
269
271
  # (samples, timesteps, img_dims..., filters)
@@ -207,7 +207,7 @@ class RNN(base_layer.Layer):
207
207
  shape=(self.units, self.units),
208
208
  initializer='uniform',
209
209
  name='recurrent_kernel')
210
- self.built = True
210
+ super().build(input_shape)
211
211
 
212
212
  def call(self, inputs, states):
213
213
  prev_output = states[0]
@@ -56,7 +56,7 @@ class Wrapper(Layer):
56
56
  if not self.layer.built:
57
57
  self.layer.build(input_shape)
58
58
  self.layer.built = True
59
- self.built = True
59
+ super().build(input_shape)
60
60
 
61
61
  @property
62
62
  def activity_regularizer(self):
@@ -470,7 +470,8 @@ class Bidirectional(Wrapper):
470
470
  self.forward_layer.build(input_shape)
471
471
  with backend.name_scope(self.backward_layer.name):
472
472
  self.backward_layer.build(input_shape)
473
- self.built = True
473
+ # Call Layer.build() to skip Wrapper.build() which we override here.
474
+ Layer.build(self, input_shape)
474
475
 
475
476
  def compute_mask(self, inputs, mask):
476
477
  if isinstance(mask, list):
@@ -102,10 +102,10 @@ class _RNNCellWrapper(AbstractRNNCell):
102
102
  inputs, state, cell_call_fn=self.cell.call, **kwargs
103
103
  )
104
104
 
105
- def build(self, inputs_shape):
105
+ def build(self, input_shape):
106
106
  """Builds the wrapped cell."""
107
- self.cell.build(inputs_shape)
108
- self.built = True
107
+ self.cell.build(input_shape)
108
+ super().build(input_shape)
109
109
 
110
110
  @property
111
111
  def wrapped_cell(self):
@@ -144,8 +144,6 @@ class CuDNNGRU(_CuDNNRNN):
144
144
  constraint=self.bias_constraint,
145
145
  )
146
146
 
147
- self.built = True
148
-
149
147
  def _process_batch(self, inputs, initial_state):
150
148
  if not self.time_major:
151
149
  inputs = tf.transpose(inputs, perm=(1, 0, 2))
@@ -172,6 +170,10 @@ class CuDNNGRU(_CuDNNRNN):
172
170
  shape=self._vector_shape,
173
171
  )
174
172
 
173
+ batch_dim = tf.shape(inputs)[1]
174
+ max_sequence_length = tf.shape(inputs)[0]
175
+ sequence_lengths = tf.fill([batch_dim], max_sequence_length)
176
+
175
177
  args = {
176
178
  "input": inputs,
177
179
  "input_h": input_h,
@@ -179,9 +181,10 @@ class CuDNNGRU(_CuDNNRNN):
179
181
  "params": params,
180
182
  "is_training": True,
181
183
  "rnn_mode": "gru",
184
+ "sequence_lengths": sequence_lengths,
182
185
  }
183
186
 
184
- outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV2(**args)
187
+ outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(**args)
185
188
 
186
189
  if self.stateful or self.return_state:
187
190
  h = h[0]
@@ -170,8 +170,6 @@ class CuDNNLSTM(_CuDNNRNN):
170
170
  constraint=self.bias_constraint,
171
171
  )
172
172
 
173
- self.built = True
174
-
175
173
  def _process_batch(self, inputs, initial_state):
176
174
  if not self.time_major:
177
175
  inputs = tf.transpose(inputs, perm=(1, 0, 2))
@@ -204,15 +202,20 @@ class CuDNNLSTM(_CuDNNRNN):
204
202
  shape=self._vector_shape,
205
203
  )
206
204
 
205
+ batch_dim = tf.shape(inputs)[1]
206
+ max_sequence_length = tf.shape(inputs)[0]
207
+ sequence_lengths = tf.fill([batch_dim], max_sequence_length)
208
+
207
209
  args = {
208
210
  "input": inputs,
209
211
  "input_h": input_h,
210
212
  "input_c": input_c,
211
213
  "params": params,
212
214
  "is_training": True,
215
+ "sequence_lengths": sequence_lengths,
213
216
  }
214
217
 
215
- outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV2(**args)
218
+ outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3(**args)
216
219
 
217
220
  if self.stateful or self.return_state:
218
221
  h = h[0]
@@ -222,7 +222,6 @@ class GRUCell(DropoutRNNCellMixin, base_layer.BaseRandomLayer):
222
222
  )
223
223
  else:
224
224
  self.bias = None
225
- self.built = True
226
225
 
227
226
  def call(self, inputs, states, training=None):
228
227
  h_tm1 = (
@@ -1034,11 +1033,13 @@ def gpu_gru(
1034
1033
  mask, time_major
1035
1034
  )
1036
1035
 
1037
- if not time_major and sequence_lengths is None:
1038
- inputs = tf.transpose(inputs, perm=(1, 0, 2))
1039
- seq_axis, batch_axis = (0, 1)
1040
- else:
1041
- seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
1036
+ seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
1037
+
1038
+ if sequence_lengths is None:
1039
+ max_sequence_length = tf.shape(inputs)[seq_axis]
1040
+ batch_size = tf.shape(inputs)[batch_axis]
1041
+ sequence_lengths = tf.fill([batch_size], max_sequence_length)
1042
+
1042
1043
  # For init_h, cuDNN expects one more dim of num_layers before or after batch
1043
1044
  # dim for time major or batch major inputs respectively
1044
1045
  init_h = tf.expand_dims(init_h, axis=seq_axis)
@@ -1069,49 +1070,36 @@ def gpu_gru(
1069
1070
  transpose_weights=True,
1070
1071
  )
1071
1072
 
1072
- if sequence_lengths is not None:
1073
- if go_backwards:
1074
- # Three reversals are required. E.g.,
1075
- # normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
1076
- # reversed_input_to_cudnn = [3, 2, 1, 0, 0]
1077
- # output_from_cudnn = [6, 5, 4, 0, 0]
1078
- # expected_output = [0, 0, 6, 5 ,4]
1079
- inputs = tf.reverse_sequence(
1080
- inputs,
1081
- sequence_lengths,
1082
- seq_axis=seq_axis,
1083
- batch_axis=batch_axis,
1084
- )
1085
- outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(
1086
- input=inputs,
1087
- input_h=init_h,
1088
- input_c=0,
1089
- params=params,
1090
- is_training=True,
1091
- rnn_mode="gru",
1092
- sequence_lengths=sequence_lengths,
1093
- time_major=time_major,
1073
+ if go_backwards:
1074
+ # Three reversals are required. E.g.,
1075
+ # normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
1076
+ # reversed_input_to_cudnn = [3, 2, 1, 0, 0]
1077
+ # output_from_cudnn = [6, 5, 4, 0, 0]
1078
+ # expected_output = [0, 0, 6, 5 ,4]
1079
+ inputs = tf.reverse_sequence(
1080
+ inputs,
1081
+ sequence_lengths,
1082
+ seq_axis=seq_axis,
1083
+ batch_axis=batch_axis,
1094
1084
  )
1095
- if go_backwards:
1096
- outputs = tf.reverse_sequence(
1097
- outputs,
1098
- sequence_lengths,
1099
- seq_axis=seq_axis,
1100
- batch_axis=batch_axis,
1101
- )
1102
- outputs = tf.reverse(outputs, axis=[seq_axis])
1103
- else:
1104
- if go_backwards:
1105
- # Reverse axis 0 since the input is already convert to time major.
1106
- inputs = tf.reverse(inputs, axis=[0])
1107
- outputs, h, _, _ = tf.raw_ops.CudnnRNN(
1108
- input=inputs,
1109
- input_h=init_h,
1110
- input_c=0,
1111
- params=params,
1112
- is_training=True,
1113
- rnn_mode="gru",
1085
+ outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(
1086
+ input=inputs,
1087
+ input_h=init_h,
1088
+ input_c=0,
1089
+ params=params,
1090
+ is_training=True,
1091
+ rnn_mode="gru",
1092
+ sequence_lengths=sequence_lengths,
1093
+ time_major=time_major,
1094
+ )
1095
+ if go_backwards:
1096
+ outputs = tf.reverse_sequence(
1097
+ outputs,
1098
+ sequence_lengths,
1099
+ seq_axis=seq_axis,
1100
+ batch_axis=batch_axis,
1114
1101
  )
1102
+ outputs = tf.reverse(outputs, axis=[seq_axis])
1115
1103
 
1116
1104
  last_output = outputs[-1]
1117
1105
  if not time_major and sequence_lengths is None and return_sequences:
@@ -368,9 +368,9 @@ class DropoutWrapper(_RNNCellWrapperV1):
368
368
  def wrapped_cell(self):
369
369
  return self.cell
370
370
 
371
- def build(self, inputs_shape):
372
- self.cell.build(inputs_shape)
373
- self.built = True
371
+ def build(self, input_shape):
372
+ self.cell.build(input_shape)
373
+ super().build(input_shape)
374
374
 
375
375
  def _variational_recurrent_dropout_value(
376
376
  self, unused_index, value, noise, keep_prob
@@ -246,11 +246,6 @@ class RNNCell(base_layer.Layer):
246
246
  """Integer or TensorShape: size of outputs produced by this cell."""
247
247
  raise NotImplementedError("Abstract method")
248
248
 
249
- def build(self, _):
250
- # This tells the parent Layer object that it's OK to call
251
- # self.add_weight() inside the call() method.
252
- pass
253
-
254
249
  def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
255
250
  if inputs is not None:
256
251
  # Validate the given batch_size and dtype against inputs if
@@ -445,15 +440,15 @@ class BasicRNNCell(LayerRNNCell):
445
440
  return self._num_units
446
441
 
447
442
  @tf_utils.shape_type_conversion
448
- def build(self, inputs_shape):
449
- if inputs_shape[-1] is None:
443
+ def build(self, input_shape):
444
+ if input_shape[-1] is None:
450
445
  raise ValueError(
451
446
  "Expected inputs.shape[-1] to be known, "
452
- f"received shape: {inputs_shape}"
447
+ f"received shape: {input_shape}"
453
448
  )
454
449
  _check_supported_dtypes(self.dtype)
455
450
 
456
- input_depth = inputs_shape[-1]
451
+ input_depth = input_shape[-1]
457
452
  self._kernel = self.add_weight(
458
453
  _WEIGHTS_VARIABLE_NAME,
459
454
  shape=[input_depth + self._num_units, self._num_units],
@@ -464,7 +459,7 @@ class BasicRNNCell(LayerRNNCell):
464
459
  initializer=tf.compat.v1.zeros_initializer(dtype=self.dtype),
465
460
  )
466
461
 
467
- self.built = True
462
+ super().build(input_shape)
468
463
 
469
464
  def call(self, inputs, state):
470
465
  """Most basic RNN: output = new_state = act(W * input + U * state +
@@ -563,14 +558,14 @@ class GRUCell(LayerRNNCell):
563
558
  return self._num_units
564
559
 
565
560
  @tf_utils.shape_type_conversion
566
- def build(self, inputs_shape):
567
- if inputs_shape[-1] is None:
561
+ def build(self, input_shape):
562
+ if input_shape[-1] is None:
568
563
  raise ValueError(
569
564
  "Expected inputs.shape[-1] to be known, "
570
- f"received shape: {inputs_shape}"
565
+ f"received shape: {input_shape}"
571
566
  )
572
567
  _check_supported_dtypes(self.dtype)
573
- input_depth = inputs_shape[-1]
568
+ input_depth = input_shape[-1]
574
569
  self._gate_kernel = self.add_weight(
575
570
  f"gates/{_WEIGHTS_VARIABLE_NAME}",
576
571
  shape=[input_depth + self._num_units, 2 * self._num_units],
@@ -600,7 +595,7 @@ class GRUCell(LayerRNNCell):
600
595
  ),
601
596
  )
602
597
 
603
- self.built = True
598
+ super().build(input_shape)
604
599
 
605
600
  def call(self, inputs, state):
606
601
  """Gated recurrent unit (GRU) with nunits cells."""
@@ -774,14 +769,14 @@ class BasicLSTMCell(LayerRNNCell):
774
769
  return self._num_units
775
770
 
776
771
  @tf_utils.shape_type_conversion
777
- def build(self, inputs_shape):
778
- if inputs_shape[-1] is None:
772
+ def build(self, input_shape):
773
+ if input_shape[-1] is None:
779
774
  raise ValueError(
780
775
  "Expected inputs.shape[-1] to be known, "
781
- f"received shape: {inputs_shape}"
776
+ f"received shape: {input_shape}"
782
777
  )
783
778
  _check_supported_dtypes(self.dtype)
784
- input_depth = inputs_shape[-1]
779
+ input_depth = input_shape[-1]
785
780
  h_depth = self._num_units
786
781
  self._kernel = self.add_weight(
787
782
  _WEIGHTS_VARIABLE_NAME,
@@ -793,7 +788,7 @@ class BasicLSTMCell(LayerRNNCell):
793
788
  initializer=tf.compat.v1.zeros_initializer(dtype=self.dtype),
794
789
  )
795
790
 
796
- self.built = True
791
+ super().build(input_shape)
797
792
 
798
793
  def call(self, inputs, state):
799
794
  """Long short-term memory cell (LSTM).
@@ -1017,14 +1012,14 @@ class LSTMCell(LayerRNNCell):
1017
1012
  return self._output_size
1018
1013
 
1019
1014
  @tf_utils.shape_type_conversion
1020
- def build(self, inputs_shape):
1021
- if inputs_shape[-1] is None:
1015
+ def build(self, input_shape):
1016
+ if input_shape[-1] is None:
1022
1017
  raise ValueError(
1023
1018
  "Expected inputs.shape[-1] to be known, "
1024
- f"received shape: {inputs_shape}"
1019
+ f"received shape: {input_shape}"
1025
1020
  )
1026
1021
  _check_supported_dtypes(self.dtype)
1027
- input_depth = inputs_shape[-1]
1022
+ input_depth = input_shape[-1]
1028
1023
  h_depth = self._num_units if self._num_proj is None else self._num_proj
1029
1024
  maybe_partitioner = (
1030
1025
  tf.compat.v1.fixed_size_partitioner(self._num_unit_shards)
@@ -1076,7 +1071,7 @@ class LSTMCell(LayerRNNCell):
1076
1071
  partitioner=maybe_proj_partitioner,
1077
1072
  )
1078
1073
 
1079
- self.built = True
1074
+ super().build(input_shape)
1080
1075
 
1081
1076
  def call(self, inputs, state):
1082
1077
  """Run one step of LSTM.
@@ -236,7 +236,6 @@ class LSTMCell(DropoutRNNCellMixin, base_layer.BaseRandomLayer):
236
236
  )
237
237
  else:
238
238
  self.bias = None
239
- self.built = True
240
239
 
241
240
  def _compute_carry_and_output(self, x, h_tm1, c_tm1):
242
241
  """Computes carry and output using split kernels."""
@@ -1063,11 +1062,13 @@ def gpu_lstm(
1063
1062
  mask, time_major
1064
1063
  )
1065
1064
 
1066
- if not time_major and sequence_lengths is None:
1067
- inputs = tf.transpose(inputs, perm=(1, 0, 2))
1068
- seq_axis, batch_axis = (0, 1)
1069
- else:
1070
- seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
1065
+ seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
1066
+
1067
+ if sequence_lengths is None:
1068
+ max_sequence_length = tf.shape(inputs)[seq_axis]
1069
+ batch_size = tf.shape(inputs)[batch_axis]
1070
+ sequence_lengths = tf.fill([batch_size], max_sequence_length)
1071
+
1071
1072
  # For init_h and init_c, cuDNN expects one more dim of num_layers before or
1072
1073
  # after batch dim for time major or batch major inputs respectively
1073
1074
  init_h = tf.expand_dims(init_h, axis=seq_axis)
@@ -1099,52 +1100,36 @@ def gpu_lstm(
1099
1100
  transpose_weights=True,
1100
1101
  )
1101
1102
 
1102
- if sequence_lengths is not None:
1103
- if go_backwards:
1104
- # Three reversals are required. E.g.,
1105
- # normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
1106
- # reversed_input_to_cudnn = [3, 2, 1, 0, 0]
1107
- # output_from_cudnn = [6, 5, 4, 0, 0]
1108
- # expected_output = [0, 0, 6, 5 ,4]
1109
- inputs = tf.reverse_sequence(
1110
- inputs,
1111
- sequence_lengths,
1112
- seq_axis=seq_axis,
1113
- batch_axis=batch_axis,
1114
- )
1115
- outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3(
1116
- input=inputs,
1117
- input_h=init_h,
1118
- input_c=init_c,
1119
- params=params,
1120
- is_training=True,
1121
- rnn_mode="lstm",
1122
- sequence_lengths=sequence_lengths,
1123
- time_major=time_major,
1103
+ if go_backwards:
1104
+ # Three reversals are required. E.g.,
1105
+ # normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
1106
+ # reversed_input_to_cudnn = [3, 2, 1, 0, 0]
1107
+ # output_from_cudnn = [6, 5, 4, 0, 0]
1108
+ # expected_output = [0, 0, 6, 5 ,4]
1109
+ inputs = tf.reverse_sequence(
1110
+ inputs,
1111
+ sequence_lengths,
1112
+ seq_axis=seq_axis,
1113
+ batch_axis=batch_axis,
1124
1114
  )
1125
- if go_backwards:
1126
- outputs = tf.reverse_sequence(
1127
- outputs,
1128
- sequence_lengths,
1129
- seq_axis=seq_axis,
1130
- batch_axis=batch_axis,
1131
- )
1132
- outputs = tf.reverse(outputs, axis=[seq_axis])
1133
- else:
1134
- # # Fill the array with shape [batch] with value of max timesteps.
1135
- # sequence_length = array_ops.fill([array_ops.shape(inputs)[1]],
1136
- # array_ops.shape(inputs)[0])
1137
- if go_backwards:
1138
- # Reverse axis 0 since the input is already convert to time major.
1139
- inputs = tf.reverse(inputs, axis=[0])
1140
- outputs, h, c, _ = tf.raw_ops.CudnnRNN(
1141
- input=inputs,
1142
- input_h=init_h,
1143
- input_c=init_c,
1144
- params=params,
1145
- is_training=True,
1146
- rnn_mode="lstm",
1115
+ outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3(
1116
+ input=inputs,
1117
+ input_h=init_h,
1118
+ input_c=init_c,
1119
+ params=params,
1120
+ is_training=True,
1121
+ rnn_mode="lstm",
1122
+ sequence_lengths=sequence_lengths,
1123
+ time_major=time_major,
1124
+ )
1125
+ if go_backwards:
1126
+ outputs = tf.reverse_sequence(
1127
+ outputs,
1128
+ sequence_lengths,
1129
+ seq_axis=seq_axis,
1130
+ batch_axis=batch_axis,
1147
1131
  )
1132
+ outputs = tf.reverse(outputs, axis=[seq_axis])
1148
1133
 
1149
1134
  last_output = outputs[-1]
1150
1135
  if not time_major and sequence_lengths is None and return_sequences:
@@ -189,7 +189,6 @@ class SimpleRNNCell(DropoutRNNCellMixin, base_layer.BaseRandomLayer):
189
189
  )
190
190
  else:
191
191
  self.bias = None
192
- self.built = True
193
192
 
194
193
  def call(self, inputs, states, training=None):
195
194
  prev_output = states[0] if tf.nest.is_nested(states) else states
@@ -166,6 +166,7 @@ class StackedRNNCells(base_layer.Layer):
166
166
 
167
167
  @tf_utils.shape_type_conversion
168
168
  def build(self, input_shape):
169
+ super().build(input_shape)
169
170
  if isinstance(input_shape, list):
170
171
  input_shape = input_shape[0]
171
172
 
@@ -195,7 +196,6 @@ class StackedRNNCells(base_layer.Layer):
195
196
  input_shape = tuple(
196
197
  [batch_size] + tf.TensorShape(output_dim).as_list()
197
198
  )
198
- self.built = True
199
199
 
200
200
  def get_config(self):
201
201
  cells = []
@@ -135,7 +135,6 @@ class TimeDistributed(Wrapper):
135
135
  )
136
136
  child_input_shape = tf_utils.convert_shapes(child_input_shape)
137
137
  super().build(tuple(child_input_shape))
138
- self.built = True
139
138
 
140
139
  def compute_output_shape(self, input_shape):
141
140
  input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
@@ -124,20 +124,21 @@ class AutoCastVariable(tf.Variable, tf.__internal__.types.Tensor):
124
124
  def _should_cast(self):
125
125
  """Returns True if this variable should be casted when accessed."""
126
126
  autocast_dtype = getattr(_autocast_dtype, "dtype", None)
127
- return autocast_dtype is not None and self.dtype != autocast_dtype
127
+ return autocast_dtype is not None and self.true_dtype != autocast_dtype
128
128
 
129
129
  @property
130
130
  def dtype(self):
131
- """The dtype of the underlying variable, before any casts are done."""
132
- return self._variable.dtype
131
+ """The dtype when the value is accessed, that is after casting."""
132
+ return self._cast_dtype
133
133
 
134
134
  @property
135
135
  def true_dtype(self):
136
- """Deprecated alias of `dtype`."""
136
+ """The dtype of the underlying variable, before any casts are done."""
137
137
  return self._variable.dtype
138
138
 
139
139
  @property
140
140
  def _cast_dtype(self):
141
+ """The dtype after casting."""
141
142
  dtype = getattr(_autocast_dtype, "dtype", None)
142
143
  return dtype or self._variable.dtype
143
144
 
@@ -202,7 +203,8 @@ class AutoCastVariable(tf.Variable, tf.__internal__.types.Tensor):
202
203
  if tf.executing_eagerly() and not self._in_graph_mode:
203
204
  repr_str = (
204
205
  "<AutoCastVariable '{v.name}' shape={v.shape} "
205
- "dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}, "
206
+ "dtype={v.true_dtype.name} "
207
+ "dtype_to_cast_to={v._cast_dtype.name}, "
206
208
  "numpy={np_repr}>"
207
209
  )
208
210
  return repr_str.format(
@@ -211,7 +213,8 @@ class AutoCastVariable(tf.Variable, tf.__internal__.types.Tensor):
211
213
  else:
212
214
  repr_str = (
213
215
  "<AutoCastVariable '{v.name}' shape={v.shape} "
214
- "dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}>"
216
+ "dtype={v.true_dtype.name} "
217
+ "dtype_to_cast_to={v._cast_dtype.name}>"
215
218
  )
216
219
  return repr_str.format(v=self)
217
220
 
@@ -261,6 +264,9 @@ class AutoCastVariable(tf.Variable, tf.__internal__.types.Tensor):
261
264
  def _apply_assign_update(
262
265
  self, update_fn, value, use_locking=None, name=None, read_value=True
263
266
  ):
267
+ # In auto cast scope, we cast back to the actual variable dtype.
268
+ if self._should_cast():
269
+ value = tf.cast(value, self.true_dtype)
264
270
  # TODO(b/146181571): This logic can be simplified once
265
271
  # DistributedVariable.assign returns a DistributedVariable. Currently
266
272
  # for MirroredStrategy, it returns a Mirrored value.