keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,7 @@
1
+ import functools
1
2
  import inspect
3
+ import itertools
4
+ import string
2
5
 
3
6
  import numpy as np
4
7
 
@@ -8,10 +11,27 @@ from keras.src.api_export import keras_export
8
11
  from keras.src.backend.common.variables import is_float_dtype
9
12
  from keras.src.backend.common.variables import standardize_dtype
10
13
  from keras.src.layers.layer import Layer
14
+ from keras.src.random.seed_generator import draw_seed
11
15
  from keras.src.saving import serialization_lib
12
16
  from keras.src.utils import jax_utils
13
17
  from keras.src.utils import tracking
14
18
  from keras.src.utils.module_utils import jax
19
+ from keras.src.utils.module_utils import tensorflow as tf
20
+
21
+ if backend.backend() == "tensorflow":
22
+ tf_no_automatic_dependency_tracking = (
23
+ tf.__internal__.tracking.no_automatic_dependency_tracking
24
+ )
25
+ else:
26
+
27
+ def tf_no_automatic_dependency_tracking(fn):
28
+ return fn
29
+
30
+
31
+ def _convert_to_jax_key(tensor):
32
+ if backend.backend() == "tensorflow":
33
+ return tf.bitcast(tensor, tf.uint32)[0]
34
+ return tensor
15
35
 
16
36
 
17
37
  @keras_export("keras.layers.JaxLayer")
@@ -219,21 +239,15 @@ class JaxLayer(Layer):
219
239
  seed=None,
220
240
  **kwargs,
221
241
  ):
222
- if backend.backend() != "jax":
242
+ if backend.backend() not in ["jax", "tensorflow"]:
223
243
  raise ValueError(
224
- "JaxLayer is only supported with the JAX backend. Current "
225
- f"backend: {backend.backend()}"
226
- )
227
-
228
- if init_fn is None and params is None and state is None:
229
- raise ValueError(
230
- "`init_fn`, `params` and `state` cannot all be `None`."
244
+ f"{self.__class__.__name__} is only supported with the JAX or"
245
+ f" Tensorflow backend. Current backend: {backend.backend()}"
231
246
  )
232
247
 
233
248
  super().__init__(**kwargs)
234
249
  self.call_fn = call_fn
235
250
  self.init_fn = init_fn
236
- self.seed_generator = backend.random.SeedGenerator(seed)
237
251
  self.tracked_params = self._create_variables(params, trainable=True)
238
252
  self.tracked_state = self._create_variables(state, trainable=False)
239
253
  if self.params is not None or self.state is not None:
@@ -245,13 +259,35 @@ class JaxLayer(Layer):
245
259
  {"params", "state", "rng", "inputs", "training"},
246
260
  {"inputs"},
247
261
  )
248
- self.has_state = "state" in self.call_fn_arguments
262
+ self.call_fn_has_params = "params" in self.call_fn_arguments
263
+ self.call_fn_has_state = "state" in self.call_fn_arguments
264
+ call_fn_has_rng = "rng" in self.call_fn_arguments
265
+
266
+ if call_fn_has_rng:
267
+ self.seed_generator = backend.random.SeedGenerator(seed)
268
+ else:
269
+ self.seed_generator = None
270
+
271
+ if (
272
+ init_fn is None
273
+ and params is None
274
+ and state is None
275
+ and (self.call_fn_has_params or self.call_fn_has_state)
276
+ ):
277
+ raise ValueError(
278
+ "`init_fn`, `params` and `state` cannot all be `None` when "
279
+ "`call_fn` takes a `params` or a `state` argument."
280
+ )
249
281
 
250
282
  if init_fn:
251
283
  self.init_fn_arguments = self._validate_signature(
252
284
  init_fn, "init_fn", {"rng", "inputs", "training"}, {"inputs"}
253
285
  )
254
286
 
287
+ # Attributes for jax2tf functions
288
+ self.jax2tf_training_false_fn = None
289
+ self.jax2tf_training_true_fn = None
290
+
255
291
  def _validate_signature(self, fn, fn_name, allowed, required):
256
292
  fn_parameters = inspect.signature(fn).parameters
257
293
  for parameter_name in required:
@@ -272,7 +308,81 @@ class JaxLayer(Layer):
272
308
 
273
309
  return parameter_names
274
310
 
311
+ def _get_jax2tf_input_shape(self, input_shape):
312
+ """Convert input shape in a format suitable for `jax2tf`.
313
+
314
+ `jax2tf` expects a letter for each unknown dimension, which allows
315
+ correlated dimensions. Since correlated dimensions are not supported by
316
+ Keras, we simply use 'a', 'b', 'c'..., for each unknown dimension. We
317
+ however use 'batch' for dimension 0 if not defined to correlate the
318
+ batch size across inputs.
319
+
320
+ Example (spaces added for readability):
321
+ ```
322
+ input_shape: (None , 4 , None, None, 5 )
323
+ result: "(batch, 4 , a , b , 5 )"
324
+ ```
325
+
326
+ Args:
327
+ input_shape: a single shape or a structure of shapes for the inputs.
328
+ Returns:
329
+ the shape or shapes structure in the `jax2tf` format as strings.
330
+ """
331
+ dim_names = itertools.chain(
332
+ string.ascii_lowercase, # a, b, ... z
333
+ itertools.starmap( # aa, ab, ... az, ba, bb, ... zz
334
+ lambda a, b: a + b,
335
+ itertools.product(string.ascii_lowercase, repeat=2),
336
+ ),
337
+ )
338
+
339
+ def get_single_jax2tf_shape(shape):
340
+ jax2tf_shape = []
341
+
342
+ for index, dim in enumerate(shape):
343
+ if dim is not None:
344
+ jax2tf_shape.append(str(dim))
345
+ elif index == 0:
346
+ jax2tf_shape.append("batch")
347
+ else:
348
+ jax2tf_shape.append(next(dim_names))
349
+
350
+ return "(" + ", ".join(jax2tf_shape) + ")"
351
+
352
+ res = tree.map_shape_structure(get_single_jax2tf_shape, input_shape)
353
+ return res
354
+
355
+ def _jax2tf_convert(self, fn, polymorphic_shapes):
356
+ from jax.experimental import jax2tf
357
+
358
+ converted_fn = jax2tf.convert(fn, polymorphic_shapes=polymorphic_shapes)
359
+ # Autograph won't work with the output of jax2tf.
360
+ converted_fn = tf.autograph.experimental.do_not_convert(converted_fn)
361
+ return converted_fn
362
+
363
+ def _partial_with_positional(self, fn, index, value):
364
+ """Return a new partial with one positional argument set to a value.
365
+
366
+ This is needed because `jax2tf` only supports positional arguments and
367
+ `functools.partial` only supports setting positional arguments starting
368
+ from the left. Our use case is the `training` argument which is
369
+ typically the righmost argument.
370
+
371
+ Args:
372
+ fn: the function to wrap.
373
+ index: the index of the positional argument to set to `value`.
374
+ value: the value for the positional argument at `index`.
375
+ """
376
+
377
+ @functools.wraps(fn)
378
+ def wrapper(*args):
379
+ args = args[0:index] + (value,) + args[index:]
380
+ return fn(*args)
381
+
382
+ return wrapper
383
+
275
384
  @tracking.no_automatic_dependency_tracking
385
+ @tf_no_automatic_dependency_tracking
276
386
  def _create_variables(self, values, trainable):
277
387
  """Create a structure of variables from a structure of JAX arrays.
278
388
 
@@ -296,14 +406,14 @@ class JaxLayer(Layer):
296
406
 
297
407
  def create_variable(value):
298
408
  if backend.is_tensor(value) or isinstance(
299
- value, (np.ndarray, np.generic)
409
+ value, (np.ndarray, np.generic, jax.Array)
300
410
  ):
301
411
  dtype = value.dtype
302
412
  if is_float_dtype(dtype):
303
413
  dtype = None # Use the layer dtype policy
304
414
  return self.add_weight(
305
415
  value.shape,
306
- initializer=value,
416
+ initializer=backend.convert_to_tensor(value),
307
417
  dtype=dtype,
308
418
  trainable=trainable,
309
419
  )
@@ -331,46 +441,69 @@ class JaxLayer(Layer):
331
441
  flat_variables, _ = jax.tree_util.tree_flatten(variables)
332
442
  return flat_variables
333
443
 
444
+ def _get_init_seed(self):
445
+ """
446
+ Returns a single seed as a tensor of shape [2].
447
+
448
+ Call this within `_get_init_rng()` to obtain a new seed.
449
+
450
+ Returns:
451
+ A native tensor of shape [2] and the backend dtype for seeds.
452
+ """
453
+ # Use the global SeedGenerator.
454
+ return draw_seed(None)
455
+
334
456
  def _get_init_rng(self):
335
457
  """
336
- Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `init_fn`.
458
+ Returns a seed or seeds to pass as the `rng` argument of `init_fn`.
337
459
 
338
- By default, this returns a single `PRNGKey` retrieved by calling
339
- `self.seed_generator.next()`. Override this to return a different
340
- structure.
460
+ By default, this returns a single seed. Override this to return a
461
+ different structure. Overrides should use `self._get_init_seed()` to
462
+ obtain new seeds.
341
463
 
342
464
  Returns:
343
- a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as
344
- the `rng` argument of `init_fn`.
465
+ RNG key or structure of keys as tensors of shape [2] and the backend
466
+ dtype for seeds.
467
+ """
468
+ return self._get_init_seed()
469
+
470
+ def _get_call_seed(self):
471
+ """
472
+ Returns a single seed as a tensor of shape [2].
473
+
474
+ Call this within `_get_call_rng()` to obtain a new seed.
475
+
476
+ Returns:
477
+ A native tensor of shape [2] and the backend dtype for seeds.
345
478
  """
346
479
  return self.seed_generator.next()
347
480
 
348
481
  def _get_call_rng(self, training):
349
482
  """
350
- Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `call_fn`.
483
+ Returns a seed or seeds to pass as the `rng` argument of `call_fn`.
351
484
 
352
- By default, this returns a single `PRNGKey` retrieved by calling
353
- `self.seed_generator.next()` when `training` is `True`, and `None` when
354
- `training` is `False`. Override this to return a different structure or
355
- to pass RNGs in inference mode too.
485
+ By default, this returns a seed when `training` is `True`, and `None`
486
+ when `training` is `False`. Override this to return a different
487
+ structure or to pass seeds in inference mode too. Overrides should use
488
+ `self._get_call_seed()` to obtain seeds.
356
489
 
357
490
  Returns:
358
- a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as
359
- the `rng` argument of `call_fn`.
491
+ RNG key or structure of keys as tensors of shape [2] and the backend
492
+ dtype for seeds.
360
493
  """
361
494
  if training:
362
- return self.seed_generator.next()
495
+ return self._get_call_seed()
363
496
  else:
364
497
  return None
365
498
 
366
- def build(self, input_shape):
367
- if self.params is not None or self.state is not None:
368
- return
369
-
370
- if jax_utils.is_in_jax_tracing_scope():
499
+ def _initialize_weights(self, input_shape):
500
+ if jax_utils.is_in_jax_tracing_scope() or tf.inside_function():
371
501
  # This exception is not actually shown, it is caught and a detailed
372
502
  # warning about calling 'build' is printed.
373
- raise ValueError("'JaxLayer' cannot be built in tracing scope")
503
+ raise ValueError(
504
+ "'JaxLayer' cannot be built in tracing scope"
505
+ "or inside tf function"
506
+ )
374
507
 
375
508
  # Initialize `params` and `state` if needed by calling `init_fn`.
376
509
  def create_input(shape):
@@ -381,14 +514,19 @@ class JaxLayer(Layer):
381
514
  init_args = []
382
515
  for argument_name in self.init_fn_arguments:
383
516
  if argument_name == "rng":
384
- init_args.append(self._get_init_rng())
517
+ init_args.append(
518
+ jax.tree_util.tree_map(
519
+ lambda x: jax.numpy.array(_convert_to_jax_key(x)),
520
+ self._get_init_rng(),
521
+ )
522
+ )
385
523
  elif argument_name == "inputs":
386
524
  init_args.append(init_inputs)
387
525
  elif argument_name == "training":
388
526
  init_args.append(True)
389
527
 
390
528
  init_result = self.init_fn(*init_args)
391
- if self.has_state:
529
+ if self.call_fn_has_state:
392
530
  init_params, init_state = init_result
393
531
  else:
394
532
  init_params, init_state = init_result, None
@@ -398,6 +536,49 @@ class JaxLayer(Layer):
398
536
  )
399
537
  self.tracked_state = self._create_variables(init_state, trainable=False)
400
538
 
539
+ def build(self, input_shape):
540
+ if (
541
+ self.params is None
542
+ and self.state is None
543
+ and (self.call_fn_has_params or self.call_fn_has_state)
544
+ ):
545
+ self._initialize_weights(input_shape)
546
+
547
+ if backend.backend() == "tensorflow":
548
+ polymorphic_shapes = []
549
+ for argument in self.call_fn_arguments:
550
+ if argument == "inputs":
551
+ polymorphic_shapes.append(
552
+ self._get_jax2tf_input_shape(input_shape)
553
+ )
554
+ elif argument != "training":
555
+ # params, state, rng
556
+ polymorphic_shapes.append("...")
557
+
558
+ if "training" in self.call_fn_arguments:
559
+ training_argument_index = self.call_fn_arguments.index(
560
+ "training"
561
+ )
562
+ self.jax2tf_training_false_fn = self._jax2tf_convert(
563
+ self._partial_with_positional(
564
+ self.call_fn, training_argument_index, False
565
+ ),
566
+ polymorphic_shapes,
567
+ )
568
+ self.jax2tf_training_true_fn = self._jax2tf_convert(
569
+ self._partial_with_positional(
570
+ self.call_fn, training_argument_index, True
571
+ ),
572
+ polymorphic_shapes,
573
+ )
574
+ else:
575
+ self.jax2tf_training_false_fn = self._jax2tf_convert(
576
+ self.call_fn,
577
+ polymorphic_shapes,
578
+ )
579
+ self.jax2tf_training_true_fn = None
580
+ super().build(input_shape)
581
+
401
582
  def call(self, inputs, training=False):
402
583
  def unwrap_variable(variable):
403
584
  return None if variable is None else variable.value
@@ -413,11 +594,16 @@ class JaxLayer(Layer):
413
594
  jax.tree_util.tree_map(unwrap_variable, self.state)
414
595
  )
415
596
  elif argument_name == "rng":
416
- call_args.append(self._get_call_rng(training))
597
+ call_args.append(
598
+ jax.tree_util.tree_map(
599
+ _convert_to_jax_key, self._get_call_rng(training)
600
+ )
601
+ )
417
602
  elif argument_name == "inputs":
418
603
  call_args.append(inputs)
419
604
  elif argument_name == "training":
420
- call_args.append(training)
605
+ if backend.backend() == "jax":
606
+ call_args.append(training)
421
607
 
422
608
  def assign_state_to_variable(value, variable):
423
609
  # This exists only to make debugging this error case easier.
@@ -429,14 +615,23 @@ class JaxLayer(Layer):
429
615
  )
430
616
  variable.assign(value)
431
617
 
432
- if self.has_state:
433
- predictions, new_state = self.call_fn(*call_args)
434
- jax.tree_util.tree_map(
435
- assign_state_to_variable, new_state, self.state
436
- )
437
- return predictions
438
- else:
439
- return self.call_fn(*call_args)
618
+ def call_with_fn(fn):
619
+ if self.call_fn_has_state:
620
+ predictions, new_state = fn(*call_args)
621
+ jax.tree_util.tree_map(
622
+ assign_state_to_variable, new_state, self.state
623
+ )
624
+ return predictions
625
+ else:
626
+ return fn(*call_args)
627
+
628
+ if backend.backend() == "jax":
629
+ return call_with_fn(self.call_fn)
630
+ elif backend.backend() == "tensorflow":
631
+ if training and self.jax2tf_training_true_fn is not None:
632
+ return call_with_fn(self.jax2tf_training_true_fn)
633
+ else:
634
+ return call_with_fn(self.jax2tf_training_false_fn)
440
635
 
441
636
  def get_config(self):
442
637
  config = {
@@ -554,18 +749,12 @@ class FlaxLayer(JaxLayer):
554
749
  **kwargs,
555
750
  ):
556
751
  # Late import to only require Flax when this is used.
557
- from flax.core import scope as flax_scope
558
-
559
- if backend.backend() != "jax":
560
- raise ValueError(
561
- "FlaxLayer is only supported with the JAX backend. Current "
562
- f"backend: {backend.backend()}"
563
- )
752
+ from flax.linen import DenyList
564
753
 
565
754
  self.module = module
566
755
  self.method = method
567
756
 
568
- apply_mutable = flax_scope.DenyList(["params"])
757
+ apply_mutable = DenyList(["params"])
569
758
 
570
759
  def apply_with_training(params, state, rng, inputs, training):
571
760
  return self.module.apply(
@@ -650,13 +839,13 @@ class FlaxLayer(JaxLayer):
650
839
 
651
840
  def _get_init_rng(self):
652
841
  return {
653
- "params": self.seed_generator.next(),
654
- "dropout": self.seed_generator.next(),
842
+ "params": self._get_init_seed(),
843
+ "dropout": self._get_init_seed(),
655
844
  }
656
845
 
657
846
  def _get_call_rng(self, training):
658
847
  if training:
659
- return {"dropout": self.seed_generator.next()}
848
+ return {"dropout": self._get_call_seed()}
660
849
  else:
661
850
  return {}
662
851
 
@@ -39,11 +39,31 @@ class LazyModule:
39
39
  return f"LazyModule({self.name})"
40
40
 
41
41
 
42
+ class OrbaxLazyModule(LazyModule):
43
+ def initialize(self):
44
+ try:
45
+ parent_module = importlib.import_module("orbax.checkpoint")
46
+ self.module = parent_module.v1
47
+ self.parent_module = parent_module
48
+ except ImportError:
49
+ raise ImportError(self.import_error_msg)
50
+
51
+ def __getattr__(self, name):
52
+ if name == "_api_export_path":
53
+ raise AttributeError
54
+ if self.module is None:
55
+ self.initialize()
56
+ if name == "multihost":
57
+ return self.parent_module.multihost
58
+ return getattr(self.module, name)
59
+
60
+
42
61
  tensorflow = LazyModule("tensorflow")
43
62
  gfile = LazyModule("tensorflow.io.gfile", pip_name="tensorflow")
44
63
  tensorflow_io = LazyModule("tensorflow_io")
45
64
  scipy = LazyModule("scipy")
46
65
  jax = LazyModule("jax")
66
+ h5py = LazyModule("h5py")
47
67
  torch_xla = LazyModule(
48
68
  "torch_xla",
49
69
  import_error_msg=(
@@ -59,3 +79,12 @@ optree = LazyModule("optree")
59
79
  dmtree = LazyModule("tree")
60
80
  tf2onnx = LazyModule("tf2onnx")
61
81
  grain = LazyModule("grain")
82
+ litert = LazyModule("ai_edge_litert")
83
+ ocp = OrbaxLazyModule(
84
+ "orbax.checkpoint.v1",
85
+ pip_name="orbax-checkpoint",
86
+ import_error_msg=(
87
+ "OrbaxCheckpoint requires the 'orbax-checkpoint' package. "
88
+ "You can install it via pip install orbax-checkpoint"
89
+ ),
90
+ )
@@ -3,7 +3,8 @@ import os
3
3
  import sys
4
4
  import time
5
5
 
6
- from keras.src import backend
6
+ import numpy as np
7
+
7
8
  from keras.src.api_export import keras_export
8
9
  from keras.src.utils import io_utils
9
10
 
@@ -162,12 +163,10 @@ class Progbar:
162
163
  for k in self._values_order:
163
164
  info += f" - {k}:"
164
165
  if isinstance(self._values[k], list):
165
- avg = backend.convert_to_numpy(
166
- backend.numpy.mean(
167
- self._values[k][0] / max(1, self._values[k][1])
168
- )
169
- )
170
- avg = float(avg)
166
+ values, count = self._values[k]
167
+ if not isinstance(values, float):
168
+ values = np.mean(values)
169
+ avg = values / max(1, count)
171
170
  if abs(avg) > 1e-3:
172
171
  info += f" {avg:.4f}"
173
172
  else:
@@ -194,11 +193,10 @@ class Progbar:
194
193
  info += f" -{self._format_time(time_per_unit, self.unit_name)}"
195
194
  for k in self._values_order:
196
195
  info += f" - {k}:"
197
- avg = backend.convert_to_numpy(
198
- backend.numpy.mean(
199
- self._values[k][0] / max(1, self._values[k][1])
200
- )
201
- )
196
+ values, count = self._values[k]
197
+ if not isinstance(values, float):
198
+ values = np.mean(values)
199
+ avg = values / max(1, count)
202
200
  if avg > 1e-3:
203
201
  info += f" {avg:.4f}"
204
202
  else:
@@ -181,6 +181,8 @@ def pythonify_logs(logs):
181
181
  A flattened dict with values converted to Python-native types if
182
182
  possible.
183
183
  """
184
+ from keras.src import backend
185
+
184
186
  logs = logs or {}
185
187
  result = {}
186
188
  for key, value in sorted(logs.items()):
@@ -188,6 +190,9 @@ def pythonify_logs(logs):
188
190
  result.update(pythonify_logs(value))
189
191
  else:
190
192
  try:
193
+ # Prevent torch compiler from breaking the graph.
194
+ if backend.is_tensor(value):
195
+ value = backend.convert_to_numpy(value)
191
196
  value = float(value)
192
197
  except:
193
198
  pass
@@ -5,6 +5,7 @@ import numpy as np
5
5
  from keras.src import backend
6
6
  from keras.src.api_export import keras_export
7
7
  from keras.src.backend.common import global_state
8
+ from keras.src.random import seed_generator
8
9
  from keras.src.utils.module_utils import tensorflow as tf
9
10
 
10
11
  GLOBAL_RANDOM_SEED = "global_random_seed"
@@ -20,7 +21,7 @@ def set_random_seed(seed):
20
21
  sources of randomness, or when certain non-deterministic cuDNN ops are
21
22
  involved.
22
23
 
23
- Calling this utility is equivalent to the following:
24
+ Calling this utility does the following:
24
25
 
25
26
  ```python
26
27
  import random
@@ -36,6 +37,9 @@ def set_random_seed(seed):
36
37
  torch.manual_seed(seed)
37
38
  ```
38
39
 
40
+ Additionally, it resets the global Keras `SeedGenerator`, which is used by
41
+ `keras.random` functions when the `seed` is not provided.
42
+
39
43
  Note that the TensorFlow seed is set even if you're not using TensorFlow
40
44
  as your backend framework, since many workflows leverage `tf.data`
41
45
  pipelines (which feature random shuffling). Likewise many workflows
@@ -52,6 +56,10 @@ def set_random_seed(seed):
52
56
 
53
57
  # Store seed in global state so we can query it if set.
54
58
  global_state.set_global_attribute(GLOBAL_RANDOM_SEED, seed)
59
+ # Remove global SeedGenerator, it will be recreated from the seed.
60
+ global_state.set_global_attribute(
61
+ seed_generator.GLOBAL_SEED_GENERATOR, None
62
+ )
55
63
  random.seed(seed)
56
64
  np.random.seed(seed)
57
65
  if tf.available: