keras-nightly 3.14.0.dev2026010104__py3-none-any.whl → 3.14.0.dev2026012204__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/ops/__init__.py +2 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +2 -0
- keras/_tf_keras/keras/quantizers/__init__.py +1 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/ops/__init__.py +2 -0
- keras/ops/numpy/__init__.py +2 -0
- keras/quantizers/__init__.py +1 -0
- keras/src/backend/jax/nn.py +26 -9
- keras/src/backend/jax/numpy.py +10 -0
- keras/src/backend/numpy/numpy.py +15 -0
- keras/src/backend/openvino/numpy.py +338 -17
- keras/src/backend/tensorflow/numpy.py +24 -1
- keras/src/backend/tensorflow/rnn.py +17 -7
- keras/src/backend/torch/numpy.py +26 -0
- keras/src/backend/torch/rnn.py +28 -11
- keras/src/callbacks/orbax_checkpoint.py +75 -42
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +90 -1
- keras/src/layers/core/dense.py +122 -6
- keras/src/layers/core/einsum_dense.py +151 -7
- keras/src/layers/core/embedding.py +1 -1
- keras/src/layers/core/reversible_embedding.py +10 -1
- keras/src/layers/layer.py +5 -0
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/center_crop.py +13 -15
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/losses/losses.py +24 -0
- keras/src/models/model.py +18 -9
- keras/src/ops/image.py +106 -93
- keras/src/ops/numpy.py +138 -0
- keras/src/quantizers/__init__.py +2 -0
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +1 -2
- keras/src/quantizers/gptq_core.py +1 -1
- keras/src/quantizers/quantization_config.py +14 -0
- keras/src/quantizers/quantizers.py +61 -52
- keras/src/random/seed_generator.py +2 -2
- keras/src/saving/orbax_util.py +50 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/utils/jax_layer.py +69 -31
- keras/src/utils/module_utils.py +11 -0
- keras/src/utils/tracking.py +5 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/METADATA +1 -1
- {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/RECORD +52 -48
- {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/WHEEL +1 -1
- {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/top_level.txt +0 -0
|
@@ -2125,6 +2125,22 @@ def moveaxis(x, source, destination):
|
|
|
2125
2125
|
return tf.transpose(x, perm)
|
|
2126
2126
|
|
|
2127
2127
|
|
|
2128
|
+
def nansum(x, axis=None, keepdims=False):
|
|
2129
|
+
x = convert_to_tensor(x)
|
|
2130
|
+
dtype = standardize_dtype(x.dtype)
|
|
2131
|
+
x_clean = tf.where(
|
|
2132
|
+
tf.math.is_nan(cast(x, config.floatx())), tf.zeros((), dtype=dtype), x
|
|
2133
|
+
)
|
|
2134
|
+
|
|
2135
|
+
if dtype in ("bool", "int8", "int16"):
|
|
2136
|
+
dtype = "int32"
|
|
2137
|
+
elif dtype in ("uint8", "uint16"):
|
|
2138
|
+
dtype = "uint32"
|
|
2139
|
+
x_clean = cast(x_clean, dtype)
|
|
2140
|
+
|
|
2141
|
+
return tf.reduce_sum(x_clean, axis=axis, keepdims=keepdims)
|
|
2142
|
+
|
|
2143
|
+
|
|
2128
2144
|
def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
|
|
2129
2145
|
x = convert_to_tensor(x)
|
|
2130
2146
|
|
|
@@ -2151,7 +2167,7 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
|
|
|
2151
2167
|
|
|
2152
2168
|
def ndim(x):
|
|
2153
2169
|
x = convert_to_tensor(x)
|
|
2154
|
-
return x.
|
|
2170
|
+
return x.shape.rank
|
|
2155
2171
|
|
|
2156
2172
|
|
|
2157
2173
|
def nonzero(x):
|
|
@@ -2215,6 +2231,13 @@ def prod(x, axis=None, keepdims=False, dtype=None):
|
|
|
2215
2231
|
return tf.reduce_prod(x, axis=axis, keepdims=keepdims)
|
|
2216
2232
|
|
|
2217
2233
|
|
|
2234
|
+
def ptp(x, axis=None, keepdims=False):
|
|
2235
|
+
x = convert_to_tensor(x)
|
|
2236
|
+
return tf.reduce_max(x, axis=axis, keepdims=keepdims) - tf.reduce_min(
|
|
2237
|
+
x, axis=axis, keepdims=keepdims
|
|
2238
|
+
)
|
|
2239
|
+
|
|
2240
|
+
|
|
2218
2241
|
def _quantile(x, q, axis=None, method="linear", keepdims=False):
|
|
2219
2242
|
# ref: tfp.stats.percentile
|
|
2220
2243
|
# float64 is needed here and below, else we get the wrong index if the array
|
|
@@ -539,11 +539,21 @@ def _do_lstm_arguments_support_cudnn(
|
|
|
539
539
|
|
|
540
540
|
|
|
541
541
|
def _has_fully_masked_sequence(mask):
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
542
|
+
"""Check if input sequence contains any fully masked data.
|
|
543
|
+
|
|
544
|
+
cuDNN kernel will error out if the input sequence contains any fully masked
|
|
545
|
+
data. We work around this issue by rerouting the computation to the
|
|
546
|
+
standard kernel until the issue on the cuDNN side has been fixed. For a
|
|
547
|
+
fully masked sequence, it will contain all `False` values. To make it easy
|
|
548
|
+
to check, we invert the boolean and check if any of the sequences has all
|
|
549
|
+
`True` values.
|
|
550
|
+
|
|
551
|
+
Args:
|
|
552
|
+
mask: The mask tensor.
|
|
553
|
+
|
|
554
|
+
Returns:
|
|
555
|
+
A boolean tensor, `True` if the mask contains a fully masked sequence.
|
|
556
|
+
"""
|
|
547
557
|
return tf.reduce_any(
|
|
548
558
|
tf.reduce_all(tf.logical_not(tf.cast(mask, dtype="bool")), axis=1)
|
|
549
559
|
)
|
|
@@ -900,8 +910,8 @@ def _cudnn_lstm(
|
|
|
900
910
|
|
|
901
911
|
if tf.sysconfig.get_build_info()["is_rocm_build"]:
|
|
902
912
|
# ROCm MIOpen's weight sequence for LSTM is different from both
|
|
903
|
-
# canonical and
|
|
904
|
-
# MIOpen: [i, f, o, c]
|
|
913
|
+
# canonical and cuDNN format
|
|
914
|
+
# MIOpen: [i, f, o, c] cuDNN/Canonical: [i, f, c, o]
|
|
905
915
|
# i is input gate weights.
|
|
906
916
|
# f is forget gate weights.
|
|
907
917
|
# o is output gate weights.
|
keras/src/backend/torch/numpy.py
CHANGED
|
@@ -1272,6 +1272,20 @@ def moveaxis(x, source, destination):
|
|
|
1272
1272
|
return torch.moveaxis(x, source=source, destination=destination)
|
|
1273
1273
|
|
|
1274
1274
|
|
|
1275
|
+
def nansum(x, axis=None, keepdims=False):
|
|
1276
|
+
if isinstance(x, (list, tuple)):
|
|
1277
|
+
x = stack(x)
|
|
1278
|
+
x = convert_to_tensor(x)
|
|
1279
|
+
dtype = standardize_dtype(x.dtype)
|
|
1280
|
+
|
|
1281
|
+
if dtype in ("bool", "uint8", "int8", "int16"):
|
|
1282
|
+
dtype = "int32"
|
|
1283
|
+
|
|
1284
|
+
if axis == () or axis == []:
|
|
1285
|
+
return cast(torch.nan_to_num(x, nan=0), dtype)
|
|
1286
|
+
return cast(torch.nansum(x, dim=axis, keepdim=keepdims), dtype)
|
|
1287
|
+
|
|
1288
|
+
|
|
1275
1289
|
def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
|
|
1276
1290
|
x = convert_to_tensor(x)
|
|
1277
1291
|
return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)
|
|
@@ -1382,6 +1396,18 @@ def prod(x, axis=None, keepdims=False, dtype=None):
|
|
|
1382
1396
|
return x
|
|
1383
1397
|
|
|
1384
1398
|
|
|
1399
|
+
def ptp(x, axis=None, keepdims=False):
|
|
1400
|
+
x = convert_to_tensor(x)
|
|
1401
|
+
if axis is None:
|
|
1402
|
+
return x.max() - x.min()
|
|
1403
|
+
elif axis == ():
|
|
1404
|
+
return torch.zeros_like(x)
|
|
1405
|
+
else:
|
|
1406
|
+
return torch.amax(x, dim=axis, keepdim=keepdims) - torch.amin(
|
|
1407
|
+
x, dim=axis, keepdim=keepdims
|
|
1408
|
+
)
|
|
1409
|
+
|
|
1410
|
+
|
|
1385
1411
|
def quantile(x, q, axis=None, method="linear", keepdims=False):
|
|
1386
1412
|
x = convert_to_tensor(x)
|
|
1387
1413
|
q = convert_to_tensor(q)
|
keras/src/backend/torch/rnn.py
CHANGED
|
@@ -413,11 +413,21 @@ def _is_sequence_right_padded(mask):
|
|
|
413
413
|
|
|
414
414
|
|
|
415
415
|
def _has_fully_masked_sequence(mask):
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
416
|
+
"""Check if input sequence contains any fully masked data.
|
|
417
|
+
|
|
418
|
+
cuDNN kernel will error out if the input sequence contains any fully masked
|
|
419
|
+
data. We work around this issue by rerouting the computation to the
|
|
420
|
+
standard kernel until the issue on the cuDNN side has been fixed. For a
|
|
421
|
+
fully masked sequence, it will contain all `False` values. To make it easy
|
|
422
|
+
to check, we invert the boolean and check if any of the sequences has all
|
|
423
|
+
`True` values.
|
|
424
|
+
|
|
425
|
+
Args:
|
|
426
|
+
mask: The mask tensor.
|
|
427
|
+
|
|
428
|
+
Returns:
|
|
429
|
+
A boolean tensor, `True` if the mask contains a fully masked sequence.
|
|
430
|
+
"""
|
|
421
431
|
return torch.any(torch.all(~mask, dim=1))
|
|
422
432
|
|
|
423
433
|
|
|
@@ -447,8 +457,8 @@ def _compute_sequence_length_from_mask(mask, batch_first):
|
|
|
447
457
|
The masking tensor is a 2D boolean tensor with shape [batch, timestep]. For
|
|
448
458
|
any timestep that should be masked, the corresponding field will be False.
|
|
449
459
|
Consider the following example:
|
|
450
|
-
|
|
451
|
-
|
|
460
|
+
a = [[True, True, False, False]
|
|
461
|
+
[True, True, True, False]]
|
|
452
462
|
It is a (2, 4) tensor, and the corresponding sequence length result should
|
|
453
463
|
be 1D tensor with value [2, 3]. Note that the masking tensor must be right
|
|
454
464
|
padded that could be checked by, e.g., `is_sequence_right_padded()`.
|
|
@@ -467,12 +477,19 @@ def _compute_sequence_length_from_mask(mask, batch_first):
|
|
|
467
477
|
|
|
468
478
|
|
|
469
479
|
def prepare_lstm_weights(lstm, kernel, recurrent_kernel, bias, device):
|
|
470
|
-
"""Copies kernel and recurrent kernel weights
|
|
480
|
+
"""Copies kernel and recurrent kernel weights into the PyTorch format.
|
|
481
|
+
|
|
471
482
|
We split the kernel and recurrent kernel weights, create associated
|
|
472
|
-
torch tensors adapted to be in line with the
|
|
473
|
-
After we have copied the weights, we ensure the
|
|
474
|
-
the same device and memory layout is optimized for
|
|
483
|
+
torch tensors adapted to be in line with the cuDNN optimization.
|
|
484
|
+
After we have copied the weights, we ensure the parameters are on
|
|
485
|
+
the same device and memory layout is optimized for cuDNN.
|
|
475
486
|
|
|
487
|
+
Args:
|
|
488
|
+
lstm: The PyTorch LSTM layer to prepare weights for.
|
|
489
|
+
kernel: The kernel weights tensor.
|
|
490
|
+
recurrent_kernel: The recurrent kernel weights tensor.
|
|
491
|
+
bias: The bias tensor.
|
|
492
|
+
device: The device to place the tensors on.
|
|
476
493
|
"""
|
|
477
494
|
|
|
478
495
|
lstm = lstm.to(device)
|
|
@@ -8,7 +8,6 @@ from keras.src.api_export import keras_export
|
|
|
8
8
|
from keras.src.callbacks.monitor_callback import (
|
|
9
9
|
MonitorCallback, # For metric monitoring logic
|
|
10
10
|
)
|
|
11
|
-
from keras.src.utils.io_utils import print_msg
|
|
12
11
|
from keras.src.utils.module_utils import ocp
|
|
13
12
|
|
|
14
13
|
# Context and AsyncOptions are accessed through the lazy-loaded ocp module
|
|
@@ -62,6 +61,11 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
62
61
|
This callback saves the model's weights and optimizer state asynchronously
|
|
63
62
|
using Orbax, allowing training to continue without blocking for I/O.
|
|
64
63
|
|
|
64
|
+
**Multi-host Support**: When running in a multi-host distributed training
|
|
65
|
+
environment with JAX backend, this callback automatically coordinates
|
|
66
|
+
checkpointing across all hosts to ensure consistency and proper
|
|
67
|
+
synchronization. Multi-host checkpointing is only supported on JAX.
|
|
68
|
+
|
|
65
69
|
Example:
|
|
66
70
|
|
|
67
71
|
```python
|
|
@@ -92,10 +96,6 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
92
96
|
verbose: Verbosity mode, 0 or 1.
|
|
93
97
|
save_best_only: if `save_best_only=True`, it only saves when the model
|
|
94
98
|
is considered the "best" based on the monitored quantity.
|
|
95
|
-
save_weights_only: if `save_weights_only=True`, only the model's
|
|
96
|
-
weights will be saved. Otherwise, the full model state
|
|
97
|
-
(weights, non-trainable variables, optimizer state, and
|
|
98
|
-
metrics state) will be saved. Defaults to False.
|
|
99
99
|
mode: one of {'auto', 'min', 'max'}. Used with `save_best_only`.
|
|
100
100
|
save_freq: `'epoch'` or integer. Frequency to save checkpoints.
|
|
101
101
|
max_to_keep: Integer, maximum number of recent checkpoints to keep.
|
|
@@ -112,7 +112,6 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
112
112
|
monitor="val_loss",
|
|
113
113
|
verbose=0,
|
|
114
114
|
save_best_only=False,
|
|
115
|
-
save_weights_only=False,
|
|
116
115
|
mode="auto",
|
|
117
116
|
save_freq="epoch",
|
|
118
117
|
initial_value_threshold=None,
|
|
@@ -129,7 +128,6 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
129
128
|
self.directory = directory
|
|
130
129
|
self.verbose = verbose
|
|
131
130
|
self.save_best_only = save_best_only
|
|
132
|
-
self.save_weights_only = save_weights_only
|
|
133
131
|
self.save_freq = save_freq
|
|
134
132
|
self.max_to_keep = max_to_keep
|
|
135
133
|
self.save_on_background = save_on_background
|
|
@@ -138,6 +136,9 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
138
136
|
self._current_epoch = 0 # Keep track of epoch
|
|
139
137
|
self._total_batches_seen = 0 # Global batch counter for step tracking
|
|
140
138
|
|
|
139
|
+
# Multi-host support
|
|
140
|
+
self._multihost_initialized = self._is_multihost_initialized()
|
|
141
|
+
|
|
141
142
|
if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
|
|
142
143
|
raise ValueError(
|
|
143
144
|
f"Unrecognized save_freq: {self.save_freq}. "
|
|
@@ -151,14 +152,18 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
151
152
|
ocp.training.preservation_policies.LatestN(max_to_keep)
|
|
152
153
|
)
|
|
153
154
|
|
|
154
|
-
# Use AnyPreservationPolicy to combine them
|
|
155
|
+
# Use AnyPreservationPolicy to combine them, or use directly
|
|
156
|
+
# if single policy
|
|
155
157
|
preservation_policy = None
|
|
156
158
|
if policies:
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
159
|
+
if len(policies) == 1:
|
|
160
|
+
preservation_policy = policies[0]
|
|
161
|
+
else:
|
|
162
|
+
preservation_policy = (
|
|
163
|
+
ocp.training.preservation_policies.AnyPreservationPolicy(
|
|
164
|
+
policies
|
|
165
|
+
)
|
|
160
166
|
)
|
|
161
|
-
)
|
|
162
167
|
|
|
163
168
|
# Create the V1 Checkpointer with direct parameter passing
|
|
164
169
|
# Orbax will handle directory creation on all processes as needed
|
|
@@ -167,6 +172,54 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
167
172
|
preservation_policy=preservation_policy,
|
|
168
173
|
)
|
|
169
174
|
|
|
175
|
+
def _is_multihost_initialized(self):
|
|
176
|
+
"""Check if multi-host environment is initialized."""
|
|
177
|
+
# Multi-host checkpointing is only supported on JAX backend
|
|
178
|
+
if backend.backend() != "jax":
|
|
179
|
+
return False
|
|
180
|
+
|
|
181
|
+
multihost = ocp.multihost
|
|
182
|
+
# Check if JAX distributed client is initialized
|
|
183
|
+
# (indicates multihost setup)
|
|
184
|
+
return multihost.is_jax_distributed_client_initialized()
|
|
185
|
+
|
|
186
|
+
def _sync_processes(self, key=None):
|
|
187
|
+
"""Synchronize all processes across hosts."""
|
|
188
|
+
if not self._multihost_initialized:
|
|
189
|
+
return # No-op for single host
|
|
190
|
+
|
|
191
|
+
multihost = ocp.multihost
|
|
192
|
+
sync_key = key or "orbax_checkpoint_sync"
|
|
193
|
+
multihost.sync_global_processes(sync_key)
|
|
194
|
+
|
|
195
|
+
def is_multihost_enabled(self):
|
|
196
|
+
"""Return True if multi-host checkpointing is enabled and initialized.
|
|
197
|
+
|
|
198
|
+
This method can be used to check if the callback is operating in
|
|
199
|
+
a multi-host distributed training environment. Multi-host checkpointing
|
|
200
|
+
is only supported on JAX backend.
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
bool: True if multi-host support is active, False otherwise.
|
|
204
|
+
"""
|
|
205
|
+
return self._multihost_initialized
|
|
206
|
+
|
|
207
|
+
def is_primary_host(self):
|
|
208
|
+
"""Return True if this process is the primary host in multi-host setup.
|
|
209
|
+
|
|
210
|
+
In multi-host environments, only the primary host typically handles
|
|
211
|
+
logging and coordination tasks. Multi-host checkpointing is only
|
|
212
|
+
supported on JAX backend.
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
bool: True if this is the primary host, False otherwise.
|
|
216
|
+
Always returns True in single-host environments.
|
|
217
|
+
"""
|
|
218
|
+
if not self._multihost_initialized:
|
|
219
|
+
return True # Single host is always primary
|
|
220
|
+
multihost = ocp.multihost
|
|
221
|
+
return multihost.is_primary_host()
|
|
222
|
+
|
|
170
223
|
def _should_save_on_batch(self, batch):
|
|
171
224
|
"""Check if we should save on this batch."""
|
|
172
225
|
if self.save_freq == "epoch":
|
|
@@ -186,32 +239,14 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
186
239
|
return False
|
|
187
240
|
|
|
188
241
|
def _save_checkpoint(self, step, logs=None):
|
|
189
|
-
"""Save a checkpoint at the given step."""
|
|
242
|
+
"""Save a checkpoint at the given step with multi-host coordination."""
|
|
190
243
|
|
|
191
244
|
# --- Prepare Composite State (Backend-Agnostic) ---
|
|
192
245
|
state_tree = _get_state_tree(self.model)
|
|
193
246
|
|
|
194
247
|
# Save the nested state structures directly (preserving layer
|
|
195
248
|
# names and structure)
|
|
196
|
-
|
|
197
|
-
composite_state = {
|
|
198
|
-
"trainable_variables": state_tree["trainable_variables"],
|
|
199
|
-
}
|
|
200
|
-
if "non_trainable_variables" in state_tree:
|
|
201
|
-
composite_state["non_trainable_variables"] = state_tree[
|
|
202
|
-
"non_trainable_variables"
|
|
203
|
-
]
|
|
204
|
-
else:
|
|
205
|
-
composite_state = state_tree
|
|
206
|
-
|
|
207
|
-
# --- Save Logic (V1 API) ---
|
|
208
|
-
# All processes participate in distributed checkpointing
|
|
209
|
-
# Checkpointer is configured to save unconditionally when
|
|
210
|
-
# save_pytree is called
|
|
211
|
-
if self.verbose > 0:
|
|
212
|
-
print_msg(
|
|
213
|
-
f"OrbaxCheckpoint: Triggering async save for step {step}..."
|
|
214
|
-
)
|
|
249
|
+
composite_state = state_tree
|
|
215
250
|
|
|
216
251
|
# Use a single with statement. If context_options is empty,
|
|
217
252
|
# Context() uses defaults.
|
|
@@ -282,18 +317,16 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
282
317
|
except Exception:
|
|
283
318
|
pass # Ignore errors during cleanup
|
|
284
319
|
|
|
320
|
+
# Multi-host synchronization: ensure all hosts complete cleanup
|
|
321
|
+
self._sync_processes("checkpoint_cleanup")
|
|
322
|
+
|
|
285
323
|
def wait_until_finished(self):
|
|
286
324
|
"""Wait for any in-progress checkpoint operations to complete.
|
|
287
325
|
This method blocks until all asynchronous checkpoint save operations
|
|
288
|
-
have completed
|
|
289
|
-
checkpoints if there might be pending save operations.
|
|
326
|
+
have completed across all hosts in a multi-host setup.
|
|
290
327
|
"""
|
|
291
|
-
# Wait for any async operations to complete
|
|
292
|
-
|
|
293
|
-
self.checkpointer.wait()
|
|
294
|
-
else:
|
|
295
|
-
# Fallback for older Orbax versions that don't have wait() method
|
|
296
|
-
while self.checkpointer.is_saving_in_progress():
|
|
297
|
-
import time
|
|
328
|
+
# Wait for any async operations to complete on this host
|
|
329
|
+
self.checkpointer.wait()
|
|
298
330
|
|
|
299
|
-
|
|
331
|
+
# Multi-host synchronization: ensure all hosts complete
|
|
332
|
+
self._sync_processes("checkpoint_wait_complete")
|
|
@@ -2,6 +2,7 @@ from keras.src import backend
|
|
|
2
2
|
from keras.src.api_export import keras_export
|
|
3
3
|
from keras.src.dtype_policies import dtype_policy
|
|
4
4
|
from keras.src.dtype_policies.dtype_policy import QUANTIZATION_MODES
|
|
5
|
+
from keras.src.dtype_policies.dtype_policy import AWQDTypePolicy
|
|
5
6
|
from keras.src.dtype_policies.dtype_policy import DTypePolicy
|
|
6
7
|
from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy
|
|
7
8
|
from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy
|
|
@@ -10,6 +11,7 @@ from keras.src.dtype_policies.dtype_policy import QuantizedFloat8DTypePolicy
|
|
|
10
11
|
from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
|
|
11
12
|
|
|
12
13
|
ALL_OBJECTS = {
|
|
14
|
+
AWQDTypePolicy,
|
|
13
15
|
DTypePolicy,
|
|
14
16
|
FloatDTypePolicy,
|
|
15
17
|
QuantizedDTypePolicy,
|
|
@@ -3,7 +3,7 @@ from keras.src import ops
|
|
|
3
3
|
from keras.src.api_export import keras_export
|
|
4
4
|
from keras.src.backend.common import global_state
|
|
5
5
|
|
|
6
|
-
QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq")
|
|
6
|
+
QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq", "awq")
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
@keras_export(
|
|
@@ -376,6 +376,93 @@ class GPTQDTypePolicy(QuantizedDTypePolicy):
|
|
|
376
376
|
return config
|
|
377
377
|
|
|
378
378
|
|
|
379
|
+
@keras_export("keras.dtype_policies.AWQDTypePolicy")
|
|
380
|
+
class AWQDTypePolicy(QuantizedDTypePolicy):
|
|
381
|
+
"""Quantized dtype policy for AWQ quantization.
|
|
382
|
+
|
|
383
|
+
This policy helps propagate quantization settings for AWQ
|
|
384
|
+
when loading an AWQ quantized model in Keras format.
|
|
385
|
+
|
|
386
|
+
Args:
|
|
387
|
+
mode: The quantization mode. This should be a string in the format
|
|
388
|
+
`"awq/<weight_bits>/<group_size>"`.
|
|
389
|
+
- `"awq"`: The identifier for the quantization algorithm.
|
|
390
|
+
- `<weight_bits>`: Number of bits to quantize weights to.
|
|
391
|
+
AWQ presently only supports 4-bit quantization.
|
|
392
|
+
- `<group_size>`: The group size for quantization. Supported
|
|
393
|
+
values are -1 (for per-channel quantization) or any
|
|
394
|
+
positive integer.
|
|
395
|
+
Example: `"awq/4/128"`.
|
|
396
|
+
source_name: The source dtype policy name, e.g. "float32".
|
|
397
|
+
"""
|
|
398
|
+
|
|
399
|
+
def __init__(
|
|
400
|
+
self,
|
|
401
|
+
mode,
|
|
402
|
+
source_name=None,
|
|
403
|
+
):
|
|
404
|
+
parts = mode.split("/")
|
|
405
|
+
expected_format = "'awq/<weight_bits>/<group_size>'"
|
|
406
|
+
|
|
407
|
+
# Validate format.
|
|
408
|
+
if len(parts) != 3 or parts[0] != "awq":
|
|
409
|
+
raise ValueError(
|
|
410
|
+
"Invalid mode for AWQDTypePolicy. Expected format "
|
|
411
|
+
f"{expected_format}, but got '{mode}'."
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
# Validate and cast weight_bits and group_size.
|
|
415
|
+
try:
|
|
416
|
+
weight_bits = int(parts[1])
|
|
417
|
+
group_size = int(parts[2])
|
|
418
|
+
except ValueError:
|
|
419
|
+
raise ValueError(
|
|
420
|
+
"Invalid mode for AWQDTypePolicy. <weight_bits> and "
|
|
421
|
+
"<group_size> must be integers. Expected format "
|
|
422
|
+
f"{expected_format}, but got '{mode}'."
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
# AWQ presently only supports 4-bit quantization.
|
|
426
|
+
if weight_bits != 4:
|
|
427
|
+
raise ValueError(
|
|
428
|
+
"Invalid weight_bits in mode. AWQ only supports 4-bit "
|
|
429
|
+
f"quantization, but got {weight_bits} from '{mode}'."
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
if group_size < -1 or group_size == 0:
|
|
433
|
+
raise ValueError(
|
|
434
|
+
"Invalid group_size in mode. Supported values are "
|
|
435
|
+
"-1 (per-channel) or a positive integer, "
|
|
436
|
+
f"but got {group_size} from '{mode}'."
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
base_mode = parts[0]
|
|
440
|
+
super().__init__(
|
|
441
|
+
mode=base_mode,
|
|
442
|
+
source_name=source_name,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
self._name = f"{mode}_from_{source_name}"
|
|
446
|
+
self.mode = base_mode
|
|
447
|
+
self.weight_bits = weight_bits
|
|
448
|
+
self.group_size = group_size
|
|
449
|
+
|
|
450
|
+
def __eq__(self, other):
|
|
451
|
+
if super().__eq__(other) is False:
|
|
452
|
+
return False
|
|
453
|
+
return (
|
|
454
|
+
self.weight_bits == other.weight_bits
|
|
455
|
+
and self.group_size == other.group_size
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
def get_config(self):
|
|
459
|
+
config = super().get_config()
|
|
460
|
+
# Reconstruct the full mode string for serialization
|
|
461
|
+
mode = f"{self.mode}/{self.weight_bits}/{self.group_size}"
|
|
462
|
+
config.update({"mode": mode})
|
|
463
|
+
return config
|
|
464
|
+
|
|
465
|
+
|
|
379
466
|
@keras_export(
|
|
380
467
|
[
|
|
381
468
|
"keras.config.set_dtype_policy",
|
|
@@ -442,6 +529,8 @@ def _get_quantized_dtype_policy_by_str(policy):
|
|
|
442
529
|
return QuantizedDTypePolicy(mode, source_name)
|
|
443
530
|
elif policy.startswith("gptq"):
|
|
444
531
|
return GPTQDTypePolicy(mode, source_name)
|
|
532
|
+
elif policy.startswith("awq"):
|
|
533
|
+
return AWQDTypePolicy(mode, source_name)
|
|
445
534
|
elif policy.startswith("float8"):
|
|
446
535
|
return QuantizedFloat8DTypePolicy(mode, source_name)
|
|
447
536
|
else:
|