keras-nightly 3.14.0.dev2026012804__py3-none-any.whl → 3.14.0.dev2026012904__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/quantizers/__init__.py +3 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/quantizers/__init__.py +3 -0
- keras/src/backend/jax/core.py +12 -2
- keras/src/callbacks/orbax_checkpoint.py +41 -8
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +80 -1
- keras/src/layers/core/dense.py +278 -95
- keras/src/layers/core/einsum_dense.py +350 -181
- keras/src/layers/core/embedding.py +236 -49
- keras/src/layers/core/reversible_embedding.py +177 -35
- keras/src/layers/preprocessing/discretization.py +30 -1
- keras/src/quantizers/__init__.py +6 -0
- keras/src/quantizers/quantization_config.py +98 -4
- keras/src/quantizers/quantizers.py +262 -32
- keras/src/saving/saving_api.py +66 -2
- keras/src/version.py +1 -1
- {keras_nightly-3.14.0.dev2026012804.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/METADATA +1 -1
- {keras_nightly-3.14.0.dev2026012804.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/RECORD +22 -22
- {keras_nightly-3.14.0.dev2026012804.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/WHEEL +0 -0
- {keras_nightly-3.14.0.dev2026012804.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import copy
|
|
2
|
+
import math
|
|
2
3
|
|
|
3
4
|
from keras.src import dtype_policies
|
|
4
5
|
from keras.src import layers
|
|
@@ -8,6 +9,8 @@ from keras.src.api_export import keras_export
|
|
|
8
9
|
from keras.src.backend import KerasTensor
|
|
9
10
|
from keras.src.backend import set_keras_mask
|
|
10
11
|
from keras.src.quantizers.quantization_config import QuantizationConfig
|
|
12
|
+
from keras.src.quantizers.quantization_config import get_block_size_for_layer
|
|
13
|
+
from keras.src.quantizers.quantizers import dequantize_with_sz_map
|
|
11
14
|
|
|
12
15
|
|
|
13
16
|
@keras_export("keras.layers.ReversibleEmbedding")
|
|
@@ -125,7 +128,7 @@ class ReversibleEmbedding(layers.Embedding):
|
|
|
125
128
|
return result
|
|
126
129
|
else:
|
|
127
130
|
if self.tie_weights:
|
|
128
|
-
kernel = ops.transpose(
|
|
131
|
+
kernel = ops.transpose(self.embeddings)
|
|
129
132
|
else:
|
|
130
133
|
kernel = self.reverse_embeddings
|
|
131
134
|
if self.reverse_dtype is not None:
|
|
@@ -180,6 +183,9 @@ class ReversibleEmbedding(layers.Embedding):
|
|
|
180
183
|
variable_spec.append("reverse_embeddings")
|
|
181
184
|
if mode in ("int4", "int8"):
|
|
182
185
|
variable_spec.append("reverse_embeddings_scale")
|
|
186
|
+
if mode == "int4":
|
|
187
|
+
# reverse_embeddings_zero only exists for sub-channel
|
|
188
|
+
variable_spec.append("reverse_embeddings_zero")
|
|
183
189
|
return _spec
|
|
184
190
|
|
|
185
191
|
def quantized_build(self, embeddings_shape, mode, config=None):
|
|
@@ -235,13 +241,34 @@ class ReversibleEmbedding(layers.Embedding):
|
|
|
235
241
|
dtype="int8",
|
|
236
242
|
trainable=False,
|
|
237
243
|
)
|
|
244
|
+
|
|
245
|
+
# Determine block_size from config or dtype_policy
|
|
246
|
+
block_size = get_block_size_for_layer(self, config)
|
|
247
|
+
|
|
248
|
+
if block_size is None or block_size == -1:
|
|
249
|
+
# Per-channel: one scale per output unit (input_dim)
|
|
250
|
+
reverse_scale_shape = (self.input_dim,)
|
|
251
|
+
else:
|
|
252
|
+
# Grouped: scale per group along output_dim (axis=0)
|
|
253
|
+
n_groups = math.ceil(self.output_dim / block_size)
|
|
254
|
+
reverse_scale_shape = (n_groups, self.input_dim)
|
|
255
|
+
|
|
238
256
|
self.reverse_embeddings_scale = self.add_weight(
|
|
239
257
|
name="reverse_embeddings_scale",
|
|
240
|
-
shape=
|
|
258
|
+
shape=reverse_scale_shape,
|
|
241
259
|
initializer="ones",
|
|
242
260
|
trainable=False,
|
|
243
261
|
)
|
|
244
262
|
|
|
263
|
+
# Zero point for asymmetric grouped quantization
|
|
264
|
+
if block_size is not None and block_size != -1:
|
|
265
|
+
self.reverse_embeddings_zero = self.add_weight(
|
|
266
|
+
name="reverse_embeddings_zero",
|
|
267
|
+
shape=reverse_scale_shape,
|
|
268
|
+
initializer="zeros",
|
|
269
|
+
trainable=False,
|
|
270
|
+
)
|
|
271
|
+
|
|
245
272
|
def _int8_call(self, inputs, reverse=False):
|
|
246
273
|
if not reverse:
|
|
247
274
|
return super()._int8_call(inputs)
|
|
@@ -272,23 +299,79 @@ class ReversibleEmbedding(layers.Embedding):
|
|
|
272
299
|
if not reverse:
|
|
273
300
|
return super()._int4_call(inputs)
|
|
274
301
|
else:
|
|
302
|
+
block_size = getattr(self, "_int4_block_size", None)
|
|
303
|
+
|
|
275
304
|
if self.tie_weights:
|
|
276
305
|
embeddings = ops.transpose(self._embeddings)
|
|
277
|
-
scale =
|
|
306
|
+
scale = self.embeddings_scale
|
|
307
|
+
# For tied weights, scale shape is (input_dim,) or
|
|
308
|
+
# (input_dim, n_groups). For per-channel, transpose scale.
|
|
309
|
+
if block_size is None or block_size == -1:
|
|
310
|
+
scale = ops.transpose(scale)
|
|
278
311
|
else:
|
|
279
312
|
embeddings = self.reverse_embeddings
|
|
280
313
|
scale = self.reverse_embeddings_scale
|
|
314
|
+
|
|
281
315
|
unpacked_embeddings = quantizers.unpack_int4(
|
|
282
316
|
embeddings, self.output_dim, axis=0
|
|
283
317
|
)
|
|
318
|
+
|
|
284
319
|
if self.inputs_quantizer:
|
|
285
320
|
inputs, inputs_scale = self.inputs_quantizer(inputs)
|
|
286
321
|
else:
|
|
287
322
|
inputs_scale = ops.ones((1,), dtype=self.compute_dtype)
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
323
|
+
|
|
324
|
+
if block_size is None or block_size == -1:
|
|
325
|
+
# Per-channel: do matmul then dequantize
|
|
326
|
+
logits = ops.matmul(inputs, unpacked_embeddings)
|
|
327
|
+
logits = ops.cast(logits, self.compute_dtype)
|
|
328
|
+
logits = ops.divide(logits, ops.multiply(inputs_scale, scale))
|
|
329
|
+
elif self.tie_weights:
|
|
330
|
+
# Sub-channel with asymmetric quantization (tied weights)
|
|
331
|
+
# Must dequantize embeddings before matmul for correctness
|
|
332
|
+
# unpacked_embeddings shape: (output_dim, input_dim)
|
|
333
|
+
# scale shape: (input_dim, n_groups)
|
|
334
|
+
# embeddings_zero shape: (input_dim, n_groups)
|
|
335
|
+
# g_idx shape: (output_dim,)
|
|
336
|
+
|
|
337
|
+
# Transpose scale/zero for dequantization:
|
|
338
|
+
# [input_dim, n_groups] -> [n_groups, input_dim]
|
|
339
|
+
scale_t = ops.transpose(scale)
|
|
340
|
+
zero_t = ops.transpose(self.embeddings_zero)
|
|
341
|
+
|
|
342
|
+
float_embeddings = dequantize_with_sz_map(
|
|
343
|
+
ops.cast(unpacked_embeddings, self.compute_dtype),
|
|
344
|
+
scale_t,
|
|
345
|
+
zero_t,
|
|
346
|
+
self.g_idx,
|
|
347
|
+
group_axis=0,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
# inputs shape: (batch, output_dim)
|
|
351
|
+
# float_embeddings shape: (output_dim, input_dim)
|
|
352
|
+
logits = ops.matmul(inputs, float_embeddings)
|
|
353
|
+
logits = ops.divide(logits, inputs_scale)
|
|
354
|
+
else:
|
|
355
|
+
# Untied weights with asymmetric grouped quantization
|
|
356
|
+
# Must dequantize embeddings before matmul for correctness
|
|
357
|
+
# unpacked_embeddings shape: (output_dim, input_dim)
|
|
358
|
+
# scale shape: (n_groups, input_dim)
|
|
359
|
+
# reverse_embeddings_zero shape: (n_groups, input_dim)
|
|
360
|
+
# g_idx shape: (output_dim,) - reuse from forward pass
|
|
361
|
+
|
|
362
|
+
float_embeddings = dequantize_with_sz_map(
|
|
363
|
+
ops.cast(unpacked_embeddings, self.compute_dtype),
|
|
364
|
+
scale,
|
|
365
|
+
self.reverse_embeddings_zero,
|
|
366
|
+
self.g_idx,
|
|
367
|
+
group_axis=0,
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
# inputs shape: (batch, output_dim)
|
|
371
|
+
# float_embeddings shape: (output_dim, input_dim)
|
|
372
|
+
logits = ops.matmul(inputs, float_embeddings)
|
|
373
|
+
logits = ops.divide(logits, inputs_scale)
|
|
374
|
+
|
|
292
375
|
# Optionally soft-cap logits.
|
|
293
376
|
if self.logit_soft_cap is not None:
|
|
294
377
|
soft_cap = self.logit_soft_cap
|
|
@@ -340,60 +423,119 @@ class ReversibleEmbedding(layers.Embedding):
|
|
|
340
423
|
self.reverse_embeddings.assign(reverse_embeddings_value)
|
|
341
424
|
self.reverse_embeddings_scale.assign(reverse_embeddings_scale)
|
|
342
425
|
elif mode == "int4":
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
self.quantization_config,
|
|
346
|
-
quantizers.AbsMaxQuantizer(
|
|
347
|
-
axis=-1,
|
|
348
|
-
value_range=(-8, 7),
|
|
349
|
-
output_dtype="int8",
|
|
350
|
-
),
|
|
351
|
-
)
|
|
352
|
-
embeddings_value, embeddings_scale = weight_quantizer(
|
|
353
|
-
self._embeddings, to_numpy=True
|
|
426
|
+
from keras.src.quantizers.quantization_config import (
|
|
427
|
+
Int4QuantizationConfig,
|
|
354
428
|
)
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
429
|
+
|
|
430
|
+
block_size = None
|
|
431
|
+
if isinstance(self.quantization_config, Int4QuantizationConfig):
|
|
432
|
+
block_size = self.quantization_config.block_size
|
|
433
|
+
|
|
434
|
+
use_grouped = block_size is not None and block_size != -1
|
|
435
|
+
|
|
436
|
+
# Quantize forward embeddings
|
|
437
|
+
if not use_grouped:
|
|
438
|
+
# Per-channel quantization
|
|
439
|
+
weight_quantizer = (
|
|
363
440
|
QuantizationConfig.weight_quantizer_or_default(
|
|
364
441
|
self.quantization_config,
|
|
365
442
|
quantizers.AbsMaxQuantizer(
|
|
366
|
-
axis
|
|
443
|
+
axis=-1,
|
|
367
444
|
value_range=(-8, 7),
|
|
368
445
|
output_dtype="int8",
|
|
369
446
|
),
|
|
370
447
|
)
|
|
371
448
|
)
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
self.reverse_embeddings, to_numpy=True
|
|
375
|
-
)
|
|
449
|
+
embeddings_value, embeddings_scale = weight_quantizer(
|
|
450
|
+
self._embeddings, to_numpy=True
|
|
376
451
|
)
|
|
377
|
-
|
|
378
|
-
|
|
452
|
+
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
|
|
453
|
+
else:
|
|
454
|
+
# Sub-channel quantization with asymmetric zero point
|
|
455
|
+
embeddings_t = ops.transpose(self._embeddings)
|
|
456
|
+
embeddings_value_t, scale_t, zero_t = (
|
|
457
|
+
quantizers.abs_max_quantize_grouped_with_zero_point(
|
|
458
|
+
embeddings_t,
|
|
459
|
+
block_size=block_size,
|
|
460
|
+
value_range=(-8, 7),
|
|
461
|
+
dtype="int8",
|
|
462
|
+
to_numpy=True,
|
|
463
|
+
)
|
|
379
464
|
)
|
|
380
|
-
#
|
|
465
|
+
# Transpose back to (input_dim, output_dim) layout
|
|
466
|
+
embeddings_value = ops.transpose(embeddings_value_t)
|
|
467
|
+
embeddings_scale = ops.transpose(scale_t)
|
|
468
|
+
embeddings_zero = ops.transpose(zero_t)
|
|
469
|
+
|
|
470
|
+
packed_embeddings_value, _, _ = quantizers.pack_int4(
|
|
471
|
+
embeddings_value, axis=-1
|
|
472
|
+
)
|
|
473
|
+
del self._embeddings
|
|
474
|
+
|
|
475
|
+
# Quantize reverse embeddings if not tied
|
|
476
|
+
if not self.tie_weights:
|
|
477
|
+
if not use_grouped:
|
|
478
|
+
reverse_weight_quantizer = (
|
|
479
|
+
QuantizationConfig.weight_quantizer_or_default(
|
|
480
|
+
self.quantization_config,
|
|
481
|
+
quantizers.AbsMaxQuantizer(
|
|
482
|
+
axis=0,
|
|
483
|
+
value_range=(-8, 7),
|
|
484
|
+
output_dtype="int8",
|
|
485
|
+
),
|
|
486
|
+
)
|
|
487
|
+
)
|
|
488
|
+
reverse_embeddings_value, reverse_embeddings_scale = (
|
|
489
|
+
reverse_weight_quantizer(
|
|
490
|
+
self.reverse_embeddings, to_numpy=True
|
|
491
|
+
)
|
|
492
|
+
)
|
|
493
|
+
reverse_embeddings_scale = ops.squeeze(
|
|
494
|
+
reverse_embeddings_scale, axis=0
|
|
495
|
+
)
|
|
496
|
+
else:
|
|
497
|
+
reverse_value, reverse_scale, reverse_zero = (
|
|
498
|
+
quantizers.abs_max_quantize_grouped_with_zero_point(
|
|
499
|
+
self.reverse_embeddings,
|
|
500
|
+
block_size=block_size,
|
|
501
|
+
value_range=(-8, 7),
|
|
502
|
+
dtype="int8",
|
|
503
|
+
to_numpy=True,
|
|
504
|
+
)
|
|
505
|
+
)
|
|
506
|
+
reverse_embeddings_value = reverse_value
|
|
507
|
+
reverse_embeddings_scale = reverse_scale
|
|
508
|
+
reverse_embeddings_zero = reverse_zero
|
|
509
|
+
|
|
381
510
|
packed_reverse_embeddings_value, _, _ = quantizers.pack_int4(
|
|
382
511
|
reverse_embeddings_value, axis=0
|
|
383
512
|
)
|
|
384
513
|
del self.reverse_embeddings
|
|
514
|
+
|
|
385
515
|
self.quantized_build(
|
|
386
516
|
embeddings_shape, mode, self.quantization_config
|
|
387
517
|
)
|
|
388
518
|
self._embeddings.assign(packed_embeddings_value)
|
|
389
519
|
self.embeddings_scale.assign(embeddings_scale)
|
|
520
|
+
if use_grouped:
|
|
521
|
+
self.embeddings_zero.assign(embeddings_zero)
|
|
390
522
|
if not self.tie_weights:
|
|
391
523
|
self.reverse_embeddings.assign(packed_reverse_embeddings_value)
|
|
392
524
|
self.reverse_embeddings_scale.assign(reverse_embeddings_scale)
|
|
525
|
+
if use_grouped:
|
|
526
|
+
self.reverse_embeddings_zero.assign(reverse_embeddings_zero)
|
|
393
527
|
else:
|
|
394
528
|
raise self._quantization_mode_error(mode)
|
|
395
529
|
|
|
396
530
|
# Set new dtype policy.
|
|
397
531
|
if self.dtype_policy.quantization_mode is None:
|
|
398
|
-
|
|
532
|
+
policy_name = mode
|
|
533
|
+
if mode == "int4":
|
|
534
|
+
# Include block_size in policy name for sub-channel quantization
|
|
535
|
+
block_size = get_block_size_for_layer(self, config)
|
|
536
|
+
block_size_value = -1 if block_size is None else block_size
|
|
537
|
+
policy_name = f"int4/{block_size_value}"
|
|
538
|
+
policy = dtype_policies.get(
|
|
539
|
+
f"{policy_name}_from_{self.dtype_policy.name}"
|
|
540
|
+
)
|
|
399
541
|
self.dtype_policy = policy
|
|
@@ -213,8 +213,37 @@ class Discretization(DataLayer):
|
|
|
213
213
|
return
|
|
214
214
|
self.summary = np.array([[], []], dtype="float32")
|
|
215
215
|
|
|
216
|
+
def compute_output_shape(self, input_shape):
|
|
217
|
+
if self.output_mode == "int":
|
|
218
|
+
return input_shape
|
|
219
|
+
|
|
220
|
+
# Calculate depth (number of bins)
|
|
221
|
+
depth = (
|
|
222
|
+
len(self.bin_boundaries) + 1
|
|
223
|
+
if self.bin_boundaries is not None
|
|
224
|
+
else self.num_bins
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
if self.output_mode == "one_hot":
|
|
228
|
+
# For one_hot mode, add depth dimension
|
|
229
|
+
# If last dimension is 1, replace it with depth, otherwise append
|
|
230
|
+
if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
|
|
231
|
+
return tuple(input_shape[:-1]) + (depth,)
|
|
232
|
+
else:
|
|
233
|
+
return tuple(input_shape) + (depth,)
|
|
234
|
+
else:
|
|
235
|
+
if input_shape and len(input_shape) >= 2:
|
|
236
|
+
# Match to eager tensor, remove second and append depth
|
|
237
|
+
out_shape = (
|
|
238
|
+
(input_shape[0],) + tuple(input_shape[2:]) + (depth,)
|
|
239
|
+
)
|
|
240
|
+
return out_shape
|
|
241
|
+
else:
|
|
242
|
+
return (depth,)
|
|
243
|
+
|
|
216
244
|
def compute_output_spec(self, inputs):
|
|
217
|
-
|
|
245
|
+
output_shape = self.compute_output_shape(inputs.shape)
|
|
246
|
+
return backend.KerasTensor(shape=output_shape, dtype=self.output_dtype)
|
|
218
247
|
|
|
219
248
|
def load_own_variables(self, store):
|
|
220
249
|
if len(store) == 1:
|
keras/src/quantizers/__init__.py
CHANGED
|
@@ -9,11 +9,17 @@ from keras.src.quantizers.quantization_config import QuantizationConfig
|
|
|
9
9
|
from keras.src.quantizers.quantizers import AbsMaxQuantizer
|
|
10
10
|
from keras.src.quantizers.quantizers import Quantizer
|
|
11
11
|
from keras.src.quantizers.quantizers import abs_max_quantize
|
|
12
|
+
from keras.src.quantizers.quantizers import (
|
|
13
|
+
abs_max_quantize_grouped_with_zero_point,
|
|
14
|
+
)
|
|
12
15
|
from keras.src.quantizers.quantizers import compute_float8_amax_history
|
|
13
16
|
from keras.src.quantizers.quantizers import compute_float8_scale
|
|
17
|
+
from keras.src.quantizers.quantizers import compute_quantization_parameters
|
|
18
|
+
from keras.src.quantizers.quantizers import dequantize_with_sz_map
|
|
14
19
|
from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars
|
|
15
20
|
from keras.src.quantizers.quantizers import pack_int4
|
|
16
21
|
from keras.src.quantizers.quantizers import quantize_and_dequantize
|
|
22
|
+
from keras.src.quantizers.quantizers import quantize_with_sz_map
|
|
17
23
|
from keras.src.quantizers.quantizers import unpack_int4
|
|
18
24
|
from keras.src.saving import serialization_lib
|
|
19
25
|
from keras.src.utils.naming import to_snake_case
|
|
@@ -99,14 +99,46 @@ class Int4QuantizationConfig(QuantizationConfig):
|
|
|
99
99
|
weight_quantizer: Quantizer for weights.
|
|
100
100
|
activation_quantizer: Quantizer for activations. If "default", uses
|
|
101
101
|
AbsMaxQuantizer with axis=-1.
|
|
102
|
+
block_size: Size of groups along the input dimension for sub-channel
|
|
103
|
+
quantization. If a positive integer, uses sub-channel quantization
|
|
104
|
+
with `ceil(input_dim / block_size)` groups. If `None` or `-1`,
|
|
105
|
+
uses per-channel quantization (one scale per output channel).
|
|
106
|
+
Default: `128` (sub-channel with 128-element groups).
|
|
102
107
|
"""
|
|
103
108
|
|
|
104
|
-
def __init__(
|
|
105
|
-
|
|
106
|
-
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
weight_quantizer=None,
|
|
112
|
+
activation_quantizer="default",
|
|
113
|
+
block_size=128,
|
|
114
|
+
):
|
|
107
115
|
if activation_quantizer == "default":
|
|
108
|
-
|
|
116
|
+
# Use weight-only quantization by default for int4
|
|
117
|
+
activation_quantizer = None
|
|
109
118
|
super().__init__(weight_quantizer, activation_quantizer)
|
|
119
|
+
|
|
120
|
+
# Validate block_size
|
|
121
|
+
if block_size is not None and block_size != -1 and block_size <= 0:
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"block_size must be None, -1, or a positive integer. "
|
|
124
|
+
f"Received: block_size={block_size}"
|
|
125
|
+
)
|
|
126
|
+
self.block_size = block_size
|
|
127
|
+
|
|
128
|
+
# Sub-channel quantization does not support custom quantizers
|
|
129
|
+
is_sub_channel = block_size is not None and block_size > 0
|
|
130
|
+
has_custom_quantizer = (
|
|
131
|
+
self.weight_quantizer is not None
|
|
132
|
+
or self.activation_quantizer is not None
|
|
133
|
+
)
|
|
134
|
+
if is_sub_channel and has_custom_quantizer:
|
|
135
|
+
raise ValueError(
|
|
136
|
+
"Int4 sub-channel quantization (block_size > 0) does not "
|
|
137
|
+
"support custom quantizers. Either set block_size to None "
|
|
138
|
+
"or -1 for per-channel quantization, or remove the custom "
|
|
139
|
+
f"quantizer arguments. Received: block_size={block_size}"
|
|
140
|
+
)
|
|
141
|
+
|
|
110
142
|
if self.weight_quantizer is not None:
|
|
111
143
|
if self.weight_quantizer.value_range != (-8, 7):
|
|
112
144
|
raise ValueError(
|
|
@@ -126,6 +158,28 @@ class Int4QuantizationConfig(QuantizationConfig):
|
|
|
126
158
|
def mode(self):
|
|
127
159
|
return "int4"
|
|
128
160
|
|
|
161
|
+
def get_config(self):
|
|
162
|
+
config = super().get_config()
|
|
163
|
+
config["block_size"] = self.block_size
|
|
164
|
+
return config
|
|
165
|
+
|
|
166
|
+
@classmethod
|
|
167
|
+
def from_config(cls, config):
|
|
168
|
+
weight_quantizer = serialization_lib.deserialize_keras_object(
|
|
169
|
+
config.get("weight_quantizer")
|
|
170
|
+
)
|
|
171
|
+
activation_quantizer = serialization_lib.deserialize_keras_object(
|
|
172
|
+
config.get("activation_quantizer")
|
|
173
|
+
)
|
|
174
|
+
# Default to None for backwards compatibility with models saved
|
|
175
|
+
# before block_size was introduced (those used per-channel mode)
|
|
176
|
+
block_size = config.get("block_size", None)
|
|
177
|
+
return cls(
|
|
178
|
+
weight_quantizer=weight_quantizer,
|
|
179
|
+
activation_quantizer=activation_quantizer,
|
|
180
|
+
block_size=block_size,
|
|
181
|
+
)
|
|
182
|
+
|
|
129
183
|
|
|
130
184
|
@keras_export("keras.quantizers.Float8QuantizationConfig")
|
|
131
185
|
class Float8QuantizationConfig(QuantizationConfig):
|
|
@@ -244,3 +298,43 @@ def _validate_mode(mode):
|
|
|
244
298
|
"Invalid quantization mode. "
|
|
245
299
|
f"Expected one of {QUANTIZATION_MODES}. Received: mode={mode}"
|
|
246
300
|
)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def get_block_size_for_layer(layer, config):
|
|
304
|
+
"""Determine the block size for int4 quantization.
|
|
305
|
+
|
|
306
|
+
The block size can be specified either through the `config` argument
|
|
307
|
+
or through the `dtype_policy` if it is of type `Int4DTypePolicy`.
|
|
308
|
+
|
|
309
|
+
The config argument is usually available when quantizing the layer
|
|
310
|
+
via the `quantize` method. If the layer was deserialized from a
|
|
311
|
+
saved model, the block size should be specified in the `dtype_policy`.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
layer: The layer being quantized.
|
|
315
|
+
config: An optional configuration object that may contain the
|
|
316
|
+
`block_size` attribute.
|
|
317
|
+
Returns:
|
|
318
|
+
int or None. The determined block size for int4 quantization.
|
|
319
|
+
Returns `None` or `-1` for per-channel quantization.
|
|
320
|
+
"""
|
|
321
|
+
from keras.src.dtype_policies.dtype_policy import Int4DTypePolicy
|
|
322
|
+
from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
|
|
323
|
+
|
|
324
|
+
if config and isinstance(config, Int4QuantizationConfig):
|
|
325
|
+
return config.block_size
|
|
326
|
+
elif isinstance(layer.dtype_policy, Int4DTypePolicy):
|
|
327
|
+
block_size = layer.dtype_policy.block_size
|
|
328
|
+
# Convert -1 to None for consistency
|
|
329
|
+
return None if block_size == -1 else block_size
|
|
330
|
+
elif isinstance(layer.dtype_policy, DTypePolicyMap):
|
|
331
|
+
policy = layer.dtype_policy[layer.path]
|
|
332
|
+
if isinstance(policy, Int4DTypePolicy):
|
|
333
|
+
block_size = policy.block_size
|
|
334
|
+
return None if block_size == -1 else block_size
|
|
335
|
+
# Fall back to None for legacy QuantizedDTypePolicy
|
|
336
|
+
return None
|
|
337
|
+
else:
|
|
338
|
+
# For backwards compatibility with models that don't have
|
|
339
|
+
# Int4DTypePolicy (legacy per-channel mode)
|
|
340
|
+
return None
|