keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,246 @@
1
+ from keras.src.api_export import keras_export
2
+ from keras.src.dtype_policies import QUANTIZATION_MODES
3
+ from keras.src.saving import serialization_lib
4
+
5
+
6
+ @keras_export("keras.quantizers.QuantizationConfig")
7
+ class QuantizationConfig:
8
+ """Base class for quantization configs.
9
+
10
+ Subclasses must implement the `mode` property and the `get_config` and
11
+ `from_config` class methods.
12
+
13
+ Args:
14
+ weight_quantizer: Quantizer for weights.
15
+ activation_quantizer: Quantizer for activations.
16
+ """
17
+
18
+ def __init__(self, weight_quantizer=None, activation_quantizer=None):
19
+ self.weight_quantizer = weight_quantizer
20
+ self.activation_quantizer = activation_quantizer
21
+
22
+ @property
23
+ def mode(self):
24
+ raise NotImplementedError(
25
+ "Subclasses must implement this property. Do not instantiate "
26
+ "QuantizationConfig directly."
27
+ )
28
+
29
+ def get_config(self):
30
+ return {
31
+ "weight_quantizer": serialization_lib.serialize_keras_object(
32
+ self.weight_quantizer
33
+ ),
34
+ "activation_quantizer": serialization_lib.serialize_keras_object(
35
+ self.activation_quantizer
36
+ ),
37
+ }
38
+
39
+ @classmethod
40
+ def from_config(cls, config):
41
+ weight_quantizer = serialization_lib.deserialize_keras_object(
42
+ config.get("weight_quantizer")
43
+ )
44
+ activation_quantizer = serialization_lib.deserialize_keras_object(
45
+ config.get("activation_quantizer")
46
+ )
47
+ return cls(
48
+ weight_quantizer=weight_quantizer,
49
+ activation_quantizer=activation_quantizer,
50
+ )
51
+
52
+ @staticmethod
53
+ def weight_quantizer_or_default(config, default):
54
+ if config is not None and config.weight_quantizer is not None:
55
+ return config.weight_quantizer
56
+ return default
57
+
58
+ @staticmethod
59
+ def activation_quantizer_or_default(config, default):
60
+ if config is not None:
61
+ return config.activation_quantizer
62
+ return default
63
+
64
+
65
+ @keras_export("keras.quantizers.Int8QuantizationConfig")
66
+ class Int8QuantizationConfig(QuantizationConfig):
67
+ """Int8 quantization config.
68
+
69
+ Args:
70
+ weight_quantizer: Quantizer for weights.
71
+ activation_quantizer: Quantizer for activations. If "default", uses
72
+ AbsMaxQuantizer with axis=-1.
73
+ """
74
+
75
+ def __init__(self, weight_quantizer=None, activation_quantizer="default"):
76
+ from keras.src.quantizers.quantizers import AbsMaxQuantizer
77
+
78
+ if activation_quantizer == "default":
79
+ activation_quantizer = AbsMaxQuantizer()
80
+ super().__init__(weight_quantizer, activation_quantizer)
81
+ if self.weight_quantizer is not None:
82
+ if self.weight_quantizer.output_dtype != "int8":
83
+ raise ValueError(
84
+ "Int8QuantizationConfig requires a weight_quantizer "
85
+ "with output_dtype='int8'. Received: "
86
+ f"output_dtype={self.weight_quantizer.output_dtype}"
87
+ )
88
+
89
+ @property
90
+ def mode(self):
91
+ return "int8"
92
+
93
+
94
+ @keras_export("keras.quantizers.Int4QuantizationConfig")
95
+ class Int4QuantizationConfig(QuantizationConfig):
96
+ """Int4 quantization config.
97
+
98
+ Args:
99
+ weight_quantizer: Quantizer for weights.
100
+ activation_quantizer: Quantizer for activations. If "default", uses
101
+ AbsMaxQuantizer with axis=-1.
102
+ """
103
+
104
+ def __init__(self, weight_quantizer=None, activation_quantizer="default"):
105
+ from keras.src.quantizers.quantizers import AbsMaxQuantizer
106
+
107
+ if activation_quantizer == "default":
108
+ activation_quantizer = AbsMaxQuantizer()
109
+ super().__init__(weight_quantizer, activation_quantizer)
110
+ if self.weight_quantizer is not None:
111
+ if self.weight_quantizer.value_range != (-8, 7):
112
+ raise ValueError(
113
+ "Int4QuantizationConfig requires a weight_quantizer "
114
+ "with value_range=(-8, 7). Received: "
115
+ f"value_range={self.weight_quantizer.value_range}"
116
+ )
117
+
118
+ if self.weight_quantizer.output_dtype != "int8":
119
+ raise ValueError(
120
+ "Int4QuantizationConfig requires a weight_quantizer "
121
+ "with output_dtype='int8'. Received: "
122
+ f"output_dtype={self.weight_quantizer.output_dtype}"
123
+ )
124
+
125
+ @property
126
+ def mode(self):
127
+ return "int4"
128
+
129
+
130
+ @keras_export("keras.quantizers.Float8QuantizationConfig")
131
+ class Float8QuantizationConfig(QuantizationConfig):
132
+ """FP8 quantization config.
133
+
134
+ FP8 mixed-precision training does not support user defined quantizers.
135
+ This config is only used to indicate that FP8 mixed-precision training
136
+ should be used.
137
+ """
138
+
139
+ def __init__(self):
140
+ super().__init__(None, None)
141
+
142
+ @property
143
+ def mode(self):
144
+ return "float8"
145
+
146
+ def get_config(self):
147
+ return {}
148
+
149
+ @classmethod
150
+ def from_config(cls, config):
151
+ return cls()
152
+
153
+
154
+ def validate_and_resolve_config(mode, config):
155
+ """Validate and resolve quantization config.
156
+
157
+ This function validates the quantization config and resolves the mode.
158
+ If mode is not provided, it is inferred from the config.
159
+ If config is not provided, a default config is inferred from the mode.
160
+
161
+ Args:
162
+ mode: Quantization mode.
163
+ config: Quantization config.
164
+ """
165
+ # 1. Backwards Compatibility: Handle string shortcuts.
166
+ if isinstance(config, str):
167
+ mode = config
168
+ config = None
169
+
170
+ _validate_mode(mode)
171
+
172
+ # 2. Resolve "mode" into a Config object.
173
+ if config is None:
174
+ if mode == "int8":
175
+ config = Int8QuantizationConfig()
176
+ elif mode == "int4":
177
+ config = Int4QuantizationConfig()
178
+ elif mode == "float8":
179
+ config = Float8QuantizationConfig()
180
+ elif mode == "gptq":
181
+ raise ValueError(
182
+ "For GPTQ, you must pass a `GPTQConfig` object in the "
183
+ "`config` argument."
184
+ )
185
+ elif mode == "awq":
186
+ raise ValueError(
187
+ "For AWQ, you must pass an `AWQConfig` object in the "
188
+ "`config` argument."
189
+ )
190
+ else:
191
+ if mode is not None:
192
+ raise ValueError(
193
+ f"Invalid quantization mode. Received: mode={mode}"
194
+ )
195
+ raise ValueError(
196
+ "You must provide either `mode` or `config` to `quantize`."
197
+ )
198
+ else:
199
+ if not isinstance(config, QuantizationConfig):
200
+ raise ValueError(
201
+ "Argument `config` must be an instance of "
202
+ "`QuantizationConfig`. "
203
+ f"Received: config={config} (of type {type(config)})"
204
+ )
205
+
206
+ # 3. Validation: Prevent contradictions.
207
+ if mode is not None and config.mode != mode:
208
+ raise ValueError(
209
+ f"Contradictory arguments: mode='{mode}' but "
210
+ f"config.mode='{config.mode}'"
211
+ )
212
+
213
+ # Ensure mode is consistent.
214
+ mode = config.mode
215
+
216
+ # Ensure the mode derived from the config is valid.
217
+ _validate_mode(mode)
218
+
219
+ if mode == "gptq":
220
+ from keras.src.quantizers.gptq_config import GPTQConfig
221
+
222
+ if not isinstance(config, GPTQConfig):
223
+ raise ValueError(
224
+ "Mode 'gptq' requires a valid `config` argument of type "
225
+ f"`GPTQConfig`. Received: {type(config)}"
226
+ )
227
+
228
+ if mode == "awq":
229
+ from keras.src.quantizers.awq_config import AWQConfig
230
+
231
+ if not isinstance(config, AWQConfig):
232
+ raise ValueError(
233
+ "Mode 'awq' requires a valid `config` argument of type "
234
+ f"`AWQConfig`. Received: {type(config)}"
235
+ )
236
+
237
+ return config
238
+
239
+
240
+ def _validate_mode(mode):
241
+ """Validates quantization mode."""
242
+ if mode is not None and mode not in QUANTIZATION_MODES:
243
+ raise ValueError(
244
+ "Invalid quantization mode. "
245
+ f"Expected one of {QUANTIZATION_MODES}. Received: mode={mode}"
246
+ )