keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +16 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +6 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +16 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +12 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +6 -12
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +38 -20
- keras/src/backend/jax/core.py +126 -78
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/layer.py +3 -1
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +511 -29
- keras/src/backend/jax/numpy.py +109 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +18 -3
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +97 -8
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +6 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1369 -195
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +351 -56
- keras/src/backend/tensorflow/trainer.py +6 -2
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +109 -9
- keras/src/backend/torch/trainer.py +8 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +4 -0
- keras/src/dtype_policies/dtype_policy.py +180 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/onnx.py +6 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +406 -102
- keras/src/layers/core/einsum_dense.py +521 -116
- keras/src/layers/core/embedding.py +257 -99
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +50 -15
- keras/src/layers/merging/concatenate.py +6 -5
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/preprocessing/string_lookup.py +26 -28
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/legacy/preprocessing/image.py +4 -1
- keras/src/legacy/preprocessing/sequence.py +20 -12
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +195 -44
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +701 -44
- keras/src/ops/operation.py +90 -29
- keras/src/ops/operation_utils.py +2 -0
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +346 -207
- keras/src/quantizers/gptq_config.py +63 -13
- keras/src/quantizers/gptq_core.py +328 -215
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +407 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +38 -17
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
- keras/src/tree/torchtree_impl.py +215 -0
- keras/src/tree/tree_api.py +6 -1
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/python_utils.py +5 -0
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +70 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
keras/src/models/model.py
CHANGED
|
@@ -2,13 +2,16 @@ import inspect
|
|
|
2
2
|
import json
|
|
3
3
|
import typing
|
|
4
4
|
import warnings
|
|
5
|
+
from collections.abc import Callable
|
|
5
6
|
|
|
6
7
|
from keras.src import backend
|
|
7
8
|
from keras.src import utils
|
|
8
9
|
from keras.src.api_export import keras_export
|
|
9
10
|
from keras.src.layers.layer import Layer
|
|
10
11
|
from keras.src.models.variable_mapping import map_saveable_variables
|
|
11
|
-
from keras.src.quantizers.
|
|
12
|
+
from keras.src.quantizers.awq_core import awq_quantize
|
|
13
|
+
from keras.src.quantizers.gptq_core import gptq_quantize
|
|
14
|
+
from keras.src.quantizers.utils import should_quantize_layer
|
|
12
15
|
from keras.src.saving import saving_api
|
|
13
16
|
from keras.src.trainers import trainer as base_trainer
|
|
14
17
|
from keras.src.utils import summary_utils
|
|
@@ -421,62 +424,168 @@ class Model(Trainer, base_trainer.Trainer, Layer):
|
|
|
421
424
|
**kwargs,
|
|
422
425
|
)
|
|
423
426
|
|
|
424
|
-
def
|
|
427
|
+
def get_quantization_layer_structure(self, mode=None):
|
|
428
|
+
"""Returns the quantization structure for the model.
|
|
429
|
+
|
|
430
|
+
This method is intended to be overridden by model authors to provide
|
|
431
|
+
topology information required for structure-aware quantization modes
|
|
432
|
+
like 'gptq'.
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
mode: The quantization mode.
|
|
436
|
+
|
|
437
|
+
Returns:
|
|
438
|
+
A dictionary describing the topology, e.g.:
|
|
439
|
+
`{'pre_block_layers': [list], 'sequential_blocks': [list]}`
|
|
440
|
+
or `None` if the mode does not require structure or is not
|
|
441
|
+
supported. `'pre_block_layers'` is a list of layers that
|
|
442
|
+
the inputs should be passed through, before being passed to
|
|
443
|
+
the sequential blocks. For example, inputs to an LLM must
|
|
444
|
+
first be passed through an embedding layer, followed by
|
|
445
|
+
the transformer.
|
|
446
|
+
"""
|
|
447
|
+
del mode # Unused.
|
|
448
|
+
return None
|
|
449
|
+
|
|
450
|
+
def quantize(self, mode=None, config=None, filters=None, **kwargs):
|
|
425
451
|
"""Quantize the weights of the model.
|
|
426
452
|
|
|
427
453
|
Note that the model must be built first before calling this method.
|
|
428
|
-
`quantize` will recursively call `quantize(
|
|
454
|
+
`quantize` will recursively call `quantize(...)` in all layers and
|
|
429
455
|
will be skipped if the layer doesn't implement the function.
|
|
430
456
|
|
|
457
|
+
This method can be called by passing a `mode` string, which uses the
|
|
458
|
+
default configuration for that mode. Alternatively, a `config` object
|
|
459
|
+
can be passed to customize the behavior of the quantization (e.g. to
|
|
460
|
+
use specific quantizers for weights or activations).
|
|
461
|
+
|
|
431
462
|
Args:
|
|
432
|
-
mode: The mode of the quantization.
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
463
|
+
mode: The mode of the quantization. Supported modes are:
|
|
464
|
+
`"int8"`, `"int4"`, `"float8"`, `"gptq"`. This is
|
|
465
|
+
optional if `config` is provided.
|
|
466
|
+
config: The configuration object specifying additional
|
|
467
|
+
quantization options. This argument allows to configure
|
|
468
|
+
the weight and activation quantizers. be an instance of
|
|
469
|
+
`keras.quantizers.QuantizationConfig`.
|
|
470
|
+
filters: Optional filters to apply to the quantization. Can be a
|
|
471
|
+
regex string, a list of regex strings, or a callable. Only the
|
|
472
|
+
layers which match the filter conditions will be quantized.
|
|
473
|
+
**kwargs: Additional keyword arguments.
|
|
436
474
|
|
|
437
|
-
|
|
438
|
-
if not isinstance(config, GPTQConfig):
|
|
439
|
-
raise ValueError(
|
|
440
|
-
"The `config` argument must be of type "
|
|
441
|
-
"`keras.quantizers.GPTQConfig`."
|
|
442
|
-
)
|
|
443
|
-
# The config object's own quantize method drives the process
|
|
444
|
-
config.quantize(self)
|
|
445
|
-
return
|
|
475
|
+
Example:
|
|
446
476
|
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
)
|
|
477
|
+
Quantize a model to int8 with default configuration:
|
|
478
|
+
|
|
479
|
+
```python
|
|
480
|
+
# Build the model
|
|
481
|
+
model = keras.Sequential([
|
|
482
|
+
keras.Input(shape=(10,)),
|
|
483
|
+
keras.layers.Dense(10),
|
|
484
|
+
])
|
|
485
|
+
model.build((None, 10))
|
|
453
486
|
|
|
487
|
+
# Quantize with default int8 config
|
|
488
|
+
model.quantize("int8")
|
|
489
|
+
```
|
|
490
|
+
|
|
491
|
+
Quantize a model to int8 with a custom configuration:
|
|
492
|
+
|
|
493
|
+
```python
|
|
494
|
+
from keras.quantizers import Int8QuantizationConfig
|
|
495
|
+
from keras.quantizers import AbsMaxQuantizer
|
|
496
|
+
|
|
497
|
+
# Build the model
|
|
498
|
+
model = keras.Sequential([
|
|
499
|
+
keras.Input(shape=(10,)),
|
|
500
|
+
keras.layers.Dense(10),
|
|
501
|
+
])
|
|
502
|
+
model.build((None, 10))
|
|
503
|
+
|
|
504
|
+
# Create a custom config
|
|
505
|
+
config = Int8QuantizationConfig(
|
|
506
|
+
weight_quantizer=AbsMaxQuantizer(
|
|
507
|
+
axis=0,
|
|
508
|
+
value_range=(-127, 127)
|
|
509
|
+
),
|
|
510
|
+
activation_quantizer=AbsMaxQuantizer(
|
|
511
|
+
axis=-1,
|
|
512
|
+
value_range=(-127, 127)
|
|
513
|
+
),
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
# Quantize with custom config
|
|
517
|
+
model.quantize(config=config)
|
|
518
|
+
```
|
|
519
|
+
"""
|
|
520
|
+
# Validate inputs.
|
|
454
521
|
type_check = kwargs.pop("type_check", True)
|
|
455
522
|
if kwargs:
|
|
456
523
|
raise ValueError(
|
|
457
524
|
"Unrecognized keyword arguments "
|
|
458
525
|
f"passed to {self.__class__.__name__}: {kwargs}"
|
|
459
526
|
)
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
527
|
+
|
|
528
|
+
if filters is not None:
|
|
529
|
+
if not isinstance(filters, (str, Callable, list, tuple)):
|
|
530
|
+
raise ValueError(
|
|
531
|
+
"The `filters` argument must be a regex string, a list of "
|
|
532
|
+
"regex strings, or a callable. Received: "
|
|
533
|
+
f"{type(filters)}"
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
graph_modified = False
|
|
466
537
|
for layer in self._flatten_layers():
|
|
467
|
-
|
|
468
|
-
if
|
|
538
|
+
# Apply filters
|
|
539
|
+
if not should_quantize_layer(layer, filters):
|
|
540
|
+
continue
|
|
541
|
+
|
|
542
|
+
if len(list(layer._flatten_layers())) == 1:
|
|
469
543
|
try:
|
|
470
|
-
layer.quantize(mode, type_check=type_check)
|
|
471
|
-
|
|
544
|
+
layer.quantize(mode, type_check=type_check, config=config)
|
|
545
|
+
graph_modified = True
|
|
472
546
|
except NotImplementedError as e:
|
|
473
547
|
warnings.warn(str(e))
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
548
|
+
except AttributeError:
|
|
549
|
+
pass
|
|
550
|
+
|
|
551
|
+
if mode in ["gptq", "awq"]:
|
|
552
|
+
# Resolve model structure.
|
|
553
|
+
# 1. If quantization_layer_structure is provided inside the config,
|
|
554
|
+
# use that.
|
|
555
|
+
structure = config.quantization_layer_structure
|
|
556
|
+
# 2. If no layer structure is provided in the config, try to fetch
|
|
557
|
+
# it using the `get_quantization_layer_structure` hook.
|
|
558
|
+
if structure is None:
|
|
559
|
+
structure = self.get_quantization_layer_structure(mode)
|
|
560
|
+
|
|
561
|
+
if structure is None:
|
|
562
|
+
raise ValueError(
|
|
563
|
+
f"For {mode=}, a valid quantization structure must be "
|
|
564
|
+
"provided either via `config.quantization_layer_structure` "
|
|
565
|
+
"or by overriding "
|
|
566
|
+
"`model.get_quantization_layer_structure(mode)`. The "
|
|
567
|
+
"structure should be a dictionary with keys "
|
|
568
|
+
"'pre_block_layers' and 'sequential_blocks'."
|
|
569
|
+
)
|
|
570
|
+
if mode == "gptq":
|
|
571
|
+
gptq_quantize(config, structure, filters=filters)
|
|
572
|
+
elif mode == "awq":
|
|
573
|
+
awq_quantize(config, structure, filters=filters)
|
|
574
|
+
|
|
575
|
+
# If any layer was changed, we must rebuild the execution functions.
|
|
576
|
+
if graph_modified:
|
|
477
577
|
self.train_function = None
|
|
478
578
|
self.test_function = None
|
|
479
579
|
self.predict_function = None
|
|
580
|
+
self._post_quantize(mode, **kwargs)
|
|
581
|
+
|
|
582
|
+
def _post_quantize(self, mode, **kwargs):
|
|
583
|
+
if backend.backend() == "torch":
|
|
584
|
+
# We need to manually retrack `torch_params`.
|
|
585
|
+
# The reason is that after quantization, the removed variables are
|
|
586
|
+
# still referenced by `torch_params` and cannot be gc.
|
|
587
|
+
for layer in self._flatten_layers():
|
|
588
|
+
layer._track_variables()
|
|
480
589
|
|
|
481
590
|
def build_from_config(self, config):
|
|
482
591
|
if not config:
|
|
@@ -556,8 +665,8 @@ class Model(Trainer, base_trainer.Trainer, Layer):
|
|
|
556
665
|
filepath: `str` or `pathlib.Path` object. The path to save the
|
|
557
666
|
artifact.
|
|
558
667
|
format: `str`. The export format. Supported values:
|
|
559
|
-
`"tf_saved_model"` and `"
|
|
560
|
-
`"tf_saved_model"`.
|
|
668
|
+
`"tf_saved_model"`, `"onnx"`, `"openvino"`, and `"litert"`.
|
|
669
|
+
Defaults to `"tf_saved_model"`.
|
|
561
670
|
verbose: `bool`. Whether to print a message during export. Defaults
|
|
562
671
|
to `None`, which uses the default value set by different
|
|
563
672
|
backends and formats.
|
|
@@ -580,6 +689,13 @@ class Model(Trainer, base_trainer.Trainer, Layer):
|
|
|
580
689
|
provided, they will be automatically computed.
|
|
581
690
|
- `opset_version`: Optional `int`. Specific to `format="onnx"`.
|
|
582
691
|
An integer value that specifies the ONNX opset version.
|
|
692
|
+
- LiteRT-specific options: Optional keyword arguments specific
|
|
693
|
+
to `format="litert"`. These are passed directly to the
|
|
694
|
+
TensorFlow Lite converter and include options like
|
|
695
|
+
`optimizations`, `representative_dataset`,
|
|
696
|
+
`experimental_new_quantizer`, `allow_custom_ops`,
|
|
697
|
+
`enable_select_tf_ops`, etc. See TensorFlow Lite
|
|
698
|
+
documentation for all available options.
|
|
583
699
|
|
|
584
700
|
**Note:** This feature is currently supported only with TensorFlow, JAX
|
|
585
701
|
and Torch backends.
|
|
@@ -614,18 +730,41 @@ class Model(Trainer, base_trainer.Trainer, Layer):
|
|
|
614
730
|
}
|
|
615
731
|
predictions = ort_session.run(None, ort_inputs)
|
|
616
732
|
```
|
|
733
|
+
|
|
734
|
+
Here's how to export a LiteRT (TFLite) for inference.
|
|
735
|
+
|
|
736
|
+
```python
|
|
737
|
+
# Export the model as a LiteRT artifact
|
|
738
|
+
model.export("path/to/location", format="litert")
|
|
739
|
+
|
|
740
|
+
# Load the artifact in a different process/environment
|
|
741
|
+
interpreter = tf.lite.Interpreter(model_path="path/to/location")
|
|
742
|
+
interpreter.allocate_tensors()
|
|
743
|
+
interpreter.set_tensor(
|
|
744
|
+
interpreter.get_input_details()[0]['index'], input_data
|
|
745
|
+
)
|
|
746
|
+
interpreter.invoke()
|
|
747
|
+
output_data = interpreter.get_tensor(
|
|
748
|
+
interpreter.get_output_details()[0]['index']
|
|
749
|
+
)
|
|
750
|
+
```
|
|
617
751
|
"""
|
|
752
|
+
from keras.src.export import export_litert
|
|
618
753
|
from keras.src.export import export_onnx
|
|
619
754
|
from keras.src.export import export_openvino
|
|
620
755
|
from keras.src.export import export_saved_model
|
|
621
756
|
|
|
622
|
-
available_formats = ("tf_saved_model", "onnx", "openvino")
|
|
757
|
+
available_formats = ("tf_saved_model", "onnx", "openvino", "litert")
|
|
623
758
|
if format not in available_formats:
|
|
624
759
|
raise ValueError(
|
|
625
760
|
f"Unrecognized format={format}. Supported formats are: "
|
|
626
761
|
f"{list(available_formats)}."
|
|
627
762
|
)
|
|
628
763
|
|
|
764
|
+
# Check if LiteRT export is available (requires TensorFlow backend)
|
|
765
|
+
if format == "litert" and backend.backend() != "tensorflow":
|
|
766
|
+
raise ImportError("LiteRT export requires TensorFlow backend.")
|
|
767
|
+
|
|
629
768
|
if format == "tf_saved_model":
|
|
630
769
|
export_saved_model(
|
|
631
770
|
self,
|
|
@@ -650,6 +789,13 @@ class Model(Trainer, base_trainer.Trainer, Layer):
|
|
|
650
789
|
input_signature=input_signature,
|
|
651
790
|
**kwargs,
|
|
652
791
|
)
|
|
792
|
+
elif format == "litert":
|
|
793
|
+
export_litert(
|
|
794
|
+
self,
|
|
795
|
+
filepath,
|
|
796
|
+
input_signature=input_signature,
|
|
797
|
+
**kwargs,
|
|
798
|
+
)
|
|
653
799
|
|
|
654
800
|
@classmethod
|
|
655
801
|
def from_config(cls, config, custom_objects=None):
|
|
@@ -850,13 +996,18 @@ class Model(Trainer, base_trainer.Trainer, Layer):
|
|
|
850
996
|
self.non_trainable_variables, path_value_dict
|
|
851
997
|
)
|
|
852
998
|
elif k == "optimizer_variables":
|
|
853
|
-
self.
|
|
854
|
-
self.
|
|
855
|
-
|
|
999
|
+
if hasattr(self, "optimizer") and self.optimizer is not None:
|
|
1000
|
+
self._assign_variable_values(
|
|
1001
|
+
self.optimizer.variables, path_value_dict
|
|
1002
|
+
)
|
|
856
1003
|
elif k == "metrics_variables":
|
|
857
|
-
|
|
858
|
-
self
|
|
859
|
-
|
|
1004
|
+
if (
|
|
1005
|
+
hasattr(self, "metrics_variables")
|
|
1006
|
+
and self.metrics_variables
|
|
1007
|
+
):
|
|
1008
|
+
self._assign_variable_values(
|
|
1009
|
+
self.metrics_variables, path_value_dict
|
|
1010
|
+
)
|
|
860
1011
|
else:
|
|
861
1012
|
raise ValueError(f"Unknown variable name: {k}")
|
|
862
1013
|
|
keras/src/ops/image.py
CHANGED
|
@@ -565,6 +565,8 @@ class ExtractPatches(Operation):
|
|
|
565
565
|
if isinstance(size, int):
|
|
566
566
|
size = (size, size)
|
|
567
567
|
self.size = size
|
|
568
|
+
if strides is None:
|
|
569
|
+
strides = size
|
|
568
570
|
self.strides = strides
|
|
569
571
|
self.dilation_rate = dilation_rate
|
|
570
572
|
self.padding = padding
|
|
@@ -583,8 +585,6 @@ class ExtractPatches(Operation):
|
|
|
583
585
|
def compute_output_spec(self, images):
|
|
584
586
|
images_shape = list(images.shape)
|
|
585
587
|
original_ndim = len(images_shape)
|
|
586
|
-
if not self.strides:
|
|
587
|
-
strides = (self.size[0], self.size[1])
|
|
588
588
|
if self.data_format == "channels_last":
|
|
589
589
|
channels_in = images_shape[-1]
|
|
590
590
|
else:
|
|
@@ -597,7 +597,7 @@ class ExtractPatches(Operation):
|
|
|
597
597
|
images_shape,
|
|
598
598
|
filters,
|
|
599
599
|
kernel_size,
|
|
600
|
-
strides=strides,
|
|
600
|
+
strides=self.strides,
|
|
601
601
|
padding=self.padding,
|
|
602
602
|
data_format=self.data_format,
|
|
603
603
|
dilation_rate=self.dilation_rate,
|
|
@@ -616,42 +616,98 @@ def extract_patches(
|
|
|
616
616
|
padding="valid",
|
|
617
617
|
data_format=None,
|
|
618
618
|
):
|
|
619
|
-
"""Extracts patches from the image(s).
|
|
619
|
+
"""Extracts patches from the image(s) or volume(s).
|
|
620
|
+
|
|
621
|
+
This function supports both 2D and 3D patch extraction based on the
|
|
622
|
+
`size` argument length, similar to how `keras.ops.conv` handles
|
|
623
|
+
different dimensions.
|
|
620
624
|
|
|
621
625
|
Args:
|
|
622
|
-
images: Input image or batch of images.
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
626
|
+
images: Input image/volume or batch of images/volumes.
|
|
627
|
+
For 2D patches: 3D `(H, W, C)` or 4D `(N, H, W, C)`.
|
|
628
|
+
For 3D patches: 4D `(D, H, W, C)` or 5D `(N, D, H, W, C)`.
|
|
629
|
+
size: Patch size as int or tuple.
|
|
630
|
+
Length 2 tuple `(patch_height, patch_width)` or int for 2D patches.
|
|
631
|
+
Length 3 tuple `(patch_depth, patch_height, patch_width)` for
|
|
632
|
+
3D patches.
|
|
633
|
+
strides: Strides for patch extraction. If not specified, defaults
|
|
634
|
+
to `size` (non-overlapping patches).
|
|
635
|
+
dilation_rate: Dilation rate for patch extraction. Note that
|
|
636
|
+
`dilation_rate > 1` is not supported with `strides > 1`.
|
|
630
637
|
padding: The type of padding algorithm to use: `"same"` or `"valid"`.
|
|
631
638
|
data_format: A string specifying the data format of the input tensor.
|
|
632
639
|
It can be either `"channels_last"` or `"channels_first"`.
|
|
633
|
-
|
|
634
|
-
`(batch, height, width, channels)`, while `"channels_first"`
|
|
635
|
-
corresponds to inputs with shape `(batch, channels, height, width)`.
|
|
636
|
-
If not specified, the value will default to
|
|
637
|
-
`keras.config.image_data_format`.
|
|
640
|
+
If not specified, defaults to `keras.config.image_data_format`.
|
|
638
641
|
|
|
639
642
|
Returns:
|
|
640
|
-
Extracted patches
|
|
643
|
+
Extracted patches with shape depending on input and `size`:
|
|
644
|
+
- 2D patches: 3D (unbatched) or 4D (batched)
|
|
645
|
+
- 3D patches: 4D (unbatched) or 5D (batched)
|
|
641
646
|
|
|
642
647
|
Examples:
|
|
643
648
|
|
|
649
|
+
>>> # 2D patches from batch of images
|
|
644
650
|
>>> image = np.random.random(
|
|
645
651
|
... (2, 20, 20, 3)
|
|
646
|
-
... ).astype("float32")
|
|
652
|
+
... ).astype("float32")
|
|
647
653
|
>>> patches = keras.ops.image.extract_patches(image, (5, 5))
|
|
648
654
|
>>> patches.shape
|
|
649
655
|
(2, 4, 4, 75)
|
|
650
|
-
|
|
656
|
+
|
|
657
|
+
>>> # 2D patches from single image
|
|
658
|
+
>>> image = np.random.random((20, 20, 3)).astype("float32")
|
|
651
659
|
>>> patches = keras.ops.image.extract_patches(image, (3, 3), (1, 1))
|
|
652
660
|
>>> patches.shape
|
|
653
661
|
(18, 18, 27)
|
|
662
|
+
|
|
663
|
+
>>> # 3D patches from batch of volumes
|
|
664
|
+
>>> volumes = np.random.random(
|
|
665
|
+
... (2, 10, 10, 10, 3)
|
|
666
|
+
... ).astype("float32")
|
|
667
|
+
>>> patches = keras.ops.image.extract_patches(volumes, (3, 3, 3))
|
|
668
|
+
>>> patches.shape
|
|
669
|
+
(2, 3, 3, 3, 81)
|
|
670
|
+
|
|
671
|
+
>>> # 3D patches from single volume
|
|
672
|
+
>>> volume = np.random.random((10, 10, 10, 3)).astype("float32")
|
|
673
|
+
>>> patches = keras.ops.image.extract_patches(volume, (3, 3, 3))
|
|
674
|
+
>>> patches.shape
|
|
675
|
+
(3, 3, 3, 81)
|
|
654
676
|
"""
|
|
677
|
+
# Validate size argument
|
|
678
|
+
if not isinstance(size, int):
|
|
679
|
+
if not isinstance(size, (tuple, list)):
|
|
680
|
+
raise TypeError(
|
|
681
|
+
"Invalid `size` argument. Expected an int or a tuple. "
|
|
682
|
+
f"Received: size={size} of type {type(size).__name__}"
|
|
683
|
+
)
|
|
684
|
+
if len(size) not in (2, 3):
|
|
685
|
+
raise ValueError(
|
|
686
|
+
"Invalid `size` argument. Expected a tuple of length 2 or 3. "
|
|
687
|
+
f"Received: size={size} with length {len(size)}"
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
# Determine 2D vs 3D based on size argument
|
|
691
|
+
if not isinstance(size, int) and len(size) == 3:
|
|
692
|
+
# 3D patch extraction
|
|
693
|
+
if any_symbolic_tensors((images,)):
|
|
694
|
+
return ExtractPatches3D(
|
|
695
|
+
size=size,
|
|
696
|
+
strides=strides,
|
|
697
|
+
dilation_rate=dilation_rate,
|
|
698
|
+
padding=padding,
|
|
699
|
+
data_format=data_format,
|
|
700
|
+
).symbolic_call(images)
|
|
701
|
+
return _extract_patches_3d(
|
|
702
|
+
images,
|
|
703
|
+
size,
|
|
704
|
+
strides,
|
|
705
|
+
dilation_rate,
|
|
706
|
+
padding,
|
|
707
|
+
data_format=data_format,
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
# 2D patch extraction (default)
|
|
655
711
|
if any_symbolic_tensors((images,)):
|
|
656
712
|
return ExtractPatches(
|
|
657
713
|
size=size,
|
|
@@ -712,6 +768,187 @@ def _extract_patches(
|
|
|
712
768
|
return patches
|
|
713
769
|
|
|
714
770
|
|
|
771
|
+
class ExtractPatches3D(Operation):
|
|
772
|
+
def __init__(
|
|
773
|
+
self,
|
|
774
|
+
size,
|
|
775
|
+
strides=None,
|
|
776
|
+
dilation_rate=1,
|
|
777
|
+
padding="valid",
|
|
778
|
+
data_format=None,
|
|
779
|
+
*,
|
|
780
|
+
name=None,
|
|
781
|
+
):
|
|
782
|
+
super().__init__(name=name)
|
|
783
|
+
if isinstance(size, int):
|
|
784
|
+
size = (size, size, size)
|
|
785
|
+
elif len(size) != 3:
|
|
786
|
+
raise TypeError(
|
|
787
|
+
"Invalid `size` argument. Expected an "
|
|
788
|
+
f"int or a tuple of length 3. Received: size={size}"
|
|
789
|
+
)
|
|
790
|
+
self.size = size
|
|
791
|
+
if strides is not None:
|
|
792
|
+
if isinstance(strides, int):
|
|
793
|
+
strides = (strides, strides, strides)
|
|
794
|
+
elif len(strides) != 3:
|
|
795
|
+
raise ValueError(f"Invalid `strides` argument. Got: {strides}")
|
|
796
|
+
else:
|
|
797
|
+
strides = size
|
|
798
|
+
self.strides = strides
|
|
799
|
+
self.dilation_rate = dilation_rate
|
|
800
|
+
self.padding = padding
|
|
801
|
+
self.data_format = backend.standardize_data_format(data_format)
|
|
802
|
+
|
|
803
|
+
def call(self, volumes):
|
|
804
|
+
return _extract_patches_3d(
|
|
805
|
+
volumes,
|
|
806
|
+
self.size,
|
|
807
|
+
self.strides,
|
|
808
|
+
self.dilation_rate,
|
|
809
|
+
self.padding,
|
|
810
|
+
self.data_format,
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
def compute_output_spec(self, volumes):
|
|
814
|
+
volumes_shape = list(volumes.shape)
|
|
815
|
+
original_ndim = len(volumes_shape)
|
|
816
|
+
strides = self.strides
|
|
817
|
+
if self.data_format == "channels_last":
|
|
818
|
+
channels_in = volumes_shape[-1]
|
|
819
|
+
else:
|
|
820
|
+
channels_in = volumes_shape[-4]
|
|
821
|
+
if original_ndim == 4:
|
|
822
|
+
volumes_shape = [1] + volumes_shape
|
|
823
|
+
filters = self.size[0] * self.size[1] * self.size[2] * channels_in
|
|
824
|
+
kernel_size = (self.size[0], self.size[1], self.size[2])
|
|
825
|
+
out_shape = compute_conv_output_shape(
|
|
826
|
+
volumes_shape,
|
|
827
|
+
filters,
|
|
828
|
+
kernel_size,
|
|
829
|
+
strides=strides,
|
|
830
|
+
padding=self.padding,
|
|
831
|
+
data_format=self.data_format,
|
|
832
|
+
dilation_rate=self.dilation_rate,
|
|
833
|
+
)
|
|
834
|
+
if original_ndim == 4:
|
|
835
|
+
out_shape = out_shape[1:]
|
|
836
|
+
return KerasTensor(shape=out_shape, dtype=volumes.dtype)
|
|
837
|
+
|
|
838
|
+
|
|
839
|
+
def _extract_patches_3d(
|
|
840
|
+
volumes,
|
|
841
|
+
size,
|
|
842
|
+
strides=None,
|
|
843
|
+
dilation_rate=1,
|
|
844
|
+
padding="valid",
|
|
845
|
+
data_format=None,
|
|
846
|
+
):
|
|
847
|
+
if isinstance(size, int):
|
|
848
|
+
patch_d = patch_h = patch_w = size
|
|
849
|
+
elif len(size) == 3:
|
|
850
|
+
patch_d, patch_h, patch_w = size
|
|
851
|
+
else:
|
|
852
|
+
raise TypeError(
|
|
853
|
+
"Invalid `size` argument. Expected an "
|
|
854
|
+
f"int or a tuple of length 3. Received: size={size}"
|
|
855
|
+
)
|
|
856
|
+
if strides is None:
|
|
857
|
+
strides = size
|
|
858
|
+
if isinstance(strides, int):
|
|
859
|
+
strides = (strides, strides, strides)
|
|
860
|
+
if len(strides) != 3:
|
|
861
|
+
raise ValueError(f"Invalid `strides` argument. Got: {strides}")
|
|
862
|
+
data_format = backend.standardize_data_format(data_format)
|
|
863
|
+
if data_format == "channels_last":
|
|
864
|
+
channels_in = volumes.shape[-1]
|
|
865
|
+
elif data_format == "channels_first":
|
|
866
|
+
channels_in = volumes.shape[-4]
|
|
867
|
+
out_dim = patch_d * patch_w * patch_h * channels_in
|
|
868
|
+
kernel = backend.numpy.eye(out_dim, dtype=volumes.dtype)
|
|
869
|
+
kernel = backend.numpy.reshape(
|
|
870
|
+
kernel, (patch_d, patch_h, patch_w, channels_in, out_dim)
|
|
871
|
+
)
|
|
872
|
+
_unbatched = False
|
|
873
|
+
if len(volumes.shape) == 4:
|
|
874
|
+
_unbatched = True
|
|
875
|
+
volumes = backend.numpy.expand_dims(volumes, axis=0)
|
|
876
|
+
patches = backend.nn.conv(
|
|
877
|
+
inputs=volumes,
|
|
878
|
+
kernel=kernel,
|
|
879
|
+
strides=strides,
|
|
880
|
+
padding=padding,
|
|
881
|
+
data_format=data_format,
|
|
882
|
+
dilation_rate=dilation_rate,
|
|
883
|
+
)
|
|
884
|
+
if _unbatched:
|
|
885
|
+
patches = backend.numpy.squeeze(patches, axis=0)
|
|
886
|
+
return patches
|
|
887
|
+
|
|
888
|
+
|
|
889
|
+
@keras_export("keras.ops.image.extract_patches_3d")
|
|
890
|
+
def extract_patches_3d(
|
|
891
|
+
volumes,
|
|
892
|
+
size,
|
|
893
|
+
strides=None,
|
|
894
|
+
dilation_rate=1,
|
|
895
|
+
padding="valid",
|
|
896
|
+
data_format=None,
|
|
897
|
+
):
|
|
898
|
+
"""Extracts patches from the volume(s).
|
|
899
|
+
|
|
900
|
+
Args:
|
|
901
|
+
volumes: Input volume or batch of volumes. Must be 4D or 5D.
|
|
902
|
+
size: Patch size int or tuple (patch_depth, patch_height, patch_width)
|
|
903
|
+
strides: strides along depth, height, and width. If not specified, or
|
|
904
|
+
if `None`, it defaults to the same value as `size`.
|
|
905
|
+
dilation_rate: This is the input stride, specifying how far two
|
|
906
|
+
consecutive patch samples are in the input. Note that using
|
|
907
|
+
`dilation_rate > 1` is not supported in conjunction with
|
|
908
|
+
`strides > 1` on the TensorFlow backend.
|
|
909
|
+
padding: The type of padding algorithm to use: `"same"` or `"valid"`.
|
|
910
|
+
data_format: A string specifying the data format of the input tensor.
|
|
911
|
+
It can be either `"channels_last"` or `"channels_first"`.
|
|
912
|
+
`"channels_last"` corresponds to inputs with shape
|
|
913
|
+
`(batch, depth, height, width, channels)`, while `"channels_first"`
|
|
914
|
+
corresponds to inputs with shape
|
|
915
|
+
`(batch, channels, depth, height, width)`. If not specified,
|
|
916
|
+
the value will default to `keras.config.image_data_format()`.
|
|
917
|
+
|
|
918
|
+
Returns:
|
|
919
|
+
Extracted patches 4D (if not batched) or 5D (if batched)
|
|
920
|
+
|
|
921
|
+
Examples:
|
|
922
|
+
|
|
923
|
+
>>> import numpy as np
|
|
924
|
+
>>> import keras
|
|
925
|
+
>>> # Batched case
|
|
926
|
+
>>> volumes = np.random.random(
|
|
927
|
+
... (2, 10, 10, 10, 3)
|
|
928
|
+
... ).astype("float32") # batch of 2 volumes
|
|
929
|
+
>>> patches = keras.ops.image.extract_patches_3d(volumes, (3, 3, 3))
|
|
930
|
+
>>> patches.shape
|
|
931
|
+
(2, 3, 3, 3, 81)
|
|
932
|
+
>>> # Unbatched case
|
|
933
|
+
>>> volume = np.random.random((10, 10, 10, 3)).astype("float32") # 1 volume
|
|
934
|
+
>>> patches = keras.ops.image.extract_patches_3d(volume, (3, 3, 3))
|
|
935
|
+
>>> patches.shape
|
|
936
|
+
(3, 3, 3, 81)
|
|
937
|
+
"""
|
|
938
|
+
if any_symbolic_tensors((volumes,)):
|
|
939
|
+
return ExtractPatches3D(
|
|
940
|
+
size=size,
|
|
941
|
+
strides=strides,
|
|
942
|
+
dilation_rate=dilation_rate,
|
|
943
|
+
padding=padding,
|
|
944
|
+
data_format=data_format,
|
|
945
|
+
).symbolic_call(volumes)
|
|
946
|
+
|
|
947
|
+
return _extract_patches_3d(
|
|
948
|
+
volumes, size, strides, dilation_rate, padding, data_format=data_format
|
|
949
|
+
)
|
|
950
|
+
|
|
951
|
+
|
|
715
952
|
class MapCoordinates(Operation):
|
|
716
953
|
def __init__(self, order, fill_mode="constant", fill_value=0, *, name=None):
|
|
717
954
|
super().__init__(name=name)
|