tf-keras-nightly 2.20.0.dev2025062209__py3-none-any.whl → 2.20.0.dev2025082818__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tf-keras-nightly might be problematic. Click here for more details.

Files changed (33) hide show
  1. tf_keras/__init__.py +1 -1
  2. tf_keras/protobuf/projector_config_pb2.py +23 -12
  3. tf_keras/protobuf/saved_metadata_pb2.py +21 -10
  4. tf_keras/protobuf/versions_pb2.py +19 -8
  5. tf_keras/src/metrics/confusion_metrics.py +47 -1
  6. tf_keras/src/models/sharpness_aware_minimization.py +17 -7
  7. tf_keras/src/utils/metrics_utils.py +4 -1
  8. {tf_keras_nightly-2.20.0.dev2025062209.dist-info → tf_keras_nightly-2.20.0.dev2025082818.dist-info}/METADATA +1 -1
  9. {tf_keras_nightly-2.20.0.dev2025062209.dist-info → tf_keras_nightly-2.20.0.dev2025082818.dist-info}/RECORD +11 -33
  10. tf_keras/src/layers/preprocessing/benchmarks/bucketized_column_dense_benchmark.py +0 -85
  11. tf_keras/src/layers/preprocessing/benchmarks/category_encoding_benchmark.py +0 -84
  12. tf_keras/src/layers/preprocessing/benchmarks/category_hash_dense_benchmark.py +0 -89
  13. tf_keras/src/layers/preprocessing/benchmarks/category_hash_varlen_benchmark.py +0 -89
  14. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_dense_benchmark.py +0 -110
  15. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_varlen_benchmark.py +0 -103
  16. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_dense_benchmark.py +0 -87
  17. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_dense_benchmark.py +0 -96
  18. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_varlen_benchmark.py +0 -96
  19. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_varlen_benchmark.py +0 -87
  20. tf_keras/src/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py +0 -109
  21. tf_keras/src/layers/preprocessing/benchmarks/embedding_dense_benchmark.py +0 -86
  22. tf_keras/src/layers/preprocessing/benchmarks/embedding_varlen_benchmark.py +0 -89
  23. tf_keras/src/layers/preprocessing/benchmarks/hashed_crossing_benchmark.py +0 -90
  24. tf_keras/src/layers/preprocessing/benchmarks/hashing_benchmark.py +0 -105
  25. tf_keras/src/layers/preprocessing/benchmarks/image_preproc_benchmark.py +0 -159
  26. tf_keras/src/layers/preprocessing/benchmarks/index_lookup_adapt_benchmark.py +0 -135
  27. tf_keras/src/layers/preprocessing/benchmarks/index_lookup_forward_benchmark.py +0 -144
  28. tf_keras/src/layers/preprocessing/benchmarks/normalization_adapt_benchmark.py +0 -124
  29. tf_keras/src/layers/preprocessing/benchmarks/weighted_embedding_varlen_benchmark.py +0 -99
  30. tf_keras/src/saving/legacy/saved_model/create_test_saved_model.py +0 -37
  31. tf_keras/src/tests/keras_doctest.py +0 -159
  32. {tf_keras_nightly-2.20.0.dev2025062209.dist-info → tf_keras_nightly-2.20.0.dev2025082818.dist-info}/WHEEL +0 -0
  33. {tf_keras_nightly-2.20.0.dev2025062209.dist-info → tf_keras_nightly-2.20.0.dev2025082818.dist-info}/top_level.txt +0 -0
tf_keras/__init__.py CHANGED
@@ -27,4 +27,4 @@ from tf_keras.src.engine.sequential import Sequential
27
27
  from tf_keras.src.engine.training import Model
28
28
 
29
29
 
30
- __version__ = "2.20.0.dev2025062209"
30
+ __version__ = "2.20.0.dev2025082818"
@@ -1,11 +1,22 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
3
4
  # source: tf_keras/protobuf/projector_config.proto
5
+ # Protobuf Python Version: 5.28.3
4
6
  """Generated protocol buffer code."""
5
- from google.protobuf.internal import builder as _builder
6
7
  from google.protobuf import descriptor as _descriptor
7
8
  from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
8
10
  from google.protobuf import symbol_database as _symbol_database
11
+ from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 5,
15
+ 28,
16
+ 3,
17
+ '',
18
+ 'tf_keras/protobuf/projector_config.proto'
19
+ )
9
20
  # @@protoc_insertion_point(imports)
10
21
 
11
22
  _sym_db = _symbol_database.Default()
@@ -15,15 +26,15 @@ _sym_db = _symbol_database.Default()
15
26
 
16
27
  DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n(tf_keras/protobuf/projector_config.proto\x12 third_party.py.tf_keras.protobuf\">\n\x0eSpriteMetadata\x12\x12\n\nimage_path\x18\x01 \x01(\t\x12\x18\n\x10single_image_dim\x18\x02 \x03(\r\"\xc0\x01\n\rEmbeddingInfo\x12\x13\n\x0btensor_name\x18\x01 \x01(\t\x12\x15\n\rmetadata_path\x18\x02 \x01(\t\x12\x16\n\x0e\x62ookmarks_path\x18\x03 \x01(\t\x12\x14\n\x0ctensor_shape\x18\x04 \x03(\r\x12@\n\x06sprite\x18\x05 \x01(\x0b\x32\x30.third_party.py.tf_keras.protobuf.SpriteMetadata\x12\x13\n\x0btensor_path\x18\x06 \x01(\t\"\x93\x01\n\x0fProjectorConfig\x12\x1d\n\x15model_checkpoint_path\x18\x01 \x01(\t\x12\x43\n\nembeddings\x18\x02 \x03(\x0b\x32/.third_party.py.tf_keras.protobuf.EmbeddingInfo\x12\x1c\n\x14model_checkpoint_dir\x18\x03 \x01(\tb\x06proto3')
17
28
 
18
- _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
19
- _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.projector_config_pb2', globals())
20
- if _descriptor._USE_C_DESCRIPTORS == False:
21
-
22
- DESCRIPTOR._options = None
23
- _SPRITEMETADATA._serialized_start=78
24
- _SPRITEMETADATA._serialized_end=140
25
- _EMBEDDINGINFO._serialized_start=143
26
- _EMBEDDINGINFO._serialized_end=335
27
- _PROJECTORCONFIG._serialized_start=338
28
- _PROJECTORCONFIG._serialized_end=485
29
+ _globals = globals()
30
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
31
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.projector_config_pb2', _globals)
32
+ if not _descriptor._USE_C_DESCRIPTORS:
33
+ DESCRIPTOR._loaded_options = None
34
+ _globals['_SPRITEMETADATA']._serialized_start=78
35
+ _globals['_SPRITEMETADATA']._serialized_end=140
36
+ _globals['_EMBEDDINGINFO']._serialized_start=143
37
+ _globals['_EMBEDDINGINFO']._serialized_end=335
38
+ _globals['_PROJECTORCONFIG']._serialized_start=338
39
+ _globals['_PROJECTORCONFIG']._serialized_end=485
29
40
  # @@protoc_insertion_point(module_scope)
@@ -1,11 +1,22 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
3
4
  # source: tf_keras/protobuf/saved_metadata.proto
5
+ # Protobuf Python Version: 5.28.3
4
6
  """Generated protocol buffer code."""
5
- from google.protobuf.internal import builder as _builder
6
7
  from google.protobuf import descriptor as _descriptor
7
8
  from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
8
10
  from google.protobuf import symbol_database as _symbol_database
11
+ from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 5,
15
+ 28,
16
+ 3,
17
+ '',
18
+ 'tf_keras/protobuf/saved_metadata.proto'
19
+ )
9
20
  # @@protoc_insertion_point(imports)
10
21
 
11
22
  _sym_db = _symbol_database.Default()
@@ -16,13 +27,13 @@ from tf_keras.protobuf import versions_pb2 as tf__keras_dot_protobuf_dot_version
16
27
 
17
28
  DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n&tf_keras/protobuf/saved_metadata.proto\x12 third_party.py.tf_keras.protobuf\x1a tf_keras/protobuf/versions.proto\"M\n\rSavedMetadata\x12<\n\x05nodes\x18\x01 \x03(\x0b\x32-.third_party.py.tf_keras.protobuf.SavedObject\"\x9c\x01\n\x0bSavedObject\x12\x0f\n\x07node_id\x18\x02 \x01(\x05\x12\x11\n\tnode_path\x18\x03 \x01(\t\x12\x12\n\nidentifier\x18\x04 \x01(\t\x12\x10\n\x08metadata\x18\x05 \x01(\t\x12=\n\x07version\x18\x06 \x01(\x0b\x32,.third_party.py.tf_keras.protobuf.VersionDefJ\x04\x08\x01\x10\x02\x62\x06proto3')
18
29
 
19
- _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
20
- _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.saved_metadata_pb2', globals())
21
- if _descriptor._USE_C_DESCRIPTORS == False:
22
-
23
- DESCRIPTOR._options = None
24
- _SAVEDMETADATA._serialized_start=110
25
- _SAVEDMETADATA._serialized_end=187
26
- _SAVEDOBJECT._serialized_start=190
27
- _SAVEDOBJECT._serialized_end=346
30
+ _globals = globals()
31
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
32
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.saved_metadata_pb2', _globals)
33
+ if not _descriptor._USE_C_DESCRIPTORS:
34
+ DESCRIPTOR._loaded_options = None
35
+ _globals['_SAVEDMETADATA']._serialized_start=110
36
+ _globals['_SAVEDMETADATA']._serialized_end=187
37
+ _globals['_SAVEDOBJECT']._serialized_start=190
38
+ _globals['_SAVEDOBJECT']._serialized_end=346
28
39
  # @@protoc_insertion_point(module_scope)
@@ -1,11 +1,22 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
3
4
  # source: tf_keras/protobuf/versions.proto
5
+ # Protobuf Python Version: 5.28.3
4
6
  """Generated protocol buffer code."""
5
- from google.protobuf.internal import builder as _builder
6
7
  from google.protobuf import descriptor as _descriptor
7
8
  from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
8
10
  from google.protobuf import symbol_database as _symbol_database
11
+ from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 5,
15
+ 28,
16
+ 3,
17
+ '',
18
+ 'tf_keras/protobuf/versions.proto'
19
+ )
9
20
  # @@protoc_insertion_point(imports)
10
21
 
11
22
  _sym_db = _symbol_database.Default()
@@ -15,11 +26,11 @@ _sym_db = _symbol_database.Default()
15
26
 
16
27
  DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n tf_keras/protobuf/versions.proto\x12 third_party.py.tf_keras.protobuf\"K\n\nVersionDef\x12\x10\n\x08producer\x18\x01 \x01(\x05\x12\x14\n\x0cmin_consumer\x18\x02 \x01(\x05\x12\x15\n\rbad_consumers\x18\x03 \x03(\x05\x62\x06proto3')
17
28
 
18
- _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
19
- _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.versions_pb2', globals())
20
- if _descriptor._USE_C_DESCRIPTORS == False:
21
-
22
- DESCRIPTOR._options = None
23
- _VERSIONDEF._serialized_start=70
24
- _VERSIONDEF._serialized_end=145
29
+ _globals = globals()
30
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
31
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.versions_pb2', _globals)
32
+ if not _descriptor._USE_C_DESCRIPTORS:
33
+ DESCRIPTOR._loaded_options = None
34
+ _globals['_VERSIONDEF']._serialized_start=70
35
+ _globals['_VERSIONDEF']._serialized_end=145
25
36
  # @@protoc_insertion_point(module_scope)
@@ -1612,13 +1612,59 @@ class AUC(base_metric.Metric):
1612
1612
  )
1613
1613
  x = fp_rate
1614
1614
  y = recall
1615
- else: # curve == 'PR'.
1615
+ elif self.curve == metrics_utils.AUCCurve.PR:
1616
1616
  precision = tf.math.divide_no_nan(
1617
1617
  self.true_positives,
1618
1618
  tf.math.add(self.true_positives, self.false_positives),
1619
1619
  )
1620
1620
  x = recall
1621
1621
  y = precision
1622
+ else: # curve == 'PR_GAIN'.
1623
+ # Due to the hyperbolic transform, this formula is less robust than
1624
+ # ROC or PR values. In particular
1625
+ # 1) Both measures diverge when there are no negative examples;
1626
+ # 2) Both measures diverge when there are no true positives;
1627
+ # 3) Recall gain becomes negative when the recall is lower than the
1628
+ # label average (i.e. when more negative examples are classified
1629
+ # positive than real positives).
1630
+ #
1631
+ # We ignore case 1 as it is easily communicated. For case 2 we set
1632
+ # recall_gain to 0 and precision_gain to 1. For case 3 we set the
1633
+ # recall_gain to 0. These fixes will result in an overastimation of
1634
+ # the AUC for estimateors that are anti-correlated with the label
1635
+ # (at some thresholds).
1636
+ #
1637
+ # The scaling factor $\frac{P}{N}$ that is used to form both
1638
+ # gain values.
1639
+ scaling_factor = tf.math.divide_no_nan(
1640
+ tf.math.add(self.true_positives, self.false_negatives),
1641
+ tf.math.add(self.false_positives, self.true_negatives),
1642
+ )
1643
+
1644
+ recall_gain = 1.0 - scaling_factor * tf.math.divide_no_nan(
1645
+ self.false_negatives, self.true_positives
1646
+ )
1647
+ precision_gain = 1.0 - scaling_factor * tf.math.divide_no_nan(
1648
+ self.false_positives, self.true_positives
1649
+ )
1650
+ # Handle case 2.
1651
+ recall_gain = tf.where(
1652
+ tf.equal(self.true_positives, 0.0),
1653
+ tf.zeros_like(recall_gain),
1654
+ recall_gain,
1655
+ )
1656
+ precision_gain = tf.where(
1657
+ tf.equal(self.true_positives, 0.0),
1658
+ tf.ones_like(precision_gain),
1659
+ precision_gain,
1660
+ )
1661
+ # Handle case 3.
1662
+ recall_gain = tf.math.maximum(
1663
+ recall_gain, tf.zeros_like(recall_gain)
1664
+ )
1665
+
1666
+ x = recall_gain
1667
+ y = precision_gain
1622
1668
 
1623
1669
  # Find the rectangle heights based on `summation_method`.
1624
1670
  if (
@@ -72,17 +72,27 @@ class SharpnessAwareMinimization(Model):
72
72
  if self.num_batch_splits is not None:
73
73
  x_split = tf.split(x, self.num_batch_splits)
74
74
  y_split = tf.split(y, self.num_batch_splits)
75
+ # Split the sample weight if it is provided.
76
+ if sample_weight is not None:
77
+ sample_weight_split = tf.split(
78
+ sample_weight, self.num_batch_splits
79
+ )
80
+ else:
81
+ sample_weight_split = [None] * self.num_batch_splits
75
82
  else:
76
83
  x_split = [x]
77
84
  y_split = [y]
85
+ sample_weight_split = [sample_weight]
78
86
 
79
87
  gradients_all_batches = []
80
88
  pred_all_batches = []
81
- for x_batch, y_batch in zip(x_split, y_split):
89
+ for x_batch, y_batch, sample_weight_batch in zip(
90
+ x_split, y_split, sample_weight_split
91
+ ):
82
92
  epsilon_w_cache = []
83
93
  with tf.GradientTape() as tape:
84
- pred = self.model(x_batch)
85
- loss = self.compiled_loss(y_batch, pred)
94
+ pred = self(x_batch)
95
+ loss = self.compiled_loss(y_batch, pred, sample_weight_batch)
86
96
  pred_all_batches.append(pred)
87
97
  trainable_variables = self.model.trainable_variables
88
98
  gradients = tape.gradient(loss, trainable_variables)
@@ -98,8 +108,8 @@ class SharpnessAwareMinimization(Model):
98
108
  epsilon_w_cache.append(epsilon_w)
99
109
 
100
110
  with tf.GradientTape() as tape:
101
- pred = self(x_batch)
102
- loss = self.compiled_loss(y_batch, pred)
111
+ pred = self(x_batch, training=True)
112
+ loss = self.compiled_loss(y_batch, pred, sample_weight_batch)
103
113
  gradients = tape.gradient(loss, trainable_variables)
104
114
  if len(gradients_all_batches) == 0:
105
115
  for gradient in gradients:
@@ -127,7 +137,7 @@ class SharpnessAwareMinimization(Model):
127
137
  self.compiled_metrics.update_state(y, pred, sample_weight)
128
138
  return {m.name: m.result() for m in self.metrics}
129
139
 
130
- def call(self, inputs):
140
+ def call(self, inputs, **kwargs):
131
141
  """Forward pass of SAM.
132
142
 
133
143
  SAM delegates the forward pass call to the wrapped model.
@@ -138,7 +148,7 @@ class SharpnessAwareMinimization(Model):
138
148
  Returns:
139
149
  A Tensor, the outputs of the wrapped model for given `inputs`.
140
150
  """
141
- return self.model(inputs)
151
+ return self.model(inputs, **kwargs)
142
152
 
143
153
  def get_config(self):
144
154
  config = super().get_config()
@@ -237,6 +237,7 @@ class AUCCurve(Enum):
237
237
 
238
238
  ROC = "ROC"
239
239
  PR = "PR"
240
+ PR_GAIN = "PR_GAIN"
240
241
 
241
242
  @staticmethod
242
243
  def from_str(key):
@@ -244,10 +245,12 @@ class AUCCurve(Enum):
244
245
  return AUCCurve.PR
245
246
  elif key in ("roc", "ROC"):
246
247
  return AUCCurve.ROC
248
+ elif key in ("pr_gain", "prgain", "PR_GAIN", "PRGAIN"):
249
+ return AUCCurve.PR_GAIN
247
250
  else:
248
251
  raise ValueError(
249
252
  f'Invalid AUC curve value: "{key}". '
250
- 'Expected values are ["PR", "ROC"]'
253
+ 'Expected values are ["PR", "ROC", "PR_GAIN"]'
251
254
  )
252
255
 
253
256
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tf_keras-nightly
3
- Version: 2.20.0.dev2025062209
3
+ Version: 2.20.0.dev2025082818
4
4
  Summary: Deep learning for humans.
5
5
  Home-page: https://keras.io/
6
6
  Download-URL: https://github.com/keras-team/tf-keras/tags
@@ -1,4 +1,4 @@
1
- tf_keras/__init__.py,sha256=NmpTp2UuYQk3UJgOiT28WwSxiqm59P7O4EhQNyqJslQ,911
1
+ tf_keras/__init__.py,sha256=ozEje7VN-_ziSdFlB85zormVRKYEb1VaRO-LO33z8OI,911
2
2
  tf_keras/__internal__/__init__.py,sha256=OHQbeIC0QtRBI7dgXaJaVbH8F00x8dCI-DvEcIfyMsE,671
3
3
  tf_keras/__internal__/backend/__init__.py,sha256=LnMs2A6685gDG79fxqmdulIYlVE_3WmXlBTBo9ZWYcw,162
4
4
  tf_keras/__internal__/layers/__init__.py,sha256=F5SGMhOTPzm-PR44VrfinURHcVeQPIEdwnZlAkSTB3A,176
@@ -201,9 +201,9 @@ tf_keras/preprocessing/image/__init__.py,sha256=H6rbMLtlGIy_jBLCSDklVTMXUjEUe8KQ
201
201
  tf_keras/preprocessing/sequence/__init__.py,sha256=Zg9mw0TIRIc-BmVtdXvW3jdIQo05VHZX_xmqZDMuaik,285
202
202
  tf_keras/preprocessing/text/__init__.py,sha256=1yQd-VZD6SjnEpPyBFLucYMxu9A5DnAnIec2tba9zQk,329
203
203
  tf_keras/protobuf/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
204
- tf_keras/protobuf/projector_config_pb2.py,sha256=GHQfZbNY6IgeVYvL1A9o5ET7EiG-jv--zhbAJg9Ez3k,1821
205
- tf_keras/protobuf/saved_metadata_pb2.py,sha256=K4ROX6DQeyFej5TBrUvfY7e_gzpQuCRuiuiVgk3ehhg,1585
206
- tf_keras/protobuf/versions_pb2.py,sha256=HP6fzinb4-KIEZaINXIAe-BpxQnGROxrxECgGcpcvFE,1119
204
+ tf_keras/protobuf/projector_config_pb2.py,sha256=B9MGtu-Dr7P6DXq7zNAjJyWNXCePXjgWdEWGWtD2OoI,2208
205
+ tf_keras/protobuf/saved_metadata_pb2.py,sha256=LvEonuYDv46N-ANJ3K_1hys8rrSbp-28lI4NiTrmNvo,1946
206
+ tf_keras/protobuf/versions_pb2.py,sha256=bilE48UoIQqPZEPGIvhIkRXQB9-gI7XH28PG-ZjD4ew,1450
207
207
  tf_keras/regularizers/__init__.py,sha256=D6TnroEDjnyP79TY_624g2DToxVWuKzuaiBAn_gUQaY,634
208
208
  tf_keras/saving/__init__.py,sha256=Xo0imlDhiYV7Rowy8BjMwrFJuAB8h2DdIuVcxvaeEa0,681
209
209
  tf_keras/src/__init__.py,sha256=0_kJChFpSVveEmORgCyrhovmeKxMEzS78Hh1b4Egy18,1502
@@ -402,27 +402,7 @@ tf_keras/src/layers/preprocessing/preprocessing_utils.py,sha256=OR8NDGv8foDT2Ngv
402
402
  tf_keras/src/layers/preprocessing/string_lookup.py,sha256=2yqsgps42qMd6MB6vwBevionU7dh77OQdLburmn90b0,19179
403
403
  tf_keras/src/layers/preprocessing/text_vectorization.py,sha256=mL6sHm3TPXKg8q51vEWyo7LYKyiEoQFzm7GUkrSS-6E,30467
404
404
  tf_keras/src/layers/preprocessing/benchmarks/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
405
- tf_keras/src/layers/preprocessing/benchmarks/bucketized_column_dense_benchmark.py,sha256=ZKFxRPRDx9VYUzu3k42DO2hrN9Ve9UNLPYEraN3BU94,2845
406
- tf_keras/src/layers/preprocessing/benchmarks/category_encoding_benchmark.py,sha256=IEdxK6eQa1YdxgmOQ13YBeJ94afFWfGazAO6NvfxJ5w,2949
407
- tf_keras/src/layers/preprocessing/benchmarks/category_hash_dense_benchmark.py,sha256=yoGE5ofB7fspimQ1ImShs5KguNGUQ_JpsGYPLZS1gpQ,2809
408
- tf_keras/src/layers/preprocessing/benchmarks/category_hash_varlen_benchmark.py,sha256=5b80c35WGpWEXgv2lutqVVuS52mWiD6Cyw1ZA6KkseU,2723
409
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_dense_benchmark.py,sha256=SPmA9yXH3dr6uHfs1IsAkrjNo02YgyfmWrt24pl6ROs,3588
410
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_varlen_benchmark.py,sha256=JAM0X1lBkZd7KYtBFaBP2HfxxB3Uj7Ik7WeFhajbwNo,3437
411
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_dense_benchmark.py,sha256=y0RR1TMq5PUv4Jlh7jMmQrJWsjDtgDivsqTeEMi6ovI,2863
412
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_dense_benchmark.py,sha256=lyfRE8NP3gLfTDnIzucPqjMiLAOCEUS-pSwa1f7EXLM,3169
413
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_varlen_benchmark.py,sha256=Ebx54Qo5ec-Sys5bPhp1KaVtmWsQxpNotzxxpxtfBPg,3101
414
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_varlen_benchmark.py,sha256=WkDOo5borQYk78xKbnsh7tcEZyjrDEGb7NnTkYzoM18,2795
415
- tf_keras/src/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py,sha256=UD48alO_v-Vb8naluZtPozU7U4Oy-1WPSV1oqhzl-Yk,3776
416
- tf_keras/src/layers/preprocessing/benchmarks/embedding_dense_benchmark.py,sha256=PB7D3pFmVxlxZ4tKO7N-NB-wfJ0KY8B4RpAy_BZG01A,2836
417
- tf_keras/src/layers/preprocessing/benchmarks/embedding_varlen_benchmark.py,sha256=-fif0N3JPQT9fIwmpj-XE2eJif21sK2TsC6fri7ZuWI,2831
418
405
  tf_keras/src/layers/preprocessing/benchmarks/feature_column_benchmark.py,sha256=cSSHeEGH1dhxR3UJiCFZUgFeRZHd25eDUMRmKE140is,4814
419
- tf_keras/src/layers/preprocessing/benchmarks/hashed_crossing_benchmark.py,sha256=QV4n0f2j5b1Us-D2NHMA7WMRuUeMyiZpg-FAEopK0qs,2835
420
- tf_keras/src/layers/preprocessing/benchmarks/hashing_benchmark.py,sha256=wV16NUaNLfYZVwZCuMiX7JN9YDmbqyxaWHERo_uFJoE,3624
421
- tf_keras/src/layers/preprocessing/benchmarks/image_preproc_benchmark.py,sha256=x-XDwI75oIW3clnGOOmRG0Tb3hsQTx40bwxT7sj6CaE,5467
422
- tf_keras/src/layers/preprocessing/benchmarks/index_lookup_adapt_benchmark.py,sha256=LLv8vcdsphIBy5-owcABZdVSGSGMmQ7W-LmFTezO9Wc,4475
423
- tf_keras/src/layers/preprocessing/benchmarks/index_lookup_forward_benchmark.py,sha256=O4e0X-yLYWpfN2pX_WshN92ygw7XqlXZfgQjeO1WjuY,4941
424
- tf_keras/src/layers/preprocessing/benchmarks/normalization_adapt_benchmark.py,sha256=sB-Tcem8UdFGXnKx4HI4fLjTsIjaGJ2WAaphrxuItVc,4420
425
- tf_keras/src/layers/preprocessing/benchmarks/weighted_embedding_varlen_benchmark.py,sha256=Z5k0UaPM0-VfUw9tMv4_dEhsQNDODWlfNtsZ1RHFrFI,3324
426
406
  tf_keras/src/layers/regularization/__init__.py,sha256=9fIrtV8SwP1PG8BXfNrSP8rSyCdh4pPnV7hNvDbRysg,1369
427
407
  tf_keras/src/layers/regularization/activity_regularization.py,sha256=QxnBlnkHi2HZ2Pt-mX5WGiJWzljNQmh-X4La9f7XDGo,1942
428
408
  tf_keras/src/layers/regularization/alpha_dropout.py,sha256=JmMO6OHzpVtRS2Tl1fTslktQPM4MuN0ivNlCOUhH0VM,3800
@@ -483,7 +463,7 @@ tf_keras/src/legacy_tf_layers/variable_scope_shim.py,sha256=kGAFW03pVWSB1DhHvQ1W
483
463
  tf_keras/src/metrics/__init__.py,sha256=dM8S0ZhfiyPaXkdYuOSKvoytmYOkh8aYuJnpgUoT6vg,9699
484
464
  tf_keras/src/metrics/accuracy_metrics.py,sha256=RRQqyYZcVrEY2Pfc-OV6k3rYhv9ejSLJ9JbJzs_D5vk,17514
485
465
  tf_keras/src/metrics/base_metric.py,sha256=MCaI7Bx-kgs5udTRLvKMJ3SO90-GFs_9QMigrhkX9HQ,36498
486
- tf_keras/src/metrics/confusion_metrics.py,sha256=e3BFKJgIRjMDGHjpbtZwiproFE9wnJJ9fKhc9zoVdsE,65962
466
+ tf_keras/src/metrics/confusion_metrics.py,sha256=V1uNFUc1zyjxd-m-D83QhJ9bkbtPCfsXf3CROOPWmzs,68068
487
467
  tf_keras/src/metrics/f_score_metrics.py,sha256=3uxqH9NNqoKaGPz-R6eERA23bK1TabCXrsJUz2sbetU,12000
488
468
  tf_keras/src/metrics/hinge_metrics.py,sha256=QXtNdxE-IgZmdVQXIew_pN6X3aF9i7r7xirmb6oiOKA,4132
489
469
  tf_keras/src/metrics/iou_metrics.py,sha256=dUqZpOppIPj3aCtS25Hs6bvJoPHNnrtAChujoA-6bLQ,28530
@@ -498,7 +478,7 @@ tf_keras/src/mixed_precision/policy.py,sha256=1GWHp99dU0f6D0h_jIrSQkoLyIf0ClRJ0B
498
478
  tf_keras/src/mixed_precision/test_util.py,sha256=S4dDVLvFmv3OXvo-7kswO8MStwvTjP_caE3DrUhy9Po,8641
499
479
  tf_keras/src/models/__init__.py,sha256=VQ3cZve-CsmM_4CEi9q-V7m2qFO9HbdiO38mAR4dKdM,1823
500
480
  tf_keras/src/models/cloning.py,sha256=PHLTG0gSjvoKl8jxGaLCUq3ejK_o0PNA7gxSqxyoLBI,36839
501
- tf_keras/src/models/sharpness_aware_minimization.py,sha256=4nofg5_fbrRuGa5RAIQwJ-OL8eeiWg7jlNkMuJSCB_g,7301
481
+ tf_keras/src/models/sharpness_aware_minimization.py,sha256=MArrweVZA85F1tPHZd06AVKpAdacaPplTz6eOS2XcRk,7795
502
482
  tf_keras/src/optimizers/__init__.py,sha256=lkPBfjJhWx_0nV8MrEmjWvJTGKutM1a9nIrB0ua0O-k,13044
503
483
  tf_keras/src/optimizers/adadelta.py,sha256=47HgdG0v-76B5htebkwl1OoryFPLO2kgk_CsYqgq7hU,6174
504
484
  tf_keras/src/optimizers/adafactor.py,sha256=_IYi6WMyXl4nPimr15nAPWvj6ZKcP7cESFsdpeabNKQ,8651
@@ -549,7 +529,6 @@ tf_keras/src/saving/legacy/serialization.py,sha256=OrmHQPolQFsR-UCMxNTxkIFTKY4DK
549
529
  tf_keras/src/saving/legacy/saved_model/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
550
530
  tf_keras/src/saving/legacy/saved_model/base_serialization.py,sha256=dALR19_zt4c80zVw3yjCj9wfRoJufDjCrvkJyS82Dnk,5104
551
531
  tf_keras/src/saving/legacy/saved_model/constants.py,sha256=96ymvysCZ2Ru888YT_DEPMDgizHdDoBFGEOXsf-9AwE,1779
552
- tf_keras/src/saving/legacy/saved_model/create_test_saved_model.py,sha256=mS5jmCsDwUFUKr08G0tphPSA5ZAd7illNyj3QXKejOA,1040
553
532
  tf_keras/src/saving/legacy/saved_model/json_utils.py,sha256=WOyJaamx15lJ_V4XZSYM3RtuATa73uRNRM13o2s7yQ4,8071
554
533
  tf_keras/src/saving/legacy/saved_model/layer_serialization.py,sha256=FQwNk2XJq8dzgoSAWrmcabglZTA-6oDPtfeLiGrZO6A,8418
555
534
  tf_keras/src/saving/legacy/saved_model/load.py,sha256=wJUL0T4ZlUgxk3pZa_7E4NXEE3LXfjSKuGvTlrQchHc,57007
@@ -567,7 +546,6 @@ tf_keras/src/testing_infra/test_combinations.py,sha256=ETwFTN8eBAusQpqU7dg_Qckb1
567
546
  tf_keras/src/testing_infra/test_utils.py,sha256=SMEYejGPfYnZT2tVgzHL3gBHNGk6qcTu1qcZetHv870,40307
568
547
  tf_keras/src/tests/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
569
548
  tf_keras/src/tests/get_config_samples.py,sha256=qz2SZb_JIW2NoTak9NphLJkDTgYmlQ5RNm64T9wQ6L8,15307
570
- tf_keras/src/tests/keras_doctest.py,sha256=qFPhxdstCjwGZw0JIKPMZ_PF-oBzEgP6EqZ9n_0mtio,4638
571
549
  tf_keras/src/tests/model_architectures.py,sha256=83-y4n0LtvpcpXPgawvPGIcvaqaPZ_XgOVEDRgycLmw,10830
572
550
  tf_keras/src/tests/model_subclassing_test_util.py,sha256=tMRAx38exGbDKEUd5kNDRn7Q-epiMPCAxdbAEGSCP6Y,5515
573
551
  tf_keras/src/utils/__init__.py,sha256=HDp6YtwWY9al-pSjokrgj_IzsFi36TWQVGJp3ibTlws,3129
@@ -587,7 +565,7 @@ tf_keras/src/utils/kernelized_utils.py,sha256=s475SAos2zHQ1NT9AHZmbWUSahHKOhdctP
587
565
  tf_keras/src/utils/kpl_test_utils.py,sha256=vnaJkySSTVhXsFEdDxNJArwXaah0yPNTK8o_3rYZvOE,7365
588
566
  tf_keras/src/utils/layer_utils.py,sha256=cLKqiqJ2em16zZyXaXFErsL6yja28qE6kgPs2TTcdcY,43427
589
567
  tf_keras/src/utils/losses_utils.py,sha256=oPHJSNLY8U57ieQD59vnGHNavZpMpeTZtL7VIlDwwfM,16919
590
- tf_keras/src/utils/metrics_utils.py,sha256=feW5GoiznbQKkxmE3Url2nlXfWgvHMpJPXdGKCdiV_U,39803
568
+ tf_keras/src/utils/metrics_utils.py,sha256=h4F4MGcHrpjthypj-nZ1n2szBrBZj4X0R9cEzMcx75w,39938
591
569
  tf_keras/src/utils/mode_keys.py,sha256=_QYq58qr_b-RhvMYBYnL47NkC0G1ng8NYcVnS_IYi-A,856
592
570
  tf_keras/src/utils/np_utils.py,sha256=4EZ58G1zThQfQEmMNBPnUYRszXRJoY4foxYhOGfS89s,4805
593
571
  tf_keras/src/utils/object_identity.py,sha256=HZEETVcCoBrnIFjnxmBhZaCKP9xQMv9rMr_ihlMveVs,6879
@@ -606,7 +584,7 @@ tf_keras/src/utils/legacy/__init__.py,sha256=EfMmeHYDzwvxNaktPhQbkTdcPSIGCqMhBND
606
584
  tf_keras/utils/__init__.py,sha256=b7_d-USe_EmLo02_P99Q1rUCzKBYayPCfiYFStP-0nw,2735
607
585
  tf_keras/utils/experimental/__init__.py,sha256=DzGogE2AosjxOVILQBT8PDDcqbWTc0wWnZRobCdpcec,97
608
586
  tf_keras/utils/legacy/__init__.py,sha256=7ujlDa5HeSRcth2NdqA0S1P2-VZF1kB3n68jye6Dj-8,189
609
- tf_keras_nightly-2.20.0.dev2025062209.dist-info/METADATA,sha256=c6nykYDOl9DWcGm2se_RXVg0DK8u23dswgFfKeeWvXw,1857
610
- tf_keras_nightly-2.20.0.dev2025062209.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
611
- tf_keras_nightly-2.20.0.dev2025062209.dist-info/top_level.txt,sha256=LC8FK7zHDNKxB17C6lGKvrZ_fZZGJsRiBK23SfiDegY,9
612
- tf_keras_nightly-2.20.0.dev2025062209.dist-info/RECORD,,
587
+ tf_keras_nightly-2.20.0.dev2025082818.dist-info/METADATA,sha256=ZOv9cjktum7eBqt5JjQf7daEfrYpsh8wpUIwfVoycYw,1857
588
+ tf_keras_nightly-2.20.0.dev2025082818.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
589
+ tf_keras_nightly-2.20.0.dev2025082818.dist-info/top_level.txt,sha256=LC8FK7zHDNKxB17C6lGKvrZ_fZZGJsRiBK23SfiDegY,9
590
+ tf_keras_nightly-2.20.0.dev2025082818.dist-info/RECORD,,
@@ -1,85 +0,0 @@
1
- # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- """Benchmark for KPL implementation of bucketized columns with dense inputs."""
16
-
17
- import numpy as np
18
- import tensorflow.compat.v2 as tf
19
-
20
- import tf_keras.src as keras
21
- from tf_keras.src.layers.preprocessing import discretization
22
- from tf_keras.src.layers.preprocessing.benchmarks import (
23
- feature_column_benchmark as fc_bm,
24
- )
25
-
26
- # isort: off
27
- from tensorflow.python.eager.def_function import (
28
- function as tf_function,
29
- )
30
-
31
- NUM_REPEATS = 10 # The number of times to run each benchmark.
32
- BATCH_SIZES = [32, 256]
33
-
34
-
35
- ### KPL AND FC IMPLEMENTATION BENCHMARKS ###
36
- def embedding_varlen(batch_size, max_length):
37
- """Benchmark a variable-length embedding."""
38
- # Data and constants.
39
- max_value = 25.0
40
- bins = np.arange(1.0, max_value)
41
- data = fc_bm.create_data(
42
- max_length, batch_size * NUM_REPEATS, 100000, dtype=float
43
- )
44
-
45
- # TF-Keras implementation
46
- model = keras.Sequential()
47
- model.add(keras.Input(shape=(max_length,), name="data", dtype=tf.float32))
48
- model.add(discretization.Discretization(bins))
49
-
50
- # FC implementation
51
- fc = tf.feature_column.bucketized_column(
52
- tf.feature_column.numeric_column("data"), boundaries=list(bins)
53
- )
54
-
55
- # Wrap the FC implementation in a tf.function for a fair comparison
56
- @tf_function()
57
- def fc_fn(tensors):
58
- fc.transform_feature(
59
- tf.__internal__.feature_column.FeatureTransformationCache(tensors),
60
- None,
61
- )
62
-
63
- # Benchmark runs
64
- keras_data = {"data": data.to_tensor(default_value=0.0)}
65
- k_avg_time = fc_bm.run_keras(keras_data, model, batch_size, NUM_REPEATS)
66
-
67
- fc_data = {"data": data.to_tensor(default_value=0.0)}
68
- fc_avg_time = fc_bm.run_fc(fc_data, fc_fn, batch_size, NUM_REPEATS)
69
-
70
- return k_avg_time, fc_avg_time
71
-
72
-
73
- class BenchmarkLayer(fc_bm.LayerBenchmark):
74
- """Benchmark the layer forward pass."""
75
-
76
- def benchmark_layer(self):
77
- for batch in BATCH_SIZES:
78
- name = f"bucketized|dense|batch_{batch}"
79
- k_time, f_time = embedding_varlen(batch_size=batch, max_length=256)
80
- self.report(name, k_time, f_time, NUM_REPEATS)
81
-
82
-
83
- if __name__ == "__main__":
84
- tf.test.main()
85
-
@@ -1,84 +0,0 @@
1
- # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- """Benchmark for TF-Keras category_encoding preprocessing layer."""
16
-
17
- import time
18
-
19
- import numpy as np
20
- import tensorflow.compat.v2 as tf
21
-
22
- import tf_keras.src as keras
23
- from tf_keras.src.layers.preprocessing import category_encoding
24
-
25
-
26
- class BenchmarkLayer(tf.test.Benchmark):
27
- """Benchmark the layer forward pass."""
28
-
29
- def run_dataset_implementation(
30
- self, output_mode, batch_size, sequence_length, max_tokens
31
- ):
32
- input_t = keras.Input(shape=(sequence_length,), dtype=tf.int32)
33
- layer = category_encoding.CategoryEncoding(
34
- max_tokens=max_tokens, output_mode=output_mode
35
- )
36
- _ = layer(input_t)
37
-
38
- num_repeats = 5
39
- starts = []
40
- ends = []
41
- for _ in range(num_repeats):
42
- ds = tf.data.Dataset.from_tensor_slices(
43
- tf.random.uniform(
44
- [batch_size * 10, sequence_length],
45
- minval=0,
46
- maxval=max_tokens - 1,
47
- dtype=tf.int32,
48
- )
49
- )
50
- ds = ds.shuffle(batch_size * 100)
51
- ds = ds.batch(batch_size)
52
- num_batches = 5
53
- ds = ds.take(num_batches)
54
- ds = ds.prefetch(num_batches)
55
- starts.append(time.time())
56
- # Benchmarked code begins here.
57
- for i in ds:
58
- _ = layer(i)
59
- # Benchmarked code ends here.
60
- ends.append(time.time())
61
-
62
- avg_time = np.mean(np.array(ends) - np.array(starts)) / num_batches
63
- name = "category_encoding|batch_%s|seq_length_%s|%s_max_tokens" % (
64
- batch_size,
65
- sequence_length,
66
- max_tokens,
67
- )
68
- self.report_benchmark(iters=num_repeats, wall_time=avg_time, name=name)
69
-
70
- def benchmark_vocab_size_by_batch(self):
71
- for batch in [32, 256, 2048]:
72
- for sequence_length in [10, 1000]:
73
- for num_tokens in [100, 1000, 20000]:
74
- self.run_dataset_implementation(
75
- output_mode="count",
76
- batch_size=batch,
77
- sequence_length=sequence_length,
78
- max_tokens=num_tokens,
79
- )
80
-
81
-
82
- if __name__ == "__main__":
83
- tf.test.main()
84
-