keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +16 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +6 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +16 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +12 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +6 -12
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +38 -20
- keras/src/backend/jax/core.py +126 -78
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/layer.py +3 -1
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +511 -29
- keras/src/backend/jax/numpy.py +109 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +18 -3
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +97 -8
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +6 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1369 -195
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +351 -56
- keras/src/backend/tensorflow/trainer.py +6 -2
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +109 -9
- keras/src/backend/torch/trainer.py +8 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +4 -0
- keras/src/dtype_policies/dtype_policy.py +180 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/onnx.py +6 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +406 -102
- keras/src/layers/core/einsum_dense.py +521 -116
- keras/src/layers/core/embedding.py +257 -99
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +50 -15
- keras/src/layers/merging/concatenate.py +6 -5
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/preprocessing/string_lookup.py +26 -28
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/legacy/preprocessing/image.py +4 -1
- keras/src/legacy/preprocessing/sequence.py +20 -12
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +195 -44
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +701 -44
- keras/src/ops/operation.py +90 -29
- keras/src/ops/operation_utils.py +2 -0
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +346 -207
- keras/src/quantizers/gptq_config.py +63 -13
- keras/src/quantizers/gptq_core.py +328 -215
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +407 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +38 -17
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
- keras/src/tree/torchtree_impl.py +215 -0
- keras/src/tree/tree_api.py +6 -1
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/python_utils.py +5 -0
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +70 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,598 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
from keras.src import tree
|
|
3
|
+
from keras.src.api_export import keras_export
|
|
4
|
+
from keras.src.distillation.distillation_loss import _convert_loss_to_function
|
|
5
|
+
from keras.src.models.model import Model
|
|
6
|
+
from keras.src.saving import serialization_lib
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@keras_export("keras.distillation.Distiller")
|
|
10
|
+
class Distiller(Model):
|
|
11
|
+
"""Distillation model for transferring knowledge from teacher to student.
|
|
12
|
+
|
|
13
|
+
Knowledge distillation transfers knowledge from a large, complex model
|
|
14
|
+
(teacher) to a smaller, simpler model (student). The student learns
|
|
15
|
+
from both ground truth labels and the teacher's predictions, often
|
|
16
|
+
achieving better performance than training on labels alone.
|
|
17
|
+
|
|
18
|
+
Arguments:
|
|
19
|
+
teacher: A trained `keras.Model` that serves as the knowledge source.
|
|
20
|
+
The teacher model is frozen during distillation.
|
|
21
|
+
student: A `keras.Model` to be trained through distillation.
|
|
22
|
+
distillation_losses: List of distillation losses to apply. Can be a
|
|
23
|
+
single distillation loss or a list of distillation losses like
|
|
24
|
+
`keras.distillation.LogitsDistillation`,
|
|
25
|
+
`keras.distillation.FeatureDistillation`, or custom distillation
|
|
26
|
+
losses.
|
|
27
|
+
distillation_loss_weights: List of weights for each distillation loss.
|
|
28
|
+
Must have the same length as `distillation_losses`. If `None`,
|
|
29
|
+
equal weights are used.
|
|
30
|
+
student_loss_weight: Weight for the student's supervised loss component.
|
|
31
|
+
Must be between 0 and 1. Defaults to 0.5.
|
|
32
|
+
name: Name for the distiller model. Defaults to `"distiller"`.
|
|
33
|
+
**kwargs: Additional keyword arguments passed to the parent `Model`
|
|
34
|
+
class.
|
|
35
|
+
|
|
36
|
+
Attributes:
|
|
37
|
+
student: The student model being trained. Access this to get the trained
|
|
38
|
+
student model for independent use after distillation training.
|
|
39
|
+
teacher: The teacher model providing knowledge. This model is frozen
|
|
40
|
+
during training.
|
|
41
|
+
|
|
42
|
+
Examples:
|
|
43
|
+
|
|
44
|
+
```python
|
|
45
|
+
# Basic distillation with KerasHub models
|
|
46
|
+
import keras_hub as hub
|
|
47
|
+
|
|
48
|
+
teacher = hub.models.CausalLM.from_preset("gemma_2b_en")
|
|
49
|
+
student = hub.models.CausalLM.from_preset(
|
|
50
|
+
"gemma_1.1_2b_en", load_weights=False
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# Single distillation loss
|
|
54
|
+
distiller = Distiller(
|
|
55
|
+
teacher=teacher,
|
|
56
|
+
student=student,
|
|
57
|
+
distillation_losses=LogitsDistillation(temperature=3.0),
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Compile the distiller (like any Keras model)
|
|
61
|
+
distiller.compile(
|
|
62
|
+
optimizer='adam',
|
|
63
|
+
loss='sparse_categorical_crossentropy',
|
|
64
|
+
metrics=['accuracy']
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Train the distiller
|
|
68
|
+
distiller.fit(x_train, y_train, epochs=10)
|
|
69
|
+
|
|
70
|
+
# Access the trained student model
|
|
71
|
+
trained_student = distiller.student
|
|
72
|
+
|
|
73
|
+
# Multiple distillation losses
|
|
74
|
+
distiller = Distiller(
|
|
75
|
+
teacher=teacher,
|
|
76
|
+
student=student,
|
|
77
|
+
distillation_losses=[
|
|
78
|
+
LogitsDistillation(temperature=3.0),
|
|
79
|
+
FeatureDistillation(
|
|
80
|
+
teacher_layer_name="dense_1",
|
|
81
|
+
student_layer_name="dense_1"
|
|
82
|
+
)
|
|
83
|
+
],
|
|
84
|
+
distillation_loss_weights=[1.0, 0.5],
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Compile with custom settings
|
|
88
|
+
distiller.compile(
|
|
89
|
+
optimizer='adam',
|
|
90
|
+
loss='sparse_categorical_crossentropy',
|
|
91
|
+
metrics=['accuracy']
|
|
92
|
+
)
|
|
93
|
+
```
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
teacher,
|
|
99
|
+
student,
|
|
100
|
+
distillation_losses,
|
|
101
|
+
distillation_loss_weights=None,
|
|
102
|
+
student_loss_weight=0.5,
|
|
103
|
+
name="distiller",
|
|
104
|
+
**kwargs,
|
|
105
|
+
):
|
|
106
|
+
super().__init__(name=name, **kwargs)
|
|
107
|
+
|
|
108
|
+
# Validate inputs
|
|
109
|
+
self._validate_models(teacher, student)
|
|
110
|
+
|
|
111
|
+
# Store configuration
|
|
112
|
+
self.teacher = teacher
|
|
113
|
+
self.student = student
|
|
114
|
+
|
|
115
|
+
# Validate student_loss_weight
|
|
116
|
+
if not isinstance(student_loss_weight, (int, float)):
|
|
117
|
+
raise ValueError(
|
|
118
|
+
f"student_loss_weight must be a number, got "
|
|
119
|
+
f"{type(student_loss_weight)}"
|
|
120
|
+
)
|
|
121
|
+
if student_loss_weight < 0.0 or student_loss_weight > 1.0:
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"student_loss_weight must be between 0.0 and 1.0, "
|
|
124
|
+
f"got {student_loss_weight}"
|
|
125
|
+
)
|
|
126
|
+
self.student_loss_weight = student_loss_weight
|
|
127
|
+
|
|
128
|
+
# Handle distillation losses configuration
|
|
129
|
+
if distillation_losses is None:
|
|
130
|
+
raise ValueError(
|
|
131
|
+
"'distillation_losses' cannot be `None`. Provide a "
|
|
132
|
+
"distillation loss (e.g., LogitsDistillation or "
|
|
133
|
+
"FeatureDistillation) or a list of distillation losses."
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Convert single distillation loss to list for uniform handling
|
|
137
|
+
if not isinstance(distillation_losses, (list, tuple)):
|
|
138
|
+
self.distillation_losses = [distillation_losses]
|
|
139
|
+
self.distillation_loss_weights = [1.0]
|
|
140
|
+
else:
|
|
141
|
+
self.distillation_losses = distillation_losses
|
|
142
|
+
# Set default weights if not provided
|
|
143
|
+
if distillation_loss_weights is None:
|
|
144
|
+
self.distillation_loss_weights = [1.0] * len(
|
|
145
|
+
distillation_losses
|
|
146
|
+
)
|
|
147
|
+
else:
|
|
148
|
+
if len(distillation_loss_weights) != len(distillation_losses):
|
|
149
|
+
raise ValueError(
|
|
150
|
+
f"Number of distillation_loss_weights "
|
|
151
|
+
f"({len(distillation_loss_weights)}) must match "
|
|
152
|
+
f"number of distillation_losses "
|
|
153
|
+
f"({len(distillation_losses)})"
|
|
154
|
+
)
|
|
155
|
+
self.distillation_loss_weights = distillation_loss_weights
|
|
156
|
+
|
|
157
|
+
# Validate distillation loss compatibility and create extractors
|
|
158
|
+
for distillation_loss in self.distillation_losses:
|
|
159
|
+
self._validate_distillation_loss_compatibility(
|
|
160
|
+
teacher, student, distillation_loss
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
self._create_multi_feature_extractors()
|
|
164
|
+
|
|
165
|
+
# Freeze teacher model
|
|
166
|
+
self.teacher.trainable = False
|
|
167
|
+
|
|
168
|
+
# Initialize loss tracking metrics
|
|
169
|
+
self.student_loss_tracker = keras.metrics.Mean(name="student_loss")
|
|
170
|
+
self.distillation_loss_tracker = keras.metrics.Mean(
|
|
171
|
+
name="distillation_loss"
|
|
172
|
+
)
|
|
173
|
+
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
|
|
174
|
+
|
|
175
|
+
def _validate_models(self, teacher, student):
|
|
176
|
+
"""Validate that teacher and student models are compatible."""
|
|
177
|
+
if not isinstance(teacher, keras.Model):
|
|
178
|
+
raise ValueError(
|
|
179
|
+
f"Teacher must be a keras.Model, got {type(teacher)}"
|
|
180
|
+
)
|
|
181
|
+
if not isinstance(student, keras.Model):
|
|
182
|
+
raise ValueError(
|
|
183
|
+
f"Student must be a keras.Model, got {type(student)}"
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
self._validate_input_compatibility(teacher, student)
|
|
187
|
+
self._validate_output_compatibility(teacher, student)
|
|
188
|
+
self._validate_dtype_compatibility(teacher, student)
|
|
189
|
+
|
|
190
|
+
def _assert_shapes_are_compatible(self, shape1, shape2, context):
|
|
191
|
+
"""Assert that two shapes are compatible."""
|
|
192
|
+
if len(shape1) != len(shape2):
|
|
193
|
+
raise ValueError(
|
|
194
|
+
f"Teacher and student {context} shapes have different "
|
|
195
|
+
f"dimensions. Teacher: {shape1}, Student: {shape2}."
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
for dim1, dim2 in zip(shape1, shape2):
|
|
199
|
+
if dim1 is not None and dim2 is not None and dim1 != dim2:
|
|
200
|
+
raise ValueError(
|
|
201
|
+
f"Teacher and student {context} shapes are incompatible. "
|
|
202
|
+
f"Teacher: {shape1}, Student: {shape2}. "
|
|
203
|
+
f"All dimensions must match."
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
def _assert_same_dtype(self, teacher_dtype, student_dtype, context):
|
|
207
|
+
"""Assert that teacher and student dtypes are the same."""
|
|
208
|
+
if teacher_dtype != student_dtype:
|
|
209
|
+
raise ValueError(
|
|
210
|
+
f"Teacher and student {context} dtypes must match. "
|
|
211
|
+
f"Teacher: {teacher_dtype}, Student: {student_dtype}."
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
def _validate_input_compatibility(self, teacher, student):
|
|
215
|
+
"""Validate that teacher and student have compatible input shapes."""
|
|
216
|
+
if not hasattr(teacher, "inputs") or not hasattr(student, "inputs"):
|
|
217
|
+
return
|
|
218
|
+
teacher_inputs = getattr(teacher, "inputs")
|
|
219
|
+
student_inputs = getattr(student, "inputs")
|
|
220
|
+
if teacher_inputs is None or student_inputs is None:
|
|
221
|
+
return
|
|
222
|
+
|
|
223
|
+
tree.map_structure(
|
|
224
|
+
lambda ti, si: self._assert_shapes_are_compatible(
|
|
225
|
+
ti.shape, si.shape, "input"
|
|
226
|
+
),
|
|
227
|
+
teacher_inputs,
|
|
228
|
+
student_inputs,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
def _validate_output_compatibility(self, teacher, student):
|
|
232
|
+
"""Validate that teacher and student have compatible output shapes."""
|
|
233
|
+
if not hasattr(teacher, "outputs") or not hasattr(student, "outputs"):
|
|
234
|
+
return
|
|
235
|
+
teacher_outputs = getattr(teacher, "outputs")
|
|
236
|
+
student_outputs = getattr(student, "outputs")
|
|
237
|
+
if teacher_outputs is None or student_outputs is None:
|
|
238
|
+
return
|
|
239
|
+
|
|
240
|
+
tree.map_structure(
|
|
241
|
+
lambda to, so: self._assert_shapes_are_compatible(
|
|
242
|
+
to.shape, so.shape, "output"
|
|
243
|
+
),
|
|
244
|
+
teacher_outputs,
|
|
245
|
+
student_outputs,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
def _validate_dtype_compatibility(self, teacher, student):
|
|
249
|
+
"""Validate that teacher and student have compatible data types."""
|
|
250
|
+
if not hasattr(teacher, "inputs") or not hasattr(student, "inputs"):
|
|
251
|
+
return
|
|
252
|
+
if teacher.inputs is None or student.inputs is None:
|
|
253
|
+
return
|
|
254
|
+
|
|
255
|
+
tree.map_structure(
|
|
256
|
+
lambda ti, si: self._assert_same_dtype(ti.dtype, si.dtype, "input"),
|
|
257
|
+
teacher.inputs,
|
|
258
|
+
student.inputs,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
if not hasattr(teacher, "outputs") or not hasattr(student, "outputs"):
|
|
262
|
+
return
|
|
263
|
+
if teacher.outputs is None or student.outputs is None:
|
|
264
|
+
return
|
|
265
|
+
|
|
266
|
+
tree.map_structure(
|
|
267
|
+
lambda to, so: self._assert_same_dtype(
|
|
268
|
+
to.dtype, so.dtype, "output"
|
|
269
|
+
),
|
|
270
|
+
teacher.outputs,
|
|
271
|
+
student.outputs,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
def _validate_distillation_loss_compatibility(
|
|
275
|
+
self, teacher, student, distillation_loss
|
|
276
|
+
):
|
|
277
|
+
"""Validate that the distillation loss is compatible with teacher
|
|
278
|
+
and student models."""
|
|
279
|
+
distillation_loss.validate_model_compatibility(teacher, student)
|
|
280
|
+
|
|
281
|
+
def _create_multi_feature_extractors(self):
|
|
282
|
+
"""Create feature extractors for efficient multi-layer extraction."""
|
|
283
|
+
teacher_layer_names = []
|
|
284
|
+
student_layer_names = []
|
|
285
|
+
|
|
286
|
+
for distillation_loss in self.distillation_losses:
|
|
287
|
+
if (
|
|
288
|
+
hasattr(distillation_loss, "teacher_layer_name")
|
|
289
|
+
and distillation_loss.teacher_layer_name
|
|
290
|
+
):
|
|
291
|
+
if (
|
|
292
|
+
distillation_loss.teacher_layer_name
|
|
293
|
+
not in teacher_layer_names
|
|
294
|
+
):
|
|
295
|
+
teacher_layer_names.append(
|
|
296
|
+
distillation_loss.teacher_layer_name
|
|
297
|
+
)
|
|
298
|
+
if (
|
|
299
|
+
hasattr(distillation_loss, "student_layer_name")
|
|
300
|
+
and distillation_loss.student_layer_name
|
|
301
|
+
):
|
|
302
|
+
if (
|
|
303
|
+
distillation_loss.student_layer_name
|
|
304
|
+
not in student_layer_names
|
|
305
|
+
):
|
|
306
|
+
student_layer_names.append(
|
|
307
|
+
distillation_loss.student_layer_name
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
self._teacher_feature_extractor = self._create_feature_extractor(
|
|
311
|
+
self.teacher, teacher_layer_names
|
|
312
|
+
)
|
|
313
|
+
self._student_feature_extractor = self._create_feature_extractor(
|
|
314
|
+
self.student, student_layer_names
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
def _create_feature_extractor(self, model, layer_names):
|
|
318
|
+
"""Create a feature extractor for a model.
|
|
319
|
+
|
|
320
|
+
Arguments:
|
|
321
|
+
model: The model to create an extractor for.
|
|
322
|
+
layer_names: List of layer names to extract features from.
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
Feature extractor model or `None` if no layer names provided.
|
|
326
|
+
|
|
327
|
+
Raises:
|
|
328
|
+
ValueError: If model has no symbolic inputs/outputs.
|
|
329
|
+
"""
|
|
330
|
+
if not layer_names:
|
|
331
|
+
return None
|
|
332
|
+
|
|
333
|
+
if not hasattr(model, "inputs") or model.inputs is None:
|
|
334
|
+
raise ValueError(
|
|
335
|
+
f"Cannot create feature extractor for {model.name}. "
|
|
336
|
+
f"The model has no symbolic inputs attribute."
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
if isinstance(model, keras.Sequential):
|
|
340
|
+
final_output = model.layers[-1].output
|
|
341
|
+
else:
|
|
342
|
+
final_output = model.output
|
|
343
|
+
|
|
344
|
+
outputs = {"final_output": final_output}
|
|
345
|
+
for layer_name in layer_names:
|
|
346
|
+
layer = model.get_layer(name=layer_name)
|
|
347
|
+
outputs[layer_name] = layer.output
|
|
348
|
+
|
|
349
|
+
return keras.Model(
|
|
350
|
+
inputs=model.inputs,
|
|
351
|
+
outputs=outputs,
|
|
352
|
+
name=f"{model.name}_multi_feature_extractor",
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
def _extract_all_teacher_features(self, x):
|
|
356
|
+
"""Extract all teacher features in a single forward pass."""
|
|
357
|
+
if self._teacher_feature_extractor is not None:
|
|
358
|
+
return self._teacher_feature_extractor(x, training=False)
|
|
359
|
+
else:
|
|
360
|
+
return {"final_output": self.teacher(x, training=False)}
|
|
361
|
+
|
|
362
|
+
def _extract_all_student_features(self, x, y_pred):
|
|
363
|
+
"""Extract all student features in a single forward pass."""
|
|
364
|
+
if self._student_feature_extractor is not None:
|
|
365
|
+
return self._student_feature_extractor(x, training=True)
|
|
366
|
+
else:
|
|
367
|
+
return {"final_output": y_pred}
|
|
368
|
+
|
|
369
|
+
def _get_distillation_loss_features(
|
|
370
|
+
self, distillation_loss, all_features, is_teacher
|
|
371
|
+
):
|
|
372
|
+
"""Get the specific features needed by a distillation loss."""
|
|
373
|
+
if is_teacher:
|
|
374
|
+
layer_name = distillation_loss.teacher_layer_name or "final_output"
|
|
375
|
+
else:
|
|
376
|
+
layer_name = distillation_loss.student_layer_name or "final_output"
|
|
377
|
+
|
|
378
|
+
if layer_name not in all_features:
|
|
379
|
+
raise ValueError(
|
|
380
|
+
f"Layer '{layer_name}' not found in extracted features. "
|
|
381
|
+
f"Available: {list(all_features.keys())}"
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
return all_features[layer_name]
|
|
385
|
+
|
|
386
|
+
def compile(self, optimizer="adam", loss=None, metrics=None, **kwargs):
|
|
387
|
+
"""Compile the distiller with proper integration.
|
|
388
|
+
|
|
389
|
+
Arguments:
|
|
390
|
+
optimizer: Optimizer for training the student model.
|
|
391
|
+
loss: Student loss function for the student's supervised learning.
|
|
392
|
+
Can be a string identifier or a loss function instance.
|
|
393
|
+
metrics: Additional metrics to track during training.
|
|
394
|
+
**kwargs: Additional arguments passed to parent compile.
|
|
395
|
+
"""
|
|
396
|
+
if loss is None:
|
|
397
|
+
raise ValueError("'loss' cannot be `None`.")
|
|
398
|
+
|
|
399
|
+
self._student_loss = tree.map_structure(_convert_loss_to_function, loss)
|
|
400
|
+
self._student_loss_for_serialization = loss
|
|
401
|
+
|
|
402
|
+
if metrics is not None and not isinstance(metrics, (list, tuple)):
|
|
403
|
+
raise ValueError(
|
|
404
|
+
f"metrics must be a list or tuple, got {type(metrics)}"
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
super().compile(
|
|
408
|
+
optimizer=optimizer,
|
|
409
|
+
loss=None,
|
|
410
|
+
metrics=metrics,
|
|
411
|
+
**kwargs,
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
def call(self, inputs, training=None, **kwargs):
|
|
415
|
+
"""Forward pass returns student predictions."""
|
|
416
|
+
return self.student(inputs, training=training, **kwargs)
|
|
417
|
+
|
|
418
|
+
def compute_loss(
|
|
419
|
+
self, x=None, y=None, y_pred=None, sample_weight=None, training=True
|
|
420
|
+
):
|
|
421
|
+
"""Compute combined distillation loss.
|
|
422
|
+
|
|
423
|
+
Arguments:
|
|
424
|
+
x: Input data.
|
|
425
|
+
y: Target data.
|
|
426
|
+
y_pred: Model predictions.
|
|
427
|
+
sample_weight: Sample weights (currently unused).
|
|
428
|
+
training: Whether the model is in training mode.
|
|
429
|
+
|
|
430
|
+
Returns:
|
|
431
|
+
Combined loss tensor.
|
|
432
|
+
"""
|
|
433
|
+
# Handle case where y_pred is not provided
|
|
434
|
+
if y_pred is None:
|
|
435
|
+
y_pred = self(x, training=training)
|
|
436
|
+
# Compute student loss
|
|
437
|
+
student_loss = 0.0
|
|
438
|
+
if self.student_loss_weight > 0.0 and y is not None:
|
|
439
|
+
loss_values = tree.map_structure(
|
|
440
|
+
lambda l, o, o_pred: l(o, o_pred),
|
|
441
|
+
self._student_loss,
|
|
442
|
+
y,
|
|
443
|
+
y_pred,
|
|
444
|
+
)
|
|
445
|
+
flat_losses = tree.flatten(loss_values)
|
|
446
|
+
student_loss = (
|
|
447
|
+
keras.ops.sum(keras.ops.stack(flat_losses))
|
|
448
|
+
if len(flat_losses) > 1
|
|
449
|
+
else flat_losses[0]
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
# Ensure student_loss is a scalar
|
|
453
|
+
if hasattr(student_loss, "shape") and len(student_loss.shape) > 0:
|
|
454
|
+
student_loss = keras.ops.mean(student_loss)
|
|
455
|
+
|
|
456
|
+
# Compute distillation loss
|
|
457
|
+
distillation_loss = 0.0
|
|
458
|
+
if self.student_loss_weight < 1.0:
|
|
459
|
+
teacher_features = self._extract_all_teacher_features(x)
|
|
460
|
+
student_features = self._extract_all_student_features(x, y_pred)
|
|
461
|
+
|
|
462
|
+
# Apply distillation losses using pre-extracted features
|
|
463
|
+
for distillation_loss_fn, weight in zip(
|
|
464
|
+
self.distillation_losses, self.distillation_loss_weights
|
|
465
|
+
):
|
|
466
|
+
# Get appropriate outputs/features for this distillation loss
|
|
467
|
+
if (
|
|
468
|
+
hasattr(distillation_loss_fn, "teacher_layer_name")
|
|
469
|
+
and distillation_loss_fn.teacher_layer_name is not None
|
|
470
|
+
):
|
|
471
|
+
# FeatureDistillation with specific layers
|
|
472
|
+
try:
|
|
473
|
+
distillation_loss_teacher_output = (
|
|
474
|
+
self._get_distillation_loss_features(
|
|
475
|
+
distillation_loss_fn,
|
|
476
|
+
teacher_features,
|
|
477
|
+
is_teacher=True,
|
|
478
|
+
)
|
|
479
|
+
)
|
|
480
|
+
distillation_loss_student_output = (
|
|
481
|
+
self._get_distillation_loss_features(
|
|
482
|
+
distillation_loss_fn,
|
|
483
|
+
student_features,
|
|
484
|
+
is_teacher=False,
|
|
485
|
+
)
|
|
486
|
+
)
|
|
487
|
+
except ValueError as e:
|
|
488
|
+
# Re-raise with context about which loss failed
|
|
489
|
+
raise RuntimeError(
|
|
490
|
+
f"Failed to extract features for "
|
|
491
|
+
f"{type(distillation_loss_fn).__name__} "
|
|
492
|
+
f"targeting teacher layer "
|
|
493
|
+
f"'{distillation_loss_fn.teacher_layer_name}' "
|
|
494
|
+
f"and student layer "
|
|
495
|
+
f"'{distillation_loss_fn.student_layer_name}'. "
|
|
496
|
+
f"Original error: {e}"
|
|
497
|
+
) from e
|
|
498
|
+
else:
|
|
499
|
+
# LogitsDistillation or FeatureDistillation (final outputs)
|
|
500
|
+
distillation_loss_teacher_output = teacher_features[
|
|
501
|
+
"final_output"
|
|
502
|
+
]
|
|
503
|
+
distillation_loss_student_output = y_pred
|
|
504
|
+
|
|
505
|
+
# Validate outputs are compatible for this distillation loss
|
|
506
|
+
distillation_loss_fn.validate_outputs(
|
|
507
|
+
distillation_loss_teacher_output,
|
|
508
|
+
distillation_loss_student_output,
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
# Compute loss for this distillation loss
|
|
512
|
+
current_distillation_loss = distillation_loss_fn.compute_loss(
|
|
513
|
+
distillation_loss_teacher_output,
|
|
514
|
+
distillation_loss_student_output,
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
# Validate that distillation loss returns a scalar
|
|
518
|
+
if (
|
|
519
|
+
hasattr(current_distillation_loss, "shape")
|
|
520
|
+
and len(current_distillation_loss.shape) > 0
|
|
521
|
+
):
|
|
522
|
+
raise ValueError(
|
|
523
|
+
f"Distillation loss "
|
|
524
|
+
f"{distillation_loss_fn.__class__.__name__} "
|
|
525
|
+
f"returned a non-scalar loss with shape "
|
|
526
|
+
f"{current_distillation_loss.shape}. "
|
|
527
|
+
f"The compute_loss method must return a scalar "
|
|
528
|
+
f"tensor."
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
# Apply weight and add to total
|
|
532
|
+
distillation_loss = keras.ops.add(
|
|
533
|
+
distillation_loss,
|
|
534
|
+
keras.ops.multiply(weight, current_distillation_loss),
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
# Combine losses
|
|
538
|
+
total_loss = keras.ops.add(
|
|
539
|
+
keras.ops.multiply(self.student_loss_weight, student_loss),
|
|
540
|
+
keras.ops.multiply(
|
|
541
|
+
keras.ops.subtract(1.0, self.student_loss_weight),
|
|
542
|
+
distillation_loss,
|
|
543
|
+
),
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
# Update metrics
|
|
547
|
+
self.student_loss_tracker.update_state(student_loss)
|
|
548
|
+
self.distillation_loss_tracker.update_state(distillation_loss)
|
|
549
|
+
self.total_loss_tracker.update_state(total_loss)
|
|
550
|
+
|
|
551
|
+
return total_loss
|
|
552
|
+
|
|
553
|
+
def reset_metrics(self):
|
|
554
|
+
"""Reset all metrics."""
|
|
555
|
+
super().reset_metrics()
|
|
556
|
+
self.student_loss_tracker.reset_state()
|
|
557
|
+
self.distillation_loss_tracker.reset_state()
|
|
558
|
+
self.total_loss_tracker.reset_state()
|
|
559
|
+
|
|
560
|
+
def get_config(self):
|
|
561
|
+
"""Get configuration for serialization."""
|
|
562
|
+
config = super().get_config()
|
|
563
|
+
config.update(
|
|
564
|
+
{
|
|
565
|
+
"teacher": serialization_lib.serialize_keras_object(
|
|
566
|
+
self.teacher
|
|
567
|
+
),
|
|
568
|
+
"student": serialization_lib.serialize_keras_object(
|
|
569
|
+
self.student
|
|
570
|
+
),
|
|
571
|
+
"distillation_losses": [
|
|
572
|
+
serialization_lib.serialize_keras_object(distillation_loss)
|
|
573
|
+
for distillation_loss in self.distillation_losses
|
|
574
|
+
],
|
|
575
|
+
"distillation_loss_weights": self.distillation_loss_weights,
|
|
576
|
+
"student_loss_weight": self.student_loss_weight,
|
|
577
|
+
}
|
|
578
|
+
)
|
|
579
|
+
return config
|
|
580
|
+
|
|
581
|
+
@classmethod
|
|
582
|
+
def from_config(cls, config):
|
|
583
|
+
"""Create instance from configuration."""
|
|
584
|
+
config = config.copy()
|
|
585
|
+
|
|
586
|
+
# Deserialize objects
|
|
587
|
+
config["teacher"] = serialization_lib.deserialize_keras_object(
|
|
588
|
+
config["teacher"]
|
|
589
|
+
)
|
|
590
|
+
config["student"] = serialization_lib.deserialize_keras_object(
|
|
591
|
+
config["student"]
|
|
592
|
+
)
|
|
593
|
+
config["distillation_losses"] = [
|
|
594
|
+
serialization_lib.deserialize_keras_object(distillation_loss)
|
|
595
|
+
for distillation_loss in config["distillation_losses"]
|
|
596
|
+
]
|
|
597
|
+
|
|
598
|
+
return cls(**config)
|
|
@@ -39,6 +39,20 @@ def list_devices(device_type=None):
|
|
|
39
39
|
return distribution_lib.list_devices(device_type)
|
|
40
40
|
|
|
41
41
|
|
|
42
|
+
@keras_export("keras.distribution.get_device_count")
|
|
43
|
+
def get_device_count(device_type=None):
|
|
44
|
+
"""Returns the number of available JAX devices.
|
|
45
|
+
Args:
|
|
46
|
+
device_type: Optional device type to count (e.g., "cpu", "gpu", "tpu").
|
|
47
|
+
If `None`, it defaults to counting "gpu" or "tpu" devices if
|
|
48
|
+
available, otherwise it counts "cpu" devices. It does not
|
|
49
|
+
return the sum of all device types.
|
|
50
|
+
Returns:
|
|
51
|
+
int: The total number of JAX devices for the specified type.
|
|
52
|
+
"""
|
|
53
|
+
return distribution_lib.get_device_count(device_type=device_type)
|
|
54
|
+
|
|
55
|
+
|
|
42
56
|
@keras_export("keras.distribution.initialize")
|
|
43
57
|
def initialize(job_addresses=None, num_processes=None, process_id=None):
|
|
44
58
|
"""Initialize the distribution system for multi-host/process setting.
|
|
@@ -2,18 +2,22 @@ from keras.src import backend
|
|
|
2
2
|
from keras.src.api_export import keras_export
|
|
3
3
|
from keras.src.dtype_policies import dtype_policy
|
|
4
4
|
from keras.src.dtype_policies.dtype_policy import QUANTIZATION_MODES
|
|
5
|
+
from keras.src.dtype_policies.dtype_policy import AWQDTypePolicy
|
|
5
6
|
from keras.src.dtype_policies.dtype_policy import DTypePolicy
|
|
6
7
|
from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy
|
|
8
|
+
from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy
|
|
7
9
|
from keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy
|
|
8
10
|
from keras.src.dtype_policies.dtype_policy import QuantizedFloat8DTypePolicy
|
|
9
11
|
from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
|
|
10
12
|
|
|
11
13
|
ALL_OBJECTS = {
|
|
14
|
+
AWQDTypePolicy,
|
|
12
15
|
DTypePolicy,
|
|
13
16
|
FloatDTypePolicy,
|
|
14
17
|
QuantizedDTypePolicy,
|
|
15
18
|
QuantizedFloat8DTypePolicy,
|
|
16
19
|
DTypePolicyMap,
|
|
20
|
+
GPTQDTypePolicy,
|
|
17
21
|
}
|
|
18
22
|
ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
|
|
19
23
|
|