tf-keras-nightly 2.17.0.dev2024050509__py3-none-any.whl → 2.19.0.dev2024101709__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tf_keras/__init__.py +1 -1
- tf_keras/src/__init__.py +1 -1
- tf_keras/src/callbacks.py +24 -7
- tf_keras/src/engine/base_layer.py +10 -4
- tf_keras/src/engine/base_layer_v1.py +10 -4
- tf_keras/src/engine/node.py +8 -3
- tf_keras/src/layers/activation/prelu.py +1 -1
- tf_keras/src/layers/attention/base_dense_attention.py +2 -1
- tf_keras/src/layers/convolutional/base_conv.py +1 -1
- tf_keras/src/layers/convolutional/base_depthwise_conv.py +3 -1
- tf_keras/src/layers/convolutional/base_separable_conv.py +3 -1
- tf_keras/src/layers/convolutional/conv1d_transpose.py +3 -1
- tf_keras/src/layers/convolutional/conv2d_transpose.py +3 -1
- tf_keras/src/layers/convolutional/conv3d_transpose.py +3 -1
- tf_keras/src/layers/core/dense.py +1 -1
- tf_keras/src/layers/core/embedding.py +1 -1
- tf_keras/src/layers/locally_connected/locally_connected1d.py +1 -1
- tf_keras/src/layers/locally_connected/locally_connected2d.py +1 -1
- tf_keras/src/layers/normalization/batch_normalization.py +1 -1
- tf_keras/src/layers/normalization/layer_normalization.py +1 -1
- tf_keras/src/layers/rnn/abstract_rnn_cell.py +1 -1
- tf_keras/src/layers/rnn/base_conv_lstm.py +0 -1
- tf_keras/src/layers/rnn/base_conv_rnn.py +3 -1
- tf_keras/src/layers/rnn/base_rnn.py +1 -1
- tf_keras/src/layers/rnn/base_wrapper.py +1 -1
- tf_keras/src/layers/rnn/bidirectional.py +2 -1
- tf_keras/src/layers/rnn/cell_wrappers.py +3 -3
- tf_keras/src/layers/rnn/cudnn_gru.py +6 -3
- tf_keras/src/layers/rnn/cudnn_lstm.py +6 -3
- tf_keras/src/layers/rnn/gru.py +35 -47
- tf_keras/src/layers/rnn/legacy_cell_wrappers.py +3 -3
- tf_keras/src/layers/rnn/legacy_cells.py +20 -25
- tf_keras/src/layers/rnn/lstm.py +35 -50
- tf_keras/src/layers/rnn/simple_rnn.py +0 -1
- tf_keras/src/layers/rnn/stacked_rnn_cells.py +1 -1
- tf_keras/src/layers/rnn/time_distributed.py +0 -1
- tf_keras/src/mixed_precision/autocast_variable.py +12 -6
- tf_keras/src/mixed_precision/test_util.py +6 -5
- tf_keras/src/optimizers/legacy/optimizer_v2.py +9 -2
- tf_keras/src/optimizers/optimizer.py +18 -9
- tf_keras/src/premade_models/linear.py +2 -1
- tf_keras/src/utils/data_utils.py +1 -1
- tf_keras/src/utils/steps_per_execution_tuning.py +1 -1
- tf_keras/src/utils/timeseries_dataset.py +13 -5
- {tf_keras_nightly-2.17.0.dev2024050509.dist-info → tf_keras_nightly-2.19.0.dev2024101709.dist-info}/METADATA +2 -2
- {tf_keras_nightly-2.17.0.dev2024050509.dist-info → tf_keras_nightly-2.19.0.dev2024101709.dist-info}/RECORD +48 -48
- {tf_keras_nightly-2.17.0.dev2024050509.dist-info → tf_keras_nightly-2.19.0.dev2024101709.dist-info}/WHEEL +1 -1
- {tf_keras_nightly-2.17.0.dev2024050509.dist-info → tf_keras_nightly-2.19.0.dev2024101709.dist-info}/top_level.txt +0 -0
tf_keras/__init__.py
CHANGED
tf_keras/src/__init__.py
CHANGED
@@ -35,7 +35,7 @@ from tf_keras.src.testing_infra import test_utils
|
|
35
35
|
from tensorflow.python import tf2
|
36
36
|
from tensorflow.python.util.tf_export import keras_export
|
37
37
|
|
38
|
-
__version__ = "2.
|
38
|
+
__version__ = "2.19.0"
|
39
39
|
|
40
40
|
keras_export("keras.__version__").export_constant(__name__, "__version__")
|
41
41
|
|
tf_keras/src/callbacks.py
CHANGED
@@ -1423,20 +1423,20 @@ class ModelCheckpoint(Callback):
|
|
1423
1423
|
if mode == "min":
|
1424
1424
|
self.monitor_op = np.less
|
1425
1425
|
if self.best is None:
|
1426
|
-
self.best = np.
|
1426
|
+
self.best = np.inf
|
1427
1427
|
elif mode == "max":
|
1428
1428
|
self.monitor_op = np.greater
|
1429
1429
|
if self.best is None:
|
1430
|
-
self.best = -np.
|
1430
|
+
self.best = -np.inf
|
1431
1431
|
else:
|
1432
1432
|
if "acc" in self.monitor or self.monitor.startswith("fmeasure"):
|
1433
1433
|
self.monitor_op = np.greater
|
1434
1434
|
if self.best is None:
|
1435
|
-
self.best = -np.
|
1435
|
+
self.best = -np.inf
|
1436
1436
|
else:
|
1437
1437
|
self.monitor_op = np.less
|
1438
1438
|
if self.best is None:
|
1439
|
-
self.best = np.
|
1439
|
+
self.best = np.inf
|
1440
1440
|
|
1441
1441
|
if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
|
1442
1442
|
raise ValueError(
|
@@ -1903,6 +1903,23 @@ class BackupAndRestore(Callback):
|
|
1903
1903
|
"only supports empty strategy, "
|
1904
1904
|
"MirroredStrategy, MultiWorkerMirroredStrategy and TPUStrategy."
|
1905
1905
|
)
|
1906
|
+
|
1907
|
+
# Re-initialize the optimizer.
|
1908
|
+
if self.model.built:
|
1909
|
+
if (
|
1910
|
+
self.model.optimizer is not None
|
1911
|
+
and callable(getattr(self.model.optimizer, "build", None))
|
1912
|
+
and not getattr(self.model.optimizer, "_built", False)
|
1913
|
+
):
|
1914
|
+
self.model.optimizer.build(self.model.trainable_variables)
|
1915
|
+
else:
|
1916
|
+
logging.warning(
|
1917
|
+
"To use the BackupAndRestore callback, "
|
1918
|
+
"you model must be built before you call `fit()`. "
|
1919
|
+
f"Model {self.model} is unbuilt. You can build it "
|
1920
|
+
"beforehand by calling it on a batch of data."
|
1921
|
+
)
|
1922
|
+
|
1906
1923
|
self.model._training_state = worker_training_state.WorkerTrainingState(
|
1907
1924
|
self.model,
|
1908
1925
|
self.backup_dir,
|
@@ -2095,7 +2112,7 @@ class EarlyStopping(Callback):
|
|
2095
2112
|
# Allow instances to be re-used
|
2096
2113
|
self.wait = 0
|
2097
2114
|
self.stopped_epoch = 0
|
2098
|
-
self.best = np.
|
2115
|
+
self.best = np.inf if self.monitor_op == np.less else -np.inf
|
2099
2116
|
self.best_weights = None
|
2100
2117
|
self.best_epoch = 0
|
2101
2118
|
|
@@ -3098,10 +3115,10 @@ class ReduceLROnPlateau(Callback):
|
|
3098
3115
|
self.mode == "auto" and "acc" not in self.monitor
|
3099
3116
|
):
|
3100
3117
|
self.monitor_op = lambda a, b: np.less(a, b - self.min_delta)
|
3101
|
-
self.best = np.
|
3118
|
+
self.best = np.inf
|
3102
3119
|
else:
|
3103
3120
|
self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta)
|
3104
|
-
self.best = -np.
|
3121
|
+
self.best = -np.inf
|
3105
3122
|
self.cooldown_counter = 0
|
3106
3123
|
self.wait = 0
|
3107
3124
|
|
@@ -578,7 +578,8 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
578
578
|
Accepted values are constants defined in the class
|
579
579
|
`tf.VariableAggregation`.
|
580
580
|
**kwargs: Additional keyword arguments. Accepted values are `getter`,
|
581
|
-
`collections`, `experimental_autocast` and
|
581
|
+
`collections`, `autocast`, `experimental_autocast` and
|
582
|
+
`caching_device`.
|
582
583
|
|
583
584
|
Returns:
|
584
585
|
The variable created.
|
@@ -594,6 +595,7 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
594
595
|
# Validate optional keyword arguments.
|
595
596
|
for kwarg in kwargs:
|
596
597
|
if kwarg not in [
|
598
|
+
"autocast",
|
597
599
|
"collections",
|
598
600
|
"experimental_autocast",
|
599
601
|
"caching_device",
|
@@ -603,9 +605,13 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
603
605
|
]:
|
604
606
|
raise TypeError("Unknown keyword argument:", kwarg)
|
605
607
|
collections_arg = kwargs.pop("collections", None)
|
606
|
-
# 'experimental_autocast' can be set to False by the
|
607
|
-
# an AutoCastVariable should never be created.
|
608
|
-
autocast = kwargs.pop("
|
608
|
+
# 'autocast' or 'experimental_autocast' can be set to False by the
|
609
|
+
# caller to indicate an AutoCastVariable should never be created.
|
610
|
+
autocast = kwargs.pop("autocast", None)
|
611
|
+
if autocast is None:
|
612
|
+
autocast = kwargs.pop("experimental_autocast", None)
|
613
|
+
if autocast is None:
|
614
|
+
autocast = True
|
609
615
|
# See the docstring for tf.Variable about the details for
|
610
616
|
# caching_device.
|
611
617
|
caching_device = kwargs.pop("caching_device", None)
|
@@ -352,7 +352,8 @@ class Layer(base_layer.Layer):
|
|
352
352
|
Accepted values are constants defined in the class
|
353
353
|
`tf.VariableAggregation`.
|
354
354
|
**kwargs: Additional keyword arguments. Accepted values are `getter`,
|
355
|
-
`collections`, `experimental_autocast` and
|
355
|
+
`collections`, `autocast`, `experimental_autocast` and
|
356
|
+
`caching_device`.
|
356
357
|
|
357
358
|
Returns:
|
358
359
|
The created variable. Usually either a `Variable` or
|
@@ -371,6 +372,7 @@ class Layer(base_layer.Layer):
|
|
371
372
|
# Validate optional keyword arguments.
|
372
373
|
for kwarg in kwargs:
|
373
374
|
if kwarg not in [
|
375
|
+
"autocast",
|
374
376
|
"getter",
|
375
377
|
"collections",
|
376
378
|
"experimental_autocast",
|
@@ -380,9 +382,13 @@ class Layer(base_layer.Layer):
|
|
380
382
|
has_custom_getter = "getter" in kwargs
|
381
383
|
getter = kwargs.pop("getter", base_layer_utils.make_variable)
|
382
384
|
collections_arg = kwargs.pop("collections", None)
|
383
|
-
# 'experimental_autocast' can be set to False by the
|
384
|
-
# an AutoCastVariable should never be created.
|
385
|
-
autocast = kwargs.pop("
|
385
|
+
# 'autocast' or 'experimental_autocast' can be set to False by the
|
386
|
+
# caller to indicate an AutoCastVariable should never be created.
|
387
|
+
autocast = kwargs.pop("autocast", None)
|
388
|
+
if autocast is None:
|
389
|
+
autocast = kwargs.pop("experimental_autocast", None)
|
390
|
+
if autocast is None:
|
391
|
+
autocast = True
|
386
392
|
# See the docstring for tf.Variable about the details for
|
387
393
|
# caching_device.
|
388
394
|
caching_device = kwargs.pop("caching_device", None)
|
tf_keras/src/engine/node.py
CHANGED
@@ -84,9 +84,10 @@ class Node:
|
|
84
84
|
self.call_args = call_args
|
85
85
|
self.call_kwargs = call_kwargs
|
86
86
|
|
87
|
-
# Cached for performance.
|
87
|
+
# Cached for performance. Put kwargs in order of the call method instead
|
88
|
+
# of using the sorted key order from `tf.nest.flatten`.
|
88
89
|
self._flat_arguments = tf.nest.flatten(
|
89
|
-
(self.call_args, self.call_kwargs)
|
90
|
+
(self.call_args, self.call_kwargs.values())
|
90
91
|
)
|
91
92
|
# Used to avoid expensive `nest` operations in the most common case.
|
92
93
|
self._single_positional_tensor_passed = (
|
@@ -176,9 +177,13 @@ class Node:
|
|
176
177
|
for kt_id, kt_index in self._keras_inputs_ids_and_indices:
|
177
178
|
flat_arguments[kt_index] = tensor_dict[kt_id].pop()
|
178
179
|
|
180
|
+
# Pack the same way as `self._flat_arguments`, i.e. `kwargs` as a
|
181
|
+
# list in the original order.
|
179
182
|
args, kwargs = tf.nest.pack_sequence_as(
|
180
|
-
(self.call_args, self.call_kwargs), flat_arguments
|
183
|
+
(self.call_args, self.call_kwargs.values()), flat_arguments
|
181
184
|
)
|
185
|
+
# Add the keys to `kwargs` to go from a list to a dict.
|
186
|
+
kwargs = {k: v for k, v in zip(self.call_kwargs.keys(), kwargs)}
|
182
187
|
return args, kwargs
|
183
188
|
|
184
189
|
def serialize(self, make_node_key, node_conversion_map):
|
@@ -102,7 +102,7 @@ class PReLU(Layer):
|
|
102
102
|
if i not in self.shared_axes:
|
103
103
|
axes[i] = input_shape[i]
|
104
104
|
self.input_spec = InputSpec(ndim=len(input_shape), axes=axes)
|
105
|
-
|
105
|
+
super().build(input_shape)
|
106
106
|
|
107
107
|
def call(self, inputs):
|
108
108
|
pos = backend.relu(inputs)
|
@@ -86,7 +86,8 @@ class BaseDenseAttention(base_layer.BaseRandomLayer):
|
|
86
86
|
# be purely stateless, with no reference to any variable.
|
87
87
|
if self.dropout > 0:
|
88
88
|
super().build(input_shape)
|
89
|
-
|
89
|
+
else:
|
90
|
+
base_layer.Layer.build(self, input_shape)
|
90
91
|
|
91
92
|
def _calculate_scores(self, query, key):
|
92
93
|
"""Calculates attention scores.
|
@@ -248,7 +248,7 @@ class Conv(Layer):
|
|
248
248
|
self.input_spec = InputSpec(
|
249
249
|
min_ndim=self.rank + 2, axes={channel_axis: input_channel}
|
250
250
|
)
|
251
|
-
|
251
|
+
super().build(input_shape)
|
252
252
|
|
253
253
|
def convolution_op(self, inputs, kernel):
|
254
254
|
if self.padding == "causal":
|
@@ -20,6 +20,7 @@ import tensorflow.compat.v2 as tf
|
|
20
20
|
from tf_keras.src import constraints
|
21
21
|
from tf_keras.src import initializers
|
22
22
|
from tf_keras.src import regularizers
|
23
|
+
from tf_keras.src.engine.base_layer import Layer
|
23
24
|
from tf_keras.src.engine.input_spec import InputSpec
|
24
25
|
from tf_keras.src.layers.convolutional.base_conv import Conv
|
25
26
|
|
@@ -202,7 +203,8 @@ class DepthwiseConv(Conv):
|
|
202
203
|
self.input_spec = InputSpec(
|
203
204
|
min_ndim=self.rank + 2, axes={channel_axis: input_dim}
|
204
205
|
)
|
205
|
-
|
206
|
+
# Call Layer.build() to skip Conv.build() which we override here.
|
207
|
+
Layer.build(self, input_shape)
|
206
208
|
|
207
209
|
def call(self, inputs):
|
208
210
|
raise NotImplementedError
|
@@ -21,6 +21,7 @@ from tf_keras.src import activations
|
|
21
21
|
from tf_keras.src import constraints
|
22
22
|
from tf_keras.src import initializers
|
23
23
|
from tf_keras.src import regularizers
|
24
|
+
from tf_keras.src.engine.base_layer import Layer
|
24
25
|
from tf_keras.src.engine.input_spec import InputSpec
|
25
26
|
from tf_keras.src.layers.convolutional.base_conv import Conv
|
26
27
|
|
@@ -203,7 +204,8 @@ class SeparableConv(Conv):
|
|
203
204
|
)
|
204
205
|
else:
|
205
206
|
self.bias = None
|
206
|
-
|
207
|
+
# Call Layer.build() to skip Conv.build() which we override here.
|
208
|
+
Layer.build(self, input_shape)
|
207
209
|
|
208
210
|
def call(self, inputs):
|
209
211
|
raise NotImplementedError
|
@@ -22,6 +22,7 @@ from tf_keras.src import constraints
|
|
22
22
|
from tf_keras.src import initializers
|
23
23
|
from tf_keras.src import regularizers
|
24
24
|
from tf_keras.src.dtensor import utils
|
25
|
+
from tf_keras.src.engine.base_layer import Layer
|
25
26
|
from tf_keras.src.engine.input_spec import InputSpec
|
26
27
|
from tf_keras.src.layers.convolutional.conv1d import Conv1D
|
27
28
|
from tf_keras.src.utils import conv_utils
|
@@ -214,7 +215,8 @@ class Conv1DTranspose(Conv1D):
|
|
214
215
|
)
|
215
216
|
else:
|
216
217
|
self.bias = None
|
217
|
-
|
218
|
+
# Call Layer.build() to skip Conv.build() which we override here.
|
219
|
+
Layer.build(self, input_shape)
|
218
220
|
|
219
221
|
def call(self, inputs):
|
220
222
|
inputs_shape = tf.shape(inputs)
|
@@ -23,6 +23,7 @@ from tf_keras.src import constraints
|
|
23
23
|
from tf_keras.src import initializers
|
24
24
|
from tf_keras.src import regularizers
|
25
25
|
from tf_keras.src.dtensor import utils
|
26
|
+
from tf_keras.src.engine.base_layer import Layer
|
26
27
|
from tf_keras.src.engine.input_spec import InputSpec
|
27
28
|
from tf_keras.src.layers.convolutional.conv2d import Conv2D
|
28
29
|
from tf_keras.src.utils import conv_utils
|
@@ -240,7 +241,8 @@ class Conv2DTranspose(Conv2D):
|
|
240
241
|
)
|
241
242
|
else:
|
242
243
|
self.bias = None
|
243
|
-
|
244
|
+
# Call Layer.build() to skip Conv.build() which we override here.
|
245
|
+
Layer.build(self, input_shape)
|
244
246
|
|
245
247
|
def call(self, inputs):
|
246
248
|
inputs_shape = tf.shape(inputs)
|
@@ -22,6 +22,7 @@ from tf_keras.src import constraints
|
|
22
22
|
from tf_keras.src import initializers
|
23
23
|
from tf_keras.src import regularizers
|
24
24
|
from tf_keras.src.dtensor import utils
|
25
|
+
from tf_keras.src.engine.base_layer import Layer
|
25
26
|
from tf_keras.src.engine.input_spec import InputSpec
|
26
27
|
from tf_keras.src.layers.convolutional.conv3d import Conv3D
|
27
28
|
from tf_keras.src.utils import conv_utils
|
@@ -247,7 +248,8 @@ class Conv3DTranspose(Conv3D):
|
|
247
248
|
)
|
248
249
|
else:
|
249
250
|
self.bias = None
|
250
|
-
|
251
|
+
# Call Layer.build() to skip Conv.build() which we override here.
|
252
|
+
Layer.build(self, input_shape)
|
251
253
|
|
252
254
|
def call(self, inputs):
|
253
255
|
inputs_shape = tf.shape(inputs)
|
@@ -284,7 +284,7 @@ class LocallyConnected1D(Layer):
|
|
284
284
|
self.input_spec = InputSpec(ndim=3, axes={1: input_dim})
|
285
285
|
else:
|
286
286
|
self.input_spec = InputSpec(ndim=3, axes={-1: input_dim})
|
287
|
-
|
287
|
+
super().build(input_shape)
|
288
288
|
|
289
289
|
@tf_utils.shape_type_conversion
|
290
290
|
def compute_output_shape(self, input_shape):
|
@@ -308,7 +308,7 @@ class LocallyConnected2D(Layer):
|
|
308
308
|
self.input_spec = InputSpec(ndim=4, axes={1: input_filter})
|
309
309
|
else:
|
310
310
|
self.input_spec = InputSpec(ndim=4, axes={-1: input_filter})
|
311
|
-
|
311
|
+
super().build(input_shape)
|
312
312
|
|
313
313
|
@tf_utils.shape_type_conversion
|
314
314
|
def compute_output_shape(self, input_shape):
|
@@ -542,7 +542,7 @@ class BatchNormalizationBase(Layer):
|
|
542
542
|
finally:
|
543
543
|
if partitioner:
|
544
544
|
self._scope.set_partitioner(partitioner)
|
545
|
-
|
545
|
+
super().build(input_shape)
|
546
546
|
|
547
547
|
def call(self, inputs, training=None, mask=None):
|
548
548
|
inputs = tf.cast(inputs, self.compute_dtype)
|
@@ -20,6 +20,7 @@ import tensorflow.compat.v2 as tf
|
|
20
20
|
|
21
21
|
from tf_keras.src import backend
|
22
22
|
from tf_keras.src.engine import base_layer
|
23
|
+
from tf_keras.src.engine.base_layer import Layer
|
23
24
|
from tf_keras.src.engine.input_spec import InputSpec
|
24
25
|
from tf_keras.src.layers.rnn.base_rnn import RNN
|
25
26
|
from tf_keras.src.utils import conv_utils
|
@@ -207,6 +208,8 @@ class ConvRNN(RNN):
|
|
207
208
|
|
208
209
|
@tf_utils.shape_type_conversion
|
209
210
|
def build(self, input_shape):
|
211
|
+
# Call Layer.build() to skip RNN.build() which we override here.
|
212
|
+
Layer.build(self, input_shape)
|
210
213
|
# Note input_shape will be list of shapes of initial states and
|
211
214
|
# constants if these are passed in __call__.
|
212
215
|
if self._num_constants is not None:
|
@@ -263,7 +266,6 @@ class ConvRNN(RNN):
|
|
263
266
|
]
|
264
267
|
if self.stateful:
|
265
268
|
self.reset_states()
|
266
|
-
self.built = True
|
267
269
|
|
268
270
|
def get_initial_state(self, inputs):
|
269
271
|
# (samples, timesteps, img_dims..., filters)
|
@@ -470,7 +470,8 @@ class Bidirectional(Wrapper):
|
|
470
470
|
self.forward_layer.build(input_shape)
|
471
471
|
with backend.name_scope(self.backward_layer.name):
|
472
472
|
self.backward_layer.build(input_shape)
|
473
|
-
|
473
|
+
# Call Layer.build() to skip Wrapper.build() which we override here.
|
474
|
+
Layer.build(self, input_shape)
|
474
475
|
|
475
476
|
def compute_mask(self, inputs, mask):
|
476
477
|
if isinstance(mask, list):
|
@@ -102,10 +102,10 @@ class _RNNCellWrapper(AbstractRNNCell):
|
|
102
102
|
inputs, state, cell_call_fn=self.cell.call, **kwargs
|
103
103
|
)
|
104
104
|
|
105
|
-
def build(self,
|
105
|
+
def build(self, input_shape):
|
106
106
|
"""Builds the wrapped cell."""
|
107
|
-
self.cell.build(
|
108
|
-
|
107
|
+
self.cell.build(input_shape)
|
108
|
+
super().build(input_shape)
|
109
109
|
|
110
110
|
@property
|
111
111
|
def wrapped_cell(self):
|
@@ -144,8 +144,6 @@ class CuDNNGRU(_CuDNNRNN):
|
|
144
144
|
constraint=self.bias_constraint,
|
145
145
|
)
|
146
146
|
|
147
|
-
self.built = True
|
148
|
-
|
149
147
|
def _process_batch(self, inputs, initial_state):
|
150
148
|
if not self.time_major:
|
151
149
|
inputs = tf.transpose(inputs, perm=(1, 0, 2))
|
@@ -172,6 +170,10 @@ class CuDNNGRU(_CuDNNRNN):
|
|
172
170
|
shape=self._vector_shape,
|
173
171
|
)
|
174
172
|
|
173
|
+
batch_dim = tf.shape(inputs)[1]
|
174
|
+
max_sequence_length = tf.shape(inputs)[0]
|
175
|
+
sequence_lengths = tf.fill([batch_dim], max_sequence_length)
|
176
|
+
|
175
177
|
args = {
|
176
178
|
"input": inputs,
|
177
179
|
"input_h": input_h,
|
@@ -179,9 +181,10 @@ class CuDNNGRU(_CuDNNRNN):
|
|
179
181
|
"params": params,
|
180
182
|
"is_training": True,
|
181
183
|
"rnn_mode": "gru",
|
184
|
+
"sequence_lengths": sequence_lengths,
|
182
185
|
}
|
183
186
|
|
184
|
-
outputs, h, _, _, _ = tf.raw_ops.
|
187
|
+
outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(**args)
|
185
188
|
|
186
189
|
if self.stateful or self.return_state:
|
187
190
|
h = h[0]
|
@@ -170,8 +170,6 @@ class CuDNNLSTM(_CuDNNRNN):
|
|
170
170
|
constraint=self.bias_constraint,
|
171
171
|
)
|
172
172
|
|
173
|
-
self.built = True
|
174
|
-
|
175
173
|
def _process_batch(self, inputs, initial_state):
|
176
174
|
if not self.time_major:
|
177
175
|
inputs = tf.transpose(inputs, perm=(1, 0, 2))
|
@@ -204,15 +202,20 @@ class CuDNNLSTM(_CuDNNRNN):
|
|
204
202
|
shape=self._vector_shape,
|
205
203
|
)
|
206
204
|
|
205
|
+
batch_dim = tf.shape(inputs)[1]
|
206
|
+
max_sequence_length = tf.shape(inputs)[0]
|
207
|
+
sequence_lengths = tf.fill([batch_dim], max_sequence_length)
|
208
|
+
|
207
209
|
args = {
|
208
210
|
"input": inputs,
|
209
211
|
"input_h": input_h,
|
210
212
|
"input_c": input_c,
|
211
213
|
"params": params,
|
212
214
|
"is_training": True,
|
215
|
+
"sequence_lengths": sequence_lengths,
|
213
216
|
}
|
214
217
|
|
215
|
-
outputs, h, c, _, _ = tf.raw_ops.
|
218
|
+
outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3(**args)
|
216
219
|
|
217
220
|
if self.stateful or self.return_state:
|
218
221
|
h = h[0]
|
tf_keras/src/layers/rnn/gru.py
CHANGED
@@ -222,7 +222,6 @@ class GRUCell(DropoutRNNCellMixin, base_layer.BaseRandomLayer):
|
|
222
222
|
)
|
223
223
|
else:
|
224
224
|
self.bias = None
|
225
|
-
self.built = True
|
226
225
|
|
227
226
|
def call(self, inputs, states, training=None):
|
228
227
|
h_tm1 = (
|
@@ -1034,11 +1033,13 @@ def gpu_gru(
|
|
1034
1033
|
mask, time_major
|
1035
1034
|
)
|
1036
1035
|
|
1037
|
-
if
|
1038
|
-
|
1039
|
-
|
1040
|
-
|
1041
|
-
|
1036
|
+
seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
|
1037
|
+
|
1038
|
+
if sequence_lengths is None:
|
1039
|
+
max_sequence_length = tf.shape(inputs)[seq_axis]
|
1040
|
+
batch_size = tf.shape(inputs)[batch_axis]
|
1041
|
+
sequence_lengths = tf.fill([batch_size], max_sequence_length)
|
1042
|
+
|
1042
1043
|
# For init_h, cuDNN expects one more dim of num_layers before or after batch
|
1043
1044
|
# dim for time major or batch major inputs respectively
|
1044
1045
|
init_h = tf.expand_dims(init_h, axis=seq_axis)
|
@@ -1069,49 +1070,36 @@ def gpu_gru(
|
|
1069
1070
|
transpose_weights=True,
|
1070
1071
|
)
|
1071
1072
|
|
1072
|
-
if
|
1073
|
-
|
1074
|
-
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1078
|
-
|
1079
|
-
inputs
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
batch_axis=batch_axis,
|
1084
|
-
)
|
1085
|
-
outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(
|
1086
|
-
input=inputs,
|
1087
|
-
input_h=init_h,
|
1088
|
-
input_c=0,
|
1089
|
-
params=params,
|
1090
|
-
is_training=True,
|
1091
|
-
rnn_mode="gru",
|
1092
|
-
sequence_lengths=sequence_lengths,
|
1093
|
-
time_major=time_major,
|
1073
|
+
if go_backwards:
|
1074
|
+
# Three reversals are required. E.g.,
|
1075
|
+
# normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
|
1076
|
+
# reversed_input_to_cudnn = [3, 2, 1, 0, 0]
|
1077
|
+
# output_from_cudnn = [6, 5, 4, 0, 0]
|
1078
|
+
# expected_output = [0, 0, 6, 5 ,4]
|
1079
|
+
inputs = tf.reverse_sequence(
|
1080
|
+
inputs,
|
1081
|
+
sequence_lengths,
|
1082
|
+
seq_axis=seq_axis,
|
1083
|
+
batch_axis=batch_axis,
|
1094
1084
|
)
|
1095
|
-
|
1096
|
-
|
1097
|
-
|
1098
|
-
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
1104
|
-
|
1105
|
-
|
1106
|
-
|
1107
|
-
|
1108
|
-
|
1109
|
-
|
1110
|
-
|
1111
|
-
params=params,
|
1112
|
-
is_training=True,
|
1113
|
-
rnn_mode="gru",
|
1085
|
+
outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(
|
1086
|
+
input=inputs,
|
1087
|
+
input_h=init_h,
|
1088
|
+
input_c=0,
|
1089
|
+
params=params,
|
1090
|
+
is_training=True,
|
1091
|
+
rnn_mode="gru",
|
1092
|
+
sequence_lengths=sequence_lengths,
|
1093
|
+
time_major=time_major,
|
1094
|
+
)
|
1095
|
+
if go_backwards:
|
1096
|
+
outputs = tf.reverse_sequence(
|
1097
|
+
outputs,
|
1098
|
+
sequence_lengths,
|
1099
|
+
seq_axis=seq_axis,
|
1100
|
+
batch_axis=batch_axis,
|
1114
1101
|
)
|
1102
|
+
outputs = tf.reverse(outputs, axis=[seq_axis])
|
1115
1103
|
|
1116
1104
|
last_output = outputs[-1]
|
1117
1105
|
if not time_major and sequence_lengths is None and return_sequences:
|
@@ -368,9 +368,9 @@ class DropoutWrapper(_RNNCellWrapperV1):
|
|
368
368
|
def wrapped_cell(self):
|
369
369
|
return self.cell
|
370
370
|
|
371
|
-
def build(self,
|
372
|
-
self.cell.build(
|
373
|
-
|
371
|
+
def build(self, input_shape):
|
372
|
+
self.cell.build(input_shape)
|
373
|
+
super().build(input_shape)
|
374
374
|
|
375
375
|
def _variational_recurrent_dropout_value(
|
376
376
|
self, unused_index, value, noise, keep_prob
|