keras-nightly 3.14.0.dev2026011504__py3-none-any.whl → 3.14.0.dev2026011704__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -539,11 +539,21 @@ def _do_lstm_arguments_support_cudnn(
539
539
 
540
540
 
541
541
  def _has_fully_masked_sequence(mask):
542
- # Cudnn kernel will error out if the input sequence contains any
543
- # fully masked data. We walk around this issue by rerouting the computation
544
- # to standard kernel, until the issue on cudnn side has been fixed. For a
545
- # fully masked sequence, it will contain all Falses. To make it easy to
546
- # check, we inverse the boolean, check if any of the sequence has all True.
542
+ """Check if input sequence contains any fully masked data.
543
+
544
+ cuDNN kernel will error out if the input sequence contains any fully masked
545
+ data. We work around this issue by rerouting the computation to the
546
+ standard kernel until the issue on the cuDNN side has been fixed. For a
547
+ fully masked sequence, it will contain all `False` values. To make it easy
548
+ to check, we invert the boolean and check if any of the sequences has all
549
+ `True` values.
550
+
551
+ Args:
552
+ mask: The mask tensor.
553
+
554
+ Returns:
555
+ A boolean tensor, `True` if the mask contains a fully masked sequence.
556
+ """
547
557
  return tf.reduce_any(
548
558
  tf.reduce_all(tf.logical_not(tf.cast(mask, dtype="bool")), axis=1)
549
559
  )
@@ -900,8 +910,8 @@ def _cudnn_lstm(
900
910
 
901
911
  if tf.sysconfig.get_build_info()["is_rocm_build"]:
902
912
  # ROCm MIOpen's weight sequence for LSTM is different from both
903
- # canonical and Cudnn format
904
- # MIOpen: [i, f, o, c] Cudnn/Canonical: [i, f, c, o]
913
+ # canonical and cuDNN format
914
+ # MIOpen: [i, f, o, c] cuDNN/Canonical: [i, f, c, o]
905
915
  # i is input gate weights.
906
916
  # f is forget gate weights.
907
917
  # o is output gate weights.
@@ -413,11 +413,21 @@ def _is_sequence_right_padded(mask):
413
413
 
414
414
 
415
415
  def _has_fully_masked_sequence(mask):
416
- # Cudnn kernel will error out if the input sequence contains any
417
- # fully masked data. We walk around this issue by rerouting the computation
418
- # to standard kernel, until the issue on cudnn side has been fixed. For a
419
- # fully masked sequence, it will contain all Falses. To make it easy to
420
- # check, we inverse the boolean, check if any of the sequence has all True.
416
+ """Check if input sequence contains any fully masked data.
417
+
418
+ cuDNN kernel will error out if the input sequence contains any fully masked
419
+ data. We work around this issue by rerouting the computation to the
420
+ standard kernel until the issue on the cuDNN side has been fixed. For a
421
+ fully masked sequence, it will contain all `False` values. To make it easy
422
+ to check, we invert the boolean and check if any of the sequences has all
423
+ `True` values.
424
+
425
+ Args:
426
+ mask: The mask tensor.
427
+
428
+ Returns:
429
+ A boolean tensor, `True` if the mask contains a fully masked sequence.
430
+ """
421
431
  return torch.any(torch.all(~mask, dim=1))
422
432
 
423
433
 
@@ -447,8 +457,8 @@ def _compute_sequence_length_from_mask(mask, batch_first):
447
457
  The masking tensor is a 2D boolean tensor with shape [batch, timestep]. For
448
458
  any timestep that should be masked, the corresponding field will be False.
449
459
  Consider the following example:
450
- a = [[True, True, False, False]
451
- [True, True, True, False]]
460
+ a = [[True, True, False, False]
461
+ [True, True, True, False]]
452
462
  It is a (2, 4) tensor, and the corresponding sequence length result should
453
463
  be 1D tensor with value [2, 3]. Note that the masking tensor must be right
454
464
  padded that could be checked by, e.g., `is_sequence_right_padded()`.
@@ -467,12 +477,19 @@ def _compute_sequence_length_from_mask(mask, batch_first):
467
477
 
468
478
 
469
479
  def prepare_lstm_weights(lstm, kernel, recurrent_kernel, bias, device):
470
- """Copies kernel and recurrent kernel weights in the Pytorch format
480
+ """Copies kernel and recurrent kernel weights into the PyTorch format.
481
+
471
482
  We split the kernel and recurrent kernel weights, create associated
472
- torch tensors adapted to be in line with the Cudnn optimization.
473
- After we have copied the weights, we ensure the paramters are on
474
- the same device and memory layout is optimized for Cudnn.
483
+ torch tensors adapted to be in line with the cuDNN optimization.
484
+ After we have copied the weights, we ensure the parameters are on
485
+ the same device and memory layout is optimized for cuDNN.
475
486
 
487
+ Args:
488
+ lstm: The PyTorch LSTM layer to prepare weights for.
489
+ kernel: The kernel weights tensor.
490
+ recurrent_kernel: The recurrent kernel weights tensor.
491
+ bias: The bias tensor.
492
+ device: The device to place the tensors on.
476
493
  """
477
494
 
478
495
  lstm = lstm.to(device)
@@ -92,8 +92,8 @@ class RandomContrast(BaseImagePreprocessingLayer):
92
92
 
93
93
  def transform_images(self, images, transformation, training=True):
94
94
  if training:
95
- constrast_factor = transformation["contrast_factor"]
96
- outputs = self._adjust_constrast(images, constrast_factor)
95
+ contrast_factor = transformation["contrast_factor"]
96
+ outputs = self._adjust_contrast(images, contrast_factor)
97
97
  outputs = self.backend.numpy.clip(
98
98
  outputs, self.value_range[0], self.value_range[1]
99
99
  )
@@ -117,7 +117,7 @@ class RandomContrast(BaseImagePreprocessingLayer):
117
117
  ):
118
118
  return segmentation_masks
119
119
 
120
- def _adjust_constrast(self, inputs, contrast_factor):
120
+ def _adjust_contrast(self, inputs, contrast_factor):
121
121
  if self.data_format == "channels_first":
122
122
  height_axis = -2
123
123
  width_axis = -1
@@ -11,6 +11,7 @@ from keras.src.api_export import keras_export
11
11
  from keras.src.backend.common.variables import is_float_dtype
12
12
  from keras.src.backend.common.variables import standardize_dtype
13
13
  from keras.src.layers.layer import Layer
14
+ from keras.src.random.seed_generator import draw_seed
14
15
  from keras.src.saving import serialization_lib
15
16
  from keras.src.utils import jax_utils
16
17
  from keras.src.utils import tracking
@@ -244,15 +245,9 @@ class JaxLayer(Layer):
244
245
  f" Tensorflow backend. Current backend: {backend.backend()}"
245
246
  )
246
247
 
247
- if init_fn is None and params is None and state is None:
248
- raise ValueError(
249
- "`init_fn`, `params` and `state` cannot all be `None`."
250
- )
251
-
252
248
  super().__init__(**kwargs)
253
249
  self.call_fn = call_fn
254
250
  self.init_fn = init_fn
255
- self.seed_generator = backend.random.SeedGenerator(seed)
256
251
  self.tracked_params = self._create_variables(params, trainable=True)
257
252
  self.tracked_state = self._create_variables(state, trainable=False)
258
253
  if self.params is not None or self.state is not None:
@@ -264,7 +259,25 @@ class JaxLayer(Layer):
264
259
  {"params", "state", "rng", "inputs", "training"},
265
260
  {"inputs"},
266
261
  )
267
- self.has_state = "state" in self.call_fn_arguments
262
+ self.call_fn_has_params = "params" in self.call_fn_arguments
263
+ self.call_fn_has_state = "state" in self.call_fn_arguments
264
+ call_fn_has_rng = "rng" in self.call_fn_arguments
265
+
266
+ if call_fn_has_rng:
267
+ self.seed_generator = backend.random.SeedGenerator(seed)
268
+ else:
269
+ self.seed_generator = None
270
+
271
+ if (
272
+ init_fn is None
273
+ and params is None
274
+ and state is None
275
+ and (self.call_fn_has_params or self.call_fn_has_state)
276
+ ):
277
+ raise ValueError(
278
+ "`init_fn`, `params` and `state` cannot all be `None` when "
279
+ "`call_fn` takes a `params` or a `state` argument."
280
+ )
268
281
 
269
282
  if init_fn:
270
283
  self.init_fn_arguments = self._validate_signature(
@@ -428,37 +441,58 @@ class JaxLayer(Layer):
428
441
  flat_variables, _ = jax.tree_util.tree_flatten(variables)
429
442
  return flat_variables
430
443
 
444
+ def _get_init_seed(self):
445
+ """
446
+ Returns a single seed as a tensor of shape [2].
447
+
448
+ Call this within `_get_init_rng()` to obtain a new seed.
449
+
450
+ Returns:
451
+ A native tensor of shape [2] and the backend dtype for seeds.
452
+ """
453
+ # Use the global SeedGenerator.
454
+ return draw_seed(None)
455
+
431
456
  def _get_init_rng(self):
432
457
  """
433
- Returns a key in form of the backend array of size 2 dtype uint32
434
- to pass to `init_fn`.
458
+ Returns a seed or seeds to pass as the `rng` argument of `init_fn`.
459
+
460
+ By default, this returns a single seed. Override this to return a
461
+ different structure. Overrides should use `self._get_init_seed()` to
462
+ obtain new seeds.
463
+
464
+ Returns:
465
+ RNG key or structure of keys as tensors of shape [2] and the backend
466
+ dtype for seeds.
467
+ """
468
+ return self._get_init_seed()
469
+
470
+ def _get_call_seed(self):
471
+ """
472
+ Returns a single seed as a tensor of shape [2].
435
473
 
436
- By default, this returns a Jax or TF array of size 2 by calling
437
- `self.seed_generator.next()`. Override this to return a different
438
- structure.
474
+ Call this within `_get_call_rng()` to obtain a new seed.
439
475
 
440
476
  Returns:
441
- a key as an Jax or TF array of size 2 dtype uint32 will be passed
442
- as the `rng` argument of `init_fn`.
477
+ A native tensor of shape [2] and the backend dtype for seeds.
443
478
  """
444
479
  return self.seed_generator.next()
445
480
 
446
481
  def _get_call_rng(self, training):
447
482
  """
448
- Returns a key in form of the backend array of size 2 dtype uint32
449
- to pass to `call_fn`.
483
+ Returns a seed or seeds to pass as the `rng` argument of `call_fn`.
450
484
 
451
- By default, this returns a Jax or TF array of size 2 by calling
452
- `self.seed_generator.next()` when `training` is `True`, and `None` when
453
- `training` is `False`. Override this to return a different structure or
454
- to pass RNGs in inference mode too.
485
+ By default, this returns a seed when `training` is `True`, and `None`
486
+ when `training` is `False`. Override this to return a different
487
+ structure or to pass seeds in inference mode too. Overrides should use
488
+ `self._get_call_seed()` to obtain seeds.
455
489
 
456
490
  Returns:
457
- a key as an Jax or TF array of size 2 dtype uint32 will be passed
458
- as the `rng` argument of `call_fn`.
491
+ RNG key or structure of keys as tensors of shape [2] and the backend
492
+ dtype for seeds.
459
493
  """
460
494
  if training:
461
- return self.seed_generator.next()
495
+ return self._get_call_seed()
462
496
  else:
463
497
  return None
464
498
 
@@ -492,7 +526,7 @@ class JaxLayer(Layer):
492
526
  init_args.append(True)
493
527
 
494
528
  init_result = self.init_fn(*init_args)
495
- if self.has_state:
529
+ if self.call_fn_has_state:
496
530
  init_params, init_state = init_result
497
531
  else:
498
532
  init_params, init_state = init_result, None
@@ -503,7 +537,11 @@ class JaxLayer(Layer):
503
537
  self.tracked_state = self._create_variables(init_state, trainable=False)
504
538
 
505
539
  def build(self, input_shape):
506
- if self.params is None and self.state is None:
540
+ if (
541
+ self.params is None
542
+ and self.state is None
543
+ and (self.call_fn_has_params or self.call_fn_has_state)
544
+ ):
507
545
  self._initialize_weights(input_shape)
508
546
 
509
547
  if backend.backend() == "tensorflow":
@@ -578,7 +616,7 @@ class JaxLayer(Layer):
578
616
  variable.assign(value)
579
617
 
580
618
  def call_with_fn(fn):
581
- if self.has_state:
619
+ if self.call_fn_has_state:
582
620
  predictions, new_state = fn(*call_args)
583
621
  jax.tree_util.tree_map(
584
622
  assign_state_to_variable, new_state, self.state
@@ -711,12 +749,12 @@ class FlaxLayer(JaxLayer):
711
749
  **kwargs,
712
750
  ):
713
751
  # Late import to only require Flax when this is used.
714
- from flax.core import scope as flax_scope
752
+ from flax.linen import DenyList
715
753
 
716
754
  self.module = module
717
755
  self.method = method
718
756
 
719
- apply_mutable = flax_scope.DenyList(["params"])
757
+ apply_mutable = DenyList(["params"])
720
758
 
721
759
  def apply_with_training(params, state, rng, inputs, training):
722
760
  return self.module.apply(
@@ -801,13 +839,13 @@ class FlaxLayer(JaxLayer):
801
839
 
802
840
  def _get_init_rng(self):
803
841
  return {
804
- "params": self.seed_generator.next(),
805
- "dropout": self.seed_generator.next(),
842
+ "params": self._get_init_seed(),
843
+ "dropout": self._get_init_seed(),
806
844
  }
807
845
 
808
846
  def _get_call_rng(self, training):
809
847
  if training:
810
- return {"dropout": self.seed_generator.next()}
848
+ return {"dropout": self._get_call_seed()}
811
849
  else:
812
850
  return {}
813
851
 
keras/src/version.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from keras.src.api_export import keras_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "3.14.0.dev2026011504"
4
+ __version__ = "3.14.0.dev2026011704"
5
5
 
6
6
 
7
7
  @keras_export("keras.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-nightly
3
- Version: 3.14.0.dev2026011504
3
+ Version: 3.14.0.dev2026011704
4
4
  Summary: Multi-backend Keras
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
@@ -128,7 +128,7 @@ keras/regularizers/__init__.py,sha256=542Shphw7W8h4Dyf2rmqMKUECVZ8IVBvN9g1LWhz-b
128
128
  keras/saving/__init__.py,sha256=KvL2GZxjvgFgEhvEnkvqjIR9JSNHKz-NWZacXajsjLI,1298
129
129
  keras/src/__init__.py,sha256=Gi4S7EiCMkE03PbdGNpFdaUYySWDs_FcAJ8Taz9Y1BE,684
130
130
  keras/src/api_export.py,sha256=gXOkBOnmscV013WAc75lc4Up01-Kkg9EylIAT_QWctg,1173
131
- keras/src/version.py,sha256=9qMnmtF-qZTXfhu_aYy3T9wAf18rbTFfpRqEiny8QSU,204
131
+ keras/src/version.py,sha256=PYl1X5NcUeyqMJlZOu02EORfq3XehPpN28bC457e3F8,204
132
132
  keras/src/activations/__init__.py,sha256=0nL3IFDB9unlrMz8ninKOWo-uCHasTUpTo1tXZb2u44,4433
133
133
  keras/src/activations/activations.py,sha256=mogPggtp4CGldI3VOPNmesRxp6EbiR1_i4KLGaVwzL8,17614
134
134
  keras/src/applications/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -214,7 +214,7 @@ keras/src/backend/tensorflow/nn.py,sha256=6vtZHzUED6_blUPE1Tnc3GAxPpJ2ebxoaiMn80
214
214
  keras/src/backend/tensorflow/numpy.py,sha256=nIpMvr-g81I9KF74RD4AbU4e4t-0eFa9MND2Fh1u8Tk,104623
215
215
  keras/src/backend/tensorflow/optimizer.py,sha256=kFlyEOnGjEYdLpd8mpwhUeku78__xBfZbbrDWpJrq60,9307
216
216
  keras/src/backend/tensorflow/random.py,sha256=iO8V_soaDXZm9ewyAVbjudhsMj08C348c9Bz64nxXC4,6475
217
- keras/src/backend/tensorflow/rnn.py,sha256=99EJqbPdWddmG14zyjjhUZfU5zo9ObmslF_Mak7EmAs,34602
217
+ keras/src/backend/tensorflow/rnn.py,sha256=JbOSpt48cm612c7YwiTYOQCQsNXyI_6QeRhtUn8qEvM,34829
218
218
  keras/src/backend/tensorflow/sparse.py,sha256=a_FZcJY-wPl1x4vY0T7j-GORa4SAuMjNEToJLmK0daQ,32247
219
219
  keras/src/backend/tensorflow/tensorboard.py,sha256=e7pXicuMfQjuCmq1wOmixWhWt2EbjLMBo_JPAqCbZRk,504
220
220
  keras/src/backend/tensorflow/trackable.py,sha256=QZn0JvpBJ7Kx4e6zM2IVIWz9ADcWDB-dHN6vjoQBa9Q,1993
@@ -229,7 +229,7 @@ keras/src/backend/torch/math.py,sha256=g-ElDii2Y_o1-t6BAu2nbS7JH-aPqVS5Fqds8aYzI
229
229
  keras/src/backend/torch/nn.py,sha256=zmEzXEuwD7fVRDm145zsxzUDmqNmRgZS4LmeIx4Nbus,37498
230
230
  keras/src/backend/torch/numpy.py,sha256=gvHviedkAoEaTax89wDqUrjbUSX1ndjxicHy-PLv2Nc,57668
231
231
  keras/src/backend/torch/random.py,sha256=YhLfC7qkGpzlU_i6gGPVormo3BMSo7OUA3TC3GCehrA,8292
232
- keras/src/backend/torch/rnn.py,sha256=J0vg7ikxBiv1FzEavgwT8IVCs0ceBcEv5LYyM5C2suA,25545
232
+ keras/src/backend/torch/rnn.py,sha256=MJIVbHKsUA2dZm4Gu2NvRxlrFCWeWSxSZRmFxSsC3Zg,26041
233
233
  keras/src/backend/torch/trainer.py,sha256=dcikz1c5O0FHNzRKSi6WhIHsHfLV2HDlrXPElSd1cgE,17985
234
234
  keras/src/backend/torch/optimizers/__init__.py,sha256=yvqiyKgMEh-nGpacssdpsMySujyYB6lPy-Wil3onXvo,78
235
235
  keras/src/backend/torch/optimizers/torch_adadelta.py,sha256=iPjGHvD7q_VD0WaMNxuNcvz8uIWd0smRyEMzMqryUD4,1672
@@ -401,7 +401,7 @@ keras/src/layers/preprocessing/image_preprocessing/rand_augment.py,sha256=upDdEg
401
401
  keras/src/layers/preprocessing/image_preprocessing/random_brightness.py,sha256=Ix01T1xsbf_QknyWcSlK1SxVPvFNtHw20xmWHhuQPZI,6083
402
402
  keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration.py,sha256=N6rCXPhWCEh-xWqC9ETYwrbJ2f6lIqyCR9Z18uV3xd0,4896
403
403
  keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py,sha256=rbQvLhCPPXyAaYfcMiVzyN0yvfFrcfbRbkVruO9o38U,9464
404
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py,sha256=GvB5iQngY-4v99mGS9dXOlGTX4GB6Z7ZvDqW1TKJR5A,5474
404
+ keras/src/layers/preprocessing/image_preprocessing/random_contrast.py,sha256=eJ7aakES1YfSv1JXjv8ZT3ltTqgG6Oo1_XU6BopKDng,5470
405
405
  keras/src/layers/preprocessing/image_preprocessing/random_crop.py,sha256=y2iHw-xbSV11uK4D34VT9QEkpvKOk-D-TmVSCZUjDn0,10553
406
406
  keras/src/layers/preprocessing/image_preprocessing/random_elastic_transform.py,sha256=fIfPe-906LUhTUDpiuPwM5oEOJ_1UQ9BhMHBFpItcGM,10208
407
407
  keras/src/layers/preprocessing/image_preprocessing/random_erasing.py,sha256=O7f44V805Wta9RMZyks4sl-LViglTCdp7_n-qj_nWbI,11233
@@ -585,7 +585,7 @@ keras/src/utils/grain_utils.py,sha256=Wfwv12E3UrNZjJjTEk2JVV6_YEUav35UJ6bV1UAPEI
585
585
  keras/src/utils/image_dataset_utils.py,sha256=0lOzD1CiXwZOe1wW-5uvFKuIgot9PWUC9KJJA0NVuP8,24017
586
586
  keras/src/utils/image_utils.py,sha256=lGe4iKYQkQ6j15CbHoqpSMC6JEvCrekYBuYGoMClcpo,17051
587
587
  keras/src/utils/io_utils.py,sha256=Riv9TCCnz6xQLUvR1QC-UOCoGZ_KiNTwQVvLY6dKcX8,4432
588
- keras/src/utils/jax_layer.py,sha256=ytws8NcxWzJ4kViBy3bc-Pk3st3_3L8RqXxgq9sYp1k,32912
588
+ keras/src/utils/jax_layer.py,sha256=xwUkk-yp5lieC_uJesn4T4Lkw1bdjtSY5Q-bK8PuHH0,34027
589
589
  keras/src/utils/jax_utils.py,sha256=vY3P4S9mfWEjdirLd81ocKqeCm-UVfgQ1yTi6UHdBiM,322
590
590
  keras/src/utils/model_visualization.py,sha256=0ENeiq8q-qbyGjfcRixyyInb3aTxfcKCooKhZ1hSuI0,17794
591
591
  keras/src/utils/module_utils.py,sha256=FTZPMRLurURchLPX1tu-h3b-UoPW28faNOlDzpYDW6A,2894
@@ -618,7 +618,7 @@ keras/utils/bounding_boxes/__init__.py,sha256=jtvQll4u8ZY0Z96HwNhP1nxWEG9FM3gI-6
618
618
  keras/utils/legacy/__init__.py,sha256=oSYZz6uS8UxSElRaaJYWJEoweJ4GAasZjnn7fNaOlog,342
619
619
  keras/visualization/__init__.py,sha256=UKWmiy6sps4SWlmQi9WX8_Z53cPpLlphz2zIeHdwJpQ,722
620
620
  keras/wrappers/__init__.py,sha256=QkS-O5K8qGS7C3sytF8MpmO6PasATpNVGF8qtb7Ojsw,407
621
- keras_nightly-3.14.0.dev2026011504.dist-info/METADATA,sha256=6s-lhD6ZQgn-dl1KiHGJMLNttJ4ir79glEAKZAMn6fI,6339
622
- keras_nightly-3.14.0.dev2026011504.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
623
- keras_nightly-3.14.0.dev2026011504.dist-info/top_level.txt,sha256=ptcw_-QuGZ4ZDjMdwi_Z0clZm8QAqFdvzzFnDEOTs9o,6
624
- keras_nightly-3.14.0.dev2026011504.dist-info/RECORD,,
621
+ keras_nightly-3.14.0.dev2026011704.dist-info/METADATA,sha256=XtouV2KcEzUqH0W897TEDF7jmTiNPzWoJyzib0rfKAo,6339
622
+ keras_nightly-3.14.0.dev2026011704.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
623
+ keras_nightly-3.14.0.dev2026011704.dist-info/top_level.txt,sha256=ptcw_-QuGZ4ZDjMdwi_Z0clZm8QAqFdvzzFnDEOTs9o,6
624
+ keras_nightly-3.14.0.dev2026011704.dist-info/RECORD,,