tf-models-nightly 2.17.0.dev20240411__py2.py3-none-any.whl → 2.17.0.dev20240413__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- official/modeling/multitask/interleaving_trainer.py +8 -8
- official/recommendation/uplift/keys.py +1 -1
- official/recommendation/uplift/layers/heads/two_tower_logits_head.py +1 -1
- official/recommendation/uplift/metrics/loss_metric.py +3 -0
- official/recommendation/uplift/metrics/sliced_metric.py +5 -0
- official/recommendation/uplift/metrics/sliced_metric_test.py +27 -0
- official/recommendation/uplift/models/two_tower_uplift_model_test.py +48 -0
- {tf_models_nightly-2.17.0.dev20240411.dist-info → tf_models_nightly-2.17.0.dev20240413.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.17.0.dev20240411.dist-info → tf_models_nightly-2.17.0.dev20240413.dist-info}/RECORD +13 -13
- {tf_models_nightly-2.17.0.dev20240411.dist-info → tf_models_nightly-2.17.0.dev20240413.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.17.0.dev20240411.dist-info → tf_models_nightly-2.17.0.dev20240413.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.17.0.dev20240411.dist-info → tf_models_nightly-2.17.0.dev20240413.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.17.0.dev20240411.dist-info → tf_models_nightly-2.17.0.dev20240413.dist-info}/top_level.txt +0 -0
@@ -43,12 +43,6 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
|
|
43
43
|
trainer_options=trainer_options)
|
44
44
|
self._task_sampler = task_sampler
|
45
45
|
|
46
|
-
# TODO(haozhangthu): Add taskwise step counter to train_loop_end for logging
|
47
|
-
# on TensorBoard.
|
48
|
-
self._task_step_counters = {
|
49
|
-
name: orbit.utils.create_global_step() for name in self.multi_task.tasks
|
50
|
-
}
|
51
|
-
|
52
46
|
# Build per task train step.
|
53
47
|
def _get_task_step(task_name, task):
|
54
48
|
|
@@ -63,8 +57,6 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
|
|
63
57
|
optimizer=self.optimizer,
|
64
58
|
metrics=self.training_metrics[task_name])
|
65
59
|
self.training_losses[task_name].update_state(task_logs[task.loss])
|
66
|
-
self.global_step.assign_add(1)
|
67
|
-
self.task_step_counter(task_name).assign_add(1)
|
68
60
|
|
69
61
|
return step_fn
|
70
62
|
|
@@ -73,6 +65,12 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
|
|
73
65
|
for name, task in self.multi_task.tasks.items()
|
74
66
|
}
|
75
67
|
|
68
|
+
# TODO(haozhangthu): Add taskwise step counter to train_loop_end for logging
|
69
|
+
# on TensorBoard.
|
70
|
+
self._task_step_counters = {
|
71
|
+
name: orbit.utils.create_global_step() for name in self.multi_task.tasks
|
72
|
+
}
|
73
|
+
|
76
74
|
# If the new Keras optimizer is used, we require all model variables are
|
77
75
|
# created before the training and let the optimizer to create the slot
|
78
76
|
# variable all together.
|
@@ -99,6 +97,8 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
|
|
99
97
|
if rn >= begin and rn < end:
|
100
98
|
self._strategy.run(
|
101
99
|
self._task_train_step_map[name], args=(next(iterator_map[name]),))
|
100
|
+
self.global_step.assign_add(1)
|
101
|
+
self.task_step_counter(name).assign_add(1)
|
102
102
|
|
103
103
|
def train_loop_end(self):
|
104
104
|
"""Record loss and metric values per task."""
|
@@ -183,6 +183,9 @@ class LossMetric(tf_keras.metrics.Metric):
|
|
183
183
|
def result(self) -> tf.Tensor | dict[str, tf.Tensor]:
|
184
184
|
return self._loss.result()
|
185
185
|
|
186
|
+
def reset_state(self):
|
187
|
+
self._loss.reset_state()
|
188
|
+
|
186
189
|
def get_config(self) -> dict[str, Any]:
|
187
190
|
config = super().get_config()
|
188
191
|
config["loss_fn"] = tf_keras.utils.serialize_keras_object(self._loss_fn)
|
@@ -193,6 +193,11 @@ class SlicedMetric(tf_keras.metrics.Metric):
|
|
193
193
|
f"{metric_result}."
|
194
194
|
)
|
195
195
|
|
196
|
+
def reset_state(self):
|
197
|
+
self._metric.reset_state()
|
198
|
+
for metric in self._sliced_metrics:
|
199
|
+
metric.reset_state()
|
200
|
+
|
196
201
|
def get_config(self):
|
197
202
|
return {
|
198
203
|
"name": self.name,
|
@@ -315,6 +315,33 @@ class SlicedMetricTest(keras_test_case.KerasTestCase, parameterized.TestCase):
|
|
315
315
|
}
|
316
316
|
self.assertDictEqual(expected_result, metric.result())
|
317
317
|
|
318
|
+
def test_reset_state(self):
|
319
|
+
metric = sliced_metric.SlicedMetric(
|
320
|
+
metric=tf_keras.metrics.AUC(curve="PR", from_logits=False, name="auc"),
|
321
|
+
slicing_spec={"control": False, "treatment": True},
|
322
|
+
)
|
323
|
+
|
324
|
+
expected_initial_result = {
|
325
|
+
"auc": 0.0,
|
326
|
+
"auc/control": 0.0,
|
327
|
+
"auc/treatment": 0.0,
|
328
|
+
}
|
329
|
+
self.assertAllClose(expected_initial_result, metric.result())
|
330
|
+
|
331
|
+
metric.update_state(
|
332
|
+
tf.constant([[0], [0], [1], [1]]), # y_true
|
333
|
+
tf.constant([[0.2], [0.6], [0.3], [0.7]]), # y_pred
|
334
|
+
slicing_feature=tf.constant([[True], [False], [True], [False]]),
|
335
|
+
)
|
336
|
+
|
337
|
+
result = metric.result()
|
338
|
+
self.assertGreater(result["auc"], 0.0)
|
339
|
+
self.assertGreater(result["auc/control"], 0.0)
|
340
|
+
self.assertGreater(result["auc/treatment"], 0.0)
|
341
|
+
|
342
|
+
metric.reset_state()
|
343
|
+
self.assertAllClose(expected_initial_result, metric.result())
|
344
|
+
|
318
345
|
def test_metric_config(self):
|
319
346
|
metric = sliced_metric.SlicedMetric(
|
320
347
|
tf_keras.metrics.SparseTopKCategoricalAccuracy(k=2, name="accuracy@2"),
|
@@ -21,6 +21,7 @@ from official.recommendation.uplift import keras_test_case
|
|
21
21
|
from official.recommendation.uplift import keys
|
22
22
|
from official.recommendation.uplift.layers.uplift_networks import two_tower_uplift_network
|
23
23
|
from official.recommendation.uplift.losses import true_logits_loss
|
24
|
+
from official.recommendation.uplift.metrics import loss_metric
|
24
25
|
from official.recommendation.uplift.models import two_tower_uplift_model
|
25
26
|
|
26
27
|
|
@@ -127,6 +128,53 @@ class TwoTowerUpliftModelTest(
|
|
127
128
|
}
|
128
129
|
self.assertAllClose(expected_predictions, model.predict(dataset))
|
129
130
|
|
131
|
+
def test_classification_model_trains(self):
|
132
|
+
tf_keras.utils.set_random_seed(1)
|
133
|
+
|
134
|
+
# Create binary classifier uplift model.
|
135
|
+
uplift_network = self._get_uplift_network(
|
136
|
+
control_feature_encoder=None, control_input_combiner=None
|
137
|
+
)
|
138
|
+
model = two_tower_uplift_model.TwoTowerUpliftModel(
|
139
|
+
treatment_indicator_feature_name="is_treatment",
|
140
|
+
uplift_network=uplift_network,
|
141
|
+
inverse_link_fn=tf.math.sigmoid,
|
142
|
+
)
|
143
|
+
model.compile(
|
144
|
+
optimizer=tf_keras.optimizers.SGD(0.1),
|
145
|
+
loss=true_logits_loss.TrueLogitsLoss(
|
146
|
+
loss_fn=tf_keras.losses.binary_crossentropy, from_logits=True
|
147
|
+
),
|
148
|
+
metrics=[
|
149
|
+
loss_metric.LossMetric(
|
150
|
+
tf_keras.metrics.AUC(curve="PR", from_logits=True, name="aucpr")
|
151
|
+
),
|
152
|
+
],
|
153
|
+
)
|
154
|
+
|
155
|
+
# Create toy classification dataset.
|
156
|
+
treatment = tf.constant([[1], [1], [0], [1], [1], [1], [0], [1], [0], [1]])
|
157
|
+
y = treatment
|
158
|
+
dataset = tf.data.Dataset.from_tensor_slices((
|
159
|
+
{
|
160
|
+
"shared_feature": np.random.normal(size=(10, 1)),
|
161
|
+
"treatment_feature": np.random.normal(size=(10, 1)),
|
162
|
+
"is_treatment": treatment,
|
163
|
+
},
|
164
|
+
y,
|
165
|
+
)).batch(5)
|
166
|
+
|
167
|
+
# Test model training.
|
168
|
+
history = model.fit(dataset, epochs=100)
|
169
|
+
self.assertIn("loss", history.history)
|
170
|
+
self.assertLen(history.history["loss"], 100)
|
171
|
+
self.assertBetween(
|
172
|
+
history.history["loss"][-1], 0.0, history.history["loss"][0]
|
173
|
+
)
|
174
|
+
self.assertIn("aucpr", history.history)
|
175
|
+
self.assertLess(history.history["aucpr"][0], 1.0)
|
176
|
+
self.assertEqual(history.history["aucpr"][-1], 1.0)
|
177
|
+
|
130
178
|
@parameterized.named_parameters(
|
131
179
|
{
|
132
180
|
"testcase_name": "identity",
|
@@ -218,7 +218,7 @@ official/modeling/multitask/base_trainer_test.py,sha256=qJ7z4kid2XAX6hOIvUHa7dwq
|
|
218
218
|
official/modeling/multitask/configs.py,sha256=ZO2waQrMn9CAgyFpsmeQvplCF5VeXz7tCPmIuy5jvlc,3164
|
219
219
|
official/modeling/multitask/evaluator.py,sha256=spDm2X8EX62qsxI2ehVjrkIKoo-omQQOYcAVKZNgxHc,6078
|
220
220
|
official/modeling/multitask/evaluator_test.py,sha256=vU-q-gM7GqiMqE5zbBnOT8mPFhQmHjniMyNnwganhso,4643
|
221
|
-
official/modeling/multitask/interleaving_trainer.py,sha256=
|
221
|
+
official/modeling/multitask/interleaving_trainer.py,sha256=ZZHKsqbJKLqvwtgy-PUv_S_8bDG0MhJDNwWICY_IF6Q,4458
|
222
222
|
official/modeling/multitask/interleaving_trainer_test.py,sha256=MeQQxpcinPTQuTrAcITjwHa2bAj-XCBCqYsrbxPBus8,4305
|
223
223
|
official/modeling/multitask/multitask.py,sha256=DV-ysfhPiIZgsrzZNylsPBxKNBf_xzPxJYjF4buWVgE,5948
|
224
224
|
official/modeling/multitask/task_sampler.py,sha256=SGVVdjMb5oG4vnCczpfdgBtbsdsXiyBLl9si_0V6nko,4897
|
@@ -887,7 +887,7 @@ official/recommendation/ranking/data/data_pipeline_multi_hot_test.py,sha256=arLj
|
|
887
887
|
official/recommendation/ranking/data/data_pipeline_test.py,sha256=VRYo7WqURRkM3lbmfctvSZxyH1EzUfqwxR3sy2sZxdc,2345
|
888
888
|
official/recommendation/uplift/__init__.py,sha256=_jZilTPWKu-MfMaz1IgBjEW6wqkK3FNZ1QAP4a8my3I,990
|
889
889
|
official/recommendation/uplift/keras_test_case.py,sha256=gF5Z2FzXlKAvhuJDdj7PmFj3jsW_ZmUAv_F9Xokvs2M,6156
|
890
|
-
official/recommendation/uplift/keys.py,sha256=
|
890
|
+
official/recommendation/uplift/keys.py,sha256=7zkxkPIcXceIN5hWm4ATai4h8ymwCj85dU2r8-3XifM,1032
|
891
891
|
official/recommendation/uplift/types.py,sha256=OBpMvU4uAOKXGwN1UIPZw6KWJXBXeuZGT3P3VyDkyWc,4769
|
892
892
|
official/recommendation/uplift/utils.py,sha256=ZSrzoFJosSdmu7P3yYDo-TER8UQKFimy9ntuXqV4L4c,3501
|
893
893
|
official/recommendation/uplift/utils_test.py,sha256=K-pzTe3MoOEjdy5tLNSRpIQoWWH12MjFTPkhEIHfS3A,4610
|
@@ -896,7 +896,7 @@ official/recommendation/uplift/layers/encoders/__init__.py,sha256=QaQk0BMHJNE8Pc
|
|
896
896
|
official/recommendation/uplift/layers/encoders/concat_features.py,sha256=UNjDXN2C-GBXJaiitMYmKuhIsOiR1ROfJ73vHkuGkoA,4210
|
897
897
|
official/recommendation/uplift/layers/encoders/concat_features_test.py,sha256=sIjdoXqQHfQQ5P0yOz_PCuGQHl6j-wKNzaN42PL1jKI,9368
|
898
898
|
official/recommendation/uplift/layers/heads/__init__.py,sha256=nbtNxVbIGZq4GLEw9APPwCX3Zhatv5qwc0ddvAaelvU,728
|
899
|
-
official/recommendation/uplift/layers/heads/two_tower_logits_head.py,sha256=
|
899
|
+
official/recommendation/uplift/layers/heads/two_tower_logits_head.py,sha256=EyeVdiRNe4qk1aEvYEp1tEeGdllUFqtt8saCrK_QI1M,6615
|
900
900
|
official/recommendation/uplift/layers/heads/two_tower_logits_head_test.py,sha256=GKk8kYgq_gXcheUfcG1u_i5I5nVh8jRdam9KxvEHsXU,7615
|
901
901
|
official/recommendation/uplift/layers/uplift_networks/__init__.py,sha256=eaM75bIO4WZRABzDjTl5bIsR5kKiICpzw-QpgXh7K0A,917
|
902
902
|
official/recommendation/uplift/layers/uplift_networks/base_uplift_networks.py,sha256=pGVmdNXyDZR0TA9Htxa07KoUOzE0uEl5IyHEAVRgTk4,1329
|
@@ -912,11 +912,11 @@ official/recommendation/uplift/metrics/label_mean.py,sha256=ECaes7FZmsksnwySn7jf
|
|
912
912
|
official/recommendation/uplift/metrics/label_mean_test.py,sha256=b_d3lNlpkDm2xKLUkxfiXeQg7pjL8HNx7y9NaYarpV0,7083
|
913
913
|
official/recommendation/uplift/metrics/label_variance.py,sha256=9DCl42BJkehxfWD3pSbZnRNvwfhVM6VyHwivGdaU72s,3610
|
914
914
|
official/recommendation/uplift/metrics/label_variance_test.py,sha256=k0mdEU1WU53-HIEO5HGtfp1MleifD-h4bZNKtTvM3Ws,7681
|
915
|
-
official/recommendation/uplift/metrics/loss_metric.py,sha256=
|
915
|
+
official/recommendation/uplift/metrics/loss_metric.py,sha256=gYZdnTsuL_2q1FZuPip-DaWxt_Q-02YYaePyMBVNx7w,7344
|
916
916
|
official/recommendation/uplift/metrics/loss_metric_test.py,sha256=48rQG8bKFdy0xBFjoOLXKRUlYpCEyAzSmPOFoF7FX94,16021
|
917
917
|
official/recommendation/uplift/metrics/metric_configs.py,sha256=Z-r79orE4EycQ5TJ7xdI5LhjOHT3wzChYyDxcxGqLXk,1670
|
918
|
-
official/recommendation/uplift/metrics/sliced_metric.py,sha256=
|
919
|
-
official/recommendation/uplift/metrics/sliced_metric_test.py,sha256=
|
918
|
+
official/recommendation/uplift/metrics/sliced_metric.py,sha256=uhvzudOWtMNKZ0avwGhX-37UELR9Cq9b4C0g8erBkXw,8688
|
919
|
+
official/recommendation/uplift/metrics/sliced_metric_test.py,sha256=bhVGyI1tOkFkVOtruJo3p6XopDFyG1JW5qdZm9-RqeU,12248
|
920
920
|
official/recommendation/uplift/metrics/treatment_fraction.py,sha256=WHrKfsN42xU7S-pK99xEVpVtd3zLD7UidLT1K8vgIn4,2757
|
921
921
|
official/recommendation/uplift/metrics/treatment_fraction_test.py,sha256=LtFljDdz9yfH1GNDMo8OcdS4yhsez5WyHsthH3qJf3s,5430
|
922
922
|
official/recommendation/uplift/metrics/treatment_sliced_metric.py,sha256=S0ZSoOHcjeWDWiEZlRnFHtRkOzizvrfmsFwbYP0Z0rY,3804
|
@@ -927,7 +927,7 @@ official/recommendation/uplift/metrics/variance.py,sha256=rhwZzUX-cRbwr-7vhC0I0b
|
|
927
927
|
official/recommendation/uplift/metrics/variance_test.py,sha256=EPISeHOFIh6WfODuC0SXbnmMugh90acMmm4BJkEZXlo,7757
|
928
928
|
official/recommendation/uplift/models/__init__.py,sha256=kWy2K5LGXHVyyrTjJvbVFcBjj1bjPRI2dpIq-sfdhvo,716
|
929
929
|
official/recommendation/uplift/models/two_tower_uplift_model.py,sha256=Fb6nLFAOqch81ravK57K9kggAeqvtJcBtKGZwCex0ts,5028
|
930
|
-
official/recommendation/uplift/models/two_tower_uplift_model_test.py,sha256=
|
930
|
+
official/recommendation/uplift/models/two_tower_uplift_model_test.py,sha256=J7qC9f0fDG1aIrLz85K1qUzTFyAIH0v8eA1yfPJb9YY,10061
|
931
931
|
official/utils/__init__.py,sha256=7oiypy0N82PDw9aSdcJBLVoGTd_oRSUOdvuJhMv4leQ,609
|
932
932
|
official/utils/hyperparams_flags.py,sha256=2FCAxfblio6ay36Yf4o7Nx188wRzFM1mbKOtVXiZCzo,4607
|
933
933
|
official/utils/docs/__init__.py,sha256=7oiypy0N82PDw9aSdcJBLVoGTd_oRSUOdvuJhMv4leQ,609
|
@@ -1204,9 +1204,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbda
|
|
1204
1204
|
tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
|
1205
1205
|
tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
|
1206
1206
|
tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
|
1207
|
-
tf_models_nightly-2.17.0.
|
1208
|
-
tf_models_nightly-2.17.0.
|
1209
|
-
tf_models_nightly-2.17.0.
|
1210
|
-
tf_models_nightly-2.17.0.
|
1211
|
-
tf_models_nightly-2.17.0.
|
1212
|
-
tf_models_nightly-2.17.0.
|
1207
|
+
tf_models_nightly-2.17.0.dev20240413.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
|
1208
|
+
tf_models_nightly-2.17.0.dev20240413.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
|
1209
|
+
tf_models_nightly-2.17.0.dev20240413.dist-info/METADATA,sha256=DqjFt5jNaqGygdbDJo34myYN5F7c6EboWUtItI2AQVQ,1432
|
1210
|
+
tf_models_nightly-2.17.0.dev20240413.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
|
1211
|
+
tf_models_nightly-2.17.0.dev20240413.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
|
1212
|
+
tf_models_nightly-2.17.0.dev20240413.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|