keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +13 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +3 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +13 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +9 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/name_scope.py +2 -1
  28. keras/src/backend/common/variables.py +30 -15
  29. keras/src/backend/jax/core.py +92 -3
  30. keras/src/backend/jax/distribution_lib.py +16 -2
  31. keras/src/backend/jax/linalg.py +4 -0
  32. keras/src/backend/jax/nn.py +509 -29
  33. keras/src/backend/jax/numpy.py +59 -8
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +311 -1
  37. keras/src/backend/numpy/numpy.py +65 -2
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +943 -189
  43. keras/src/backend/tensorflow/layer.py +43 -9
  44. keras/src/backend/tensorflow/linalg.py +24 -0
  45. keras/src/backend/tensorflow/nn.py +545 -1
  46. keras/src/backend/tensorflow/numpy.py +250 -50
  47. keras/src/backend/torch/core.py +3 -1
  48. keras/src/backend/torch/linalg.py +4 -0
  49. keras/src/backend/torch/nn.py +125 -0
  50. keras/src/backend/torch/numpy.py +80 -2
  51. keras/src/callbacks/__init__.py +1 -0
  52. keras/src/callbacks/model_checkpoint.py +5 -0
  53. keras/src/callbacks/orbax_checkpoint.py +332 -0
  54. keras/src/callbacks/terminate_on_nan.py +54 -5
  55. keras/src/datasets/cifar10.py +5 -0
  56. keras/src/distillation/__init__.py +1 -0
  57. keras/src/distillation/distillation_loss.py +390 -0
  58. keras/src/distillation/distiller.py +598 -0
  59. keras/src/distribution/distribution_lib.py +14 -0
  60. keras/src/dtype_policies/__init__.py +2 -0
  61. keras/src/dtype_policies/dtype_policy.py +90 -1
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/multi_head_attention.py +4 -1
  70. keras/src/layers/core/dense.py +241 -111
  71. keras/src/layers/core/einsum_dense.py +316 -131
  72. keras/src/layers/core/embedding.py +84 -94
  73. keras/src/layers/core/input_layer.py +1 -0
  74. keras/src/layers/core/reversible_embedding.py +399 -0
  75. keras/src/layers/input_spec.py +17 -17
  76. keras/src/layers/layer.py +45 -15
  77. keras/src/layers/merging/dot.py +4 -1
  78. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  79. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  80. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  81. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  82. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  83. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  84. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  85. keras/src/layers/preprocessing/discretization.py +6 -5
  86. keras/src/layers/preprocessing/feature_space.py +8 -4
  87. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  88. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  89. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  90. keras/src/layers/preprocessing/index_lookup.py +19 -1
  91. keras/src/layers/preprocessing/normalization.py +14 -1
  92. keras/src/layers/regularization/dropout.py +43 -1
  93. keras/src/layers/rnn/rnn.py +19 -0
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/losses/losses.py +24 -0
  96. keras/src/metrics/confusion_metrics.py +7 -6
  97. keras/src/models/cloning.py +4 -0
  98. keras/src/models/functional.py +11 -3
  99. keras/src/models/model.py +172 -34
  100. keras/src/ops/image.py +257 -20
  101. keras/src/ops/linalg.py +93 -0
  102. keras/src/ops/nn.py +258 -0
  103. keras/src/ops/numpy.py +569 -36
  104. keras/src/optimizers/muon.py +65 -31
  105. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  106. keras/src/quantizers/__init__.py +14 -1
  107. keras/src/quantizers/awq.py +361 -0
  108. keras/src/quantizers/awq_config.py +140 -0
  109. keras/src/quantizers/awq_core.py +217 -0
  110. keras/src/quantizers/gptq.py +2 -8
  111. keras/src/quantizers/gptq_config.py +36 -1
  112. keras/src/quantizers/gptq_core.py +65 -79
  113. keras/src/quantizers/quantization_config.py +246 -0
  114. keras/src/quantizers/quantizers.py +127 -61
  115. keras/src/quantizers/utils.py +23 -0
  116. keras/src/random/seed_generator.py +6 -4
  117. keras/src/saving/file_editor.py +81 -6
  118. keras/src/saving/orbax_util.py +26 -0
  119. keras/src/saving/saving_api.py +37 -14
  120. keras/src/saving/saving_lib.py +1 -1
  121. keras/src/testing/__init__.py +1 -0
  122. keras/src/testing/test_case.py +45 -5
  123. keras/src/utils/backend_utils.py +31 -4
  124. keras/src/utils/dataset_utils.py +234 -35
  125. keras/src/utils/file_utils.py +49 -11
  126. keras/src/utils/image_utils.py +14 -2
  127. keras/src/utils/jax_layer.py +244 -55
  128. keras/src/utils/module_utils.py +29 -0
  129. keras/src/utils/progbar.py +10 -2
  130. keras/src/utils/rng_utils.py +9 -1
  131. keras/src/utils/tracking.py +5 -5
  132. keras/src/version.py +1 -1
  133. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  134. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
  135. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  136. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,7 @@
1
+ import functools
1
2
  import inspect
3
+ import itertools
4
+ import string
2
5
 
3
6
  import numpy as np
4
7
 
@@ -8,10 +11,27 @@ from keras.src.api_export import keras_export
8
11
  from keras.src.backend.common.variables import is_float_dtype
9
12
  from keras.src.backend.common.variables import standardize_dtype
10
13
  from keras.src.layers.layer import Layer
14
+ from keras.src.random.seed_generator import draw_seed
11
15
  from keras.src.saving import serialization_lib
12
16
  from keras.src.utils import jax_utils
13
17
  from keras.src.utils import tracking
14
18
  from keras.src.utils.module_utils import jax
19
+ from keras.src.utils.module_utils import tensorflow as tf
20
+
21
+ if backend.backend() == "tensorflow":
22
+ tf_no_automatic_dependency_tracking = (
23
+ tf.__internal__.tracking.no_automatic_dependency_tracking
24
+ )
25
+ else:
26
+
27
+ def tf_no_automatic_dependency_tracking(fn):
28
+ return fn
29
+
30
+
31
+ def _convert_to_jax_key(tensor):
32
+ if backend.backend() == "tensorflow":
33
+ return tf.bitcast(tensor, tf.uint32)[0]
34
+ return tensor
15
35
 
16
36
 
17
37
  @keras_export("keras.layers.JaxLayer")
@@ -219,21 +239,15 @@ class JaxLayer(Layer):
219
239
  seed=None,
220
240
  **kwargs,
221
241
  ):
222
- if backend.backend() != "jax":
242
+ if backend.backend() not in ["jax", "tensorflow"]:
223
243
  raise ValueError(
224
- "JaxLayer is only supported with the JAX backend. Current "
225
- f"backend: {backend.backend()}"
226
- )
227
-
228
- if init_fn is None and params is None and state is None:
229
- raise ValueError(
230
- "`init_fn`, `params` and `state` cannot all be `None`."
244
+ f"{self.__class__.__name__} is only supported with the JAX or"
245
+ f" Tensorflow backend. Current backend: {backend.backend()}"
231
246
  )
232
247
 
233
248
  super().__init__(**kwargs)
234
249
  self.call_fn = call_fn
235
250
  self.init_fn = init_fn
236
- self.seed_generator = backend.random.SeedGenerator(seed)
237
251
  self.tracked_params = self._create_variables(params, trainable=True)
238
252
  self.tracked_state = self._create_variables(state, trainable=False)
239
253
  if self.params is not None or self.state is not None:
@@ -245,13 +259,35 @@ class JaxLayer(Layer):
245
259
  {"params", "state", "rng", "inputs", "training"},
246
260
  {"inputs"},
247
261
  )
248
- self.has_state = "state" in self.call_fn_arguments
262
+ self.call_fn_has_params = "params" in self.call_fn_arguments
263
+ self.call_fn_has_state = "state" in self.call_fn_arguments
264
+ call_fn_has_rng = "rng" in self.call_fn_arguments
265
+
266
+ if call_fn_has_rng:
267
+ self.seed_generator = backend.random.SeedGenerator(seed)
268
+ else:
269
+ self.seed_generator = None
270
+
271
+ if (
272
+ init_fn is None
273
+ and params is None
274
+ and state is None
275
+ and (self.call_fn_has_params or self.call_fn_has_state)
276
+ ):
277
+ raise ValueError(
278
+ "`init_fn`, `params` and `state` cannot all be `None` when "
279
+ "`call_fn` takes a `params` or a `state` argument."
280
+ )
249
281
 
250
282
  if init_fn:
251
283
  self.init_fn_arguments = self._validate_signature(
252
284
  init_fn, "init_fn", {"rng", "inputs", "training"}, {"inputs"}
253
285
  )
254
286
 
287
+ # Attributes for jax2tf functions
288
+ self.jax2tf_training_false_fn = None
289
+ self.jax2tf_training_true_fn = None
290
+
255
291
  def _validate_signature(self, fn, fn_name, allowed, required):
256
292
  fn_parameters = inspect.signature(fn).parameters
257
293
  for parameter_name in required:
@@ -272,7 +308,81 @@ class JaxLayer(Layer):
272
308
 
273
309
  return parameter_names
274
310
 
311
+ def _get_jax2tf_input_shape(self, input_shape):
312
+ """Convert input shape in a format suitable for `jax2tf`.
313
+
314
+ `jax2tf` expects a letter for each unknown dimension, which allows
315
+ correlated dimensions. Since correlated dimensions are not supported by
316
+ Keras, we simply use 'a', 'b', 'c'..., for each unknown dimension. We
317
+ however use 'batch' for dimension 0 if not defined to correlate the
318
+ batch size across inputs.
319
+
320
+ Example (spaces added for readability):
321
+ ```
322
+ input_shape: (None , 4 , None, None, 5 )
323
+ result: "(batch, 4 , a , b , 5 )"
324
+ ```
325
+
326
+ Args:
327
+ input_shape: a single shape or a structure of shapes for the inputs.
328
+ Returns:
329
+ the shape or shapes structure in the `jax2tf` format as strings.
330
+ """
331
+ dim_names = itertools.chain(
332
+ string.ascii_lowercase, # a, b, ... z
333
+ itertools.starmap( # aa, ab, ... az, ba, bb, ... zz
334
+ lambda a, b: a + b,
335
+ itertools.product(string.ascii_lowercase, repeat=2),
336
+ ),
337
+ )
338
+
339
+ def get_single_jax2tf_shape(shape):
340
+ jax2tf_shape = []
341
+
342
+ for index, dim in enumerate(shape):
343
+ if dim is not None:
344
+ jax2tf_shape.append(str(dim))
345
+ elif index == 0:
346
+ jax2tf_shape.append("batch")
347
+ else:
348
+ jax2tf_shape.append(next(dim_names))
349
+
350
+ return "(" + ", ".join(jax2tf_shape) + ")"
351
+
352
+ res = tree.map_shape_structure(get_single_jax2tf_shape, input_shape)
353
+ return res
354
+
355
+ def _jax2tf_convert(self, fn, polymorphic_shapes):
356
+ from jax.experimental import jax2tf
357
+
358
+ converted_fn = jax2tf.convert(fn, polymorphic_shapes=polymorphic_shapes)
359
+ # Autograph won't work with the output of jax2tf.
360
+ converted_fn = tf.autograph.experimental.do_not_convert(converted_fn)
361
+ return converted_fn
362
+
363
+ def _partial_with_positional(self, fn, index, value):
364
+ """Return a new partial with one positional argument set to a value.
365
+
366
+ This is needed because `jax2tf` only supports positional arguments and
367
+ `functools.partial` only supports setting positional arguments starting
368
+ from the left. Our use case is the `training` argument which is
369
+ typically the righmost argument.
370
+
371
+ Args:
372
+ fn: the function to wrap.
373
+ index: the index of the positional argument to set to `value`.
374
+ value: the value for the positional argument at `index`.
375
+ """
376
+
377
+ @functools.wraps(fn)
378
+ def wrapper(*args):
379
+ args = args[0:index] + (value,) + args[index:]
380
+ return fn(*args)
381
+
382
+ return wrapper
383
+
275
384
  @tracking.no_automatic_dependency_tracking
385
+ @tf_no_automatic_dependency_tracking
276
386
  def _create_variables(self, values, trainable):
277
387
  """Create a structure of variables from a structure of JAX arrays.
278
388
 
@@ -296,14 +406,14 @@ class JaxLayer(Layer):
296
406
 
297
407
  def create_variable(value):
298
408
  if backend.is_tensor(value) or isinstance(
299
- value, (np.ndarray, np.generic)
409
+ value, (np.ndarray, np.generic, jax.Array)
300
410
  ):
301
411
  dtype = value.dtype
302
412
  if is_float_dtype(dtype):
303
413
  dtype = None # Use the layer dtype policy
304
414
  return self.add_weight(
305
415
  value.shape,
306
- initializer=value,
416
+ initializer=backend.convert_to_tensor(value),
307
417
  dtype=dtype,
308
418
  trainable=trainable,
309
419
  )
@@ -331,46 +441,69 @@ class JaxLayer(Layer):
331
441
  flat_variables, _ = jax.tree_util.tree_flatten(variables)
332
442
  return flat_variables
333
443
 
444
+ def _get_init_seed(self):
445
+ """
446
+ Returns a single seed as a tensor of shape [2].
447
+
448
+ Call this within `_get_init_rng()` to obtain a new seed.
449
+
450
+ Returns:
451
+ A native tensor of shape [2] and the backend dtype for seeds.
452
+ """
453
+ # Use the global SeedGenerator.
454
+ return draw_seed(None)
455
+
334
456
  def _get_init_rng(self):
335
457
  """
336
- Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `init_fn`.
458
+ Returns a seed or seeds to pass as the `rng` argument of `init_fn`.
337
459
 
338
- By default, this returns a single `PRNGKey` retrieved by calling
339
- `self.seed_generator.next()`. Override this to return a different
340
- structure.
460
+ By default, this returns a single seed. Override this to return a
461
+ different structure. Overrides should use `self._get_init_seed()` to
462
+ obtain new seeds.
341
463
 
342
464
  Returns:
343
- a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as
344
- the `rng` argument of `init_fn`.
465
+ RNG key or structure of keys as tensors of shape [2] and the backend
466
+ dtype for seeds.
467
+ """
468
+ return self._get_init_seed()
469
+
470
+ def _get_call_seed(self):
471
+ """
472
+ Returns a single seed as a tensor of shape [2].
473
+
474
+ Call this within `_get_call_rng()` to obtain a new seed.
475
+
476
+ Returns:
477
+ A native tensor of shape [2] and the backend dtype for seeds.
345
478
  """
346
479
  return self.seed_generator.next()
347
480
 
348
481
  def _get_call_rng(self, training):
349
482
  """
350
- Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `call_fn`.
483
+ Returns a seed or seeds to pass as the `rng` argument of `call_fn`.
351
484
 
352
- By default, this returns a single `PRNGKey` retrieved by calling
353
- `self.seed_generator.next()` when `training` is `True`, and `None` when
354
- `training` is `False`. Override this to return a different structure or
355
- to pass RNGs in inference mode too.
485
+ By default, this returns a seed when `training` is `True`, and `None`
486
+ when `training` is `False`. Override this to return a different
487
+ structure or to pass seeds in inference mode too. Overrides should use
488
+ `self._get_call_seed()` to obtain seeds.
356
489
 
357
490
  Returns:
358
- a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as
359
- the `rng` argument of `call_fn`.
491
+ RNG key or structure of keys as tensors of shape [2] and the backend
492
+ dtype for seeds.
360
493
  """
361
494
  if training:
362
- return self.seed_generator.next()
495
+ return self._get_call_seed()
363
496
  else:
364
497
  return None
365
498
 
366
- def build(self, input_shape):
367
- if self.params is not None or self.state is not None:
368
- return
369
-
370
- if jax_utils.is_in_jax_tracing_scope():
499
+ def _initialize_weights(self, input_shape):
500
+ if jax_utils.is_in_jax_tracing_scope() or tf.inside_function():
371
501
  # This exception is not actually shown, it is caught and a detailed
372
502
  # warning about calling 'build' is printed.
373
- raise ValueError("'JaxLayer' cannot be built in tracing scope")
503
+ raise ValueError(
504
+ "'JaxLayer' cannot be built in tracing scope"
505
+ "or inside tf function"
506
+ )
374
507
 
375
508
  # Initialize `params` and `state` if needed by calling `init_fn`.
376
509
  def create_input(shape):
@@ -381,14 +514,19 @@ class JaxLayer(Layer):
381
514
  init_args = []
382
515
  for argument_name in self.init_fn_arguments:
383
516
  if argument_name == "rng":
384
- init_args.append(self._get_init_rng())
517
+ init_args.append(
518
+ jax.tree_util.tree_map(
519
+ lambda x: jax.numpy.array(_convert_to_jax_key(x)),
520
+ self._get_init_rng(),
521
+ )
522
+ )
385
523
  elif argument_name == "inputs":
386
524
  init_args.append(init_inputs)
387
525
  elif argument_name == "training":
388
526
  init_args.append(True)
389
527
 
390
528
  init_result = self.init_fn(*init_args)
391
- if self.has_state:
529
+ if self.call_fn_has_state:
392
530
  init_params, init_state = init_result
393
531
  else:
394
532
  init_params, init_state = init_result, None
@@ -398,6 +536,49 @@ class JaxLayer(Layer):
398
536
  )
399
537
  self.tracked_state = self._create_variables(init_state, trainable=False)
400
538
 
539
+ def build(self, input_shape):
540
+ if (
541
+ self.params is None
542
+ and self.state is None
543
+ and (self.call_fn_has_params or self.call_fn_has_state)
544
+ ):
545
+ self._initialize_weights(input_shape)
546
+
547
+ if backend.backend() == "tensorflow":
548
+ polymorphic_shapes = []
549
+ for argument in self.call_fn_arguments:
550
+ if argument == "inputs":
551
+ polymorphic_shapes.append(
552
+ self._get_jax2tf_input_shape(input_shape)
553
+ )
554
+ elif argument != "training":
555
+ # params, state, rng
556
+ polymorphic_shapes.append("...")
557
+
558
+ if "training" in self.call_fn_arguments:
559
+ training_argument_index = self.call_fn_arguments.index(
560
+ "training"
561
+ )
562
+ self.jax2tf_training_false_fn = self._jax2tf_convert(
563
+ self._partial_with_positional(
564
+ self.call_fn, training_argument_index, False
565
+ ),
566
+ polymorphic_shapes,
567
+ )
568
+ self.jax2tf_training_true_fn = self._jax2tf_convert(
569
+ self._partial_with_positional(
570
+ self.call_fn, training_argument_index, True
571
+ ),
572
+ polymorphic_shapes,
573
+ )
574
+ else:
575
+ self.jax2tf_training_false_fn = self._jax2tf_convert(
576
+ self.call_fn,
577
+ polymorphic_shapes,
578
+ )
579
+ self.jax2tf_training_true_fn = None
580
+ super().build(input_shape)
581
+
401
582
  def call(self, inputs, training=False):
402
583
  def unwrap_variable(variable):
403
584
  return None if variable is None else variable.value
@@ -413,11 +594,16 @@ class JaxLayer(Layer):
413
594
  jax.tree_util.tree_map(unwrap_variable, self.state)
414
595
  )
415
596
  elif argument_name == "rng":
416
- call_args.append(self._get_call_rng(training))
597
+ call_args.append(
598
+ jax.tree_util.tree_map(
599
+ _convert_to_jax_key, self._get_call_rng(training)
600
+ )
601
+ )
417
602
  elif argument_name == "inputs":
418
603
  call_args.append(inputs)
419
604
  elif argument_name == "training":
420
- call_args.append(training)
605
+ if backend.backend() == "jax":
606
+ call_args.append(training)
421
607
 
422
608
  def assign_state_to_variable(value, variable):
423
609
  # This exists only to make debugging this error case easier.
@@ -429,14 +615,23 @@ class JaxLayer(Layer):
429
615
  )
430
616
  variable.assign(value)
431
617
 
432
- if self.has_state:
433
- predictions, new_state = self.call_fn(*call_args)
434
- jax.tree_util.tree_map(
435
- assign_state_to_variable, new_state, self.state
436
- )
437
- return predictions
438
- else:
439
- return self.call_fn(*call_args)
618
+ def call_with_fn(fn):
619
+ if self.call_fn_has_state:
620
+ predictions, new_state = fn(*call_args)
621
+ jax.tree_util.tree_map(
622
+ assign_state_to_variable, new_state, self.state
623
+ )
624
+ return predictions
625
+ else:
626
+ return fn(*call_args)
627
+
628
+ if backend.backend() == "jax":
629
+ return call_with_fn(self.call_fn)
630
+ elif backend.backend() == "tensorflow":
631
+ if training and self.jax2tf_training_true_fn is not None:
632
+ return call_with_fn(self.jax2tf_training_true_fn)
633
+ else:
634
+ return call_with_fn(self.jax2tf_training_false_fn)
440
635
 
441
636
  def get_config(self):
442
637
  config = {
@@ -554,18 +749,12 @@ class FlaxLayer(JaxLayer):
554
749
  **kwargs,
555
750
  ):
556
751
  # Late import to only require Flax when this is used.
557
- from flax.core import scope as flax_scope
558
-
559
- if backend.backend() != "jax":
560
- raise ValueError(
561
- "FlaxLayer is only supported with the JAX backend. Current "
562
- f"backend: {backend.backend()}"
563
- )
752
+ from flax.linen import DenyList
564
753
 
565
754
  self.module = module
566
755
  self.method = method
567
756
 
568
- apply_mutable = flax_scope.DenyList(["params"])
757
+ apply_mutable = DenyList(["params"])
569
758
 
570
759
  def apply_with_training(params, state, rng, inputs, training):
571
760
  return self.module.apply(
@@ -650,13 +839,13 @@ class FlaxLayer(JaxLayer):
650
839
 
651
840
  def _get_init_rng(self):
652
841
  return {
653
- "params": self.seed_generator.next(),
654
- "dropout": self.seed_generator.next(),
842
+ "params": self._get_init_seed(),
843
+ "dropout": self._get_init_seed(),
655
844
  }
656
845
 
657
846
  def _get_call_rng(self, training):
658
847
  if training:
659
- return {"dropout": self.seed_generator.next()}
848
+ return {"dropout": self._get_call_seed()}
660
849
  else:
661
850
  return {}
662
851
 
@@ -39,11 +39,31 @@ class LazyModule:
39
39
  return f"LazyModule({self.name})"
40
40
 
41
41
 
42
+ class OrbaxLazyModule(LazyModule):
43
+ def initialize(self):
44
+ try:
45
+ parent_module = importlib.import_module("orbax.checkpoint")
46
+ self.module = parent_module.v1
47
+ self.parent_module = parent_module
48
+ except ImportError:
49
+ raise ImportError(self.import_error_msg)
50
+
51
+ def __getattr__(self, name):
52
+ if name == "_api_export_path":
53
+ raise AttributeError
54
+ if self.module is None:
55
+ self.initialize()
56
+ if name == "multihost":
57
+ return self.parent_module.multihost
58
+ return getattr(self.module, name)
59
+
60
+
42
61
  tensorflow = LazyModule("tensorflow")
43
62
  gfile = LazyModule("tensorflow.io.gfile", pip_name="tensorflow")
44
63
  tensorflow_io = LazyModule("tensorflow_io")
45
64
  scipy = LazyModule("scipy")
46
65
  jax = LazyModule("jax")
66
+ h5py = LazyModule("h5py")
47
67
  torch_xla = LazyModule(
48
68
  "torch_xla",
49
69
  import_error_msg=(
@@ -59,3 +79,12 @@ optree = LazyModule("optree")
59
79
  dmtree = LazyModule("tree")
60
80
  tf2onnx = LazyModule("tf2onnx")
61
81
  grain = LazyModule("grain")
82
+ litert = LazyModule("ai_edge_litert")
83
+ ocp = OrbaxLazyModule(
84
+ "orbax.checkpoint.v1",
85
+ pip_name="orbax-checkpoint",
86
+ import_error_msg=(
87
+ "OrbaxCheckpoint requires the 'orbax-checkpoint' package. "
88
+ "You can install it via pip install orbax-checkpoint"
89
+ ),
90
+ )
@@ -3,6 +3,8 @@ import os
3
3
  import sys
4
4
  import time
5
5
 
6
+ import numpy as np
7
+
6
8
  from keras.src.api_export import keras_export
7
9
  from keras.src.utils import io_utils
8
10
 
@@ -161,7 +163,10 @@ class Progbar:
161
163
  for k in self._values_order:
162
164
  info += f" - {k}:"
163
165
  if isinstance(self._values[k], list):
164
- avg = self._values[k][0] / max(1, self._values[k][1])
166
+ values, count = self._values[k]
167
+ if not isinstance(values, float):
168
+ values = np.mean(values)
169
+ avg = values / max(1, count)
165
170
  if abs(avg) > 1e-3:
166
171
  info += f" {avg:.4f}"
167
172
  else:
@@ -188,7 +193,10 @@ class Progbar:
188
193
  info += f" -{self._format_time(time_per_unit, self.unit_name)}"
189
194
  for k in self._values_order:
190
195
  info += f" - {k}:"
191
- avg = self._values[k][0] / max(1, self._values[k][1])
196
+ values, count = self._values[k]
197
+ if not isinstance(values, float):
198
+ values = np.mean(values)
199
+ avg = values / max(1, count)
192
200
  if avg > 1e-3:
193
201
  info += f" {avg:.4f}"
194
202
  else:
@@ -5,6 +5,7 @@ import numpy as np
5
5
  from keras.src import backend
6
6
  from keras.src.api_export import keras_export
7
7
  from keras.src.backend.common import global_state
8
+ from keras.src.random import seed_generator
8
9
  from keras.src.utils.module_utils import tensorflow as tf
9
10
 
10
11
  GLOBAL_RANDOM_SEED = "global_random_seed"
@@ -20,7 +21,7 @@ def set_random_seed(seed):
20
21
  sources of randomness, or when certain non-deterministic cuDNN ops are
21
22
  involved.
22
23
 
23
- Calling this utility is equivalent to the following:
24
+ Calling this utility does the following:
24
25
 
25
26
  ```python
26
27
  import random
@@ -36,6 +37,9 @@ def set_random_seed(seed):
36
37
  torch.manual_seed(seed)
37
38
  ```
38
39
 
40
+ Additionally, it resets the global Keras `SeedGenerator`, which is used by
41
+ `keras.random` functions when the `seed` is not provided.
42
+
39
43
  Note that the TensorFlow seed is set even if you're not using TensorFlow
40
44
  as your backend framework, since many workflows leverage `tf.data`
41
45
  pipelines (which feature random shuffling). Likewise many workflows
@@ -52,6 +56,10 @@ def set_random_seed(seed):
52
56
 
53
57
  # Store seed in global state so we can query it if set.
54
58
  global_state.set_global_attribute(GLOBAL_RANDOM_SEED, seed)
59
+ # Remove global SeedGenerator, it will be recreated from the seed.
60
+ global_state.set_global_attribute(
61
+ seed_generator.GLOBAL_SEED_GENERATOR, None
62
+ )
55
63
  random.seed(seed)
56
64
  np.random.seed(seed)
57
65
  if tf.available:
@@ -31,13 +31,13 @@ def no_automatic_dependency_tracking(fn):
31
31
  class Tracker:
32
32
  """Attribute tracker, used for e.g. Variable tracking.
33
33
 
34
- Monitors certain attribute types
35
- and put them in appropriate lists in case of a match.
34
+ Monitors certain attribute types and places matching
35
+ objects into user provided tracking collections.
36
36
 
37
37
  Also passively tracks certain mutable collections
38
- (dict, list) so that items added to them later
39
- still get tracked. This is done by wrapping these
40
- collections into an equivalent, tracking-aware object.
38
+ (e.g. dict and list) ensuring that items added after
39
+ initialization are still tracked. This is done by wrapping
40
+ these collections in tracking-aware proxy objects.
41
41
 
42
42
  Example:
43
43
 
keras/src/version.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from keras.src.api_export import keras_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "3.12.0.dev2025100503"
4
+ __version__ = "3.14.0.dev2026011604"
5
5
 
6
6
 
7
7
  @keras_export("keras.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-nightly
3
- Version: 3.12.0.dev2025100503
3
+ Version: 3.14.0.dev2026011604
4
4
  Summary: Multi-backend Keras
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
@@ -8,15 +8,15 @@ Project-URL: Home, https://keras.io/
8
8
  Project-URL: Repository, https://github.com/keras-team/keras
9
9
  Classifier: Development Status :: 4 - Beta
10
10
  Classifier: Programming Language :: Python :: 3
11
- Classifier: Programming Language :: Python :: 3.10
12
11
  Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Programming Language :: Python :: 3.12
13
13
  Classifier: Programming Language :: Python :: 3 :: Only
14
14
  Classifier: Operating System :: Unix
15
15
  Classifier: Operating System :: MacOS
16
16
  Classifier: Intended Audience :: Science/Research
17
17
  Classifier: Topic :: Scientific/Engineering
18
18
  Classifier: Topic :: Software Development
19
- Requires-Python: >=3.10
19
+ Requires-Python: >=3.11
20
20
  Description-Content-Type: text/markdown
21
21
  Requires-Dist: absl-py
22
22
  Requires-Dist: numpy
@@ -56,9 +56,8 @@ pip install keras --upgrade
56
56
 
57
57
  2. Install backend package(s).
58
58
 
59
- To use `keras`, you should also install the backend of choice: `tensorflow`, `jax`, or `torch`.
60
- Note that `tensorflow` is required for using certain Keras 3 features: certain preprocessing layers
61
- as well as `tf.data` pipelines.
59
+ To use `keras`, you should also install the backend of choice: `tensorflow`, `jax`, or `torch`. Additionally,
60
+ The `openvino` backend is available with support for model inference only.
62
61
 
63
62
  ### Local installation
64
63
 
@@ -85,6 +84,17 @@ python pip_build.py --install
85
84
  ./shell/api_gen.sh
86
85
  ```
87
86
 
87
+ ## Backend Compatibility Table
88
+
89
+ The following table lists the minimum supported versions of each backend for the latest stable release of Keras (v3.x):
90
+
91
+ | Backend | Minimum Supported Version |
92
+ |------------|---------------------------|
93
+ | TensorFlow | 2.16.1 |
94
+ | JAX | 0.4.20 |
95
+ | PyTorch | 2.1.0 |
96
+ | OpenVINO | 2025.3.0 |
97
+
88
98
  #### Adding GPU support
89
99
 
90
100
  The `requirements.txt` file will install a CPU-only version of TensorFlow, JAX, and PyTorch. For GPU support, we also