keras-nightly 3.12.0.dev2025092403__py3-none-any.whl → 3.14.0.dev2026010104__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/layers/__init__.py +21 -0
  7. keras/_tf_keras/keras/ops/__init__.py +13 -0
  8. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  9. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  11. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  12. keras/_tf_keras/keras/quantizers/__init__.py +12 -0
  13. keras/callbacks/__init__.py +3 -0
  14. keras/distillation/__init__.py +16 -0
  15. keras/distribution/__init__.py +3 -0
  16. keras/layers/__init__.py +21 -0
  17. keras/ops/__init__.py +13 -0
  18. keras/ops/image/__init__.py +1 -0
  19. keras/ops/linalg/__init__.py +1 -0
  20. keras/ops/nn/__init__.py +3 -0
  21. keras/ops/numpy/__init__.py +9 -0
  22. keras/quantizers/__init__.py +12 -0
  23. keras/src/applications/imagenet_utils.py +4 -1
  24. keras/src/backend/common/backend_utils.py +30 -6
  25. keras/src/backend/common/dtypes.py +1 -1
  26. keras/src/backend/common/name_scope.py +2 -1
  27. keras/src/backend/common/variables.py +33 -16
  28. keras/src/backend/jax/core.py +92 -3
  29. keras/src/backend/jax/distribution_lib.py +16 -2
  30. keras/src/backend/jax/linalg.py +4 -0
  31. keras/src/backend/jax/nn.py +485 -20
  32. keras/src/backend/jax/numpy.py +92 -23
  33. keras/src/backend/jax/optimizer.py +3 -2
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +313 -2
  37. keras/src/backend/numpy/numpy.py +76 -7
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +1030 -185
  43. keras/src/backend/openvino/random.py +7 -14
  44. keras/src/backend/tensorflow/layer.py +43 -9
  45. keras/src/backend/tensorflow/linalg.py +24 -0
  46. keras/src/backend/tensorflow/nn.py +545 -1
  47. keras/src/backend/tensorflow/numpy.py +264 -54
  48. keras/src/backend/torch/core.py +3 -1
  49. keras/src/backend/torch/linalg.py +4 -0
  50. keras/src/backend/torch/nn.py +125 -0
  51. keras/src/backend/torch/numpy.py +84 -8
  52. keras/src/callbacks/__init__.py +1 -0
  53. keras/src/callbacks/callback_list.py +45 -11
  54. keras/src/callbacks/model_checkpoint.py +5 -0
  55. keras/src/callbacks/orbax_checkpoint.py +299 -0
  56. keras/src/callbacks/terminate_on_nan.py +54 -5
  57. keras/src/datasets/cifar10.py +5 -0
  58. keras/src/distillation/__init__.py +1 -0
  59. keras/src/distillation/distillation_loss.py +390 -0
  60. keras/src/distillation/distiller.py +598 -0
  61. keras/src/distribution/distribution_lib.py +14 -0
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/attention.py +1 -1
  70. keras/src/layers/attention/multi_head_attention.py +4 -1
  71. keras/src/layers/core/dense.py +191 -172
  72. keras/src/layers/core/einsum_dense.py +235 -186
  73. keras/src/layers/core/embedding.py +83 -93
  74. keras/src/layers/core/input_layer.py +1 -0
  75. keras/src/layers/core/reversible_embedding.py +390 -0
  76. keras/src/layers/input_spec.py +17 -17
  77. keras/src/layers/layer.py +40 -15
  78. keras/src/layers/merging/dot.py +4 -1
  79. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  80. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  81. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  82. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  83. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  84. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  85. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  86. keras/src/layers/preprocessing/discretization.py +6 -5
  87. keras/src/layers/preprocessing/index_lookup.py +19 -1
  88. keras/src/layers/preprocessing/normalization.py +16 -1
  89. keras/src/layers/regularization/dropout.py +43 -1
  90. keras/src/layers/rnn/gru.py +1 -1
  91. keras/src/layers/rnn/lstm.py +2 -2
  92. keras/src/layers/rnn/rnn.py +19 -0
  93. keras/src/layers/rnn/simple_rnn.py +1 -1
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/metrics/confusion_metrics.py +7 -6
  96. keras/src/models/cloning.py +4 -0
  97. keras/src/models/functional.py +11 -3
  98. keras/src/models/model.py +156 -27
  99. keras/src/ops/image.py +184 -3
  100. keras/src/ops/linalg.py +93 -0
  101. keras/src/ops/nn.py +268 -2
  102. keras/src/ops/numpy.py +541 -43
  103. keras/src/optimizers/adafactor.py +29 -10
  104. keras/src/optimizers/base_optimizer.py +22 -3
  105. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  106. keras/src/optimizers/muon.py +65 -31
  107. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  108. keras/src/quantizers/__init__.py +12 -1
  109. keras/src/quantizers/gptq.py +8 -6
  110. keras/src/quantizers/gptq_config.py +36 -1
  111. keras/src/quantizers/gptq_core.py +150 -78
  112. keras/src/quantizers/quantization_config.py +232 -0
  113. keras/src/quantizers/quantizers.py +114 -38
  114. keras/src/quantizers/utils.py +23 -0
  115. keras/src/random/seed_generator.py +4 -2
  116. keras/src/saving/file_editor.py +81 -6
  117. keras/src/saving/saving_lib.py +1 -1
  118. keras/src/testing/__init__.py +1 -0
  119. keras/src/testing/test_case.py +45 -5
  120. keras/src/trainers/compile_utils.py +14 -5
  121. keras/src/utils/backend_utils.py +31 -4
  122. keras/src/utils/dataset_utils.py +234 -35
  123. keras/src/utils/file_utils.py +49 -11
  124. keras/src/utils/image_utils.py +14 -2
  125. keras/src/utils/jax_layer.py +187 -36
  126. keras/src/utils/module_utils.py +18 -0
  127. keras/src/utils/progbar.py +10 -12
  128. keras/src/utils/rng_utils.py +9 -1
  129. keras/src/version.py +1 -1
  130. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/METADATA +16 -6
  131. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/RECORD +133 -116
  132. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/WHEEL +0 -0
  133. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,7 @@
1
+ import functools
1
2
  import inspect
3
+ import itertools
4
+ import string
2
5
 
3
6
  import numpy as np
4
7
 
@@ -12,6 +15,22 @@ from keras.src.saving import serialization_lib
12
15
  from keras.src.utils import jax_utils
13
16
  from keras.src.utils import tracking
14
17
  from keras.src.utils.module_utils import jax
18
+ from keras.src.utils.module_utils import tensorflow as tf
19
+
20
+ if backend.backend() == "tensorflow":
21
+ tf_no_automatic_dependency_tracking = (
22
+ tf.__internal__.tracking.no_automatic_dependency_tracking
23
+ )
24
+ else:
25
+
26
+ def tf_no_automatic_dependency_tracking(fn):
27
+ return fn
28
+
29
+
30
+ def _convert_to_jax_key(tensor):
31
+ if backend.backend() == "tensorflow":
32
+ return tf.bitcast(tensor, tf.uint32)[0]
33
+ return tensor
15
34
 
16
35
 
17
36
  @keras_export("keras.layers.JaxLayer")
@@ -219,10 +238,10 @@ class JaxLayer(Layer):
219
238
  seed=None,
220
239
  **kwargs,
221
240
  ):
222
- if backend.backend() != "jax":
241
+ if backend.backend() not in ["jax", "tensorflow"]:
223
242
  raise ValueError(
224
- "JaxLayer is only supported with the JAX backend. Current "
225
- f"backend: {backend.backend()}"
243
+ f"{self.__class__.__name__} is only supported with the JAX or"
244
+ f" Tensorflow backend. Current backend: {backend.backend()}"
226
245
  )
227
246
 
228
247
  if init_fn is None and params is None and state is None:
@@ -252,6 +271,10 @@ class JaxLayer(Layer):
252
271
  init_fn, "init_fn", {"rng", "inputs", "training"}, {"inputs"}
253
272
  )
254
273
 
274
+ # Attributes for jax2tf functions
275
+ self.jax2tf_training_false_fn = None
276
+ self.jax2tf_training_true_fn = None
277
+
255
278
  def _validate_signature(self, fn, fn_name, allowed, required):
256
279
  fn_parameters = inspect.signature(fn).parameters
257
280
  for parameter_name in required:
@@ -272,7 +295,81 @@ class JaxLayer(Layer):
272
295
 
273
296
  return parameter_names
274
297
 
298
+ def _get_jax2tf_input_shape(self, input_shape):
299
+ """Convert input shape in a format suitable for `jax2tf`.
300
+
301
+ `jax2tf` expects a letter for each unknown dimension, which allows
302
+ correlated dimensions. Since correlated dimensions are not supported by
303
+ Keras, we simply use 'a', 'b', 'c'..., for each unknown dimension. We
304
+ however use 'batch' for dimension 0 if not defined to correlate the
305
+ batch size across inputs.
306
+
307
+ Example (spaces added for readability):
308
+ ```
309
+ input_shape: (None , 4 , None, None, 5 )
310
+ result: "(batch, 4 , a , b , 5 )"
311
+ ```
312
+
313
+ Args:
314
+ input_shape: a single shape or a structure of shapes for the inputs.
315
+ Returns:
316
+ the shape or shapes structure in the `jax2tf` format as strings.
317
+ """
318
+ dim_names = itertools.chain(
319
+ string.ascii_lowercase, # a, b, ... z
320
+ itertools.starmap( # aa, ab, ... az, ba, bb, ... zz
321
+ lambda a, b: a + b,
322
+ itertools.product(string.ascii_lowercase, repeat=2),
323
+ ),
324
+ )
325
+
326
+ def get_single_jax2tf_shape(shape):
327
+ jax2tf_shape = []
328
+
329
+ for index, dim in enumerate(shape):
330
+ if dim is not None:
331
+ jax2tf_shape.append(str(dim))
332
+ elif index == 0:
333
+ jax2tf_shape.append("batch")
334
+ else:
335
+ jax2tf_shape.append(next(dim_names))
336
+
337
+ return "(" + ", ".join(jax2tf_shape) + ")"
338
+
339
+ res = tree.map_shape_structure(get_single_jax2tf_shape, input_shape)
340
+ return res
341
+
342
+ def _jax2tf_convert(self, fn, polymorphic_shapes):
343
+ from jax.experimental import jax2tf
344
+
345
+ converted_fn = jax2tf.convert(fn, polymorphic_shapes=polymorphic_shapes)
346
+ # Autograph won't work with the output of jax2tf.
347
+ converted_fn = tf.autograph.experimental.do_not_convert(converted_fn)
348
+ return converted_fn
349
+
350
+ def _partial_with_positional(self, fn, index, value):
351
+ """Return a new partial with one positional argument set to a value.
352
+
353
+ This is needed because `jax2tf` only supports positional arguments and
354
+ `functools.partial` only supports setting positional arguments starting
355
+ from the left. Our use case is the `training` argument which is
356
+ typically the righmost argument.
357
+
358
+ Args:
359
+ fn: the function to wrap.
360
+ index: the index of the positional argument to set to `value`.
361
+ value: the value for the positional argument at `index`.
362
+ """
363
+
364
+ @functools.wraps(fn)
365
+ def wrapper(*args):
366
+ args = args[0:index] + (value,) + args[index:]
367
+ return fn(*args)
368
+
369
+ return wrapper
370
+
275
371
  @tracking.no_automatic_dependency_tracking
372
+ @tf_no_automatic_dependency_tracking
276
373
  def _create_variables(self, values, trainable):
277
374
  """Create a structure of variables from a structure of JAX arrays.
278
375
 
@@ -296,14 +393,14 @@ class JaxLayer(Layer):
296
393
 
297
394
  def create_variable(value):
298
395
  if backend.is_tensor(value) or isinstance(
299
- value, (np.ndarray, np.generic)
396
+ value, (np.ndarray, np.generic, jax.Array)
300
397
  ):
301
398
  dtype = value.dtype
302
399
  if is_float_dtype(dtype):
303
400
  dtype = None # Use the layer dtype policy
304
401
  return self.add_weight(
305
402
  value.shape,
306
- initializer=value,
403
+ initializer=backend.convert_to_tensor(value),
307
404
  dtype=dtype,
308
405
  trainable=trainable,
309
406
  )
@@ -333,44 +430,46 @@ class JaxLayer(Layer):
333
430
 
334
431
  def _get_init_rng(self):
335
432
  """
336
- Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `init_fn`.
433
+ Returns a key in form of the backend array of size 2 dtype uint32
434
+ to pass to `init_fn`.
337
435
 
338
- By default, this returns a single `PRNGKey` retrieved by calling
436
+ By default, this returns a Jax or TF array of size 2 by calling
339
437
  `self.seed_generator.next()`. Override this to return a different
340
438
  structure.
341
439
 
342
440
  Returns:
343
- a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as
344
- the `rng` argument of `init_fn`.
441
+ a key as an Jax or TF array of size 2 dtype uint32 will be passed
442
+ as the `rng` argument of `init_fn`.
345
443
  """
346
444
  return self.seed_generator.next()
347
445
 
348
446
  def _get_call_rng(self, training):
349
447
  """
350
- Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `call_fn`.
448
+ Returns a key in form of the backend array of size 2 dtype uint32
449
+ to pass to `call_fn`.
351
450
 
352
- By default, this returns a single `PRNGKey` retrieved by calling
451
+ By default, this returns a Jax or TF array of size 2 by calling
353
452
  `self.seed_generator.next()` when `training` is `True`, and `None` when
354
453
  `training` is `False`. Override this to return a different structure or
355
454
  to pass RNGs in inference mode too.
356
455
 
357
456
  Returns:
358
- a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as
359
- the `rng` argument of `call_fn`.
457
+ a key as an Jax or TF array of size 2 dtype uint32 will be passed
458
+ as the `rng` argument of `call_fn`.
360
459
  """
361
460
  if training:
362
461
  return self.seed_generator.next()
363
462
  else:
364
463
  return None
365
464
 
366
- def build(self, input_shape):
367
- if self.params is not None or self.state is not None:
368
- return
369
-
370
- if jax_utils.is_in_jax_tracing_scope():
465
+ def _initialize_weights(self, input_shape):
466
+ if jax_utils.is_in_jax_tracing_scope() or tf.inside_function():
371
467
  # This exception is not actually shown, it is caught and a detailed
372
468
  # warning about calling 'build' is printed.
373
- raise ValueError("'JaxLayer' cannot be built in tracing scope")
469
+ raise ValueError(
470
+ "'JaxLayer' cannot be built in tracing scope"
471
+ "or inside tf function"
472
+ )
374
473
 
375
474
  # Initialize `params` and `state` if needed by calling `init_fn`.
376
475
  def create_input(shape):
@@ -381,7 +480,12 @@ class JaxLayer(Layer):
381
480
  init_args = []
382
481
  for argument_name in self.init_fn_arguments:
383
482
  if argument_name == "rng":
384
- init_args.append(self._get_init_rng())
483
+ init_args.append(
484
+ jax.tree_util.tree_map(
485
+ lambda x: jax.numpy.array(_convert_to_jax_key(x)),
486
+ self._get_init_rng(),
487
+ )
488
+ )
385
489
  elif argument_name == "inputs":
386
490
  init_args.append(init_inputs)
387
491
  elif argument_name == "training":
@@ -398,6 +502,45 @@ class JaxLayer(Layer):
398
502
  )
399
503
  self.tracked_state = self._create_variables(init_state, trainable=False)
400
504
 
505
+ def build(self, input_shape):
506
+ if self.params is None and self.state is None:
507
+ self._initialize_weights(input_shape)
508
+
509
+ if backend.backend() == "tensorflow":
510
+ polymorphic_shapes = []
511
+ for argument in self.call_fn_arguments:
512
+ if argument == "inputs":
513
+ polymorphic_shapes.append(
514
+ self._get_jax2tf_input_shape(input_shape)
515
+ )
516
+ elif argument != "training":
517
+ # params, state, rng
518
+ polymorphic_shapes.append("...")
519
+
520
+ if "training" in self.call_fn_arguments:
521
+ training_argument_index = self.call_fn_arguments.index(
522
+ "training"
523
+ )
524
+ self.jax2tf_training_false_fn = self._jax2tf_convert(
525
+ self._partial_with_positional(
526
+ self.call_fn, training_argument_index, False
527
+ ),
528
+ polymorphic_shapes,
529
+ )
530
+ self.jax2tf_training_true_fn = self._jax2tf_convert(
531
+ self._partial_with_positional(
532
+ self.call_fn, training_argument_index, True
533
+ ),
534
+ polymorphic_shapes,
535
+ )
536
+ else:
537
+ self.jax2tf_training_false_fn = self._jax2tf_convert(
538
+ self.call_fn,
539
+ polymorphic_shapes,
540
+ )
541
+ self.jax2tf_training_true_fn = None
542
+ super().build(input_shape)
543
+
401
544
  def call(self, inputs, training=False):
402
545
  def unwrap_variable(variable):
403
546
  return None if variable is None else variable.value
@@ -413,11 +556,16 @@ class JaxLayer(Layer):
413
556
  jax.tree_util.tree_map(unwrap_variable, self.state)
414
557
  )
415
558
  elif argument_name == "rng":
416
- call_args.append(self._get_call_rng(training))
559
+ call_args.append(
560
+ jax.tree_util.tree_map(
561
+ _convert_to_jax_key, self._get_call_rng(training)
562
+ )
563
+ )
417
564
  elif argument_name == "inputs":
418
565
  call_args.append(inputs)
419
566
  elif argument_name == "training":
420
- call_args.append(training)
567
+ if backend.backend() == "jax":
568
+ call_args.append(training)
421
569
 
422
570
  def assign_state_to_variable(value, variable):
423
571
  # This exists only to make debugging this error case easier.
@@ -429,14 +577,23 @@ class JaxLayer(Layer):
429
577
  )
430
578
  variable.assign(value)
431
579
 
432
- if self.has_state:
433
- predictions, new_state = self.call_fn(*call_args)
434
- jax.tree_util.tree_map(
435
- assign_state_to_variable, new_state, self.state
436
- )
437
- return predictions
438
- else:
439
- return self.call_fn(*call_args)
580
+ def call_with_fn(fn):
581
+ if self.has_state:
582
+ predictions, new_state = fn(*call_args)
583
+ jax.tree_util.tree_map(
584
+ assign_state_to_variable, new_state, self.state
585
+ )
586
+ return predictions
587
+ else:
588
+ return fn(*call_args)
589
+
590
+ if backend.backend() == "jax":
591
+ return call_with_fn(self.call_fn)
592
+ elif backend.backend() == "tensorflow":
593
+ if training and self.jax2tf_training_true_fn is not None:
594
+ return call_with_fn(self.jax2tf_training_true_fn)
595
+ else:
596
+ return call_with_fn(self.jax2tf_training_false_fn)
440
597
 
441
598
  def get_config(self):
442
599
  config = {
@@ -556,12 +713,6 @@ class FlaxLayer(JaxLayer):
556
713
  # Late import to only require Flax when this is used.
557
714
  from flax.core import scope as flax_scope
558
715
 
559
- if backend.backend() != "jax":
560
- raise ValueError(
561
- "FlaxLayer is only supported with the JAX backend. Current "
562
- f"backend: {backend.backend()}"
563
- )
564
-
565
716
  self.module = module
566
717
  self.method = method
567
718
 
@@ -39,6 +39,15 @@ class LazyModule:
39
39
  return f"LazyModule({self.name})"
40
40
 
41
41
 
42
+ class OrbaxLazyModule(LazyModule):
43
+ def initialize(self):
44
+ try:
45
+ parent_module = importlib.import_module("orbax.checkpoint")
46
+ self.module = parent_module.v1
47
+ except ImportError:
48
+ raise ImportError(self.import_error_msg)
49
+
50
+
42
51
  tensorflow = LazyModule("tensorflow")
43
52
  gfile = LazyModule("tensorflow.io.gfile", pip_name="tensorflow")
44
53
  tensorflow_io = LazyModule("tensorflow_io")
@@ -59,3 +68,12 @@ optree = LazyModule("optree")
59
68
  dmtree = LazyModule("tree")
60
69
  tf2onnx = LazyModule("tf2onnx")
61
70
  grain = LazyModule("grain")
71
+ litert = LazyModule("ai_edge_litert")
72
+ ocp = OrbaxLazyModule(
73
+ "orbax.checkpoint.v1",
74
+ pip_name="orbax-checkpoint",
75
+ import_error_msg=(
76
+ "OrbaxCheckpoint requires the 'orbax-checkpoint' package. "
77
+ "You can install it via pip install orbax-checkpoint"
78
+ ),
79
+ )
@@ -3,7 +3,8 @@ import os
3
3
  import sys
4
4
  import time
5
5
 
6
- from keras.src import backend
6
+ import numpy as np
7
+
7
8
  from keras.src.api_export import keras_export
8
9
  from keras.src.utils import io_utils
9
10
 
@@ -162,12 +163,10 @@ class Progbar:
162
163
  for k in self._values_order:
163
164
  info += f" - {k}:"
164
165
  if isinstance(self._values[k], list):
165
- avg = backend.convert_to_numpy(
166
- backend.numpy.mean(
167
- self._values[k][0] / max(1, self._values[k][1])
168
- )
169
- )
170
- avg = float(avg)
166
+ values, count = self._values[k]
167
+ if not isinstance(values, float):
168
+ values = np.mean(values)
169
+ avg = values / max(1, count)
171
170
  if abs(avg) > 1e-3:
172
171
  info += f" {avg:.4f}"
173
172
  else:
@@ -194,11 +193,10 @@ class Progbar:
194
193
  info += f" -{self._format_time(time_per_unit, self.unit_name)}"
195
194
  for k in self._values_order:
196
195
  info += f" - {k}:"
197
- avg = backend.convert_to_numpy(
198
- backend.numpy.mean(
199
- self._values[k][0] / max(1, self._values[k][1])
200
- )
201
- )
196
+ values, count = self._values[k]
197
+ if not isinstance(values, float):
198
+ values = np.mean(values)
199
+ avg = values / max(1, count)
202
200
  if avg > 1e-3:
203
201
  info += f" {avg:.4f}"
204
202
  else:
@@ -5,6 +5,7 @@ import numpy as np
5
5
  from keras.src import backend
6
6
  from keras.src.api_export import keras_export
7
7
  from keras.src.backend.common import global_state
8
+ from keras.src.random import seed_generator
8
9
  from keras.src.utils.module_utils import tensorflow as tf
9
10
 
10
11
  GLOBAL_RANDOM_SEED = "global_random_seed"
@@ -20,7 +21,7 @@ def set_random_seed(seed):
20
21
  sources of randomness, or when certain non-deterministic cuDNN ops are
21
22
  involved.
22
23
 
23
- Calling this utility is equivalent to the following:
24
+ Calling this utility does the following:
24
25
 
25
26
  ```python
26
27
  import random
@@ -36,6 +37,9 @@ def set_random_seed(seed):
36
37
  torch.manual_seed(seed)
37
38
  ```
38
39
 
40
+ Additionally, it resets the global Keras `SeedGenerator`, which is used by
41
+ `keras.random` functions when the `seed` is not provided.
42
+
39
43
  Note that the TensorFlow seed is set even if you're not using TensorFlow
40
44
  as your backend framework, since many workflows leverage `tf.data`
41
45
  pipelines (which feature random shuffling). Likewise many workflows
@@ -52,6 +56,10 @@ def set_random_seed(seed):
52
56
 
53
57
  # Store seed in global state so we can query it if set.
54
58
  global_state.set_global_attribute(GLOBAL_RANDOM_SEED, seed)
59
+ # Remove global SeedGenerator, it will be recreated from the seed.
60
+ global_state.set_global_attribute(
61
+ seed_generator.GLOBAL_SEED_GENERATOR, None
62
+ )
55
63
  random.seed(seed)
56
64
  np.random.seed(seed)
57
65
  if tf.available:
keras/src/version.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from keras.src.api_export import keras_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "3.12.0.dev2025092403"
4
+ __version__ = "3.14.0.dev2026010104"
5
5
 
6
6
 
7
7
  @keras_export("keras.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-nightly
3
- Version: 3.12.0.dev2025092403
3
+ Version: 3.14.0.dev2026010104
4
4
  Summary: Multi-backend Keras
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
@@ -8,15 +8,15 @@ Project-URL: Home, https://keras.io/
8
8
  Project-URL: Repository, https://github.com/keras-team/keras
9
9
  Classifier: Development Status :: 4 - Beta
10
10
  Classifier: Programming Language :: Python :: 3
11
- Classifier: Programming Language :: Python :: 3.10
12
11
  Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Programming Language :: Python :: 3.12
13
13
  Classifier: Programming Language :: Python :: 3 :: Only
14
14
  Classifier: Operating System :: Unix
15
15
  Classifier: Operating System :: MacOS
16
16
  Classifier: Intended Audience :: Science/Research
17
17
  Classifier: Topic :: Scientific/Engineering
18
18
  Classifier: Topic :: Software Development
19
- Requires-Python: >=3.10
19
+ Requires-Python: >=3.11
20
20
  Description-Content-Type: text/markdown
21
21
  Requires-Dist: absl-py
22
22
  Requires-Dist: numpy
@@ -56,9 +56,8 @@ pip install keras --upgrade
56
56
 
57
57
  2. Install backend package(s).
58
58
 
59
- To use `keras`, you should also install the backend of choice: `tensorflow`, `jax`, or `torch`.
60
- Note that `tensorflow` is required for using certain Keras 3 features: certain preprocessing layers
61
- as well as `tf.data` pipelines.
59
+ To use `keras`, you should also install the backend of choice: `tensorflow`, `jax`, or `torch`. Additionally,
60
+ The `openvino` backend is available with support for model inference only.
62
61
 
63
62
  ### Local installation
64
63
 
@@ -85,6 +84,17 @@ python pip_build.py --install
85
84
  ./shell/api_gen.sh
86
85
  ```
87
86
 
87
+ ## Backend Compatibility Table
88
+
89
+ The following table lists the minimum supported versions of each backend for the latest stable release of Keras (v3.x):
90
+
91
+ | Backend | Minimum Supported Version |
92
+ |------------|---------------------------|
93
+ | TensorFlow | 2.16.1 |
94
+ | JAX | 0.4.20 |
95
+ | PyTorch | 2.1.0 |
96
+ | OpenVINO | 2025.3.0 |
97
+
88
98
  #### Adding GPU support
89
99
 
90
100
  The `requirements.txt` file will install a CPU-only version of TensorFlow, JAX, and PyTorch. For GPU support, we also