tf-keras-nightly 2.17.0.dev2024050509__py3-none-any.whl → 2.19.0.dev2024101709__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. tf_keras/__init__.py +1 -1
  2. tf_keras/src/__init__.py +1 -1
  3. tf_keras/src/callbacks.py +24 -7
  4. tf_keras/src/engine/base_layer.py +10 -4
  5. tf_keras/src/engine/base_layer_v1.py +10 -4
  6. tf_keras/src/engine/node.py +8 -3
  7. tf_keras/src/layers/activation/prelu.py +1 -1
  8. tf_keras/src/layers/attention/base_dense_attention.py +2 -1
  9. tf_keras/src/layers/convolutional/base_conv.py +1 -1
  10. tf_keras/src/layers/convolutional/base_depthwise_conv.py +3 -1
  11. tf_keras/src/layers/convolutional/base_separable_conv.py +3 -1
  12. tf_keras/src/layers/convolutional/conv1d_transpose.py +3 -1
  13. tf_keras/src/layers/convolutional/conv2d_transpose.py +3 -1
  14. tf_keras/src/layers/convolutional/conv3d_transpose.py +3 -1
  15. tf_keras/src/layers/core/dense.py +1 -1
  16. tf_keras/src/layers/core/embedding.py +1 -1
  17. tf_keras/src/layers/locally_connected/locally_connected1d.py +1 -1
  18. tf_keras/src/layers/locally_connected/locally_connected2d.py +1 -1
  19. tf_keras/src/layers/normalization/batch_normalization.py +1 -1
  20. tf_keras/src/layers/normalization/layer_normalization.py +1 -1
  21. tf_keras/src/layers/rnn/abstract_rnn_cell.py +1 -1
  22. tf_keras/src/layers/rnn/base_conv_lstm.py +0 -1
  23. tf_keras/src/layers/rnn/base_conv_rnn.py +3 -1
  24. tf_keras/src/layers/rnn/base_rnn.py +1 -1
  25. tf_keras/src/layers/rnn/base_wrapper.py +1 -1
  26. tf_keras/src/layers/rnn/bidirectional.py +2 -1
  27. tf_keras/src/layers/rnn/cell_wrappers.py +3 -3
  28. tf_keras/src/layers/rnn/cudnn_gru.py +6 -3
  29. tf_keras/src/layers/rnn/cudnn_lstm.py +6 -3
  30. tf_keras/src/layers/rnn/gru.py +35 -47
  31. tf_keras/src/layers/rnn/legacy_cell_wrappers.py +3 -3
  32. tf_keras/src/layers/rnn/legacy_cells.py +20 -25
  33. tf_keras/src/layers/rnn/lstm.py +35 -50
  34. tf_keras/src/layers/rnn/simple_rnn.py +0 -1
  35. tf_keras/src/layers/rnn/stacked_rnn_cells.py +1 -1
  36. tf_keras/src/layers/rnn/time_distributed.py +0 -1
  37. tf_keras/src/mixed_precision/autocast_variable.py +12 -6
  38. tf_keras/src/mixed_precision/test_util.py +6 -5
  39. tf_keras/src/optimizers/legacy/optimizer_v2.py +9 -2
  40. tf_keras/src/optimizers/optimizer.py +18 -9
  41. tf_keras/src/premade_models/linear.py +2 -1
  42. tf_keras/src/utils/data_utils.py +1 -1
  43. tf_keras/src/utils/steps_per_execution_tuning.py +1 -1
  44. tf_keras/src/utils/timeseries_dataset.py +13 -5
  45. {tf_keras_nightly-2.17.0.dev2024050509.dist-info → tf_keras_nightly-2.19.0.dev2024101709.dist-info}/METADATA +2 -2
  46. {tf_keras_nightly-2.17.0.dev2024050509.dist-info → tf_keras_nightly-2.19.0.dev2024101709.dist-info}/RECORD +48 -48
  47. {tf_keras_nightly-2.17.0.dev2024050509.dist-info → tf_keras_nightly-2.19.0.dev2024101709.dist-info}/WHEEL +1 -1
  48. {tf_keras_nightly-2.17.0.dev2024050509.dist-info → tf_keras_nightly-2.19.0.dev2024101709.dist-info}/top_level.txt +0 -0
tf_keras/__init__.py CHANGED
@@ -27,4 +27,4 @@ from tf_keras.src.engine.sequential import Sequential
27
27
  from tf_keras.src.engine.training import Model
28
28
 
29
29
 
30
- __version__ = "2.17.0.dev2024050509"
30
+ __version__ = "2.19.0.dev2024101709"
tf_keras/src/__init__.py CHANGED
@@ -35,7 +35,7 @@ from tf_keras.src.testing_infra import test_utils
35
35
  from tensorflow.python import tf2
36
36
  from tensorflow.python.util.tf_export import keras_export
37
37
 
38
- __version__ = "2.17.0"
38
+ __version__ = "2.19.0"
39
39
 
40
40
  keras_export("keras.__version__").export_constant(__name__, "__version__")
41
41
 
tf_keras/src/callbacks.py CHANGED
@@ -1423,20 +1423,20 @@ class ModelCheckpoint(Callback):
1423
1423
  if mode == "min":
1424
1424
  self.monitor_op = np.less
1425
1425
  if self.best is None:
1426
- self.best = np.Inf
1426
+ self.best = np.inf
1427
1427
  elif mode == "max":
1428
1428
  self.monitor_op = np.greater
1429
1429
  if self.best is None:
1430
- self.best = -np.Inf
1430
+ self.best = -np.inf
1431
1431
  else:
1432
1432
  if "acc" in self.monitor or self.monitor.startswith("fmeasure"):
1433
1433
  self.monitor_op = np.greater
1434
1434
  if self.best is None:
1435
- self.best = -np.Inf
1435
+ self.best = -np.inf
1436
1436
  else:
1437
1437
  self.monitor_op = np.less
1438
1438
  if self.best is None:
1439
- self.best = np.Inf
1439
+ self.best = np.inf
1440
1440
 
1441
1441
  if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
1442
1442
  raise ValueError(
@@ -1903,6 +1903,23 @@ class BackupAndRestore(Callback):
1903
1903
  "only supports empty strategy, "
1904
1904
  "MirroredStrategy, MultiWorkerMirroredStrategy and TPUStrategy."
1905
1905
  )
1906
+
1907
+ # Re-initialize the optimizer.
1908
+ if self.model.built:
1909
+ if (
1910
+ self.model.optimizer is not None
1911
+ and callable(getattr(self.model.optimizer, "build", None))
1912
+ and not getattr(self.model.optimizer, "_built", False)
1913
+ ):
1914
+ self.model.optimizer.build(self.model.trainable_variables)
1915
+ else:
1916
+ logging.warning(
1917
+ "To use the BackupAndRestore callback, "
1918
+ "you model must be built before you call `fit()`. "
1919
+ f"Model {self.model} is unbuilt. You can build it "
1920
+ "beforehand by calling it on a batch of data."
1921
+ )
1922
+
1906
1923
  self.model._training_state = worker_training_state.WorkerTrainingState(
1907
1924
  self.model,
1908
1925
  self.backup_dir,
@@ -2095,7 +2112,7 @@ class EarlyStopping(Callback):
2095
2112
  # Allow instances to be re-used
2096
2113
  self.wait = 0
2097
2114
  self.stopped_epoch = 0
2098
- self.best = np.Inf if self.monitor_op == np.less else -np.Inf
2115
+ self.best = np.inf if self.monitor_op == np.less else -np.inf
2099
2116
  self.best_weights = None
2100
2117
  self.best_epoch = 0
2101
2118
 
@@ -3098,10 +3115,10 @@ class ReduceLROnPlateau(Callback):
3098
3115
  self.mode == "auto" and "acc" not in self.monitor
3099
3116
  ):
3100
3117
  self.monitor_op = lambda a, b: np.less(a, b - self.min_delta)
3101
- self.best = np.Inf
3118
+ self.best = np.inf
3102
3119
  else:
3103
3120
  self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta)
3104
- self.best = -np.Inf
3121
+ self.best = -np.inf
3105
3122
  self.cooldown_counter = 0
3106
3123
  self.wait = 0
3107
3124
 
@@ -578,7 +578,8 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
578
578
  Accepted values are constants defined in the class
579
579
  `tf.VariableAggregation`.
580
580
  **kwargs: Additional keyword arguments. Accepted values are `getter`,
581
- `collections`, `experimental_autocast` and `caching_device`.
581
+ `collections`, `autocast`, `experimental_autocast` and
582
+ `caching_device`.
582
583
 
583
584
  Returns:
584
585
  The variable created.
@@ -594,6 +595,7 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
594
595
  # Validate optional keyword arguments.
595
596
  for kwarg in kwargs:
596
597
  if kwarg not in [
598
+ "autocast",
597
599
  "collections",
598
600
  "experimental_autocast",
599
601
  "caching_device",
@@ -603,9 +605,13 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
603
605
  ]:
604
606
  raise TypeError("Unknown keyword argument:", kwarg)
605
607
  collections_arg = kwargs.pop("collections", None)
606
- # 'experimental_autocast' can be set to False by the caller to indicate
607
- # an AutoCastVariable should never be created.
608
- autocast = kwargs.pop("experimental_autocast", True)
608
+ # 'autocast' or 'experimental_autocast' can be set to False by the
609
+ # caller to indicate an AutoCastVariable should never be created.
610
+ autocast = kwargs.pop("autocast", None)
611
+ if autocast is None:
612
+ autocast = kwargs.pop("experimental_autocast", None)
613
+ if autocast is None:
614
+ autocast = True
609
615
  # See the docstring for tf.Variable about the details for
610
616
  # caching_device.
611
617
  caching_device = kwargs.pop("caching_device", None)
@@ -352,7 +352,8 @@ class Layer(base_layer.Layer):
352
352
  Accepted values are constants defined in the class
353
353
  `tf.VariableAggregation`.
354
354
  **kwargs: Additional keyword arguments. Accepted values are `getter`,
355
- `collections`, `experimental_autocast` and `caching_device`.
355
+ `collections`, `autocast`, `experimental_autocast` and
356
+ `caching_device`.
356
357
 
357
358
  Returns:
358
359
  The created variable. Usually either a `Variable` or
@@ -371,6 +372,7 @@ class Layer(base_layer.Layer):
371
372
  # Validate optional keyword arguments.
372
373
  for kwarg in kwargs:
373
374
  if kwarg not in [
375
+ "autocast",
374
376
  "getter",
375
377
  "collections",
376
378
  "experimental_autocast",
@@ -380,9 +382,13 @@ class Layer(base_layer.Layer):
380
382
  has_custom_getter = "getter" in kwargs
381
383
  getter = kwargs.pop("getter", base_layer_utils.make_variable)
382
384
  collections_arg = kwargs.pop("collections", None)
383
- # 'experimental_autocast' can be set to False by the caller to indicate
384
- # an AutoCastVariable should never be created.
385
- autocast = kwargs.pop("experimental_autocast", True)
385
+ # 'autocast' or 'experimental_autocast' can be set to False by the
386
+ # caller to indicate an AutoCastVariable should never be created.
387
+ autocast = kwargs.pop("autocast", None)
388
+ if autocast is None:
389
+ autocast = kwargs.pop("experimental_autocast", None)
390
+ if autocast is None:
391
+ autocast = True
386
392
  # See the docstring for tf.Variable about the details for
387
393
  # caching_device.
388
394
  caching_device = kwargs.pop("caching_device", None)
@@ -84,9 +84,10 @@ class Node:
84
84
  self.call_args = call_args
85
85
  self.call_kwargs = call_kwargs
86
86
 
87
- # Cached for performance.
87
+ # Cached for performance. Put kwargs in order of the call method instead
88
+ # of using the sorted key order from `tf.nest.flatten`.
88
89
  self._flat_arguments = tf.nest.flatten(
89
- (self.call_args, self.call_kwargs)
90
+ (self.call_args, self.call_kwargs.values())
90
91
  )
91
92
  # Used to avoid expensive `nest` operations in the most common case.
92
93
  self._single_positional_tensor_passed = (
@@ -176,9 +177,13 @@ class Node:
176
177
  for kt_id, kt_index in self._keras_inputs_ids_and_indices:
177
178
  flat_arguments[kt_index] = tensor_dict[kt_id].pop()
178
179
 
180
+ # Pack the same way as `self._flat_arguments`, i.e. `kwargs` as a
181
+ # list in the original order.
179
182
  args, kwargs = tf.nest.pack_sequence_as(
180
- (self.call_args, self.call_kwargs), flat_arguments
183
+ (self.call_args, self.call_kwargs.values()), flat_arguments
181
184
  )
185
+ # Add the keys to `kwargs` to go from a list to a dict.
186
+ kwargs = {k: v for k, v in zip(self.call_kwargs.keys(), kwargs)}
182
187
  return args, kwargs
183
188
 
184
189
  def serialize(self, make_node_key, node_conversion_map):
@@ -102,7 +102,7 @@ class PReLU(Layer):
102
102
  if i not in self.shared_axes:
103
103
  axes[i] = input_shape[i]
104
104
  self.input_spec = InputSpec(ndim=len(input_shape), axes=axes)
105
- self.built = True
105
+ super().build(input_shape)
106
106
 
107
107
  def call(self, inputs):
108
108
  pos = backend.relu(inputs)
@@ -86,7 +86,8 @@ class BaseDenseAttention(base_layer.BaseRandomLayer):
86
86
  # be purely stateless, with no reference to any variable.
87
87
  if self.dropout > 0:
88
88
  super().build(input_shape)
89
- self.built = True
89
+ else:
90
+ base_layer.Layer.build(self, input_shape)
90
91
 
91
92
  def _calculate_scores(self, query, key):
92
93
  """Calculates attention scores.
@@ -248,7 +248,7 @@ class Conv(Layer):
248
248
  self.input_spec = InputSpec(
249
249
  min_ndim=self.rank + 2, axes={channel_axis: input_channel}
250
250
  )
251
- self.built = True
251
+ super().build(input_shape)
252
252
 
253
253
  def convolution_op(self, inputs, kernel):
254
254
  if self.padding == "causal":
@@ -20,6 +20,7 @@ import tensorflow.compat.v2 as tf
20
20
  from tf_keras.src import constraints
21
21
  from tf_keras.src import initializers
22
22
  from tf_keras.src import regularizers
23
+ from tf_keras.src.engine.base_layer import Layer
23
24
  from tf_keras.src.engine.input_spec import InputSpec
24
25
  from tf_keras.src.layers.convolutional.base_conv import Conv
25
26
 
@@ -202,7 +203,8 @@ class DepthwiseConv(Conv):
202
203
  self.input_spec = InputSpec(
203
204
  min_ndim=self.rank + 2, axes={channel_axis: input_dim}
204
205
  )
205
- self.built = True
206
+ # Call Layer.build() to skip Conv.build() which we override here.
207
+ Layer.build(self, input_shape)
206
208
 
207
209
  def call(self, inputs):
208
210
  raise NotImplementedError
@@ -21,6 +21,7 @@ from tf_keras.src import activations
21
21
  from tf_keras.src import constraints
22
22
  from tf_keras.src import initializers
23
23
  from tf_keras.src import regularizers
24
+ from tf_keras.src.engine.base_layer import Layer
24
25
  from tf_keras.src.engine.input_spec import InputSpec
25
26
  from tf_keras.src.layers.convolutional.base_conv import Conv
26
27
 
@@ -203,7 +204,8 @@ class SeparableConv(Conv):
203
204
  )
204
205
  else:
205
206
  self.bias = None
206
- self.built = True
207
+ # Call Layer.build() to skip Conv.build() which we override here.
208
+ Layer.build(self, input_shape)
207
209
 
208
210
  def call(self, inputs):
209
211
  raise NotImplementedError
@@ -22,6 +22,7 @@ from tf_keras.src import constraints
22
22
  from tf_keras.src import initializers
23
23
  from tf_keras.src import regularizers
24
24
  from tf_keras.src.dtensor import utils
25
+ from tf_keras.src.engine.base_layer import Layer
25
26
  from tf_keras.src.engine.input_spec import InputSpec
26
27
  from tf_keras.src.layers.convolutional.conv1d import Conv1D
27
28
  from tf_keras.src.utils import conv_utils
@@ -214,7 +215,8 @@ class Conv1DTranspose(Conv1D):
214
215
  )
215
216
  else:
216
217
  self.bias = None
217
- self.built = True
218
+ # Call Layer.build() to skip Conv.build() which we override here.
219
+ Layer.build(self, input_shape)
218
220
 
219
221
  def call(self, inputs):
220
222
  inputs_shape = tf.shape(inputs)
@@ -23,6 +23,7 @@ from tf_keras.src import constraints
23
23
  from tf_keras.src import initializers
24
24
  from tf_keras.src import regularizers
25
25
  from tf_keras.src.dtensor import utils
26
+ from tf_keras.src.engine.base_layer import Layer
26
27
  from tf_keras.src.engine.input_spec import InputSpec
27
28
  from tf_keras.src.layers.convolutional.conv2d import Conv2D
28
29
  from tf_keras.src.utils import conv_utils
@@ -240,7 +241,8 @@ class Conv2DTranspose(Conv2D):
240
241
  )
241
242
  else:
242
243
  self.bias = None
243
- self.built = True
244
+ # Call Layer.build() to skip Conv.build() which we override here.
245
+ Layer.build(self, input_shape)
244
246
 
245
247
  def call(self, inputs):
246
248
  inputs_shape = tf.shape(inputs)
@@ -22,6 +22,7 @@ from tf_keras.src import constraints
22
22
  from tf_keras.src import initializers
23
23
  from tf_keras.src import regularizers
24
24
  from tf_keras.src.dtensor import utils
25
+ from tf_keras.src.engine.base_layer import Layer
25
26
  from tf_keras.src.engine.input_spec import InputSpec
26
27
  from tf_keras.src.layers.convolutional.conv3d import Conv3D
27
28
  from tf_keras.src.utils import conv_utils
@@ -247,7 +248,8 @@ class Conv3DTranspose(Conv3D):
247
248
  )
248
249
  else:
249
250
  self.bias = None
250
- self.built = True
251
+ # Call Layer.build() to skip Conv.build() which we override here.
252
+ Layer.build(self, input_shape)
251
253
 
252
254
  def call(self, inputs):
253
255
  inputs_shape = tf.shape(inputs)
@@ -174,7 +174,7 @@ class Dense(Layer):
174
174
  )
175
175
  else:
176
176
  self.bias = None
177
- self.built = True
177
+ super().build(input_shape)
178
178
 
179
179
  def call(self, inputs):
180
180
  if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
@@ -185,7 +185,7 @@ class Embedding(Layer):
185
185
  constraint=self.embeddings_constraint,
186
186
  experimental_autocast=False,
187
187
  )
188
- self.built = True
188
+ super().build(input_shape)
189
189
 
190
190
  def compute_mask(self, inputs, mask=None):
191
191
  if not self.mask_zero:
@@ -284,7 +284,7 @@ class LocallyConnected1D(Layer):
284
284
  self.input_spec = InputSpec(ndim=3, axes={1: input_dim})
285
285
  else:
286
286
  self.input_spec = InputSpec(ndim=3, axes={-1: input_dim})
287
- self.built = True
287
+ super().build(input_shape)
288
288
 
289
289
  @tf_utils.shape_type_conversion
290
290
  def compute_output_shape(self, input_shape):
@@ -308,7 +308,7 @@ class LocallyConnected2D(Layer):
308
308
  self.input_spec = InputSpec(ndim=4, axes={1: input_filter})
309
309
  else:
310
310
  self.input_spec = InputSpec(ndim=4, axes={-1: input_filter})
311
- self.built = True
311
+ super().build(input_shape)
312
312
 
313
313
  @tf_utils.shape_type_conversion
314
314
  def compute_output_shape(self, input_shape):
@@ -542,7 +542,7 @@ class BatchNormalizationBase(Layer):
542
542
  finally:
543
543
  if partitioner:
544
544
  self._scope.set_partitioner(partitioner)
545
- self.built = True
545
+ super().build(input_shape)
546
546
 
547
547
  def call(self, inputs, training=None, mask=None):
548
548
  inputs = tf.cast(inputs, self.compute_dtype)
@@ -249,7 +249,7 @@ class LayerNormalization(Layer):
249
249
  self.beta = None
250
250
 
251
251
  self._fused = self._fused_can_be_used(rank)
252
- self.built = True
252
+ super().build(input_shape)
253
253
 
254
254
  def call(self, inputs):
255
255
  # TODO(b/229545225): Remove the RaggedTensor check.
@@ -56,7 +56,7 @@ class AbstractRNNCell(base_layer.Layer):
56
56
  shape=(self.units, self.units),
57
57
  initializer='uniform',
58
58
  name='recurrent_kernel')
59
- self.built = True
59
+ super().build(input_shape)
60
60
 
61
61
  def call(self, inputs, states):
62
62
  prev_output = states[0]
@@ -218,7 +218,6 @@ class ConvLSTMCell(DropoutRNNCellMixin, base_layer.BaseRandomLayer):
218
218
  )
219
219
  else:
220
220
  self.bias = None
221
- self.built = True
222
221
 
223
222
  def call(self, inputs, states, training=None):
224
223
  h_tm1 = states[0] # previous memory state
@@ -20,6 +20,7 @@ import tensorflow.compat.v2 as tf
20
20
 
21
21
  from tf_keras.src import backend
22
22
  from tf_keras.src.engine import base_layer
23
+ from tf_keras.src.engine.base_layer import Layer
23
24
  from tf_keras.src.engine.input_spec import InputSpec
24
25
  from tf_keras.src.layers.rnn.base_rnn import RNN
25
26
  from tf_keras.src.utils import conv_utils
@@ -207,6 +208,8 @@ class ConvRNN(RNN):
207
208
 
208
209
  @tf_utils.shape_type_conversion
209
210
  def build(self, input_shape):
211
+ # Call Layer.build() to skip RNN.build() which we override here.
212
+ Layer.build(self, input_shape)
210
213
  # Note input_shape will be list of shapes of initial states and
211
214
  # constants if these are passed in __call__.
212
215
  if self._num_constants is not None:
@@ -263,7 +266,6 @@ class ConvRNN(RNN):
263
266
  ]
264
267
  if self.stateful:
265
268
  self.reset_states()
266
- self.built = True
267
269
 
268
270
  def get_initial_state(self, inputs):
269
271
  # (samples, timesteps, img_dims..., filters)
@@ -207,7 +207,7 @@ class RNN(base_layer.Layer):
207
207
  shape=(self.units, self.units),
208
208
  initializer='uniform',
209
209
  name='recurrent_kernel')
210
- self.built = True
210
+ super().build(input_shape)
211
211
 
212
212
  def call(self, inputs, states):
213
213
  prev_output = states[0]
@@ -56,7 +56,7 @@ class Wrapper(Layer):
56
56
  if not self.layer.built:
57
57
  self.layer.build(input_shape)
58
58
  self.layer.built = True
59
- self.built = True
59
+ super().build(input_shape)
60
60
 
61
61
  @property
62
62
  def activity_regularizer(self):
@@ -470,7 +470,8 @@ class Bidirectional(Wrapper):
470
470
  self.forward_layer.build(input_shape)
471
471
  with backend.name_scope(self.backward_layer.name):
472
472
  self.backward_layer.build(input_shape)
473
- self.built = True
473
+ # Call Layer.build() to skip Wrapper.build() which we override here.
474
+ Layer.build(self, input_shape)
474
475
 
475
476
  def compute_mask(self, inputs, mask):
476
477
  if isinstance(mask, list):
@@ -102,10 +102,10 @@ class _RNNCellWrapper(AbstractRNNCell):
102
102
  inputs, state, cell_call_fn=self.cell.call, **kwargs
103
103
  )
104
104
 
105
- def build(self, inputs_shape):
105
+ def build(self, input_shape):
106
106
  """Builds the wrapped cell."""
107
- self.cell.build(inputs_shape)
108
- self.built = True
107
+ self.cell.build(input_shape)
108
+ super().build(input_shape)
109
109
 
110
110
  @property
111
111
  def wrapped_cell(self):
@@ -144,8 +144,6 @@ class CuDNNGRU(_CuDNNRNN):
144
144
  constraint=self.bias_constraint,
145
145
  )
146
146
 
147
- self.built = True
148
-
149
147
  def _process_batch(self, inputs, initial_state):
150
148
  if not self.time_major:
151
149
  inputs = tf.transpose(inputs, perm=(1, 0, 2))
@@ -172,6 +170,10 @@ class CuDNNGRU(_CuDNNRNN):
172
170
  shape=self._vector_shape,
173
171
  )
174
172
 
173
+ batch_dim = tf.shape(inputs)[1]
174
+ max_sequence_length = tf.shape(inputs)[0]
175
+ sequence_lengths = tf.fill([batch_dim], max_sequence_length)
176
+
175
177
  args = {
176
178
  "input": inputs,
177
179
  "input_h": input_h,
@@ -179,9 +181,10 @@ class CuDNNGRU(_CuDNNRNN):
179
181
  "params": params,
180
182
  "is_training": True,
181
183
  "rnn_mode": "gru",
184
+ "sequence_lengths": sequence_lengths,
182
185
  }
183
186
 
184
- outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV2(**args)
187
+ outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(**args)
185
188
 
186
189
  if self.stateful or self.return_state:
187
190
  h = h[0]
@@ -170,8 +170,6 @@ class CuDNNLSTM(_CuDNNRNN):
170
170
  constraint=self.bias_constraint,
171
171
  )
172
172
 
173
- self.built = True
174
-
175
173
  def _process_batch(self, inputs, initial_state):
176
174
  if not self.time_major:
177
175
  inputs = tf.transpose(inputs, perm=(1, 0, 2))
@@ -204,15 +202,20 @@ class CuDNNLSTM(_CuDNNRNN):
204
202
  shape=self._vector_shape,
205
203
  )
206
204
 
205
+ batch_dim = tf.shape(inputs)[1]
206
+ max_sequence_length = tf.shape(inputs)[0]
207
+ sequence_lengths = tf.fill([batch_dim], max_sequence_length)
208
+
207
209
  args = {
208
210
  "input": inputs,
209
211
  "input_h": input_h,
210
212
  "input_c": input_c,
211
213
  "params": params,
212
214
  "is_training": True,
215
+ "sequence_lengths": sequence_lengths,
213
216
  }
214
217
 
215
- outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV2(**args)
218
+ outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3(**args)
216
219
 
217
220
  if self.stateful or self.return_state:
218
221
  h = h[0]
@@ -222,7 +222,6 @@ class GRUCell(DropoutRNNCellMixin, base_layer.BaseRandomLayer):
222
222
  )
223
223
  else:
224
224
  self.bias = None
225
- self.built = True
226
225
 
227
226
  def call(self, inputs, states, training=None):
228
227
  h_tm1 = (
@@ -1034,11 +1033,13 @@ def gpu_gru(
1034
1033
  mask, time_major
1035
1034
  )
1036
1035
 
1037
- if not time_major and sequence_lengths is None:
1038
- inputs = tf.transpose(inputs, perm=(1, 0, 2))
1039
- seq_axis, batch_axis = (0, 1)
1040
- else:
1041
- seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
1036
+ seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
1037
+
1038
+ if sequence_lengths is None:
1039
+ max_sequence_length = tf.shape(inputs)[seq_axis]
1040
+ batch_size = tf.shape(inputs)[batch_axis]
1041
+ sequence_lengths = tf.fill([batch_size], max_sequence_length)
1042
+
1042
1043
  # For init_h, cuDNN expects one more dim of num_layers before or after batch
1043
1044
  # dim for time major or batch major inputs respectively
1044
1045
  init_h = tf.expand_dims(init_h, axis=seq_axis)
@@ -1069,49 +1070,36 @@ def gpu_gru(
1069
1070
  transpose_weights=True,
1070
1071
  )
1071
1072
 
1072
- if sequence_lengths is not None:
1073
- if go_backwards:
1074
- # Three reversals are required. E.g.,
1075
- # normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
1076
- # reversed_input_to_cudnn = [3, 2, 1, 0, 0]
1077
- # output_from_cudnn = [6, 5, 4, 0, 0]
1078
- # expected_output = [0, 0, 6, 5 ,4]
1079
- inputs = tf.reverse_sequence(
1080
- inputs,
1081
- sequence_lengths,
1082
- seq_axis=seq_axis,
1083
- batch_axis=batch_axis,
1084
- )
1085
- outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(
1086
- input=inputs,
1087
- input_h=init_h,
1088
- input_c=0,
1089
- params=params,
1090
- is_training=True,
1091
- rnn_mode="gru",
1092
- sequence_lengths=sequence_lengths,
1093
- time_major=time_major,
1073
+ if go_backwards:
1074
+ # Three reversals are required. E.g.,
1075
+ # normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
1076
+ # reversed_input_to_cudnn = [3, 2, 1, 0, 0]
1077
+ # output_from_cudnn = [6, 5, 4, 0, 0]
1078
+ # expected_output = [0, 0, 6, 5 ,4]
1079
+ inputs = tf.reverse_sequence(
1080
+ inputs,
1081
+ sequence_lengths,
1082
+ seq_axis=seq_axis,
1083
+ batch_axis=batch_axis,
1094
1084
  )
1095
- if go_backwards:
1096
- outputs = tf.reverse_sequence(
1097
- outputs,
1098
- sequence_lengths,
1099
- seq_axis=seq_axis,
1100
- batch_axis=batch_axis,
1101
- )
1102
- outputs = tf.reverse(outputs, axis=[seq_axis])
1103
- else:
1104
- if go_backwards:
1105
- # Reverse axis 0 since the input is already convert to time major.
1106
- inputs = tf.reverse(inputs, axis=[0])
1107
- outputs, h, _, _ = tf.raw_ops.CudnnRNN(
1108
- input=inputs,
1109
- input_h=init_h,
1110
- input_c=0,
1111
- params=params,
1112
- is_training=True,
1113
- rnn_mode="gru",
1085
+ outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(
1086
+ input=inputs,
1087
+ input_h=init_h,
1088
+ input_c=0,
1089
+ params=params,
1090
+ is_training=True,
1091
+ rnn_mode="gru",
1092
+ sequence_lengths=sequence_lengths,
1093
+ time_major=time_major,
1094
+ )
1095
+ if go_backwards:
1096
+ outputs = tf.reverse_sequence(
1097
+ outputs,
1098
+ sequence_lengths,
1099
+ seq_axis=seq_axis,
1100
+ batch_axis=batch_axis,
1114
1101
  )
1102
+ outputs = tf.reverse(outputs, axis=[seq_axis])
1115
1103
 
1116
1104
  last_output = outputs[-1]
1117
1105
  if not time_major and sequence_lengths is None and return_sequences:
@@ -368,9 +368,9 @@ class DropoutWrapper(_RNNCellWrapperV1):
368
368
  def wrapped_cell(self):
369
369
  return self.cell
370
370
 
371
- def build(self, inputs_shape):
372
- self.cell.build(inputs_shape)
373
- self.built = True
371
+ def build(self, input_shape):
372
+ self.cell.build(input_shape)
373
+ super().build(input_shape)
374
374
 
375
375
  def _variational_recurrent_dropout_value(
376
376
  self, unused_index, value, noise, keep_prob