tf-keras-nightly 2.20.0.dev2025062109__py3-none-any.whl → 2.20.0.dev2025082818__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tf-keras-nightly might be problematic. Click here for more details.
- tf_keras/__init__.py +1 -1
- tf_keras/protobuf/projector_config_pb2.py +23 -12
- tf_keras/protobuf/saved_metadata_pb2.py +21 -10
- tf_keras/protobuf/versions_pb2.py +19 -8
- tf_keras/src/metrics/confusion_metrics.py +47 -1
- tf_keras/src/models/sharpness_aware_minimization.py +17 -7
- tf_keras/src/utils/metrics_utils.py +4 -1
- {tf_keras_nightly-2.20.0.dev2025062109.dist-info → tf_keras_nightly-2.20.0.dev2025082818.dist-info}/METADATA +1 -1
- {tf_keras_nightly-2.20.0.dev2025062109.dist-info → tf_keras_nightly-2.20.0.dev2025082818.dist-info}/RECORD +11 -33
- tf_keras/src/layers/preprocessing/benchmarks/bucketized_column_dense_benchmark.py +0 -85
- tf_keras/src/layers/preprocessing/benchmarks/category_encoding_benchmark.py +0 -84
- tf_keras/src/layers/preprocessing/benchmarks/category_hash_dense_benchmark.py +0 -89
- tf_keras/src/layers/preprocessing/benchmarks/category_hash_varlen_benchmark.py +0 -89
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_dense_benchmark.py +0 -110
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_varlen_benchmark.py +0 -103
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_dense_benchmark.py +0 -87
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_dense_benchmark.py +0 -96
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_varlen_benchmark.py +0 -96
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_varlen_benchmark.py +0 -87
- tf_keras/src/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py +0 -109
- tf_keras/src/layers/preprocessing/benchmarks/embedding_dense_benchmark.py +0 -86
- tf_keras/src/layers/preprocessing/benchmarks/embedding_varlen_benchmark.py +0 -89
- tf_keras/src/layers/preprocessing/benchmarks/hashed_crossing_benchmark.py +0 -90
- tf_keras/src/layers/preprocessing/benchmarks/hashing_benchmark.py +0 -105
- tf_keras/src/layers/preprocessing/benchmarks/image_preproc_benchmark.py +0 -159
- tf_keras/src/layers/preprocessing/benchmarks/index_lookup_adapt_benchmark.py +0 -135
- tf_keras/src/layers/preprocessing/benchmarks/index_lookup_forward_benchmark.py +0 -144
- tf_keras/src/layers/preprocessing/benchmarks/normalization_adapt_benchmark.py +0 -124
- tf_keras/src/layers/preprocessing/benchmarks/weighted_embedding_varlen_benchmark.py +0 -99
- tf_keras/src/saving/legacy/saved_model/create_test_saved_model.py +0 -37
- tf_keras/src/tests/keras_doctest.py +0 -159
- {tf_keras_nightly-2.20.0.dev2025062109.dist-info → tf_keras_nightly-2.20.0.dev2025082818.dist-info}/WHEEL +0 -0
- {tf_keras_nightly-2.20.0.dev2025062109.dist-info → tf_keras_nightly-2.20.0.dev2025082818.dist-info}/top_level.txt +0 -0
tf_keras/__init__.py
CHANGED
|
@@ -1,11 +1,22 @@
|
|
|
1
1
|
# -*- coding: utf-8 -*-
|
|
2
2
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
|
+
# NO CHECKED-IN PROTOBUF GENCODE
|
|
3
4
|
# source: tf_keras/protobuf/projector_config.proto
|
|
5
|
+
# Protobuf Python Version: 5.28.3
|
|
4
6
|
"""Generated protocol buffer code."""
|
|
5
|
-
from google.protobuf.internal import builder as _builder
|
|
6
7
|
from google.protobuf import descriptor as _descriptor
|
|
7
8
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
9
|
+
from google.protobuf import runtime_version as _runtime_version
|
|
8
10
|
from google.protobuf import symbol_database as _symbol_database
|
|
11
|
+
from google.protobuf.internal import builder as _builder
|
|
12
|
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
13
|
+
_runtime_version.Domain.PUBLIC,
|
|
14
|
+
5,
|
|
15
|
+
28,
|
|
16
|
+
3,
|
|
17
|
+
'',
|
|
18
|
+
'tf_keras/protobuf/projector_config.proto'
|
|
19
|
+
)
|
|
9
20
|
# @@protoc_insertion_point(imports)
|
|
10
21
|
|
|
11
22
|
_sym_db = _symbol_database.Default()
|
|
@@ -15,15 +26,15 @@ _sym_db = _symbol_database.Default()
|
|
|
15
26
|
|
|
16
27
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n(tf_keras/protobuf/projector_config.proto\x12 third_party.py.tf_keras.protobuf\">\n\x0eSpriteMetadata\x12\x12\n\nimage_path\x18\x01 \x01(\t\x12\x18\n\x10single_image_dim\x18\x02 \x03(\r\"\xc0\x01\n\rEmbeddingInfo\x12\x13\n\x0btensor_name\x18\x01 \x01(\t\x12\x15\n\rmetadata_path\x18\x02 \x01(\t\x12\x16\n\x0e\x62ookmarks_path\x18\x03 \x01(\t\x12\x14\n\x0ctensor_shape\x18\x04 \x03(\r\x12@\n\x06sprite\x18\x05 \x01(\x0b\x32\x30.third_party.py.tf_keras.protobuf.SpriteMetadata\x12\x13\n\x0btensor_path\x18\x06 \x01(\t\"\x93\x01\n\x0fProjectorConfig\x12\x1d\n\x15model_checkpoint_path\x18\x01 \x01(\t\x12\x43\n\nembeddings\x18\x02 \x03(\x0b\x32/.third_party.py.tf_keras.protobuf.EmbeddingInfo\x12\x1c\n\x14model_checkpoint_dir\x18\x03 \x01(\tb\x06proto3')
|
|
17
28
|
|
|
18
|
-
|
|
19
|
-
_builder.
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
DESCRIPTOR.
|
|
23
|
-
_SPRITEMETADATA._serialized_start=78
|
|
24
|
-
_SPRITEMETADATA._serialized_end=140
|
|
25
|
-
_EMBEDDINGINFO._serialized_start=143
|
|
26
|
-
_EMBEDDINGINFO._serialized_end=335
|
|
27
|
-
_PROJECTORCONFIG._serialized_start=338
|
|
28
|
-
_PROJECTORCONFIG._serialized_end=485
|
|
29
|
+
_globals = globals()
|
|
30
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
31
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.projector_config_pb2', _globals)
|
|
32
|
+
if not _descriptor._USE_C_DESCRIPTORS:
|
|
33
|
+
DESCRIPTOR._loaded_options = None
|
|
34
|
+
_globals['_SPRITEMETADATA']._serialized_start=78
|
|
35
|
+
_globals['_SPRITEMETADATA']._serialized_end=140
|
|
36
|
+
_globals['_EMBEDDINGINFO']._serialized_start=143
|
|
37
|
+
_globals['_EMBEDDINGINFO']._serialized_end=335
|
|
38
|
+
_globals['_PROJECTORCONFIG']._serialized_start=338
|
|
39
|
+
_globals['_PROJECTORCONFIG']._serialized_end=485
|
|
29
40
|
# @@protoc_insertion_point(module_scope)
|
|
@@ -1,11 +1,22 @@
|
|
|
1
1
|
# -*- coding: utf-8 -*-
|
|
2
2
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
|
+
# NO CHECKED-IN PROTOBUF GENCODE
|
|
3
4
|
# source: tf_keras/protobuf/saved_metadata.proto
|
|
5
|
+
# Protobuf Python Version: 5.28.3
|
|
4
6
|
"""Generated protocol buffer code."""
|
|
5
|
-
from google.protobuf.internal import builder as _builder
|
|
6
7
|
from google.protobuf import descriptor as _descriptor
|
|
7
8
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
9
|
+
from google.protobuf import runtime_version as _runtime_version
|
|
8
10
|
from google.protobuf import symbol_database as _symbol_database
|
|
11
|
+
from google.protobuf.internal import builder as _builder
|
|
12
|
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
13
|
+
_runtime_version.Domain.PUBLIC,
|
|
14
|
+
5,
|
|
15
|
+
28,
|
|
16
|
+
3,
|
|
17
|
+
'',
|
|
18
|
+
'tf_keras/protobuf/saved_metadata.proto'
|
|
19
|
+
)
|
|
9
20
|
# @@protoc_insertion_point(imports)
|
|
10
21
|
|
|
11
22
|
_sym_db = _symbol_database.Default()
|
|
@@ -16,13 +27,13 @@ from tf_keras.protobuf import versions_pb2 as tf__keras_dot_protobuf_dot_version
|
|
|
16
27
|
|
|
17
28
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n&tf_keras/protobuf/saved_metadata.proto\x12 third_party.py.tf_keras.protobuf\x1a tf_keras/protobuf/versions.proto\"M\n\rSavedMetadata\x12<\n\x05nodes\x18\x01 \x03(\x0b\x32-.third_party.py.tf_keras.protobuf.SavedObject\"\x9c\x01\n\x0bSavedObject\x12\x0f\n\x07node_id\x18\x02 \x01(\x05\x12\x11\n\tnode_path\x18\x03 \x01(\t\x12\x12\n\nidentifier\x18\x04 \x01(\t\x12\x10\n\x08metadata\x18\x05 \x01(\t\x12=\n\x07version\x18\x06 \x01(\x0b\x32,.third_party.py.tf_keras.protobuf.VersionDefJ\x04\x08\x01\x10\x02\x62\x06proto3')
|
|
18
29
|
|
|
19
|
-
|
|
20
|
-
_builder.
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
DESCRIPTOR.
|
|
24
|
-
_SAVEDMETADATA._serialized_start=110
|
|
25
|
-
_SAVEDMETADATA._serialized_end=187
|
|
26
|
-
_SAVEDOBJECT._serialized_start=190
|
|
27
|
-
_SAVEDOBJECT._serialized_end=346
|
|
30
|
+
_globals = globals()
|
|
31
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
32
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.saved_metadata_pb2', _globals)
|
|
33
|
+
if not _descriptor._USE_C_DESCRIPTORS:
|
|
34
|
+
DESCRIPTOR._loaded_options = None
|
|
35
|
+
_globals['_SAVEDMETADATA']._serialized_start=110
|
|
36
|
+
_globals['_SAVEDMETADATA']._serialized_end=187
|
|
37
|
+
_globals['_SAVEDOBJECT']._serialized_start=190
|
|
38
|
+
_globals['_SAVEDOBJECT']._serialized_end=346
|
|
28
39
|
# @@protoc_insertion_point(module_scope)
|
|
@@ -1,11 +1,22 @@
|
|
|
1
1
|
# -*- coding: utf-8 -*-
|
|
2
2
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
|
+
# NO CHECKED-IN PROTOBUF GENCODE
|
|
3
4
|
# source: tf_keras/protobuf/versions.proto
|
|
5
|
+
# Protobuf Python Version: 5.28.3
|
|
4
6
|
"""Generated protocol buffer code."""
|
|
5
|
-
from google.protobuf.internal import builder as _builder
|
|
6
7
|
from google.protobuf import descriptor as _descriptor
|
|
7
8
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
9
|
+
from google.protobuf import runtime_version as _runtime_version
|
|
8
10
|
from google.protobuf import symbol_database as _symbol_database
|
|
11
|
+
from google.protobuf.internal import builder as _builder
|
|
12
|
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
13
|
+
_runtime_version.Domain.PUBLIC,
|
|
14
|
+
5,
|
|
15
|
+
28,
|
|
16
|
+
3,
|
|
17
|
+
'',
|
|
18
|
+
'tf_keras/protobuf/versions.proto'
|
|
19
|
+
)
|
|
9
20
|
# @@protoc_insertion_point(imports)
|
|
10
21
|
|
|
11
22
|
_sym_db = _symbol_database.Default()
|
|
@@ -15,11 +26,11 @@ _sym_db = _symbol_database.Default()
|
|
|
15
26
|
|
|
16
27
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n tf_keras/protobuf/versions.proto\x12 third_party.py.tf_keras.protobuf\"K\n\nVersionDef\x12\x10\n\x08producer\x18\x01 \x01(\x05\x12\x14\n\x0cmin_consumer\x18\x02 \x01(\x05\x12\x15\n\rbad_consumers\x18\x03 \x03(\x05\x62\x06proto3')
|
|
17
28
|
|
|
18
|
-
|
|
19
|
-
_builder.
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
DESCRIPTOR.
|
|
23
|
-
_VERSIONDEF._serialized_start=70
|
|
24
|
-
_VERSIONDEF._serialized_end=145
|
|
29
|
+
_globals = globals()
|
|
30
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
31
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.versions_pb2', _globals)
|
|
32
|
+
if not _descriptor._USE_C_DESCRIPTORS:
|
|
33
|
+
DESCRIPTOR._loaded_options = None
|
|
34
|
+
_globals['_VERSIONDEF']._serialized_start=70
|
|
35
|
+
_globals['_VERSIONDEF']._serialized_end=145
|
|
25
36
|
# @@protoc_insertion_point(module_scope)
|
|
@@ -1612,13 +1612,59 @@ class AUC(base_metric.Metric):
|
|
|
1612
1612
|
)
|
|
1613
1613
|
x = fp_rate
|
|
1614
1614
|
y = recall
|
|
1615
|
-
|
|
1615
|
+
elif self.curve == metrics_utils.AUCCurve.PR:
|
|
1616
1616
|
precision = tf.math.divide_no_nan(
|
|
1617
1617
|
self.true_positives,
|
|
1618
1618
|
tf.math.add(self.true_positives, self.false_positives),
|
|
1619
1619
|
)
|
|
1620
1620
|
x = recall
|
|
1621
1621
|
y = precision
|
|
1622
|
+
else: # curve == 'PR_GAIN'.
|
|
1623
|
+
# Due to the hyperbolic transform, this formula is less robust than
|
|
1624
|
+
# ROC or PR values. In particular
|
|
1625
|
+
# 1) Both measures diverge when there are no negative examples;
|
|
1626
|
+
# 2) Both measures diverge when there are no true positives;
|
|
1627
|
+
# 3) Recall gain becomes negative when the recall is lower than the
|
|
1628
|
+
# label average (i.e. when more negative examples are classified
|
|
1629
|
+
# positive than real positives).
|
|
1630
|
+
#
|
|
1631
|
+
# We ignore case 1 as it is easily communicated. For case 2 we set
|
|
1632
|
+
# recall_gain to 0 and precision_gain to 1. For case 3 we set the
|
|
1633
|
+
# recall_gain to 0. These fixes will result in an overastimation of
|
|
1634
|
+
# the AUC for estimateors that are anti-correlated with the label
|
|
1635
|
+
# (at some thresholds).
|
|
1636
|
+
#
|
|
1637
|
+
# The scaling factor $\frac{P}{N}$ that is used to form both
|
|
1638
|
+
# gain values.
|
|
1639
|
+
scaling_factor = tf.math.divide_no_nan(
|
|
1640
|
+
tf.math.add(self.true_positives, self.false_negatives),
|
|
1641
|
+
tf.math.add(self.false_positives, self.true_negatives),
|
|
1642
|
+
)
|
|
1643
|
+
|
|
1644
|
+
recall_gain = 1.0 - scaling_factor * tf.math.divide_no_nan(
|
|
1645
|
+
self.false_negatives, self.true_positives
|
|
1646
|
+
)
|
|
1647
|
+
precision_gain = 1.0 - scaling_factor * tf.math.divide_no_nan(
|
|
1648
|
+
self.false_positives, self.true_positives
|
|
1649
|
+
)
|
|
1650
|
+
# Handle case 2.
|
|
1651
|
+
recall_gain = tf.where(
|
|
1652
|
+
tf.equal(self.true_positives, 0.0),
|
|
1653
|
+
tf.zeros_like(recall_gain),
|
|
1654
|
+
recall_gain,
|
|
1655
|
+
)
|
|
1656
|
+
precision_gain = tf.where(
|
|
1657
|
+
tf.equal(self.true_positives, 0.0),
|
|
1658
|
+
tf.ones_like(precision_gain),
|
|
1659
|
+
precision_gain,
|
|
1660
|
+
)
|
|
1661
|
+
# Handle case 3.
|
|
1662
|
+
recall_gain = tf.math.maximum(
|
|
1663
|
+
recall_gain, tf.zeros_like(recall_gain)
|
|
1664
|
+
)
|
|
1665
|
+
|
|
1666
|
+
x = recall_gain
|
|
1667
|
+
y = precision_gain
|
|
1622
1668
|
|
|
1623
1669
|
# Find the rectangle heights based on `summation_method`.
|
|
1624
1670
|
if (
|
|
@@ -72,17 +72,27 @@ class SharpnessAwareMinimization(Model):
|
|
|
72
72
|
if self.num_batch_splits is not None:
|
|
73
73
|
x_split = tf.split(x, self.num_batch_splits)
|
|
74
74
|
y_split = tf.split(y, self.num_batch_splits)
|
|
75
|
+
# Split the sample weight if it is provided.
|
|
76
|
+
if sample_weight is not None:
|
|
77
|
+
sample_weight_split = tf.split(
|
|
78
|
+
sample_weight, self.num_batch_splits
|
|
79
|
+
)
|
|
80
|
+
else:
|
|
81
|
+
sample_weight_split = [None] * self.num_batch_splits
|
|
75
82
|
else:
|
|
76
83
|
x_split = [x]
|
|
77
84
|
y_split = [y]
|
|
85
|
+
sample_weight_split = [sample_weight]
|
|
78
86
|
|
|
79
87
|
gradients_all_batches = []
|
|
80
88
|
pred_all_batches = []
|
|
81
|
-
for x_batch, y_batch in zip(
|
|
89
|
+
for x_batch, y_batch, sample_weight_batch in zip(
|
|
90
|
+
x_split, y_split, sample_weight_split
|
|
91
|
+
):
|
|
82
92
|
epsilon_w_cache = []
|
|
83
93
|
with tf.GradientTape() as tape:
|
|
84
|
-
pred = self
|
|
85
|
-
loss = self.compiled_loss(y_batch, pred)
|
|
94
|
+
pred = self(x_batch)
|
|
95
|
+
loss = self.compiled_loss(y_batch, pred, sample_weight_batch)
|
|
86
96
|
pred_all_batches.append(pred)
|
|
87
97
|
trainable_variables = self.model.trainable_variables
|
|
88
98
|
gradients = tape.gradient(loss, trainable_variables)
|
|
@@ -98,8 +108,8 @@ class SharpnessAwareMinimization(Model):
|
|
|
98
108
|
epsilon_w_cache.append(epsilon_w)
|
|
99
109
|
|
|
100
110
|
with tf.GradientTape() as tape:
|
|
101
|
-
pred = self(x_batch)
|
|
102
|
-
loss = self.compiled_loss(y_batch, pred)
|
|
111
|
+
pred = self(x_batch, training=True)
|
|
112
|
+
loss = self.compiled_loss(y_batch, pred, sample_weight_batch)
|
|
103
113
|
gradients = tape.gradient(loss, trainable_variables)
|
|
104
114
|
if len(gradients_all_batches) == 0:
|
|
105
115
|
for gradient in gradients:
|
|
@@ -127,7 +137,7 @@ class SharpnessAwareMinimization(Model):
|
|
|
127
137
|
self.compiled_metrics.update_state(y, pred, sample_weight)
|
|
128
138
|
return {m.name: m.result() for m in self.metrics}
|
|
129
139
|
|
|
130
|
-
def call(self, inputs):
|
|
140
|
+
def call(self, inputs, **kwargs):
|
|
131
141
|
"""Forward pass of SAM.
|
|
132
142
|
|
|
133
143
|
SAM delegates the forward pass call to the wrapped model.
|
|
@@ -138,7 +148,7 @@ class SharpnessAwareMinimization(Model):
|
|
|
138
148
|
Returns:
|
|
139
149
|
A Tensor, the outputs of the wrapped model for given `inputs`.
|
|
140
150
|
"""
|
|
141
|
-
return self.model(inputs)
|
|
151
|
+
return self.model(inputs, **kwargs)
|
|
142
152
|
|
|
143
153
|
def get_config(self):
|
|
144
154
|
config = super().get_config()
|
|
@@ -237,6 +237,7 @@ class AUCCurve(Enum):
|
|
|
237
237
|
|
|
238
238
|
ROC = "ROC"
|
|
239
239
|
PR = "PR"
|
|
240
|
+
PR_GAIN = "PR_GAIN"
|
|
240
241
|
|
|
241
242
|
@staticmethod
|
|
242
243
|
def from_str(key):
|
|
@@ -244,10 +245,12 @@ class AUCCurve(Enum):
|
|
|
244
245
|
return AUCCurve.PR
|
|
245
246
|
elif key in ("roc", "ROC"):
|
|
246
247
|
return AUCCurve.ROC
|
|
248
|
+
elif key in ("pr_gain", "prgain", "PR_GAIN", "PRGAIN"):
|
|
249
|
+
return AUCCurve.PR_GAIN
|
|
247
250
|
else:
|
|
248
251
|
raise ValueError(
|
|
249
252
|
f'Invalid AUC curve value: "{key}". '
|
|
250
|
-
'Expected values are ["PR", "ROC"]'
|
|
253
|
+
'Expected values are ["PR", "ROC", "PR_GAIN"]'
|
|
251
254
|
)
|
|
252
255
|
|
|
253
256
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
tf_keras/__init__.py,sha256=
|
|
1
|
+
tf_keras/__init__.py,sha256=ozEje7VN-_ziSdFlB85zormVRKYEb1VaRO-LO33z8OI,911
|
|
2
2
|
tf_keras/__internal__/__init__.py,sha256=OHQbeIC0QtRBI7dgXaJaVbH8F00x8dCI-DvEcIfyMsE,671
|
|
3
3
|
tf_keras/__internal__/backend/__init__.py,sha256=LnMs2A6685gDG79fxqmdulIYlVE_3WmXlBTBo9ZWYcw,162
|
|
4
4
|
tf_keras/__internal__/layers/__init__.py,sha256=F5SGMhOTPzm-PR44VrfinURHcVeQPIEdwnZlAkSTB3A,176
|
|
@@ -201,9 +201,9 @@ tf_keras/preprocessing/image/__init__.py,sha256=H6rbMLtlGIy_jBLCSDklVTMXUjEUe8KQ
|
|
|
201
201
|
tf_keras/preprocessing/sequence/__init__.py,sha256=Zg9mw0TIRIc-BmVtdXvW3jdIQo05VHZX_xmqZDMuaik,285
|
|
202
202
|
tf_keras/preprocessing/text/__init__.py,sha256=1yQd-VZD6SjnEpPyBFLucYMxu9A5DnAnIec2tba9zQk,329
|
|
203
203
|
tf_keras/protobuf/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
204
|
-
tf_keras/protobuf/projector_config_pb2.py,sha256=
|
|
205
|
-
tf_keras/protobuf/saved_metadata_pb2.py,sha256=
|
|
206
|
-
tf_keras/protobuf/versions_pb2.py,sha256=
|
|
204
|
+
tf_keras/protobuf/projector_config_pb2.py,sha256=B9MGtu-Dr7P6DXq7zNAjJyWNXCePXjgWdEWGWtD2OoI,2208
|
|
205
|
+
tf_keras/protobuf/saved_metadata_pb2.py,sha256=LvEonuYDv46N-ANJ3K_1hys8rrSbp-28lI4NiTrmNvo,1946
|
|
206
|
+
tf_keras/protobuf/versions_pb2.py,sha256=bilE48UoIQqPZEPGIvhIkRXQB9-gI7XH28PG-ZjD4ew,1450
|
|
207
207
|
tf_keras/regularizers/__init__.py,sha256=D6TnroEDjnyP79TY_624g2DToxVWuKzuaiBAn_gUQaY,634
|
|
208
208
|
tf_keras/saving/__init__.py,sha256=Xo0imlDhiYV7Rowy8BjMwrFJuAB8h2DdIuVcxvaeEa0,681
|
|
209
209
|
tf_keras/src/__init__.py,sha256=0_kJChFpSVveEmORgCyrhovmeKxMEzS78Hh1b4Egy18,1502
|
|
@@ -402,27 +402,7 @@ tf_keras/src/layers/preprocessing/preprocessing_utils.py,sha256=OR8NDGv8foDT2Ngv
|
|
|
402
402
|
tf_keras/src/layers/preprocessing/string_lookup.py,sha256=2yqsgps42qMd6MB6vwBevionU7dh77OQdLburmn90b0,19179
|
|
403
403
|
tf_keras/src/layers/preprocessing/text_vectorization.py,sha256=mL6sHm3TPXKg8q51vEWyo7LYKyiEoQFzm7GUkrSS-6E,30467
|
|
404
404
|
tf_keras/src/layers/preprocessing/benchmarks/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
405
|
-
tf_keras/src/layers/preprocessing/benchmarks/bucketized_column_dense_benchmark.py,sha256=ZKFxRPRDx9VYUzu3k42DO2hrN9Ve9UNLPYEraN3BU94,2845
|
|
406
|
-
tf_keras/src/layers/preprocessing/benchmarks/category_encoding_benchmark.py,sha256=IEdxK6eQa1YdxgmOQ13YBeJ94afFWfGazAO6NvfxJ5w,2949
|
|
407
|
-
tf_keras/src/layers/preprocessing/benchmarks/category_hash_dense_benchmark.py,sha256=yoGE5ofB7fspimQ1ImShs5KguNGUQ_JpsGYPLZS1gpQ,2809
|
|
408
|
-
tf_keras/src/layers/preprocessing/benchmarks/category_hash_varlen_benchmark.py,sha256=5b80c35WGpWEXgv2lutqVVuS52mWiD6Cyw1ZA6KkseU,2723
|
|
409
|
-
tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_dense_benchmark.py,sha256=SPmA9yXH3dr6uHfs1IsAkrjNo02YgyfmWrt24pl6ROs,3588
|
|
410
|
-
tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_varlen_benchmark.py,sha256=JAM0X1lBkZd7KYtBFaBP2HfxxB3Uj7Ik7WeFhajbwNo,3437
|
|
411
|
-
tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_dense_benchmark.py,sha256=y0RR1TMq5PUv4Jlh7jMmQrJWsjDtgDivsqTeEMi6ovI,2863
|
|
412
|
-
tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_dense_benchmark.py,sha256=lyfRE8NP3gLfTDnIzucPqjMiLAOCEUS-pSwa1f7EXLM,3169
|
|
413
|
-
tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_varlen_benchmark.py,sha256=Ebx54Qo5ec-Sys5bPhp1KaVtmWsQxpNotzxxpxtfBPg,3101
|
|
414
|
-
tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_varlen_benchmark.py,sha256=WkDOo5borQYk78xKbnsh7tcEZyjrDEGb7NnTkYzoM18,2795
|
|
415
|
-
tf_keras/src/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py,sha256=UD48alO_v-Vb8naluZtPozU7U4Oy-1WPSV1oqhzl-Yk,3776
|
|
416
|
-
tf_keras/src/layers/preprocessing/benchmarks/embedding_dense_benchmark.py,sha256=PB7D3pFmVxlxZ4tKO7N-NB-wfJ0KY8B4RpAy_BZG01A,2836
|
|
417
|
-
tf_keras/src/layers/preprocessing/benchmarks/embedding_varlen_benchmark.py,sha256=-fif0N3JPQT9fIwmpj-XE2eJif21sK2TsC6fri7ZuWI,2831
|
|
418
405
|
tf_keras/src/layers/preprocessing/benchmarks/feature_column_benchmark.py,sha256=cSSHeEGH1dhxR3UJiCFZUgFeRZHd25eDUMRmKE140is,4814
|
|
419
|
-
tf_keras/src/layers/preprocessing/benchmarks/hashed_crossing_benchmark.py,sha256=QV4n0f2j5b1Us-D2NHMA7WMRuUeMyiZpg-FAEopK0qs,2835
|
|
420
|
-
tf_keras/src/layers/preprocessing/benchmarks/hashing_benchmark.py,sha256=wV16NUaNLfYZVwZCuMiX7JN9YDmbqyxaWHERo_uFJoE,3624
|
|
421
|
-
tf_keras/src/layers/preprocessing/benchmarks/image_preproc_benchmark.py,sha256=x-XDwI75oIW3clnGOOmRG0Tb3hsQTx40bwxT7sj6CaE,5467
|
|
422
|
-
tf_keras/src/layers/preprocessing/benchmarks/index_lookup_adapt_benchmark.py,sha256=LLv8vcdsphIBy5-owcABZdVSGSGMmQ7W-LmFTezO9Wc,4475
|
|
423
|
-
tf_keras/src/layers/preprocessing/benchmarks/index_lookup_forward_benchmark.py,sha256=O4e0X-yLYWpfN2pX_WshN92ygw7XqlXZfgQjeO1WjuY,4941
|
|
424
|
-
tf_keras/src/layers/preprocessing/benchmarks/normalization_adapt_benchmark.py,sha256=sB-Tcem8UdFGXnKx4HI4fLjTsIjaGJ2WAaphrxuItVc,4420
|
|
425
|
-
tf_keras/src/layers/preprocessing/benchmarks/weighted_embedding_varlen_benchmark.py,sha256=Z5k0UaPM0-VfUw9tMv4_dEhsQNDODWlfNtsZ1RHFrFI,3324
|
|
426
406
|
tf_keras/src/layers/regularization/__init__.py,sha256=9fIrtV8SwP1PG8BXfNrSP8rSyCdh4pPnV7hNvDbRysg,1369
|
|
427
407
|
tf_keras/src/layers/regularization/activity_regularization.py,sha256=QxnBlnkHi2HZ2Pt-mX5WGiJWzljNQmh-X4La9f7XDGo,1942
|
|
428
408
|
tf_keras/src/layers/regularization/alpha_dropout.py,sha256=JmMO6OHzpVtRS2Tl1fTslktQPM4MuN0ivNlCOUhH0VM,3800
|
|
@@ -483,7 +463,7 @@ tf_keras/src/legacy_tf_layers/variable_scope_shim.py,sha256=kGAFW03pVWSB1DhHvQ1W
|
|
|
483
463
|
tf_keras/src/metrics/__init__.py,sha256=dM8S0ZhfiyPaXkdYuOSKvoytmYOkh8aYuJnpgUoT6vg,9699
|
|
484
464
|
tf_keras/src/metrics/accuracy_metrics.py,sha256=RRQqyYZcVrEY2Pfc-OV6k3rYhv9ejSLJ9JbJzs_D5vk,17514
|
|
485
465
|
tf_keras/src/metrics/base_metric.py,sha256=MCaI7Bx-kgs5udTRLvKMJ3SO90-GFs_9QMigrhkX9HQ,36498
|
|
486
|
-
tf_keras/src/metrics/confusion_metrics.py,sha256=
|
|
466
|
+
tf_keras/src/metrics/confusion_metrics.py,sha256=V1uNFUc1zyjxd-m-D83QhJ9bkbtPCfsXf3CROOPWmzs,68068
|
|
487
467
|
tf_keras/src/metrics/f_score_metrics.py,sha256=3uxqH9NNqoKaGPz-R6eERA23bK1TabCXrsJUz2sbetU,12000
|
|
488
468
|
tf_keras/src/metrics/hinge_metrics.py,sha256=QXtNdxE-IgZmdVQXIew_pN6X3aF9i7r7xirmb6oiOKA,4132
|
|
489
469
|
tf_keras/src/metrics/iou_metrics.py,sha256=dUqZpOppIPj3aCtS25Hs6bvJoPHNnrtAChujoA-6bLQ,28530
|
|
@@ -498,7 +478,7 @@ tf_keras/src/mixed_precision/policy.py,sha256=1GWHp99dU0f6D0h_jIrSQkoLyIf0ClRJ0B
|
|
|
498
478
|
tf_keras/src/mixed_precision/test_util.py,sha256=S4dDVLvFmv3OXvo-7kswO8MStwvTjP_caE3DrUhy9Po,8641
|
|
499
479
|
tf_keras/src/models/__init__.py,sha256=VQ3cZve-CsmM_4CEi9q-V7m2qFO9HbdiO38mAR4dKdM,1823
|
|
500
480
|
tf_keras/src/models/cloning.py,sha256=PHLTG0gSjvoKl8jxGaLCUq3ejK_o0PNA7gxSqxyoLBI,36839
|
|
501
|
-
tf_keras/src/models/sharpness_aware_minimization.py,sha256=
|
|
481
|
+
tf_keras/src/models/sharpness_aware_minimization.py,sha256=MArrweVZA85F1tPHZd06AVKpAdacaPplTz6eOS2XcRk,7795
|
|
502
482
|
tf_keras/src/optimizers/__init__.py,sha256=lkPBfjJhWx_0nV8MrEmjWvJTGKutM1a9nIrB0ua0O-k,13044
|
|
503
483
|
tf_keras/src/optimizers/adadelta.py,sha256=47HgdG0v-76B5htebkwl1OoryFPLO2kgk_CsYqgq7hU,6174
|
|
504
484
|
tf_keras/src/optimizers/adafactor.py,sha256=_IYi6WMyXl4nPimr15nAPWvj6ZKcP7cESFsdpeabNKQ,8651
|
|
@@ -549,7 +529,6 @@ tf_keras/src/saving/legacy/serialization.py,sha256=OrmHQPolQFsR-UCMxNTxkIFTKY4DK
|
|
|
549
529
|
tf_keras/src/saving/legacy/saved_model/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
550
530
|
tf_keras/src/saving/legacy/saved_model/base_serialization.py,sha256=dALR19_zt4c80zVw3yjCj9wfRoJufDjCrvkJyS82Dnk,5104
|
|
551
531
|
tf_keras/src/saving/legacy/saved_model/constants.py,sha256=96ymvysCZ2Ru888YT_DEPMDgizHdDoBFGEOXsf-9AwE,1779
|
|
552
|
-
tf_keras/src/saving/legacy/saved_model/create_test_saved_model.py,sha256=mS5jmCsDwUFUKr08G0tphPSA5ZAd7illNyj3QXKejOA,1040
|
|
553
532
|
tf_keras/src/saving/legacy/saved_model/json_utils.py,sha256=WOyJaamx15lJ_V4XZSYM3RtuATa73uRNRM13o2s7yQ4,8071
|
|
554
533
|
tf_keras/src/saving/legacy/saved_model/layer_serialization.py,sha256=FQwNk2XJq8dzgoSAWrmcabglZTA-6oDPtfeLiGrZO6A,8418
|
|
555
534
|
tf_keras/src/saving/legacy/saved_model/load.py,sha256=wJUL0T4ZlUgxk3pZa_7E4NXEE3LXfjSKuGvTlrQchHc,57007
|
|
@@ -567,7 +546,6 @@ tf_keras/src/testing_infra/test_combinations.py,sha256=ETwFTN8eBAusQpqU7dg_Qckb1
|
|
|
567
546
|
tf_keras/src/testing_infra/test_utils.py,sha256=SMEYejGPfYnZT2tVgzHL3gBHNGk6qcTu1qcZetHv870,40307
|
|
568
547
|
tf_keras/src/tests/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
569
548
|
tf_keras/src/tests/get_config_samples.py,sha256=qz2SZb_JIW2NoTak9NphLJkDTgYmlQ5RNm64T9wQ6L8,15307
|
|
570
|
-
tf_keras/src/tests/keras_doctest.py,sha256=qFPhxdstCjwGZw0JIKPMZ_PF-oBzEgP6EqZ9n_0mtio,4638
|
|
571
549
|
tf_keras/src/tests/model_architectures.py,sha256=83-y4n0LtvpcpXPgawvPGIcvaqaPZ_XgOVEDRgycLmw,10830
|
|
572
550
|
tf_keras/src/tests/model_subclassing_test_util.py,sha256=tMRAx38exGbDKEUd5kNDRn7Q-epiMPCAxdbAEGSCP6Y,5515
|
|
573
551
|
tf_keras/src/utils/__init__.py,sha256=HDp6YtwWY9al-pSjokrgj_IzsFi36TWQVGJp3ibTlws,3129
|
|
@@ -587,7 +565,7 @@ tf_keras/src/utils/kernelized_utils.py,sha256=s475SAos2zHQ1NT9AHZmbWUSahHKOhdctP
|
|
|
587
565
|
tf_keras/src/utils/kpl_test_utils.py,sha256=vnaJkySSTVhXsFEdDxNJArwXaah0yPNTK8o_3rYZvOE,7365
|
|
588
566
|
tf_keras/src/utils/layer_utils.py,sha256=cLKqiqJ2em16zZyXaXFErsL6yja28qE6kgPs2TTcdcY,43427
|
|
589
567
|
tf_keras/src/utils/losses_utils.py,sha256=oPHJSNLY8U57ieQD59vnGHNavZpMpeTZtL7VIlDwwfM,16919
|
|
590
|
-
tf_keras/src/utils/metrics_utils.py,sha256=
|
|
568
|
+
tf_keras/src/utils/metrics_utils.py,sha256=h4F4MGcHrpjthypj-nZ1n2szBrBZj4X0R9cEzMcx75w,39938
|
|
591
569
|
tf_keras/src/utils/mode_keys.py,sha256=_QYq58qr_b-RhvMYBYnL47NkC0G1ng8NYcVnS_IYi-A,856
|
|
592
570
|
tf_keras/src/utils/np_utils.py,sha256=4EZ58G1zThQfQEmMNBPnUYRszXRJoY4foxYhOGfS89s,4805
|
|
593
571
|
tf_keras/src/utils/object_identity.py,sha256=HZEETVcCoBrnIFjnxmBhZaCKP9xQMv9rMr_ihlMveVs,6879
|
|
@@ -606,7 +584,7 @@ tf_keras/src/utils/legacy/__init__.py,sha256=EfMmeHYDzwvxNaktPhQbkTdcPSIGCqMhBND
|
|
|
606
584
|
tf_keras/utils/__init__.py,sha256=b7_d-USe_EmLo02_P99Q1rUCzKBYayPCfiYFStP-0nw,2735
|
|
607
585
|
tf_keras/utils/experimental/__init__.py,sha256=DzGogE2AosjxOVILQBT8PDDcqbWTc0wWnZRobCdpcec,97
|
|
608
586
|
tf_keras/utils/legacy/__init__.py,sha256=7ujlDa5HeSRcth2NdqA0S1P2-VZF1kB3n68jye6Dj-8,189
|
|
609
|
-
tf_keras_nightly-2.20.0.
|
|
610
|
-
tf_keras_nightly-2.20.0.
|
|
611
|
-
tf_keras_nightly-2.20.0.
|
|
612
|
-
tf_keras_nightly-2.20.0.
|
|
587
|
+
tf_keras_nightly-2.20.0.dev2025082818.dist-info/METADATA,sha256=ZOv9cjktum7eBqt5JjQf7daEfrYpsh8wpUIwfVoycYw,1857
|
|
588
|
+
tf_keras_nightly-2.20.0.dev2025082818.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
589
|
+
tf_keras_nightly-2.20.0.dev2025082818.dist-info/top_level.txt,sha256=LC8FK7zHDNKxB17C6lGKvrZ_fZZGJsRiBK23SfiDegY,9
|
|
590
|
+
tf_keras_nightly-2.20.0.dev2025082818.dist-info/RECORD,,
|
|
@@ -1,85 +0,0 @@
|
|
|
1
|
-
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
"""Benchmark for KPL implementation of bucketized columns with dense inputs."""
|
|
16
|
-
|
|
17
|
-
import numpy as np
|
|
18
|
-
import tensorflow.compat.v2 as tf
|
|
19
|
-
|
|
20
|
-
import tf_keras.src as keras
|
|
21
|
-
from tf_keras.src.layers.preprocessing import discretization
|
|
22
|
-
from tf_keras.src.layers.preprocessing.benchmarks import (
|
|
23
|
-
feature_column_benchmark as fc_bm,
|
|
24
|
-
)
|
|
25
|
-
|
|
26
|
-
# isort: off
|
|
27
|
-
from tensorflow.python.eager.def_function import (
|
|
28
|
-
function as tf_function,
|
|
29
|
-
)
|
|
30
|
-
|
|
31
|
-
NUM_REPEATS = 10 # The number of times to run each benchmark.
|
|
32
|
-
BATCH_SIZES = [32, 256]
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
### KPL AND FC IMPLEMENTATION BENCHMARKS ###
|
|
36
|
-
def embedding_varlen(batch_size, max_length):
|
|
37
|
-
"""Benchmark a variable-length embedding."""
|
|
38
|
-
# Data and constants.
|
|
39
|
-
max_value = 25.0
|
|
40
|
-
bins = np.arange(1.0, max_value)
|
|
41
|
-
data = fc_bm.create_data(
|
|
42
|
-
max_length, batch_size * NUM_REPEATS, 100000, dtype=float
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
# TF-Keras implementation
|
|
46
|
-
model = keras.Sequential()
|
|
47
|
-
model.add(keras.Input(shape=(max_length,), name="data", dtype=tf.float32))
|
|
48
|
-
model.add(discretization.Discretization(bins))
|
|
49
|
-
|
|
50
|
-
# FC implementation
|
|
51
|
-
fc = tf.feature_column.bucketized_column(
|
|
52
|
-
tf.feature_column.numeric_column("data"), boundaries=list(bins)
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
# Wrap the FC implementation in a tf.function for a fair comparison
|
|
56
|
-
@tf_function()
|
|
57
|
-
def fc_fn(tensors):
|
|
58
|
-
fc.transform_feature(
|
|
59
|
-
tf.__internal__.feature_column.FeatureTransformationCache(tensors),
|
|
60
|
-
None,
|
|
61
|
-
)
|
|
62
|
-
|
|
63
|
-
# Benchmark runs
|
|
64
|
-
keras_data = {"data": data.to_tensor(default_value=0.0)}
|
|
65
|
-
k_avg_time = fc_bm.run_keras(keras_data, model, batch_size, NUM_REPEATS)
|
|
66
|
-
|
|
67
|
-
fc_data = {"data": data.to_tensor(default_value=0.0)}
|
|
68
|
-
fc_avg_time = fc_bm.run_fc(fc_data, fc_fn, batch_size, NUM_REPEATS)
|
|
69
|
-
|
|
70
|
-
return k_avg_time, fc_avg_time
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
class BenchmarkLayer(fc_bm.LayerBenchmark):
|
|
74
|
-
"""Benchmark the layer forward pass."""
|
|
75
|
-
|
|
76
|
-
def benchmark_layer(self):
|
|
77
|
-
for batch in BATCH_SIZES:
|
|
78
|
-
name = f"bucketized|dense|batch_{batch}"
|
|
79
|
-
k_time, f_time = embedding_varlen(batch_size=batch, max_length=256)
|
|
80
|
-
self.report(name, k_time, f_time, NUM_REPEATS)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
if __name__ == "__main__":
|
|
84
|
-
tf.test.main()
|
|
85
|
-
|
|
@@ -1,84 +0,0 @@
|
|
|
1
|
-
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
"""Benchmark for TF-Keras category_encoding preprocessing layer."""
|
|
16
|
-
|
|
17
|
-
import time
|
|
18
|
-
|
|
19
|
-
import numpy as np
|
|
20
|
-
import tensorflow.compat.v2 as tf
|
|
21
|
-
|
|
22
|
-
import tf_keras.src as keras
|
|
23
|
-
from tf_keras.src.layers.preprocessing import category_encoding
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
class BenchmarkLayer(tf.test.Benchmark):
|
|
27
|
-
"""Benchmark the layer forward pass."""
|
|
28
|
-
|
|
29
|
-
def run_dataset_implementation(
|
|
30
|
-
self, output_mode, batch_size, sequence_length, max_tokens
|
|
31
|
-
):
|
|
32
|
-
input_t = keras.Input(shape=(sequence_length,), dtype=tf.int32)
|
|
33
|
-
layer = category_encoding.CategoryEncoding(
|
|
34
|
-
max_tokens=max_tokens, output_mode=output_mode
|
|
35
|
-
)
|
|
36
|
-
_ = layer(input_t)
|
|
37
|
-
|
|
38
|
-
num_repeats = 5
|
|
39
|
-
starts = []
|
|
40
|
-
ends = []
|
|
41
|
-
for _ in range(num_repeats):
|
|
42
|
-
ds = tf.data.Dataset.from_tensor_slices(
|
|
43
|
-
tf.random.uniform(
|
|
44
|
-
[batch_size * 10, sequence_length],
|
|
45
|
-
minval=0,
|
|
46
|
-
maxval=max_tokens - 1,
|
|
47
|
-
dtype=tf.int32,
|
|
48
|
-
)
|
|
49
|
-
)
|
|
50
|
-
ds = ds.shuffle(batch_size * 100)
|
|
51
|
-
ds = ds.batch(batch_size)
|
|
52
|
-
num_batches = 5
|
|
53
|
-
ds = ds.take(num_batches)
|
|
54
|
-
ds = ds.prefetch(num_batches)
|
|
55
|
-
starts.append(time.time())
|
|
56
|
-
# Benchmarked code begins here.
|
|
57
|
-
for i in ds:
|
|
58
|
-
_ = layer(i)
|
|
59
|
-
# Benchmarked code ends here.
|
|
60
|
-
ends.append(time.time())
|
|
61
|
-
|
|
62
|
-
avg_time = np.mean(np.array(ends) - np.array(starts)) / num_batches
|
|
63
|
-
name = "category_encoding|batch_%s|seq_length_%s|%s_max_tokens" % (
|
|
64
|
-
batch_size,
|
|
65
|
-
sequence_length,
|
|
66
|
-
max_tokens,
|
|
67
|
-
)
|
|
68
|
-
self.report_benchmark(iters=num_repeats, wall_time=avg_time, name=name)
|
|
69
|
-
|
|
70
|
-
def benchmark_vocab_size_by_batch(self):
|
|
71
|
-
for batch in [32, 256, 2048]:
|
|
72
|
-
for sequence_length in [10, 1000]:
|
|
73
|
-
for num_tokens in [100, 1000, 20000]:
|
|
74
|
-
self.run_dataset_implementation(
|
|
75
|
-
output_mode="count",
|
|
76
|
-
batch_size=batch,
|
|
77
|
-
sequence_length=sequence_length,
|
|
78
|
-
max_tokens=num_tokens,
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
if __name__ == "__main__":
|
|
83
|
-
tf.test.main()
|
|
84
|
-
|