keras-nightly 3.12.0.dev2025092403__py3-none-any.whl → 3.14.0.dev2026010104__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +12 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +12 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +1 -1
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +33 -16
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +485 -20
- keras/src/backend/jax/numpy.py +92 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +76 -7
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1030 -185
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +264 -54
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +84 -8
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +299 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +191 -172
- keras/src/layers/core/einsum_dense.py +235 -186
- keras/src/layers/core/embedding.py +83 -93
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +390 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +40 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/losses/loss.py +1 -1
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +156 -27
- keras/src/ops/image.py +184 -3
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +541 -43
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +12 -1
- keras/src/quantizers/gptq.py +8 -6
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +150 -78
- keras/src/quantizers/quantization_config.py +232 -0
- keras/src/quantizers/quantizers.py +114 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +4 -2
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +14 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +187 -36
- keras/src/utils/module_utils.py +18 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/rng_utils.py +9 -1
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/RECORD +133 -116
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Distillation module for knowledge distillation in Keras."""
|
|
@@ -0,0 +1,390 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
from keras.src import tree
|
|
3
|
+
from keras.src.api_export import keras_export
|
|
4
|
+
from keras.src.saving import serialization_lib
|
|
5
|
+
from keras.src.utils import tracking
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _convert_loss_to_function(loss_item):
|
|
9
|
+
"""Convert a loss string identifier to a loss function.
|
|
10
|
+
|
|
11
|
+
Arguments:
|
|
12
|
+
loss_item: Either a string identifier, a loss function instance,
|
|
13
|
+
or `None`.
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
A loss function instance, or `None`.
|
|
17
|
+
|
|
18
|
+
Raises:
|
|
19
|
+
ValueError: If the loss string identifier is unknown.
|
|
20
|
+
"""
|
|
21
|
+
if loss_item is None:
|
|
22
|
+
return None
|
|
23
|
+
elif isinstance(loss_item, str):
|
|
24
|
+
loss_fn = keras.losses.get(loss_item)
|
|
25
|
+
if loss_fn is None:
|
|
26
|
+
raise ValueError(f"Unknown loss function: '{loss_item}'.")
|
|
27
|
+
return loss_fn
|
|
28
|
+
else:
|
|
29
|
+
return loss_item
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@keras_export("keras.distillation.DistillationLoss")
|
|
33
|
+
class DistillationLoss:
|
|
34
|
+
"""Base class for distillation loss computation.
|
|
35
|
+
|
|
36
|
+
Distillation losses define how to compute the distillation loss
|
|
37
|
+
between teacher and student outputs. Each loss implements a specific
|
|
38
|
+
approach to knowledge transfer, from simple logits matching to feature-based
|
|
39
|
+
distillation.
|
|
40
|
+
|
|
41
|
+
To create custom distillation losses, subclass this class and
|
|
42
|
+
override the `compute_loss` method.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def compute_loss(self, teacher_outputs, student_outputs, **kwargs):
|
|
46
|
+
"""Compute distillation loss between teacher and student outputs.
|
|
47
|
+
|
|
48
|
+
This method should implement the specific distillation logic for
|
|
49
|
+
transferring knowledge from teacher to student.
|
|
50
|
+
|
|
51
|
+
Arguments:
|
|
52
|
+
teacher_outputs: Outputs from the teacher model. Can be a single
|
|
53
|
+
tensor or a list/tuple of tensors for multi-output models.
|
|
54
|
+
student_outputs: Outputs from the student model. Can be a single
|
|
55
|
+
tensor or a list/tuple of tensors for multi-output models.
|
|
56
|
+
**kwargs: Additional arguments for custom distillation_loss.
|
|
57
|
+
Returns:
|
|
58
|
+
Distillation loss tensor.
|
|
59
|
+
"""
|
|
60
|
+
raise NotImplementedError("Subclasses must implement compute_loss")
|
|
61
|
+
|
|
62
|
+
def validate_outputs(self, teacher_outputs, student_outputs):
|
|
63
|
+
"""Validate that teacher and student outputs are compatible.
|
|
64
|
+
|
|
65
|
+
Arguments:
|
|
66
|
+
teacher_outputs: Outputs from the teacher model.
|
|
67
|
+
student_outputs: Outputs from the student model.
|
|
68
|
+
Raises:
|
|
69
|
+
ValueError: If outputs are not compatible.
|
|
70
|
+
"""
|
|
71
|
+
keras.tree.assert_same_structure(teacher_outputs, student_outputs)
|
|
72
|
+
|
|
73
|
+
def validate_model_compatibility(self, teacher, student):
|
|
74
|
+
"""Validate that teacher and student models are compatible.
|
|
75
|
+
|
|
76
|
+
Arguments:
|
|
77
|
+
teacher: The teacher model.
|
|
78
|
+
student: The student model.
|
|
79
|
+
Raises:
|
|
80
|
+
ValueError: If models are not compatible with this distillation
|
|
81
|
+
loss.
|
|
82
|
+
"""
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@keras_export("keras.distillation.FeatureDistillation")
|
|
87
|
+
class FeatureDistillation(DistillationLoss):
|
|
88
|
+
"""Feature distillation loss.
|
|
89
|
+
|
|
90
|
+
Feature distillation transfers knowledge from intermediate layers of the
|
|
91
|
+
teacher model to corresponding layers of the student model. This approach
|
|
92
|
+
helps the student learn better internal representations and often leads
|
|
93
|
+
to better performance compared to logits-only distillation.
|
|
94
|
+
|
|
95
|
+
Arguments:
|
|
96
|
+
loss: Loss function to use for feature distillation. Can be:
|
|
97
|
+
- String identifier (e.g., 'mse', 'cosine_similarity', 'mae')
|
|
98
|
+
- Keras loss instance
|
|
99
|
+
- Nested structure of losses matching the layer output structure
|
|
100
|
+
- `None` to skip distillation for that output (useful for
|
|
101
|
+
multi-output models where you only want to distill some outputs)
|
|
102
|
+
At least one loss must be non-`None`. Defaults to 'mse'.
|
|
103
|
+
teacher_layer_name: Name of the teacher layer to extract features from.
|
|
104
|
+
If `None`, uses the final output. Defaults to `None`.
|
|
105
|
+
student_layer_name: Name of the student layer to extract features from.
|
|
106
|
+
If `None`, uses the final output. Defaults to `None`.
|
|
107
|
+
|
|
108
|
+
Examlpe(s):
|
|
109
|
+
|
|
110
|
+
```python
|
|
111
|
+
# Basic feature distillation from final outputs
|
|
112
|
+
distillation_loss = FeatureDistillation(loss="mse")
|
|
113
|
+
|
|
114
|
+
# Distill from specific intermediate layers
|
|
115
|
+
distillation_loss = FeatureDistillation(
|
|
116
|
+
loss="mse",
|
|
117
|
+
teacher_layer_name="dense_1",
|
|
118
|
+
student_layer_name="dense_1"
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# Use cosine similarity for different feature sizes
|
|
122
|
+
distillation_loss = FeatureDistillation(
|
|
123
|
+
loss="cosine_similarity",
|
|
124
|
+
teacher_layer_name="conv2d_2",
|
|
125
|
+
student_layer_name="conv2d_1"
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# With custom loss instance
|
|
129
|
+
distillation_loss = FeatureDistillation(
|
|
130
|
+
loss=keras.losses.MeanAbsoluteError()
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# For multi-output models
|
|
134
|
+
distillation_loss = FeatureDistillation(
|
|
135
|
+
loss=["mse", "cosine_similarity"]
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# For multi-output models, only distill some outputs
|
|
139
|
+
distillation_loss = FeatureDistillation(
|
|
140
|
+
loss=["mse", None, "cosine_similarity"] # Skip middle output
|
|
141
|
+
)
|
|
142
|
+
```
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
@tracking.no_automatic_dependency_tracking
|
|
146
|
+
def __init__(
|
|
147
|
+
self, loss="mse", teacher_layer_name=None, student_layer_name=None
|
|
148
|
+
):
|
|
149
|
+
self.teacher_layer_name = teacher_layer_name
|
|
150
|
+
self.student_layer_name = student_layer_name
|
|
151
|
+
self.loss = tree.map_structure(_convert_loss_to_function, loss)
|
|
152
|
+
|
|
153
|
+
flat_losses = tree.flatten(self.loss)
|
|
154
|
+
if all(l is None for l in flat_losses):
|
|
155
|
+
raise ValueError(
|
|
156
|
+
"The `loss` argument in `FeatureDistillation` must "
|
|
157
|
+
"contain at least one non-`None` value."
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
def validate_model_compatibility(self, teacher, student):
|
|
161
|
+
"""Validate that teacher and student models are compatible for feature
|
|
162
|
+
distillation."""
|
|
163
|
+
if (
|
|
164
|
+
self.teacher_layer_name is not None
|
|
165
|
+
or self.student_layer_name is not None
|
|
166
|
+
):
|
|
167
|
+
teacher_is_subclassed = (
|
|
168
|
+
not hasattr(teacher, "inputs") or teacher.inputs is None
|
|
169
|
+
)
|
|
170
|
+
student_is_subclassed = (
|
|
171
|
+
not hasattr(student, "inputs") or student.inputs is None
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
if teacher_is_subclassed or student_is_subclassed:
|
|
175
|
+
subclassed_models = []
|
|
176
|
+
if teacher_is_subclassed:
|
|
177
|
+
subclassed_models.append("teacher")
|
|
178
|
+
if student_is_subclassed:
|
|
179
|
+
subclassed_models.append("student")
|
|
180
|
+
|
|
181
|
+
models_str = " and ".join(subclassed_models)
|
|
182
|
+
raise ValueError(
|
|
183
|
+
f"FeatureDistillation with specific layer names requires "
|
|
184
|
+
f"Functional or Sequential models. The {models_str} "
|
|
185
|
+
f"model(s) appear to be subclassed (no symbolic "
|
|
186
|
+
f"inputs/outputs). Either use Functional/Sequential "
|
|
187
|
+
f"models, or use FeatureDistillation without layer names "
|
|
188
|
+
f"(to distill final outputs only), or use "
|
|
189
|
+
f"LogitsDistillation instead."
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
if self.teacher_layer_name is not None:
|
|
193
|
+
try:
|
|
194
|
+
teacher.get_layer(name=self.teacher_layer_name)
|
|
195
|
+
except ValueError as e:
|
|
196
|
+
raise ValueError(f"In teacher model: {e}")
|
|
197
|
+
|
|
198
|
+
if self.student_layer_name is not None:
|
|
199
|
+
try:
|
|
200
|
+
student.get_layer(name=self.student_layer_name)
|
|
201
|
+
except ValueError as e:
|
|
202
|
+
raise ValueError(f"In student model: {e}")
|
|
203
|
+
|
|
204
|
+
def validate_outputs(self, teacher_outputs, student_outputs):
|
|
205
|
+
"""Validate that outputs are compatible for feature distillation."""
|
|
206
|
+
super().validate_outputs(teacher_outputs, student_outputs)
|
|
207
|
+
|
|
208
|
+
try:
|
|
209
|
+
tree.assert_same_structure(self.loss, teacher_outputs)
|
|
210
|
+
except ValueError as e:
|
|
211
|
+
raise ValueError(
|
|
212
|
+
f"Loss structure mismatch. "
|
|
213
|
+
f"Loss structure: {tree.structure(self.loss)}, "
|
|
214
|
+
f"Output structure: {tree.structure(teacher_outputs)}. "
|
|
215
|
+
f"Error: {e}"
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
def compute_loss(self, teacher_outputs, student_outputs, **kwargs):
|
|
219
|
+
"""Compute feature distillation loss using extracted features.
|
|
220
|
+
|
|
221
|
+
Arguments:
|
|
222
|
+
teacher_outputs: Extracted features from teacher layer.
|
|
223
|
+
student_outputs: Extracted features from student layer.
|
|
224
|
+
**kwargs: Additional arguments (ignored).
|
|
225
|
+
Returns:
|
|
226
|
+
Scalar distillation loss tensor.
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
def apply_loss(loss_fn, teacher_features, student_features):
|
|
230
|
+
if loss_fn is None:
|
|
231
|
+
return 0.0
|
|
232
|
+
|
|
233
|
+
loss = keras.ops.mean(loss_fn(teacher_features, student_features))
|
|
234
|
+
|
|
235
|
+
return loss
|
|
236
|
+
|
|
237
|
+
loss_values = tree.map_structure(
|
|
238
|
+
apply_loss, self.loss, teacher_outputs, student_outputs
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
flat_losses = tree.flatten(loss_values)
|
|
242
|
+
return keras.ops.sum(keras.ops.stack(flat_losses))
|
|
243
|
+
|
|
244
|
+
def get_config(self):
|
|
245
|
+
"""Get configuration for serialization."""
|
|
246
|
+
return {
|
|
247
|
+
"loss": keras.losses.serialize(self.loss),
|
|
248
|
+
"teacher_layer_name": self.teacher_layer_name,
|
|
249
|
+
"student_layer_name": self.student_layer_name,
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
@classmethod
|
|
253
|
+
def from_config(cls, config):
|
|
254
|
+
"""Create instance from configuration."""
|
|
255
|
+
config = config.copy()
|
|
256
|
+
config["loss"] = keras.losses.deserialize(config["loss"])
|
|
257
|
+
return cls(**config)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
@keras_export("keras.distillation.LogitsDistillation")
|
|
261
|
+
class LogitsDistillation(DistillationLoss):
|
|
262
|
+
"""Distillation loss that transfers knowledge from final model outputs.
|
|
263
|
+
|
|
264
|
+
This distillation loss applies temperature scaling to the teacher's logits
|
|
265
|
+
before computing the loss between teacher and student predictions. It's the
|
|
266
|
+
most common approach for knowledge distillation.
|
|
267
|
+
|
|
268
|
+
Arguments:
|
|
269
|
+
temperature: Temperature for softmax scaling. Higher values produce
|
|
270
|
+
softer probability distributions that are easier for the student to
|
|
271
|
+
learn. Typical values range from 3-5. Defaults to 3.0.
|
|
272
|
+
loss: Loss function to use for distillation. Can be:
|
|
273
|
+
- String identifier (e.g., 'kl_divergence',
|
|
274
|
+
'categorical_crossentropy')
|
|
275
|
+
- Keras loss instance
|
|
276
|
+
- Nested structure of losses matching the model output structure
|
|
277
|
+
- `None` to skip distillation for that output (useful for
|
|
278
|
+
multi-output models where you only want to distill some outputs)
|
|
279
|
+
At least one loss must be non-`None`. Defaults to 'kl_divergence'.
|
|
280
|
+
|
|
281
|
+
Examlpe(s):
|
|
282
|
+
|
|
283
|
+
```python
|
|
284
|
+
# Basic logits distillation with KL divergence
|
|
285
|
+
distillation_loss = LogitsDistillation(temperature=3.0)
|
|
286
|
+
|
|
287
|
+
# With categorical crossentropy loss
|
|
288
|
+
distillation_loss = LogitsDistillation(
|
|
289
|
+
temperature=4.0,
|
|
290
|
+
loss="categorical_crossentropy"
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
# With custom loss instance
|
|
294
|
+
distillation_loss = LogitsDistillation(
|
|
295
|
+
temperature=4.0,
|
|
296
|
+
loss=keras.losses.CategoricalCrossentropy(from_logits=True)
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
# For multi-output models
|
|
300
|
+
distillation_loss = LogitsDistillation(
|
|
301
|
+
temperature=3.0,
|
|
302
|
+
loss=["kl_divergence", "categorical_crossentropy"]
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
# For multi-output models, only distill some outputs
|
|
306
|
+
distillation_loss = LogitsDistillation(
|
|
307
|
+
temperature=3.0,
|
|
308
|
+
loss=["kl_divergence", None] # Skip second output
|
|
309
|
+
)
|
|
310
|
+
```
|
|
311
|
+
"""
|
|
312
|
+
|
|
313
|
+
@tracking.no_automatic_dependency_tracking
|
|
314
|
+
def __init__(
|
|
315
|
+
self,
|
|
316
|
+
temperature=3.0,
|
|
317
|
+
loss="kl_divergence",
|
|
318
|
+
):
|
|
319
|
+
self.temperature = temperature
|
|
320
|
+
self.loss = tree.map_structure(_convert_loss_to_function, loss)
|
|
321
|
+
|
|
322
|
+
flat_losses = tree.flatten(self.loss)
|
|
323
|
+
if all(l is None for l in flat_losses):
|
|
324
|
+
raise ValueError("At least one loss must be non-`None`.")
|
|
325
|
+
|
|
326
|
+
if not isinstance(self.temperature, (int, float)):
|
|
327
|
+
raise ValueError(
|
|
328
|
+
f"temperature must be a number, got {type(self.temperature)}"
|
|
329
|
+
)
|
|
330
|
+
if self.temperature <= 0.0:
|
|
331
|
+
raise ValueError("temperature must be positive.")
|
|
332
|
+
|
|
333
|
+
def compute_loss(self, teacher_outputs, student_outputs, **kwargs):
|
|
334
|
+
"""Compute distillation loss using the configured loss function.
|
|
335
|
+
|
|
336
|
+
Arguments:
|
|
337
|
+
teacher_outputs: Logits from teacher model. Can be a single tensor,
|
|
338
|
+
list/tuple of tensors, or dict of tensors.
|
|
339
|
+
student_outputs: Logits from student model. Can be a single tensor,
|
|
340
|
+
list/tuple of tensors, or dict of tensors.
|
|
341
|
+
**kwargs: Additional arguments (ignored).
|
|
342
|
+
Returns:
|
|
343
|
+
Distillation loss tensor.
|
|
344
|
+
"""
|
|
345
|
+
# Apply temperature scaling using tree.map_structure
|
|
346
|
+
teacher_scaled = tree.map_structure(
|
|
347
|
+
lambda x: keras.ops.divide(x, self.temperature), teacher_outputs
|
|
348
|
+
)
|
|
349
|
+
student_scaled = tree.map_structure(
|
|
350
|
+
lambda x: keras.ops.divide(x, self.temperature), student_outputs
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
# Apply loss function(s) to corresponding outputs
|
|
354
|
+
def apply_loss(loss_fn, teacher_logits, student_logits):
|
|
355
|
+
if loss_fn is None:
|
|
356
|
+
return 0.0
|
|
357
|
+
|
|
358
|
+
# Special handling for KL divergence (needs probabilities)
|
|
359
|
+
if isinstance(loss_fn, keras.losses.KLDivergence):
|
|
360
|
+
teacher_probs = keras.ops.softmax(teacher_logits, axis=-1)
|
|
361
|
+
student_probs = keras.ops.softmax(student_logits, axis=-1)
|
|
362
|
+
loss = keras.ops.mean(loss_fn(teacher_probs, student_probs))
|
|
363
|
+
# Scale by temperature^2 for KL (per literature)
|
|
364
|
+
return loss * (self.temperature**2)
|
|
365
|
+
else:
|
|
366
|
+
# For other losses, use logits directly
|
|
367
|
+
return keras.ops.mean(loss_fn(teacher_logits, student_logits))
|
|
368
|
+
|
|
369
|
+
# Apply losses using tree.map_structure
|
|
370
|
+
loss_values = tree.map_structure(
|
|
371
|
+
apply_loss, self.loss, teacher_scaled, student_scaled
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
# Sum all losses and return scalar
|
|
375
|
+
flat_losses = tree.flatten(loss_values)
|
|
376
|
+
return keras.ops.sum(keras.ops.stack(flat_losses))
|
|
377
|
+
|
|
378
|
+
def get_config(self):
|
|
379
|
+
"""Get configuration for serialization."""
|
|
380
|
+
return {
|
|
381
|
+
"temperature": self.temperature,
|
|
382
|
+
"loss": serialization_lib.serialize_keras_object(self.loss),
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
@classmethod
|
|
386
|
+
def from_config(cls, config):
|
|
387
|
+
"""Create instance from configuration."""
|
|
388
|
+
config = config.copy()
|
|
389
|
+
config["loss"] = keras.losses.deserialize(config["loss"])
|
|
390
|
+
return cls(**config)
|