keras-nightly 3.12.0.dev2025082103__py3-none-any.whl → 3.12.0.dev2025082303__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. keras/_tf_keras/keras/ops/__init__.py +1 -0
  2. keras/_tf_keras/keras/ops/numpy/__init__.py +1 -0
  3. keras/_tf_keras/keras/quantizers/__init__.py +1 -0
  4. keras/ops/__init__.py +1 -0
  5. keras/ops/numpy/__init__.py +1 -0
  6. keras/quantizers/__init__.py +1 -0
  7. keras/src/applications/convnext.py +20 -20
  8. keras/src/applications/densenet.py +21 -21
  9. keras/src/applications/efficientnet.py +16 -16
  10. keras/src/applications/efficientnet_v2.py +28 -28
  11. keras/src/applications/inception_resnet_v2.py +7 -7
  12. keras/src/applications/inception_v3.py +5 -5
  13. keras/src/applications/mobilenet_v2.py +13 -20
  14. keras/src/applications/mobilenet_v3.py +15 -15
  15. keras/src/applications/nasnet.py +7 -8
  16. keras/src/applications/resnet.py +32 -32
  17. keras/src/applications/xception.py +10 -10
  18. keras/src/backend/common/dtypes.py +8 -3
  19. keras/src/backend/common/variables.py +3 -1
  20. keras/src/backend/jax/export.py +1 -1
  21. keras/src/backend/jax/numpy.py +6 -0
  22. keras/src/backend/jax/trainer.py +1 -1
  23. keras/src/backend/numpy/numpy.py +28 -0
  24. keras/src/backend/openvino/numpy.py +5 -1
  25. keras/src/backend/tensorflow/numpy.py +22 -0
  26. keras/src/backend/tensorflow/trainer.py +19 -1
  27. keras/src/backend/torch/core.py +6 -9
  28. keras/src/backend/torch/nn.py +1 -2
  29. keras/src/backend/torch/numpy.py +16 -0
  30. keras/src/backend/torch/trainer.py +1 -1
  31. keras/src/callbacks/backup_and_restore.py +2 -2
  32. keras/src/callbacks/csv_logger.py +1 -1
  33. keras/src/callbacks/model_checkpoint.py +1 -1
  34. keras/src/callbacks/tensorboard.py +6 -6
  35. keras/src/constraints/constraints.py +9 -7
  36. keras/src/datasets/boston_housing.py +1 -1
  37. keras/src/datasets/california_housing.py +1 -1
  38. keras/src/datasets/cifar10.py +1 -1
  39. keras/src/datasets/cifar100.py +2 -2
  40. keras/src/datasets/imdb.py +2 -2
  41. keras/src/datasets/mnist.py +1 -1
  42. keras/src/datasets/reuters.py +2 -2
  43. keras/src/dtype_policies/dtype_policy.py +1 -1
  44. keras/src/dtype_policies/dtype_policy_map.py +1 -1
  45. keras/src/export/tf2onnx_lib.py +1 -3
  46. keras/src/initializers/constant_initializers.py +9 -5
  47. keras/src/layers/input_spec.py +6 -6
  48. keras/src/layers/layer.py +1 -1
  49. keras/src/layers/preprocessing/category_encoding.py +3 -3
  50. keras/src/layers/preprocessing/data_layer.py +159 -0
  51. keras/src/layers/preprocessing/discretization.py +3 -3
  52. keras/src/layers/preprocessing/feature_space.py +4 -4
  53. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +7 -4
  54. keras/src/layers/preprocessing/image_preprocessing/auto_contrast.py +3 -0
  55. keras/src/layers/preprocessing/image_preprocessing/base_image_preprocessing_layer.py +2 -2
  56. keras/src/layers/preprocessing/image_preprocessing/center_crop.py +1 -1
  57. keras/src/layers/preprocessing/image_preprocessing/cut_mix.py +6 -3
  58. keras/src/layers/preprocessing/image_preprocessing/equalization.py +1 -1
  59. keras/src/layers/preprocessing/image_preprocessing/max_num_bounding_box.py +3 -0
  60. keras/src/layers/preprocessing/image_preprocessing/mix_up.py +7 -4
  61. keras/src/layers/preprocessing/image_preprocessing/rand_augment.py +3 -1
  62. keras/src/layers/preprocessing/image_preprocessing/random_brightness.py +1 -1
  63. keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration.py +3 -0
  64. keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py +3 -0
  65. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +1 -1
  66. keras/src/layers/preprocessing/image_preprocessing/random_crop.py +1 -1
  67. keras/src/layers/preprocessing/image_preprocessing/random_elastic_transform.py +3 -0
  68. keras/src/layers/preprocessing/image_preprocessing/random_erasing.py +6 -3
  69. keras/src/layers/preprocessing/image_preprocessing/random_flip.py +1 -1
  70. keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur.py +3 -0
  71. keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py +1 -1
  72. keras/src/layers/preprocessing/image_preprocessing/random_hue.py +3 -0
  73. keras/src/layers/preprocessing/image_preprocessing/random_invert.py +3 -0
  74. keras/src/layers/preprocessing/image_preprocessing/random_perspective.py +3 -0
  75. keras/src/layers/preprocessing/image_preprocessing/random_posterization.py +3 -0
  76. keras/src/layers/preprocessing/image_preprocessing/random_rotation.py +1 -1
  77. keras/src/layers/preprocessing/image_preprocessing/random_saturation.py +3 -0
  78. keras/src/layers/preprocessing/image_preprocessing/random_sharpness.py +3 -0
  79. keras/src/layers/preprocessing/image_preprocessing/random_shear.py +3 -0
  80. keras/src/layers/preprocessing/image_preprocessing/random_translation.py +3 -3
  81. keras/src/layers/preprocessing/image_preprocessing/random_zoom.py +3 -3
  82. keras/src/layers/preprocessing/image_preprocessing/resizing.py +3 -3
  83. keras/src/layers/preprocessing/image_preprocessing/solarization.py +3 -0
  84. keras/src/layers/preprocessing/mel_spectrogram.py +29 -25
  85. keras/src/layers/preprocessing/normalization.py +5 -2
  86. keras/src/layers/preprocessing/rescaling.py +3 -3
  87. keras/src/layers/rnn/bidirectional.py +4 -4
  88. keras/src/legacy/backend.py +9 -23
  89. keras/src/legacy/preprocessing/image.py +11 -22
  90. keras/src/legacy/preprocessing/text.py +1 -1
  91. keras/src/models/functional.py +2 -2
  92. keras/src/models/model.py +21 -3
  93. keras/src/ops/function.py +1 -1
  94. keras/src/ops/numpy.py +49 -5
  95. keras/src/ops/operation.py +3 -2
  96. keras/src/optimizers/base_optimizer.py +3 -4
  97. keras/src/optimizers/schedules/learning_rate_schedule.py +16 -9
  98. keras/src/quantizers/gptq.py +350 -0
  99. keras/src/quantizers/gptq_config.py +169 -0
  100. keras/src/quantizers/gptq_core.py +335 -0
  101. keras/src/quantizers/gptq_quant.py +133 -0
  102. keras/src/saving/file_editor.py +22 -20
  103. keras/src/saving/object_registration.py +1 -1
  104. keras/src/saving/saving_lib.py +4 -4
  105. keras/src/saving/serialization_lib.py +3 -5
  106. keras/src/trainers/compile_utils.py +1 -1
  107. keras/src/trainers/data_adapters/array_data_adapter.py +9 -3
  108. keras/src/trainers/data_adapters/data_adapter_utils.py +15 -5
  109. keras/src/trainers/data_adapters/generator_data_adapter.py +2 -0
  110. keras/src/trainers/data_adapters/grain_dataset_adapter.py +8 -2
  111. keras/src/trainers/data_adapters/tf_dataset_adapter.py +4 -2
  112. keras/src/trainers/data_adapters/torch_data_loader_adapter.py +3 -1
  113. keras/src/tree/dmtree_impl.py +19 -3
  114. keras/src/tree/optree_impl.py +3 -3
  115. keras/src/tree/tree_api.py +5 -2
  116. keras/src/utils/file_utils.py +13 -5
  117. keras/src/utils/io_utils.py +1 -1
  118. keras/src/utils/model_visualization.py +1 -1
  119. keras/src/utils/progbar.py +5 -5
  120. keras/src/utils/summary_utils.py +4 -4
  121. keras/src/version.py +1 -1
  122. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/METADATA +1 -1
  123. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/RECORD +125 -121
  124. keras/src/layers/preprocessing/tf_data_layer.py +0 -78
  125. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/WHEEL +0 -0
  126. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,335 @@
1
+ import random
2
+
3
+ import numpy as np
4
+ from absl import logging
5
+
6
+ from keras.src import ops
7
+ from keras.src import utils as keras_utils
8
+ from keras.src.layers import Dense
9
+ from keras.src.layers import EinsumDense
10
+ from keras.src.layers import Embedding
11
+ from keras.src.quantizers.gptq import GPTQ
12
+ from keras.src.quantizers.gptq_quant import GPTQQuantization
13
+
14
+
15
+ def get_dataloader(tokenizer, sequence_length, dataset, num_samples=128):
16
+ """
17
+ Prepares and chunks the calibration dataloader, repeating short datasets.
18
+ """
19
+ all_tokens = []
20
+
21
+ if not hasattr(dataset, "__iter__") or isinstance(dataset, (str, bytes)):
22
+ raise TypeError(
23
+ "The `dataset` argument must be an iterable (e.g., a list of "
24
+ "strings, a generator, or a NumPy array). Got type: "
25
+ f"{type(dataset).__name__}. Please pass the loaded dataset "
26
+ "directly."
27
+ )
28
+
29
+ dataset_list = list(dataset)
30
+
31
+ if not dataset_list:
32
+ raise ValueError("Provided dataset is empty.")
33
+
34
+ if isinstance(dataset_list[0], str):
35
+ logging.info("(Dataset contains strings, tokenizing now...)")
36
+ full_text = "\n\n".join(dataset_list)
37
+ all_tokens = tokenizer.tokenize(full_text)
38
+ else:
39
+ logging.info("(Dataset is pre-tokenized, concatenating...)")
40
+ all_tokens = np.concatenate(
41
+ [ops.convert_to_numpy(s).reshape(-1) for s in dataset_list], axis=0
42
+ )
43
+
44
+ all_tokens = np.array(all_tokens, dtype=np.int32)
45
+
46
+ # Repeat data if it's too short
47
+ required_tokens = num_samples * sequence_length
48
+ if len(all_tokens) < required_tokens:
49
+ logging.info(
50
+ f"Warning: Dataset is too short ({len(all_tokens)} tokens)."
51
+ " Repeating data to generate {num_samples} samples."
52
+ )
53
+ repeats = -(-required_tokens // len(all_tokens)) # Ceiling division
54
+ all_tokens = np.tile(all_tokens, repeats)
55
+
56
+ # Chunk the token list into samples
57
+
58
+ calibration_samples = []
59
+ for _ in range(num_samples):
60
+ # Generate a random starting index
61
+ start_index = random.randint(0, len(all_tokens) - sequence_length - 1)
62
+ end_index = start_index + sequence_length
63
+ sample = all_tokens[start_index:end_index]
64
+ calibration_samples.append(np.reshape(sample, (1, sequence_length)))
65
+
66
+ final_array = np.stack(calibration_samples, axis=0)
67
+ return final_array
68
+
69
+
70
+ def _find_layers_recursive(layer, prefix, found_layers):
71
+ """
72
+ Recursively search for Dense and EinsumDense layers and record them.
73
+ """
74
+ for sub_layer in layer._layers:
75
+ # Construct a unique name for the layer based on its hierarchy
76
+ layer_name = f"{prefix}.{sub_layer.name}"
77
+ if isinstance(sub_layer, (Dense, EinsumDense)):
78
+ found_layers[layer_name] = sub_layer
79
+
80
+ # Recurse into nested layers that are not the target types
81
+ elif hasattr(sub_layer, "_layers") and sub_layer._layers:
82
+ _find_layers_recursive(sub_layer, layer_name, found_layers)
83
+
84
+
85
+ def find_layers_in_block(block):
86
+ """
87
+ A pluggable, generic function to find all Dense and EinsumDense layers
88
+ within any transformer block by using a recursive search.
89
+ """
90
+ found_layers = {}
91
+ # Start the recursive search from the block itself
92
+ _find_layers_recursive(block, "block", found_layers)
93
+ return found_layers
94
+
95
+
96
+ def apply_gptq_layerwise(
97
+ model,
98
+ dataloader,
99
+ num_samples,
100
+ hessian_damping,
101
+ group_size,
102
+ symmetric,
103
+ activation_order,
104
+ weight_bits,
105
+ ):
106
+ """Applies GPTQ quantization layer-by-layer to a Keras model.
107
+
108
+ This function is designed to work with common transformer architectures,
109
+ like those provided by KerasHub. It automatically discovers the model's
110
+ structure by first looking for the standard format: a `model.backbone`
111
+ attribute that contains a `transformer_layers` list.
112
+
113
+ If a standard backbone is not found, it falls back to a heuristic for
114
+ custom models, where it assumes the first `keras.layers.Embedding` layer
115
+ is the input embedding and any subsequent container layers are the
116
+ transformer blocks to be quantized.
117
+
118
+ The core logic operates as follows:
119
+ 1. It automatically detects the model's structure, identifying the main
120
+ embedding layer and a sequence of transformer blocks.
121
+ 2. It processes the model sequentially, one block at a time. For each
122
+ block, it uses temporary hooks to capture the input activations of
123
+ each target layer during a forward pass with the calibration data.
124
+ 3. These captured activations are used to compute the Hessian matrix for
125
+ each layer's weights.
126
+ 4. The GPTQ algorithm is then applied to each layer to find the optimal
127
+ quantized weights that minimize the error introduced.
128
+ 5. The output activations from the current block are then used as the
129
+ input for the next block, ensuring that quantization errors are
130
+ accounted for throughout the model.
131
+
132
+ Args:
133
+ model: The Keras model instance to be quantized. The function will
134
+ attempt to automatically discover its structure.
135
+ dataloader: An iterable providing calibration data. Each item should
136
+ be a batch of token IDs suitable for the model's embedding layer.
137
+ num_samples: (int) The number of samples from the dataloader to use for
138
+ calibration.
139
+ hessian_damping: (float) The percentage of dampening to add to the
140
+ Hessian diagonal for stabilization during inverse calculation.
141
+ A value of 0.01 is common.
142
+ group_size: (int) The size of the groups to use for quantization. A
143
+ value of 128 means that 128 weights will share the same scaling
144
+ factor. Use -1 for per-channel quantization.
145
+ symmetric: (bool) If True, symmetric quantization is used. Otherwise,
146
+ asymmetric quantization is used.
147
+ activation_order: (bool) If True, reorders the weight columns based on
148
+ activation magnitude, which can improve quantization accuracy.
149
+ weight_bits: (int) The number of bits to use for the quantized weights,
150
+ e.g., 4 for 4-bit quantization.
151
+
152
+ Raises:
153
+ ValueError: If the function cannot automatically find an embedding
154
+ layer or any transformer-like blocks to quantize within the model.
155
+ """
156
+ logging.info("Starting model quantization...")
157
+ embedding_layer = None
158
+ transformer_blocks = []
159
+ if hasattr(model, "backbone"):
160
+ logging.info("Detected KerasHub model structure.")
161
+ backbone = model.backbone
162
+
163
+ # Add the check for the 'transformer_layers' attribute.
164
+ if hasattr(backbone, "transformer_layers"):
165
+ transformer_blocks = backbone.transformer_layers
166
+ else:
167
+ # Raise a specific error if the attribute is missing.
168
+ raise ValueError(
169
+ "The model's backbone does not have a 'transformer_layers' "
170
+ "attribute. Please ensure you are using a standard KerasHub "
171
+ "transformer model."
172
+ )
173
+ # Find the embedding layer by checking for common names or by type.
174
+ if hasattr(backbone, "token_embedding"):
175
+ embedding_layer = backbone.token_embedding
176
+ elif hasattr(backbone, "embedding"):
177
+ embedding_layer = backbone.embedding
178
+ else:
179
+ raise ValueError(
180
+ "Could not automatically find an embedding layer in the model."
181
+ )
182
+
183
+ else:
184
+ logging.info("Detected custom model structure.")
185
+ for layer in model.layers:
186
+ # The first Embedding layer found is assumed to be the main one.
187
+ if isinstance(layer, Embedding) and embedding_layer is None:
188
+ embedding_layer = layer
189
+ # A "block" is a container-like layer with its own sub-layers
190
+ # that we can quantize. This is a heuristic that works for the
191
+ # test.
192
+ elif hasattr(layer, "_layers") and layer._layers:
193
+ transformer_blocks.append(layer)
194
+
195
+ if embedding_layer is None:
196
+ raise ValueError(
197
+ "Could not automatically find an embedding layer in the model."
198
+ )
199
+ if not transformer_blocks:
200
+ raise ValueError(
201
+ "Could not automatically find any transformer-like blocks to "
202
+ "quantize."
203
+ )
204
+
205
+ # Initial inputs are the outputs of the token embedding layer
206
+ inputs = [
207
+ embedding_layer(ops.convert_to_tensor(batch, dtype="int32"))
208
+ for batch in dataloader
209
+ ]
210
+ progbar = keras_utils.Progbar(target=len(transformer_blocks))
211
+
212
+ for block_idx, block in enumerate(transformer_blocks):
213
+ logging.info(f"Quantizing Block {block_idx}")
214
+ sub_layers_map = find_layers_in_block(block)
215
+
216
+ if not sub_layers_map:
217
+ logging.info(
218
+ f" No Dense or EinsumDense layers found in block {block_idx}. "
219
+ "Skipping."
220
+ )
221
+ else:
222
+ logging.info(f"Found layers: {list(sub_layers_map.keys())}")
223
+ gptq_objects = {
224
+ name: GPTQ(layer) for name, layer in sub_layers_map.items()
225
+ }
226
+
227
+ captured_inputs = {name: [] for name in sub_layers_map.keys()}
228
+ original_calls = {}
229
+
230
+ def create_hook(name, original_call_func):
231
+ """A factory for creating a hook to capture layer inputs."""
232
+
233
+ def hook(*args, **kwargs):
234
+ if args:
235
+ inp = args[0]
236
+ else:
237
+ inp = kwargs["inputs"]
238
+ captured_inputs[name].append(inp)
239
+ return original_call_func(*args, **kwargs)
240
+
241
+ return hook
242
+
243
+ try:
244
+ for name, layer in sub_layers_map.items():
245
+ original_call = layer.call
246
+ original_calls[name] = original_call
247
+ layer.call = create_hook(name, original_call)
248
+
249
+ logging.info(f"Capturing activations for block {block_idx}...")
250
+ for sample_idx in range(num_samples):
251
+ current_input = inputs[sample_idx]
252
+ if len(current_input.shape) == 2:
253
+ current_input = ops.expand_dims(current_input, axis=0)
254
+ _ = block(current_input)
255
+
256
+ finally:
257
+ for name, layer in sub_layers_map.items():
258
+ if name in original_calls:
259
+ layer.call = original_calls[name]
260
+
261
+ logging.info(f"Building Hessians for block {block_idx}...")
262
+ for name, gptq_object in gptq_objects.items():
263
+ layer_inputs = ops.concatenate(captured_inputs[name], axis=0)
264
+
265
+ # Explicitly reshape the input tensor to be 2D, with the second
266
+ # dimension matching the number of input features expected by
267
+ # the layer's kernel.
268
+ # This correctly handles inputs of any dimensionality
269
+ # (e.g., 3D or 4D).
270
+ num_features = gptq_object.rows
271
+ input_reshaped = ops.reshape(layer_inputs, (-1, num_features))
272
+ gptq_object.update_hessian_with_batch(input_reshaped)
273
+
274
+ quantizer = GPTQQuantization(
275
+ weight_bits,
276
+ per_channel=True,
277
+ symmetric=symmetric,
278
+ group_size=group_size,
279
+ )
280
+ for name, gptq_object in gptq_objects.items():
281
+ logging.info(f"Quantizing {name}...")
282
+ gptq_object.quantizer = quantizer
283
+ gptq_object.quantize_and_correct_block(
284
+ hessian_damping=hessian_damping,
285
+ group_size=group_size,
286
+ activation_order=activation_order,
287
+ )
288
+ gptq_object.free()
289
+
290
+ del gptq_objects, captured_inputs, original_calls
291
+
292
+ if block_idx < len(transformer_blocks) - 1:
293
+ logging.info(f"Generating inputs for block {block_idx + 1}...")
294
+ next_block_inputs = []
295
+ for sample_idx in range(num_samples):
296
+ current_input = inputs[sample_idx]
297
+ if len(current_input.shape) == 2:
298
+ current_input = ops.expand_dims(current_input, axis=0)
299
+ output = block(current_input)[0]
300
+ next_block_inputs.append(output)
301
+ inputs = next_block_inputs
302
+ progbar.update(current=block_idx + 1)
303
+
304
+ logging.info("Quantization process complete.")
305
+
306
+
307
+ def quantize_model(model, config):
308
+ """
309
+ Top-level function to quantize a Keras model using GPTQ.
310
+ """
311
+ logging.info("Starting GPTQ quantization process...")
312
+
313
+ # Load ALL data needed from the generator/source in a single call.
314
+ total_samples_to_request = config.num_samples
315
+ full_dataloader = get_dataloader(
316
+ config.tokenizer,
317
+ config.sequence_length,
318
+ config.dataset,
319
+ num_samples=total_samples_to_request,
320
+ )
321
+
322
+ # Split the materialized data. This works because full_dataloader
323
+ # is now a NumPy array, which can be sliced and reused.
324
+ calibration_dataloader = full_dataloader[: config.num_samples]
325
+
326
+ apply_gptq_layerwise(
327
+ model,
328
+ calibration_dataloader, # Use the calibration slice
329
+ config.num_samples, # Use the configured number of samples
330
+ config.hessian_damping,
331
+ config.group_size,
332
+ config.symmetric,
333
+ config.activation_order,
334
+ config.weight_bits,
335
+ )
@@ -0,0 +1,133 @@
1
+ from keras.src import ops
2
+
3
+
4
+ def dequantize(input_tensor, scale, zero, maxq):
5
+ """The core quantization function."""
6
+ epsilon = ops.cast(1e-8, dtype=scale.dtype)
7
+ scale = ops.where(ops.equal(scale, 0), epsilon, scale)
8
+
9
+ quantized_tensor = ops.divide(input_tensor, scale)
10
+ quantized_tensor = ops.round(quantized_tensor)
11
+ q = ops.add(quantized_tensor, zero)
12
+ q = ops.clip(q, 0, maxq)
13
+
14
+ dequantized_tensor = ops.subtract(q, zero)
15
+ return ops.multiply(scale, dequantized_tensor)
16
+
17
+
18
+ class GPTQQuantization:
19
+ """A class that handles the quantization of weights using GPTQ method.
20
+
21
+ This class provides methods to find quantization parameters (scale and zero)
22
+ for a given tensor and can be used to quantize weights in a GPTQ context.
23
+
24
+ Args:
25
+ weight_bits: (int) The number of bits to quantize to (e.g., 4).
26
+ per_channel: (bool) A flag indicating whether quantization is
27
+ applied per-channel (`True`) or per-tensor (`False`).
28
+ Defaults to `False`.
29
+ symmetric: (bool) A flag indicating whether symmetric (`True`) or
30
+ asymmetric (`False`) quantization is used. Defaults to `False`.
31
+ group_size: (int) The size of weight groups for quantization. A
32
+ value of -1 indicates that grouping is not used.
33
+ Defaults to -1.
34
+ """
35
+
36
+ def __init__(
37
+ self, weight_bits, per_channel=True, symmetric=False, group_size=-1
38
+ ):
39
+ self.weight_bits = weight_bits
40
+ self.maxq = ops.cast(
41
+ ops.subtract(ops.power(2, weight_bits), 1), "float32"
42
+ )
43
+ self.per_channel = per_channel
44
+ self.symmetric = symmetric
45
+ self.group_size = group_size
46
+
47
+ # These are now determined later by `find_params`
48
+ self.scale = None
49
+ self.zero = None
50
+
51
+ def find_params(self, input_tensor, weight=False):
52
+ """Finds quantization parameters (scale and zero) for a given tensor."""
53
+
54
+ if input_tensor is None:
55
+ raise ValueError("Input tensor 'input_tensor' cannot be None.")
56
+
57
+ # For weights, we typically expect at least a 2D tensor.
58
+ if weight and len(input_tensor.shape) < 2:
59
+ raise ValueError(
60
+ f"Input weight tensor 'input_tensor' must have a rank of at "
61
+ f"least 2, but got rank {len(input_tensor.shape)}."
62
+ )
63
+
64
+ if ops.size(input_tensor) == 0:
65
+ raise ValueError("Input tensor 'input_tensor' cannot be empty.")
66
+
67
+ original_shape = input_tensor.shape
68
+
69
+ if self.per_channel:
70
+ if weight:
71
+ if self.group_size != -1:
72
+ input_reshaped = ops.reshape(
73
+ input_tensor, [-1, self.group_size]
74
+ )
75
+ else:
76
+ input_reshaped = ops.reshape(
77
+ input_tensor, [original_shape[0], -1]
78
+ )
79
+ else: # per-tensor
80
+ input_reshaped = ops.reshape(input_tensor, [1, -1])
81
+
82
+ # Find min/max values
83
+ min_values = ops.min(input_reshaped, axis=1)
84
+ max_values = ops.max(input_reshaped, axis=1)
85
+
86
+ # Apply symmetric quantization logic if enabled
87
+ if self.symmetric:
88
+ max_values = ops.maximum(ops.abs(min_values), max_values)
89
+ min_values = ops.where(
90
+ ops.less(min_values, 0), ops.negative(max_values), min_values
91
+ )
92
+
93
+ # Ensure range is not zero to avoid division errors
94
+ zero_range = ops.equal(min_values, max_values)
95
+ min_values = ops.where(
96
+ zero_range, ops.subtract(min_values, 1), min_values
97
+ )
98
+ max_values = ops.where(zero_range, ops.add(max_values, 1), max_values)
99
+
100
+ # Calculate scale and zero-point
101
+ self.scale = ops.divide(ops.subtract(max_values, min_values), self.maxq)
102
+ if self.symmetric:
103
+ self.zero = ops.full_like(
104
+ self.scale, ops.divide(ops.add(self.maxq, 1), 2)
105
+ )
106
+ else:
107
+ self.zero = ops.round(
108
+ ops.divide(ops.negative(min_values), self.scale)
109
+ )
110
+
111
+ # Ensure scale is non-zero
112
+ self.scale = ops.where(ops.less_equal(self.scale, 0), 1e-8, self.scale)
113
+
114
+ if weight:
115
+ # Per-channel, non-grouped case: simple reshape is correct.
116
+ if self.per_channel and self.group_size == -1:
117
+ self.scale = ops.reshape(self.scale, [-1, 1])
118
+ self.zero = ops.reshape(self.zero, [-1, 1])
119
+ elif not self.per_channel:
120
+ num_rows = original_shape[0]
121
+ self.scale = ops.tile(
122
+ ops.reshape(self.scale, (1, 1)), (num_rows, 1)
123
+ )
124
+ self.zero = ops.tile(
125
+ ops.reshape(self.zero, (1, 1)), (num_rows, 1)
126
+ )
127
+ if self.per_channel:
128
+ self.scale = ops.reshape(self.scale, [-1, 1])
129
+ self.zero = ops.reshape(self.zero, [-1, 1])
130
+
131
+ def ready(self):
132
+ """Checks if the quantization parameters have been computed."""
133
+ return self.scale is not None and self.zero is not None
@@ -1,5 +1,6 @@
1
1
  import collections
2
2
  import json
3
+ import os.path
3
4
  import pprint
4
5
  import zipfile
5
6
 
@@ -76,7 +77,7 @@ class KerasFileEditor:
76
77
  if filepath.endswith(".keras"):
77
78
  zf = zipfile.ZipFile(filepath, "r")
78
79
  weights_store = H5IOStore(
79
- saving_lib._VARS_FNAME + ".h5",
80
+ f"{saving_lib._VARS_FNAME}.h5",
80
81
  archive=zf,
81
82
  mode="r",
82
83
  )
@@ -143,7 +144,7 @@ class KerasFileEditor:
143
144
  ):
144
145
  base_inner_path = inner_path
145
146
  for ref_key, ref_val in ref_spec.items():
146
- inner_path = base_inner_path + "/" + ref_key
147
+ inner_path = f"{base_inner_path}/{ref_key}"
147
148
  if inner_path in checked_paths:
148
149
  continue
149
150
 
@@ -435,7 +436,7 @@ class KerasFileEditor:
435
436
  _save(
436
437
  weights_dict[name],
437
438
  weights_store,
438
- inner_path=inner_path + "/" + name,
439
+ inner_path=os.path.join(inner_path, name),
439
440
  )
440
441
  else:
441
442
  # e.g. name="0", value=HDF5Dataset
@@ -462,7 +463,7 @@ class KerasFileEditor:
462
463
 
463
464
  result = collections.OrderedDict()
464
465
  for key in data.keys():
465
- inner_path = inner_path + "/" + key
466
+ inner_path = f"{inner_path}/{key}"
466
467
  value = data[key]
467
468
  if isinstance(value, h5py.Group):
468
469
  if len(value) == 0:
@@ -506,7 +507,7 @@ class KerasFileEditor:
506
507
  self, weights_dict, indent=0, is_first=True, prefix="", inner_path=""
507
508
  ):
508
509
  for idx, (key, value) in enumerate(weights_dict.items()):
509
- inner_path = inner_path + "/" + key
510
+ inner_path = os.path.join(inner_path, key)
510
511
  is_last = idx == len(weights_dict) - 1
511
512
  if is_first:
512
513
  is_first = False
@@ -556,29 +557,30 @@ class KerasFileEditor:
556
557
  html = ""
557
558
  for key, value in dictionary.items():
558
559
  if isinstance(value, dict) and value:
560
+ weights_html = _generate_html_weights(
561
+ value, margin_left + 20, font_size - 1
562
+ )
559
563
  html += (
560
564
  f'<details style="margin-left: {margin_left}px;">'
561
- + '<summary style="'
562
- + f"font-size: {font_size}em; "
563
- + "font-weight: bold;"
564
- + f'">{key}</summary>'
565
- + _generate_html_weights(
566
- value, margin_left + 20, font_size - 1
567
- )
568
- + "</details>"
565
+ '<summary style="'
566
+ f"font-size: {font_size}em; "
567
+ "font-weight: bold;"
568
+ f'">{key}</summary>'
569
+ f"{weights_html}"
570
+ "</details>"
569
571
  )
570
572
  else:
571
573
  html += (
572
574
  f'<details style="margin-left: {margin_left}px;">'
573
- + f'<summary style="font-size: {font_size}em;">'
574
- + f"{key} : shape={value.shape}"
575
- + f", dtype={value.dtype}</summary>"
576
- + f"<div style="
575
+ f'<summary style="font-size: {font_size}em;">'
576
+ f"{key} : shape={value.shape}"
577
+ f", dtype={value.dtype}</summary>"
578
+ f"<div style="
577
579
  f'"margin-left: {margin_left}px;'
578
580
  f'"margin-top: {margin_left}px;">'
579
- + f"{display_weight(value)}"
580
- + "</div>"
581
- + "</details>"
581
+ f"{display_weight(value)}"
582
+ "</div>"
583
+ "</details>"
582
584
  )
583
585
  return html
584
586
 
@@ -140,7 +140,7 @@ def register_keras_serializable(package="Custom", name=None):
140
140
  def decorator(arg):
141
141
  """Registers a class with the Keras serialization framework."""
142
142
  class_name = name if name is not None else arg.__name__
143
- registered_name = package + ">" + class_name
143
+ registered_name = f"{package}>{class_name}"
144
144
 
145
145
  if inspect.isclass(arg) and not hasattr(arg, "get_config"):
146
146
  raise ValueError(
@@ -46,8 +46,8 @@ except ImportError:
46
46
  _CONFIG_FILENAME = "config.json"
47
47
  _METADATA_FILENAME = "metadata.json"
48
48
  _VARS_FNAME = "model.weights" # Will become e.g. "model.weights.h5"
49
- _VARS_FNAME_H5 = _VARS_FNAME + ".h5"
50
- _VARS_FNAME_NPZ = _VARS_FNAME + ".npz"
49
+ _VARS_FNAME_H5 = f"{_VARS_FNAME}.h5"
50
+ _VARS_FNAME_NPZ = f"{_VARS_FNAME}.npz"
51
51
  _ASSETS_DIRNAME = "assets"
52
52
  _MEMORY_UPPER_BOUND = 0.5 # 50%
53
53
 
@@ -664,7 +664,7 @@ def _write_to_zip_recursively(zipfile_to_save, system_path, zip_path):
664
664
  def _name_key(name):
665
665
  """Make sure that private attributes are visited last."""
666
666
  if name.startswith("_"):
667
- return "~" + name
667
+ return f"~{name}"
668
668
  return name
669
669
 
670
670
 
@@ -1288,7 +1288,7 @@ class ShardedH5IOStore(H5IOStore):
1288
1288
  # If not found, check shard map and switch files.
1289
1289
  weight_map = self.sharding_config["weight_map"]
1290
1290
  filenames = weight_map.get(parsed_path) or weight_map.get(
1291
- "/" + parsed_path + "/vars"
1291
+ f"/{parsed_path}/vars"
1292
1292
  )
1293
1293
  if filenames is not None:
1294
1294
  if not isinstance(filenames, list):
@@ -778,7 +778,7 @@ def _retrieve_class_or_fn(
778
778
  # module name might not match the package structure
779
779
  # (e.g. experimental symbols).
780
780
  if module == "keras" or module.startswith("keras."):
781
- api_name = module + "." + name
781
+ api_name = f"{module}.{name}"
782
782
 
783
783
  if api_name in LOADING_APIS:
784
784
  raise ValueError(
@@ -796,9 +796,7 @@ def _retrieve_class_or_fn(
796
796
  # the corresponding function from the identifying string.
797
797
  if obj_type == "function" and module == "builtins":
798
798
  for mod in BUILTIN_MODULES:
799
- obj = api_export.get_symbol_from_name(
800
- "keras." + mod + "." + name
801
- )
799
+ obj = api_export.get_symbol_from_name(f"keras.{mod}.{name}")
802
800
  if obj is not None:
803
801
  return obj
804
802
 
@@ -807,7 +805,7 @@ def _retrieve_class_or_fn(
807
805
  # i.e. "name" instead of "package>name". This allows recent versions
808
806
  # of Keras to reload models saved with 3.6 and lower.
809
807
  if ">" not in name:
810
- separated_name = ">" + name
808
+ separated_name = f">{name}"
811
809
  for custom_name, custom_object in custom_objects.items():
812
810
  if custom_name.endswith(separated_name):
813
811
  return custom_object
@@ -659,7 +659,7 @@ class CompileLoss(losses_module.Loss):
659
659
  # Add `Mean` metric to the tracker for each loss.
660
660
  if len(self._flat_losses) > 1:
661
661
  for _loss in self._flat_losses:
662
- name = _loss.name + "_loss"
662
+ name = f"{_loss.name}_loss"
663
663
  self._tracker.add_to_store(
664
664
  "metrics", metrics_module.Mean(name=name)
665
665
  )
@@ -76,7 +76,9 @@ class ArrayDataAdapter(DataAdapter):
76
76
  inputs = data_adapter_utils.pack_x_y_sample_weight(x, y, sample_weight)
77
77
 
78
78
  data_adapter_utils.check_data_cardinality(inputs)
79
- num_samples = set(i.shape[0] for i in tree.flatten(inputs)).pop()
79
+ num_samples = set(
80
+ i.shape[0] for i in tree.flatten(inputs) if i is not None
81
+ ).pop()
80
82
  self._num_samples = num_samples
81
83
  self._inputs = inputs
82
84
 
@@ -269,7 +271,9 @@ class ArrayDataAdapter(DataAdapter):
269
271
  x = convert_to_tensor(x)
270
272
  return x
271
273
 
272
- return tree.map_structure(slice_and_convert, self.array)
274
+ return tree.map_structure(
275
+ slice_and_convert, self.array, none_is_leaf=False
276
+ )
273
277
 
274
278
  def __len__(self):
275
279
  return len(self.array[0])
@@ -337,7 +341,9 @@ class ArrayDataAdapter(DataAdapter):
337
341
  slice_indices_and_convert_fn = functools.partial(
338
342
  slice_and_convert_fn, indices=indices
339
343
  )
340
- yield tree.map_structure(slice_indices_and_convert_fn, inputs)
344
+ yield tree.map_structure(
345
+ slice_indices_and_convert_fn, inputs, none_is_leaf=False
346
+ )
341
347
 
342
348
  @property
343
349
  def num_batches(self):