keras-nightly 3.14.0.dev2026012804__py3-none-any.whl → 3.14.0.dev2026012904__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,5 @@
1
1
  import copy
2
+ import math
2
3
 
3
4
  from keras.src import dtype_policies
4
5
  from keras.src import layers
@@ -8,6 +9,8 @@ from keras.src.api_export import keras_export
8
9
  from keras.src.backend import KerasTensor
9
10
  from keras.src.backend import set_keras_mask
10
11
  from keras.src.quantizers.quantization_config import QuantizationConfig
12
+ from keras.src.quantizers.quantization_config import get_block_size_for_layer
13
+ from keras.src.quantizers.quantizers import dequantize_with_sz_map
11
14
 
12
15
 
13
16
  @keras_export("keras.layers.ReversibleEmbedding")
@@ -125,7 +128,7 @@ class ReversibleEmbedding(layers.Embedding):
125
128
  return result
126
129
  else:
127
130
  if self.tie_weights:
128
- kernel = ops.transpose(ops.convert_to_tensor(self.embeddings))
131
+ kernel = ops.transpose(self.embeddings)
129
132
  else:
130
133
  kernel = self.reverse_embeddings
131
134
  if self.reverse_dtype is not None:
@@ -180,6 +183,9 @@ class ReversibleEmbedding(layers.Embedding):
180
183
  variable_spec.append("reverse_embeddings")
181
184
  if mode in ("int4", "int8"):
182
185
  variable_spec.append("reverse_embeddings_scale")
186
+ if mode == "int4":
187
+ # reverse_embeddings_zero only exists for sub-channel
188
+ variable_spec.append("reverse_embeddings_zero")
183
189
  return _spec
184
190
 
185
191
  def quantized_build(self, embeddings_shape, mode, config=None):
@@ -235,13 +241,34 @@ class ReversibleEmbedding(layers.Embedding):
235
241
  dtype="int8",
236
242
  trainable=False,
237
243
  )
244
+
245
+ # Determine block_size from config or dtype_policy
246
+ block_size = get_block_size_for_layer(self, config)
247
+
248
+ if block_size is None or block_size == -1:
249
+ # Per-channel: one scale per output unit (input_dim)
250
+ reverse_scale_shape = (self.input_dim,)
251
+ else:
252
+ # Grouped: scale per group along output_dim (axis=0)
253
+ n_groups = math.ceil(self.output_dim / block_size)
254
+ reverse_scale_shape = (n_groups, self.input_dim)
255
+
238
256
  self.reverse_embeddings_scale = self.add_weight(
239
257
  name="reverse_embeddings_scale",
240
- shape=(self.input_dim,),
258
+ shape=reverse_scale_shape,
241
259
  initializer="ones",
242
260
  trainable=False,
243
261
  )
244
262
 
263
+ # Zero point for asymmetric grouped quantization
264
+ if block_size is not None and block_size != -1:
265
+ self.reverse_embeddings_zero = self.add_weight(
266
+ name="reverse_embeddings_zero",
267
+ shape=reverse_scale_shape,
268
+ initializer="zeros",
269
+ trainable=False,
270
+ )
271
+
245
272
  def _int8_call(self, inputs, reverse=False):
246
273
  if not reverse:
247
274
  return super()._int8_call(inputs)
@@ -272,23 +299,79 @@ class ReversibleEmbedding(layers.Embedding):
272
299
  if not reverse:
273
300
  return super()._int4_call(inputs)
274
301
  else:
302
+ block_size = getattr(self, "_int4_block_size", None)
303
+
275
304
  if self.tie_weights:
276
305
  embeddings = ops.transpose(self._embeddings)
277
- scale = ops.transpose(self.embeddings_scale)
306
+ scale = self.embeddings_scale
307
+ # For tied weights, scale shape is (input_dim,) or
308
+ # (input_dim, n_groups). For per-channel, transpose scale.
309
+ if block_size is None or block_size == -1:
310
+ scale = ops.transpose(scale)
278
311
  else:
279
312
  embeddings = self.reverse_embeddings
280
313
  scale = self.reverse_embeddings_scale
314
+
281
315
  unpacked_embeddings = quantizers.unpack_int4(
282
316
  embeddings, self.output_dim, axis=0
283
317
  )
318
+
284
319
  if self.inputs_quantizer:
285
320
  inputs, inputs_scale = self.inputs_quantizer(inputs)
286
321
  else:
287
322
  inputs_scale = ops.ones((1,), dtype=self.compute_dtype)
288
- logits = ops.matmul(inputs, unpacked_embeddings)
289
- # De-scale outputs
290
- logits = ops.cast(logits, self.compute_dtype)
291
- logits = ops.divide(logits, ops.multiply(inputs_scale, scale))
323
+
324
+ if block_size is None or block_size == -1:
325
+ # Per-channel: do matmul then dequantize
326
+ logits = ops.matmul(inputs, unpacked_embeddings)
327
+ logits = ops.cast(logits, self.compute_dtype)
328
+ logits = ops.divide(logits, ops.multiply(inputs_scale, scale))
329
+ elif self.tie_weights:
330
+ # Sub-channel with asymmetric quantization (tied weights)
331
+ # Must dequantize embeddings before matmul for correctness
332
+ # unpacked_embeddings shape: (output_dim, input_dim)
333
+ # scale shape: (input_dim, n_groups)
334
+ # embeddings_zero shape: (input_dim, n_groups)
335
+ # g_idx shape: (output_dim,)
336
+
337
+ # Transpose scale/zero for dequantization:
338
+ # [input_dim, n_groups] -> [n_groups, input_dim]
339
+ scale_t = ops.transpose(scale)
340
+ zero_t = ops.transpose(self.embeddings_zero)
341
+
342
+ float_embeddings = dequantize_with_sz_map(
343
+ ops.cast(unpacked_embeddings, self.compute_dtype),
344
+ scale_t,
345
+ zero_t,
346
+ self.g_idx,
347
+ group_axis=0,
348
+ )
349
+
350
+ # inputs shape: (batch, output_dim)
351
+ # float_embeddings shape: (output_dim, input_dim)
352
+ logits = ops.matmul(inputs, float_embeddings)
353
+ logits = ops.divide(logits, inputs_scale)
354
+ else:
355
+ # Untied weights with asymmetric grouped quantization
356
+ # Must dequantize embeddings before matmul for correctness
357
+ # unpacked_embeddings shape: (output_dim, input_dim)
358
+ # scale shape: (n_groups, input_dim)
359
+ # reverse_embeddings_zero shape: (n_groups, input_dim)
360
+ # g_idx shape: (output_dim,) - reuse from forward pass
361
+
362
+ float_embeddings = dequantize_with_sz_map(
363
+ ops.cast(unpacked_embeddings, self.compute_dtype),
364
+ scale,
365
+ self.reverse_embeddings_zero,
366
+ self.g_idx,
367
+ group_axis=0,
368
+ )
369
+
370
+ # inputs shape: (batch, output_dim)
371
+ # float_embeddings shape: (output_dim, input_dim)
372
+ logits = ops.matmul(inputs, float_embeddings)
373
+ logits = ops.divide(logits, inputs_scale)
374
+
292
375
  # Optionally soft-cap logits.
293
376
  if self.logit_soft_cap is not None:
294
377
  soft_cap = self.logit_soft_cap
@@ -340,60 +423,119 @@ class ReversibleEmbedding(layers.Embedding):
340
423
  self.reverse_embeddings.assign(reverse_embeddings_value)
341
424
  self.reverse_embeddings_scale.assign(reverse_embeddings_scale)
342
425
  elif mode == "int4":
343
- # Quantize to int4 values (stored in int8 dtype, range [-8, 7]).
344
- weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
345
- self.quantization_config,
346
- quantizers.AbsMaxQuantizer(
347
- axis=-1,
348
- value_range=(-8, 7),
349
- output_dtype="int8",
350
- ),
351
- )
352
- embeddings_value, embeddings_scale = weight_quantizer(
353
- self._embeddings, to_numpy=True
426
+ from keras.src.quantizers.quantization_config import (
427
+ Int4QuantizationConfig,
354
428
  )
355
- embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
356
- # 2. Pack two int4 values into a single int8 byte.
357
- packed_embeddings_value, _, _ = quantizers.pack_int4(
358
- embeddings_value, axis=-1
359
- )
360
- del self._embeddings
361
- if not self.tie_weights:
362
- reverse_weight_quantizer = (
429
+
430
+ block_size = None
431
+ if isinstance(self.quantization_config, Int4QuantizationConfig):
432
+ block_size = self.quantization_config.block_size
433
+
434
+ use_grouped = block_size is not None and block_size != -1
435
+
436
+ # Quantize forward embeddings
437
+ if not use_grouped:
438
+ # Per-channel quantization
439
+ weight_quantizer = (
363
440
  QuantizationConfig.weight_quantizer_or_default(
364
441
  self.quantization_config,
365
442
  quantizers.AbsMaxQuantizer(
366
- axis=0,
443
+ axis=-1,
367
444
  value_range=(-8, 7),
368
445
  output_dtype="int8",
369
446
  ),
370
447
  )
371
448
  )
372
- reverse_embeddings_value, reverse_embeddings_scale = (
373
- reverse_weight_quantizer(
374
- self.reverse_embeddings, to_numpy=True
375
- )
449
+ embeddings_value, embeddings_scale = weight_quantizer(
450
+ self._embeddings, to_numpy=True
376
451
  )
377
- reverse_embeddings_scale = ops.squeeze(
378
- reverse_embeddings_scale, axis=0
452
+ embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
453
+ else:
454
+ # Sub-channel quantization with asymmetric zero point
455
+ embeddings_t = ops.transpose(self._embeddings)
456
+ embeddings_value_t, scale_t, zero_t = (
457
+ quantizers.abs_max_quantize_grouped_with_zero_point(
458
+ embeddings_t,
459
+ block_size=block_size,
460
+ value_range=(-8, 7),
461
+ dtype="int8",
462
+ to_numpy=True,
463
+ )
379
464
  )
380
- # Pack two int4 values into a single int8 byte.
465
+ # Transpose back to (input_dim, output_dim) layout
466
+ embeddings_value = ops.transpose(embeddings_value_t)
467
+ embeddings_scale = ops.transpose(scale_t)
468
+ embeddings_zero = ops.transpose(zero_t)
469
+
470
+ packed_embeddings_value, _, _ = quantizers.pack_int4(
471
+ embeddings_value, axis=-1
472
+ )
473
+ del self._embeddings
474
+
475
+ # Quantize reverse embeddings if not tied
476
+ if not self.tie_weights:
477
+ if not use_grouped:
478
+ reverse_weight_quantizer = (
479
+ QuantizationConfig.weight_quantizer_or_default(
480
+ self.quantization_config,
481
+ quantizers.AbsMaxQuantizer(
482
+ axis=0,
483
+ value_range=(-8, 7),
484
+ output_dtype="int8",
485
+ ),
486
+ )
487
+ )
488
+ reverse_embeddings_value, reverse_embeddings_scale = (
489
+ reverse_weight_quantizer(
490
+ self.reverse_embeddings, to_numpy=True
491
+ )
492
+ )
493
+ reverse_embeddings_scale = ops.squeeze(
494
+ reverse_embeddings_scale, axis=0
495
+ )
496
+ else:
497
+ reverse_value, reverse_scale, reverse_zero = (
498
+ quantizers.abs_max_quantize_grouped_with_zero_point(
499
+ self.reverse_embeddings,
500
+ block_size=block_size,
501
+ value_range=(-8, 7),
502
+ dtype="int8",
503
+ to_numpy=True,
504
+ )
505
+ )
506
+ reverse_embeddings_value = reverse_value
507
+ reverse_embeddings_scale = reverse_scale
508
+ reverse_embeddings_zero = reverse_zero
509
+
381
510
  packed_reverse_embeddings_value, _, _ = quantizers.pack_int4(
382
511
  reverse_embeddings_value, axis=0
383
512
  )
384
513
  del self.reverse_embeddings
514
+
385
515
  self.quantized_build(
386
516
  embeddings_shape, mode, self.quantization_config
387
517
  )
388
518
  self._embeddings.assign(packed_embeddings_value)
389
519
  self.embeddings_scale.assign(embeddings_scale)
520
+ if use_grouped:
521
+ self.embeddings_zero.assign(embeddings_zero)
390
522
  if not self.tie_weights:
391
523
  self.reverse_embeddings.assign(packed_reverse_embeddings_value)
392
524
  self.reverse_embeddings_scale.assign(reverse_embeddings_scale)
525
+ if use_grouped:
526
+ self.reverse_embeddings_zero.assign(reverse_embeddings_zero)
393
527
  else:
394
528
  raise self._quantization_mode_error(mode)
395
529
 
396
530
  # Set new dtype policy.
397
531
  if self.dtype_policy.quantization_mode is None:
398
- policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
532
+ policy_name = mode
533
+ if mode == "int4":
534
+ # Include block_size in policy name for sub-channel quantization
535
+ block_size = get_block_size_for_layer(self, config)
536
+ block_size_value = -1 if block_size is None else block_size
537
+ policy_name = f"int4/{block_size_value}"
538
+ policy = dtype_policies.get(
539
+ f"{policy_name}_from_{self.dtype_policy.name}"
540
+ )
399
541
  self.dtype_policy = policy
@@ -213,8 +213,37 @@ class Discretization(DataLayer):
213
213
  return
214
214
  self.summary = np.array([[], []], dtype="float32")
215
215
 
216
+ def compute_output_shape(self, input_shape):
217
+ if self.output_mode == "int":
218
+ return input_shape
219
+
220
+ # Calculate depth (number of bins)
221
+ depth = (
222
+ len(self.bin_boundaries) + 1
223
+ if self.bin_boundaries is not None
224
+ else self.num_bins
225
+ )
226
+
227
+ if self.output_mode == "one_hot":
228
+ # For one_hot mode, add depth dimension
229
+ # If last dimension is 1, replace it with depth, otherwise append
230
+ if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
231
+ return tuple(input_shape[:-1]) + (depth,)
232
+ else:
233
+ return tuple(input_shape) + (depth,)
234
+ else:
235
+ if input_shape and len(input_shape) >= 2:
236
+ # Match to eager tensor, remove second and append depth
237
+ out_shape = (
238
+ (input_shape[0],) + tuple(input_shape[2:]) + (depth,)
239
+ )
240
+ return out_shape
241
+ else:
242
+ return (depth,)
243
+
216
244
  def compute_output_spec(self, inputs):
217
- return backend.KerasTensor(shape=inputs.shape, dtype=self.output_dtype)
245
+ output_shape = self.compute_output_shape(inputs.shape)
246
+ return backend.KerasTensor(shape=output_shape, dtype=self.output_dtype)
218
247
 
219
248
  def load_own_variables(self, store):
220
249
  if len(store) == 1:
@@ -9,11 +9,17 @@ from keras.src.quantizers.quantization_config import QuantizationConfig
9
9
  from keras.src.quantizers.quantizers import AbsMaxQuantizer
10
10
  from keras.src.quantizers.quantizers import Quantizer
11
11
  from keras.src.quantizers.quantizers import abs_max_quantize
12
+ from keras.src.quantizers.quantizers import (
13
+ abs_max_quantize_grouped_with_zero_point,
14
+ )
12
15
  from keras.src.quantizers.quantizers import compute_float8_amax_history
13
16
  from keras.src.quantizers.quantizers import compute_float8_scale
17
+ from keras.src.quantizers.quantizers import compute_quantization_parameters
18
+ from keras.src.quantizers.quantizers import dequantize_with_sz_map
14
19
  from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars
15
20
  from keras.src.quantizers.quantizers import pack_int4
16
21
  from keras.src.quantizers.quantizers import quantize_and_dequantize
22
+ from keras.src.quantizers.quantizers import quantize_with_sz_map
17
23
  from keras.src.quantizers.quantizers import unpack_int4
18
24
  from keras.src.saving import serialization_lib
19
25
  from keras.src.utils.naming import to_snake_case
@@ -99,14 +99,46 @@ class Int4QuantizationConfig(QuantizationConfig):
99
99
  weight_quantizer: Quantizer for weights.
100
100
  activation_quantizer: Quantizer for activations. If "default", uses
101
101
  AbsMaxQuantizer with axis=-1.
102
+ block_size: Size of groups along the input dimension for sub-channel
103
+ quantization. If a positive integer, uses sub-channel quantization
104
+ with `ceil(input_dim / block_size)` groups. If `None` or `-1`,
105
+ uses per-channel quantization (one scale per output channel).
106
+ Default: `128` (sub-channel with 128-element groups).
102
107
  """
103
108
 
104
- def __init__(self, weight_quantizer=None, activation_quantizer="default"):
105
- from keras.src.quantizers.quantizers import AbsMaxQuantizer
106
-
109
+ def __init__(
110
+ self,
111
+ weight_quantizer=None,
112
+ activation_quantizer="default",
113
+ block_size=128,
114
+ ):
107
115
  if activation_quantizer == "default":
108
- activation_quantizer = AbsMaxQuantizer()
116
+ # Use weight-only quantization by default for int4
117
+ activation_quantizer = None
109
118
  super().__init__(weight_quantizer, activation_quantizer)
119
+
120
+ # Validate block_size
121
+ if block_size is not None and block_size != -1 and block_size <= 0:
122
+ raise ValueError(
123
+ f"block_size must be None, -1, or a positive integer. "
124
+ f"Received: block_size={block_size}"
125
+ )
126
+ self.block_size = block_size
127
+
128
+ # Sub-channel quantization does not support custom quantizers
129
+ is_sub_channel = block_size is not None and block_size > 0
130
+ has_custom_quantizer = (
131
+ self.weight_quantizer is not None
132
+ or self.activation_quantizer is not None
133
+ )
134
+ if is_sub_channel and has_custom_quantizer:
135
+ raise ValueError(
136
+ "Int4 sub-channel quantization (block_size > 0) does not "
137
+ "support custom quantizers. Either set block_size to None "
138
+ "or -1 for per-channel quantization, or remove the custom "
139
+ f"quantizer arguments. Received: block_size={block_size}"
140
+ )
141
+
110
142
  if self.weight_quantizer is not None:
111
143
  if self.weight_quantizer.value_range != (-8, 7):
112
144
  raise ValueError(
@@ -126,6 +158,28 @@ class Int4QuantizationConfig(QuantizationConfig):
126
158
  def mode(self):
127
159
  return "int4"
128
160
 
161
+ def get_config(self):
162
+ config = super().get_config()
163
+ config["block_size"] = self.block_size
164
+ return config
165
+
166
+ @classmethod
167
+ def from_config(cls, config):
168
+ weight_quantizer = serialization_lib.deserialize_keras_object(
169
+ config.get("weight_quantizer")
170
+ )
171
+ activation_quantizer = serialization_lib.deserialize_keras_object(
172
+ config.get("activation_quantizer")
173
+ )
174
+ # Default to None for backwards compatibility with models saved
175
+ # before block_size was introduced (those used per-channel mode)
176
+ block_size = config.get("block_size", None)
177
+ return cls(
178
+ weight_quantizer=weight_quantizer,
179
+ activation_quantizer=activation_quantizer,
180
+ block_size=block_size,
181
+ )
182
+
129
183
 
130
184
  @keras_export("keras.quantizers.Float8QuantizationConfig")
131
185
  class Float8QuantizationConfig(QuantizationConfig):
@@ -244,3 +298,43 @@ def _validate_mode(mode):
244
298
  "Invalid quantization mode. "
245
299
  f"Expected one of {QUANTIZATION_MODES}. Received: mode={mode}"
246
300
  )
301
+
302
+
303
+ def get_block_size_for_layer(layer, config):
304
+ """Determine the block size for int4 quantization.
305
+
306
+ The block size can be specified either through the `config` argument
307
+ or through the `dtype_policy` if it is of type `Int4DTypePolicy`.
308
+
309
+ The config argument is usually available when quantizing the layer
310
+ via the `quantize` method. If the layer was deserialized from a
311
+ saved model, the block size should be specified in the `dtype_policy`.
312
+
313
+ Args:
314
+ layer: The layer being quantized.
315
+ config: An optional configuration object that may contain the
316
+ `block_size` attribute.
317
+ Returns:
318
+ int or None. The determined block size for int4 quantization.
319
+ Returns `None` or `-1` for per-channel quantization.
320
+ """
321
+ from keras.src.dtype_policies.dtype_policy import Int4DTypePolicy
322
+ from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
323
+
324
+ if config and isinstance(config, Int4QuantizationConfig):
325
+ return config.block_size
326
+ elif isinstance(layer.dtype_policy, Int4DTypePolicy):
327
+ block_size = layer.dtype_policy.block_size
328
+ # Convert -1 to None for consistency
329
+ return None if block_size == -1 else block_size
330
+ elif isinstance(layer.dtype_policy, DTypePolicyMap):
331
+ policy = layer.dtype_policy[layer.path]
332
+ if isinstance(policy, Int4DTypePolicy):
333
+ block_size = policy.block_size
334
+ return None if block_size == -1 else block_size
335
+ # Fall back to None for legacy QuantizedDTypePolicy
336
+ return None
337
+ else:
338
+ # For backwards compatibility with models that don't have
339
+ # Int4DTypePolicy (legacy per-channel mode)
340
+ return None