keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +16 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +6 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +16 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +12 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +6 -12
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +38 -20
- keras/src/backend/jax/core.py +126 -78
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/layer.py +3 -1
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +511 -29
- keras/src/backend/jax/numpy.py +109 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +18 -3
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +97 -8
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +6 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1369 -195
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +351 -56
- keras/src/backend/tensorflow/trainer.py +6 -2
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +109 -9
- keras/src/backend/torch/trainer.py +8 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +4 -0
- keras/src/dtype_policies/dtype_policy.py +180 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/onnx.py +6 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +406 -102
- keras/src/layers/core/einsum_dense.py +521 -116
- keras/src/layers/core/embedding.py +257 -99
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +50 -15
- keras/src/layers/merging/concatenate.py +6 -5
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/preprocessing/string_lookup.py +26 -28
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/legacy/preprocessing/image.py +4 -1
- keras/src/legacy/preprocessing/sequence.py +20 -12
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +195 -44
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +701 -44
- keras/src/ops/operation.py +90 -29
- keras/src/ops/operation_utils.py +2 -0
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +346 -207
- keras/src/quantizers/gptq_config.py +63 -13
- keras/src/quantizers/gptq_core.py +328 -215
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +407 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +38 -17
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
- keras/src/tree/torchtree_impl.py +215 -0
- keras/src/tree/tree_api.py +6 -1
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/python_utils.py +5 -0
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +70 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
from keras.src.api_export import keras_export
|
|
2
|
+
from keras.src.dtype_policies import QUANTIZATION_MODES
|
|
3
|
+
from keras.src.saving import serialization_lib
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@keras_export("keras.quantizers.QuantizationConfig")
|
|
7
|
+
class QuantizationConfig:
|
|
8
|
+
"""Base class for quantization configs.
|
|
9
|
+
|
|
10
|
+
Subclasses must implement the `mode` property and the `get_config` and
|
|
11
|
+
`from_config` class methods.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
weight_quantizer: Quantizer for weights.
|
|
15
|
+
activation_quantizer: Quantizer for activations.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, weight_quantizer=None, activation_quantizer=None):
|
|
19
|
+
self.weight_quantizer = weight_quantizer
|
|
20
|
+
self.activation_quantizer = activation_quantizer
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def mode(self):
|
|
24
|
+
raise NotImplementedError(
|
|
25
|
+
"Subclasses must implement this property. Do not instantiate "
|
|
26
|
+
"QuantizationConfig directly."
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
def get_config(self):
|
|
30
|
+
return {
|
|
31
|
+
"weight_quantizer": serialization_lib.serialize_keras_object(
|
|
32
|
+
self.weight_quantizer
|
|
33
|
+
),
|
|
34
|
+
"activation_quantizer": serialization_lib.serialize_keras_object(
|
|
35
|
+
self.activation_quantizer
|
|
36
|
+
),
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def from_config(cls, config):
|
|
41
|
+
weight_quantizer = serialization_lib.deserialize_keras_object(
|
|
42
|
+
config.get("weight_quantizer")
|
|
43
|
+
)
|
|
44
|
+
activation_quantizer = serialization_lib.deserialize_keras_object(
|
|
45
|
+
config.get("activation_quantizer")
|
|
46
|
+
)
|
|
47
|
+
return cls(
|
|
48
|
+
weight_quantizer=weight_quantizer,
|
|
49
|
+
activation_quantizer=activation_quantizer,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
@staticmethod
|
|
53
|
+
def weight_quantizer_or_default(config, default):
|
|
54
|
+
if config is not None and config.weight_quantizer is not None:
|
|
55
|
+
return config.weight_quantizer
|
|
56
|
+
return default
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def activation_quantizer_or_default(config, default):
|
|
60
|
+
if config is not None:
|
|
61
|
+
return config.activation_quantizer
|
|
62
|
+
return default
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@keras_export("keras.quantizers.Int8QuantizationConfig")
|
|
66
|
+
class Int8QuantizationConfig(QuantizationConfig):
|
|
67
|
+
"""Int8 quantization config.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
weight_quantizer: Quantizer for weights.
|
|
71
|
+
activation_quantizer: Quantizer for activations. If "default", uses
|
|
72
|
+
AbsMaxQuantizer with axis=-1.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(self, weight_quantizer=None, activation_quantizer="default"):
|
|
76
|
+
from keras.src.quantizers.quantizers import AbsMaxQuantizer
|
|
77
|
+
|
|
78
|
+
if activation_quantizer == "default":
|
|
79
|
+
activation_quantizer = AbsMaxQuantizer()
|
|
80
|
+
super().__init__(weight_quantizer, activation_quantizer)
|
|
81
|
+
if self.weight_quantizer is not None:
|
|
82
|
+
if self.weight_quantizer.output_dtype != "int8":
|
|
83
|
+
raise ValueError(
|
|
84
|
+
"Int8QuantizationConfig requires a weight_quantizer "
|
|
85
|
+
"with output_dtype='int8'. Received: "
|
|
86
|
+
f"output_dtype={self.weight_quantizer.output_dtype}"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def mode(self):
|
|
91
|
+
return "int8"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@keras_export("keras.quantizers.Int4QuantizationConfig")
|
|
95
|
+
class Int4QuantizationConfig(QuantizationConfig):
|
|
96
|
+
"""Int4 quantization config.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
weight_quantizer: Quantizer for weights.
|
|
100
|
+
activation_quantizer: Quantizer for activations. If "default", uses
|
|
101
|
+
AbsMaxQuantizer with axis=-1.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
def __init__(self, weight_quantizer=None, activation_quantizer="default"):
|
|
105
|
+
from keras.src.quantizers.quantizers import AbsMaxQuantizer
|
|
106
|
+
|
|
107
|
+
if activation_quantizer == "default":
|
|
108
|
+
activation_quantizer = AbsMaxQuantizer()
|
|
109
|
+
super().__init__(weight_quantizer, activation_quantizer)
|
|
110
|
+
if self.weight_quantizer is not None:
|
|
111
|
+
if self.weight_quantizer.value_range != (-8, 7):
|
|
112
|
+
raise ValueError(
|
|
113
|
+
"Int4QuantizationConfig requires a weight_quantizer "
|
|
114
|
+
"with value_range=(-8, 7). Received: "
|
|
115
|
+
f"value_range={self.weight_quantizer.value_range}"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
if self.weight_quantizer.output_dtype != "int8":
|
|
119
|
+
raise ValueError(
|
|
120
|
+
"Int4QuantizationConfig requires a weight_quantizer "
|
|
121
|
+
"with output_dtype='int8'. Received: "
|
|
122
|
+
f"output_dtype={self.weight_quantizer.output_dtype}"
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def mode(self):
|
|
127
|
+
return "int4"
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@keras_export("keras.quantizers.Float8QuantizationConfig")
|
|
131
|
+
class Float8QuantizationConfig(QuantizationConfig):
|
|
132
|
+
"""FP8 quantization config.
|
|
133
|
+
|
|
134
|
+
FP8 mixed-precision training does not support user defined quantizers.
|
|
135
|
+
This config is only used to indicate that FP8 mixed-precision training
|
|
136
|
+
should be used.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
def __init__(self):
|
|
140
|
+
super().__init__(None, None)
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def mode(self):
|
|
144
|
+
return "float8"
|
|
145
|
+
|
|
146
|
+
def get_config(self):
|
|
147
|
+
return {}
|
|
148
|
+
|
|
149
|
+
@classmethod
|
|
150
|
+
def from_config(cls, config):
|
|
151
|
+
return cls()
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def validate_and_resolve_config(mode, config):
|
|
155
|
+
"""Validate and resolve quantization config.
|
|
156
|
+
|
|
157
|
+
This function validates the quantization config and resolves the mode.
|
|
158
|
+
If mode is not provided, it is inferred from the config.
|
|
159
|
+
If config is not provided, a default config is inferred from the mode.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
mode: Quantization mode.
|
|
163
|
+
config: Quantization config.
|
|
164
|
+
"""
|
|
165
|
+
# 1. Backwards Compatibility: Handle string shortcuts.
|
|
166
|
+
if isinstance(config, str):
|
|
167
|
+
mode = config
|
|
168
|
+
config = None
|
|
169
|
+
|
|
170
|
+
_validate_mode(mode)
|
|
171
|
+
|
|
172
|
+
# 2. Resolve "mode" into a Config object.
|
|
173
|
+
if config is None:
|
|
174
|
+
if mode == "int8":
|
|
175
|
+
config = Int8QuantizationConfig()
|
|
176
|
+
elif mode == "int4":
|
|
177
|
+
config = Int4QuantizationConfig()
|
|
178
|
+
elif mode == "float8":
|
|
179
|
+
config = Float8QuantizationConfig()
|
|
180
|
+
elif mode == "gptq":
|
|
181
|
+
raise ValueError(
|
|
182
|
+
"For GPTQ, you must pass a `GPTQConfig` object in the "
|
|
183
|
+
"`config` argument."
|
|
184
|
+
)
|
|
185
|
+
elif mode == "awq":
|
|
186
|
+
raise ValueError(
|
|
187
|
+
"For AWQ, you must pass an `AWQConfig` object in the "
|
|
188
|
+
"`config` argument."
|
|
189
|
+
)
|
|
190
|
+
else:
|
|
191
|
+
if mode is not None:
|
|
192
|
+
raise ValueError(
|
|
193
|
+
f"Invalid quantization mode. Received: mode={mode}"
|
|
194
|
+
)
|
|
195
|
+
raise ValueError(
|
|
196
|
+
"You must provide either `mode` or `config` to `quantize`."
|
|
197
|
+
)
|
|
198
|
+
else:
|
|
199
|
+
if not isinstance(config, QuantizationConfig):
|
|
200
|
+
raise ValueError(
|
|
201
|
+
"Argument `config` must be an instance of "
|
|
202
|
+
"`QuantizationConfig`. "
|
|
203
|
+
f"Received: config={config} (of type {type(config)})"
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# 3. Validation: Prevent contradictions.
|
|
207
|
+
if mode is not None and config.mode != mode:
|
|
208
|
+
raise ValueError(
|
|
209
|
+
f"Contradictory arguments: mode='{mode}' but "
|
|
210
|
+
f"config.mode='{config.mode}'"
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Ensure mode is consistent.
|
|
214
|
+
mode = config.mode
|
|
215
|
+
|
|
216
|
+
# Ensure the mode derived from the config is valid.
|
|
217
|
+
_validate_mode(mode)
|
|
218
|
+
|
|
219
|
+
if mode == "gptq":
|
|
220
|
+
from keras.src.quantizers.gptq_config import GPTQConfig
|
|
221
|
+
|
|
222
|
+
if not isinstance(config, GPTQConfig):
|
|
223
|
+
raise ValueError(
|
|
224
|
+
"Mode 'gptq' requires a valid `config` argument of type "
|
|
225
|
+
f"`GPTQConfig`. Received: {type(config)}"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
if mode == "awq":
|
|
229
|
+
from keras.src.quantizers.awq_config import AWQConfig
|
|
230
|
+
|
|
231
|
+
if not isinstance(config, AWQConfig):
|
|
232
|
+
raise ValueError(
|
|
233
|
+
"Mode 'awq' requires a valid `config` argument of type "
|
|
234
|
+
f"`AWQConfig`. Received: {type(config)}"
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
return config
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def _validate_mode(mode):
|
|
241
|
+
"""Validates quantization mode."""
|
|
242
|
+
if mode is not None and mode not in QUANTIZATION_MODES:
|
|
243
|
+
raise ValueError(
|
|
244
|
+
"Invalid quantization mode. "
|
|
245
|
+
f"Expected one of {QUANTIZATION_MODES}. Received: mode={mode}"
|
|
246
|
+
)
|