keras-nightly 3.14.0.dev2026010104__py3-none-any.whl → 3.14.0.dev2026012204__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  2. keras/_tf_keras/keras/ops/__init__.py +2 -0
  3. keras/_tf_keras/keras/ops/numpy/__init__.py +2 -0
  4. keras/_tf_keras/keras/quantizers/__init__.py +1 -0
  5. keras/dtype_policies/__init__.py +3 -0
  6. keras/ops/__init__.py +2 -0
  7. keras/ops/numpy/__init__.py +2 -0
  8. keras/quantizers/__init__.py +1 -0
  9. keras/src/backend/jax/nn.py +26 -9
  10. keras/src/backend/jax/numpy.py +10 -0
  11. keras/src/backend/numpy/numpy.py +15 -0
  12. keras/src/backend/openvino/numpy.py +338 -17
  13. keras/src/backend/tensorflow/numpy.py +24 -1
  14. keras/src/backend/tensorflow/rnn.py +17 -7
  15. keras/src/backend/torch/numpy.py +26 -0
  16. keras/src/backend/torch/rnn.py +28 -11
  17. keras/src/callbacks/orbax_checkpoint.py +75 -42
  18. keras/src/dtype_policies/__init__.py +2 -0
  19. keras/src/dtype_policies/dtype_policy.py +90 -1
  20. keras/src/layers/core/dense.py +122 -6
  21. keras/src/layers/core/einsum_dense.py +151 -7
  22. keras/src/layers/core/embedding.py +1 -1
  23. keras/src/layers/core/reversible_embedding.py +10 -1
  24. keras/src/layers/layer.py +5 -0
  25. keras/src/layers/preprocessing/feature_space.py +8 -4
  26. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  27. keras/src/layers/preprocessing/image_preprocessing/center_crop.py +13 -15
  28. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  29. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  30. keras/src/losses/losses.py +24 -0
  31. keras/src/models/model.py +18 -9
  32. keras/src/ops/image.py +106 -93
  33. keras/src/ops/numpy.py +138 -0
  34. keras/src/quantizers/__init__.py +2 -0
  35. keras/src/quantizers/awq.py +361 -0
  36. keras/src/quantizers/awq_config.py +140 -0
  37. keras/src/quantizers/awq_core.py +217 -0
  38. keras/src/quantizers/gptq.py +1 -2
  39. keras/src/quantizers/gptq_core.py +1 -1
  40. keras/src/quantizers/quantization_config.py +14 -0
  41. keras/src/quantizers/quantizers.py +61 -52
  42. keras/src/random/seed_generator.py +2 -2
  43. keras/src/saving/orbax_util.py +50 -0
  44. keras/src/saving/saving_api.py +37 -14
  45. keras/src/utils/jax_layer.py +69 -31
  46. keras/src/utils/module_utils.py +11 -0
  47. keras/src/utils/tracking.py +5 -5
  48. keras/src/version.py +1 -1
  49. {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/METADATA +1 -1
  50. {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/RECORD +52 -48
  51. {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/WHEEL +1 -1
  52. {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/top_level.txt +0 -0
@@ -128,7 +128,7 @@ class Dense(Layer):
128
128
  mode=self.quantization_mode,
129
129
  config=self.quantization_config,
130
130
  )
131
- if self.quantization_mode not in ("int8", "int4", "gptq"):
131
+ if self.quantization_mode not in ("int8", "int4", "gptq", "awq"):
132
132
  # If the layer is quantized to int8 or int4, `self._kernel` will be
133
133
  # added in `self._int8_build` or `_int4_build`. Therefore, we skip
134
134
  # it here.
@@ -165,15 +165,17 @@ class Dense(Layer):
165
165
 
166
166
  mode = self.quantization_mode
167
167
  is_gptq = mode == "gptq"
168
+ is_awq = mode == "awq"
168
169
  is_int4 = mode == "int4"
169
- calibrated = bool(getattr(self, "is_gptq_calibrated", False))
170
+ gptq_calibrated = bool(getattr(self, "is_gptq_calibrated", False))
171
+ awq_calibrated = bool(getattr(self, "is_awq_calibrated", False))
170
172
  gptq_bits = (
171
173
  gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None
172
174
  )
173
175
 
174
176
  # Decide the source tensor first (packed vs already-quantized vs plain
175
177
  # kernel)
176
- if is_gptq and calibrated and gptq_bits != 4:
178
+ if is_gptq and gptq_calibrated and gptq_bits != 4:
177
179
  # calibrated GPTQ, not 4-bit, no unpacking needed
178
180
  kernel = self.quantized_kernel
179
181
  else:
@@ -183,7 +185,15 @@ class Dense(Layer):
183
185
  # Handle int4 unpacking cases in one place
184
186
  if is_int4:
185
187
  kernel = quantizers.unpack_int4(kernel, self._orig_input_dim)
186
- elif is_gptq and calibrated and gptq_bits == 4:
188
+ elif is_gptq and gptq_calibrated and gptq_bits == 4:
189
+ kernel = quantizers.unpack_int4(
190
+ self.quantized_kernel,
191
+ orig_len=self.units,
192
+ axis=0,
193
+ dtype="uint8",
194
+ )
195
+ elif is_awq and awq_calibrated:
196
+ # AWQ always uses 4-bit quantization
187
197
  kernel = quantizers.unpack_int4(
188
198
  self.quantized_kernel,
189
199
  orig_len=self.units,
@@ -304,8 +314,9 @@ class Dense(Layer):
304
314
  if mode not in self.variable_serialization_spec:
305
315
  raise self._quantization_mode_error(mode)
306
316
 
307
- # A saved GPTQ quantized model will always be calibrated.
317
+ # A saved GPTQ/AWQ quantized model will always be calibrated.
308
318
  self.is_gptq_calibrated = mode == "gptq"
319
+ self.is_awq_calibrated = mode == "awq"
309
320
 
310
321
  idx = 0
311
322
  for name in self.variable_serialization_spec[mode]:
@@ -395,6 +406,14 @@ class Dense(Layer):
395
406
  "kernel_zero",
396
407
  "g_idx",
397
408
  ],
409
+ "awq": [
410
+ "bias",
411
+ "quantized_kernel",
412
+ "kernel_scale",
413
+ "kernel_zero",
414
+ "awq_scales",
415
+ "g_idx",
416
+ ],
398
417
  }
399
418
 
400
419
  def quantized_build(self, kernel_shape, mode, config=None):
@@ -406,6 +425,8 @@ class Dense(Layer):
406
425
  self._float8_build()
407
426
  elif mode == "gptq":
408
427
  self._gptq_build(kernel_shape, config)
428
+ elif mode == "awq":
429
+ self._awq_build(kernel_shape, config)
409
430
  else:
410
431
  raise self._quantization_mode_error(mode)
411
432
  self._is_quantized = True
@@ -515,6 +536,97 @@ class Dense(Layer):
515
536
  y = self.activation(y)
516
537
  return y
517
538
 
539
+ def _awq_build(self, kernel_shape, config):
540
+ """Build variables for AWQ quantization.
541
+
542
+ AWQ uses 4-bit quantization with per-channel AWQ scales that protect
543
+ salient weights based on activation magnitudes.
544
+ """
545
+ from keras.src.quantizers import awq_core
546
+
547
+ # Ensures the forward pass uses the original high-precision kernel
548
+ # until calibration has been performed.
549
+ self.is_awq_calibrated = False
550
+ self.kernel_shape = kernel_shape
551
+
552
+ # For 4-bit weights, we pack two values per byte.
553
+ units = (kernel_shape[1] + 1) // 2
554
+
555
+ self.quantized_kernel = self.add_weight(
556
+ name="kernel",
557
+ shape=(units, kernel_shape[0]),
558
+ initializer="zeros",
559
+ dtype="uint8",
560
+ trainable=False,
561
+ )
562
+
563
+ group_size = awq_core.get_group_size_for_layer(self, config)
564
+ num_groups = (
565
+ 1 if group_size == -1 else math.ceil(kernel_shape[0] / group_size)
566
+ )
567
+ self.kernel_scale = self.add_weight(
568
+ name="kernel_scale",
569
+ shape=(self.units, num_groups),
570
+ initializer="ones",
571
+ trainable=False,
572
+ )
573
+ self.kernel_zero = self.add_weight(
574
+ name="kernel_zero",
575
+ shape=(self.units, num_groups),
576
+ initializer="zeros",
577
+ dtype="uint8",
578
+ trainable=False,
579
+ )
580
+
581
+ # Per-channel AWQ scales from activation magnitudes
582
+ self.awq_scales = self.add_weight(
583
+ name="awq_scales",
584
+ shape=(kernel_shape[0],),
585
+ initializer="ones",
586
+ trainable=False,
587
+ )
588
+ self.g_idx = self.add_weight(
589
+ name="g_idx",
590
+ shape=(kernel_shape[0],),
591
+ initializer="zeros",
592
+ dtype="float32",
593
+ trainable=False,
594
+ )
595
+
596
+ def _awq_call(self, inputs, training=False):
597
+ """Forward pass for AWQ quantized layer."""
598
+ if not self.is_awq_calibrated:
599
+ W = self._kernel
600
+ else:
601
+ # Unpack 4-bit weights
602
+ W = quantizers.unpack_int4(
603
+ self.quantized_kernel,
604
+ orig_len=self.units,
605
+ axis=0,
606
+ dtype="uint8",
607
+ )
608
+ # Dequantize using scale/zero maps
609
+ W = ops.transpose(
610
+ dequantize_with_sz_map(
611
+ W,
612
+ self.kernel_scale,
613
+ self.kernel_zero,
614
+ self.g_idx,
615
+ )
616
+ )
617
+ # Apply AWQ scales by dividing to restore original magnitude
618
+ # (We multiplied by scales before quantization, so divide to undo)
619
+ # awq_scales has shape [input_dim], W has shape [input_dim, units]
620
+ # Expand dims for proper broadcasting.
621
+ W = ops.divide(W, ops.expand_dims(self.awq_scales, -1))
622
+
623
+ y = ops.matmul(inputs, W)
624
+ if self.bias is not None:
625
+ y = ops.add(y, self.bias)
626
+ if self.activation is not None:
627
+ y = self.activation(y)
628
+ return y
629
+
518
630
  def _int4_build(self, kernel_shape, config=None):
519
631
  """Build variables for int4 quantization.
520
632
 
@@ -835,6 +947,8 @@ class Dense(Layer):
835
947
  self.kernel_scale.assign(kernel_scale)
836
948
  elif mode == "gptq":
837
949
  self.quantized_build(kernel_shape, mode, self.quantization_config)
950
+ elif mode == "awq":
951
+ self.quantized_build(kernel_shape, mode, self.quantization_config)
838
952
  elif mode == "float8":
839
953
  self.quantized_build(kernel_shape, mode)
840
954
  else:
@@ -847,6 +961,8 @@ class Dense(Layer):
847
961
  policy_name = mode
848
962
  if mode == "gptq":
849
963
  policy_name = self.quantization_config.dtype_policy_string()
964
+ elif mode == "awq":
965
+ policy_name = self.quantization_config.dtype_policy_string()
850
966
  policy = dtype_policies.get(
851
967
  f"{policy_name}_from_{self.dtype_policy.name}"
852
968
  )
@@ -881,7 +997,7 @@ class Dense(Layer):
881
997
  `kernel_scale`: The quantization scale for the merged kernel.
882
998
  This is `None` if the layer is not quantized.
883
999
  """
884
- if self.dtype_policy.quantization_mode in (None, "gptq"):
1000
+ if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
885
1001
  return self.kernel, None
886
1002
 
887
1003
  kernel_value = self._kernel
@@ -180,7 +180,7 @@ class EinsumDense(Layer):
180
180
  # quantized to int8 or int4, because `quantized_build` has created the
181
181
  # appropriate kernel variable. For other modes (e.g., float8 or no
182
182
  # quantization), we still need the floating-point kernel.
183
- if self.quantization_mode not in ("int8", "int4", "gptq"):
183
+ if self.quantization_mode not in ("int8", "int4", "gptq", "awq"):
184
184
  # If the layer is quantized to int8, `self._kernel` will be added
185
185
  # in `self._int8_build`. Therefore, we skip it here.
186
186
  self._kernel = self.add_weight(
@@ -219,15 +219,17 @@ class EinsumDense(Layer):
219
219
 
220
220
  mode = self.quantization_mode
221
221
  is_gptq = mode == "gptq"
222
+ is_awq = mode == "awq"
222
223
  is_int4 = mode == "int4"
223
- calibrated = bool(getattr(self, "is_gptq_calibrated", False))
224
+ gptq_calibrated = bool(getattr(self, "is_gptq_calibrated", False))
225
+ awq_calibrated = bool(getattr(self, "is_awq_calibrated", False))
224
226
  gptq_bits = (
225
227
  gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None
226
228
  )
227
229
 
228
230
  # Decide the source tensor first (packed vs already-quantized vs plain
229
231
  # kernel)
230
- if is_gptq and calibrated and gptq_bits != 4:
232
+ if is_gptq and gptq_calibrated and gptq_bits != 4:
231
233
  # calibrated GPTQ, not 4-bit, no unpacking needed
232
234
  kernel = self.quantized_kernel
233
235
  else:
@@ -241,13 +243,21 @@ class EinsumDense(Layer):
241
243
  self._orig_length_along_pack_axis,
242
244
  self._int4_pack_axis,
243
245
  )
244
- elif is_gptq and calibrated and gptq_bits == 4:
246
+ elif is_gptq and gptq_calibrated and gptq_bits == 4:
245
247
  kernel = quantizers.unpack_int4(
246
248
  self.quantized_kernel,
247
249
  orig_len=self.gptq_unpacked_column_size,
248
250
  axis=0,
249
251
  dtype="uint8",
250
252
  )
253
+ elif is_awq and awq_calibrated:
254
+ # AWQ always uses 4-bit quantization
255
+ kernel = quantizers.unpack_int4(
256
+ self.quantized_kernel,
257
+ orig_len=self.awq_unpacked_column_size,
258
+ axis=0,
259
+ dtype="uint8",
260
+ )
251
261
 
252
262
  # Apply LoRA if enabled
253
263
  if self.lora_enabled:
@@ -362,8 +372,9 @@ class EinsumDense(Layer):
362
372
  if mode not in self.variable_serialization_spec:
363
373
  raise self._quantization_mode_error(mode)
364
374
 
365
- # A saved GPTQ quantized model will always be calibrated.
375
+ # A saved GPTQ/AWQ quantized model will always be calibrated.
366
376
  self.is_gptq_calibrated = mode == "gptq"
377
+ self.is_awq_calibrated = mode == "awq"
367
378
 
368
379
  idx = 0
369
380
  for name in self.variable_serialization_spec[mode]:
@@ -459,6 +470,14 @@ class EinsumDense(Layer):
459
470
  "kernel_zero",
460
471
  "g_idx",
461
472
  ],
473
+ "awq": [
474
+ "bias",
475
+ "quantized_kernel",
476
+ "kernel_scale",
477
+ "kernel_zero",
478
+ "awq_scales",
479
+ "g_idx",
480
+ ],
462
481
  }
463
482
 
464
483
  def quantized_build(self, kernel_shape, mode, config=None):
@@ -470,6 +489,8 @@ class EinsumDense(Layer):
470
489
  self._float8_build()
471
490
  elif mode == "gptq":
472
491
  self._gptq_build(kernel_shape, config)
492
+ elif mode == "awq":
493
+ self._awq_build(kernel_shape, config)
473
494
  else:
474
495
  raise self._quantization_mode_error(mode)
475
496
  self._is_quantized = True
@@ -616,6 +637,127 @@ class EinsumDense(Layer):
616
637
  y = self.activation(y)
617
638
  return y
618
639
 
640
+ def _awq_build(self, kernel_shape, config):
641
+ """Build variables for AWQ quantization.
642
+
643
+ AWQ uses 4-bit quantization with per-channel AWQ scales that protect
644
+ salient weights based on activation magnitudes.
645
+ """
646
+ from keras.src.quantizers import awq_core
647
+
648
+ # Ensures the forward pass uses the original high-precision kernel
649
+ # until calibration has been performed.
650
+ self.is_awq_calibrated = False
651
+
652
+ self.original_kernel_shape = kernel_shape
653
+ if len(kernel_shape) == 2:
654
+ rows = kernel_shape[0]
655
+ columns = kernel_shape[1]
656
+ elif len(kernel_shape) == 3:
657
+ shape = list(self.original_kernel_shape)
658
+ d_model_dim_index = shape.index(max(shape))
659
+
660
+ if d_model_dim_index == 0: # QKV projection case
661
+ in_features, heads, head_dim = shape
662
+ rows, columns = (
663
+ in_features,
664
+ heads * head_dim,
665
+ )
666
+ elif d_model_dim_index in [1, 2]: # Attention Output case
667
+ heads, head_dim, out_features = shape
668
+ rows, columns = (
669
+ heads * head_dim,
670
+ out_features,
671
+ )
672
+ else:
673
+ raise ValueError("Could not determine row/column split.")
674
+ else:
675
+ raise ValueError("AWQ quantization only supports 2D or 3D kernels.")
676
+
677
+ group_size = awq_core.get_group_size_for_layer(self, config)
678
+ num_groups = 1 if group_size == -1 else math.ceil(rows / group_size)
679
+
680
+ self.awq_unpacked_column_size = columns
681
+
682
+ # For 4-bit weights, we pack two values per byte.
683
+ kernel_columns = (columns + 1) // 2
684
+
685
+ self._set_quantization_info()
686
+
687
+ self.quantized_kernel = self.add_weight(
688
+ name="kernel",
689
+ shape=(kernel_columns, rows),
690
+ initializer="zeros",
691
+ dtype="uint8",
692
+ trainable=False,
693
+ )
694
+
695
+ self.kernel_scale = self.add_weight(
696
+ name="kernel_scale",
697
+ shape=(columns, num_groups),
698
+ initializer="ones",
699
+ trainable=False,
700
+ )
701
+ self.kernel_zero = self.add_weight(
702
+ name="zero_point",
703
+ shape=(columns, num_groups),
704
+ initializer="zeros",
705
+ dtype="uint8",
706
+ trainable=False,
707
+ )
708
+
709
+ # Per-channel AWQ scales from activation magnitudes
710
+ self.awq_scales = self.add_weight(
711
+ name="awq_scales",
712
+ shape=(rows,),
713
+ initializer="ones",
714
+ trainable=False,
715
+ )
716
+
717
+ self.g_idx = self.add_weight(
718
+ name="g_idx",
719
+ shape=(rows,),
720
+ initializer="zeros",
721
+ dtype="float32",
722
+ trainable=False,
723
+ )
724
+
725
+ def _awq_call(self, inputs, training=False):
726
+ """Forward pass for AWQ quantized layer."""
727
+ if not self.is_awq_calibrated:
728
+ W = self._kernel
729
+ else:
730
+ # Unpack 4-bit weights
731
+ W = quantizers.unpack_int4(
732
+ self.quantized_kernel,
733
+ orig_len=self.awq_unpacked_column_size,
734
+ axis=0,
735
+ dtype="uint8",
736
+ )
737
+ # Dequantize using scale/zero maps
738
+ W = dequantize_with_sz_map(
739
+ W,
740
+ self.kernel_scale,
741
+ self.kernel_zero,
742
+ self.g_idx,
743
+ )
744
+ W = ops.transpose(W)
745
+
746
+ # Apply AWQ scales by dividing to restore original magnitude
747
+ # (We multiplied by scales before quantization, so divide to undo)
748
+ # awq_scales has shape [input_dim], W has shape [input_dim, out_dim]
749
+ # Expand dims for proper broadcasting.
750
+ W = ops.divide(W, ops.expand_dims(self.awq_scales, -1))
751
+
752
+ W = ops.reshape(W, self.original_kernel_shape)
753
+
754
+ y = ops.einsum(self.equation, inputs, W)
755
+ if self.bias is not None:
756
+ y = ops.add(y, self.bias)
757
+ if self.activation is not None:
758
+ y = self.activation(y)
759
+ return y
760
+
619
761
  def _int4_build(self, kernel_shape, config=None):
620
762
  """Build variables for int4 quantization.
621
763
 
@@ -1010,7 +1152,7 @@ class EinsumDense(Layer):
1010
1152
  self.quantization_config = config
1011
1153
 
1012
1154
  kernel_shape = self._kernel.shape
1013
- if mode in ("int8", "int4", "gptq"):
1155
+ if mode in ("int8", "int4", "gptq", "awq"):
1014
1156
  self._set_quantization_info()
1015
1157
 
1016
1158
  if mode == "int8":
@@ -1058,6 +1200,8 @@ class EinsumDense(Layer):
1058
1200
  policy_name = mode
1059
1201
  if mode == "gptq":
1060
1202
  policy_name = self.quantization_config.dtype_policy_string()
1203
+ elif mode == "awq":
1204
+ policy_name = self.quantization_config.dtype_policy_string()
1061
1205
  policy = dtype_policies.get(
1062
1206
  f"{policy_name}_from_{self.dtype_policy.name}"
1063
1207
  )
@@ -1121,7 +1265,7 @@ class EinsumDense(Layer):
1121
1265
  This is `None` if the layer is not quantized.
1122
1266
  """
1123
1267
  # If not a quantized layer, return the full-precision kernel directly.
1124
- if self.dtype_policy.quantization_mode in (None, "gptq"):
1268
+ if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
1125
1269
  return self.kernel, None
1126
1270
 
1127
1271
  # If quantized but LoRA is not enabled, return the original quantized
@@ -514,7 +514,7 @@ class Embedding(Layer):
514
514
  `embeddings_scale`: The quantization scale for the merged
515
515
  embeddings. This is `None` if the layer is not quantized.
516
516
  """
517
- if self.dtype_policy.quantization_mode in (None, "gptq"):
517
+ if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
518
518
  return self.embeddings, None
519
519
 
520
520
  embeddings_value = self._embeddings
@@ -6,6 +6,7 @@ from keras.src import ops
6
6
  from keras.src import quantizers
7
7
  from keras.src.api_export import keras_export
8
8
  from keras.src.backend import KerasTensor
9
+ from keras.src.backend import set_keras_mask
9
10
  from keras.src.quantizers.quantization_config import QuantizationConfig
10
11
 
11
12
 
@@ -117,7 +118,11 @@ class ReversibleEmbedding(layers.Embedding):
117
118
 
118
119
  def call(self, inputs, reverse=False):
119
120
  if not reverse:
120
- return super().call(inputs)
121
+ result = super().call(inputs)
122
+ mask = super().compute_mask(inputs)
123
+ if mask is not None:
124
+ set_keras_mask(result, mask)
125
+ return result
121
126
  else:
122
127
  if self.tie_weights:
123
128
  kernel = ops.transpose(ops.convert_to_tensor(self.embeddings))
@@ -135,6 +140,10 @@ class ReversibleEmbedding(layers.Embedding):
135
140
  )
136
141
  return logits
137
142
 
143
+ def compute_mask(self, inputs, mask=None):
144
+ # Disable masking from super class, masking is done directly in call.
145
+ return None
146
+
138
147
  def compute_output_shape(self, input_shape, reverse=False):
139
148
  output_shape = list(input_shape)
140
149
  if reverse:
keras/src/layers/layer.py CHANGED
@@ -1337,6 +1337,8 @@ class Layer(BackendLayer, Operation):
1337
1337
  return self._int4_call(*args, **kwargs)
1338
1338
  elif self.quantization_mode == "gptq":
1339
1339
  return self._gptq_call(*args, **kwargs)
1340
+ elif self.quantization_mode == "awq":
1341
+ return self._awq_call(*args, **kwargs)
1340
1342
  else:
1341
1343
  raise self._quantization_mode_error(self.quantization_mode)
1342
1344
 
@@ -1352,6 +1354,9 @@ class Layer(BackendLayer, Operation):
1352
1354
  def _gptq_call(self, *args, **kwargs):
1353
1355
  raise self._not_implemented_error(self._gptq_call)
1354
1356
 
1357
+ def _awq_call(self, *args, **kwargs):
1358
+ raise self._not_implemented_error(self._awq_call)
1359
+
1355
1360
  def _not_implemented_error(self, attr, msg=None):
1356
1361
  if callable(attr):
1357
1362
  attr_name = attr.__name__
@@ -507,10 +507,14 @@ class FeatureSpace(Layer):
507
507
 
508
508
  def adapt(self, dataset):
509
509
  if not isinstance(dataset, tf.data.Dataset):
510
- raise ValueError(
511
- "`adapt()` can only be called on a tf.data.Dataset. "
512
- f"Received instead: {dataset} (of type {type(dataset)})"
513
- )
510
+ if isinstance(dataset, dict):
511
+ dataset = tf.data.Dataset.from_tensor_slices(dataset)
512
+ else:
513
+ raise ValueError(
514
+ "`adapt()` can only be called on a tf.data.Dataset or a "
515
+ "dict of arrays/lists. "
516
+ f"Received instead: {dataset} (of type {type(dataset)})"
517
+ )
514
518
 
515
519
  for name in self._list_adaptable_preprocessors():
516
520
  # Call adapt() on each individual adaptable layer.
@@ -316,8 +316,8 @@ class AugMix(BaseImagePreprocessingLayer):
316
316
  def get_config(self):
317
317
  config = {
318
318
  "value_range": self.value_range,
319
- "num_chains": self.chain_depth,
320
- "chain_depth": self.num_chains,
319
+ "num_chains": self.num_chains,
320
+ "chain_depth": self.chain_depth,
321
321
  "factor": self.factor,
322
322
  "alpha": self.alpha,
323
323
  "all_ops": self.all_ops,
@@ -183,28 +183,26 @@ class CenterCrop(BaseImagePreprocessingLayer):
183
183
 
184
184
  def transform_images(self, images, transformation=None, training=True):
185
185
  inputs = self.backend.cast(images, self.compute_dtype)
186
+ inputs_shape = self.backend.shape(inputs)
187
+
186
188
  if self.data_format == "channels_first":
187
- init_height = inputs.shape[-2]
188
- init_width = inputs.shape[-1]
189
+ init_height = inputs_shape[-2]
190
+ init_width = inputs_shape[-1]
189
191
  else:
190
- init_height = inputs.shape[-3]
191
- init_width = inputs.shape[-2]
192
-
193
- if init_height is None or init_width is None:
194
- # Dynamic size case. TODO.
195
- raise ValueError(
196
- "At this time, CenterCrop can only "
197
- "process images with a static spatial "
198
- f"shape. Received: inputs.shape={inputs.shape}"
199
- )
192
+ init_height = inputs_shape[-3]
193
+ init_width = inputs_shape[-2]
200
194
 
195
+ # All these operations work both with ints (static sizes) and scalar
196
+ # tensors (dynamic sizes).
201
197
  h_diff = init_height - self.height
202
198
  w_diff = init_width - self.width
203
199
 
204
- h_start = int(h_diff / 2)
205
- w_start = int(w_diff / 2)
200
+ h_start = h_diff // 2
201
+ w_start = w_diff // 2
206
202
 
207
- if h_diff >= 0 and w_diff >= 0:
203
+ if (not isinstance(h_diff, int) or h_diff >= 0) and (
204
+ not isinstance(w_diff, int) or w_diff >= 0
205
+ ):
208
206
  if len(inputs.shape) == 4:
209
207
  if self.data_format == "channels_first":
210
208
  return inputs[
@@ -92,8 +92,8 @@ class RandomContrast(BaseImagePreprocessingLayer):
92
92
 
93
93
  def transform_images(self, images, transformation, training=True):
94
94
  if training:
95
- constrast_factor = transformation["contrast_factor"]
96
- outputs = self._adjust_constrast(images, constrast_factor)
95
+ contrast_factor = transformation["contrast_factor"]
96
+ outputs = self._adjust_contrast(images, contrast_factor)
97
97
  outputs = self.backend.numpy.clip(
98
98
  outputs, self.value_range[0], self.value_range[1]
99
99
  )
@@ -117,7 +117,7 @@ class RandomContrast(BaseImagePreprocessingLayer):
117
117
  ):
118
118
  return segmentation_masks
119
119
 
120
- def _adjust_constrast(self, inputs, contrast_factor):
120
+ def _adjust_contrast(self, inputs, contrast_factor):
121
121
  if self.data_format == "channels_first":
122
122
  height_axis = -2
123
123
  width_axis = -1
@@ -66,6 +66,16 @@ class Resizing(BaseImagePreprocessingLayer):
66
66
  `~/.keras/keras.json`. If you never set it, then it will be
67
67
  `"channels_last"`.
68
68
  **kwargs: Base layer keyword arguments, such as `name` and `dtype`.
69
+
70
+ Example:
71
+
72
+ ```python
73
+ (x_train, y_train), _ = keras.datasets.cifar10.load_data()
74
+ image = x_train[0]
75
+ resizer = keras.layers.Resizing(128, 128)
76
+ resized_image = resizer(image)
77
+ print("original:", image.shape, "resized:", resized_image.shape)
78
+ ```
69
79
  """
70
80
 
71
81
  _USE_BASE_FACTOR = False
@@ -73,6 +73,14 @@ class MeanSquaredError(LossFunctionWrapper):
73
73
  `"float32"` unless set to different value
74
74
  (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
75
75
  provided, then the `compute_dtype` will be utilized.
76
+
77
+ Examples:
78
+
79
+ >>> y_true = keras.ops.array([1.0, 0.0, 1.0])
80
+ >>> y_pred = keras.ops.array([0.9, 0.1, 0.8])
81
+ >>> loss = keras.losses.MeanSquaredError()
82
+ >>> loss(y_true, y_pred)
83
+ 0.02
76
84
  """
77
85
 
78
86
  def __init__(
@@ -114,6 +122,14 @@ class MeanAbsoluteError(LossFunctionWrapper):
114
122
  `"float32"` unless set to different value
115
123
  (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
116
124
  provided, then the `compute_dtype` will be utilized.
125
+
126
+ Examples:
127
+
128
+ >>> y_true = keras.ops.array([1.0, 0.3, 1.0])
129
+ >>> y_pred = keras.ops.array([1.9, 0.3, 1.8])
130
+ >>> loss = keras.losses.MeanAbsoluteError()
131
+ >>> loss(y_true, y_pred)
132
+ 0.5666667
117
133
  """
118
134
 
119
135
  def __init__(
@@ -155,6 +171,14 @@ class MeanAbsolutePercentageError(LossFunctionWrapper):
155
171
  `"float32"` unless set to different value
156
172
  (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
157
173
  provided, then the `compute_dtype` will be utilized.
174
+
175
+ Examples:
176
+
177
+ >>> y_true = keras.ops.array([100.0, 200.0, 300.0])
178
+ >>> y_pred = keras.ops.array([90.0, 210.0, 310.0])
179
+ >>> loss = keras.losses.MeanAbsolutePercentageError()
180
+ >>> loss(y_true, y_pred)
181
+ 6.111111
158
182
  """
159
183
 
160
184
  def __init__(