keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -1,133 +0,0 @@
1
- from keras.src import ops
2
-
3
-
4
- def dequantize(input_tensor, scale, zero, maxq):
5
- """The core quantization function."""
6
- epsilon = ops.cast(1e-8, dtype=scale.dtype)
7
- scale = ops.where(ops.equal(scale, 0), epsilon, scale)
8
-
9
- quantized_tensor = ops.divide(input_tensor, scale)
10
- quantized_tensor = ops.round(quantized_tensor)
11
- q = ops.add(quantized_tensor, zero)
12
- q = ops.clip(q, 0, maxq)
13
-
14
- dequantized_tensor = ops.subtract(q, zero)
15
- return ops.multiply(scale, dequantized_tensor)
16
-
17
-
18
- class GPTQQuantization:
19
- """A class that handles the quantization of weights using GPTQ method.
20
-
21
- This class provides methods to find quantization parameters (scale and zero)
22
- for a given tensor and can be used to quantize weights in a GPTQ context.
23
-
24
- Args:
25
- weight_bits: (int) The number of bits to quantize to (e.g., 4).
26
- per_channel: (bool) A flag indicating whether quantization is
27
- applied per-channel (`True`) or per-tensor (`False`).
28
- Defaults to `False`.
29
- symmetric: (bool) A flag indicating whether symmetric (`True`) or
30
- asymmetric (`False`) quantization is used. Defaults to `False`.
31
- group_size: (int) The size of weight groups for quantization. A
32
- value of -1 indicates that grouping is not used.
33
- Defaults to -1.
34
- """
35
-
36
- def __init__(
37
- self, weight_bits, per_channel=True, symmetric=False, group_size=-1
38
- ):
39
- self.weight_bits = weight_bits
40
- self.maxq = ops.cast(
41
- ops.subtract(ops.power(2, weight_bits), 1), "float32"
42
- )
43
- self.per_channel = per_channel
44
- self.symmetric = symmetric
45
- self.group_size = group_size
46
-
47
- # These are now determined later by `find_params`
48
- self.scale = None
49
- self.zero = None
50
-
51
- def find_params(self, input_tensor, weight=False):
52
- """Finds quantization parameters (scale and zero) for a given tensor."""
53
-
54
- if input_tensor is None:
55
- raise ValueError("Input tensor 'input_tensor' cannot be None.")
56
-
57
- # For weights, we typically expect at least a 2D tensor.
58
- if weight and len(input_tensor.shape) < 2:
59
- raise ValueError(
60
- f"Input weight tensor 'input_tensor' must have a rank of at "
61
- f"least 2, but got rank {len(input_tensor.shape)}."
62
- )
63
-
64
- if ops.size(input_tensor) == 0:
65
- raise ValueError("Input tensor 'input_tensor' cannot be empty.")
66
-
67
- original_shape = input_tensor.shape
68
-
69
- if self.per_channel:
70
- if weight:
71
- if self.group_size != -1:
72
- input_reshaped = ops.reshape(
73
- input_tensor, [-1, self.group_size]
74
- )
75
- else:
76
- input_reshaped = ops.reshape(
77
- input_tensor, [original_shape[0], -1]
78
- )
79
- else: # per-tensor
80
- input_reshaped = ops.reshape(input_tensor, [1, -1])
81
-
82
- # Find min/max values
83
- min_values = ops.min(input_reshaped, axis=1)
84
- max_values = ops.max(input_reshaped, axis=1)
85
-
86
- # Apply symmetric quantization logic if enabled
87
- if self.symmetric:
88
- max_values = ops.maximum(ops.abs(min_values), max_values)
89
- min_values = ops.where(
90
- ops.less(min_values, 0), ops.negative(max_values), min_values
91
- )
92
-
93
- # Ensure range is not zero to avoid division errors
94
- zero_range = ops.equal(min_values, max_values)
95
- min_values = ops.where(
96
- zero_range, ops.subtract(min_values, 1), min_values
97
- )
98
- max_values = ops.where(zero_range, ops.add(max_values, 1), max_values)
99
-
100
- # Calculate scale and zero-point
101
- self.scale = ops.divide(ops.subtract(max_values, min_values), self.maxq)
102
- if self.symmetric:
103
- self.zero = ops.full_like(
104
- self.scale, ops.divide(ops.add(self.maxq, 1), 2)
105
- )
106
- else:
107
- self.zero = ops.round(
108
- ops.divide(ops.negative(min_values), self.scale)
109
- )
110
-
111
- # Ensure scale is non-zero
112
- self.scale = ops.where(ops.less_equal(self.scale, 0), 1e-8, self.scale)
113
-
114
- if weight:
115
- # Per-channel, non-grouped case: simple reshape is correct.
116
- if self.per_channel and self.group_size == -1:
117
- self.scale = ops.reshape(self.scale, [-1, 1])
118
- self.zero = ops.reshape(self.zero, [-1, 1])
119
- elif not self.per_channel:
120
- num_rows = original_shape[0]
121
- self.scale = ops.tile(
122
- ops.reshape(self.scale, (1, 1)), (num_rows, 1)
123
- )
124
- self.zero = ops.tile(
125
- ops.reshape(self.zero, (1, 1)), (num_rows, 1)
126
- )
127
- if self.per_channel:
128
- self.scale = ops.reshape(self.scale, [-1, 1])
129
- self.zero = ops.reshape(self.zero, [-1, 1])
130
-
131
- def ready(self):
132
- """Checks if the quantization parameters have been computed."""
133
- return self.scale is not None and self.zero is not None