keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,217 @@
1
+ """AWQ core functionality for layer-wise quantization.
2
+
3
+ This module provides the orchestration logic for applying AWQ quantization
4
+ to transformer models in a layer-by-layer fashion.
5
+ """
6
+
7
+ from contextlib import contextmanager
8
+
9
+ from absl import logging
10
+
11
+ from keras.src import ops
12
+ from keras.src import utils as keras_utils
13
+ from keras.src.dtype_policies.dtype_policy import AWQDTypePolicy
14
+ from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
15
+ from keras.src.quantizers.awq import AWQ
16
+ from keras.src.quantizers.awq_config import AWQConfig
17
+ from keras.src.quantizers.gptq_core import find_layers_in_block
18
+ from keras.src.quantizers.gptq_core import get_dataloader
19
+ from keras.src.quantizers.utils import should_quantize_layer
20
+
21
+
22
+ @contextmanager
23
+ def stream_activations(layers_map, awq_objects):
24
+ """Context manager to capture activations for AWQ calibration.
25
+
26
+ Temporarily patches layer.call methods to capture activation statistics
27
+ for computing per-channel scaling factors.
28
+
29
+ Args:
30
+ layers_map: Dict[str, Layer]. Mapping from layer names to layers.
31
+ awq_objects: Dict[str, AWQ]. Mapping from names to AWQ instances.
32
+
33
+ Yields:
34
+ None: The patched state is active only within the `with` block.
35
+ """
36
+ original_calls = {}
37
+
38
+ def create_hook(name, original_call_func):
39
+ def hook(*args, **kwargs):
40
+ inp = args[0] if args else kwargs["inputs"]
41
+ num_features = awq_objects[name].rows
42
+ input_2d = ops.reshape(inp, (-1, num_features))
43
+ awq_objects[name].update_activation_magnitudes(input_2d)
44
+ return original_call_func(*args, **kwargs)
45
+
46
+ return hook
47
+
48
+ try:
49
+ for name, layer in layers_map.items():
50
+ original_calls[name] = layer.call
51
+ layer.call = create_hook(name, layer.call)
52
+ yield
53
+ finally:
54
+ for name, layer in layers_map.items():
55
+ layer.call = original_calls[name]
56
+
57
+
58
+ def apply_awq_layerwise(dataloader, config, structure, filters=None):
59
+ """Apply AWQ quantization layer-by-layer to a Keras model.
60
+
61
+ This function processes the model sequentially, one block at a time:
62
+ 1. Captures activation statistics through calibration data forward pass
63
+ 2. Uses activation magnitudes to determine weight saliency
64
+ 3. Finds optimal per-channel scales via grid search
65
+ 4. Quantizes weights with AWQ scaling
66
+
67
+ Args:
68
+ dataloader: Calibration data as numpy array.
69
+ config: AWQConfig instance.
70
+ structure: Dict with 'pre_block_layers' and 'sequential_blocks'.
71
+ filters: Optional layer filters.
72
+ """
73
+ num_samples = config.num_samples
74
+ logging.info("Starting AWQ quantization...")
75
+
76
+ pre_layers = structure.get("pre_block_layers", [])
77
+ transformer_blocks = structure.get("sequential_blocks", [])
78
+
79
+ if not transformer_blocks:
80
+ raise ValueError(
81
+ "No sequential blocks found in the structure to quantize."
82
+ )
83
+
84
+ # Process inputs through pre-block layers (e.g., embedding)
85
+ inputs = []
86
+ for batch in dataloader:
87
+ batch = ops.convert_to_tensor(batch, dtype="int32")
88
+ for layer in pre_layers:
89
+ batch = layer(batch)
90
+ inputs.append(batch)
91
+
92
+ num_samples = min(num_samples, len(inputs))
93
+ progbar = keras_utils.Progbar(target=len(transformer_blocks))
94
+
95
+ for block_idx, block in enumerate(transformer_blocks):
96
+ logging.info(f"Quantizing Block {block_idx}")
97
+ sub_layers_map = find_layers_in_block(block)
98
+
99
+ # Apply filters
100
+ final_sub_layers_map = {}
101
+ for name, layer in sub_layers_map.items():
102
+ if not should_quantize_layer(layer, filters):
103
+ continue
104
+ final_sub_layers_map[name] = layer
105
+
106
+ sub_layers_map = final_sub_layers_map
107
+
108
+ if not sub_layers_map:
109
+ logging.info(
110
+ f" No quantizable layers found in block {block_idx}. Skipping."
111
+ )
112
+ else:
113
+ logging.info(f"Found layers: {list(sub_layers_map.keys())}")
114
+
115
+ # Create AWQ objects for each layer
116
+ awq_objects = {
117
+ name: AWQ(layer, config)
118
+ for name, layer in sub_layers_map.items()
119
+ }
120
+
121
+ # Capture activation statistics
122
+ with stream_activations(sub_layers_map, awq_objects):
123
+ for sample_idx in range(num_samples):
124
+ current_input = inputs[sample_idx]
125
+ if len(current_input.shape) == 2:
126
+ current_input = ops.expand_dims(current_input, axis=0)
127
+ _ = block(current_input)
128
+
129
+ # Quantize each layer
130
+ for name, awq_object in awq_objects.items():
131
+ logging.info(f"Quantizing {name}...")
132
+ awq_object.quantize_layer()
133
+ awq_object.free()
134
+
135
+ del awq_objects
136
+
137
+ # Generate inputs for next block
138
+ if block_idx < len(transformer_blocks) - 1:
139
+ logging.info(f"Generating inputs for block {block_idx + 1}...")
140
+ next_block_inputs = []
141
+ for sample_idx in range(num_samples):
142
+ current_input = inputs[sample_idx]
143
+ if len(current_input.shape) == 2:
144
+ current_input = ops.expand_dims(current_input, axis=0)
145
+ output = block(current_input)[0]
146
+ next_block_inputs.append(output)
147
+ inputs = next_block_inputs
148
+
149
+ progbar.update(current=block_idx + 1)
150
+
151
+ logging.info("AWQ quantization complete.")
152
+
153
+
154
+ def awq_quantize(config, quantization_layer_structure, filters=None):
155
+ """Main entry point for AWQ quantization.
156
+
157
+ Args:
158
+ config: AWQConfig instance.
159
+ quantization_layer_structure: Model structure dictionary.
160
+ filters: Optional layer filters.
161
+ """
162
+ if config.dataset is None or config.tokenizer is None:
163
+ raise ValueError(
164
+ "AWQ quantization requires a dataset and tokenizer. "
165
+ "Please provide them in the AWQConfig."
166
+ )
167
+
168
+ if quantization_layer_structure is None:
169
+ raise ValueError(
170
+ "For 'awq' mode, a valid quantization structure must be provided "
171
+ "either via `config.quantization_layer_structure` or by overriding "
172
+ "`model.get_quantization_layer_structure(mode)`. The structure "
173
+ "should be a dictionary with keys 'pre_block_layers' and "
174
+ "'sequential_blocks'."
175
+ )
176
+
177
+ # Load calibration data
178
+ dataloader = get_dataloader(
179
+ config.tokenizer,
180
+ config.sequence_length,
181
+ config.dataset,
182
+ num_samples=config.num_samples,
183
+ )
184
+
185
+ apply_awq_layerwise(
186
+ dataloader[: config.num_samples],
187
+ config,
188
+ quantization_layer_structure,
189
+ filters=filters,
190
+ )
191
+
192
+
193
+ def get_group_size_for_layer(layer, config):
194
+ """Get group size from config or dtype policy.
195
+
196
+ Args:
197
+ layer: The layer to get group size for.
198
+ config: Optional AWQConfig instance.
199
+
200
+ Returns:
201
+ int: The group size for quantization.
202
+
203
+ Raises:
204
+ ValueError: If group size cannot be determined.
205
+ """
206
+ if config and isinstance(config, AWQConfig):
207
+ return config.group_size
208
+ elif isinstance(layer.dtype_policy, AWQDTypePolicy):
209
+ return layer.dtype_policy.group_size
210
+ elif isinstance(layer.dtype_policy, DTypePolicyMap):
211
+ policy = layer.dtype_policy[layer.path]
212
+ if isinstance(policy, AWQDTypePolicy):
213
+ return policy.group_size
214
+ raise ValueError(
215
+ "For AWQ quantization, group_size must be specified "
216
+ "through AWQConfig or AWQDTypePolicy."
217
+ )