keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +16 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +6 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +16 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +12 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +6 -12
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +38 -20
- keras/src/backend/jax/core.py +126 -78
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/layer.py +3 -1
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +511 -29
- keras/src/backend/jax/numpy.py +109 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +18 -3
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +97 -8
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +6 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1369 -195
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +351 -56
- keras/src/backend/tensorflow/trainer.py +6 -2
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +109 -9
- keras/src/backend/torch/trainer.py +8 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +4 -0
- keras/src/dtype_policies/dtype_policy.py +180 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/onnx.py +6 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +406 -102
- keras/src/layers/core/einsum_dense.py +521 -116
- keras/src/layers/core/embedding.py +257 -99
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +50 -15
- keras/src/layers/merging/concatenate.py +6 -5
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/preprocessing/string_lookup.py +26 -28
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/legacy/preprocessing/image.py +4 -1
- keras/src/legacy/preprocessing/sequence.py +20 -12
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +195 -44
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +701 -44
- keras/src/ops/operation.py +90 -29
- keras/src/ops/operation_utils.py +2 -0
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +346 -207
- keras/src/quantizers/gptq_config.py +63 -13
- keras/src/quantizers/gptq_core.py +328 -215
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +407 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +38 -17
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
- keras/src/tree/torchtree_impl.py +215 -0
- keras/src/tree/tree_api.py +6 -1
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/python_utils.py +5 -0
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +70 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
|
@@ -22,21 +22,14 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
|
|
|
22
22
|
|
|
23
23
|
def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
|
|
24
24
|
dtype = dtype or floatx()
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
seed1, seed2 = convert_to_numpy(seed)
|
|
25
|
+
seed_val = draw_seed(seed)
|
|
26
|
+
if isinstance(seed_val, OpenVINOKerasTensor):
|
|
27
|
+
seed_data = convert_to_numpy(seed_val)
|
|
29
28
|
else:
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
shape = list(shape)
|
|
35
|
-
output_shape_const = ov_opset.constant(shape, dtype=Type.i32)
|
|
36
|
-
random_uniform = ov_opset.random_uniform(
|
|
37
|
-
output_shape_const, minval_const, maxval_const, ov_type, seed1, seed2
|
|
38
|
-
)
|
|
39
|
-
return OpenVINOKerasTensor(random_uniform.output(0))
|
|
29
|
+
seed_data = seed_val.data
|
|
30
|
+
rng = np.random.default_rng(seed_data)
|
|
31
|
+
random_values = rng.uniform(minval, maxval, size=shape).astype(dtype)
|
|
32
|
+
return OpenVINOKerasTensor(ov_opset.constant(random_values).output(0))
|
|
40
33
|
|
|
41
34
|
|
|
42
35
|
def categorical(logits, num_samples, dtype="int64", seed=None):
|
|
@@ -13,7 +13,6 @@ class TFLayer(KerasAutoTrackable):
|
|
|
13
13
|
self._saved_model_arg_spec = None
|
|
14
14
|
self._tracked = []
|
|
15
15
|
|
|
16
|
-
@tf.__internal__.tracking.no_automatic_dependency_tracking
|
|
17
16
|
def _set_save_spec(self, inputs, args=None, kwargs=None):
|
|
18
17
|
"""Defines the save spec so that serialization can trace layer calls.
|
|
19
18
|
|
|
@@ -45,6 +44,7 @@ class TFLayer(KerasAutoTrackable):
|
|
|
45
44
|
kwargs_spec,
|
|
46
45
|
)
|
|
47
46
|
|
|
47
|
+
@tf.__internal__.tracking.no_automatic_dependency_tracking
|
|
48
48
|
def _trackable_children(self, save_type="checkpoint", **kwargs):
|
|
49
49
|
if save_type == "savedmodel":
|
|
50
50
|
# SavedModel needs to ignore the execution functions.
|
|
@@ -62,17 +62,51 @@ class TFLayer(KerasAutoTrackable):
|
|
|
62
62
|
self.test_function = test_function
|
|
63
63
|
self.predict_function = predict_function
|
|
64
64
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
children[tracked_attr] = list(tracked_item)
|
|
69
|
-
if isinstance(tracked_item, tracking.TrackedDict):
|
|
70
|
-
children[tracked_attr] = dict(tracked_item)
|
|
71
|
-
if isinstance(tracked_item, tracking.TrackedSet):
|
|
72
|
-
children[tracked_attr] = list(tracked_item)
|
|
65
|
+
# Convert Keras tracked collections to plain Python structures
|
|
66
|
+
# without creating TensorFlow trackable dependencies
|
|
67
|
+
self._convert_tracked_collections(children)
|
|
73
68
|
|
|
74
69
|
return children
|
|
75
70
|
|
|
71
|
+
def _convert_tracked_collections(self, children):
|
|
72
|
+
"""Convert TrackedList/Dict/Set to plain Python structures."""
|
|
73
|
+
for tracked_attr in self._tracked:
|
|
74
|
+
tracked_item = getattr(self, tracked_attr)
|
|
75
|
+
if isinstance(tracked_item, tracking.TrackedList):
|
|
76
|
+
children[tracked_attr] = list(tracked_item)
|
|
77
|
+
if isinstance(tracked_item, tracking.TrackedDict):
|
|
78
|
+
children[tracked_attr] = dict(tracked_item)
|
|
79
|
+
if isinstance(tracked_item, tracking.TrackedSet):
|
|
80
|
+
children[tracked_attr] = list(tracked_item)
|
|
81
|
+
|
|
82
|
+
def _get_save_spec(self, dynamic_batch=True):
|
|
83
|
+
"""Compatibility shim for TensorFlow saving utilities.
|
|
84
|
+
|
|
85
|
+
TensorFlow's SavedModel / TFLite export paths (e.g.,
|
|
86
|
+
tf.lite.TFLiteConverter.from_keras_model) expect a `_get_save_spec`
|
|
87
|
+
method on models. This method generates TensorSpec objects
|
|
88
|
+
describing the model's input signature.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
dynamic_batch: whether to set the batch dimension to `None`.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
A TensorSpec, list or dict mirroring the model inputs, or
|
|
95
|
+
`None` when specs cannot be inferred.
|
|
96
|
+
"""
|
|
97
|
+
# Lazy import to avoid circular dependency
|
|
98
|
+
from keras.src.export.export_utils import make_tf_tensor_spec
|
|
99
|
+
|
|
100
|
+
# Fall back to building specs from `self.inputs`
|
|
101
|
+
inputs = getattr(self, "inputs", None)
|
|
102
|
+
if inputs is None:
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
return tree.map_structure(
|
|
106
|
+
lambda x: make_tf_tensor_spec(x, dynamic_batch=dynamic_batch),
|
|
107
|
+
inputs,
|
|
108
|
+
)
|
|
109
|
+
|
|
76
110
|
@property
|
|
77
111
|
def _default_save_signature(self):
|
|
78
112
|
"""For SavedModel support: returns the default serving signature."""
|
|
@@ -244,3 +244,27 @@ def lstsq(a, b, rcond=None):
|
|
|
244
244
|
if b_orig_ndim == 1:
|
|
245
245
|
x = tf.reshape(x, [-1])
|
|
246
246
|
return x
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def jvp(fun, primals, tangents, has_aux=False):
|
|
250
|
+
primal_flat = tf.nest.flatten(primals)
|
|
251
|
+
tangent_flat = tf.nest.flatten(tangents)
|
|
252
|
+
|
|
253
|
+
tangent_flat = [
|
|
254
|
+
tf.cast(t, p.dtype) for t, p in zip(tangent_flat, primal_flat)
|
|
255
|
+
]
|
|
256
|
+
|
|
257
|
+
with tf.autodiff.ForwardAccumulator(primal_flat, tangent_flat) as acc:
|
|
258
|
+
if has_aux:
|
|
259
|
+
primals_out, aux = fun(*primals)
|
|
260
|
+
else:
|
|
261
|
+
primals_out = fun(*primals)
|
|
262
|
+
|
|
263
|
+
primals_out_flat = tf.nest.flatten(primals_out)
|
|
264
|
+
tangents_out_flat = [acc.jvp(po) for po in primals_out_flat]
|
|
265
|
+
|
|
266
|
+
tangents_out = tf.nest.pack_sequence_as(primals_out, tangents_out_flat)
|
|
267
|
+
|
|
268
|
+
if has_aux:
|
|
269
|
+
return primals_out, tangents_out, aux
|
|
270
|
+
return primals_out, tangents_out
|
|
@@ -4,6 +4,9 @@ import warnings
|
|
|
4
4
|
import tensorflow as tf
|
|
5
5
|
|
|
6
6
|
from keras.src import backend
|
|
7
|
+
from keras.src.backend.common.backend_utils import (
|
|
8
|
+
compute_adaptive_pooling_window_sizes,
|
|
9
|
+
)
|
|
7
10
|
from keras.src.backend.common.backend_utils import (
|
|
8
11
|
compute_conv_transpose_output_shape,
|
|
9
12
|
)
|
|
@@ -268,6 +271,486 @@ def average_pool(
|
|
|
268
271
|
return outputs
|
|
269
272
|
|
|
270
273
|
|
|
274
|
+
def _compute_static_gather_indices(
|
|
275
|
+
input_dim, output_size, small_window, big_window
|
|
276
|
+
):
|
|
277
|
+
"""Compute gather indices for Two-Pool Gather method (corrected)."""
|
|
278
|
+
window_starts = tf.cast(
|
|
279
|
+
tf.floor(
|
|
280
|
+
tf.cast(tf.range(output_size), tf.float32)
|
|
281
|
+
* tf.cast(input_dim, tf.float32)
|
|
282
|
+
/ tf.cast(output_size, tf.float32)
|
|
283
|
+
),
|
|
284
|
+
tf.int32,
|
|
285
|
+
)
|
|
286
|
+
window_ends = tf.cast(
|
|
287
|
+
tf.math.ceil(
|
|
288
|
+
tf.cast(tf.range(1, output_size + 1), tf.float32)
|
|
289
|
+
* tf.cast(input_dim, tf.float32)
|
|
290
|
+
/ tf.cast(output_size, tf.float32)
|
|
291
|
+
),
|
|
292
|
+
tf.int32,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
window_ends = tf.minimum(window_ends, input_dim)
|
|
296
|
+
window_starts = tf.minimum(window_starts, input_dim - 1)
|
|
297
|
+
|
|
298
|
+
window_sizes = window_ends - window_starts
|
|
299
|
+
is_big_window = tf.equal(window_sizes, big_window)
|
|
300
|
+
|
|
301
|
+
small_pool_len = max(1, input_dim - small_window + 1)
|
|
302
|
+
|
|
303
|
+
small_indices = window_starts
|
|
304
|
+
big_indices = window_starts + small_pool_len
|
|
305
|
+
|
|
306
|
+
gather_indices = tf.where(is_big_window, big_indices, small_indices)
|
|
307
|
+
return tf.cast(gather_indices, tf.int32)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def _adaptive_average_pool1d(inputs, output_size, data_format="channels_first"):
|
|
311
|
+
if isinstance(output_size, int):
|
|
312
|
+
output_size = (output_size,)
|
|
313
|
+
if data_format == "channels_first":
|
|
314
|
+
inputs = tf.transpose(inputs, (0, 2, 1))
|
|
315
|
+
|
|
316
|
+
static_shape = inputs.shape.as_list()
|
|
317
|
+
l_static = static_shape[1]
|
|
318
|
+
out_l = output_size[0]
|
|
319
|
+
|
|
320
|
+
if l_static is None:
|
|
321
|
+
raise ValueError(
|
|
322
|
+
"Input length must be statically known for adaptive pooling"
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
small_l, big_l = compute_adaptive_pooling_window_sizes(l_static, out_l)
|
|
326
|
+
gather_l = _compute_static_gather_indices(l_static, out_l, small_l, big_l)
|
|
327
|
+
|
|
328
|
+
small_pool_l = tf.nn.pool(
|
|
329
|
+
inputs,
|
|
330
|
+
window_shape=(small_l,),
|
|
331
|
+
pooling_type="AVG",
|
|
332
|
+
strides=(1,),
|
|
333
|
+
padding="VALID",
|
|
334
|
+
data_format="NWC",
|
|
335
|
+
)
|
|
336
|
+
big_pool_l = tf.nn.pool(
|
|
337
|
+
inputs,
|
|
338
|
+
window_shape=(big_l,),
|
|
339
|
+
pooling_type="AVG",
|
|
340
|
+
strides=(1,),
|
|
341
|
+
padding="VALID",
|
|
342
|
+
data_format="NWC",
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
combined_l = tf.concat([small_pool_l, big_pool_l], axis=1)
|
|
346
|
+
pooled_l = tf.gather(combined_l, gather_l, axis=1)
|
|
347
|
+
|
|
348
|
+
if data_format == "channels_first":
|
|
349
|
+
pooled_l = tf.transpose(pooled_l, (0, 2, 1))
|
|
350
|
+
return pooled_l
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def _adaptive_max_pool1d(inputs, output_size, data_format="channels_first"):
|
|
354
|
+
if isinstance(output_size, int):
|
|
355
|
+
output_size = (output_size,)
|
|
356
|
+
if data_format == "channels_first":
|
|
357
|
+
inputs = tf.transpose(inputs, (0, 2, 1))
|
|
358
|
+
|
|
359
|
+
static_shape = inputs.shape.as_list()
|
|
360
|
+
l_static = static_shape[1]
|
|
361
|
+
out_l = output_size[0]
|
|
362
|
+
|
|
363
|
+
if l_static is None:
|
|
364
|
+
raise ValueError(
|
|
365
|
+
"Input length must be statically known for adaptive pooling"
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
small_l, big_l = compute_adaptive_pooling_window_sizes(l_static, out_l)
|
|
369
|
+
gather_l = _compute_static_gather_indices(l_static, out_l, small_l, big_l)
|
|
370
|
+
|
|
371
|
+
small_pool_l = tf.nn.pool(
|
|
372
|
+
inputs,
|
|
373
|
+
window_shape=(small_l,),
|
|
374
|
+
pooling_type="MAX",
|
|
375
|
+
strides=(1,),
|
|
376
|
+
padding="VALID",
|
|
377
|
+
data_format="NWC",
|
|
378
|
+
)
|
|
379
|
+
big_pool_l = tf.nn.pool(
|
|
380
|
+
inputs,
|
|
381
|
+
window_shape=(big_l,),
|
|
382
|
+
pooling_type="MAX",
|
|
383
|
+
strides=(1,),
|
|
384
|
+
padding="VALID",
|
|
385
|
+
data_format="NWC",
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
combined_l = tf.concat([small_pool_l, big_pool_l], axis=1)
|
|
389
|
+
pooled_l = tf.gather(combined_l, gather_l, axis=1)
|
|
390
|
+
|
|
391
|
+
if data_format == "channels_first":
|
|
392
|
+
pooled_l = tf.transpose(pooled_l, (0, 2, 1))
|
|
393
|
+
return pooled_l
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def _adaptive_average_pool2d(inputs, output_size, data_format="channels_first"):
|
|
397
|
+
if isinstance(output_size, int):
|
|
398
|
+
output_size = (output_size, output_size)
|
|
399
|
+
|
|
400
|
+
if data_format == "channels_first":
|
|
401
|
+
inputs = tf.transpose(inputs, (0, 2, 3, 1))
|
|
402
|
+
|
|
403
|
+
static_shape = inputs.shape.as_list()
|
|
404
|
+
h_static = static_shape[1]
|
|
405
|
+
w_static = static_shape[2]
|
|
406
|
+
out_h, out_w = output_size
|
|
407
|
+
|
|
408
|
+
if h_static is None or w_static is None:
|
|
409
|
+
raise ValueError(
|
|
410
|
+
"Input spatial dimensions must be "
|
|
411
|
+
"statically known for adaptive pooling"
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
small_h, big_h = compute_adaptive_pooling_window_sizes(h_static, out_h)
|
|
415
|
+
small_w, big_w = compute_adaptive_pooling_window_sizes(w_static, out_w)
|
|
416
|
+
|
|
417
|
+
gather_h = _compute_static_gather_indices(h_static, out_h, small_h, big_h)
|
|
418
|
+
gather_w = _compute_static_gather_indices(w_static, out_w, small_w, big_w)
|
|
419
|
+
|
|
420
|
+
small_pool_h = tf.nn.pool(
|
|
421
|
+
inputs,
|
|
422
|
+
window_shape=(small_h, 1),
|
|
423
|
+
pooling_type="AVG",
|
|
424
|
+
strides=(1, 1),
|
|
425
|
+
padding="VALID",
|
|
426
|
+
data_format="NHWC",
|
|
427
|
+
)
|
|
428
|
+
big_pool_h = tf.nn.pool(
|
|
429
|
+
inputs,
|
|
430
|
+
window_shape=(big_h, 1),
|
|
431
|
+
pooling_type="AVG",
|
|
432
|
+
strides=(1, 1),
|
|
433
|
+
padding="VALID",
|
|
434
|
+
data_format="NHWC",
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
combined_h = tf.concat([small_pool_h, big_pool_h], axis=1)
|
|
438
|
+
pooled_h = tf.gather(combined_h, gather_h, axis=1)
|
|
439
|
+
|
|
440
|
+
small_pool_w = tf.nn.pool(
|
|
441
|
+
pooled_h,
|
|
442
|
+
window_shape=(1, small_w),
|
|
443
|
+
pooling_type="AVG",
|
|
444
|
+
strides=(1, 1),
|
|
445
|
+
padding="VALID",
|
|
446
|
+
data_format="NHWC",
|
|
447
|
+
)
|
|
448
|
+
big_pool_w = tf.nn.pool(
|
|
449
|
+
pooled_h,
|
|
450
|
+
window_shape=(1, big_w),
|
|
451
|
+
pooling_type="AVG",
|
|
452
|
+
strides=(1, 1),
|
|
453
|
+
padding="VALID",
|
|
454
|
+
data_format="NHWC",
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
combined_w = tf.concat([small_pool_w, big_pool_w], axis=2)
|
|
458
|
+
pooled_w = tf.gather(combined_w, gather_w, axis=2)
|
|
459
|
+
|
|
460
|
+
if data_format == "channels_first":
|
|
461
|
+
pooled_w = tf.transpose(pooled_w, (0, 3, 1, 2))
|
|
462
|
+
|
|
463
|
+
return pooled_w
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
def _adaptive_max_pool2d(inputs, output_size, data_format="channels_first"):
|
|
467
|
+
"""Adaptive Max Pooling 2D using Two-Pool Gather method."""
|
|
468
|
+
if isinstance(output_size, int):
|
|
469
|
+
output_size = (output_size, output_size)
|
|
470
|
+
|
|
471
|
+
if data_format == "channels_first":
|
|
472
|
+
inputs = tf.transpose(inputs, (0, 2, 3, 1))
|
|
473
|
+
|
|
474
|
+
static_shape = inputs.shape.as_list()
|
|
475
|
+
h_static = static_shape[1]
|
|
476
|
+
w_static = static_shape[2]
|
|
477
|
+
out_h, out_w = output_size
|
|
478
|
+
|
|
479
|
+
if h_static is None or w_static is None:
|
|
480
|
+
raise ValueError(
|
|
481
|
+
"Input spatial dimensions must be "
|
|
482
|
+
"statically known for adaptive pooling"
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
small_h, big_h = compute_adaptive_pooling_window_sizes(h_static, out_h)
|
|
486
|
+
small_w, big_w = compute_adaptive_pooling_window_sizes(w_static, out_w)
|
|
487
|
+
|
|
488
|
+
gather_h = _compute_static_gather_indices(h_static, out_h, small_h, big_h)
|
|
489
|
+
gather_w = _compute_static_gather_indices(w_static, out_w, small_w, big_w)
|
|
490
|
+
|
|
491
|
+
small_pool_h = tf.nn.pool(
|
|
492
|
+
inputs,
|
|
493
|
+
window_shape=(small_h, 1),
|
|
494
|
+
pooling_type="MAX",
|
|
495
|
+
strides=(1, 1),
|
|
496
|
+
padding="VALID",
|
|
497
|
+
data_format="NHWC",
|
|
498
|
+
)
|
|
499
|
+
big_pool_h = tf.nn.pool(
|
|
500
|
+
inputs,
|
|
501
|
+
window_shape=(big_h, 1),
|
|
502
|
+
pooling_type="MAX",
|
|
503
|
+
strides=(1, 1),
|
|
504
|
+
padding="VALID",
|
|
505
|
+
data_format="NHWC",
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
combined_h = tf.concat([small_pool_h, big_pool_h], axis=1)
|
|
509
|
+
pooled_h = tf.gather(combined_h, gather_h, axis=1)
|
|
510
|
+
|
|
511
|
+
small_pool_w = tf.nn.pool(
|
|
512
|
+
pooled_h,
|
|
513
|
+
window_shape=(1, small_w),
|
|
514
|
+
pooling_type="MAX",
|
|
515
|
+
strides=(1, 1),
|
|
516
|
+
padding="VALID",
|
|
517
|
+
data_format="NHWC",
|
|
518
|
+
)
|
|
519
|
+
big_pool_w = tf.nn.pool(
|
|
520
|
+
pooled_h,
|
|
521
|
+
window_shape=(1, big_w),
|
|
522
|
+
pooling_type="MAX",
|
|
523
|
+
strides=(1, 1),
|
|
524
|
+
padding="VALID",
|
|
525
|
+
data_format="NHWC",
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
combined_w = tf.concat([small_pool_w, big_pool_w], axis=2)
|
|
529
|
+
pooled_w = tf.gather(combined_w, gather_w, axis=2)
|
|
530
|
+
|
|
531
|
+
if data_format == "channels_first":
|
|
532
|
+
pooled_w = tf.transpose(pooled_w, (0, 3, 1, 2))
|
|
533
|
+
|
|
534
|
+
return pooled_w
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
def _adaptive_average_pool3d(inputs, output_size, data_format="channels_first"):
|
|
538
|
+
if isinstance(output_size, int):
|
|
539
|
+
output_size = (output_size, output_size, output_size)
|
|
540
|
+
|
|
541
|
+
if data_format == "channels_first":
|
|
542
|
+
inputs = tf.transpose(inputs, (0, 2, 3, 4, 1))
|
|
543
|
+
|
|
544
|
+
static_shape = inputs.shape.as_list()
|
|
545
|
+
d_static = static_shape[1]
|
|
546
|
+
h_static = static_shape[2]
|
|
547
|
+
w_static = static_shape[3]
|
|
548
|
+
out_d, out_h, out_w = output_size
|
|
549
|
+
|
|
550
|
+
if d_static is None or h_static is None or w_static is None:
|
|
551
|
+
raise ValueError(
|
|
552
|
+
"Input spatial dimensions must be "
|
|
553
|
+
"statically known for adaptive pooling"
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
small_d, big_d = compute_adaptive_pooling_window_sizes(d_static, out_d)
|
|
557
|
+
small_h, big_h = compute_adaptive_pooling_window_sizes(h_static, out_h)
|
|
558
|
+
small_w, big_w = compute_adaptive_pooling_window_sizes(w_static, out_w)
|
|
559
|
+
|
|
560
|
+
gather_d = _compute_static_gather_indices(d_static, out_d, small_d, big_d)
|
|
561
|
+
gather_h = _compute_static_gather_indices(h_static, out_h, small_h, big_h)
|
|
562
|
+
gather_w = _compute_static_gather_indices(w_static, out_w, small_w, big_w)
|
|
563
|
+
|
|
564
|
+
small_pool_d = tf.nn.pool(
|
|
565
|
+
inputs,
|
|
566
|
+
window_shape=(small_d, 1, 1),
|
|
567
|
+
pooling_type="AVG",
|
|
568
|
+
strides=(1, 1, 1),
|
|
569
|
+
padding="VALID",
|
|
570
|
+
data_format="NDHWC",
|
|
571
|
+
)
|
|
572
|
+
big_pool_d = tf.nn.pool(
|
|
573
|
+
inputs,
|
|
574
|
+
window_shape=(big_d, 1, 1),
|
|
575
|
+
pooling_type="AVG",
|
|
576
|
+
strides=(1, 1, 1),
|
|
577
|
+
padding="VALID",
|
|
578
|
+
data_format="NDHWC",
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
combined_d = tf.concat([small_pool_d, big_pool_d], axis=1)
|
|
582
|
+
pooled_d = tf.gather(combined_d, gather_d, axis=1)
|
|
583
|
+
|
|
584
|
+
small_pool_h = tf.nn.pool(
|
|
585
|
+
pooled_d,
|
|
586
|
+
window_shape=(1, small_h, 1),
|
|
587
|
+
pooling_type="AVG",
|
|
588
|
+
strides=(1, 1, 1),
|
|
589
|
+
padding="VALID",
|
|
590
|
+
data_format="NDHWC",
|
|
591
|
+
)
|
|
592
|
+
big_pool_h = tf.nn.pool(
|
|
593
|
+
pooled_d,
|
|
594
|
+
window_shape=(1, big_h, 1),
|
|
595
|
+
pooling_type="AVG",
|
|
596
|
+
strides=(1, 1, 1),
|
|
597
|
+
padding="VALID",
|
|
598
|
+
data_format="NDHWC",
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
combined_h = tf.concat([small_pool_h, big_pool_h], axis=2)
|
|
602
|
+
pooled_h = tf.gather(combined_h, gather_h, axis=2)
|
|
603
|
+
|
|
604
|
+
small_pool_w = tf.nn.pool(
|
|
605
|
+
pooled_h,
|
|
606
|
+
window_shape=(1, 1, small_w),
|
|
607
|
+
pooling_type="AVG",
|
|
608
|
+
strides=(1, 1, 1),
|
|
609
|
+
padding="VALID",
|
|
610
|
+
data_format="NDHWC",
|
|
611
|
+
)
|
|
612
|
+
big_pool_w = tf.nn.pool(
|
|
613
|
+
pooled_h,
|
|
614
|
+
window_shape=(1, 1, big_w),
|
|
615
|
+
pooling_type="AVG",
|
|
616
|
+
strides=(1, 1, 1),
|
|
617
|
+
padding="VALID",
|
|
618
|
+
data_format="NDHWC",
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
combined_w = tf.concat([small_pool_w, big_pool_w], axis=3)
|
|
622
|
+
pooled_w = tf.gather(combined_w, gather_w, axis=3)
|
|
623
|
+
|
|
624
|
+
if data_format == "channels_first":
|
|
625
|
+
pooled_w = tf.transpose(pooled_w, (0, 4, 1, 2, 3))
|
|
626
|
+
|
|
627
|
+
return pooled_w
|
|
628
|
+
|
|
629
|
+
|
|
630
|
+
def _adaptive_max_pool3d(inputs, output_size, data_format="channels_first"):
|
|
631
|
+
"""Adaptive Max Pooling 3D using Two-Pool Gather method."""
|
|
632
|
+
if isinstance(output_size, int):
|
|
633
|
+
output_size = (output_size, output_size, output_size)
|
|
634
|
+
|
|
635
|
+
if data_format == "channels_first":
|
|
636
|
+
inputs = tf.transpose(inputs, (0, 2, 3, 4, 1))
|
|
637
|
+
|
|
638
|
+
static_shape = inputs.shape.as_list()
|
|
639
|
+
d_static = static_shape[1]
|
|
640
|
+
h_static = static_shape[2]
|
|
641
|
+
w_static = static_shape[3]
|
|
642
|
+
out_d, out_h, out_w = output_size
|
|
643
|
+
|
|
644
|
+
if d_static is None or h_static is None or w_static is None:
|
|
645
|
+
raise ValueError(
|
|
646
|
+
"Input spatial dimensions must be "
|
|
647
|
+
"statically known for adaptive pooling"
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
small_d, big_d = compute_adaptive_pooling_window_sizes(d_static, out_d)
|
|
651
|
+
small_h, big_h = compute_adaptive_pooling_window_sizes(h_static, out_h)
|
|
652
|
+
small_w, big_w = compute_adaptive_pooling_window_sizes(w_static, out_w)
|
|
653
|
+
|
|
654
|
+
gather_d = _compute_static_gather_indices(d_static, out_d, small_d, big_d)
|
|
655
|
+
gather_h = _compute_static_gather_indices(h_static, out_h, small_h, big_h)
|
|
656
|
+
gather_w = _compute_static_gather_indices(w_static, out_w, small_w, big_w)
|
|
657
|
+
|
|
658
|
+
small_pool_d = tf.nn.pool(
|
|
659
|
+
inputs,
|
|
660
|
+
window_shape=(small_d, 1, 1),
|
|
661
|
+
pooling_type="MAX",
|
|
662
|
+
strides=(1, 1, 1),
|
|
663
|
+
padding="VALID",
|
|
664
|
+
data_format="NDHWC",
|
|
665
|
+
)
|
|
666
|
+
big_pool_d = tf.nn.pool(
|
|
667
|
+
inputs,
|
|
668
|
+
window_shape=(big_d, 1, 1),
|
|
669
|
+
pooling_type="MAX",
|
|
670
|
+
strides=(1, 1, 1),
|
|
671
|
+
padding="VALID",
|
|
672
|
+
data_format="NDHWC",
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
combined_d = tf.concat([small_pool_d, big_pool_d], axis=1)
|
|
676
|
+
pooled_d = tf.gather(combined_d, gather_d, axis=1)
|
|
677
|
+
|
|
678
|
+
small_pool_h = tf.nn.pool(
|
|
679
|
+
pooled_d,
|
|
680
|
+
window_shape=(1, small_h, 1),
|
|
681
|
+
pooling_type="MAX",
|
|
682
|
+
strides=(1, 1, 1),
|
|
683
|
+
padding="VALID",
|
|
684
|
+
data_format="NDHWC",
|
|
685
|
+
)
|
|
686
|
+
big_pool_h = tf.nn.pool(
|
|
687
|
+
pooled_d,
|
|
688
|
+
window_shape=(1, big_h, 1),
|
|
689
|
+
pooling_type="MAX",
|
|
690
|
+
strides=(1, 1, 1),
|
|
691
|
+
padding="VALID",
|
|
692
|
+
data_format="NDHWC",
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
combined_h = tf.concat([small_pool_h, big_pool_h], axis=2)
|
|
696
|
+
pooled_h = tf.gather(combined_h, gather_h, axis=2)
|
|
697
|
+
|
|
698
|
+
small_pool_w = tf.nn.pool(
|
|
699
|
+
pooled_h,
|
|
700
|
+
window_shape=(1, 1, small_w),
|
|
701
|
+
pooling_type="MAX",
|
|
702
|
+
strides=(1, 1, 1),
|
|
703
|
+
padding="VALID",
|
|
704
|
+
data_format="NDHWC",
|
|
705
|
+
)
|
|
706
|
+
big_pool_w = tf.nn.pool(
|
|
707
|
+
pooled_h,
|
|
708
|
+
window_shape=(1, 1, big_w),
|
|
709
|
+
pooling_type="MAX",
|
|
710
|
+
strides=(1, 1, 1),
|
|
711
|
+
padding="VALID",
|
|
712
|
+
data_format="NDHWC",
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
combined_w = tf.concat([small_pool_w, big_pool_w], axis=3)
|
|
716
|
+
pooled_w = tf.gather(combined_w, gather_w, axis=3)
|
|
717
|
+
|
|
718
|
+
if data_format == "channels_first":
|
|
719
|
+
pooled_w = tf.transpose(pooled_w, (0, 4, 1, 2, 3))
|
|
720
|
+
|
|
721
|
+
return pooled_w
|
|
722
|
+
|
|
723
|
+
|
|
724
|
+
def adaptive_average_pool(inputs, output_size, data_format=None):
|
|
725
|
+
data_format = backend.standardize_data_format(data_format)
|
|
726
|
+
ndims = len(inputs.shape) - 2
|
|
727
|
+
if ndims == 1:
|
|
728
|
+
return _adaptive_average_pool1d(inputs, output_size, data_format)
|
|
729
|
+
elif ndims == 2:
|
|
730
|
+
return _adaptive_average_pool2d(inputs, output_size, data_format)
|
|
731
|
+
elif ndims == 3:
|
|
732
|
+
return _adaptive_average_pool3d(inputs, output_size, data_format)
|
|
733
|
+
else:
|
|
734
|
+
raise ValueError(
|
|
735
|
+
"adaptive_average_pool supports 1D, 2D, or 3D inputs only."
|
|
736
|
+
)
|
|
737
|
+
|
|
738
|
+
|
|
739
|
+
def adaptive_max_pool(inputs, output_size, data_format=None):
|
|
740
|
+
data_format = backend.standardize_data_format(data_format)
|
|
741
|
+
ndims = len(inputs.shape) - 2
|
|
742
|
+
if ndims == 1:
|
|
743
|
+
return _adaptive_max_pool1d(inputs, output_size, data_format)
|
|
744
|
+
elif ndims == 2:
|
|
745
|
+
return _adaptive_max_pool2d(inputs, output_size, data_format)
|
|
746
|
+
elif ndims == 3:
|
|
747
|
+
return _adaptive_max_pool3d(inputs, output_size, data_format)
|
|
748
|
+
else:
|
|
749
|
+
raise ValueError(
|
|
750
|
+
"adaptive_max_pool supports 1D, 2D, or 3D inputs only."
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
|
|
271
754
|
def _convert_data_format(data_format, ndim):
|
|
272
755
|
if data_format == "channels_last":
|
|
273
756
|
if ndim == 3:
|
|
@@ -310,7 +793,7 @@ def conv(
|
|
|
310
793
|
):
|
|
311
794
|
def _conv():
|
|
312
795
|
tf_data_format = _convert_data_format(data_format, len(inputs.shape))
|
|
313
|
-
|
|
796
|
+
result = tf.nn.convolution(
|
|
314
797
|
inputs,
|
|
315
798
|
kernel,
|
|
316
799
|
strides,
|
|
@@ -318,6 +801,20 @@ def conv(
|
|
|
318
801
|
data_format=tf_data_format,
|
|
319
802
|
dilations=dilation_rate,
|
|
320
803
|
)
|
|
804
|
+
result_shape = result.shape
|
|
805
|
+
if (
|
|
806
|
+
result_shape.is_fully_defined()
|
|
807
|
+
and math.prod(result_shape.as_list()) == 0
|
|
808
|
+
):
|
|
809
|
+
raise ValueError(
|
|
810
|
+
"The convolution operation resulted in an empty output. "
|
|
811
|
+
"Output shape:"
|
|
812
|
+
f" {result_shape}. This can happen if the input is too small "
|
|
813
|
+
"for the given kernel size, strides, dilation rate, and "
|
|
814
|
+
"padding mode. Please check the input shape and convolution "
|
|
815
|
+
"parameters."
|
|
816
|
+
)
|
|
817
|
+
return result
|
|
321
818
|
|
|
322
819
|
# Certain ops are are broken in Tensorflow on CPU only.
|
|
323
820
|
# We can work around by compiling the op with XLA.
|
|
@@ -1077,3 +1574,50 @@ def dot_product_attention(
|
|
|
1077
1574
|
return _dot_product_attention_xla(
|
|
1078
1575
|
query, key, value, bias, mask, is_causal, scale
|
|
1079
1576
|
)
|
|
1577
|
+
|
|
1578
|
+
|
|
1579
|
+
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
|
|
1580
|
+
"""Tensorflow implementation of Unfold.
|
|
1581
|
+
Extract sliding local blocks from a **NCHW** batched image tensor.
|
|
1582
|
+
|
|
1583
|
+
Args:
|
|
1584
|
+
input: 4-D tensor, shape (N, C, H, W) **required**.
|
|
1585
|
+
kernel_size: int or (kH, kW)
|
|
1586
|
+
dilation: int or (dH, dW), default 1
|
|
1587
|
+
padding: int or (pH, pW), default 0
|
|
1588
|
+
stride: int or (sH, sW), default 1
|
|
1589
|
+
|
|
1590
|
+
Returns:
|
|
1591
|
+
3-D tensor, shape (N, C*kH*kW, L)
|
|
1592
|
+
"""
|
|
1593
|
+
k = (
|
|
1594
|
+
(kernel_size, kernel_size)
|
|
1595
|
+
if isinstance(kernel_size, int)
|
|
1596
|
+
else kernel_size
|
|
1597
|
+
)
|
|
1598
|
+
d = (dilation, dilation) if isinstance(dilation, int) else dilation
|
|
1599
|
+
p = (padding, padding) if isinstance(padding, int) else padding
|
|
1600
|
+
s = (stride, stride) if isinstance(stride, int) else stride
|
|
1601
|
+
N, C, H, W = input.shape
|
|
1602
|
+
|
|
1603
|
+
# ---- padding ----
|
|
1604
|
+
if any(_ > 0 for _ in p):
|
|
1605
|
+
input = tf.pad(input, [[0, 0], [0, 0], [p[0], p[0]], [p[1], p[1]]])
|
|
1606
|
+
x = tf.transpose(input, [0, 2, 3, 1]) # (N, H, W, C)
|
|
1607
|
+
patches = tf.image.extract_patches(
|
|
1608
|
+
images=x,
|
|
1609
|
+
sizes=[1, k[0], k[1], 1],
|
|
1610
|
+
strides=[1, s[0], s[1], 1],
|
|
1611
|
+
rates=[1, d[0], d[1], 1],
|
|
1612
|
+
padding="VALID",
|
|
1613
|
+
) # (N, nH, nW, kH*kW*C)
|
|
1614
|
+
|
|
1615
|
+
N, nH, nW, D = patches.shape
|
|
1616
|
+
patches = tf.reshape(
|
|
1617
|
+
patches, [N, nH, nW, k[0], k[1], C]
|
|
1618
|
+
) # (N, nH, nW, kH, kW, C)
|
|
1619
|
+
patches = tf.transpose(
|
|
1620
|
+
patches, [0, 5, 3, 4, 1, 2]
|
|
1621
|
+
) # (N, C, kH, kW, nH, nW)
|
|
1622
|
+
patches = tf.reshape(patches, [N, C * k[0] * k[1], nH * nW])
|
|
1623
|
+
return patches
|