keras-nightly 3.14.0.dev2026010104__py3-none-any.whl → 3.14.0.dev2026012204__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  2. keras/_tf_keras/keras/ops/__init__.py +2 -0
  3. keras/_tf_keras/keras/ops/numpy/__init__.py +2 -0
  4. keras/_tf_keras/keras/quantizers/__init__.py +1 -0
  5. keras/dtype_policies/__init__.py +3 -0
  6. keras/ops/__init__.py +2 -0
  7. keras/ops/numpy/__init__.py +2 -0
  8. keras/quantizers/__init__.py +1 -0
  9. keras/src/backend/jax/nn.py +26 -9
  10. keras/src/backend/jax/numpy.py +10 -0
  11. keras/src/backend/numpy/numpy.py +15 -0
  12. keras/src/backend/openvino/numpy.py +338 -17
  13. keras/src/backend/tensorflow/numpy.py +24 -1
  14. keras/src/backend/tensorflow/rnn.py +17 -7
  15. keras/src/backend/torch/numpy.py +26 -0
  16. keras/src/backend/torch/rnn.py +28 -11
  17. keras/src/callbacks/orbax_checkpoint.py +75 -42
  18. keras/src/dtype_policies/__init__.py +2 -0
  19. keras/src/dtype_policies/dtype_policy.py +90 -1
  20. keras/src/layers/core/dense.py +122 -6
  21. keras/src/layers/core/einsum_dense.py +151 -7
  22. keras/src/layers/core/embedding.py +1 -1
  23. keras/src/layers/core/reversible_embedding.py +10 -1
  24. keras/src/layers/layer.py +5 -0
  25. keras/src/layers/preprocessing/feature_space.py +8 -4
  26. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  27. keras/src/layers/preprocessing/image_preprocessing/center_crop.py +13 -15
  28. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  29. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  30. keras/src/losses/losses.py +24 -0
  31. keras/src/models/model.py +18 -9
  32. keras/src/ops/image.py +106 -93
  33. keras/src/ops/numpy.py +138 -0
  34. keras/src/quantizers/__init__.py +2 -0
  35. keras/src/quantizers/awq.py +361 -0
  36. keras/src/quantizers/awq_config.py +140 -0
  37. keras/src/quantizers/awq_core.py +217 -0
  38. keras/src/quantizers/gptq.py +1 -2
  39. keras/src/quantizers/gptq_core.py +1 -1
  40. keras/src/quantizers/quantization_config.py +14 -0
  41. keras/src/quantizers/quantizers.py +61 -52
  42. keras/src/random/seed_generator.py +2 -2
  43. keras/src/saving/orbax_util.py +50 -0
  44. keras/src/saving/saving_api.py +37 -14
  45. keras/src/utils/jax_layer.py +69 -31
  46. keras/src/utils/module_utils.py +11 -0
  47. keras/src/utils/tracking.py +5 -5
  48. keras/src/version.py +1 -1
  49. {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/METADATA +1 -1
  50. {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/RECORD +52 -48
  51. {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/WHEEL +1 -1
  52. {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/top_level.txt +0 -0
@@ -2125,6 +2125,22 @@ def moveaxis(x, source, destination):
2125
2125
  return tf.transpose(x, perm)
2126
2126
 
2127
2127
 
2128
+ def nansum(x, axis=None, keepdims=False):
2129
+ x = convert_to_tensor(x)
2130
+ dtype = standardize_dtype(x.dtype)
2131
+ x_clean = tf.where(
2132
+ tf.math.is_nan(cast(x, config.floatx())), tf.zeros((), dtype=dtype), x
2133
+ )
2134
+
2135
+ if dtype in ("bool", "int8", "int16"):
2136
+ dtype = "int32"
2137
+ elif dtype in ("uint8", "uint16"):
2138
+ dtype = "uint32"
2139
+ x_clean = cast(x_clean, dtype)
2140
+
2141
+ return tf.reduce_sum(x_clean, axis=axis, keepdims=keepdims)
2142
+
2143
+
2128
2144
  def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
2129
2145
  x = convert_to_tensor(x)
2130
2146
 
@@ -2151,7 +2167,7 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
2151
2167
 
2152
2168
  def ndim(x):
2153
2169
  x = convert_to_tensor(x)
2154
- return x.ndim
2170
+ return x.shape.rank
2155
2171
 
2156
2172
 
2157
2173
  def nonzero(x):
@@ -2215,6 +2231,13 @@ def prod(x, axis=None, keepdims=False, dtype=None):
2215
2231
  return tf.reduce_prod(x, axis=axis, keepdims=keepdims)
2216
2232
 
2217
2233
 
2234
+ def ptp(x, axis=None, keepdims=False):
2235
+ x = convert_to_tensor(x)
2236
+ return tf.reduce_max(x, axis=axis, keepdims=keepdims) - tf.reduce_min(
2237
+ x, axis=axis, keepdims=keepdims
2238
+ )
2239
+
2240
+
2218
2241
  def _quantile(x, q, axis=None, method="linear", keepdims=False):
2219
2242
  # ref: tfp.stats.percentile
2220
2243
  # float64 is needed here and below, else we get the wrong index if the array
@@ -539,11 +539,21 @@ def _do_lstm_arguments_support_cudnn(
539
539
 
540
540
 
541
541
  def _has_fully_masked_sequence(mask):
542
- # Cudnn kernel will error out if the input sequence contains any
543
- # fully masked data. We walk around this issue by rerouting the computation
544
- # to standard kernel, until the issue on cudnn side has been fixed. For a
545
- # fully masked sequence, it will contain all Falses. To make it easy to
546
- # check, we inverse the boolean, check if any of the sequence has all True.
542
+ """Check if input sequence contains any fully masked data.
543
+
544
+ cuDNN kernel will error out if the input sequence contains any fully masked
545
+ data. We work around this issue by rerouting the computation to the
546
+ standard kernel until the issue on the cuDNN side has been fixed. For a
547
+ fully masked sequence, it will contain all `False` values. To make it easy
548
+ to check, we invert the boolean and check if any of the sequences has all
549
+ `True` values.
550
+
551
+ Args:
552
+ mask: The mask tensor.
553
+
554
+ Returns:
555
+ A boolean tensor, `True` if the mask contains a fully masked sequence.
556
+ """
547
557
  return tf.reduce_any(
548
558
  tf.reduce_all(tf.logical_not(tf.cast(mask, dtype="bool")), axis=1)
549
559
  )
@@ -900,8 +910,8 @@ def _cudnn_lstm(
900
910
 
901
911
  if tf.sysconfig.get_build_info()["is_rocm_build"]:
902
912
  # ROCm MIOpen's weight sequence for LSTM is different from both
903
- # canonical and Cudnn format
904
- # MIOpen: [i, f, o, c] Cudnn/Canonical: [i, f, c, o]
913
+ # canonical and cuDNN format
914
+ # MIOpen: [i, f, o, c] cuDNN/Canonical: [i, f, c, o]
905
915
  # i is input gate weights.
906
916
  # f is forget gate weights.
907
917
  # o is output gate weights.
@@ -1272,6 +1272,20 @@ def moveaxis(x, source, destination):
1272
1272
  return torch.moveaxis(x, source=source, destination=destination)
1273
1273
 
1274
1274
 
1275
+ def nansum(x, axis=None, keepdims=False):
1276
+ if isinstance(x, (list, tuple)):
1277
+ x = stack(x)
1278
+ x = convert_to_tensor(x)
1279
+ dtype = standardize_dtype(x.dtype)
1280
+
1281
+ if dtype in ("bool", "uint8", "int8", "int16"):
1282
+ dtype = "int32"
1283
+
1284
+ if axis == () or axis == []:
1285
+ return cast(torch.nan_to_num(x, nan=0), dtype)
1286
+ return cast(torch.nansum(x, dim=axis, keepdim=keepdims), dtype)
1287
+
1288
+
1275
1289
  def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
1276
1290
  x = convert_to_tensor(x)
1277
1291
  return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)
@@ -1382,6 +1396,18 @@ def prod(x, axis=None, keepdims=False, dtype=None):
1382
1396
  return x
1383
1397
 
1384
1398
 
1399
+ def ptp(x, axis=None, keepdims=False):
1400
+ x = convert_to_tensor(x)
1401
+ if axis is None:
1402
+ return x.max() - x.min()
1403
+ elif axis == ():
1404
+ return torch.zeros_like(x)
1405
+ else:
1406
+ return torch.amax(x, dim=axis, keepdim=keepdims) - torch.amin(
1407
+ x, dim=axis, keepdim=keepdims
1408
+ )
1409
+
1410
+
1385
1411
  def quantile(x, q, axis=None, method="linear", keepdims=False):
1386
1412
  x = convert_to_tensor(x)
1387
1413
  q = convert_to_tensor(q)
@@ -413,11 +413,21 @@ def _is_sequence_right_padded(mask):
413
413
 
414
414
 
415
415
  def _has_fully_masked_sequence(mask):
416
- # Cudnn kernel will error out if the input sequence contains any
417
- # fully masked data. We walk around this issue by rerouting the computation
418
- # to standard kernel, until the issue on cudnn side has been fixed. For a
419
- # fully masked sequence, it will contain all Falses. To make it easy to
420
- # check, we inverse the boolean, check if any of the sequence has all True.
416
+ """Check if input sequence contains any fully masked data.
417
+
418
+ cuDNN kernel will error out if the input sequence contains any fully masked
419
+ data. We work around this issue by rerouting the computation to the
420
+ standard kernel until the issue on the cuDNN side has been fixed. For a
421
+ fully masked sequence, it will contain all `False` values. To make it easy
422
+ to check, we invert the boolean and check if any of the sequences has all
423
+ `True` values.
424
+
425
+ Args:
426
+ mask: The mask tensor.
427
+
428
+ Returns:
429
+ A boolean tensor, `True` if the mask contains a fully masked sequence.
430
+ """
421
431
  return torch.any(torch.all(~mask, dim=1))
422
432
 
423
433
 
@@ -447,8 +457,8 @@ def _compute_sequence_length_from_mask(mask, batch_first):
447
457
  The masking tensor is a 2D boolean tensor with shape [batch, timestep]. For
448
458
  any timestep that should be masked, the corresponding field will be False.
449
459
  Consider the following example:
450
- a = [[True, True, False, False]
451
- [True, True, True, False]]
460
+ a = [[True, True, False, False]
461
+ [True, True, True, False]]
452
462
  It is a (2, 4) tensor, and the corresponding sequence length result should
453
463
  be 1D tensor with value [2, 3]. Note that the masking tensor must be right
454
464
  padded that could be checked by, e.g., `is_sequence_right_padded()`.
@@ -467,12 +477,19 @@ def _compute_sequence_length_from_mask(mask, batch_first):
467
477
 
468
478
 
469
479
  def prepare_lstm_weights(lstm, kernel, recurrent_kernel, bias, device):
470
- """Copies kernel and recurrent kernel weights in the Pytorch format
480
+ """Copies kernel and recurrent kernel weights into the PyTorch format.
481
+
471
482
  We split the kernel and recurrent kernel weights, create associated
472
- torch tensors adapted to be in line with the Cudnn optimization.
473
- After we have copied the weights, we ensure the paramters are on
474
- the same device and memory layout is optimized for Cudnn.
483
+ torch tensors adapted to be in line with the cuDNN optimization.
484
+ After we have copied the weights, we ensure the parameters are on
485
+ the same device and memory layout is optimized for cuDNN.
475
486
 
487
+ Args:
488
+ lstm: The PyTorch LSTM layer to prepare weights for.
489
+ kernel: The kernel weights tensor.
490
+ recurrent_kernel: The recurrent kernel weights tensor.
491
+ bias: The bias tensor.
492
+ device: The device to place the tensors on.
476
493
  """
477
494
 
478
495
  lstm = lstm.to(device)
@@ -8,7 +8,6 @@ from keras.src.api_export import keras_export
8
8
  from keras.src.callbacks.monitor_callback import (
9
9
  MonitorCallback, # For metric monitoring logic
10
10
  )
11
- from keras.src.utils.io_utils import print_msg
12
11
  from keras.src.utils.module_utils import ocp
13
12
 
14
13
  # Context and AsyncOptions are accessed through the lazy-loaded ocp module
@@ -62,6 +61,11 @@ class OrbaxCheckpoint(MonitorCallback):
62
61
  This callback saves the model's weights and optimizer state asynchronously
63
62
  using Orbax, allowing training to continue without blocking for I/O.
64
63
 
64
+ **Multi-host Support**: When running in a multi-host distributed training
65
+ environment with JAX backend, this callback automatically coordinates
66
+ checkpointing across all hosts to ensure consistency and proper
67
+ synchronization. Multi-host checkpointing is only supported on JAX.
68
+
65
69
  Example:
66
70
 
67
71
  ```python
@@ -92,10 +96,6 @@ class OrbaxCheckpoint(MonitorCallback):
92
96
  verbose: Verbosity mode, 0 or 1.
93
97
  save_best_only: if `save_best_only=True`, it only saves when the model
94
98
  is considered the "best" based on the monitored quantity.
95
- save_weights_only: if `save_weights_only=True`, only the model's
96
- weights will be saved. Otherwise, the full model state
97
- (weights, non-trainable variables, optimizer state, and
98
- metrics state) will be saved. Defaults to False.
99
99
  mode: one of {'auto', 'min', 'max'}. Used with `save_best_only`.
100
100
  save_freq: `'epoch'` or integer. Frequency to save checkpoints.
101
101
  max_to_keep: Integer, maximum number of recent checkpoints to keep.
@@ -112,7 +112,6 @@ class OrbaxCheckpoint(MonitorCallback):
112
112
  monitor="val_loss",
113
113
  verbose=0,
114
114
  save_best_only=False,
115
- save_weights_only=False,
116
115
  mode="auto",
117
116
  save_freq="epoch",
118
117
  initial_value_threshold=None,
@@ -129,7 +128,6 @@ class OrbaxCheckpoint(MonitorCallback):
129
128
  self.directory = directory
130
129
  self.verbose = verbose
131
130
  self.save_best_only = save_best_only
132
- self.save_weights_only = save_weights_only
133
131
  self.save_freq = save_freq
134
132
  self.max_to_keep = max_to_keep
135
133
  self.save_on_background = save_on_background
@@ -138,6 +136,9 @@ class OrbaxCheckpoint(MonitorCallback):
138
136
  self._current_epoch = 0 # Keep track of epoch
139
137
  self._total_batches_seen = 0 # Global batch counter for step tracking
140
138
 
139
+ # Multi-host support
140
+ self._multihost_initialized = self._is_multihost_initialized()
141
+
141
142
  if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
142
143
  raise ValueError(
143
144
  f"Unrecognized save_freq: {self.save_freq}. "
@@ -151,14 +152,18 @@ class OrbaxCheckpoint(MonitorCallback):
151
152
  ocp.training.preservation_policies.LatestN(max_to_keep)
152
153
  )
153
154
 
154
- # Use AnyPreservationPolicy to combine them.
155
+ # Use AnyPreservationPolicy to combine them, or use directly
156
+ # if single policy
155
157
  preservation_policy = None
156
158
  if policies:
157
- preservation_policy = (
158
- ocp.training.preservation_policies.AnyPreservationPolicy(
159
- policies
159
+ if len(policies) == 1:
160
+ preservation_policy = policies[0]
161
+ else:
162
+ preservation_policy = (
163
+ ocp.training.preservation_policies.AnyPreservationPolicy(
164
+ policies
165
+ )
160
166
  )
161
- )
162
167
 
163
168
  # Create the V1 Checkpointer with direct parameter passing
164
169
  # Orbax will handle directory creation on all processes as needed
@@ -167,6 +172,54 @@ class OrbaxCheckpoint(MonitorCallback):
167
172
  preservation_policy=preservation_policy,
168
173
  )
169
174
 
175
+ def _is_multihost_initialized(self):
176
+ """Check if multi-host environment is initialized."""
177
+ # Multi-host checkpointing is only supported on JAX backend
178
+ if backend.backend() != "jax":
179
+ return False
180
+
181
+ multihost = ocp.multihost
182
+ # Check if JAX distributed client is initialized
183
+ # (indicates multihost setup)
184
+ return multihost.is_jax_distributed_client_initialized()
185
+
186
+ def _sync_processes(self, key=None):
187
+ """Synchronize all processes across hosts."""
188
+ if not self._multihost_initialized:
189
+ return # No-op for single host
190
+
191
+ multihost = ocp.multihost
192
+ sync_key = key or "orbax_checkpoint_sync"
193
+ multihost.sync_global_processes(sync_key)
194
+
195
+ def is_multihost_enabled(self):
196
+ """Return True if multi-host checkpointing is enabled and initialized.
197
+
198
+ This method can be used to check if the callback is operating in
199
+ a multi-host distributed training environment. Multi-host checkpointing
200
+ is only supported on JAX backend.
201
+
202
+ Returns:
203
+ bool: True if multi-host support is active, False otherwise.
204
+ """
205
+ return self._multihost_initialized
206
+
207
+ def is_primary_host(self):
208
+ """Return True if this process is the primary host in multi-host setup.
209
+
210
+ In multi-host environments, only the primary host typically handles
211
+ logging and coordination tasks. Multi-host checkpointing is only
212
+ supported on JAX backend.
213
+
214
+ Returns:
215
+ bool: True if this is the primary host, False otherwise.
216
+ Always returns True in single-host environments.
217
+ """
218
+ if not self._multihost_initialized:
219
+ return True # Single host is always primary
220
+ multihost = ocp.multihost
221
+ return multihost.is_primary_host()
222
+
170
223
  def _should_save_on_batch(self, batch):
171
224
  """Check if we should save on this batch."""
172
225
  if self.save_freq == "epoch":
@@ -186,32 +239,14 @@ class OrbaxCheckpoint(MonitorCallback):
186
239
  return False
187
240
 
188
241
  def _save_checkpoint(self, step, logs=None):
189
- """Save a checkpoint at the given step."""
242
+ """Save a checkpoint at the given step with multi-host coordination."""
190
243
 
191
244
  # --- Prepare Composite State (Backend-Agnostic) ---
192
245
  state_tree = _get_state_tree(self.model)
193
246
 
194
247
  # Save the nested state structures directly (preserving layer
195
248
  # names and structure)
196
- if self.save_weights_only:
197
- composite_state = {
198
- "trainable_variables": state_tree["trainable_variables"],
199
- }
200
- if "non_trainable_variables" in state_tree:
201
- composite_state["non_trainable_variables"] = state_tree[
202
- "non_trainable_variables"
203
- ]
204
- else:
205
- composite_state = state_tree
206
-
207
- # --- Save Logic (V1 API) ---
208
- # All processes participate in distributed checkpointing
209
- # Checkpointer is configured to save unconditionally when
210
- # save_pytree is called
211
- if self.verbose > 0:
212
- print_msg(
213
- f"OrbaxCheckpoint: Triggering async save for step {step}..."
214
- )
249
+ composite_state = state_tree
215
250
 
216
251
  # Use a single with statement. If context_options is empty,
217
252
  # Context() uses defaults.
@@ -282,18 +317,16 @@ class OrbaxCheckpoint(MonitorCallback):
282
317
  except Exception:
283
318
  pass # Ignore errors during cleanup
284
319
 
320
+ # Multi-host synchronization: ensure all hosts complete cleanup
321
+ self._sync_processes("checkpoint_cleanup")
322
+
285
323
  def wait_until_finished(self):
286
324
  """Wait for any in-progress checkpoint operations to complete.
287
325
  This method blocks until all asynchronous checkpoint save operations
288
- have completed. It should be called before attempting to load
289
- checkpoints if there might be pending save operations.
326
+ have completed across all hosts in a multi-host setup.
290
327
  """
291
- # Wait for any async operations to complete
292
- if hasattr(self.checkpointer, "wait"):
293
- self.checkpointer.wait()
294
- else:
295
- # Fallback for older Orbax versions that don't have wait() method
296
- while self.checkpointer.is_saving_in_progress():
297
- import time
328
+ # Wait for any async operations to complete on this host
329
+ self.checkpointer.wait()
298
330
 
299
- time.sleep(0.1)
331
+ # Multi-host synchronization: ensure all hosts complete
332
+ self._sync_processes("checkpoint_wait_complete")
@@ -2,6 +2,7 @@ from keras.src import backend
2
2
  from keras.src.api_export import keras_export
3
3
  from keras.src.dtype_policies import dtype_policy
4
4
  from keras.src.dtype_policies.dtype_policy import QUANTIZATION_MODES
5
+ from keras.src.dtype_policies.dtype_policy import AWQDTypePolicy
5
6
  from keras.src.dtype_policies.dtype_policy import DTypePolicy
6
7
  from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy
7
8
  from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy
@@ -10,6 +11,7 @@ from keras.src.dtype_policies.dtype_policy import QuantizedFloat8DTypePolicy
10
11
  from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
11
12
 
12
13
  ALL_OBJECTS = {
14
+ AWQDTypePolicy,
13
15
  DTypePolicy,
14
16
  FloatDTypePolicy,
15
17
  QuantizedDTypePolicy,
@@ -3,7 +3,7 @@ from keras.src import ops
3
3
  from keras.src.api_export import keras_export
4
4
  from keras.src.backend.common import global_state
5
5
 
6
- QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq")
6
+ QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq", "awq")
7
7
 
8
8
 
9
9
  @keras_export(
@@ -376,6 +376,93 @@ class GPTQDTypePolicy(QuantizedDTypePolicy):
376
376
  return config
377
377
 
378
378
 
379
+ @keras_export("keras.dtype_policies.AWQDTypePolicy")
380
+ class AWQDTypePolicy(QuantizedDTypePolicy):
381
+ """Quantized dtype policy for AWQ quantization.
382
+
383
+ This policy helps propagate quantization settings for AWQ
384
+ when loading an AWQ quantized model in Keras format.
385
+
386
+ Args:
387
+ mode: The quantization mode. This should be a string in the format
388
+ `"awq/<weight_bits>/<group_size>"`.
389
+ - `"awq"`: The identifier for the quantization algorithm.
390
+ - `<weight_bits>`: Number of bits to quantize weights to.
391
+ AWQ presently only supports 4-bit quantization.
392
+ - `<group_size>`: The group size for quantization. Supported
393
+ values are -1 (for per-channel quantization) or any
394
+ positive integer.
395
+ Example: `"awq/4/128"`.
396
+ source_name: The source dtype policy name, e.g. "float32".
397
+ """
398
+
399
+ def __init__(
400
+ self,
401
+ mode,
402
+ source_name=None,
403
+ ):
404
+ parts = mode.split("/")
405
+ expected_format = "'awq/<weight_bits>/<group_size>'"
406
+
407
+ # Validate format.
408
+ if len(parts) != 3 or parts[0] != "awq":
409
+ raise ValueError(
410
+ "Invalid mode for AWQDTypePolicy. Expected format "
411
+ f"{expected_format}, but got '{mode}'."
412
+ )
413
+
414
+ # Validate and cast weight_bits and group_size.
415
+ try:
416
+ weight_bits = int(parts[1])
417
+ group_size = int(parts[2])
418
+ except ValueError:
419
+ raise ValueError(
420
+ "Invalid mode for AWQDTypePolicy. <weight_bits> and "
421
+ "<group_size> must be integers. Expected format "
422
+ f"{expected_format}, but got '{mode}'."
423
+ )
424
+
425
+ # AWQ presently only supports 4-bit quantization.
426
+ if weight_bits != 4:
427
+ raise ValueError(
428
+ "Invalid weight_bits in mode. AWQ only supports 4-bit "
429
+ f"quantization, but got {weight_bits} from '{mode}'."
430
+ )
431
+
432
+ if group_size < -1 or group_size == 0:
433
+ raise ValueError(
434
+ "Invalid group_size in mode. Supported values are "
435
+ "-1 (per-channel) or a positive integer, "
436
+ f"but got {group_size} from '{mode}'."
437
+ )
438
+
439
+ base_mode = parts[0]
440
+ super().__init__(
441
+ mode=base_mode,
442
+ source_name=source_name,
443
+ )
444
+
445
+ self._name = f"{mode}_from_{source_name}"
446
+ self.mode = base_mode
447
+ self.weight_bits = weight_bits
448
+ self.group_size = group_size
449
+
450
+ def __eq__(self, other):
451
+ if super().__eq__(other) is False:
452
+ return False
453
+ return (
454
+ self.weight_bits == other.weight_bits
455
+ and self.group_size == other.group_size
456
+ )
457
+
458
+ def get_config(self):
459
+ config = super().get_config()
460
+ # Reconstruct the full mode string for serialization
461
+ mode = f"{self.mode}/{self.weight_bits}/{self.group_size}"
462
+ config.update({"mode": mode})
463
+ return config
464
+
465
+
379
466
  @keras_export(
380
467
  [
381
468
  "keras.config.set_dtype_policy",
@@ -442,6 +529,8 @@ def _get_quantized_dtype_policy_by_str(policy):
442
529
  return QuantizedDTypePolicy(mode, source_name)
443
530
  elif policy.startswith("gptq"):
444
531
  return GPTQDTypePolicy(mode, source_name)
532
+ elif policy.startswith("awq"):
533
+ return AWQDTypePolicy(mode, source_name)
445
534
  elif policy.startswith("float8"):
446
535
  return QuantizedFloat8DTypePolicy(mode, source_name)
447
536
  else: