keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,390 @@
1
+ import keras
2
+ from keras.src import tree
3
+ from keras.src.api_export import keras_export
4
+ from keras.src.saving import serialization_lib
5
+ from keras.src.utils import tracking
6
+
7
+
8
+ def _convert_loss_to_function(loss_item):
9
+ """Convert a loss string identifier to a loss function.
10
+
11
+ Arguments:
12
+ loss_item: Either a string identifier, a loss function instance,
13
+ or `None`.
14
+
15
+ Returns:
16
+ A loss function instance, or `None`.
17
+
18
+ Raises:
19
+ ValueError: If the loss string identifier is unknown.
20
+ """
21
+ if loss_item is None:
22
+ return None
23
+ elif isinstance(loss_item, str):
24
+ loss_fn = keras.losses.get(loss_item)
25
+ if loss_fn is None:
26
+ raise ValueError(f"Unknown loss function: '{loss_item}'.")
27
+ return loss_fn
28
+ else:
29
+ return loss_item
30
+
31
+
32
+ @keras_export("keras.distillation.DistillationLoss")
33
+ class DistillationLoss:
34
+ """Base class for distillation loss computation.
35
+
36
+ Distillation losses define how to compute the distillation loss
37
+ between teacher and student outputs. Each loss implements a specific
38
+ approach to knowledge transfer, from simple logits matching to feature-based
39
+ distillation.
40
+
41
+ To create custom distillation losses, subclass this class and
42
+ override the `compute_loss` method.
43
+ """
44
+
45
+ def compute_loss(self, teacher_outputs, student_outputs, **kwargs):
46
+ """Compute distillation loss between teacher and student outputs.
47
+
48
+ This method should implement the specific distillation logic for
49
+ transferring knowledge from teacher to student.
50
+
51
+ Arguments:
52
+ teacher_outputs: Outputs from the teacher model. Can be a single
53
+ tensor or a list/tuple of tensors for multi-output models.
54
+ student_outputs: Outputs from the student model. Can be a single
55
+ tensor or a list/tuple of tensors for multi-output models.
56
+ **kwargs: Additional arguments for custom distillation_loss.
57
+ Returns:
58
+ Distillation loss tensor.
59
+ """
60
+ raise NotImplementedError("Subclasses must implement compute_loss")
61
+
62
+ def validate_outputs(self, teacher_outputs, student_outputs):
63
+ """Validate that teacher and student outputs are compatible.
64
+
65
+ Arguments:
66
+ teacher_outputs: Outputs from the teacher model.
67
+ student_outputs: Outputs from the student model.
68
+ Raises:
69
+ ValueError: If outputs are not compatible.
70
+ """
71
+ keras.tree.assert_same_structure(teacher_outputs, student_outputs)
72
+
73
+ def validate_model_compatibility(self, teacher, student):
74
+ """Validate that teacher and student models are compatible.
75
+
76
+ Arguments:
77
+ teacher: The teacher model.
78
+ student: The student model.
79
+ Raises:
80
+ ValueError: If models are not compatible with this distillation
81
+ loss.
82
+ """
83
+ pass
84
+
85
+
86
+ @keras_export("keras.distillation.FeatureDistillation")
87
+ class FeatureDistillation(DistillationLoss):
88
+ """Feature distillation loss.
89
+
90
+ Feature distillation transfers knowledge from intermediate layers of the
91
+ teacher model to corresponding layers of the student model. This approach
92
+ helps the student learn better internal representations and often leads
93
+ to better performance compared to logits-only distillation.
94
+
95
+ Arguments:
96
+ loss: Loss function to use for feature distillation. Can be:
97
+ - String identifier (e.g., 'mse', 'cosine_similarity', 'mae')
98
+ - Keras loss instance
99
+ - Nested structure of losses matching the layer output structure
100
+ - `None` to skip distillation for that output (useful for
101
+ multi-output models where you only want to distill some outputs)
102
+ At least one loss must be non-`None`. Defaults to 'mse'.
103
+ teacher_layer_name: Name of the teacher layer to extract features from.
104
+ If `None`, uses the final output. Defaults to `None`.
105
+ student_layer_name: Name of the student layer to extract features from.
106
+ If `None`, uses the final output. Defaults to `None`.
107
+
108
+ Examlpe(s):
109
+
110
+ ```python
111
+ # Basic feature distillation from final outputs
112
+ distillation_loss = FeatureDistillation(loss="mse")
113
+
114
+ # Distill from specific intermediate layers
115
+ distillation_loss = FeatureDistillation(
116
+ loss="mse",
117
+ teacher_layer_name="dense_1",
118
+ student_layer_name="dense_1"
119
+ )
120
+
121
+ # Use cosine similarity for different feature sizes
122
+ distillation_loss = FeatureDistillation(
123
+ loss="cosine_similarity",
124
+ teacher_layer_name="conv2d_2",
125
+ student_layer_name="conv2d_1"
126
+ )
127
+
128
+ # With custom loss instance
129
+ distillation_loss = FeatureDistillation(
130
+ loss=keras.losses.MeanAbsoluteError()
131
+ )
132
+
133
+ # For multi-output models
134
+ distillation_loss = FeatureDistillation(
135
+ loss=["mse", "cosine_similarity"]
136
+ )
137
+
138
+ # For multi-output models, only distill some outputs
139
+ distillation_loss = FeatureDistillation(
140
+ loss=["mse", None, "cosine_similarity"] # Skip middle output
141
+ )
142
+ ```
143
+ """
144
+
145
+ @tracking.no_automatic_dependency_tracking
146
+ def __init__(
147
+ self, loss="mse", teacher_layer_name=None, student_layer_name=None
148
+ ):
149
+ self.teacher_layer_name = teacher_layer_name
150
+ self.student_layer_name = student_layer_name
151
+ self.loss = tree.map_structure(_convert_loss_to_function, loss)
152
+
153
+ flat_losses = tree.flatten(self.loss)
154
+ if all(l is None for l in flat_losses):
155
+ raise ValueError(
156
+ "The `loss` argument in `FeatureDistillation` must "
157
+ "contain at least one non-`None` value."
158
+ )
159
+
160
+ def validate_model_compatibility(self, teacher, student):
161
+ """Validate that teacher and student models are compatible for feature
162
+ distillation."""
163
+ if (
164
+ self.teacher_layer_name is not None
165
+ or self.student_layer_name is not None
166
+ ):
167
+ teacher_is_subclassed = (
168
+ not hasattr(teacher, "inputs") or teacher.inputs is None
169
+ )
170
+ student_is_subclassed = (
171
+ not hasattr(student, "inputs") or student.inputs is None
172
+ )
173
+
174
+ if teacher_is_subclassed or student_is_subclassed:
175
+ subclassed_models = []
176
+ if teacher_is_subclassed:
177
+ subclassed_models.append("teacher")
178
+ if student_is_subclassed:
179
+ subclassed_models.append("student")
180
+
181
+ models_str = " and ".join(subclassed_models)
182
+ raise ValueError(
183
+ f"FeatureDistillation with specific layer names requires "
184
+ f"Functional or Sequential models. The {models_str} "
185
+ f"model(s) appear to be subclassed (no symbolic "
186
+ f"inputs/outputs). Either use Functional/Sequential "
187
+ f"models, or use FeatureDistillation without layer names "
188
+ f"(to distill final outputs only), or use "
189
+ f"LogitsDistillation instead."
190
+ )
191
+
192
+ if self.teacher_layer_name is not None:
193
+ try:
194
+ teacher.get_layer(name=self.teacher_layer_name)
195
+ except ValueError as e:
196
+ raise ValueError(f"In teacher model: {e}")
197
+
198
+ if self.student_layer_name is not None:
199
+ try:
200
+ student.get_layer(name=self.student_layer_name)
201
+ except ValueError as e:
202
+ raise ValueError(f"In student model: {e}")
203
+
204
+ def validate_outputs(self, teacher_outputs, student_outputs):
205
+ """Validate that outputs are compatible for feature distillation."""
206
+ super().validate_outputs(teacher_outputs, student_outputs)
207
+
208
+ try:
209
+ tree.assert_same_structure(self.loss, teacher_outputs)
210
+ except ValueError as e:
211
+ raise ValueError(
212
+ f"Loss structure mismatch. "
213
+ f"Loss structure: {tree.structure(self.loss)}, "
214
+ f"Output structure: {tree.structure(teacher_outputs)}. "
215
+ f"Error: {e}"
216
+ )
217
+
218
+ def compute_loss(self, teacher_outputs, student_outputs, **kwargs):
219
+ """Compute feature distillation loss using extracted features.
220
+
221
+ Arguments:
222
+ teacher_outputs: Extracted features from teacher layer.
223
+ student_outputs: Extracted features from student layer.
224
+ **kwargs: Additional arguments (ignored).
225
+ Returns:
226
+ Scalar distillation loss tensor.
227
+ """
228
+
229
+ def apply_loss(loss_fn, teacher_features, student_features):
230
+ if loss_fn is None:
231
+ return 0.0
232
+
233
+ loss = keras.ops.mean(loss_fn(teacher_features, student_features))
234
+
235
+ return loss
236
+
237
+ loss_values = tree.map_structure(
238
+ apply_loss, self.loss, teacher_outputs, student_outputs
239
+ )
240
+
241
+ flat_losses = tree.flatten(loss_values)
242
+ return keras.ops.sum(keras.ops.stack(flat_losses))
243
+
244
+ def get_config(self):
245
+ """Get configuration for serialization."""
246
+ return {
247
+ "loss": keras.losses.serialize(self.loss),
248
+ "teacher_layer_name": self.teacher_layer_name,
249
+ "student_layer_name": self.student_layer_name,
250
+ }
251
+
252
+ @classmethod
253
+ def from_config(cls, config):
254
+ """Create instance from configuration."""
255
+ config = config.copy()
256
+ config["loss"] = keras.losses.deserialize(config["loss"])
257
+ return cls(**config)
258
+
259
+
260
+ @keras_export("keras.distillation.LogitsDistillation")
261
+ class LogitsDistillation(DistillationLoss):
262
+ """Distillation loss that transfers knowledge from final model outputs.
263
+
264
+ This distillation loss applies temperature scaling to the teacher's logits
265
+ before computing the loss between teacher and student predictions. It's the
266
+ most common approach for knowledge distillation.
267
+
268
+ Arguments:
269
+ temperature: Temperature for softmax scaling. Higher values produce
270
+ softer probability distributions that are easier for the student to
271
+ learn. Typical values range from 3-5. Defaults to 3.0.
272
+ loss: Loss function to use for distillation. Can be:
273
+ - String identifier (e.g., 'kl_divergence',
274
+ 'categorical_crossentropy')
275
+ - Keras loss instance
276
+ - Nested structure of losses matching the model output structure
277
+ - `None` to skip distillation for that output (useful for
278
+ multi-output models where you only want to distill some outputs)
279
+ At least one loss must be non-`None`. Defaults to 'kl_divergence'.
280
+
281
+ Examlpe(s):
282
+
283
+ ```python
284
+ # Basic logits distillation with KL divergence
285
+ distillation_loss = LogitsDistillation(temperature=3.0)
286
+
287
+ # With categorical crossentropy loss
288
+ distillation_loss = LogitsDistillation(
289
+ temperature=4.0,
290
+ loss="categorical_crossentropy"
291
+ )
292
+
293
+ # With custom loss instance
294
+ distillation_loss = LogitsDistillation(
295
+ temperature=4.0,
296
+ loss=keras.losses.CategoricalCrossentropy(from_logits=True)
297
+ )
298
+
299
+ # For multi-output models
300
+ distillation_loss = LogitsDistillation(
301
+ temperature=3.0,
302
+ loss=["kl_divergence", "categorical_crossentropy"]
303
+ )
304
+
305
+ # For multi-output models, only distill some outputs
306
+ distillation_loss = LogitsDistillation(
307
+ temperature=3.0,
308
+ loss=["kl_divergence", None] # Skip second output
309
+ )
310
+ ```
311
+ """
312
+
313
+ @tracking.no_automatic_dependency_tracking
314
+ def __init__(
315
+ self,
316
+ temperature=3.0,
317
+ loss="kl_divergence",
318
+ ):
319
+ self.temperature = temperature
320
+ self.loss = tree.map_structure(_convert_loss_to_function, loss)
321
+
322
+ flat_losses = tree.flatten(self.loss)
323
+ if all(l is None for l in flat_losses):
324
+ raise ValueError("At least one loss must be non-`None`.")
325
+
326
+ if not isinstance(self.temperature, (int, float)):
327
+ raise ValueError(
328
+ f"temperature must be a number, got {type(self.temperature)}"
329
+ )
330
+ if self.temperature <= 0.0:
331
+ raise ValueError("temperature must be positive.")
332
+
333
+ def compute_loss(self, teacher_outputs, student_outputs, **kwargs):
334
+ """Compute distillation loss using the configured loss function.
335
+
336
+ Arguments:
337
+ teacher_outputs: Logits from teacher model. Can be a single tensor,
338
+ list/tuple of tensors, or dict of tensors.
339
+ student_outputs: Logits from student model. Can be a single tensor,
340
+ list/tuple of tensors, or dict of tensors.
341
+ **kwargs: Additional arguments (ignored).
342
+ Returns:
343
+ Distillation loss tensor.
344
+ """
345
+ # Apply temperature scaling using tree.map_structure
346
+ teacher_scaled = tree.map_structure(
347
+ lambda x: keras.ops.divide(x, self.temperature), teacher_outputs
348
+ )
349
+ student_scaled = tree.map_structure(
350
+ lambda x: keras.ops.divide(x, self.temperature), student_outputs
351
+ )
352
+
353
+ # Apply loss function(s) to corresponding outputs
354
+ def apply_loss(loss_fn, teacher_logits, student_logits):
355
+ if loss_fn is None:
356
+ return 0.0
357
+
358
+ # Special handling for KL divergence (needs probabilities)
359
+ if isinstance(loss_fn, keras.losses.KLDivergence):
360
+ teacher_probs = keras.ops.softmax(teacher_logits, axis=-1)
361
+ student_probs = keras.ops.softmax(student_logits, axis=-1)
362
+ loss = keras.ops.mean(loss_fn(teacher_probs, student_probs))
363
+ # Scale by temperature^2 for KL (per literature)
364
+ return loss * (self.temperature**2)
365
+ else:
366
+ # For other losses, use logits directly
367
+ return keras.ops.mean(loss_fn(teacher_logits, student_logits))
368
+
369
+ # Apply losses using tree.map_structure
370
+ loss_values = tree.map_structure(
371
+ apply_loss, self.loss, teacher_scaled, student_scaled
372
+ )
373
+
374
+ # Sum all losses and return scalar
375
+ flat_losses = tree.flatten(loss_values)
376
+ return keras.ops.sum(keras.ops.stack(flat_losses))
377
+
378
+ def get_config(self):
379
+ """Get configuration for serialization."""
380
+ return {
381
+ "temperature": self.temperature,
382
+ "loss": serialization_lib.serialize_keras_object(self.loss),
383
+ }
384
+
385
+ @classmethod
386
+ def from_config(cls, config):
387
+ """Create instance from configuration."""
388
+ config = config.copy()
389
+ config["loss"] = keras.losses.deserialize(config["loss"])
390
+ return cls(**config)