keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +13 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +3 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +13 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +9 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/name_scope.py +2 -1
  28. keras/src/backend/common/variables.py +30 -15
  29. keras/src/backend/jax/core.py +92 -3
  30. keras/src/backend/jax/distribution_lib.py +16 -2
  31. keras/src/backend/jax/linalg.py +4 -0
  32. keras/src/backend/jax/nn.py +509 -29
  33. keras/src/backend/jax/numpy.py +59 -8
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +311 -1
  37. keras/src/backend/numpy/numpy.py +65 -2
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +943 -189
  43. keras/src/backend/tensorflow/layer.py +43 -9
  44. keras/src/backend/tensorflow/linalg.py +24 -0
  45. keras/src/backend/tensorflow/nn.py +545 -1
  46. keras/src/backend/tensorflow/numpy.py +250 -50
  47. keras/src/backend/torch/core.py +3 -1
  48. keras/src/backend/torch/linalg.py +4 -0
  49. keras/src/backend/torch/nn.py +125 -0
  50. keras/src/backend/torch/numpy.py +80 -2
  51. keras/src/callbacks/__init__.py +1 -0
  52. keras/src/callbacks/model_checkpoint.py +5 -0
  53. keras/src/callbacks/orbax_checkpoint.py +332 -0
  54. keras/src/callbacks/terminate_on_nan.py +54 -5
  55. keras/src/datasets/cifar10.py +5 -0
  56. keras/src/distillation/__init__.py +1 -0
  57. keras/src/distillation/distillation_loss.py +390 -0
  58. keras/src/distillation/distiller.py +598 -0
  59. keras/src/distribution/distribution_lib.py +14 -0
  60. keras/src/dtype_policies/__init__.py +2 -0
  61. keras/src/dtype_policies/dtype_policy.py +90 -1
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/multi_head_attention.py +4 -1
  70. keras/src/layers/core/dense.py +241 -111
  71. keras/src/layers/core/einsum_dense.py +316 -131
  72. keras/src/layers/core/embedding.py +84 -94
  73. keras/src/layers/core/input_layer.py +1 -0
  74. keras/src/layers/core/reversible_embedding.py +399 -0
  75. keras/src/layers/input_spec.py +17 -17
  76. keras/src/layers/layer.py +45 -15
  77. keras/src/layers/merging/dot.py +4 -1
  78. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  79. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  80. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  81. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  82. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  83. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  84. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  85. keras/src/layers/preprocessing/discretization.py +6 -5
  86. keras/src/layers/preprocessing/feature_space.py +8 -4
  87. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  88. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  89. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  90. keras/src/layers/preprocessing/index_lookup.py +19 -1
  91. keras/src/layers/preprocessing/normalization.py +14 -1
  92. keras/src/layers/regularization/dropout.py +43 -1
  93. keras/src/layers/rnn/rnn.py +19 -0
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/losses/losses.py +24 -0
  96. keras/src/metrics/confusion_metrics.py +7 -6
  97. keras/src/models/cloning.py +4 -0
  98. keras/src/models/functional.py +11 -3
  99. keras/src/models/model.py +172 -34
  100. keras/src/ops/image.py +257 -20
  101. keras/src/ops/linalg.py +93 -0
  102. keras/src/ops/nn.py +258 -0
  103. keras/src/ops/numpy.py +569 -36
  104. keras/src/optimizers/muon.py +65 -31
  105. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  106. keras/src/quantizers/__init__.py +14 -1
  107. keras/src/quantizers/awq.py +361 -0
  108. keras/src/quantizers/awq_config.py +140 -0
  109. keras/src/quantizers/awq_core.py +217 -0
  110. keras/src/quantizers/gptq.py +2 -8
  111. keras/src/quantizers/gptq_config.py +36 -1
  112. keras/src/quantizers/gptq_core.py +65 -79
  113. keras/src/quantizers/quantization_config.py +246 -0
  114. keras/src/quantizers/quantizers.py +127 -61
  115. keras/src/quantizers/utils.py +23 -0
  116. keras/src/random/seed_generator.py +6 -4
  117. keras/src/saving/file_editor.py +81 -6
  118. keras/src/saving/orbax_util.py +26 -0
  119. keras/src/saving/saving_api.py +37 -14
  120. keras/src/saving/saving_lib.py +1 -1
  121. keras/src/testing/__init__.py +1 -0
  122. keras/src/testing/test_case.py +45 -5
  123. keras/src/utils/backend_utils.py +31 -4
  124. keras/src/utils/dataset_utils.py +234 -35
  125. keras/src/utils/file_utils.py +49 -11
  126. keras/src/utils/image_utils.py +14 -2
  127. keras/src/utils/jax_layer.py +244 -55
  128. keras/src/utils/module_utils.py +29 -0
  129. keras/src/utils/progbar.py +10 -2
  130. keras/src/utils/rng_utils.py +9 -1
  131. keras/src/utils/tracking.py +5 -5
  132. keras/src/version.py +1 -1
  133. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  134. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
  135. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  136. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  import functools
2
+ import math
2
3
  import operator
3
4
  import re
4
5
  import warnings
@@ -96,13 +97,13 @@ def _convert_conv_transpose_padding_args_from_keras_to_torch(
96
97
  )
97
98
 
98
99
  if torch_output_padding >= stride:
99
- raise ValueError(
100
- f"The padding arguments (padding={padding}) and "
101
- f"output_padding={output_padding}) lead to a Torch "
102
- f"output_padding ({torch_output_padding}) that is greater than "
103
- f"strides ({stride}). This is not supported. You can change the "
104
- f"padding arguments, kernel or stride, or run on another backend. "
100
+ warnings.warn(
101
+ f"Torch backend requires output_padding < stride. "
102
+ f"Clamping output_padding {torch_output_padding} -> {stride - 1} "
103
+ f"for stride {stride}.",
104
+ UserWarning,
105
105
  )
106
+ torch_output_padding = stride - 1
106
107
 
107
108
  return torch_padding, torch_output_padding
108
109
 
@@ -184,6 +185,22 @@ def compute_conv_transpose_padding_args_for_torch(
184
185
  torch_paddings.append(torch_padding)
185
186
  torch_output_paddings.append(torch_output_padding)
186
187
 
188
+ # --- FIX FOR TORCH CONSTRAINT: output_padding < stride ---
189
+ corrected_output_paddings = []
190
+ for s, op in zip(
191
+ strides
192
+ if isinstance(strides, (list, tuple))
193
+ else [strides] * num_spatial_dims,
194
+ torch_output_paddings,
195
+ ):
196
+ max_allowed = max(0, s - 1)
197
+ if op > max_allowed:
198
+ corrected_output_paddings.append(max_allowed)
199
+ else:
200
+ corrected_output_paddings.append(op)
201
+
202
+ torch_output_paddings = corrected_output_paddings
203
+
187
204
  return torch_paddings, torch_output_paddings
188
205
 
189
206
 
@@ -523,3 +540,10 @@ def slice_along_axis(x, start=0, stop=None, step=1, axis=0):
523
540
  -1 - axis
524
541
  )
525
542
  return x[tuple(slices)]
543
+
544
+
545
+ def compute_adaptive_pooling_window_sizes(input_dim, output_dim):
546
+ """Compute small and big window sizes for adaptive pooling."""
547
+ small = math.ceil(input_dim / output_dim)
548
+ big = small + 1
549
+ return small, big
@@ -58,7 +58,8 @@ class name_scope:
58
58
  name_scope_stack = global_state.get_global_attribute(
59
59
  "name_scope_stack"
60
60
  )
61
- name_scope_stack.pop()
61
+ if name_scope_stack:
62
+ name_scope_stack.pop()
62
63
 
63
64
 
64
65
  def current_path():
@@ -276,13 +276,13 @@ class Variable:
276
276
  return self._maybe_autocast(self._value)
277
277
 
278
278
  def assign(self, value):
279
- value = self._convert_to_tensor(value, dtype=self.dtype)
279
+ value = self._convert_to_tensor(value, dtype=self._dtype)
280
280
  if not shape_equal(value.shape, self.shape):
281
281
  raise ValueError(
282
282
  "The shape of the target variable and "
283
283
  "the shape of the target value in "
284
284
  "`variable.assign(value)` must match. "
285
- f"variable.shape={self.value.shape}, "
285
+ f"variable.shape={self.shape}, "
286
286
  f"Received: value.shape={value.shape}. "
287
287
  f"Target variable: {self}"
288
288
  )
@@ -399,7 +399,11 @@ class Variable:
399
399
  def __repr__(self):
400
400
  value = None
401
401
  if hasattr(self, "_value") and self._value is not None:
402
- value = backend.core.convert_to_numpy(self._value)
402
+ try:
403
+ value = backend.core.convert_to_numpy(self._value)
404
+ except:
405
+ # In some cases the conversion to numpy can fail.
406
+ pass
403
407
  value_str = f", value={value}" if value is not None else ""
404
408
  return (
405
409
  f"<Variable path={self.path}, shape={self.shape}, "
@@ -595,7 +599,6 @@ def standardize_shape(shape):
595
599
  # `tf.TensorShape` may contain `Dimension` objects.
596
600
  # We need to convert the items in it to either int or `None`
597
601
  shape = shape.as_list()
598
- shape = tuple(shape)
599
602
 
600
603
  if config.backend() == "jax":
601
604
  # Replace `_DimExpr` (dimension expression) with None
@@ -605,25 +608,37 @@ def standardize_shape(shape):
605
608
  None if jax_export.is_symbolic_dim(d) else d for d in shape
606
609
  )
607
610
 
608
- if config.backend() == "torch":
609
- # `shape` might be `torch.Size`. We need to convert the items in it to
610
- # either int or `None`
611
- shape = tuple(map(lambda x: int(x) if x is not None else None, shape))
612
-
613
- for e in shape:
614
- if e is None:
611
+ # Handle dimensions that are not ints and not None, verify they're >= 0.
612
+ standardized_shape = []
613
+ for d in shape:
614
+ if d is None:
615
+ standardized_shape.append(d)
615
616
  continue
616
- if not is_int_dtype(type(e)):
617
+
618
+ # Reject these even if they can be cast to int successfully.
619
+ if isinstance(d, (str, float)):
617
620
  raise ValueError(
618
621
  f"Cannot convert '{shape}' to a shape. "
619
- f"Found invalid entry '{e}' of type '{type(e)}'. "
622
+ f"Found invalid dimension '{d}' of type '{type(d)}'. "
620
623
  )
621
- if e < 0:
624
+
625
+ try:
626
+ # Cast numpy scalars, tf constant tensors, etc.
627
+ d = int(d)
628
+ except Exception as e:
629
+ raise ValueError(
630
+ f"Cannot convert '{shape}' to a shape. "
631
+ f"Found invalid dimension '{d}' of type '{type(d)}'. "
632
+ ) from e
633
+ if d < 0:
622
634
  raise ValueError(
623
635
  f"Cannot convert '{shape}' to a shape. "
624
636
  "Negative dimensions are not allowed."
625
637
  )
626
- return shape
638
+ standardized_shape.append(d)
639
+
640
+ # This also turns subclasses of `tuple` (e.g. `torch.Size`) to plain tuple.
641
+ return tuple(standardized_shape)
627
642
 
628
643
 
629
644
  def shape_equal(a_shape, b_shape):
@@ -30,9 +30,7 @@ class JaxVariable(KerasVariable):
30
30
  self._layout = layout
31
31
  super().__init__(*args, **kwargs)
32
32
 
33
- def _initialize(self, value):
34
- # Note that variable.shape is needed by distribution_lib
35
- self._shape = self._validate_shape(value.shape)
33
+ def _initialize_layout(self):
36
34
  # We can't import the keras/distribution/distribution_lib
37
35
  # due to circular dependency.
38
36
  distribution = global_state.get_global_attribute("distribution")
@@ -44,8 +42,28 @@ class JaxVariable(KerasVariable):
44
42
  self._layout = tensor_layout.backend_layout
45
43
  else:
46
44
  self._layout = tensor_layout
45
+
46
+ def _initialize(self, value):
47
+ # Note that variable.shape is needed by distribution_lib
48
+ self._shape = self._validate_shape(value.shape)
49
+ self._initialize_layout()
47
50
  self._direct_assign(value)
48
51
 
52
+ def _initialize_with_initializer(self, initializer):
53
+ self._initialize_layout()
54
+ layout = self._layout
55
+ shape = self._shape
56
+ if should_shard_at_init(layout, shape):
57
+ jitted_initializer = jax.jit(
58
+ initializer.__call__,
59
+ out_shardings=layout,
60
+ static_argnames=["shape", "dtype"],
61
+ )
62
+ value = jitted_initializer(shape=self._shape, dtype=self._dtype)
63
+ self._value = value
64
+ else:
65
+ super()._initialize_with_initializer(initializer)
66
+
49
67
  def _direct_assign(self, value):
50
68
  if self._layout is not None:
51
69
  value = distribution_lib.distribute_variable(value, self._layout)
@@ -112,6 +130,12 @@ if config.is_nnx_enabled():
112
130
  # The real value is now set in self._value, sync it to raw_value
113
131
  object.__setattr__(self, "raw_value", self._value)
114
132
 
133
+ def _initialize_with_initializer(self, initializer):
134
+ value = self._convert_to_tensor(
135
+ initializer(self._shape, dtype=self._dtype)
136
+ )
137
+ self._initialize(value)
138
+
115
139
  @property
116
140
  def _value(self):
117
141
  if hasattr(self, "raw_value"):
@@ -234,6 +258,71 @@ if config.is_nnx_enabled():
234
258
 
235
259
  Variable = NnxVariable
236
260
 
261
+ def _flatten_nnx_variable(variable):
262
+ children = (variable.raw_value,)
263
+ # We copy __dict__ to avoid side effects
264
+ keras_state = variable.__dict__.copy()
265
+ # Remove elements that might be problematic or redundant if
266
+ # nnx.Variable's __getstate__
267
+ keras_state.pop("raw_value", None)
268
+ aux_data = (
269
+ variable._var_metadata,
270
+ getattr(variable, "_trace_state", None),
271
+ keras_state,
272
+ )
273
+ return children, aux_data
274
+
275
+ def _unflatten_nnx_variable(aux_data, children):
276
+ var_metadata, trace_state, keras_state = aux_data
277
+ raw_value = children[0]
278
+
279
+ # Create uninitialized instance
280
+ variable = NnxVariable.__new__(NnxVariable)
281
+
282
+ # Restore state
283
+ variable._var_metadata = var_metadata
284
+ if trace_state is not None:
285
+ variable._trace_state = trace_state
286
+ variable.__dict__.update(keras_state)
287
+ variable.raw_value = raw_value
288
+
289
+ return variable
290
+
291
+ try:
292
+ jax.tree_util.register_pytree_node(
293
+ NnxVariable,
294
+ _flatten_nnx_variable,
295
+ _unflatten_nnx_variable,
296
+ )
297
+ except ValueError:
298
+ pass
299
+
300
+ def __setattr__(self, name, value):
301
+ # Mirror Keras attributes to _var_metadata to ensure persistence
302
+ # if the Pytree registration is not respected by NNX.
303
+ if (
304
+ name != "_var_metadata"
305
+ and name not in ("_raw_value", "_trace_state")
306
+ and hasattr(self, "_var_metadata")
307
+ ):
308
+ self._var_metadata[name] = value
309
+
310
+ object.__setattr__(self, name, value)
311
+
312
+ NnxVariable.__setattr__ = __setattr__
313
+
314
+
315
+ def should_shard_at_init(init_layout, shape):
316
+ if not isinstance(init_layout, jax.sharding.NamedSharding):
317
+ return False
318
+
319
+ if all(dim is None for dim in init_layout.spec):
320
+ return False
321
+
322
+ size_threshold = 250 * 1024 * 1024
323
+ array_size = np.prod(shape) * 4
324
+ return array_size >= size_threshold
325
+
237
326
 
238
327
  def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
239
328
  if ragged:
@@ -27,6 +27,20 @@ def list_devices(device_type=None):
27
27
  return [f"{device.platform}:{device.id}" for device in jax_devices]
28
28
 
29
29
 
30
+ def get_device_count(device_type=None):
31
+ """Returns the number of available JAX devices.
32
+ Args:
33
+ device_type: Optional device type to count (e.g., "cpu", "gpu", "tpu").
34
+ If `None`, it defaults to counting "gpu" or "tpu" devices if
35
+ available, otherwise it counts "cpu" devices. It does not
36
+ return the sum of all device types.
37
+ Returns:
38
+ int: The total number of JAX devices for the specified type.
39
+ """
40
+ device_type = device_type.lower() if device_type else None
41
+ return jax.device_count(device_type)
42
+
43
+
30
44
  def distribute_variable(value, layout):
31
45
  """Create a distributed variable for JAX.
32
46
 
@@ -146,13 +160,13 @@ def initialize_rng():
146
160
  # Check if the global seed generator is set and ensure it has an initialized
147
161
  # seed. Otherwise, reset the seed to the global seed.
148
162
  global_seed_generator = global_state.get_global_attribute(
149
- "global_seed_generator"
163
+ seed_generator.GLOBAL_SEED_GENERATOR
150
164
  )
151
165
  if global_seed_generator is not None:
152
166
  seed = global_seed_generator.get_config()["seed"]
153
167
  if seed is None:
154
168
  global_state.set_global_attribute(
155
- "global_seed_generator",
169
+ seed_generator.GLOBAL_SEED_GENERATOR,
156
170
  seed_generator.SeedGenerator(
157
171
  seed=global_seed,
158
172
  name=global_seed_generator.name,
@@ -97,3 +97,7 @@ def lstsq(a, b, rcond=None):
97
97
  a = convert_to_tensor(a)
98
98
  b = convert_to_tensor(b)
99
99
  return jnp.linalg.lstsq(a, b, rcond=rcond)[0]
100
+
101
+
102
+ def jvp(fun, primals, tangents, has_aux=False):
103
+ return jax.jvp(fun, primals, tangents, has_aux=has_aux)