keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,598 @@
1
+ import keras
2
+ from keras.src import tree
3
+ from keras.src.api_export import keras_export
4
+ from keras.src.distillation.distillation_loss import _convert_loss_to_function
5
+ from keras.src.models.model import Model
6
+ from keras.src.saving import serialization_lib
7
+
8
+
9
+ @keras_export("keras.distillation.Distiller")
10
+ class Distiller(Model):
11
+ """Distillation model for transferring knowledge from teacher to student.
12
+
13
+ Knowledge distillation transfers knowledge from a large, complex model
14
+ (teacher) to a smaller, simpler model (student). The student learns
15
+ from both ground truth labels and the teacher's predictions, often
16
+ achieving better performance than training on labels alone.
17
+
18
+ Arguments:
19
+ teacher: A trained `keras.Model` that serves as the knowledge source.
20
+ The teacher model is frozen during distillation.
21
+ student: A `keras.Model` to be trained through distillation.
22
+ distillation_losses: List of distillation losses to apply. Can be a
23
+ single distillation loss or a list of distillation losses like
24
+ `keras.distillation.LogitsDistillation`,
25
+ `keras.distillation.FeatureDistillation`, or custom distillation
26
+ losses.
27
+ distillation_loss_weights: List of weights for each distillation loss.
28
+ Must have the same length as `distillation_losses`. If `None`,
29
+ equal weights are used.
30
+ student_loss_weight: Weight for the student's supervised loss component.
31
+ Must be between 0 and 1. Defaults to 0.5.
32
+ name: Name for the distiller model. Defaults to `"distiller"`.
33
+ **kwargs: Additional keyword arguments passed to the parent `Model`
34
+ class.
35
+
36
+ Attributes:
37
+ student: The student model being trained. Access this to get the trained
38
+ student model for independent use after distillation training.
39
+ teacher: The teacher model providing knowledge. This model is frozen
40
+ during training.
41
+
42
+ Examples:
43
+
44
+ ```python
45
+ # Basic distillation with KerasHub models
46
+ import keras_hub as hub
47
+
48
+ teacher = hub.models.CausalLM.from_preset("gemma_2b_en")
49
+ student = hub.models.CausalLM.from_preset(
50
+ "gemma_1.1_2b_en", load_weights=False
51
+ )
52
+
53
+ # Single distillation loss
54
+ distiller = Distiller(
55
+ teacher=teacher,
56
+ student=student,
57
+ distillation_losses=LogitsDistillation(temperature=3.0),
58
+ )
59
+
60
+ # Compile the distiller (like any Keras model)
61
+ distiller.compile(
62
+ optimizer='adam',
63
+ loss='sparse_categorical_crossentropy',
64
+ metrics=['accuracy']
65
+ )
66
+
67
+ # Train the distiller
68
+ distiller.fit(x_train, y_train, epochs=10)
69
+
70
+ # Access the trained student model
71
+ trained_student = distiller.student
72
+
73
+ # Multiple distillation losses
74
+ distiller = Distiller(
75
+ teacher=teacher,
76
+ student=student,
77
+ distillation_losses=[
78
+ LogitsDistillation(temperature=3.0),
79
+ FeatureDistillation(
80
+ teacher_layer_name="dense_1",
81
+ student_layer_name="dense_1"
82
+ )
83
+ ],
84
+ distillation_loss_weights=[1.0, 0.5],
85
+ )
86
+
87
+ # Compile with custom settings
88
+ distiller.compile(
89
+ optimizer='adam',
90
+ loss='sparse_categorical_crossentropy',
91
+ metrics=['accuracy']
92
+ )
93
+ ```
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ teacher,
99
+ student,
100
+ distillation_losses,
101
+ distillation_loss_weights=None,
102
+ student_loss_weight=0.5,
103
+ name="distiller",
104
+ **kwargs,
105
+ ):
106
+ super().__init__(name=name, **kwargs)
107
+
108
+ # Validate inputs
109
+ self._validate_models(teacher, student)
110
+
111
+ # Store configuration
112
+ self.teacher = teacher
113
+ self.student = student
114
+
115
+ # Validate student_loss_weight
116
+ if not isinstance(student_loss_weight, (int, float)):
117
+ raise ValueError(
118
+ f"student_loss_weight must be a number, got "
119
+ f"{type(student_loss_weight)}"
120
+ )
121
+ if student_loss_weight < 0.0 or student_loss_weight > 1.0:
122
+ raise ValueError(
123
+ f"student_loss_weight must be between 0.0 and 1.0, "
124
+ f"got {student_loss_weight}"
125
+ )
126
+ self.student_loss_weight = student_loss_weight
127
+
128
+ # Handle distillation losses configuration
129
+ if distillation_losses is None:
130
+ raise ValueError(
131
+ "'distillation_losses' cannot be `None`. Provide a "
132
+ "distillation loss (e.g., LogitsDistillation or "
133
+ "FeatureDistillation) or a list of distillation losses."
134
+ )
135
+
136
+ # Convert single distillation loss to list for uniform handling
137
+ if not isinstance(distillation_losses, (list, tuple)):
138
+ self.distillation_losses = [distillation_losses]
139
+ self.distillation_loss_weights = [1.0]
140
+ else:
141
+ self.distillation_losses = distillation_losses
142
+ # Set default weights if not provided
143
+ if distillation_loss_weights is None:
144
+ self.distillation_loss_weights = [1.0] * len(
145
+ distillation_losses
146
+ )
147
+ else:
148
+ if len(distillation_loss_weights) != len(distillation_losses):
149
+ raise ValueError(
150
+ f"Number of distillation_loss_weights "
151
+ f"({len(distillation_loss_weights)}) must match "
152
+ f"number of distillation_losses "
153
+ f"({len(distillation_losses)})"
154
+ )
155
+ self.distillation_loss_weights = distillation_loss_weights
156
+
157
+ # Validate distillation loss compatibility and create extractors
158
+ for distillation_loss in self.distillation_losses:
159
+ self._validate_distillation_loss_compatibility(
160
+ teacher, student, distillation_loss
161
+ )
162
+
163
+ self._create_multi_feature_extractors()
164
+
165
+ # Freeze teacher model
166
+ self.teacher.trainable = False
167
+
168
+ # Initialize loss tracking metrics
169
+ self.student_loss_tracker = keras.metrics.Mean(name="student_loss")
170
+ self.distillation_loss_tracker = keras.metrics.Mean(
171
+ name="distillation_loss"
172
+ )
173
+ self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
174
+
175
+ def _validate_models(self, teacher, student):
176
+ """Validate that teacher and student models are compatible."""
177
+ if not isinstance(teacher, keras.Model):
178
+ raise ValueError(
179
+ f"Teacher must be a keras.Model, got {type(teacher)}"
180
+ )
181
+ if not isinstance(student, keras.Model):
182
+ raise ValueError(
183
+ f"Student must be a keras.Model, got {type(student)}"
184
+ )
185
+
186
+ self._validate_input_compatibility(teacher, student)
187
+ self._validate_output_compatibility(teacher, student)
188
+ self._validate_dtype_compatibility(teacher, student)
189
+
190
+ def _assert_shapes_are_compatible(self, shape1, shape2, context):
191
+ """Assert that two shapes are compatible."""
192
+ if len(shape1) != len(shape2):
193
+ raise ValueError(
194
+ f"Teacher and student {context} shapes have different "
195
+ f"dimensions. Teacher: {shape1}, Student: {shape2}."
196
+ )
197
+
198
+ for dim1, dim2 in zip(shape1, shape2):
199
+ if dim1 is not None and dim2 is not None and dim1 != dim2:
200
+ raise ValueError(
201
+ f"Teacher and student {context} shapes are incompatible. "
202
+ f"Teacher: {shape1}, Student: {shape2}. "
203
+ f"All dimensions must match."
204
+ )
205
+
206
+ def _assert_same_dtype(self, teacher_dtype, student_dtype, context):
207
+ """Assert that teacher and student dtypes are the same."""
208
+ if teacher_dtype != student_dtype:
209
+ raise ValueError(
210
+ f"Teacher and student {context} dtypes must match. "
211
+ f"Teacher: {teacher_dtype}, Student: {student_dtype}."
212
+ )
213
+
214
+ def _validate_input_compatibility(self, teacher, student):
215
+ """Validate that teacher and student have compatible input shapes."""
216
+ if not hasattr(teacher, "inputs") or not hasattr(student, "inputs"):
217
+ return
218
+ teacher_inputs = getattr(teacher, "inputs")
219
+ student_inputs = getattr(student, "inputs")
220
+ if teacher_inputs is None or student_inputs is None:
221
+ return
222
+
223
+ tree.map_structure(
224
+ lambda ti, si: self._assert_shapes_are_compatible(
225
+ ti.shape, si.shape, "input"
226
+ ),
227
+ teacher_inputs,
228
+ student_inputs,
229
+ )
230
+
231
+ def _validate_output_compatibility(self, teacher, student):
232
+ """Validate that teacher and student have compatible output shapes."""
233
+ if not hasattr(teacher, "outputs") or not hasattr(student, "outputs"):
234
+ return
235
+ teacher_outputs = getattr(teacher, "outputs")
236
+ student_outputs = getattr(student, "outputs")
237
+ if teacher_outputs is None or student_outputs is None:
238
+ return
239
+
240
+ tree.map_structure(
241
+ lambda to, so: self._assert_shapes_are_compatible(
242
+ to.shape, so.shape, "output"
243
+ ),
244
+ teacher_outputs,
245
+ student_outputs,
246
+ )
247
+
248
+ def _validate_dtype_compatibility(self, teacher, student):
249
+ """Validate that teacher and student have compatible data types."""
250
+ if not hasattr(teacher, "inputs") or not hasattr(student, "inputs"):
251
+ return
252
+ if teacher.inputs is None or student.inputs is None:
253
+ return
254
+
255
+ tree.map_structure(
256
+ lambda ti, si: self._assert_same_dtype(ti.dtype, si.dtype, "input"),
257
+ teacher.inputs,
258
+ student.inputs,
259
+ )
260
+
261
+ if not hasattr(teacher, "outputs") or not hasattr(student, "outputs"):
262
+ return
263
+ if teacher.outputs is None or student.outputs is None:
264
+ return
265
+
266
+ tree.map_structure(
267
+ lambda to, so: self._assert_same_dtype(
268
+ to.dtype, so.dtype, "output"
269
+ ),
270
+ teacher.outputs,
271
+ student.outputs,
272
+ )
273
+
274
+ def _validate_distillation_loss_compatibility(
275
+ self, teacher, student, distillation_loss
276
+ ):
277
+ """Validate that the distillation loss is compatible with teacher
278
+ and student models."""
279
+ distillation_loss.validate_model_compatibility(teacher, student)
280
+
281
+ def _create_multi_feature_extractors(self):
282
+ """Create feature extractors for efficient multi-layer extraction."""
283
+ teacher_layer_names = []
284
+ student_layer_names = []
285
+
286
+ for distillation_loss in self.distillation_losses:
287
+ if (
288
+ hasattr(distillation_loss, "teacher_layer_name")
289
+ and distillation_loss.teacher_layer_name
290
+ ):
291
+ if (
292
+ distillation_loss.teacher_layer_name
293
+ not in teacher_layer_names
294
+ ):
295
+ teacher_layer_names.append(
296
+ distillation_loss.teacher_layer_name
297
+ )
298
+ if (
299
+ hasattr(distillation_loss, "student_layer_name")
300
+ and distillation_loss.student_layer_name
301
+ ):
302
+ if (
303
+ distillation_loss.student_layer_name
304
+ not in student_layer_names
305
+ ):
306
+ student_layer_names.append(
307
+ distillation_loss.student_layer_name
308
+ )
309
+
310
+ self._teacher_feature_extractor = self._create_feature_extractor(
311
+ self.teacher, teacher_layer_names
312
+ )
313
+ self._student_feature_extractor = self._create_feature_extractor(
314
+ self.student, student_layer_names
315
+ )
316
+
317
+ def _create_feature_extractor(self, model, layer_names):
318
+ """Create a feature extractor for a model.
319
+
320
+ Arguments:
321
+ model: The model to create an extractor for.
322
+ layer_names: List of layer names to extract features from.
323
+
324
+ Returns:
325
+ Feature extractor model or `None` if no layer names provided.
326
+
327
+ Raises:
328
+ ValueError: If model has no symbolic inputs/outputs.
329
+ """
330
+ if not layer_names:
331
+ return None
332
+
333
+ if not hasattr(model, "inputs") or model.inputs is None:
334
+ raise ValueError(
335
+ f"Cannot create feature extractor for {model.name}. "
336
+ f"The model has no symbolic inputs attribute."
337
+ )
338
+
339
+ if isinstance(model, keras.Sequential):
340
+ final_output = model.layers[-1].output
341
+ else:
342
+ final_output = model.output
343
+
344
+ outputs = {"final_output": final_output}
345
+ for layer_name in layer_names:
346
+ layer = model.get_layer(name=layer_name)
347
+ outputs[layer_name] = layer.output
348
+
349
+ return keras.Model(
350
+ inputs=model.inputs,
351
+ outputs=outputs,
352
+ name=f"{model.name}_multi_feature_extractor",
353
+ )
354
+
355
+ def _extract_all_teacher_features(self, x):
356
+ """Extract all teacher features in a single forward pass."""
357
+ if self._teacher_feature_extractor is not None:
358
+ return self._teacher_feature_extractor(x, training=False)
359
+ else:
360
+ return {"final_output": self.teacher(x, training=False)}
361
+
362
+ def _extract_all_student_features(self, x, y_pred):
363
+ """Extract all student features in a single forward pass."""
364
+ if self._student_feature_extractor is not None:
365
+ return self._student_feature_extractor(x, training=True)
366
+ else:
367
+ return {"final_output": y_pred}
368
+
369
+ def _get_distillation_loss_features(
370
+ self, distillation_loss, all_features, is_teacher
371
+ ):
372
+ """Get the specific features needed by a distillation loss."""
373
+ if is_teacher:
374
+ layer_name = distillation_loss.teacher_layer_name or "final_output"
375
+ else:
376
+ layer_name = distillation_loss.student_layer_name or "final_output"
377
+
378
+ if layer_name not in all_features:
379
+ raise ValueError(
380
+ f"Layer '{layer_name}' not found in extracted features. "
381
+ f"Available: {list(all_features.keys())}"
382
+ )
383
+
384
+ return all_features[layer_name]
385
+
386
+ def compile(self, optimizer="adam", loss=None, metrics=None, **kwargs):
387
+ """Compile the distiller with proper integration.
388
+
389
+ Arguments:
390
+ optimizer: Optimizer for training the student model.
391
+ loss: Student loss function for the student's supervised learning.
392
+ Can be a string identifier or a loss function instance.
393
+ metrics: Additional metrics to track during training.
394
+ **kwargs: Additional arguments passed to parent compile.
395
+ """
396
+ if loss is None:
397
+ raise ValueError("'loss' cannot be `None`.")
398
+
399
+ self._student_loss = tree.map_structure(_convert_loss_to_function, loss)
400
+ self._student_loss_for_serialization = loss
401
+
402
+ if metrics is not None and not isinstance(metrics, (list, tuple)):
403
+ raise ValueError(
404
+ f"metrics must be a list or tuple, got {type(metrics)}"
405
+ )
406
+
407
+ super().compile(
408
+ optimizer=optimizer,
409
+ loss=None,
410
+ metrics=metrics,
411
+ **kwargs,
412
+ )
413
+
414
+ def call(self, inputs, training=None, **kwargs):
415
+ """Forward pass returns student predictions."""
416
+ return self.student(inputs, training=training, **kwargs)
417
+
418
+ def compute_loss(
419
+ self, x=None, y=None, y_pred=None, sample_weight=None, training=True
420
+ ):
421
+ """Compute combined distillation loss.
422
+
423
+ Arguments:
424
+ x: Input data.
425
+ y: Target data.
426
+ y_pred: Model predictions.
427
+ sample_weight: Sample weights (currently unused).
428
+ training: Whether the model is in training mode.
429
+
430
+ Returns:
431
+ Combined loss tensor.
432
+ """
433
+ # Handle case where y_pred is not provided
434
+ if y_pred is None:
435
+ y_pred = self(x, training=training)
436
+ # Compute student loss
437
+ student_loss = 0.0
438
+ if self.student_loss_weight > 0.0 and y is not None:
439
+ loss_values = tree.map_structure(
440
+ lambda l, o, o_pred: l(o, o_pred),
441
+ self._student_loss,
442
+ y,
443
+ y_pred,
444
+ )
445
+ flat_losses = tree.flatten(loss_values)
446
+ student_loss = (
447
+ keras.ops.sum(keras.ops.stack(flat_losses))
448
+ if len(flat_losses) > 1
449
+ else flat_losses[0]
450
+ )
451
+
452
+ # Ensure student_loss is a scalar
453
+ if hasattr(student_loss, "shape") and len(student_loss.shape) > 0:
454
+ student_loss = keras.ops.mean(student_loss)
455
+
456
+ # Compute distillation loss
457
+ distillation_loss = 0.0
458
+ if self.student_loss_weight < 1.0:
459
+ teacher_features = self._extract_all_teacher_features(x)
460
+ student_features = self._extract_all_student_features(x, y_pred)
461
+
462
+ # Apply distillation losses using pre-extracted features
463
+ for distillation_loss_fn, weight in zip(
464
+ self.distillation_losses, self.distillation_loss_weights
465
+ ):
466
+ # Get appropriate outputs/features for this distillation loss
467
+ if (
468
+ hasattr(distillation_loss_fn, "teacher_layer_name")
469
+ and distillation_loss_fn.teacher_layer_name is not None
470
+ ):
471
+ # FeatureDistillation with specific layers
472
+ try:
473
+ distillation_loss_teacher_output = (
474
+ self._get_distillation_loss_features(
475
+ distillation_loss_fn,
476
+ teacher_features,
477
+ is_teacher=True,
478
+ )
479
+ )
480
+ distillation_loss_student_output = (
481
+ self._get_distillation_loss_features(
482
+ distillation_loss_fn,
483
+ student_features,
484
+ is_teacher=False,
485
+ )
486
+ )
487
+ except ValueError as e:
488
+ # Re-raise with context about which loss failed
489
+ raise RuntimeError(
490
+ f"Failed to extract features for "
491
+ f"{type(distillation_loss_fn).__name__} "
492
+ f"targeting teacher layer "
493
+ f"'{distillation_loss_fn.teacher_layer_name}' "
494
+ f"and student layer "
495
+ f"'{distillation_loss_fn.student_layer_name}'. "
496
+ f"Original error: {e}"
497
+ ) from e
498
+ else:
499
+ # LogitsDistillation or FeatureDistillation (final outputs)
500
+ distillation_loss_teacher_output = teacher_features[
501
+ "final_output"
502
+ ]
503
+ distillation_loss_student_output = y_pred
504
+
505
+ # Validate outputs are compatible for this distillation loss
506
+ distillation_loss_fn.validate_outputs(
507
+ distillation_loss_teacher_output,
508
+ distillation_loss_student_output,
509
+ )
510
+
511
+ # Compute loss for this distillation loss
512
+ current_distillation_loss = distillation_loss_fn.compute_loss(
513
+ distillation_loss_teacher_output,
514
+ distillation_loss_student_output,
515
+ )
516
+
517
+ # Validate that distillation loss returns a scalar
518
+ if (
519
+ hasattr(current_distillation_loss, "shape")
520
+ and len(current_distillation_loss.shape) > 0
521
+ ):
522
+ raise ValueError(
523
+ f"Distillation loss "
524
+ f"{distillation_loss_fn.__class__.__name__} "
525
+ f"returned a non-scalar loss with shape "
526
+ f"{current_distillation_loss.shape}. "
527
+ f"The compute_loss method must return a scalar "
528
+ f"tensor."
529
+ )
530
+
531
+ # Apply weight and add to total
532
+ distillation_loss = keras.ops.add(
533
+ distillation_loss,
534
+ keras.ops.multiply(weight, current_distillation_loss),
535
+ )
536
+
537
+ # Combine losses
538
+ total_loss = keras.ops.add(
539
+ keras.ops.multiply(self.student_loss_weight, student_loss),
540
+ keras.ops.multiply(
541
+ keras.ops.subtract(1.0, self.student_loss_weight),
542
+ distillation_loss,
543
+ ),
544
+ )
545
+
546
+ # Update metrics
547
+ self.student_loss_tracker.update_state(student_loss)
548
+ self.distillation_loss_tracker.update_state(distillation_loss)
549
+ self.total_loss_tracker.update_state(total_loss)
550
+
551
+ return total_loss
552
+
553
+ def reset_metrics(self):
554
+ """Reset all metrics."""
555
+ super().reset_metrics()
556
+ self.student_loss_tracker.reset_state()
557
+ self.distillation_loss_tracker.reset_state()
558
+ self.total_loss_tracker.reset_state()
559
+
560
+ def get_config(self):
561
+ """Get configuration for serialization."""
562
+ config = super().get_config()
563
+ config.update(
564
+ {
565
+ "teacher": serialization_lib.serialize_keras_object(
566
+ self.teacher
567
+ ),
568
+ "student": serialization_lib.serialize_keras_object(
569
+ self.student
570
+ ),
571
+ "distillation_losses": [
572
+ serialization_lib.serialize_keras_object(distillation_loss)
573
+ for distillation_loss in self.distillation_losses
574
+ ],
575
+ "distillation_loss_weights": self.distillation_loss_weights,
576
+ "student_loss_weight": self.student_loss_weight,
577
+ }
578
+ )
579
+ return config
580
+
581
+ @classmethod
582
+ def from_config(cls, config):
583
+ """Create instance from configuration."""
584
+ config = config.copy()
585
+
586
+ # Deserialize objects
587
+ config["teacher"] = serialization_lib.deserialize_keras_object(
588
+ config["teacher"]
589
+ )
590
+ config["student"] = serialization_lib.deserialize_keras_object(
591
+ config["student"]
592
+ )
593
+ config["distillation_losses"] = [
594
+ serialization_lib.deserialize_keras_object(distillation_loss)
595
+ for distillation_loss in config["distillation_losses"]
596
+ ]
597
+
598
+ return cls(**config)
@@ -39,6 +39,20 @@ def list_devices(device_type=None):
39
39
  return distribution_lib.list_devices(device_type)
40
40
 
41
41
 
42
+ @keras_export("keras.distribution.get_device_count")
43
+ def get_device_count(device_type=None):
44
+ """Returns the number of available JAX devices.
45
+ Args:
46
+ device_type: Optional device type to count (e.g., "cpu", "gpu", "tpu").
47
+ If `None`, it defaults to counting "gpu" or "tpu" devices if
48
+ available, otherwise it counts "cpu" devices. It does not
49
+ return the sum of all device types.
50
+ Returns:
51
+ int: The total number of JAX devices for the specified type.
52
+ """
53
+ return distribution_lib.get_device_count(device_type=device_type)
54
+
55
+
42
56
  @keras_export("keras.distribution.initialize")
43
57
  def initialize(job_addresses=None, num_processes=None, process_id=None):
44
58
  """Initialize the distribution system for multi-host/process setting.
@@ -2,18 +2,22 @@ from keras.src import backend
2
2
  from keras.src.api_export import keras_export
3
3
  from keras.src.dtype_policies import dtype_policy
4
4
  from keras.src.dtype_policies.dtype_policy import QUANTIZATION_MODES
5
+ from keras.src.dtype_policies.dtype_policy import AWQDTypePolicy
5
6
  from keras.src.dtype_policies.dtype_policy import DTypePolicy
6
7
  from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy
8
+ from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy
7
9
  from keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy
8
10
  from keras.src.dtype_policies.dtype_policy import QuantizedFloat8DTypePolicy
9
11
  from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
10
12
 
11
13
  ALL_OBJECTS = {
14
+ AWQDTypePolicy,
12
15
  DTypePolicy,
13
16
  FloatDTypePolicy,
14
17
  QuantizedDTypePolicy,
15
18
  QuantizedFloat8DTypePolicy,
16
19
  DTypePolicyMap,
20
+ GPTQDTypePolicy,
17
21
  }
18
22
  ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
19
23