keras-nightly 3.14.0.dev2025122704__py3-none-any.whl → 3.14.0.dev2026012204__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  2. keras/_tf_keras/keras/ops/__init__.py +3 -0
  3. keras/_tf_keras/keras/ops/numpy/__init__.py +3 -0
  4. keras/_tf_keras/keras/quantizers/__init__.py +1 -0
  5. keras/dtype_policies/__init__.py +3 -0
  6. keras/ops/__init__.py +3 -0
  7. keras/ops/numpy/__init__.py +3 -0
  8. keras/quantizers/__init__.py +1 -0
  9. keras/src/backend/jax/nn.py +26 -9
  10. keras/src/backend/jax/numpy.py +16 -0
  11. keras/src/backend/numpy/numpy.py +23 -0
  12. keras/src/backend/openvino/numpy.py +369 -16
  13. keras/src/backend/tensorflow/numpy.py +34 -1
  14. keras/src/backend/tensorflow/rnn.py +17 -7
  15. keras/src/backend/torch/numpy.py +36 -0
  16. keras/src/backend/torch/rnn.py +28 -11
  17. keras/src/callbacks/orbax_checkpoint.py +75 -42
  18. keras/src/dtype_policies/__init__.py +2 -0
  19. keras/src/dtype_policies/dtype_policy.py +90 -1
  20. keras/src/layers/core/dense.py +122 -6
  21. keras/src/layers/core/einsum_dense.py +151 -7
  22. keras/src/layers/core/embedding.py +1 -1
  23. keras/src/layers/core/reversible_embedding.py +10 -1
  24. keras/src/layers/layer.py +5 -0
  25. keras/src/layers/preprocessing/feature_space.py +8 -4
  26. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  27. keras/src/layers/preprocessing/image_preprocessing/center_crop.py +13 -15
  28. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  29. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  30. keras/src/losses/losses.py +24 -0
  31. keras/src/models/model.py +18 -9
  32. keras/src/ops/image.py +109 -96
  33. keras/src/ops/numpy.py +181 -0
  34. keras/src/quantizers/__init__.py +2 -0
  35. keras/src/quantizers/awq.py +361 -0
  36. keras/src/quantizers/awq_config.py +140 -0
  37. keras/src/quantizers/awq_core.py +217 -0
  38. keras/src/quantizers/gptq.py +1 -2
  39. keras/src/quantizers/gptq_core.py +1 -1
  40. keras/src/quantizers/quantization_config.py +14 -0
  41. keras/src/quantizers/quantizers.py +61 -52
  42. keras/src/random/seed_generator.py +2 -2
  43. keras/src/saving/file_editor.py +81 -6
  44. keras/src/saving/orbax_util.py +50 -0
  45. keras/src/saving/saving_api.py +37 -14
  46. keras/src/utils/jax_layer.py +69 -31
  47. keras/src/utils/module_utils.py +11 -0
  48. keras/src/utils/tracking.py +5 -5
  49. keras/src/version.py +1 -1
  50. {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/METADATA +1 -1
  51. {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/RECORD +53 -49
  52. {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/WHEEL +1 -1
  53. {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,6 @@ from keras.src.api_export import keras_export
8
8
  from keras.src.callbacks.monitor_callback import (
9
9
  MonitorCallback, # For metric monitoring logic
10
10
  )
11
- from keras.src.utils.io_utils import print_msg
12
11
  from keras.src.utils.module_utils import ocp
13
12
 
14
13
  # Context and AsyncOptions are accessed through the lazy-loaded ocp module
@@ -62,6 +61,11 @@ class OrbaxCheckpoint(MonitorCallback):
62
61
  This callback saves the model's weights and optimizer state asynchronously
63
62
  using Orbax, allowing training to continue without blocking for I/O.
64
63
 
64
+ **Multi-host Support**: When running in a multi-host distributed training
65
+ environment with JAX backend, this callback automatically coordinates
66
+ checkpointing across all hosts to ensure consistency and proper
67
+ synchronization. Multi-host checkpointing is only supported on JAX.
68
+
65
69
  Example:
66
70
 
67
71
  ```python
@@ -92,10 +96,6 @@ class OrbaxCheckpoint(MonitorCallback):
92
96
  verbose: Verbosity mode, 0 or 1.
93
97
  save_best_only: if `save_best_only=True`, it only saves when the model
94
98
  is considered the "best" based on the monitored quantity.
95
- save_weights_only: if `save_weights_only=True`, only the model's
96
- weights will be saved. Otherwise, the full model state
97
- (weights, non-trainable variables, optimizer state, and
98
- metrics state) will be saved. Defaults to False.
99
99
  mode: one of {'auto', 'min', 'max'}. Used with `save_best_only`.
100
100
  save_freq: `'epoch'` or integer. Frequency to save checkpoints.
101
101
  max_to_keep: Integer, maximum number of recent checkpoints to keep.
@@ -112,7 +112,6 @@ class OrbaxCheckpoint(MonitorCallback):
112
112
  monitor="val_loss",
113
113
  verbose=0,
114
114
  save_best_only=False,
115
- save_weights_only=False,
116
115
  mode="auto",
117
116
  save_freq="epoch",
118
117
  initial_value_threshold=None,
@@ -129,7 +128,6 @@ class OrbaxCheckpoint(MonitorCallback):
129
128
  self.directory = directory
130
129
  self.verbose = verbose
131
130
  self.save_best_only = save_best_only
132
- self.save_weights_only = save_weights_only
133
131
  self.save_freq = save_freq
134
132
  self.max_to_keep = max_to_keep
135
133
  self.save_on_background = save_on_background
@@ -138,6 +136,9 @@ class OrbaxCheckpoint(MonitorCallback):
138
136
  self._current_epoch = 0 # Keep track of epoch
139
137
  self._total_batches_seen = 0 # Global batch counter for step tracking
140
138
 
139
+ # Multi-host support
140
+ self._multihost_initialized = self._is_multihost_initialized()
141
+
141
142
  if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
142
143
  raise ValueError(
143
144
  f"Unrecognized save_freq: {self.save_freq}. "
@@ -151,14 +152,18 @@ class OrbaxCheckpoint(MonitorCallback):
151
152
  ocp.training.preservation_policies.LatestN(max_to_keep)
152
153
  )
153
154
 
154
- # Use AnyPreservationPolicy to combine them.
155
+ # Use AnyPreservationPolicy to combine them, or use directly
156
+ # if single policy
155
157
  preservation_policy = None
156
158
  if policies:
157
- preservation_policy = (
158
- ocp.training.preservation_policies.AnyPreservationPolicy(
159
- policies
159
+ if len(policies) == 1:
160
+ preservation_policy = policies[0]
161
+ else:
162
+ preservation_policy = (
163
+ ocp.training.preservation_policies.AnyPreservationPolicy(
164
+ policies
165
+ )
160
166
  )
161
- )
162
167
 
163
168
  # Create the V1 Checkpointer with direct parameter passing
164
169
  # Orbax will handle directory creation on all processes as needed
@@ -167,6 +172,54 @@ class OrbaxCheckpoint(MonitorCallback):
167
172
  preservation_policy=preservation_policy,
168
173
  )
169
174
 
175
+ def _is_multihost_initialized(self):
176
+ """Check if multi-host environment is initialized."""
177
+ # Multi-host checkpointing is only supported on JAX backend
178
+ if backend.backend() != "jax":
179
+ return False
180
+
181
+ multihost = ocp.multihost
182
+ # Check if JAX distributed client is initialized
183
+ # (indicates multihost setup)
184
+ return multihost.is_jax_distributed_client_initialized()
185
+
186
+ def _sync_processes(self, key=None):
187
+ """Synchronize all processes across hosts."""
188
+ if not self._multihost_initialized:
189
+ return # No-op for single host
190
+
191
+ multihost = ocp.multihost
192
+ sync_key = key or "orbax_checkpoint_sync"
193
+ multihost.sync_global_processes(sync_key)
194
+
195
+ def is_multihost_enabled(self):
196
+ """Return True if multi-host checkpointing is enabled and initialized.
197
+
198
+ This method can be used to check if the callback is operating in
199
+ a multi-host distributed training environment. Multi-host checkpointing
200
+ is only supported on JAX backend.
201
+
202
+ Returns:
203
+ bool: True if multi-host support is active, False otherwise.
204
+ """
205
+ return self._multihost_initialized
206
+
207
+ def is_primary_host(self):
208
+ """Return True if this process is the primary host in multi-host setup.
209
+
210
+ In multi-host environments, only the primary host typically handles
211
+ logging and coordination tasks. Multi-host checkpointing is only
212
+ supported on JAX backend.
213
+
214
+ Returns:
215
+ bool: True if this is the primary host, False otherwise.
216
+ Always returns True in single-host environments.
217
+ """
218
+ if not self._multihost_initialized:
219
+ return True # Single host is always primary
220
+ multihost = ocp.multihost
221
+ return multihost.is_primary_host()
222
+
170
223
  def _should_save_on_batch(self, batch):
171
224
  """Check if we should save on this batch."""
172
225
  if self.save_freq == "epoch":
@@ -186,32 +239,14 @@ class OrbaxCheckpoint(MonitorCallback):
186
239
  return False
187
240
 
188
241
  def _save_checkpoint(self, step, logs=None):
189
- """Save a checkpoint at the given step."""
242
+ """Save a checkpoint at the given step with multi-host coordination."""
190
243
 
191
244
  # --- Prepare Composite State (Backend-Agnostic) ---
192
245
  state_tree = _get_state_tree(self.model)
193
246
 
194
247
  # Save the nested state structures directly (preserving layer
195
248
  # names and structure)
196
- if self.save_weights_only:
197
- composite_state = {
198
- "trainable_variables": state_tree["trainable_variables"],
199
- }
200
- if "non_trainable_variables" in state_tree:
201
- composite_state["non_trainable_variables"] = state_tree[
202
- "non_trainable_variables"
203
- ]
204
- else:
205
- composite_state = state_tree
206
-
207
- # --- Save Logic (V1 API) ---
208
- # All processes participate in distributed checkpointing
209
- # Checkpointer is configured to save unconditionally when
210
- # save_pytree is called
211
- if self.verbose > 0:
212
- print_msg(
213
- f"OrbaxCheckpoint: Triggering async save for step {step}..."
214
- )
249
+ composite_state = state_tree
215
250
 
216
251
  # Use a single with statement. If context_options is empty,
217
252
  # Context() uses defaults.
@@ -282,18 +317,16 @@ class OrbaxCheckpoint(MonitorCallback):
282
317
  except Exception:
283
318
  pass # Ignore errors during cleanup
284
319
 
320
+ # Multi-host synchronization: ensure all hosts complete cleanup
321
+ self._sync_processes("checkpoint_cleanup")
322
+
285
323
  def wait_until_finished(self):
286
324
  """Wait for any in-progress checkpoint operations to complete.
287
325
  This method blocks until all asynchronous checkpoint save operations
288
- have completed. It should be called before attempting to load
289
- checkpoints if there might be pending save operations.
326
+ have completed across all hosts in a multi-host setup.
290
327
  """
291
- # Wait for any async operations to complete
292
- if hasattr(self.checkpointer, "wait"):
293
- self.checkpointer.wait()
294
- else:
295
- # Fallback for older Orbax versions that don't have wait() method
296
- while self.checkpointer.is_saving_in_progress():
297
- import time
328
+ # Wait for any async operations to complete on this host
329
+ self.checkpointer.wait()
298
330
 
299
- time.sleep(0.1)
331
+ # Multi-host synchronization: ensure all hosts complete
332
+ self._sync_processes("checkpoint_wait_complete")
@@ -2,6 +2,7 @@ from keras.src import backend
2
2
  from keras.src.api_export import keras_export
3
3
  from keras.src.dtype_policies import dtype_policy
4
4
  from keras.src.dtype_policies.dtype_policy import QUANTIZATION_MODES
5
+ from keras.src.dtype_policies.dtype_policy import AWQDTypePolicy
5
6
  from keras.src.dtype_policies.dtype_policy import DTypePolicy
6
7
  from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy
7
8
  from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy
@@ -10,6 +11,7 @@ from keras.src.dtype_policies.dtype_policy import QuantizedFloat8DTypePolicy
10
11
  from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
11
12
 
12
13
  ALL_OBJECTS = {
14
+ AWQDTypePolicy,
13
15
  DTypePolicy,
14
16
  FloatDTypePolicy,
15
17
  QuantizedDTypePolicy,
@@ -3,7 +3,7 @@ from keras.src import ops
3
3
  from keras.src.api_export import keras_export
4
4
  from keras.src.backend.common import global_state
5
5
 
6
- QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq")
6
+ QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq", "awq")
7
7
 
8
8
 
9
9
  @keras_export(
@@ -376,6 +376,93 @@ class GPTQDTypePolicy(QuantizedDTypePolicy):
376
376
  return config
377
377
 
378
378
 
379
+ @keras_export("keras.dtype_policies.AWQDTypePolicy")
380
+ class AWQDTypePolicy(QuantizedDTypePolicy):
381
+ """Quantized dtype policy for AWQ quantization.
382
+
383
+ This policy helps propagate quantization settings for AWQ
384
+ when loading an AWQ quantized model in Keras format.
385
+
386
+ Args:
387
+ mode: The quantization mode. This should be a string in the format
388
+ `"awq/<weight_bits>/<group_size>"`.
389
+ - `"awq"`: The identifier for the quantization algorithm.
390
+ - `<weight_bits>`: Number of bits to quantize weights to.
391
+ AWQ presently only supports 4-bit quantization.
392
+ - `<group_size>`: The group size for quantization. Supported
393
+ values are -1 (for per-channel quantization) or any
394
+ positive integer.
395
+ Example: `"awq/4/128"`.
396
+ source_name: The source dtype policy name, e.g. "float32".
397
+ """
398
+
399
+ def __init__(
400
+ self,
401
+ mode,
402
+ source_name=None,
403
+ ):
404
+ parts = mode.split("/")
405
+ expected_format = "'awq/<weight_bits>/<group_size>'"
406
+
407
+ # Validate format.
408
+ if len(parts) != 3 or parts[0] != "awq":
409
+ raise ValueError(
410
+ "Invalid mode for AWQDTypePolicy. Expected format "
411
+ f"{expected_format}, but got '{mode}'."
412
+ )
413
+
414
+ # Validate and cast weight_bits and group_size.
415
+ try:
416
+ weight_bits = int(parts[1])
417
+ group_size = int(parts[2])
418
+ except ValueError:
419
+ raise ValueError(
420
+ "Invalid mode for AWQDTypePolicy. <weight_bits> and "
421
+ "<group_size> must be integers. Expected format "
422
+ f"{expected_format}, but got '{mode}'."
423
+ )
424
+
425
+ # AWQ presently only supports 4-bit quantization.
426
+ if weight_bits != 4:
427
+ raise ValueError(
428
+ "Invalid weight_bits in mode. AWQ only supports 4-bit "
429
+ f"quantization, but got {weight_bits} from '{mode}'."
430
+ )
431
+
432
+ if group_size < -1 or group_size == 0:
433
+ raise ValueError(
434
+ "Invalid group_size in mode. Supported values are "
435
+ "-1 (per-channel) or a positive integer, "
436
+ f"but got {group_size} from '{mode}'."
437
+ )
438
+
439
+ base_mode = parts[0]
440
+ super().__init__(
441
+ mode=base_mode,
442
+ source_name=source_name,
443
+ )
444
+
445
+ self._name = f"{mode}_from_{source_name}"
446
+ self.mode = base_mode
447
+ self.weight_bits = weight_bits
448
+ self.group_size = group_size
449
+
450
+ def __eq__(self, other):
451
+ if super().__eq__(other) is False:
452
+ return False
453
+ return (
454
+ self.weight_bits == other.weight_bits
455
+ and self.group_size == other.group_size
456
+ )
457
+
458
+ def get_config(self):
459
+ config = super().get_config()
460
+ # Reconstruct the full mode string for serialization
461
+ mode = f"{self.mode}/{self.weight_bits}/{self.group_size}"
462
+ config.update({"mode": mode})
463
+ return config
464
+
465
+
379
466
  @keras_export(
380
467
  [
381
468
  "keras.config.set_dtype_policy",
@@ -442,6 +529,8 @@ def _get_quantized_dtype_policy_by_str(policy):
442
529
  return QuantizedDTypePolicy(mode, source_name)
443
530
  elif policy.startswith("gptq"):
444
531
  return GPTQDTypePolicy(mode, source_name)
532
+ elif policy.startswith("awq"):
533
+ return AWQDTypePolicy(mode, source_name)
445
534
  elif policy.startswith("float8"):
446
535
  return QuantizedFloat8DTypePolicy(mode, source_name)
447
536
  else:
@@ -128,7 +128,7 @@ class Dense(Layer):
128
128
  mode=self.quantization_mode,
129
129
  config=self.quantization_config,
130
130
  )
131
- if self.quantization_mode not in ("int8", "int4", "gptq"):
131
+ if self.quantization_mode not in ("int8", "int4", "gptq", "awq"):
132
132
  # If the layer is quantized to int8 or int4, `self._kernel` will be
133
133
  # added in `self._int8_build` or `_int4_build`. Therefore, we skip
134
134
  # it here.
@@ -165,15 +165,17 @@ class Dense(Layer):
165
165
 
166
166
  mode = self.quantization_mode
167
167
  is_gptq = mode == "gptq"
168
+ is_awq = mode == "awq"
168
169
  is_int4 = mode == "int4"
169
- calibrated = bool(getattr(self, "is_gptq_calibrated", False))
170
+ gptq_calibrated = bool(getattr(self, "is_gptq_calibrated", False))
171
+ awq_calibrated = bool(getattr(self, "is_awq_calibrated", False))
170
172
  gptq_bits = (
171
173
  gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None
172
174
  )
173
175
 
174
176
  # Decide the source tensor first (packed vs already-quantized vs plain
175
177
  # kernel)
176
- if is_gptq and calibrated and gptq_bits != 4:
178
+ if is_gptq and gptq_calibrated and gptq_bits != 4:
177
179
  # calibrated GPTQ, not 4-bit, no unpacking needed
178
180
  kernel = self.quantized_kernel
179
181
  else:
@@ -183,7 +185,15 @@ class Dense(Layer):
183
185
  # Handle int4 unpacking cases in one place
184
186
  if is_int4:
185
187
  kernel = quantizers.unpack_int4(kernel, self._orig_input_dim)
186
- elif is_gptq and calibrated and gptq_bits == 4:
188
+ elif is_gptq and gptq_calibrated and gptq_bits == 4:
189
+ kernel = quantizers.unpack_int4(
190
+ self.quantized_kernel,
191
+ orig_len=self.units,
192
+ axis=0,
193
+ dtype="uint8",
194
+ )
195
+ elif is_awq and awq_calibrated:
196
+ # AWQ always uses 4-bit quantization
187
197
  kernel = quantizers.unpack_int4(
188
198
  self.quantized_kernel,
189
199
  orig_len=self.units,
@@ -304,8 +314,9 @@ class Dense(Layer):
304
314
  if mode not in self.variable_serialization_spec:
305
315
  raise self._quantization_mode_error(mode)
306
316
 
307
- # A saved GPTQ quantized model will always be calibrated.
317
+ # A saved GPTQ/AWQ quantized model will always be calibrated.
308
318
  self.is_gptq_calibrated = mode == "gptq"
319
+ self.is_awq_calibrated = mode == "awq"
309
320
 
310
321
  idx = 0
311
322
  for name in self.variable_serialization_spec[mode]:
@@ -395,6 +406,14 @@ class Dense(Layer):
395
406
  "kernel_zero",
396
407
  "g_idx",
397
408
  ],
409
+ "awq": [
410
+ "bias",
411
+ "quantized_kernel",
412
+ "kernel_scale",
413
+ "kernel_zero",
414
+ "awq_scales",
415
+ "g_idx",
416
+ ],
398
417
  }
399
418
 
400
419
  def quantized_build(self, kernel_shape, mode, config=None):
@@ -406,6 +425,8 @@ class Dense(Layer):
406
425
  self._float8_build()
407
426
  elif mode == "gptq":
408
427
  self._gptq_build(kernel_shape, config)
428
+ elif mode == "awq":
429
+ self._awq_build(kernel_shape, config)
409
430
  else:
410
431
  raise self._quantization_mode_error(mode)
411
432
  self._is_quantized = True
@@ -515,6 +536,97 @@ class Dense(Layer):
515
536
  y = self.activation(y)
516
537
  return y
517
538
 
539
+ def _awq_build(self, kernel_shape, config):
540
+ """Build variables for AWQ quantization.
541
+
542
+ AWQ uses 4-bit quantization with per-channel AWQ scales that protect
543
+ salient weights based on activation magnitudes.
544
+ """
545
+ from keras.src.quantizers import awq_core
546
+
547
+ # Ensures the forward pass uses the original high-precision kernel
548
+ # until calibration has been performed.
549
+ self.is_awq_calibrated = False
550
+ self.kernel_shape = kernel_shape
551
+
552
+ # For 4-bit weights, we pack two values per byte.
553
+ units = (kernel_shape[1] + 1) // 2
554
+
555
+ self.quantized_kernel = self.add_weight(
556
+ name="kernel",
557
+ shape=(units, kernel_shape[0]),
558
+ initializer="zeros",
559
+ dtype="uint8",
560
+ trainable=False,
561
+ )
562
+
563
+ group_size = awq_core.get_group_size_for_layer(self, config)
564
+ num_groups = (
565
+ 1 if group_size == -1 else math.ceil(kernel_shape[0] / group_size)
566
+ )
567
+ self.kernel_scale = self.add_weight(
568
+ name="kernel_scale",
569
+ shape=(self.units, num_groups),
570
+ initializer="ones",
571
+ trainable=False,
572
+ )
573
+ self.kernel_zero = self.add_weight(
574
+ name="kernel_zero",
575
+ shape=(self.units, num_groups),
576
+ initializer="zeros",
577
+ dtype="uint8",
578
+ trainable=False,
579
+ )
580
+
581
+ # Per-channel AWQ scales from activation magnitudes
582
+ self.awq_scales = self.add_weight(
583
+ name="awq_scales",
584
+ shape=(kernel_shape[0],),
585
+ initializer="ones",
586
+ trainable=False,
587
+ )
588
+ self.g_idx = self.add_weight(
589
+ name="g_idx",
590
+ shape=(kernel_shape[0],),
591
+ initializer="zeros",
592
+ dtype="float32",
593
+ trainable=False,
594
+ )
595
+
596
+ def _awq_call(self, inputs, training=False):
597
+ """Forward pass for AWQ quantized layer."""
598
+ if not self.is_awq_calibrated:
599
+ W = self._kernel
600
+ else:
601
+ # Unpack 4-bit weights
602
+ W = quantizers.unpack_int4(
603
+ self.quantized_kernel,
604
+ orig_len=self.units,
605
+ axis=0,
606
+ dtype="uint8",
607
+ )
608
+ # Dequantize using scale/zero maps
609
+ W = ops.transpose(
610
+ dequantize_with_sz_map(
611
+ W,
612
+ self.kernel_scale,
613
+ self.kernel_zero,
614
+ self.g_idx,
615
+ )
616
+ )
617
+ # Apply AWQ scales by dividing to restore original magnitude
618
+ # (We multiplied by scales before quantization, so divide to undo)
619
+ # awq_scales has shape [input_dim], W has shape [input_dim, units]
620
+ # Expand dims for proper broadcasting.
621
+ W = ops.divide(W, ops.expand_dims(self.awq_scales, -1))
622
+
623
+ y = ops.matmul(inputs, W)
624
+ if self.bias is not None:
625
+ y = ops.add(y, self.bias)
626
+ if self.activation is not None:
627
+ y = self.activation(y)
628
+ return y
629
+
518
630
  def _int4_build(self, kernel_shape, config=None):
519
631
  """Build variables for int4 quantization.
520
632
 
@@ -835,6 +947,8 @@ class Dense(Layer):
835
947
  self.kernel_scale.assign(kernel_scale)
836
948
  elif mode == "gptq":
837
949
  self.quantized_build(kernel_shape, mode, self.quantization_config)
950
+ elif mode == "awq":
951
+ self.quantized_build(kernel_shape, mode, self.quantization_config)
838
952
  elif mode == "float8":
839
953
  self.quantized_build(kernel_shape, mode)
840
954
  else:
@@ -847,6 +961,8 @@ class Dense(Layer):
847
961
  policy_name = mode
848
962
  if mode == "gptq":
849
963
  policy_name = self.quantization_config.dtype_policy_string()
964
+ elif mode == "awq":
965
+ policy_name = self.quantization_config.dtype_policy_string()
850
966
  policy = dtype_policies.get(
851
967
  f"{policy_name}_from_{self.dtype_policy.name}"
852
968
  )
@@ -881,7 +997,7 @@ class Dense(Layer):
881
997
  `kernel_scale`: The quantization scale for the merged kernel.
882
998
  This is `None` if the layer is not quantized.
883
999
  """
884
- if self.dtype_policy.quantization_mode in (None, "gptq"):
1000
+ if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
885
1001
  return self.kernel, None
886
1002
 
887
1003
  kernel_value = self._kernel