keras-nightly 3.12.0.dev2025092403__py3-none-any.whl → 3.14.0.dev2026010104__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/layers/__init__.py +21 -0
  7. keras/_tf_keras/keras/ops/__init__.py +13 -0
  8. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  9. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  11. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  12. keras/_tf_keras/keras/quantizers/__init__.py +12 -0
  13. keras/callbacks/__init__.py +3 -0
  14. keras/distillation/__init__.py +16 -0
  15. keras/distribution/__init__.py +3 -0
  16. keras/layers/__init__.py +21 -0
  17. keras/ops/__init__.py +13 -0
  18. keras/ops/image/__init__.py +1 -0
  19. keras/ops/linalg/__init__.py +1 -0
  20. keras/ops/nn/__init__.py +3 -0
  21. keras/ops/numpy/__init__.py +9 -0
  22. keras/quantizers/__init__.py +12 -0
  23. keras/src/applications/imagenet_utils.py +4 -1
  24. keras/src/backend/common/backend_utils.py +30 -6
  25. keras/src/backend/common/dtypes.py +1 -1
  26. keras/src/backend/common/name_scope.py +2 -1
  27. keras/src/backend/common/variables.py +33 -16
  28. keras/src/backend/jax/core.py +92 -3
  29. keras/src/backend/jax/distribution_lib.py +16 -2
  30. keras/src/backend/jax/linalg.py +4 -0
  31. keras/src/backend/jax/nn.py +485 -20
  32. keras/src/backend/jax/numpy.py +92 -23
  33. keras/src/backend/jax/optimizer.py +3 -2
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +313 -2
  37. keras/src/backend/numpy/numpy.py +76 -7
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +1030 -185
  43. keras/src/backend/openvino/random.py +7 -14
  44. keras/src/backend/tensorflow/layer.py +43 -9
  45. keras/src/backend/tensorflow/linalg.py +24 -0
  46. keras/src/backend/tensorflow/nn.py +545 -1
  47. keras/src/backend/tensorflow/numpy.py +264 -54
  48. keras/src/backend/torch/core.py +3 -1
  49. keras/src/backend/torch/linalg.py +4 -0
  50. keras/src/backend/torch/nn.py +125 -0
  51. keras/src/backend/torch/numpy.py +84 -8
  52. keras/src/callbacks/__init__.py +1 -0
  53. keras/src/callbacks/callback_list.py +45 -11
  54. keras/src/callbacks/model_checkpoint.py +5 -0
  55. keras/src/callbacks/orbax_checkpoint.py +299 -0
  56. keras/src/callbacks/terminate_on_nan.py +54 -5
  57. keras/src/datasets/cifar10.py +5 -0
  58. keras/src/distillation/__init__.py +1 -0
  59. keras/src/distillation/distillation_loss.py +390 -0
  60. keras/src/distillation/distiller.py +598 -0
  61. keras/src/distribution/distribution_lib.py +14 -0
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/attention.py +1 -1
  70. keras/src/layers/attention/multi_head_attention.py +4 -1
  71. keras/src/layers/core/dense.py +191 -172
  72. keras/src/layers/core/einsum_dense.py +235 -186
  73. keras/src/layers/core/embedding.py +83 -93
  74. keras/src/layers/core/input_layer.py +1 -0
  75. keras/src/layers/core/reversible_embedding.py +390 -0
  76. keras/src/layers/input_spec.py +17 -17
  77. keras/src/layers/layer.py +40 -15
  78. keras/src/layers/merging/dot.py +4 -1
  79. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  80. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  81. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  82. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  83. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  84. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  85. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  86. keras/src/layers/preprocessing/discretization.py +6 -5
  87. keras/src/layers/preprocessing/index_lookup.py +19 -1
  88. keras/src/layers/preprocessing/normalization.py +16 -1
  89. keras/src/layers/regularization/dropout.py +43 -1
  90. keras/src/layers/rnn/gru.py +1 -1
  91. keras/src/layers/rnn/lstm.py +2 -2
  92. keras/src/layers/rnn/rnn.py +19 -0
  93. keras/src/layers/rnn/simple_rnn.py +1 -1
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/metrics/confusion_metrics.py +7 -6
  96. keras/src/models/cloning.py +4 -0
  97. keras/src/models/functional.py +11 -3
  98. keras/src/models/model.py +156 -27
  99. keras/src/ops/image.py +184 -3
  100. keras/src/ops/linalg.py +93 -0
  101. keras/src/ops/nn.py +268 -2
  102. keras/src/ops/numpy.py +541 -43
  103. keras/src/optimizers/adafactor.py +29 -10
  104. keras/src/optimizers/base_optimizer.py +22 -3
  105. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  106. keras/src/optimizers/muon.py +65 -31
  107. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  108. keras/src/quantizers/__init__.py +12 -1
  109. keras/src/quantizers/gptq.py +8 -6
  110. keras/src/quantizers/gptq_config.py +36 -1
  111. keras/src/quantizers/gptq_core.py +150 -78
  112. keras/src/quantizers/quantization_config.py +232 -0
  113. keras/src/quantizers/quantizers.py +114 -38
  114. keras/src/quantizers/utils.py +23 -0
  115. keras/src/random/seed_generator.py +4 -2
  116. keras/src/saving/file_editor.py +81 -6
  117. keras/src/saving/saving_lib.py +1 -1
  118. keras/src/testing/__init__.py +1 -0
  119. keras/src/testing/test_case.py +45 -5
  120. keras/src/trainers/compile_utils.py +14 -5
  121. keras/src/utils/backend_utils.py +31 -4
  122. keras/src/utils/dataset_utils.py +234 -35
  123. keras/src/utils/file_utils.py +49 -11
  124. keras/src/utils/image_utils.py +14 -2
  125. keras/src/utils/jax_layer.py +187 -36
  126. keras/src/utils/module_utils.py +18 -0
  127. keras/src/utils/progbar.py +10 -12
  128. keras/src/utils/rng_utils.py +9 -1
  129. keras/src/version.py +1 -1
  130. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/METADATA +16 -6
  131. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/RECORD +133 -116
  132. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/WHEEL +0 -0
  133. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/top_level.txt +0 -0
@@ -58,7 +58,8 @@ class name_scope:
58
58
  name_scope_stack = global_state.get_global_attribute(
59
59
  "name_scope_stack"
60
60
  )
61
- name_scope_stack.pop()
61
+ if name_scope_stack:
62
+ name_scope_stack.pop()
62
63
 
63
64
 
64
65
  def current_path():
@@ -276,13 +276,13 @@ class Variable:
276
276
  return self._maybe_autocast(self._value)
277
277
 
278
278
  def assign(self, value):
279
- value = self._convert_to_tensor(value, dtype=self.dtype)
279
+ value = self._convert_to_tensor(value, dtype=self._dtype)
280
280
  if not shape_equal(value.shape, self.shape):
281
281
  raise ValueError(
282
282
  "The shape of the target variable and "
283
283
  "the shape of the target value in "
284
284
  "`variable.assign(value)` must match. "
285
- f"variable.shape={self.value.shape}, "
285
+ f"variable.shape={self.shape}, "
286
286
  f"Received: value.shape={value.shape}. "
287
287
  f"Target variable: {self}"
288
288
  )
@@ -399,7 +399,11 @@ class Variable:
399
399
  def __repr__(self):
400
400
  value = None
401
401
  if hasattr(self, "_value") and self._value is not None:
402
- value = backend.core.convert_to_numpy(self._value)
402
+ try:
403
+ value = backend.core.convert_to_numpy(self._value)
404
+ except:
405
+ # In some cases the conversion to numpy can fail.
406
+ pass
403
407
  value_str = f", value={value}" if value is not None else ""
404
408
  return (
405
409
  f"<Variable path={self.path}, shape={self.shape}, "
@@ -595,33 +599,46 @@ def standardize_shape(shape):
595
599
  # `tf.TensorShape` may contain `Dimension` objects.
596
600
  # We need to convert the items in it to either int or `None`
597
601
  shape = shape.as_list()
598
- shape = tuple(shape)
599
602
 
600
603
  if config.backend() == "jax":
601
604
  # Replace `_DimExpr` (dimension expression) with None
605
+ from jax import export as jax_export
606
+
602
607
  shape = tuple(
603
- [None if "_DimExpr" in str(type(d)) else d for d in shape]
608
+ None if jax_export.is_symbolic_dim(d) else d for d in shape
604
609
  )
605
610
 
606
- if config.backend() == "torch":
607
- # `shape` might be `torch.Size`. We need to convert the items in it to
608
- # either int or `None`
609
- shape = tuple(map(lambda x: int(x) if x is not None else None, shape))
610
-
611
- for e in shape:
612
- if e is None:
611
+ # Handle dimensions that are not ints and not None, verify they're >= 0.
612
+ standardized_shape = []
613
+ for d in shape:
614
+ if d is None:
615
+ standardized_shape.append(d)
613
616
  continue
614
- if not is_int_dtype(type(e)):
617
+
618
+ # Reject these even if they can be cast to int successfully.
619
+ if isinstance(d, (str, float)):
615
620
  raise ValueError(
616
621
  f"Cannot convert '{shape}' to a shape. "
617
- f"Found invalid entry '{e}' of type '{type(e)}'. "
622
+ f"Found invalid dimension '{d}' of type '{type(d)}'. "
618
623
  )
619
- if e < 0:
624
+
625
+ try:
626
+ # Cast numpy scalars, tf constant tensors, etc.
627
+ d = int(d)
628
+ except Exception as e:
629
+ raise ValueError(
630
+ f"Cannot convert '{shape}' to a shape. "
631
+ f"Found invalid dimension '{d}' of type '{type(d)}'. "
632
+ ) from e
633
+ if d < 0:
620
634
  raise ValueError(
621
635
  f"Cannot convert '{shape}' to a shape. "
622
636
  "Negative dimensions are not allowed."
623
637
  )
624
- return shape
638
+ standardized_shape.append(d)
639
+
640
+ # This also turns subclasses of `tuple` (e.g. `torch.Size`) to plain tuple.
641
+ return tuple(standardized_shape)
625
642
 
626
643
 
627
644
  def shape_equal(a_shape, b_shape):
@@ -30,9 +30,7 @@ class JaxVariable(KerasVariable):
30
30
  self._layout = layout
31
31
  super().__init__(*args, **kwargs)
32
32
 
33
- def _initialize(self, value):
34
- # Note that variable.shape is needed by distribution_lib
35
- self._shape = self._validate_shape(value.shape)
33
+ def _initialize_layout(self):
36
34
  # We can't import the keras/distribution/distribution_lib
37
35
  # due to circular dependency.
38
36
  distribution = global_state.get_global_attribute("distribution")
@@ -44,8 +42,28 @@ class JaxVariable(KerasVariable):
44
42
  self._layout = tensor_layout.backend_layout
45
43
  else:
46
44
  self._layout = tensor_layout
45
+
46
+ def _initialize(self, value):
47
+ # Note that variable.shape is needed by distribution_lib
48
+ self._shape = self._validate_shape(value.shape)
49
+ self._initialize_layout()
47
50
  self._direct_assign(value)
48
51
 
52
+ def _initialize_with_initializer(self, initializer):
53
+ self._initialize_layout()
54
+ layout = self._layout
55
+ shape = self._shape
56
+ if should_shard_at_init(layout, shape):
57
+ jitted_initializer = jax.jit(
58
+ initializer.__call__,
59
+ out_shardings=layout,
60
+ static_argnames=["shape", "dtype"],
61
+ )
62
+ value = jitted_initializer(shape=self._shape, dtype=self._dtype)
63
+ self._value = value
64
+ else:
65
+ super()._initialize_with_initializer(initializer)
66
+
49
67
  def _direct_assign(self, value):
50
68
  if self._layout is not None:
51
69
  value = distribution_lib.distribute_variable(value, self._layout)
@@ -112,6 +130,12 @@ if config.is_nnx_enabled():
112
130
  # The real value is now set in self._value, sync it to raw_value
113
131
  object.__setattr__(self, "raw_value", self._value)
114
132
 
133
+ def _initialize_with_initializer(self, initializer):
134
+ value = self._convert_to_tensor(
135
+ initializer(self._shape, dtype=self._dtype)
136
+ )
137
+ self._initialize(value)
138
+
115
139
  @property
116
140
  def _value(self):
117
141
  if hasattr(self, "raw_value"):
@@ -234,6 +258,71 @@ if config.is_nnx_enabled():
234
258
 
235
259
  Variable = NnxVariable
236
260
 
261
+ def _flatten_nnx_variable(variable):
262
+ children = (variable.raw_value,)
263
+ # We copy __dict__ to avoid side effects
264
+ keras_state = variable.__dict__.copy()
265
+ # Remove elements that might be problematic or redundant if
266
+ # nnx.Variable's __getstate__
267
+ keras_state.pop("raw_value", None)
268
+ aux_data = (
269
+ variable._var_metadata,
270
+ getattr(variable, "_trace_state", None),
271
+ keras_state,
272
+ )
273
+ return children, aux_data
274
+
275
+ def _unflatten_nnx_variable(aux_data, children):
276
+ var_metadata, trace_state, keras_state = aux_data
277
+ raw_value = children[0]
278
+
279
+ # Create uninitialized instance
280
+ variable = NnxVariable.__new__(NnxVariable)
281
+
282
+ # Restore state
283
+ variable._var_metadata = var_metadata
284
+ if trace_state is not None:
285
+ variable._trace_state = trace_state
286
+ variable.__dict__.update(keras_state)
287
+ variable.raw_value = raw_value
288
+
289
+ return variable
290
+
291
+ try:
292
+ jax.tree_util.register_pytree_node(
293
+ NnxVariable,
294
+ _flatten_nnx_variable,
295
+ _unflatten_nnx_variable,
296
+ )
297
+ except ValueError:
298
+ pass
299
+
300
+ def __setattr__(self, name, value):
301
+ # Mirror Keras attributes to _var_metadata to ensure persistence
302
+ # if the Pytree registration is not respected by NNX.
303
+ if (
304
+ name != "_var_metadata"
305
+ and name not in ("_raw_value", "_trace_state")
306
+ and hasattr(self, "_var_metadata")
307
+ ):
308
+ self._var_metadata[name] = value
309
+
310
+ object.__setattr__(self, name, value)
311
+
312
+ NnxVariable.__setattr__ = __setattr__
313
+
314
+
315
+ def should_shard_at_init(init_layout, shape):
316
+ if not isinstance(init_layout, jax.sharding.NamedSharding):
317
+ return False
318
+
319
+ if all(dim is None for dim in init_layout.spec):
320
+ return False
321
+
322
+ size_threshold = 250 * 1024 * 1024
323
+ array_size = np.prod(shape) * 4
324
+ return array_size >= size_threshold
325
+
237
326
 
238
327
  def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
239
328
  if ragged:
@@ -27,6 +27,20 @@ def list_devices(device_type=None):
27
27
  return [f"{device.platform}:{device.id}" for device in jax_devices]
28
28
 
29
29
 
30
+ def get_device_count(device_type=None):
31
+ """Returns the number of available JAX devices.
32
+ Args:
33
+ device_type: Optional device type to count (e.g., "cpu", "gpu", "tpu").
34
+ If `None`, it defaults to counting "gpu" or "tpu" devices if
35
+ available, otherwise it counts "cpu" devices. It does not
36
+ return the sum of all device types.
37
+ Returns:
38
+ int: The total number of JAX devices for the specified type.
39
+ """
40
+ device_type = device_type.lower() if device_type else None
41
+ return jax.device_count(device_type)
42
+
43
+
30
44
  def distribute_variable(value, layout):
31
45
  """Create a distributed variable for JAX.
32
46
 
@@ -146,13 +160,13 @@ def initialize_rng():
146
160
  # Check if the global seed generator is set and ensure it has an initialized
147
161
  # seed. Otherwise, reset the seed to the global seed.
148
162
  global_seed_generator = global_state.get_global_attribute(
149
- "global_seed_generator"
163
+ seed_generator.GLOBAL_SEED_GENERATOR
150
164
  )
151
165
  if global_seed_generator is not None:
152
166
  seed = global_seed_generator.get_config()["seed"]
153
167
  if seed is None:
154
168
  global_state.set_global_attribute(
155
- "global_seed_generator",
169
+ seed_generator.GLOBAL_SEED_GENERATOR,
156
170
  seed_generator.SeedGenerator(
157
171
  seed=global_seed,
158
172
  name=global_seed_generator.name,
@@ -97,3 +97,7 @@ def lstsq(a, b, rcond=None):
97
97
  a = convert_to_tensor(a)
98
98
  b = convert_to_tensor(b)
99
99
  return jnp.linalg.lstsq(a, b, rcond=rcond)[0]
100
+
101
+
102
+ def jvp(fun, primals, tangents, has_aux=False):
103
+ return jax.jvp(fun, primals, tangents, has_aux=has_aux)