tf-keras-nightly 2.19.0.dev2024121210__py3-none-any.whl → 2.21.0.dev2025123010__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tf_keras/__init__.py +1 -1
- tf_keras/protobuf/projector_config_pb2.py +23 -12
- tf_keras/protobuf/saved_metadata_pb2.py +21 -10
- tf_keras/protobuf/versions_pb2.py +19 -8
- tf_keras/src/__init__.py +1 -1
- tf_keras/src/backend.py +1 -1
- tf_keras/src/datasets/boston_housing.py +14 -5
- tf_keras/src/datasets/cifar10.py +9 -1
- tf_keras/src/datasets/cifar100.py +7 -1
- tf_keras/src/datasets/fashion_mnist.py +16 -4
- tf_keras/src/datasets/imdb.py +8 -0
- tf_keras/src/datasets/mnist.py +9 -3
- tf_keras/src/datasets/reuters.py +8 -0
- tf_keras/src/engine/base_layer.py +235 -97
- tf_keras/src/engine/base_layer_utils.py +17 -5
- tf_keras/src/engine/base_layer_v1.py +12 -3
- tf_keras/src/engine/data_adapter.py +35 -19
- tf_keras/src/engine/functional.py +36 -15
- tf_keras/src/engine/input_layer.py +9 -0
- tf_keras/src/engine/input_spec.py +11 -1
- tf_keras/src/engine/sequential.py +29 -12
- tf_keras/src/layers/activation/softmax.py +26 -11
- tf_keras/src/layers/attention/multi_head_attention.py +8 -1
- tf_keras/src/layers/core/tf_op_layer.py +4 -0
- tf_keras/src/layers/normalization/spectral_normalization.py +29 -22
- tf_keras/src/layers/rnn/cell_wrappers.py +13 -1
- tf_keras/src/metrics/confusion_metrics.py +51 -4
- tf_keras/src/models/sharpness_aware_minimization.py +17 -7
- tf_keras/src/preprocessing/sequence.py +2 -2
- tf_keras/src/saving/legacy/saved_model/save_impl.py +28 -12
- tf_keras/src/saving/legacy/saving_utils.py +14 -2
- tf_keras/src/saving/saving_api.py +18 -5
- tf_keras/src/saving/saving_lib.py +1 -1
- tf_keras/src/utils/layer_utils.py +45 -3
- tf_keras/src/utils/metrics_utils.py +4 -1
- tf_keras/src/utils/tf_utils.py +2 -2
- {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/METADATA +14 -3
- {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/RECORD +40 -62
- {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/WHEEL +1 -1
- tf_keras/src/layers/preprocessing/benchmarks/bucketized_column_dense_benchmark.py +0 -85
- tf_keras/src/layers/preprocessing/benchmarks/category_encoding_benchmark.py +0 -84
- tf_keras/src/layers/preprocessing/benchmarks/category_hash_dense_benchmark.py +0 -89
- tf_keras/src/layers/preprocessing/benchmarks/category_hash_varlen_benchmark.py +0 -89
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_dense_benchmark.py +0 -110
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_varlen_benchmark.py +0 -103
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_dense_benchmark.py +0 -87
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_dense_benchmark.py +0 -96
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_varlen_benchmark.py +0 -96
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_varlen_benchmark.py +0 -87
- tf_keras/src/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py +0 -109
- tf_keras/src/layers/preprocessing/benchmarks/embedding_dense_benchmark.py +0 -86
- tf_keras/src/layers/preprocessing/benchmarks/embedding_varlen_benchmark.py +0 -89
- tf_keras/src/layers/preprocessing/benchmarks/hashed_crossing_benchmark.py +0 -90
- tf_keras/src/layers/preprocessing/benchmarks/hashing_benchmark.py +0 -105
- tf_keras/src/layers/preprocessing/benchmarks/image_preproc_benchmark.py +0 -159
- tf_keras/src/layers/preprocessing/benchmarks/index_lookup_adapt_benchmark.py +0 -135
- tf_keras/src/layers/preprocessing/benchmarks/index_lookup_forward_benchmark.py +0 -144
- tf_keras/src/layers/preprocessing/benchmarks/normalization_adapt_benchmark.py +0 -124
- tf_keras/src/layers/preprocessing/benchmarks/weighted_embedding_varlen_benchmark.py +0 -99
- tf_keras/src/saving/legacy/saved_model/create_test_saved_model.py +0 -37
- tf_keras/src/tests/keras_doctest.py +0 -159
- {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/top_level.txt +0 -0
|
@@ -1471,9 +1471,10 @@ class AUC(base_metric.Metric):
|
|
|
1471
1471
|
# label_weights should be of length equal to the number of
|
|
1472
1472
|
# labels.
|
|
1473
1473
|
shapes.append((self.label_weights, ("L",)))
|
|
1474
|
-
|
|
1475
|
-
|
|
1476
|
-
|
|
1474
|
+
|
|
1475
|
+
tf.debugging.assert_shapes(
|
|
1476
|
+
shapes, message="Number of labels is not consistent."
|
|
1477
|
+
)
|
|
1477
1478
|
|
|
1478
1479
|
# Only forward label_weights to update_confusion_matrix_variables when
|
|
1479
1480
|
# multi_label is False. Otherwise the averaging of individual label AUCs
|
|
@@ -1611,13 +1612,59 @@ class AUC(base_metric.Metric):
|
|
|
1611
1612
|
)
|
|
1612
1613
|
x = fp_rate
|
|
1613
1614
|
y = recall
|
|
1614
|
-
|
|
1615
|
+
elif self.curve == metrics_utils.AUCCurve.PR:
|
|
1615
1616
|
precision = tf.math.divide_no_nan(
|
|
1616
1617
|
self.true_positives,
|
|
1617
1618
|
tf.math.add(self.true_positives, self.false_positives),
|
|
1618
1619
|
)
|
|
1619
1620
|
x = recall
|
|
1620
1621
|
y = precision
|
|
1622
|
+
else: # curve == 'PR_GAIN'.
|
|
1623
|
+
# Due to the hyperbolic transform, this formula is less robust than
|
|
1624
|
+
# ROC or PR values. In particular
|
|
1625
|
+
# 1) Both measures diverge when there are no negative examples;
|
|
1626
|
+
# 2) Both measures diverge when there are no true positives;
|
|
1627
|
+
# 3) Recall gain becomes negative when the recall is lower than the
|
|
1628
|
+
# label average (i.e. when more negative examples are classified
|
|
1629
|
+
# positive than real positives).
|
|
1630
|
+
#
|
|
1631
|
+
# We ignore case 1 as it is easily communicated. For case 2 we set
|
|
1632
|
+
# recall_gain to 0 and precision_gain to 1. For case 3 we set the
|
|
1633
|
+
# recall_gain to 0. These fixes will result in an overastimation of
|
|
1634
|
+
# the AUC for estimateors that are anti-correlated with the label
|
|
1635
|
+
# (at some thresholds).
|
|
1636
|
+
#
|
|
1637
|
+
# The scaling factor $\frac{P}{N}$ that is used to form both
|
|
1638
|
+
# gain values.
|
|
1639
|
+
scaling_factor = tf.math.divide_no_nan(
|
|
1640
|
+
tf.math.add(self.true_positives, self.false_negatives),
|
|
1641
|
+
tf.math.add(self.false_positives, self.true_negatives),
|
|
1642
|
+
)
|
|
1643
|
+
|
|
1644
|
+
recall_gain = 1.0 - scaling_factor * tf.math.divide_no_nan(
|
|
1645
|
+
self.false_negatives, self.true_positives
|
|
1646
|
+
)
|
|
1647
|
+
precision_gain = 1.0 - scaling_factor * tf.math.divide_no_nan(
|
|
1648
|
+
self.false_positives, self.true_positives
|
|
1649
|
+
)
|
|
1650
|
+
# Handle case 2.
|
|
1651
|
+
recall_gain = tf.where(
|
|
1652
|
+
tf.equal(self.true_positives, 0.0),
|
|
1653
|
+
tf.zeros_like(recall_gain),
|
|
1654
|
+
recall_gain,
|
|
1655
|
+
)
|
|
1656
|
+
precision_gain = tf.where(
|
|
1657
|
+
tf.equal(self.true_positives, 0.0),
|
|
1658
|
+
tf.ones_like(precision_gain),
|
|
1659
|
+
precision_gain,
|
|
1660
|
+
)
|
|
1661
|
+
# Handle case 3.
|
|
1662
|
+
recall_gain = tf.math.maximum(
|
|
1663
|
+
recall_gain, tf.zeros_like(recall_gain)
|
|
1664
|
+
)
|
|
1665
|
+
|
|
1666
|
+
x = recall_gain
|
|
1667
|
+
y = precision_gain
|
|
1621
1668
|
|
|
1622
1669
|
# Find the rectangle heights based on `summation_method`.
|
|
1623
1670
|
if (
|
|
@@ -72,17 +72,27 @@ class SharpnessAwareMinimization(Model):
|
|
|
72
72
|
if self.num_batch_splits is not None:
|
|
73
73
|
x_split = tf.split(x, self.num_batch_splits)
|
|
74
74
|
y_split = tf.split(y, self.num_batch_splits)
|
|
75
|
+
# Split the sample weight if it is provided.
|
|
76
|
+
if sample_weight is not None:
|
|
77
|
+
sample_weight_split = tf.split(
|
|
78
|
+
sample_weight, self.num_batch_splits
|
|
79
|
+
)
|
|
80
|
+
else:
|
|
81
|
+
sample_weight_split = [None] * self.num_batch_splits
|
|
75
82
|
else:
|
|
76
83
|
x_split = [x]
|
|
77
84
|
y_split = [y]
|
|
85
|
+
sample_weight_split = [sample_weight]
|
|
78
86
|
|
|
79
87
|
gradients_all_batches = []
|
|
80
88
|
pred_all_batches = []
|
|
81
|
-
for x_batch, y_batch in zip(
|
|
89
|
+
for x_batch, y_batch, sample_weight_batch in zip(
|
|
90
|
+
x_split, y_split, sample_weight_split
|
|
91
|
+
):
|
|
82
92
|
epsilon_w_cache = []
|
|
83
93
|
with tf.GradientTape() as tape:
|
|
84
|
-
pred = self
|
|
85
|
-
loss = self.compiled_loss(y_batch, pred)
|
|
94
|
+
pred = self(x_batch)
|
|
95
|
+
loss = self.compiled_loss(y_batch, pred, sample_weight_batch)
|
|
86
96
|
pred_all_batches.append(pred)
|
|
87
97
|
trainable_variables = self.model.trainable_variables
|
|
88
98
|
gradients = tape.gradient(loss, trainable_variables)
|
|
@@ -98,8 +108,8 @@ class SharpnessAwareMinimization(Model):
|
|
|
98
108
|
epsilon_w_cache.append(epsilon_w)
|
|
99
109
|
|
|
100
110
|
with tf.GradientTape() as tape:
|
|
101
|
-
pred = self(x_batch)
|
|
102
|
-
loss = self.compiled_loss(y_batch, pred)
|
|
111
|
+
pred = self(x_batch, training=True)
|
|
112
|
+
loss = self.compiled_loss(y_batch, pred, sample_weight_batch)
|
|
103
113
|
gradients = tape.gradient(loss, trainable_variables)
|
|
104
114
|
if len(gradients_all_batches) == 0:
|
|
105
115
|
for gradient in gradients:
|
|
@@ -127,7 +137,7 @@ class SharpnessAwareMinimization(Model):
|
|
|
127
137
|
self.compiled_metrics.update_state(y, pred, sample_weight)
|
|
128
138
|
return {m.name: m.result() for m in self.metrics}
|
|
129
139
|
|
|
130
|
-
def call(self, inputs):
|
|
140
|
+
def call(self, inputs, **kwargs):
|
|
131
141
|
"""Forward pass of SAM.
|
|
132
142
|
|
|
133
143
|
SAM delegates the forward pass call to the wrapped model.
|
|
@@ -138,7 +148,7 @@ class SharpnessAwareMinimization(Model):
|
|
|
138
148
|
Returns:
|
|
139
149
|
A Tensor, the outputs of the wrapped model for given `inputs`.
|
|
140
150
|
"""
|
|
141
|
-
return self.model(inputs)
|
|
151
|
+
return self.model(inputs, **kwargs)
|
|
142
152
|
|
|
143
153
|
def get_config(self):
|
|
144
154
|
config = super().get_config()
|
|
@@ -365,7 +365,7 @@ def skipgrams(
|
|
|
365
365
|
random.shuffle(words)
|
|
366
366
|
|
|
367
367
|
couples += [
|
|
368
|
-
[words[i % len(words)], random.randint(1, vocabulary_size - 1)]
|
|
368
|
+
[words[i % len(words)], random.randint(1, int(vocabulary_size - 1))]
|
|
369
369
|
for i in range(num_negative_samples)
|
|
370
370
|
]
|
|
371
371
|
if categorical:
|
|
@@ -375,7 +375,7 @@ def skipgrams(
|
|
|
375
375
|
|
|
376
376
|
if shuffle:
|
|
377
377
|
if seed is None:
|
|
378
|
-
seed = random.randint(0, 10e6)
|
|
378
|
+
seed = random.randint(0, int(10e6))
|
|
379
379
|
random.seed(seed)
|
|
380
380
|
random.shuffle(couples)
|
|
381
381
|
random.seed(seed)
|
|
@@ -219,7 +219,11 @@ def wrap_layer_functions(layer, serialization_cache):
|
|
|
219
219
|
with tracing_scope():
|
|
220
220
|
call_collection.trace_with_input_signature()
|
|
221
221
|
with base_layer_utils.call_context().enter(
|
|
222
|
-
layer,
|
|
222
|
+
layer,
|
|
223
|
+
inputs=None,
|
|
224
|
+
build_graph=True,
|
|
225
|
+
call_context_args={},
|
|
226
|
+
saving=True,
|
|
223
227
|
):
|
|
224
228
|
for fn in fns.values():
|
|
225
229
|
if fn is not None and not isinstance(fn, LayerCall):
|
|
@@ -515,19 +519,28 @@ class LayerCallCollection:
|
|
|
515
519
|
else:
|
|
516
520
|
add_trace_to_queue(fn, args, kwargs)
|
|
517
521
|
|
|
518
|
-
def
|
|
522
|
+
def arg_was_passed(self, arg_name, args, kwargs):
|
|
523
|
+
"""Returns True if the argument was passed to the call function."""
|
|
519
524
|
return self._call_spec.arg_was_passed(
|
|
520
|
-
|
|
525
|
+
arg_name, args, kwargs, inputs_in_args=True
|
|
521
526
|
)
|
|
522
527
|
|
|
523
|
-
def
|
|
528
|
+
def training_arg_was_passed(self, args, kwargs):
|
|
529
|
+
"""Returns True if the training arg was passed to the call function."""
|
|
530
|
+
return self.arg_was_passed("training", args, kwargs)
|
|
531
|
+
|
|
532
|
+
def get_arg_value(self, arg_name, args, kwargs):
|
|
533
|
+
"""Returns the value of the given argument or None if not found."""
|
|
524
534
|
try:
|
|
525
535
|
return self._call_spec.get_arg_value(
|
|
526
|
-
|
|
536
|
+
arg_name, args, kwargs, inputs_in_args=True
|
|
527
537
|
)
|
|
528
|
-
except KeyError: #
|
|
538
|
+
except KeyError: # Arg not found in args or kwargs.
|
|
529
539
|
return None
|
|
530
540
|
|
|
541
|
+
def get_training_arg_value(self, args, kwargs):
|
|
542
|
+
return self.get_arg_value("training", args, kwargs)
|
|
543
|
+
|
|
531
544
|
def get_input_arg_value(self, args, kwargs):
|
|
532
545
|
return self._call_spec.get_arg_value(
|
|
533
546
|
self._input_arg_name, args, kwargs, inputs_in_args=True
|
|
@@ -613,20 +626,23 @@ def layer_call_wrapper(call_collection, method, name):
|
|
|
613
626
|
def wrapper(*args, **kwargs):
|
|
614
627
|
"""Calls method within call context."""
|
|
615
628
|
layer = call_collection.layer
|
|
616
|
-
|
|
629
|
+
propagated = {"training": None}
|
|
617
630
|
inputs = _filtered_inputs([args, kwargs])
|
|
618
631
|
|
|
619
|
-
|
|
620
|
-
args
|
|
621
|
-
|
|
622
|
-
|
|
632
|
+
for context_arg in layer._call_context_args:
|
|
633
|
+
if (args or kwargs) and call_collection.arg_was_passed(
|
|
634
|
+
context_arg, args, kwargs
|
|
635
|
+
):
|
|
636
|
+
propagated[context_arg] = call_collection.get_arg_value(
|
|
637
|
+
context_arg, args, kwargs
|
|
638
|
+
)
|
|
623
639
|
|
|
624
640
|
original_losses = _reset_layer_losses(layer)
|
|
625
641
|
with base_layer_utils.call_context().enter(
|
|
626
642
|
layer,
|
|
627
643
|
inputs=inputs,
|
|
628
644
|
build_graph=False,
|
|
629
|
-
|
|
645
|
+
call_context_args=propagated,
|
|
630
646
|
saving=True,
|
|
631
647
|
):
|
|
632
648
|
with autocast_variable.enable_auto_cast_variables(
|
|
@@ -138,12 +138,24 @@ def trace_model_call(model, input_signature=None):
|
|
|
138
138
|
@tf.function
|
|
139
139
|
def _wrapped_model(*args, **kwargs):
|
|
140
140
|
"""A concrete tf.function that wraps the model's call function."""
|
|
141
|
+
call_context = base_layer_utils.call_context()
|
|
142
|
+
|
|
143
|
+
args, kwargs, propagated = model._get_propagated_call_context_arguments(
|
|
144
|
+
args, kwargs, call_context, model._call_context_args
|
|
145
|
+
)
|
|
146
|
+
|
|
141
147
|
(args, kwargs,) = model._call_spec.set_arg_value(
|
|
142
148
|
"training", False, args, kwargs, inputs_in_args=True
|
|
143
149
|
)
|
|
144
150
|
|
|
145
|
-
|
|
146
|
-
|
|
151
|
+
propagated["training"] = False
|
|
152
|
+
|
|
153
|
+
with call_context.enter(
|
|
154
|
+
model,
|
|
155
|
+
inputs=None,
|
|
156
|
+
build_graph=False,
|
|
157
|
+
call_context_args=propagated,
|
|
158
|
+
saving=True,
|
|
147
159
|
):
|
|
148
160
|
outputs = model(*args, **kwargs)
|
|
149
161
|
|
|
@@ -24,6 +24,7 @@ from tensorflow.python.util.tf_export import keras_export
|
|
|
24
24
|
|
|
25
25
|
from tf_keras.src.saving import saving_lib
|
|
26
26
|
from tf_keras.src.saving.legacy import save as legacy_sm_saving_lib
|
|
27
|
+
from tf_keras.src.saving.legacy import saving_utils
|
|
27
28
|
from tf_keras.src.utils import io_utils
|
|
28
29
|
|
|
29
30
|
try:
|
|
@@ -75,8 +76,7 @@ class SupportWriteToRemote:
|
|
|
75
76
|
supports remoted saved model out of the box.
|
|
76
77
|
"""
|
|
77
78
|
|
|
78
|
-
def __init__(self, filepath, overwrite
|
|
79
|
-
save_format = get_save_format(filepath, save_format=save_format)
|
|
79
|
+
def __init__(self, filepath, overwrite, save_format):
|
|
80
80
|
self.overwrite = overwrite
|
|
81
81
|
if saving_lib.is_remote_path(filepath) and save_format != "tf":
|
|
82
82
|
self.temp_directory = tempfile.TemporaryDirectory()
|
|
@@ -191,14 +191,14 @@ def save_model(model, filepath, overwrite=True, save_format=None, **kwargs):
|
|
|
191
191
|
when loading the model. See the `custom_objects` argument in
|
|
192
192
|
`tf.keras.saving.load_model`.
|
|
193
193
|
"""
|
|
194
|
+
save_format = get_save_format(filepath, save_format)
|
|
195
|
+
|
|
194
196
|
# Supports remote paths via a temporary file
|
|
195
197
|
with SupportWriteToRemote(
|
|
196
198
|
filepath,
|
|
197
199
|
overwrite=overwrite,
|
|
198
200
|
save_format=save_format,
|
|
199
201
|
) as local_filepath:
|
|
200
|
-
save_format = get_save_format(filepath, save_format)
|
|
201
|
-
|
|
202
202
|
# Deprecation warnings
|
|
203
203
|
if save_format == "h5":
|
|
204
204
|
warnings.warn(
|
|
@@ -307,8 +307,12 @@ def load_model(
|
|
|
307
307
|
|
|
308
308
|
|
|
309
309
|
def save_weights(model, filepath, overwrite=True, **kwargs):
|
|
310
|
+
save_format = get_save_weights_format(filepath)
|
|
311
|
+
|
|
310
312
|
# Supports remote paths via a temporary file
|
|
311
|
-
with SupportWriteToRemote(
|
|
313
|
+
with SupportWriteToRemote(
|
|
314
|
+
filepath, overwrite=overwrite, save_format=save_format
|
|
315
|
+
) as local_filepath:
|
|
312
316
|
if str(local_filepath).endswith(".weights.h5"):
|
|
313
317
|
# If file exists and should not be overwritten.
|
|
314
318
|
try:
|
|
@@ -386,3 +390,12 @@ def get_save_format(filepath, save_format):
|
|
|
386
390
|
else:
|
|
387
391
|
return "h5"
|
|
388
392
|
|
|
393
|
+
|
|
394
|
+
def get_save_weights_format(filepath):
|
|
395
|
+
filepath = io_utils.path_to_string(filepath)
|
|
396
|
+
filepath_is_h5 = saving_utils.is_hdf5_filepath(filepath)
|
|
397
|
+
if filepath_is_h5:
|
|
398
|
+
return "h5"
|
|
399
|
+
else:
|
|
400
|
+
return "tf"
|
|
401
|
+
|
|
@@ -639,7 +639,7 @@ class NpzIOStore:
|
|
|
639
639
|
self.f = archive.open(root_path, mode="r")
|
|
640
640
|
else:
|
|
641
641
|
self.f = open(root_path, mode="rb")
|
|
642
|
-
self.contents = np.load(self.f, allow_pickle=
|
|
642
|
+
self.contents = np.load(self.f, allow_pickle=False)
|
|
643
643
|
|
|
644
644
|
def make(self, path):
|
|
645
645
|
if not path:
|
|
@@ -775,11 +775,13 @@ class CallFunctionSpec:
|
|
|
775
775
|
"""Caches the spec and provides utilities for handling call function
|
|
776
776
|
args."""
|
|
777
777
|
|
|
778
|
-
def __init__(self, full_argspec):
|
|
778
|
+
def __init__(self, full_argspec, call_context_args=set()):
|
|
779
779
|
"""Initialies a `CallFunctionSpec`.
|
|
780
780
|
|
|
781
781
|
Args:
|
|
782
782
|
full_argspec: the FullArgSpec of a call function of a layer.
|
|
783
|
+
call_context_args: The set of call-context arguments registered
|
|
784
|
+
with to the current layer.
|
|
783
785
|
"""
|
|
784
786
|
self._full_argspec = full_argspec
|
|
785
787
|
|
|
@@ -797,6 +799,18 @@ class CallFunctionSpec:
|
|
|
797
799
|
"mask" in self._arg_names or call_accepts_kwargs
|
|
798
800
|
)
|
|
799
801
|
|
|
802
|
+
# Track the set of call-context arguments that the current layer's
|
|
803
|
+
# `call` method accepts.
|
|
804
|
+
self._expected_context_args = set()
|
|
805
|
+
self._update_call_context_arguments(call_context_args)
|
|
806
|
+
|
|
807
|
+
self._context_arg_defaults = dict()
|
|
808
|
+
self._update_call_context_argument_defaults(call_context_args)
|
|
809
|
+
|
|
810
|
+
def _update_call_context_argument_defaults(self, context_args):
|
|
811
|
+
"""Updates the set of call-context argument defaults for the current
|
|
812
|
+
layer's `call` method.
|
|
813
|
+
"""
|
|
800
814
|
call_fn_defaults = self._full_argspec.defaults or []
|
|
801
815
|
defaults = dict()
|
|
802
816
|
# The call arg defaults are an n-tuple of the last n elements of the
|
|
@@ -806,7 +820,21 @@ class CallFunctionSpec:
|
|
|
806
820
|
# The default training arg will be any (non-None) default specified in
|
|
807
821
|
# the method signature, or None if no value is specified.
|
|
808
822
|
defaults.update(self._full_argspec.kwonlydefaults or {})
|
|
809
|
-
|
|
823
|
+
|
|
824
|
+
for arg in context_args:
|
|
825
|
+
self._context_arg_defaults[arg] = defaults.get(arg)
|
|
826
|
+
|
|
827
|
+
def _update_call_context_arguments(self, context_args):
|
|
828
|
+
"""Updates the set of call-context arguments that the current layer's
|
|
829
|
+
`call` method accepts.
|
|
830
|
+
"""
|
|
831
|
+
call_accepts_kwargs = self._full_argspec.varkw is not None
|
|
832
|
+
args_to_add = {
|
|
833
|
+
arg
|
|
834
|
+
for arg in context_args
|
|
835
|
+
if call_accepts_kwargs or arg in self._arg_names
|
|
836
|
+
}
|
|
837
|
+
self._expected_context_args.update(args_to_add)
|
|
810
838
|
|
|
811
839
|
@property
|
|
812
840
|
def full_argspec(self):
|
|
@@ -843,6 +871,16 @@ class CallFunctionSpec:
|
|
|
843
871
|
def expects_training_arg(self, value):
|
|
844
872
|
self._expects_training_arg = value
|
|
845
873
|
|
|
874
|
+
@property
|
|
875
|
+
def expected_context_args(self):
|
|
876
|
+
"""The set of call-context arguments that the current layer's
|
|
877
|
+
`call` method accepts."""
|
|
878
|
+
return self._expected_context_args
|
|
879
|
+
|
|
880
|
+
@expected_context_args.setter
|
|
881
|
+
def expected_context_args(self, value):
|
|
882
|
+
self._expected_context_args = value
|
|
883
|
+
|
|
846
884
|
@property
|
|
847
885
|
def expects_mask_arg(self):
|
|
848
886
|
"""Whether the call function uses `mask` as a parameter."""
|
|
@@ -855,7 +893,11 @@ class CallFunctionSpec:
|
|
|
855
893
|
@property
|
|
856
894
|
def default_training_arg(self):
|
|
857
895
|
"""The default value given to the "training" argument."""
|
|
858
|
-
return self.
|
|
896
|
+
return self.get_context_arg_default("training")
|
|
897
|
+
|
|
898
|
+
def get_context_arg_default(self, arg_name):
|
|
899
|
+
"""The default value given to the call context arguments."""
|
|
900
|
+
return self._context_arg_defaults.get(arg_name, None)
|
|
859
901
|
|
|
860
902
|
def arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False):
|
|
861
903
|
"""Returns true if argument is present in `args` or `kwargs`.
|
|
@@ -237,6 +237,7 @@ class AUCCurve(Enum):
|
|
|
237
237
|
|
|
238
238
|
ROC = "ROC"
|
|
239
239
|
PR = "PR"
|
|
240
|
+
PR_GAIN = "PR_GAIN"
|
|
240
241
|
|
|
241
242
|
@staticmethod
|
|
242
243
|
def from_str(key):
|
|
@@ -244,10 +245,12 @@ class AUCCurve(Enum):
|
|
|
244
245
|
return AUCCurve.PR
|
|
245
246
|
elif key in ("roc", "ROC"):
|
|
246
247
|
return AUCCurve.ROC
|
|
248
|
+
elif key in ("pr_gain", "prgain", "PR_GAIN", "PRGAIN"):
|
|
249
|
+
return AUCCurve.PR_GAIN
|
|
247
250
|
else:
|
|
248
251
|
raise ValueError(
|
|
249
252
|
f'Invalid AUC curve value: "{key}". '
|
|
250
|
-
'Expected values are ["PR", "ROC"]'
|
|
253
|
+
'Expected values are ["PR", "ROC", "PR_GAIN"]'
|
|
251
254
|
)
|
|
252
255
|
|
|
253
256
|
|
tf_keras/src/utils/tf_utils.py
CHANGED
|
@@ -78,9 +78,9 @@ def get_random_seed():
|
|
|
78
78
|
the random seed as an integer.
|
|
79
79
|
"""
|
|
80
80
|
if getattr(backend._SEED_GENERATOR, "generator", None):
|
|
81
|
-
return backend._SEED_GENERATOR.generator.randint(1, 1e9)
|
|
81
|
+
return backend._SEED_GENERATOR.generator.randint(1, int(1e9))
|
|
82
82
|
else:
|
|
83
|
-
return random.randint(1, 1e9)
|
|
83
|
+
return random.randint(1, int(1e9))
|
|
84
84
|
|
|
85
85
|
|
|
86
86
|
def is_tensor_or_tensor_list(v):
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: tf_keras-nightly
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.21.0.dev2025123010
|
|
4
4
|
Summary: Deep learning for humans.
|
|
5
5
|
Home-page: https://keras.io/
|
|
6
6
|
Download-URL: https://github.com/keras-team/tf-keras/tags
|
|
@@ -26,7 +26,18 @@ Classifier: Topic :: Software Development
|
|
|
26
26
|
Classifier: Topic :: Software Development :: Libraries
|
|
27
27
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
28
28
|
Requires-Python: >=3.9
|
|
29
|
-
Requires-Dist: tf-nightly~=2.
|
|
29
|
+
Requires-Dist: tf-nightly~=2.21.0.dev
|
|
30
|
+
Dynamic: author
|
|
31
|
+
Dynamic: author-email
|
|
32
|
+
Dynamic: classifier
|
|
33
|
+
Dynamic: description
|
|
34
|
+
Dynamic: download-url
|
|
35
|
+
Dynamic: home-page
|
|
36
|
+
Dynamic: keywords
|
|
37
|
+
Dynamic: license
|
|
38
|
+
Dynamic: requires-dist
|
|
39
|
+
Dynamic: requires-python
|
|
40
|
+
Dynamic: summary
|
|
30
41
|
|
|
31
42
|
TF-Keras is a deep learning API written in Python,
|
|
32
43
|
running on top of the machine learning platform TensorFlow.
|