keras-nightly 3.14.0.dev2025122704__py3-none-any.whl → 3.14.0.dev2026012204__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  2. keras/_tf_keras/keras/ops/__init__.py +3 -0
  3. keras/_tf_keras/keras/ops/numpy/__init__.py +3 -0
  4. keras/_tf_keras/keras/quantizers/__init__.py +1 -0
  5. keras/dtype_policies/__init__.py +3 -0
  6. keras/ops/__init__.py +3 -0
  7. keras/ops/numpy/__init__.py +3 -0
  8. keras/quantizers/__init__.py +1 -0
  9. keras/src/backend/jax/nn.py +26 -9
  10. keras/src/backend/jax/numpy.py +16 -0
  11. keras/src/backend/numpy/numpy.py +23 -0
  12. keras/src/backend/openvino/numpy.py +369 -16
  13. keras/src/backend/tensorflow/numpy.py +34 -1
  14. keras/src/backend/tensorflow/rnn.py +17 -7
  15. keras/src/backend/torch/numpy.py +36 -0
  16. keras/src/backend/torch/rnn.py +28 -11
  17. keras/src/callbacks/orbax_checkpoint.py +75 -42
  18. keras/src/dtype_policies/__init__.py +2 -0
  19. keras/src/dtype_policies/dtype_policy.py +90 -1
  20. keras/src/layers/core/dense.py +122 -6
  21. keras/src/layers/core/einsum_dense.py +151 -7
  22. keras/src/layers/core/embedding.py +1 -1
  23. keras/src/layers/core/reversible_embedding.py +10 -1
  24. keras/src/layers/layer.py +5 -0
  25. keras/src/layers/preprocessing/feature_space.py +8 -4
  26. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  27. keras/src/layers/preprocessing/image_preprocessing/center_crop.py +13 -15
  28. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  29. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  30. keras/src/losses/losses.py +24 -0
  31. keras/src/models/model.py +18 -9
  32. keras/src/ops/image.py +109 -96
  33. keras/src/ops/numpy.py +181 -0
  34. keras/src/quantizers/__init__.py +2 -0
  35. keras/src/quantizers/awq.py +361 -0
  36. keras/src/quantizers/awq_config.py +140 -0
  37. keras/src/quantizers/awq_core.py +217 -0
  38. keras/src/quantizers/gptq.py +1 -2
  39. keras/src/quantizers/gptq_core.py +1 -1
  40. keras/src/quantizers/quantization_config.py +14 -0
  41. keras/src/quantizers/quantizers.py +61 -52
  42. keras/src/random/seed_generator.py +2 -2
  43. keras/src/saving/file_editor.py +81 -6
  44. keras/src/saving/orbax_util.py +50 -0
  45. keras/src/saving/saving_api.py +37 -14
  46. keras/src/utils/jax_layer.py +69 -31
  47. keras/src/utils/module_utils.py +11 -0
  48. keras/src/utils/tracking.py +5 -5
  49. keras/src/version.py +1 -1
  50. {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/METADATA +1 -1
  51. {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/RECORD +53 -49
  52. {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/WHEEL +1 -1
  53. {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/top_level.txt +0 -0
keras/src/ops/image.py CHANGED
@@ -565,6 +565,9 @@ class ExtractPatches(Operation):
565
565
  if isinstance(size, int):
566
566
  size = (size, size)
567
567
  self.size = size
568
+ self.is_3d = len(self.size) == 3
569
+ if strides is None:
570
+ strides = size
568
571
  self.strides = strides
569
572
  self.dilation_rate = dilation_rate
570
573
  self.padding = padding
@@ -583,29 +586,51 @@ class ExtractPatches(Operation):
583
586
  def compute_output_spec(self, images):
584
587
  images_shape = list(images.shape)
585
588
  original_ndim = len(images_shape)
586
- if not self.strides:
587
- strides = (self.size[0], self.size[1])
588
589
  if self.data_format == "channels_last":
589
590
  channels_in = images_shape[-1]
590
591
  else:
591
- channels_in = images_shape[-3]
592
- if original_ndim == 3:
593
- images_shape = [1] + images_shape
594
- filters = self.size[0] * self.size[1] * channels_in
595
- kernel_size = (self.size[0], self.size[1])
592
+ channels_in = images_shape[-4] if self.is_3d else images_shape[-3]
593
+
594
+ if self.is_3d:
595
+ # 3D patch extraction
596
+ if original_ndim == 4:
597
+ images_shape = [1] + images_shape
598
+ filters = self.size[0] * self.size[1] * self.size[2] * channels_in
599
+ kernel_size = (self.size[0], self.size[1], self.size[2])
600
+ else:
601
+ # 2D patch extraction
602
+ if original_ndim == 3:
603
+ images_shape = [1] + images_shape
604
+ filters = self.size[0] * self.size[1] * channels_in
605
+ kernel_size = (self.size[0], self.size[1])
606
+
596
607
  out_shape = compute_conv_output_shape(
597
608
  images_shape,
598
609
  filters,
599
610
  kernel_size,
600
- strides=strides,
611
+ strides=self.strides,
601
612
  padding=self.padding,
602
613
  data_format=self.data_format,
603
614
  dilation_rate=self.dilation_rate,
604
615
  )
605
- if original_ndim == 3:
606
- out_shape = out_shape[1:]
616
+
617
+ if self.is_3d:
618
+ if original_ndim == 4:
619
+ out_shape = out_shape[1:]
620
+ else:
621
+ if original_ndim == 3:
622
+ out_shape = out_shape[1:]
607
623
  return KerasTensor(shape=out_shape, dtype=images.dtype)
608
624
 
625
+ def get_config(self):
626
+ return {
627
+ "size": self.size,
628
+ "strides": self.strides,
629
+ "dilation_rate": self.dilation_rate,
630
+ "padding": self.padding,
631
+ "data_format": self.data_format,
632
+ }
633
+
609
634
 
610
635
  @keras_export("keras.ops.image.extract_patches")
611
636
  def extract_patches(
@@ -616,42 +641,78 @@ def extract_patches(
616
641
  padding="valid",
617
642
  data_format=None,
618
643
  ):
619
- """Extracts patches from the image(s).
644
+ """Extracts patches from the image(s) or volume(s).
645
+
646
+ This function supports both 2D and 3D patch extraction based on the
647
+ `size` argument length, similar to how `keras.ops.conv` handles
648
+ different dimensions.
620
649
 
621
650
  Args:
622
- images: Input image or batch of images. Must be 3D or 4D.
623
- size: Patch size int or tuple (patch_height, patch_width)
624
- strides: strides along height and width. If not specified, or
625
- if `None`, it defaults to the same value as `size`.
626
- dilation_rate: This is the input stride, specifying how far two
627
- consecutive patch samples are in the input. For value other than 1,
628
- strides must be 1. NOTE: `strides > 1` is not supported in
629
- conjunction with `dilation_rate > 1`
651
+ images: Input image/volume or batch of images/volumes.
652
+ For 2D patches: 3D `(H, W, C)` or 4D `(N, H, W, C)`.
653
+ For 3D patches: 4D `(D, H, W, C)` or 5D `(N, D, H, W, C)`.
654
+ size: Patch size as int or tuple.
655
+ Length 2 tuple `(patch_height, patch_width)` or int for 2D patches.
656
+ Length 3 tuple `(patch_depth, patch_height, patch_width)` for
657
+ 3D patches.
658
+ strides: Strides for patch extraction. If not specified, defaults
659
+ to `size` (non-overlapping patches).
660
+ dilation_rate: Dilation rate for patch extraction. Note that
661
+ `dilation_rate > 1` is not supported with `strides > 1`.
630
662
  padding: The type of padding algorithm to use: `"same"` or `"valid"`.
631
663
  data_format: A string specifying the data format of the input tensor.
632
664
  It can be either `"channels_last"` or `"channels_first"`.
633
- `"channels_last"` corresponds to inputs with shape
634
- `(batch, height, width, channels)`, while `"channels_first"`
635
- corresponds to inputs with shape `(batch, channels, height, width)`.
636
- If not specified, the value will default to
637
- `keras.config.image_data_format`.
665
+ If not specified, defaults to `keras.config.image_data_format`.
638
666
 
639
667
  Returns:
640
- Extracted patches 3D (if not batched) or 4D (if batched)
668
+ Extracted patches with shape depending on input and `size`:
669
+ - 2D patches: 3D (unbatched) or 4D (batched)
670
+ - 3D patches: 4D (unbatched) or 5D (batched)
641
671
 
642
672
  Examples:
643
673
 
674
+ >>> # 2D patches from batch of images
644
675
  >>> image = np.random.random(
645
676
  ... (2, 20, 20, 3)
646
- ... ).astype("float32") # batch of 2 RGB images
677
+ ... ).astype("float32")
647
678
  >>> patches = keras.ops.image.extract_patches(image, (5, 5))
648
679
  >>> patches.shape
649
680
  (2, 4, 4, 75)
650
- >>> image = np.random.random((20, 20, 3)).astype("float32") # 1 RGB image
681
+
682
+ >>> # 2D patches from single image
683
+ >>> image = np.random.random((20, 20, 3)).astype("float32")
651
684
  >>> patches = keras.ops.image.extract_patches(image, (3, 3), (1, 1))
652
685
  >>> patches.shape
653
686
  (18, 18, 27)
687
+
688
+ >>> # 3D patches from batch of volumes
689
+ >>> volumes = np.random.random(
690
+ ... (2, 10, 10, 10, 3)
691
+ ... ).astype("float32")
692
+ >>> patches = keras.ops.image.extract_patches(volumes, (3, 3, 3))
693
+ >>> patches.shape
694
+ (2, 3, 3, 3, 81)
695
+
696
+ >>> # 3D patches from single volume
697
+ >>> volume = np.random.random((10, 10, 10, 3)).astype("float32")
698
+ >>> patches = keras.ops.image.extract_patches(volume, (3, 3, 3))
699
+ >>> patches.shape
700
+ (3, 3, 3, 81)
654
701
  """
702
+ # Validate size argument
703
+ if not isinstance(size, int):
704
+ if not isinstance(size, (tuple, list)):
705
+ raise TypeError(
706
+ "Invalid `size` argument. Expected an int or a tuple. "
707
+ f"Received: size={size} of type {type(size).__name__}"
708
+ )
709
+ if len(size) not in (2, 3):
710
+ raise ValueError(
711
+ "Invalid `size` argument. Expected a tuple of length 2 or 3. "
712
+ f"Received: size={size} with length {len(size)}"
713
+ )
714
+
715
+ # 2D patch extraction (default)
655
716
  if any_symbolic_tensors((images,)):
656
717
  return ExtractPatches(
657
718
  size=size,
@@ -673,6 +734,23 @@ def _extract_patches(
673
734
  dilation_rate=1,
674
735
  padding="valid",
675
736
  data_format=None,
737
+ ):
738
+ if not isinstance(size, int) and len(size) == 3:
739
+ return _extract_patches_3d(
740
+ images, size, strides, dilation_rate, padding, data_format
741
+ )
742
+ return _extract_patches_2d(
743
+ images, size, strides, dilation_rate, padding, data_format
744
+ )
745
+
746
+
747
+ def _extract_patches_2d(
748
+ images,
749
+ size,
750
+ strides=None,
751
+ dilation_rate=1,
752
+ padding="valid",
753
+ data_format=None,
676
754
  ):
677
755
  if isinstance(size, int):
678
756
  patch_h = patch_w = size
@@ -712,74 +790,6 @@ def _extract_patches(
712
790
  return patches
713
791
 
714
792
 
715
- class ExtractPatches3D(Operation):
716
- def __init__(
717
- self,
718
- size,
719
- strides=None,
720
- dilation_rate=1,
721
- padding="valid",
722
- data_format=None,
723
- *,
724
- name=None,
725
- ):
726
- super().__init__(name=name)
727
- if isinstance(size, int):
728
- size = (size, size, size)
729
- elif len(size) != 3:
730
- raise TypeError(
731
- "Invalid `size` argument. Expected an "
732
- f"int or a tuple of length 3. Received: size={size}"
733
- )
734
- self.size = size
735
- if strides is not None:
736
- if isinstance(strides, int):
737
- strides = (strides, strides, strides)
738
- elif len(strides) != 3:
739
- raise ValueError(f"Invalid `strides` argument. Got: {strides}")
740
- else:
741
- strides = size
742
- self.strides = strides
743
- self.dilation_rate = dilation_rate
744
- self.padding = padding
745
- self.data_format = backend.standardize_data_format(data_format)
746
-
747
- def call(self, volumes):
748
- return _extract_patches_3d(
749
- volumes,
750
- self.size,
751
- self.strides,
752
- self.dilation_rate,
753
- self.padding,
754
- self.data_format,
755
- )
756
-
757
- def compute_output_spec(self, volumes):
758
- volumes_shape = list(volumes.shape)
759
- original_ndim = len(volumes_shape)
760
- strides = self.strides
761
- if self.data_format == "channels_last":
762
- channels_in = volumes_shape[-1]
763
- else:
764
- channels_in = volumes_shape[-4]
765
- if original_ndim == 4:
766
- volumes_shape = [1] + volumes_shape
767
- filters = self.size[0] * self.size[1] * self.size[2] * channels_in
768
- kernel_size = (self.size[0], self.size[1], self.size[2])
769
- out_shape = compute_conv_output_shape(
770
- volumes_shape,
771
- filters,
772
- kernel_size,
773
- strides=strides,
774
- padding=self.padding,
775
- data_format=self.data_format,
776
- dilation_rate=self.dilation_rate,
777
- )
778
- if original_ndim == 4:
779
- out_shape = out_shape[1:]
780
- return KerasTensor(shape=out_shape, dtype=volumes.dtype)
781
-
782
-
783
793
  def _extract_patches_3d(
784
794
  volumes,
785
795
  size,
@@ -879,8 +889,11 @@ def extract_patches_3d(
879
889
  >>> patches.shape
880
890
  (3, 3, 3, 81)
881
891
  """
892
+ # Convert int to 3-tuple for 3D
893
+ if isinstance(size, int):
894
+ size = (size, size, size)
882
895
  if any_symbolic_tensors((volumes,)):
883
- return ExtractPatches3D(
896
+ return ExtractPatches(
884
897
  size=size,
885
898
  strides=strides,
886
899
  dilation_rate=dilation_rate,
keras/src/ops/numpy.py CHANGED
@@ -5064,6 +5064,67 @@ def moveaxis(x, source, destination):
5064
5064
  return backend.numpy.moveaxis(x, source=source, destination=destination)
5065
5065
 
5066
5066
 
5067
+ class Nansum(Operation):
5068
+ def __init__(self, axis=None, keepdims=False, *, name=None):
5069
+ super().__init__(name=name)
5070
+ self.axis = axis
5071
+ self.keepdims = keepdims
5072
+
5073
+ def call(self, x):
5074
+ return backend.numpy.nansum(x, axis=self.axis, keepdims=self.keepdims)
5075
+
5076
+ def compute_output_spec(self, x):
5077
+ dtype = dtypes.result_type(getattr(x, "dtype", backend.floatx()))
5078
+
5079
+ if dtype in ("bool", "int8", "int16"):
5080
+ dtype = "int32"
5081
+ elif dtype in ("uint8", "uint16"):
5082
+ dtype = "uint32"
5083
+
5084
+ if backend.backend() == "torch" and dtype == "uint32":
5085
+ dtype = "int32"
5086
+ sparse = getattr(x, "sparse", False)
5087
+ return KerasTensor(
5088
+ reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),
5089
+ dtype=dtype,
5090
+ sparse=sparse,
5091
+ )
5092
+
5093
+
5094
+ @keras_export(["keras.ops.nansum", "keras.ops.numpy.nansum"])
5095
+ def nansum(x, axis=None, keepdims=False):
5096
+ """Sum of a tensor over the given axes, ignoring NaNs.
5097
+
5098
+ Args:
5099
+ x: Input tensor.
5100
+ axis: Axis or axes along which the sum is computed. The default is to
5101
+ compute the sum of the flattened tensor.
5102
+ keepdims: If this is set to `True`, the axes which are reduced are left
5103
+ in the result as dimensions with size one.
5104
+
5105
+ Returns:
5106
+ Output tensor containing the sum, with NaN values ignored.
5107
+
5108
+ Examples:
5109
+ >>> import numpy as np
5110
+ >>> from keras import ops
5111
+ >>> x = np.array([[1.0, np.nan, 3.0],
5112
+ ... [np.nan, 2.0, 1.0]])
5113
+ >>> ops.nansum(x)
5114
+ 7.0
5115
+
5116
+ >>> ops.nansum(x, axis=1)
5117
+ array([4., 3.])
5118
+
5119
+ >>> ops.nansum(x, axis=1, keepdims=True)
5120
+ array([[4.],
5121
+ [3.]])
5122
+ """
5123
+ if any_symbolic_tensors((x,)):
5124
+ return Nansum(axis=axis, keepdims=keepdims).symbolic_call(x)
5125
+ return backend.numpy.nansum(x, axis=axis, keepdims=keepdims)
5126
+
5127
+
5067
5128
  class NanToNum(Operation):
5068
5129
  def __init__(self, nan=0.0, posinf=None, neginf=None, *, name=None):
5069
5130
  super().__init__(name=name)
@@ -5456,6 +5517,74 @@ def prod(x, axis=None, keepdims=False, dtype=None):
5456
5517
  return backend.numpy.prod(x, axis=axis, keepdims=keepdims, dtype=dtype)
5457
5518
 
5458
5519
 
5520
+ class Ptp(Operation):
5521
+ def __init__(self, axis=None, keepdims=False, *, name=None):
5522
+ super().__init__(name=name)
5523
+ self.axis = axis
5524
+ self.keepdims = keepdims
5525
+
5526
+ def call(self, x):
5527
+ return backend.numpy.ptp(
5528
+ x,
5529
+ axis=self.axis,
5530
+ keepdims=self.keepdims,
5531
+ )
5532
+
5533
+ def compute_output_spec(self, x):
5534
+ dtype = backend.standardize_dtype(x.dtype)
5535
+ return KerasTensor(
5536
+ reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),
5537
+ dtype=dtype,
5538
+ )
5539
+
5540
+
5541
+ @keras_export(["keras.ops.ptp", "keras.ops.numpy.ptp"])
5542
+ def ptp(x, axis=None, keepdims=False):
5543
+ """Return the peak-to-peak (max - min) value of tensor elements
5544
+ over a given axis.
5545
+
5546
+ The peak-to-peak value is defined as the difference between the
5547
+ maximum and minimum values along the specified axis.
5548
+
5549
+ Args:
5550
+ x: Input tensor.
5551
+ axis: Axis or axes along which the peak-to-peak value is computed.
5552
+ The default, `axis=None`, will compute the peak-to-peak value
5553
+ over all elements in the input tensor.
5554
+ keepdims: If this is set to `True`, the axes which are reduced
5555
+ are left in the result as dimensions with size one.
5556
+
5557
+ Returns:
5558
+ A tensor containing the peak-to-peak values of `x` over the
5559
+ given axis or axes.
5560
+
5561
+ Examples:
5562
+ >>> x = keras.ops.array([[1., 3., 2.],
5563
+ ... [4., 0., 5.]])
5564
+
5565
+ >>> # Peak-to-peak over all elements
5566
+ >>> keras.ops.ptp(x)
5567
+ 5.0
5568
+
5569
+ >>> # Peak-to-peak along axis 1
5570
+ >>> keras.ops.ptp(x, axis=1)
5571
+ array([2., 5.], dtype=float32)
5572
+
5573
+ >>> # Peak-to-peak over multiple axes
5574
+ >>> x = keras.ops.reshape(x, (1, 2, 3))
5575
+ >>> keras.ops.ptp(x, axis=(1, 2))
5576
+ array([5.], dtype=float32)
5577
+
5578
+ >>> # Keep reduced dimensions
5579
+ >>> keras.ops.ptp(x, axis=2, keepdims=True)
5580
+ array([[[2.],
5581
+ [5.]]], dtype=float32)
5582
+ """
5583
+ if any_symbolic_tensors((x,)):
5584
+ return Ptp(axis=axis, keepdims=keepdims).symbolic_call(x)
5585
+ return backend.numpy.ptp(x, axis=axis, keepdims=keepdims)
5586
+
5587
+
5459
5588
  class Quantile(Operation):
5460
5589
  def __init__(
5461
5590
  self, axis=None, method="linear", keepdims=False, *, name=None
@@ -7104,6 +7233,49 @@ def negative(x):
7104
7233
  return backend.numpy.negative(x)
7105
7234
 
7106
7235
 
7236
+ class Nextafter(Operation):
7237
+ def call(self, x1, x2):
7238
+ return backend.numpy.nextafter(x1, x2)
7239
+
7240
+ def compute_output_spec(self, x1, x2):
7241
+ x1_shape = getattr(x1, "shape", [])
7242
+ x2_shape = getattr(x2, "shape", [])
7243
+ output_shape = broadcast_shapes(x1_shape, x2_shape)
7244
+
7245
+ x1_type = backend.standardize_dtype(getattr(x1, "dtype", type(x1)))
7246
+ x2_type = backend.standardize_dtype(getattr(x2, "dtype", type(x2)))
7247
+ dtype = dtypes.result_type(x1_type, x2_type, float)
7248
+ return KerasTensor(output_shape, dtype=dtype)
7249
+
7250
+
7251
+ @keras_export(["keras.ops.nextafter", "keras.ops.numpy.nextafter"])
7252
+ def nextafter(x1, x2):
7253
+ """
7254
+ Return the next representable floating-point value after `x1` towards `x2`.
7255
+
7256
+ This function computes the next floating-point value
7257
+ following `x1` in the direction of `x2`, element-wise.
7258
+
7259
+ Args:
7260
+ x1: Input tensor whose values will be moved to the next
7261
+ representable floating-point value.
7262
+ x2: Input tensor indicating the direction toward which
7263
+ `x1` is moved.
7264
+
7265
+ Returns:
7266
+ Output tensor
7267
+
7268
+ Example:
7269
+ >>> x1 = keras.ops.convert_to_tensor([1.0, 1.0])
7270
+ >>> x2 = keras.ops.convert_to_tensor([2.0, 0.0])
7271
+ >>> keras.ops.nextafter(x1, x2)
7272
+ array([1.0000001, 0.99999994], dtype=float32)
7273
+ """
7274
+ if any_symbolic_tensors((x1, x2)):
7275
+ return Nextafter().symbolic_call(x1, x2)
7276
+ return backend.numpy.nextafter(x1, x2)
7277
+
7278
+
7107
7279
  class Square(Operation):
7108
7280
  def call(self, x):
7109
7281
  return backend.numpy.square(x)
@@ -7691,6 +7863,15 @@ def correlate(x1, x2, mode="valid"):
7691
7863
 
7692
7864
  Returns:
7693
7865
  Output tensor, cross-correlation of `x1` and `x2`.
7866
+
7867
+ Notes:
7868
+ Complex-valued inputs are currently not fully supported on the
7869
+ TensorFlow and PyTorch backends. When complex tensors are passed,
7870
+ they are cast to floating-point types and the imaginary component
7871
+ is discarded.
7872
+
7873
+ This behavior is documented for clarity and may change in the
7874
+ future. See discussion in issue #21617.
7694
7875
  """
7695
7876
  if any_symbolic_tensors((x1, x2)):
7696
7877
  return Correlate(mode=mode).symbolic_call(x1, x2)
@@ -1,6 +1,7 @@
1
1
  import inspect
2
2
 
3
3
  from keras.src.api_export import keras_export
4
+ from keras.src.quantizers.awq_config import AWQConfig
4
5
  from keras.src.quantizers.quantization_config import Float8QuantizationConfig
5
6
  from keras.src.quantizers.quantization_config import Int4QuantizationConfig
6
7
  from keras.src.quantizers.quantization_config import Int8QuantizationConfig
@@ -24,6 +25,7 @@ ALL_OBJECTS = {
24
25
  Int8QuantizationConfig,
25
26
  Int4QuantizationConfig,
26
27
  Float8QuantizationConfig,
28
+ AWQConfig,
27
29
  }
28
30
  ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
29
31
  ALL_OBJECTS_DICT.update(