tf-keras-nightly 2.19.0.dev2024121210__py3-none-any.whl → 2.21.0.dev2025123010__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tf_keras/__init__.py +1 -1
  2. tf_keras/protobuf/projector_config_pb2.py +23 -12
  3. tf_keras/protobuf/saved_metadata_pb2.py +21 -10
  4. tf_keras/protobuf/versions_pb2.py +19 -8
  5. tf_keras/src/__init__.py +1 -1
  6. tf_keras/src/backend.py +1 -1
  7. tf_keras/src/datasets/boston_housing.py +14 -5
  8. tf_keras/src/datasets/cifar10.py +9 -1
  9. tf_keras/src/datasets/cifar100.py +7 -1
  10. tf_keras/src/datasets/fashion_mnist.py +16 -4
  11. tf_keras/src/datasets/imdb.py +8 -0
  12. tf_keras/src/datasets/mnist.py +9 -3
  13. tf_keras/src/datasets/reuters.py +8 -0
  14. tf_keras/src/engine/base_layer.py +235 -97
  15. tf_keras/src/engine/base_layer_utils.py +17 -5
  16. tf_keras/src/engine/base_layer_v1.py +12 -3
  17. tf_keras/src/engine/data_adapter.py +35 -19
  18. tf_keras/src/engine/functional.py +36 -15
  19. tf_keras/src/engine/input_layer.py +9 -0
  20. tf_keras/src/engine/input_spec.py +11 -1
  21. tf_keras/src/engine/sequential.py +29 -12
  22. tf_keras/src/layers/activation/softmax.py +26 -11
  23. tf_keras/src/layers/attention/multi_head_attention.py +8 -1
  24. tf_keras/src/layers/core/tf_op_layer.py +4 -0
  25. tf_keras/src/layers/normalization/spectral_normalization.py +29 -22
  26. tf_keras/src/layers/rnn/cell_wrappers.py +13 -1
  27. tf_keras/src/metrics/confusion_metrics.py +51 -4
  28. tf_keras/src/models/sharpness_aware_minimization.py +17 -7
  29. tf_keras/src/preprocessing/sequence.py +2 -2
  30. tf_keras/src/saving/legacy/saved_model/save_impl.py +28 -12
  31. tf_keras/src/saving/legacy/saving_utils.py +14 -2
  32. tf_keras/src/saving/saving_api.py +18 -5
  33. tf_keras/src/saving/saving_lib.py +1 -1
  34. tf_keras/src/utils/layer_utils.py +45 -3
  35. tf_keras/src/utils/metrics_utils.py +4 -1
  36. tf_keras/src/utils/tf_utils.py +2 -2
  37. {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/METADATA +14 -3
  38. {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/RECORD +40 -62
  39. {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/WHEEL +1 -1
  40. tf_keras/src/layers/preprocessing/benchmarks/bucketized_column_dense_benchmark.py +0 -85
  41. tf_keras/src/layers/preprocessing/benchmarks/category_encoding_benchmark.py +0 -84
  42. tf_keras/src/layers/preprocessing/benchmarks/category_hash_dense_benchmark.py +0 -89
  43. tf_keras/src/layers/preprocessing/benchmarks/category_hash_varlen_benchmark.py +0 -89
  44. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_dense_benchmark.py +0 -110
  45. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_varlen_benchmark.py +0 -103
  46. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_dense_benchmark.py +0 -87
  47. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_dense_benchmark.py +0 -96
  48. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_varlen_benchmark.py +0 -96
  49. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_varlen_benchmark.py +0 -87
  50. tf_keras/src/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py +0 -109
  51. tf_keras/src/layers/preprocessing/benchmarks/embedding_dense_benchmark.py +0 -86
  52. tf_keras/src/layers/preprocessing/benchmarks/embedding_varlen_benchmark.py +0 -89
  53. tf_keras/src/layers/preprocessing/benchmarks/hashed_crossing_benchmark.py +0 -90
  54. tf_keras/src/layers/preprocessing/benchmarks/hashing_benchmark.py +0 -105
  55. tf_keras/src/layers/preprocessing/benchmarks/image_preproc_benchmark.py +0 -159
  56. tf_keras/src/layers/preprocessing/benchmarks/index_lookup_adapt_benchmark.py +0 -135
  57. tf_keras/src/layers/preprocessing/benchmarks/index_lookup_forward_benchmark.py +0 -144
  58. tf_keras/src/layers/preprocessing/benchmarks/normalization_adapt_benchmark.py +0 -124
  59. tf_keras/src/layers/preprocessing/benchmarks/weighted_embedding_varlen_benchmark.py +0 -99
  60. tf_keras/src/saving/legacy/saved_model/create_test_saved_model.py +0 -37
  61. tf_keras/src/tests/keras_doctest.py +0 -159
  62. {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/top_level.txt +0 -0
@@ -1471,9 +1471,10 @@ class AUC(base_metric.Metric):
1471
1471
  # label_weights should be of length equal to the number of
1472
1472
  # labels.
1473
1473
  shapes.append((self.label_weights, ("L",)))
1474
- tf.debugging.assert_shapes(
1475
- shapes, message="Number of labels is not consistent."
1476
- )
1474
+
1475
+ tf.debugging.assert_shapes(
1476
+ shapes, message="Number of labels is not consistent."
1477
+ )
1477
1478
 
1478
1479
  # Only forward label_weights to update_confusion_matrix_variables when
1479
1480
  # multi_label is False. Otherwise the averaging of individual label AUCs
@@ -1611,13 +1612,59 @@ class AUC(base_metric.Metric):
1611
1612
  )
1612
1613
  x = fp_rate
1613
1614
  y = recall
1614
- else: # curve == 'PR'.
1615
+ elif self.curve == metrics_utils.AUCCurve.PR:
1615
1616
  precision = tf.math.divide_no_nan(
1616
1617
  self.true_positives,
1617
1618
  tf.math.add(self.true_positives, self.false_positives),
1618
1619
  )
1619
1620
  x = recall
1620
1621
  y = precision
1622
+ else: # curve == 'PR_GAIN'.
1623
+ # Due to the hyperbolic transform, this formula is less robust than
1624
+ # ROC or PR values. In particular
1625
+ # 1) Both measures diverge when there are no negative examples;
1626
+ # 2) Both measures diverge when there are no true positives;
1627
+ # 3) Recall gain becomes negative when the recall is lower than the
1628
+ # label average (i.e. when more negative examples are classified
1629
+ # positive than real positives).
1630
+ #
1631
+ # We ignore case 1 as it is easily communicated. For case 2 we set
1632
+ # recall_gain to 0 and precision_gain to 1. For case 3 we set the
1633
+ # recall_gain to 0. These fixes will result in an overastimation of
1634
+ # the AUC for estimateors that are anti-correlated with the label
1635
+ # (at some thresholds).
1636
+ #
1637
+ # The scaling factor $\frac{P}{N}$ that is used to form both
1638
+ # gain values.
1639
+ scaling_factor = tf.math.divide_no_nan(
1640
+ tf.math.add(self.true_positives, self.false_negatives),
1641
+ tf.math.add(self.false_positives, self.true_negatives),
1642
+ )
1643
+
1644
+ recall_gain = 1.0 - scaling_factor * tf.math.divide_no_nan(
1645
+ self.false_negatives, self.true_positives
1646
+ )
1647
+ precision_gain = 1.0 - scaling_factor * tf.math.divide_no_nan(
1648
+ self.false_positives, self.true_positives
1649
+ )
1650
+ # Handle case 2.
1651
+ recall_gain = tf.where(
1652
+ tf.equal(self.true_positives, 0.0),
1653
+ tf.zeros_like(recall_gain),
1654
+ recall_gain,
1655
+ )
1656
+ precision_gain = tf.where(
1657
+ tf.equal(self.true_positives, 0.0),
1658
+ tf.ones_like(precision_gain),
1659
+ precision_gain,
1660
+ )
1661
+ # Handle case 3.
1662
+ recall_gain = tf.math.maximum(
1663
+ recall_gain, tf.zeros_like(recall_gain)
1664
+ )
1665
+
1666
+ x = recall_gain
1667
+ y = precision_gain
1621
1668
 
1622
1669
  # Find the rectangle heights based on `summation_method`.
1623
1670
  if (
@@ -72,17 +72,27 @@ class SharpnessAwareMinimization(Model):
72
72
  if self.num_batch_splits is not None:
73
73
  x_split = tf.split(x, self.num_batch_splits)
74
74
  y_split = tf.split(y, self.num_batch_splits)
75
+ # Split the sample weight if it is provided.
76
+ if sample_weight is not None:
77
+ sample_weight_split = tf.split(
78
+ sample_weight, self.num_batch_splits
79
+ )
80
+ else:
81
+ sample_weight_split = [None] * self.num_batch_splits
75
82
  else:
76
83
  x_split = [x]
77
84
  y_split = [y]
85
+ sample_weight_split = [sample_weight]
78
86
 
79
87
  gradients_all_batches = []
80
88
  pred_all_batches = []
81
- for x_batch, y_batch in zip(x_split, y_split):
89
+ for x_batch, y_batch, sample_weight_batch in zip(
90
+ x_split, y_split, sample_weight_split
91
+ ):
82
92
  epsilon_w_cache = []
83
93
  with tf.GradientTape() as tape:
84
- pred = self.model(x_batch)
85
- loss = self.compiled_loss(y_batch, pred)
94
+ pred = self(x_batch)
95
+ loss = self.compiled_loss(y_batch, pred, sample_weight_batch)
86
96
  pred_all_batches.append(pred)
87
97
  trainable_variables = self.model.trainable_variables
88
98
  gradients = tape.gradient(loss, trainable_variables)
@@ -98,8 +108,8 @@ class SharpnessAwareMinimization(Model):
98
108
  epsilon_w_cache.append(epsilon_w)
99
109
 
100
110
  with tf.GradientTape() as tape:
101
- pred = self(x_batch)
102
- loss = self.compiled_loss(y_batch, pred)
111
+ pred = self(x_batch, training=True)
112
+ loss = self.compiled_loss(y_batch, pred, sample_weight_batch)
103
113
  gradients = tape.gradient(loss, trainable_variables)
104
114
  if len(gradients_all_batches) == 0:
105
115
  for gradient in gradients:
@@ -127,7 +137,7 @@ class SharpnessAwareMinimization(Model):
127
137
  self.compiled_metrics.update_state(y, pred, sample_weight)
128
138
  return {m.name: m.result() for m in self.metrics}
129
139
 
130
- def call(self, inputs):
140
+ def call(self, inputs, **kwargs):
131
141
  """Forward pass of SAM.
132
142
 
133
143
  SAM delegates the forward pass call to the wrapped model.
@@ -138,7 +148,7 @@ class SharpnessAwareMinimization(Model):
138
148
  Returns:
139
149
  A Tensor, the outputs of the wrapped model for given `inputs`.
140
150
  """
141
- return self.model(inputs)
151
+ return self.model(inputs, **kwargs)
142
152
 
143
153
  def get_config(self):
144
154
  config = super().get_config()
@@ -365,7 +365,7 @@ def skipgrams(
365
365
  random.shuffle(words)
366
366
 
367
367
  couples += [
368
- [words[i % len(words)], random.randint(1, vocabulary_size - 1)]
368
+ [words[i % len(words)], random.randint(1, int(vocabulary_size - 1))]
369
369
  for i in range(num_negative_samples)
370
370
  ]
371
371
  if categorical:
@@ -375,7 +375,7 @@ def skipgrams(
375
375
 
376
376
  if shuffle:
377
377
  if seed is None:
378
- seed = random.randint(0, 10e6)
378
+ seed = random.randint(0, int(10e6))
379
379
  random.seed(seed)
380
380
  random.shuffle(couples)
381
381
  random.seed(seed)
@@ -219,7 +219,11 @@ def wrap_layer_functions(layer, serialization_cache):
219
219
  with tracing_scope():
220
220
  call_collection.trace_with_input_signature()
221
221
  with base_layer_utils.call_context().enter(
222
- layer, inputs=None, build_graph=True, training=None, saving=True
222
+ layer,
223
+ inputs=None,
224
+ build_graph=True,
225
+ call_context_args={},
226
+ saving=True,
223
227
  ):
224
228
  for fn in fns.values():
225
229
  if fn is not None and not isinstance(fn, LayerCall):
@@ -515,19 +519,28 @@ class LayerCallCollection:
515
519
  else:
516
520
  add_trace_to_queue(fn, args, kwargs)
517
521
 
518
- def training_arg_was_passed(self, args, kwargs):
522
+ def arg_was_passed(self, arg_name, args, kwargs):
523
+ """Returns True if the argument was passed to the call function."""
519
524
  return self._call_spec.arg_was_passed(
520
- "training", args, kwargs, inputs_in_args=True
525
+ arg_name, args, kwargs, inputs_in_args=True
521
526
  )
522
527
 
523
- def get_training_arg_value(self, args, kwargs):
528
+ def training_arg_was_passed(self, args, kwargs):
529
+ """Returns True if the training arg was passed to the call function."""
530
+ return self.arg_was_passed("training", args, kwargs)
531
+
532
+ def get_arg_value(self, arg_name, args, kwargs):
533
+ """Returns the value of the given argument or None if not found."""
524
534
  try:
525
535
  return self._call_spec.get_arg_value(
526
- "training", args, kwargs, inputs_in_args=True
536
+ arg_name, args, kwargs, inputs_in_args=True
527
537
  )
528
- except KeyError: # Training is not in args or kwargs.
538
+ except KeyError: # Arg not found in args or kwargs.
529
539
  return None
530
540
 
541
+ def get_training_arg_value(self, args, kwargs):
542
+ return self.get_arg_value("training", args, kwargs)
543
+
531
544
  def get_input_arg_value(self, args, kwargs):
532
545
  return self._call_spec.get_arg_value(
533
546
  self._input_arg_name, args, kwargs, inputs_in_args=True
@@ -613,20 +626,23 @@ def layer_call_wrapper(call_collection, method, name):
613
626
  def wrapper(*args, **kwargs):
614
627
  """Calls method within call context."""
615
628
  layer = call_collection.layer
616
- training = None
629
+ propagated = {"training": None}
617
630
  inputs = _filtered_inputs([args, kwargs])
618
631
 
619
- if (args or kwargs) and call_collection.training_arg_was_passed(
620
- args, kwargs
621
- ):
622
- training = call_collection.get_training_arg_value(args, kwargs)
632
+ for context_arg in layer._call_context_args:
633
+ if (args or kwargs) and call_collection.arg_was_passed(
634
+ context_arg, args, kwargs
635
+ ):
636
+ propagated[context_arg] = call_collection.get_arg_value(
637
+ context_arg, args, kwargs
638
+ )
623
639
 
624
640
  original_losses = _reset_layer_losses(layer)
625
641
  with base_layer_utils.call_context().enter(
626
642
  layer,
627
643
  inputs=inputs,
628
644
  build_graph=False,
629
- training=training,
645
+ call_context_args=propagated,
630
646
  saving=True,
631
647
  ):
632
648
  with autocast_variable.enable_auto_cast_variables(
@@ -138,12 +138,24 @@ def trace_model_call(model, input_signature=None):
138
138
  @tf.function
139
139
  def _wrapped_model(*args, **kwargs):
140
140
  """A concrete tf.function that wraps the model's call function."""
141
+ call_context = base_layer_utils.call_context()
142
+
143
+ args, kwargs, propagated = model._get_propagated_call_context_arguments(
144
+ args, kwargs, call_context, model._call_context_args
145
+ )
146
+
141
147
  (args, kwargs,) = model._call_spec.set_arg_value(
142
148
  "training", False, args, kwargs, inputs_in_args=True
143
149
  )
144
150
 
145
- with base_layer_utils.call_context().enter(
146
- model, inputs=None, build_graph=False, training=False, saving=True
151
+ propagated["training"] = False
152
+
153
+ with call_context.enter(
154
+ model,
155
+ inputs=None,
156
+ build_graph=False,
157
+ call_context_args=propagated,
158
+ saving=True,
147
159
  ):
148
160
  outputs = model(*args, **kwargs)
149
161
 
@@ -24,6 +24,7 @@ from tensorflow.python.util.tf_export import keras_export
24
24
 
25
25
  from tf_keras.src.saving import saving_lib
26
26
  from tf_keras.src.saving.legacy import save as legacy_sm_saving_lib
27
+ from tf_keras.src.saving.legacy import saving_utils
27
28
  from tf_keras.src.utils import io_utils
28
29
 
29
30
  try:
@@ -75,8 +76,7 @@ class SupportWriteToRemote:
75
76
  supports remoted saved model out of the box.
76
77
  """
77
78
 
78
- def __init__(self, filepath, overwrite=True, save_format=None):
79
- save_format = get_save_format(filepath, save_format=save_format)
79
+ def __init__(self, filepath, overwrite, save_format):
80
80
  self.overwrite = overwrite
81
81
  if saving_lib.is_remote_path(filepath) and save_format != "tf":
82
82
  self.temp_directory = tempfile.TemporaryDirectory()
@@ -191,14 +191,14 @@ def save_model(model, filepath, overwrite=True, save_format=None, **kwargs):
191
191
  when loading the model. See the `custom_objects` argument in
192
192
  `tf.keras.saving.load_model`.
193
193
  """
194
+ save_format = get_save_format(filepath, save_format)
195
+
194
196
  # Supports remote paths via a temporary file
195
197
  with SupportWriteToRemote(
196
198
  filepath,
197
199
  overwrite=overwrite,
198
200
  save_format=save_format,
199
201
  ) as local_filepath:
200
- save_format = get_save_format(filepath, save_format)
201
-
202
202
  # Deprecation warnings
203
203
  if save_format == "h5":
204
204
  warnings.warn(
@@ -307,8 +307,12 @@ def load_model(
307
307
 
308
308
 
309
309
  def save_weights(model, filepath, overwrite=True, **kwargs):
310
+ save_format = get_save_weights_format(filepath)
311
+
310
312
  # Supports remote paths via a temporary file
311
- with SupportWriteToRemote(filepath, overwrite=overwrite) as local_filepath:
313
+ with SupportWriteToRemote(
314
+ filepath, overwrite=overwrite, save_format=save_format
315
+ ) as local_filepath:
312
316
  if str(local_filepath).endswith(".weights.h5"):
313
317
  # If file exists and should not be overwritten.
314
318
  try:
@@ -386,3 +390,12 @@ def get_save_format(filepath, save_format):
386
390
  else:
387
391
  return "h5"
388
392
 
393
+
394
+ def get_save_weights_format(filepath):
395
+ filepath = io_utils.path_to_string(filepath)
396
+ filepath_is_h5 = saving_utils.is_hdf5_filepath(filepath)
397
+ if filepath_is_h5:
398
+ return "h5"
399
+ else:
400
+ return "tf"
401
+
@@ -639,7 +639,7 @@ class NpzIOStore:
639
639
  self.f = archive.open(root_path, mode="r")
640
640
  else:
641
641
  self.f = open(root_path, mode="rb")
642
- self.contents = np.load(self.f, allow_pickle=True)
642
+ self.contents = np.load(self.f, allow_pickle=False)
643
643
 
644
644
  def make(self, path):
645
645
  if not path:
@@ -775,11 +775,13 @@ class CallFunctionSpec:
775
775
  """Caches the spec and provides utilities for handling call function
776
776
  args."""
777
777
 
778
- def __init__(self, full_argspec):
778
+ def __init__(self, full_argspec, call_context_args=set()):
779
779
  """Initialies a `CallFunctionSpec`.
780
780
 
781
781
  Args:
782
782
  full_argspec: the FullArgSpec of a call function of a layer.
783
+ call_context_args: The set of call-context arguments registered
784
+ with to the current layer.
783
785
  """
784
786
  self._full_argspec = full_argspec
785
787
 
@@ -797,6 +799,18 @@ class CallFunctionSpec:
797
799
  "mask" in self._arg_names or call_accepts_kwargs
798
800
  )
799
801
 
802
+ # Track the set of call-context arguments that the current layer's
803
+ # `call` method accepts.
804
+ self._expected_context_args = set()
805
+ self._update_call_context_arguments(call_context_args)
806
+
807
+ self._context_arg_defaults = dict()
808
+ self._update_call_context_argument_defaults(call_context_args)
809
+
810
+ def _update_call_context_argument_defaults(self, context_args):
811
+ """Updates the set of call-context argument defaults for the current
812
+ layer's `call` method.
813
+ """
800
814
  call_fn_defaults = self._full_argspec.defaults or []
801
815
  defaults = dict()
802
816
  # The call arg defaults are an n-tuple of the last n elements of the
@@ -806,7 +820,21 @@ class CallFunctionSpec:
806
820
  # The default training arg will be any (non-None) default specified in
807
821
  # the method signature, or None if no value is specified.
808
822
  defaults.update(self._full_argspec.kwonlydefaults or {})
809
- self._default_training_arg = defaults.get("training")
823
+
824
+ for arg in context_args:
825
+ self._context_arg_defaults[arg] = defaults.get(arg)
826
+
827
+ def _update_call_context_arguments(self, context_args):
828
+ """Updates the set of call-context arguments that the current layer's
829
+ `call` method accepts.
830
+ """
831
+ call_accepts_kwargs = self._full_argspec.varkw is not None
832
+ args_to_add = {
833
+ arg
834
+ for arg in context_args
835
+ if call_accepts_kwargs or arg in self._arg_names
836
+ }
837
+ self._expected_context_args.update(args_to_add)
810
838
 
811
839
  @property
812
840
  def full_argspec(self):
@@ -843,6 +871,16 @@ class CallFunctionSpec:
843
871
  def expects_training_arg(self, value):
844
872
  self._expects_training_arg = value
845
873
 
874
+ @property
875
+ def expected_context_args(self):
876
+ """The set of call-context arguments that the current layer's
877
+ `call` method accepts."""
878
+ return self._expected_context_args
879
+
880
+ @expected_context_args.setter
881
+ def expected_context_args(self, value):
882
+ self._expected_context_args = value
883
+
846
884
  @property
847
885
  def expects_mask_arg(self):
848
886
  """Whether the call function uses `mask` as a parameter."""
@@ -855,7 +893,11 @@ class CallFunctionSpec:
855
893
  @property
856
894
  def default_training_arg(self):
857
895
  """The default value given to the "training" argument."""
858
- return self._default_training_arg
896
+ return self.get_context_arg_default("training")
897
+
898
+ def get_context_arg_default(self, arg_name):
899
+ """The default value given to the call context arguments."""
900
+ return self._context_arg_defaults.get(arg_name, None)
859
901
 
860
902
  def arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False):
861
903
  """Returns true if argument is present in `args` or `kwargs`.
@@ -237,6 +237,7 @@ class AUCCurve(Enum):
237
237
 
238
238
  ROC = "ROC"
239
239
  PR = "PR"
240
+ PR_GAIN = "PR_GAIN"
240
241
 
241
242
  @staticmethod
242
243
  def from_str(key):
@@ -244,10 +245,12 @@ class AUCCurve(Enum):
244
245
  return AUCCurve.PR
245
246
  elif key in ("roc", "ROC"):
246
247
  return AUCCurve.ROC
248
+ elif key in ("pr_gain", "prgain", "PR_GAIN", "PRGAIN"):
249
+ return AUCCurve.PR_GAIN
247
250
  else:
248
251
  raise ValueError(
249
252
  f'Invalid AUC curve value: "{key}". '
250
- 'Expected values are ["PR", "ROC"]'
253
+ 'Expected values are ["PR", "ROC", "PR_GAIN"]'
251
254
  )
252
255
 
253
256
 
@@ -78,9 +78,9 @@ def get_random_seed():
78
78
  the random seed as an integer.
79
79
  """
80
80
  if getattr(backend._SEED_GENERATOR, "generator", None):
81
- return backend._SEED_GENERATOR.generator.randint(1, 1e9)
81
+ return backend._SEED_GENERATOR.generator.randint(1, int(1e9))
82
82
  else:
83
- return random.randint(1, 1e9)
83
+ return random.randint(1, int(1e9))
84
84
 
85
85
 
86
86
  def is_tensor_or_tensor_list(v):
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: tf_keras-nightly
3
- Version: 2.19.0.dev2024121210
3
+ Version: 2.21.0.dev2025123010
4
4
  Summary: Deep learning for humans.
5
5
  Home-page: https://keras.io/
6
6
  Download-URL: https://github.com/keras-team/tf-keras/tags
@@ -26,7 +26,18 @@ Classifier: Topic :: Software Development
26
26
  Classifier: Topic :: Software Development :: Libraries
27
27
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
28
28
  Requires-Python: >=3.9
29
- Requires-Dist: tf-nightly~=2.19.0.dev
29
+ Requires-Dist: tf-nightly~=2.21.0.dev
30
+ Dynamic: author
31
+ Dynamic: author-email
32
+ Dynamic: classifier
33
+ Dynamic: description
34
+ Dynamic: download-url
35
+ Dynamic: home-page
36
+ Dynamic: keywords
37
+ Dynamic: license
38
+ Dynamic: requires-dist
39
+ Dynamic: requires-python
40
+ Dynamic: summary
30
41
 
31
42
  TF-Keras is a deep learning API written in Python,
32
43
  running on top of the machine learning platform TensorFlow.