keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  import functools
2
+ import math
2
3
  import operator
3
4
  import re
4
5
  import warnings
@@ -96,13 +97,13 @@ def _convert_conv_transpose_padding_args_from_keras_to_torch(
96
97
  )
97
98
 
98
99
  if torch_output_padding >= stride:
99
- raise ValueError(
100
- f"The padding arguments (padding={padding}) and "
101
- f"output_padding={output_padding}) lead to a Torch "
102
- f"output_padding ({torch_output_padding}) that is greater than "
103
- f"strides ({stride}). This is not supported. You can change the "
104
- f"padding arguments, kernel or stride, or run on another backend. "
100
+ warnings.warn(
101
+ f"Torch backend requires output_padding < stride. "
102
+ f"Clamping output_padding {torch_output_padding} -> {stride - 1} "
103
+ f"for stride {stride}.",
104
+ UserWarning,
105
105
  )
106
+ torch_output_padding = stride - 1
106
107
 
107
108
  return torch_padding, torch_output_padding
108
109
 
@@ -184,6 +185,22 @@ def compute_conv_transpose_padding_args_for_torch(
184
185
  torch_paddings.append(torch_padding)
185
186
  torch_output_paddings.append(torch_output_padding)
186
187
 
188
+ # --- FIX FOR TORCH CONSTRAINT: output_padding < stride ---
189
+ corrected_output_paddings = []
190
+ for s, op in zip(
191
+ strides
192
+ if isinstance(strides, (list, tuple))
193
+ else [strides] * num_spatial_dims,
194
+ torch_output_paddings,
195
+ ):
196
+ max_allowed = max(0, s - 1)
197
+ if op > max_allowed:
198
+ corrected_output_paddings.append(max_allowed)
199
+ else:
200
+ corrected_output_paddings.append(op)
201
+
202
+ torch_output_paddings = corrected_output_paddings
203
+
187
204
  return torch_paddings, torch_output_paddings
188
205
 
189
206
 
@@ -523,3 +540,10 @@ def slice_along_axis(x, start=0, stop=None, step=1, axis=0):
523
540
  -1 - axis
524
541
  )
525
542
  return x[tuple(slices)]
543
+
544
+
545
+ def compute_adaptive_pooling_window_sizes(input_dim, output_dim):
546
+ """Compute small and big window sizes for adaptive pooling."""
547
+ small = math.ceil(input_dim / output_dim)
548
+ big = small + 1
549
+ return small, big
@@ -232,18 +232,12 @@ def _resolve_weak_type(dtype, precision="32"):
232
232
  return f"float{precision}"
233
233
 
234
234
 
235
- BIT64_TO_BIT16_DTYPE = {
236
- "int32": "int16",
237
- "int64": "int16",
238
- "uint32": "uint16",
239
- "uint64": "uint16",
240
- "float32": "float16",
241
- "float64": "float16",
242
- }
243
235
  BIT64_TO_BIT32_DTYPE = {
244
- "int64": "int32",
236
+ # Since TF variables require int64 to be placed on the GPU, we exclusively
237
+ # enable the int64 dtype for TF.
238
+ "int64": "int64" if config.backend() == "tensorflow" else "int32",
245
239
  "uint64": "uint32",
246
- "float64": "float32",
240
+ "float64": "float64" if config.backend() == "tensorflow" else "float32",
247
241
  "complex128": "complex64",
248
242
  }
249
243
 
@@ -277,8 +271,8 @@ def _lattice_result_type(*args):
277
271
  if out_weak_type:
278
272
  out_dtype = _resolve_weak_type(out_dtype, precision=precision)
279
273
 
280
- # Force to be 32-bit dtype when encountering 64-bit dtype.
281
- # TODO(hongyu): Add a config to enable 64-bit dtypes.
274
+ # Force to be 32-bit dtype when encountering 64-bit dtype. This is to
275
+ # be aligned with JAX's default behavior.
282
276
  out_dtype = BIT64_TO_BIT32_DTYPE.get(out_dtype, out_dtype)
283
277
  return out_dtype
284
278
 
@@ -58,7 +58,8 @@ class name_scope:
58
58
  name_scope_stack = global_state.get_global_attribute(
59
59
  "name_scope_stack"
60
60
  )
61
- name_scope_stack.pop()
61
+ if name_scope_stack:
62
+ name_scope_stack.pop()
62
63
 
63
64
 
64
65
  def current_path():
@@ -1,5 +1,3 @@
1
- import os.path
2
-
3
1
  import numpy as np
4
2
 
5
3
  from keras.src import backend
@@ -144,7 +142,7 @@ class Variable:
144
142
  self._name = name
145
143
  parent_path = current_path()
146
144
  if parent_path:
147
- self._path = os.path.join(current_path(), name)
145
+ self._path = f"{parent_path}/{name}"
148
146
  else:
149
147
  self._path = name
150
148
  self._shape = None
@@ -278,13 +276,13 @@ class Variable:
278
276
  return self._maybe_autocast(self._value)
279
277
 
280
278
  def assign(self, value):
281
- value = self._convert_to_tensor(value, dtype=self.dtype)
279
+ value = self._convert_to_tensor(value, dtype=self._dtype)
282
280
  if not shape_equal(value.shape, self.shape):
283
281
  raise ValueError(
284
282
  "The shape of the target variable and "
285
283
  "the shape of the target value in "
286
284
  "`variable.assign(value)` must match. "
287
- f"variable.shape={self.value.shape}, "
285
+ f"variable.shape={self.shape}, "
288
286
  f"Received: value.shape={value.shape}. "
289
287
  f"Target variable: {self}"
290
288
  )
@@ -401,7 +399,11 @@ class Variable:
401
399
  def __repr__(self):
402
400
  value = None
403
401
  if hasattr(self, "_value") and self._value is not None:
404
- value = backend.core.convert_to_numpy(self._value)
402
+ try:
403
+ value = backend.core.convert_to_numpy(self._value)
404
+ except:
405
+ # In some cases the conversion to numpy can fail.
406
+ pass
405
407
  value_str = f", value={value}" if value is not None else ""
406
408
  return (
407
409
  f"<Variable path={self.path}, shape={self.shape}, "
@@ -597,30 +599,46 @@ def standardize_shape(shape):
597
599
  # `tf.TensorShape` may contain `Dimension` objects.
598
600
  # We need to convert the items in it to either int or `None`
599
601
  shape = shape.as_list()
600
- shape = tuple(shape)
601
602
 
602
- if config.backend() == "torch":
603
- # `shape` might be `torch.Size`. We need to convert the items in it to
604
- # either int or `None`
605
- shape = tuple(map(lambda x: int(x) if x is not None else None, shape))
603
+ if config.backend() == "jax":
604
+ # Replace `_DimExpr` (dimension expression) with None
605
+ from jax import export as jax_export
606
606
 
607
- for e in shape:
608
- if e is None:
609
- continue
610
- if config.backend() == "jax" and "_DimExpr" in str(type(e)):
611
- # JAX2TF tracing uses JAX-native dimension expressions
607
+ shape = tuple(
608
+ None if jax_export.is_symbolic_dim(d) else d for d in shape
609
+ )
610
+
611
+ # Handle dimensions that are not ints and not None, verify they're >= 0.
612
+ standardized_shape = []
613
+ for d in shape:
614
+ if d is None:
615
+ standardized_shape.append(d)
612
616
  continue
613
- if not is_int_dtype(type(e)):
617
+
618
+ # Reject these even if they can be cast to int successfully.
619
+ if isinstance(d, (str, float)):
614
620
  raise ValueError(
615
621
  f"Cannot convert '{shape}' to a shape. "
616
- f"Found invalid entry '{e}' of type '{type(e)}'. "
622
+ f"Found invalid dimension '{d}' of type '{type(d)}'. "
617
623
  )
618
- if e < 0:
624
+
625
+ try:
626
+ # Cast numpy scalars, tf constant tensors, etc.
627
+ d = int(d)
628
+ except Exception as e:
629
+ raise ValueError(
630
+ f"Cannot convert '{shape}' to a shape. "
631
+ f"Found invalid dimension '{d}' of type '{type(d)}'. "
632
+ ) from e
633
+ if d < 0:
619
634
  raise ValueError(
620
635
  f"Cannot convert '{shape}' to a shape. "
621
636
  "Negative dimensions are not allowed."
622
637
  )
623
- return shape
638
+ standardized_shape.append(d)
639
+
640
+ # This also turns subclasses of `tuple` (e.g. `torch.Size`) to plain tuple.
641
+ return tuple(standardized_shape)
624
642
 
625
643
 
626
644
  def shape_equal(a_shape, b_shape):
@@ -3,6 +3,7 @@ import jax.experimental.sparse as jax_sparse
3
3
  import jax.numpy as jnp
4
4
  import ml_dtypes
5
5
  import numpy as np
6
+ from jax import export as jax_export
6
7
 
7
8
  from keras.src import tree
8
9
  from keras.src.backend import config
@@ -29,9 +30,7 @@ class JaxVariable(KerasVariable):
29
30
  self._layout = layout
30
31
  super().__init__(*args, **kwargs)
31
32
 
32
- def _initialize(self, value):
33
- # Note that variable.shape is needed by distribution_lib
34
- self._shape = self._validate_shape(value.shape)
33
+ def _initialize_layout(self):
35
34
  # We can't import the keras/distribution/distribution_lib
36
35
  # due to circular dependency.
37
36
  distribution = global_state.get_global_attribute("distribution")
@@ -43,8 +42,28 @@ class JaxVariable(KerasVariable):
43
42
  self._layout = tensor_layout.backend_layout
44
43
  else:
45
44
  self._layout = tensor_layout
45
+
46
+ def _initialize(self, value):
47
+ # Note that variable.shape is needed by distribution_lib
48
+ self._shape = self._validate_shape(value.shape)
49
+ self._initialize_layout()
46
50
  self._direct_assign(value)
47
51
 
52
+ def _initialize_with_initializer(self, initializer):
53
+ self._initialize_layout()
54
+ layout = self._layout
55
+ shape = self._shape
56
+ if should_shard_at_init(layout, shape):
57
+ jitted_initializer = jax.jit(
58
+ initializer.__call__,
59
+ out_shardings=layout,
60
+ static_argnames=["shape", "dtype"],
61
+ )
62
+ value = jitted_initializer(shape=self._shape, dtype=self._dtype)
63
+ self._value = value
64
+ else:
65
+ super()._initialize_with_initializer(initializer)
66
+
48
67
  def _direct_assign(self, value):
49
68
  if self._layout is not None:
50
69
  value = distribution_lib.distribute_variable(value, self._layout)
@@ -111,6 +130,12 @@ if config.is_nnx_enabled():
111
130
  # The real value is now set in self._value, sync it to raw_value
112
131
  object.__setattr__(self, "raw_value", self._value)
113
132
 
133
+ def _initialize_with_initializer(self, initializer):
134
+ value = self._convert_to_tensor(
135
+ initializer(self._shape, dtype=self._dtype)
136
+ )
137
+ self._initialize(value)
138
+
114
139
  @property
115
140
  def _value(self):
116
141
  if hasattr(self, "raw_value"):
@@ -233,6 +258,71 @@ if config.is_nnx_enabled():
233
258
 
234
259
  Variable = NnxVariable
235
260
 
261
+ def _flatten_nnx_variable(variable):
262
+ children = (variable.raw_value,)
263
+ # We copy __dict__ to avoid side effects
264
+ keras_state = variable.__dict__.copy()
265
+ # Remove elements that might be problematic or redundant if
266
+ # nnx.Variable's __getstate__
267
+ keras_state.pop("raw_value", None)
268
+ aux_data = (
269
+ variable._var_metadata,
270
+ getattr(variable, "_trace_state", None),
271
+ keras_state,
272
+ )
273
+ return children, aux_data
274
+
275
+ def _unflatten_nnx_variable(aux_data, children):
276
+ var_metadata, trace_state, keras_state = aux_data
277
+ raw_value = children[0]
278
+
279
+ # Create uninitialized instance
280
+ variable = NnxVariable.__new__(NnxVariable)
281
+
282
+ # Restore state
283
+ variable._var_metadata = var_metadata
284
+ if trace_state is not None:
285
+ variable._trace_state = trace_state
286
+ variable.__dict__.update(keras_state)
287
+ variable.raw_value = raw_value
288
+
289
+ return variable
290
+
291
+ try:
292
+ jax.tree_util.register_pytree_node(
293
+ NnxVariable,
294
+ _flatten_nnx_variable,
295
+ _unflatten_nnx_variable,
296
+ )
297
+ except ValueError:
298
+ pass
299
+
300
+ def __setattr__(self, name, value):
301
+ # Mirror Keras attributes to _var_metadata to ensure persistence
302
+ # if the Pytree registration is not respected by NNX.
303
+ if (
304
+ name != "_var_metadata"
305
+ and name not in ("_raw_value", "_trace_state")
306
+ and hasattr(self, "_var_metadata")
307
+ ):
308
+ self._var_metadata[name] = value
309
+
310
+ object.__setattr__(self, name, value)
311
+
312
+ NnxVariable.__setattr__ = __setattr__
313
+
314
+
315
+ def should_shard_at_init(init_layout, shape):
316
+ if not isinstance(init_layout, jax.sharding.NamedSharding):
317
+ return False
318
+
319
+ if all(dim is None for dim in init_layout.spec):
320
+ return False
321
+
322
+ size_threshold = 250 * 1024 * 1024
323
+ array_size = np.prod(shape) * 4
324
+ return array_size >= size_threshold
325
+
236
326
 
237
327
  def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
238
328
  if ragged:
@@ -282,8 +372,6 @@ def is_tensor(x):
282
372
 
283
373
 
284
374
  def shape(x):
285
- # This will work as long as we disallow
286
- # dynamic shapes in JAX.
287
375
  return x.shape
288
376
 
289
377
 
@@ -315,31 +403,29 @@ def compute_output_spec(fn, *args, **kwargs):
315
403
  else:
316
404
  maybe_symbolic_kwargs[k] = v
317
405
 
318
- # Second, find out if there are dynamic shapes
319
- has_none = False
320
- for x in tree.flatten((maybe_symbolic_args, maybe_symbolic_kwargs)):
321
- if isinstance(x, KerasTensor) and any(d is None for d in x.shape):
322
- has_none = True
323
-
324
- def convert_keras_tensor_to_jax(x, fill_value=None):
406
+ # Create a _DimExpr instance for one dimension by creating a symbolic
407
+ # shape with one dimension and extracting it.
408
+ #
409
+ # We create a single dynamic dimension and reuse it instead of creating
410
+ # N dynamic dimensions. This is for backwards compatibility. Previously
411
+ # we would fill all dynamic dimensions with the same concrete value.
412
+ # This can handle the case where there is an implicit assumption that
413
+ # two dimensions are the same (e.g. square images).
414
+ #
415
+ # We add the constraint "dynamic_dimension>=2" to prevent JAX from
416
+ # assuming that the dimension can be broadcastable or squeezable. It
417
+ # removes this ambiguity.
418
+ dynamic_dimension = jax_export.symbolic_shape(
419
+ "(dynamic_dimension)",
420
+ constraints=["dynamic_dimension>=2"],
421
+ )[0]
422
+
423
+ def convert_keras_tensor_to_jax(x):
325
424
  if isinstance(x, KerasTensor):
326
- shape = list(x.shape)
327
- if fill_value:
328
- for i, e in enumerate(shape):
329
- if e is None:
330
- shape[i] = fill_value
331
- jax_tensor = jax.ShapeDtypeStruct(shape, dtype=x.dtype)
332
- return jax_tensor
333
- if isinstance(x, dict):
334
- return {
335
- k: convert_keras_tensor_to_jax(v, fill_value=fill_value)
336
- for k, v in x.items()
337
- }
338
- if isinstance(x, list):
339
- return [
340
- convert_keras_tensor_to_jax(xi, fill_value=fill_value)
341
- for xi in x
342
- ]
425
+ shape = tuple(
426
+ [d if d is not None else dynamic_dimension for d in x.shape]
427
+ )
428
+ return jax.ShapeDtypeStruct(shape, dtype=x.dtype)
343
429
  return x
344
430
 
345
431
  def wrapped_fn(*args, **kwargs):
@@ -374,63 +460,25 @@ def compute_output_spec(fn, *args, **kwargs):
374
460
  with StatelessScope():
375
461
  return fn(*rec_args, **kwargs, **static_kwargs)
376
462
 
377
- if has_none:
378
- ms_args_1, ms_kwargs_1 = tree.map_structure(
379
- lambda x: convert_keras_tensor_to_jax(x, fill_value=83),
380
- (maybe_symbolic_args, maybe_symbolic_kwargs),
381
- )
382
- _, jax_out_1 = jax.make_jaxpr(wrapped_fn, return_shape=True)(
383
- *ms_args_1, **ms_kwargs_1
384
- )
385
-
386
- ms_args_2, ms_kwargs_2 = tree.map_structure(
387
- lambda x: convert_keras_tensor_to_jax(x, fill_value=89),
388
- (maybe_symbolic_args, maybe_symbolic_kwargs),
389
- )
390
- _, jax_out_2 = jax.make_jaxpr(wrapped_fn, return_shape=True)(
391
- *ms_args_2, **ms_kwargs_2
392
- )
393
-
394
- def merge_shapes(shape1, shape2):
395
- return tuple(
396
- [d1 if d1 == d2 else None for d1, d2 in zip(shape1, shape2)]
397
- )
398
-
399
- def convert_jax_specs_to_keras_tensor(x1, x2):
400
- if isinstance(x1, jax.ShapeDtypeStruct):
401
- if not isinstance(x2, jax.ShapeDtypeStruct):
402
- raise ValueError("Indeterministic output ordering.")
403
- return KerasTensor(
404
- merge_shapes(x1.shape, x2.shape), dtype=x1.dtype
405
- )
406
- elif isinstance(x1, jax_sparse.BCOO):
407
- if not isinstance(x2, jax_sparse.BCOO):
408
- raise ValueError("Indeterministic output ordering.")
409
- return KerasTensor(
410
- merge_shapes(x1.shape, x2.shape),
411
- dtype=x1.dtype,
412
- sparse=True,
413
- )
414
- else:
415
- return x1
416
-
417
- return tree.map_structure(
418
- convert_jax_specs_to_keras_tensor, jax_out_1, jax_out_2
419
- )
420
-
421
- maybe_symbolic_args, maybe_symbolic_kwargs = tree.map_structure(
463
+ maybe_symbolic_args_jax, maybe_symbolic_kwargs_jax = tree.map_structure(
422
464
  convert_keras_tensor_to_jax,
423
465
  (maybe_symbolic_args, maybe_symbolic_kwargs),
424
466
  )
425
- _, jax_out = jax.make_jaxpr(wrapped_fn, return_shape=True)(
426
- *maybe_symbolic_args, **maybe_symbolic_kwargs
467
+ jax_out = jax.eval_shape(
468
+ wrapped_fn, *maybe_symbolic_args_jax, **maybe_symbolic_kwargs_jax
427
469
  )
428
470
 
429
471
  def convert_jax_spec_to_keras_tensor(x):
430
472
  if isinstance(x, jax.ShapeDtypeStruct):
431
- return KerasTensor(x.shape, x.dtype)
473
+ shape = tuple(
474
+ d if isinstance(d, int) else None for d in x.shape
475
+ )
476
+ return KerasTensor(shape, x.dtype)
432
477
  elif isinstance(x, jax_sparse.BCOO):
433
- return KerasTensor(x.shape, x.dtype, sparse=True)
478
+ shape = tuple(
479
+ d if isinstance(d, int) else None for d in x.shape
480
+ )
481
+ return KerasTensor(shape, x.dtype, sparse=True)
434
482
  return x
435
483
 
436
484
  return tree.map_structure(convert_jax_spec_to_keras_tensor, jax_out)
@@ -27,6 +27,20 @@ def list_devices(device_type=None):
27
27
  return [f"{device.platform}:{device.id}" for device in jax_devices]
28
28
 
29
29
 
30
+ def get_device_count(device_type=None):
31
+ """Returns the number of available JAX devices.
32
+ Args:
33
+ device_type: Optional device type to count (e.g., "cpu", "gpu", "tpu").
34
+ If `None`, it defaults to counting "gpu" or "tpu" devices if
35
+ available, otherwise it counts "cpu" devices. It does not
36
+ return the sum of all device types.
37
+ Returns:
38
+ int: The total number of JAX devices for the specified type.
39
+ """
40
+ device_type = device_type.lower() if device_type else None
41
+ return jax.device_count(device_type)
42
+
43
+
30
44
  def distribute_variable(value, layout):
31
45
  """Create a distributed variable for JAX.
32
46
 
@@ -146,13 +160,13 @@ def initialize_rng():
146
160
  # Check if the global seed generator is set and ensure it has an initialized
147
161
  # seed. Otherwise, reset the seed to the global seed.
148
162
  global_seed_generator = global_state.get_global_attribute(
149
- "global_seed_generator"
163
+ seed_generator.GLOBAL_SEED_GENERATOR
150
164
  )
151
165
  if global_seed_generator is not None:
152
166
  seed = global_seed_generator.get_config()["seed"]
153
167
  if seed is None:
154
168
  global_state.set_global_attribute(
155
- "global_seed_generator",
169
+ seed_generator.GLOBAL_SEED_GENERATOR,
156
170
  seed_generator.SeedGenerator(
157
171
  seed=global_seed,
158
172
  name=global_seed_generator.name,
@@ -3,7 +3,9 @@ from keras.src.backend.config import is_nnx_enabled
3
3
  if is_nnx_enabled():
4
4
  from flax import nnx
5
5
 
6
- BaseLayer = nnx.Module
6
+ class BaseLayer(nnx.Module):
7
+ def __init_subclass__(cls, **kwargs):
8
+ super().__init_subclass__(pytree=False, **kwargs)
7
9
  else:
8
10
  BaseLayer = object
9
11
 
@@ -97,3 +97,7 @@ def lstsq(a, b, rcond=None):
97
97
  a = convert_to_tensor(a)
98
98
  b = convert_to_tensor(b)
99
99
  return jnp.linalg.lstsq(a, b, rcond=rcond)[0]
100
+
101
+
102
+ def jvp(fun, primals, tangents, has_aux=False):
103
+ return jax.jvp(fun, primals, tangents, has_aux=has_aux)