keras-nightly 3.14.0.dev2025122704__py3-none-any.whl → 3.14.0.dev2026012204__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/ops/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +3 -0
- keras/_tf_keras/keras/quantizers/__init__.py +1 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/ops/__init__.py +3 -0
- keras/ops/numpy/__init__.py +3 -0
- keras/quantizers/__init__.py +1 -0
- keras/src/backend/jax/nn.py +26 -9
- keras/src/backend/jax/numpy.py +16 -0
- keras/src/backend/numpy/numpy.py +23 -0
- keras/src/backend/openvino/numpy.py +369 -16
- keras/src/backend/tensorflow/numpy.py +34 -1
- keras/src/backend/tensorflow/rnn.py +17 -7
- keras/src/backend/torch/numpy.py +36 -0
- keras/src/backend/torch/rnn.py +28 -11
- keras/src/callbacks/orbax_checkpoint.py +75 -42
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +90 -1
- keras/src/layers/core/dense.py +122 -6
- keras/src/layers/core/einsum_dense.py +151 -7
- keras/src/layers/core/embedding.py +1 -1
- keras/src/layers/core/reversible_embedding.py +10 -1
- keras/src/layers/layer.py +5 -0
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/center_crop.py +13 -15
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/losses/losses.py +24 -0
- keras/src/models/model.py +18 -9
- keras/src/ops/image.py +109 -96
- keras/src/ops/numpy.py +181 -0
- keras/src/quantizers/__init__.py +2 -0
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +1 -2
- keras/src/quantizers/gptq_core.py +1 -1
- keras/src/quantizers/quantization_config.py +14 -0
- keras/src/quantizers/quantizers.py +61 -52
- keras/src/random/seed_generator.py +2 -2
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +50 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/utils/jax_layer.py +69 -31
- keras/src/utils/module_utils.py +11 -0
- keras/src/utils/tracking.py +5 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/METADATA +1 -1
- {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/RECORD +53 -49
- {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/WHEEL +1 -1
- {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/top_level.txt +0 -0
|
@@ -8,7 +8,6 @@ from keras.src.api_export import keras_export
|
|
|
8
8
|
from keras.src.callbacks.monitor_callback import (
|
|
9
9
|
MonitorCallback, # For metric monitoring logic
|
|
10
10
|
)
|
|
11
|
-
from keras.src.utils.io_utils import print_msg
|
|
12
11
|
from keras.src.utils.module_utils import ocp
|
|
13
12
|
|
|
14
13
|
# Context and AsyncOptions are accessed through the lazy-loaded ocp module
|
|
@@ -62,6 +61,11 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
62
61
|
This callback saves the model's weights and optimizer state asynchronously
|
|
63
62
|
using Orbax, allowing training to continue without blocking for I/O.
|
|
64
63
|
|
|
64
|
+
**Multi-host Support**: When running in a multi-host distributed training
|
|
65
|
+
environment with JAX backend, this callback automatically coordinates
|
|
66
|
+
checkpointing across all hosts to ensure consistency and proper
|
|
67
|
+
synchronization. Multi-host checkpointing is only supported on JAX.
|
|
68
|
+
|
|
65
69
|
Example:
|
|
66
70
|
|
|
67
71
|
```python
|
|
@@ -92,10 +96,6 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
92
96
|
verbose: Verbosity mode, 0 or 1.
|
|
93
97
|
save_best_only: if `save_best_only=True`, it only saves when the model
|
|
94
98
|
is considered the "best" based on the monitored quantity.
|
|
95
|
-
save_weights_only: if `save_weights_only=True`, only the model's
|
|
96
|
-
weights will be saved. Otherwise, the full model state
|
|
97
|
-
(weights, non-trainable variables, optimizer state, and
|
|
98
|
-
metrics state) will be saved. Defaults to False.
|
|
99
99
|
mode: one of {'auto', 'min', 'max'}. Used with `save_best_only`.
|
|
100
100
|
save_freq: `'epoch'` or integer. Frequency to save checkpoints.
|
|
101
101
|
max_to_keep: Integer, maximum number of recent checkpoints to keep.
|
|
@@ -112,7 +112,6 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
112
112
|
monitor="val_loss",
|
|
113
113
|
verbose=0,
|
|
114
114
|
save_best_only=False,
|
|
115
|
-
save_weights_only=False,
|
|
116
115
|
mode="auto",
|
|
117
116
|
save_freq="epoch",
|
|
118
117
|
initial_value_threshold=None,
|
|
@@ -129,7 +128,6 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
129
128
|
self.directory = directory
|
|
130
129
|
self.verbose = verbose
|
|
131
130
|
self.save_best_only = save_best_only
|
|
132
|
-
self.save_weights_only = save_weights_only
|
|
133
131
|
self.save_freq = save_freq
|
|
134
132
|
self.max_to_keep = max_to_keep
|
|
135
133
|
self.save_on_background = save_on_background
|
|
@@ -138,6 +136,9 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
138
136
|
self._current_epoch = 0 # Keep track of epoch
|
|
139
137
|
self._total_batches_seen = 0 # Global batch counter for step tracking
|
|
140
138
|
|
|
139
|
+
# Multi-host support
|
|
140
|
+
self._multihost_initialized = self._is_multihost_initialized()
|
|
141
|
+
|
|
141
142
|
if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
|
|
142
143
|
raise ValueError(
|
|
143
144
|
f"Unrecognized save_freq: {self.save_freq}. "
|
|
@@ -151,14 +152,18 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
151
152
|
ocp.training.preservation_policies.LatestN(max_to_keep)
|
|
152
153
|
)
|
|
153
154
|
|
|
154
|
-
# Use AnyPreservationPolicy to combine them
|
|
155
|
+
# Use AnyPreservationPolicy to combine them, or use directly
|
|
156
|
+
# if single policy
|
|
155
157
|
preservation_policy = None
|
|
156
158
|
if policies:
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
159
|
+
if len(policies) == 1:
|
|
160
|
+
preservation_policy = policies[0]
|
|
161
|
+
else:
|
|
162
|
+
preservation_policy = (
|
|
163
|
+
ocp.training.preservation_policies.AnyPreservationPolicy(
|
|
164
|
+
policies
|
|
165
|
+
)
|
|
160
166
|
)
|
|
161
|
-
)
|
|
162
167
|
|
|
163
168
|
# Create the V1 Checkpointer with direct parameter passing
|
|
164
169
|
# Orbax will handle directory creation on all processes as needed
|
|
@@ -167,6 +172,54 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
167
172
|
preservation_policy=preservation_policy,
|
|
168
173
|
)
|
|
169
174
|
|
|
175
|
+
def _is_multihost_initialized(self):
|
|
176
|
+
"""Check if multi-host environment is initialized."""
|
|
177
|
+
# Multi-host checkpointing is only supported on JAX backend
|
|
178
|
+
if backend.backend() != "jax":
|
|
179
|
+
return False
|
|
180
|
+
|
|
181
|
+
multihost = ocp.multihost
|
|
182
|
+
# Check if JAX distributed client is initialized
|
|
183
|
+
# (indicates multihost setup)
|
|
184
|
+
return multihost.is_jax_distributed_client_initialized()
|
|
185
|
+
|
|
186
|
+
def _sync_processes(self, key=None):
|
|
187
|
+
"""Synchronize all processes across hosts."""
|
|
188
|
+
if not self._multihost_initialized:
|
|
189
|
+
return # No-op for single host
|
|
190
|
+
|
|
191
|
+
multihost = ocp.multihost
|
|
192
|
+
sync_key = key or "orbax_checkpoint_sync"
|
|
193
|
+
multihost.sync_global_processes(sync_key)
|
|
194
|
+
|
|
195
|
+
def is_multihost_enabled(self):
|
|
196
|
+
"""Return True if multi-host checkpointing is enabled and initialized.
|
|
197
|
+
|
|
198
|
+
This method can be used to check if the callback is operating in
|
|
199
|
+
a multi-host distributed training environment. Multi-host checkpointing
|
|
200
|
+
is only supported on JAX backend.
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
bool: True if multi-host support is active, False otherwise.
|
|
204
|
+
"""
|
|
205
|
+
return self._multihost_initialized
|
|
206
|
+
|
|
207
|
+
def is_primary_host(self):
|
|
208
|
+
"""Return True if this process is the primary host in multi-host setup.
|
|
209
|
+
|
|
210
|
+
In multi-host environments, only the primary host typically handles
|
|
211
|
+
logging and coordination tasks. Multi-host checkpointing is only
|
|
212
|
+
supported on JAX backend.
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
bool: True if this is the primary host, False otherwise.
|
|
216
|
+
Always returns True in single-host environments.
|
|
217
|
+
"""
|
|
218
|
+
if not self._multihost_initialized:
|
|
219
|
+
return True # Single host is always primary
|
|
220
|
+
multihost = ocp.multihost
|
|
221
|
+
return multihost.is_primary_host()
|
|
222
|
+
|
|
170
223
|
def _should_save_on_batch(self, batch):
|
|
171
224
|
"""Check if we should save on this batch."""
|
|
172
225
|
if self.save_freq == "epoch":
|
|
@@ -186,32 +239,14 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
186
239
|
return False
|
|
187
240
|
|
|
188
241
|
def _save_checkpoint(self, step, logs=None):
|
|
189
|
-
"""Save a checkpoint at the given step."""
|
|
242
|
+
"""Save a checkpoint at the given step with multi-host coordination."""
|
|
190
243
|
|
|
191
244
|
# --- Prepare Composite State (Backend-Agnostic) ---
|
|
192
245
|
state_tree = _get_state_tree(self.model)
|
|
193
246
|
|
|
194
247
|
# Save the nested state structures directly (preserving layer
|
|
195
248
|
# names and structure)
|
|
196
|
-
|
|
197
|
-
composite_state = {
|
|
198
|
-
"trainable_variables": state_tree["trainable_variables"],
|
|
199
|
-
}
|
|
200
|
-
if "non_trainable_variables" in state_tree:
|
|
201
|
-
composite_state["non_trainable_variables"] = state_tree[
|
|
202
|
-
"non_trainable_variables"
|
|
203
|
-
]
|
|
204
|
-
else:
|
|
205
|
-
composite_state = state_tree
|
|
206
|
-
|
|
207
|
-
# --- Save Logic (V1 API) ---
|
|
208
|
-
# All processes participate in distributed checkpointing
|
|
209
|
-
# Checkpointer is configured to save unconditionally when
|
|
210
|
-
# save_pytree is called
|
|
211
|
-
if self.verbose > 0:
|
|
212
|
-
print_msg(
|
|
213
|
-
f"OrbaxCheckpoint: Triggering async save for step {step}..."
|
|
214
|
-
)
|
|
249
|
+
composite_state = state_tree
|
|
215
250
|
|
|
216
251
|
# Use a single with statement. If context_options is empty,
|
|
217
252
|
# Context() uses defaults.
|
|
@@ -282,18 +317,16 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
282
317
|
except Exception:
|
|
283
318
|
pass # Ignore errors during cleanup
|
|
284
319
|
|
|
320
|
+
# Multi-host synchronization: ensure all hosts complete cleanup
|
|
321
|
+
self._sync_processes("checkpoint_cleanup")
|
|
322
|
+
|
|
285
323
|
def wait_until_finished(self):
|
|
286
324
|
"""Wait for any in-progress checkpoint operations to complete.
|
|
287
325
|
This method blocks until all asynchronous checkpoint save operations
|
|
288
|
-
have completed
|
|
289
|
-
checkpoints if there might be pending save operations.
|
|
326
|
+
have completed across all hosts in a multi-host setup.
|
|
290
327
|
"""
|
|
291
|
-
# Wait for any async operations to complete
|
|
292
|
-
|
|
293
|
-
self.checkpointer.wait()
|
|
294
|
-
else:
|
|
295
|
-
# Fallback for older Orbax versions that don't have wait() method
|
|
296
|
-
while self.checkpointer.is_saving_in_progress():
|
|
297
|
-
import time
|
|
328
|
+
# Wait for any async operations to complete on this host
|
|
329
|
+
self.checkpointer.wait()
|
|
298
330
|
|
|
299
|
-
|
|
331
|
+
# Multi-host synchronization: ensure all hosts complete
|
|
332
|
+
self._sync_processes("checkpoint_wait_complete")
|
|
@@ -2,6 +2,7 @@ from keras.src import backend
|
|
|
2
2
|
from keras.src.api_export import keras_export
|
|
3
3
|
from keras.src.dtype_policies import dtype_policy
|
|
4
4
|
from keras.src.dtype_policies.dtype_policy import QUANTIZATION_MODES
|
|
5
|
+
from keras.src.dtype_policies.dtype_policy import AWQDTypePolicy
|
|
5
6
|
from keras.src.dtype_policies.dtype_policy import DTypePolicy
|
|
6
7
|
from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy
|
|
7
8
|
from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy
|
|
@@ -10,6 +11,7 @@ from keras.src.dtype_policies.dtype_policy import QuantizedFloat8DTypePolicy
|
|
|
10
11
|
from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
|
|
11
12
|
|
|
12
13
|
ALL_OBJECTS = {
|
|
14
|
+
AWQDTypePolicy,
|
|
13
15
|
DTypePolicy,
|
|
14
16
|
FloatDTypePolicy,
|
|
15
17
|
QuantizedDTypePolicy,
|
|
@@ -3,7 +3,7 @@ from keras.src import ops
|
|
|
3
3
|
from keras.src.api_export import keras_export
|
|
4
4
|
from keras.src.backend.common import global_state
|
|
5
5
|
|
|
6
|
-
QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq")
|
|
6
|
+
QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq", "awq")
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
@keras_export(
|
|
@@ -376,6 +376,93 @@ class GPTQDTypePolicy(QuantizedDTypePolicy):
|
|
|
376
376
|
return config
|
|
377
377
|
|
|
378
378
|
|
|
379
|
+
@keras_export("keras.dtype_policies.AWQDTypePolicy")
|
|
380
|
+
class AWQDTypePolicy(QuantizedDTypePolicy):
|
|
381
|
+
"""Quantized dtype policy for AWQ quantization.
|
|
382
|
+
|
|
383
|
+
This policy helps propagate quantization settings for AWQ
|
|
384
|
+
when loading an AWQ quantized model in Keras format.
|
|
385
|
+
|
|
386
|
+
Args:
|
|
387
|
+
mode: The quantization mode. This should be a string in the format
|
|
388
|
+
`"awq/<weight_bits>/<group_size>"`.
|
|
389
|
+
- `"awq"`: The identifier for the quantization algorithm.
|
|
390
|
+
- `<weight_bits>`: Number of bits to quantize weights to.
|
|
391
|
+
AWQ presently only supports 4-bit quantization.
|
|
392
|
+
- `<group_size>`: The group size for quantization. Supported
|
|
393
|
+
values are -1 (for per-channel quantization) or any
|
|
394
|
+
positive integer.
|
|
395
|
+
Example: `"awq/4/128"`.
|
|
396
|
+
source_name: The source dtype policy name, e.g. "float32".
|
|
397
|
+
"""
|
|
398
|
+
|
|
399
|
+
def __init__(
|
|
400
|
+
self,
|
|
401
|
+
mode,
|
|
402
|
+
source_name=None,
|
|
403
|
+
):
|
|
404
|
+
parts = mode.split("/")
|
|
405
|
+
expected_format = "'awq/<weight_bits>/<group_size>'"
|
|
406
|
+
|
|
407
|
+
# Validate format.
|
|
408
|
+
if len(parts) != 3 or parts[0] != "awq":
|
|
409
|
+
raise ValueError(
|
|
410
|
+
"Invalid mode for AWQDTypePolicy. Expected format "
|
|
411
|
+
f"{expected_format}, but got '{mode}'."
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
# Validate and cast weight_bits and group_size.
|
|
415
|
+
try:
|
|
416
|
+
weight_bits = int(parts[1])
|
|
417
|
+
group_size = int(parts[2])
|
|
418
|
+
except ValueError:
|
|
419
|
+
raise ValueError(
|
|
420
|
+
"Invalid mode for AWQDTypePolicy. <weight_bits> and "
|
|
421
|
+
"<group_size> must be integers. Expected format "
|
|
422
|
+
f"{expected_format}, but got '{mode}'."
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
# AWQ presently only supports 4-bit quantization.
|
|
426
|
+
if weight_bits != 4:
|
|
427
|
+
raise ValueError(
|
|
428
|
+
"Invalid weight_bits in mode. AWQ only supports 4-bit "
|
|
429
|
+
f"quantization, but got {weight_bits} from '{mode}'."
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
if group_size < -1 or group_size == 0:
|
|
433
|
+
raise ValueError(
|
|
434
|
+
"Invalid group_size in mode. Supported values are "
|
|
435
|
+
"-1 (per-channel) or a positive integer, "
|
|
436
|
+
f"but got {group_size} from '{mode}'."
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
base_mode = parts[0]
|
|
440
|
+
super().__init__(
|
|
441
|
+
mode=base_mode,
|
|
442
|
+
source_name=source_name,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
self._name = f"{mode}_from_{source_name}"
|
|
446
|
+
self.mode = base_mode
|
|
447
|
+
self.weight_bits = weight_bits
|
|
448
|
+
self.group_size = group_size
|
|
449
|
+
|
|
450
|
+
def __eq__(self, other):
|
|
451
|
+
if super().__eq__(other) is False:
|
|
452
|
+
return False
|
|
453
|
+
return (
|
|
454
|
+
self.weight_bits == other.weight_bits
|
|
455
|
+
and self.group_size == other.group_size
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
def get_config(self):
|
|
459
|
+
config = super().get_config()
|
|
460
|
+
# Reconstruct the full mode string for serialization
|
|
461
|
+
mode = f"{self.mode}/{self.weight_bits}/{self.group_size}"
|
|
462
|
+
config.update({"mode": mode})
|
|
463
|
+
return config
|
|
464
|
+
|
|
465
|
+
|
|
379
466
|
@keras_export(
|
|
380
467
|
[
|
|
381
468
|
"keras.config.set_dtype_policy",
|
|
@@ -442,6 +529,8 @@ def _get_quantized_dtype_policy_by_str(policy):
|
|
|
442
529
|
return QuantizedDTypePolicy(mode, source_name)
|
|
443
530
|
elif policy.startswith("gptq"):
|
|
444
531
|
return GPTQDTypePolicy(mode, source_name)
|
|
532
|
+
elif policy.startswith("awq"):
|
|
533
|
+
return AWQDTypePolicy(mode, source_name)
|
|
445
534
|
elif policy.startswith("float8"):
|
|
446
535
|
return QuantizedFloat8DTypePolicy(mode, source_name)
|
|
447
536
|
else:
|
keras/src/layers/core/dense.py
CHANGED
|
@@ -128,7 +128,7 @@ class Dense(Layer):
|
|
|
128
128
|
mode=self.quantization_mode,
|
|
129
129
|
config=self.quantization_config,
|
|
130
130
|
)
|
|
131
|
-
if self.quantization_mode not in ("int8", "int4", "gptq"):
|
|
131
|
+
if self.quantization_mode not in ("int8", "int4", "gptq", "awq"):
|
|
132
132
|
# If the layer is quantized to int8 or int4, `self._kernel` will be
|
|
133
133
|
# added in `self._int8_build` or `_int4_build`. Therefore, we skip
|
|
134
134
|
# it here.
|
|
@@ -165,15 +165,17 @@ class Dense(Layer):
|
|
|
165
165
|
|
|
166
166
|
mode = self.quantization_mode
|
|
167
167
|
is_gptq = mode == "gptq"
|
|
168
|
+
is_awq = mode == "awq"
|
|
168
169
|
is_int4 = mode == "int4"
|
|
169
|
-
|
|
170
|
+
gptq_calibrated = bool(getattr(self, "is_gptq_calibrated", False))
|
|
171
|
+
awq_calibrated = bool(getattr(self, "is_awq_calibrated", False))
|
|
170
172
|
gptq_bits = (
|
|
171
173
|
gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None
|
|
172
174
|
)
|
|
173
175
|
|
|
174
176
|
# Decide the source tensor first (packed vs already-quantized vs plain
|
|
175
177
|
# kernel)
|
|
176
|
-
if is_gptq and
|
|
178
|
+
if is_gptq and gptq_calibrated and gptq_bits != 4:
|
|
177
179
|
# calibrated GPTQ, not 4-bit, no unpacking needed
|
|
178
180
|
kernel = self.quantized_kernel
|
|
179
181
|
else:
|
|
@@ -183,7 +185,15 @@ class Dense(Layer):
|
|
|
183
185
|
# Handle int4 unpacking cases in one place
|
|
184
186
|
if is_int4:
|
|
185
187
|
kernel = quantizers.unpack_int4(kernel, self._orig_input_dim)
|
|
186
|
-
elif is_gptq and
|
|
188
|
+
elif is_gptq and gptq_calibrated and gptq_bits == 4:
|
|
189
|
+
kernel = quantizers.unpack_int4(
|
|
190
|
+
self.quantized_kernel,
|
|
191
|
+
orig_len=self.units,
|
|
192
|
+
axis=0,
|
|
193
|
+
dtype="uint8",
|
|
194
|
+
)
|
|
195
|
+
elif is_awq and awq_calibrated:
|
|
196
|
+
# AWQ always uses 4-bit quantization
|
|
187
197
|
kernel = quantizers.unpack_int4(
|
|
188
198
|
self.quantized_kernel,
|
|
189
199
|
orig_len=self.units,
|
|
@@ -304,8 +314,9 @@ class Dense(Layer):
|
|
|
304
314
|
if mode not in self.variable_serialization_spec:
|
|
305
315
|
raise self._quantization_mode_error(mode)
|
|
306
316
|
|
|
307
|
-
# A saved GPTQ quantized model will always be calibrated.
|
|
317
|
+
# A saved GPTQ/AWQ quantized model will always be calibrated.
|
|
308
318
|
self.is_gptq_calibrated = mode == "gptq"
|
|
319
|
+
self.is_awq_calibrated = mode == "awq"
|
|
309
320
|
|
|
310
321
|
idx = 0
|
|
311
322
|
for name in self.variable_serialization_spec[mode]:
|
|
@@ -395,6 +406,14 @@ class Dense(Layer):
|
|
|
395
406
|
"kernel_zero",
|
|
396
407
|
"g_idx",
|
|
397
408
|
],
|
|
409
|
+
"awq": [
|
|
410
|
+
"bias",
|
|
411
|
+
"quantized_kernel",
|
|
412
|
+
"kernel_scale",
|
|
413
|
+
"kernel_zero",
|
|
414
|
+
"awq_scales",
|
|
415
|
+
"g_idx",
|
|
416
|
+
],
|
|
398
417
|
}
|
|
399
418
|
|
|
400
419
|
def quantized_build(self, kernel_shape, mode, config=None):
|
|
@@ -406,6 +425,8 @@ class Dense(Layer):
|
|
|
406
425
|
self._float8_build()
|
|
407
426
|
elif mode == "gptq":
|
|
408
427
|
self._gptq_build(kernel_shape, config)
|
|
428
|
+
elif mode == "awq":
|
|
429
|
+
self._awq_build(kernel_shape, config)
|
|
409
430
|
else:
|
|
410
431
|
raise self._quantization_mode_error(mode)
|
|
411
432
|
self._is_quantized = True
|
|
@@ -515,6 +536,97 @@ class Dense(Layer):
|
|
|
515
536
|
y = self.activation(y)
|
|
516
537
|
return y
|
|
517
538
|
|
|
539
|
+
def _awq_build(self, kernel_shape, config):
|
|
540
|
+
"""Build variables for AWQ quantization.
|
|
541
|
+
|
|
542
|
+
AWQ uses 4-bit quantization with per-channel AWQ scales that protect
|
|
543
|
+
salient weights based on activation magnitudes.
|
|
544
|
+
"""
|
|
545
|
+
from keras.src.quantizers import awq_core
|
|
546
|
+
|
|
547
|
+
# Ensures the forward pass uses the original high-precision kernel
|
|
548
|
+
# until calibration has been performed.
|
|
549
|
+
self.is_awq_calibrated = False
|
|
550
|
+
self.kernel_shape = kernel_shape
|
|
551
|
+
|
|
552
|
+
# For 4-bit weights, we pack two values per byte.
|
|
553
|
+
units = (kernel_shape[1] + 1) // 2
|
|
554
|
+
|
|
555
|
+
self.quantized_kernel = self.add_weight(
|
|
556
|
+
name="kernel",
|
|
557
|
+
shape=(units, kernel_shape[0]),
|
|
558
|
+
initializer="zeros",
|
|
559
|
+
dtype="uint8",
|
|
560
|
+
trainable=False,
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
group_size = awq_core.get_group_size_for_layer(self, config)
|
|
564
|
+
num_groups = (
|
|
565
|
+
1 if group_size == -1 else math.ceil(kernel_shape[0] / group_size)
|
|
566
|
+
)
|
|
567
|
+
self.kernel_scale = self.add_weight(
|
|
568
|
+
name="kernel_scale",
|
|
569
|
+
shape=(self.units, num_groups),
|
|
570
|
+
initializer="ones",
|
|
571
|
+
trainable=False,
|
|
572
|
+
)
|
|
573
|
+
self.kernel_zero = self.add_weight(
|
|
574
|
+
name="kernel_zero",
|
|
575
|
+
shape=(self.units, num_groups),
|
|
576
|
+
initializer="zeros",
|
|
577
|
+
dtype="uint8",
|
|
578
|
+
trainable=False,
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
# Per-channel AWQ scales from activation magnitudes
|
|
582
|
+
self.awq_scales = self.add_weight(
|
|
583
|
+
name="awq_scales",
|
|
584
|
+
shape=(kernel_shape[0],),
|
|
585
|
+
initializer="ones",
|
|
586
|
+
trainable=False,
|
|
587
|
+
)
|
|
588
|
+
self.g_idx = self.add_weight(
|
|
589
|
+
name="g_idx",
|
|
590
|
+
shape=(kernel_shape[0],),
|
|
591
|
+
initializer="zeros",
|
|
592
|
+
dtype="float32",
|
|
593
|
+
trainable=False,
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
def _awq_call(self, inputs, training=False):
|
|
597
|
+
"""Forward pass for AWQ quantized layer."""
|
|
598
|
+
if not self.is_awq_calibrated:
|
|
599
|
+
W = self._kernel
|
|
600
|
+
else:
|
|
601
|
+
# Unpack 4-bit weights
|
|
602
|
+
W = quantizers.unpack_int4(
|
|
603
|
+
self.quantized_kernel,
|
|
604
|
+
orig_len=self.units,
|
|
605
|
+
axis=0,
|
|
606
|
+
dtype="uint8",
|
|
607
|
+
)
|
|
608
|
+
# Dequantize using scale/zero maps
|
|
609
|
+
W = ops.transpose(
|
|
610
|
+
dequantize_with_sz_map(
|
|
611
|
+
W,
|
|
612
|
+
self.kernel_scale,
|
|
613
|
+
self.kernel_zero,
|
|
614
|
+
self.g_idx,
|
|
615
|
+
)
|
|
616
|
+
)
|
|
617
|
+
# Apply AWQ scales by dividing to restore original magnitude
|
|
618
|
+
# (We multiplied by scales before quantization, so divide to undo)
|
|
619
|
+
# awq_scales has shape [input_dim], W has shape [input_dim, units]
|
|
620
|
+
# Expand dims for proper broadcasting.
|
|
621
|
+
W = ops.divide(W, ops.expand_dims(self.awq_scales, -1))
|
|
622
|
+
|
|
623
|
+
y = ops.matmul(inputs, W)
|
|
624
|
+
if self.bias is not None:
|
|
625
|
+
y = ops.add(y, self.bias)
|
|
626
|
+
if self.activation is not None:
|
|
627
|
+
y = self.activation(y)
|
|
628
|
+
return y
|
|
629
|
+
|
|
518
630
|
def _int4_build(self, kernel_shape, config=None):
|
|
519
631
|
"""Build variables for int4 quantization.
|
|
520
632
|
|
|
@@ -835,6 +947,8 @@ class Dense(Layer):
|
|
|
835
947
|
self.kernel_scale.assign(kernel_scale)
|
|
836
948
|
elif mode == "gptq":
|
|
837
949
|
self.quantized_build(kernel_shape, mode, self.quantization_config)
|
|
950
|
+
elif mode == "awq":
|
|
951
|
+
self.quantized_build(kernel_shape, mode, self.quantization_config)
|
|
838
952
|
elif mode == "float8":
|
|
839
953
|
self.quantized_build(kernel_shape, mode)
|
|
840
954
|
else:
|
|
@@ -847,6 +961,8 @@ class Dense(Layer):
|
|
|
847
961
|
policy_name = mode
|
|
848
962
|
if mode == "gptq":
|
|
849
963
|
policy_name = self.quantization_config.dtype_policy_string()
|
|
964
|
+
elif mode == "awq":
|
|
965
|
+
policy_name = self.quantization_config.dtype_policy_string()
|
|
850
966
|
policy = dtype_policies.get(
|
|
851
967
|
f"{policy_name}_from_{self.dtype_policy.name}"
|
|
852
968
|
)
|
|
@@ -881,7 +997,7 @@ class Dense(Layer):
|
|
|
881
997
|
`kernel_scale`: The quantization scale for the merged kernel.
|
|
882
998
|
This is `None` if the layer is not quantized.
|
|
883
999
|
"""
|
|
884
|
-
if self.dtype_policy.quantization_mode in (None, "gptq"):
|
|
1000
|
+
if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
|
|
885
1001
|
return self.kernel, None
|
|
886
1002
|
|
|
887
1003
|
kernel_value = self._kernel
|