tf-models-nightly 2.17.0.dev20240410__py2.py3-none-any.whl → 2.17.0.dev20240412__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- official/recommendation/uplift/keys.py +1 -1
- official/recommendation/uplift/layers/heads/two_tower_logits_head.py +1 -1
- official/recommendation/uplift/metrics/loss_metric.py +23 -15
- official/recommendation/uplift/metrics/loss_metric_test.py +32 -1
- tensorflow_models/__init__.py +1 -0
- tensorflow_models/tensorflow_models_test.py +11 -0
- tensorflow_models/uplift/__init__.py +23 -0
- {tf_models_nightly-2.17.0.dev20240410.dist-info → tf_models_nightly-2.17.0.dev20240412.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.17.0.dev20240410.dist-info → tf_models_nightly-2.17.0.dev20240412.dist-info}/RECORD +13 -12
- {tf_models_nightly-2.17.0.dev20240410.dist-info → tf_models_nightly-2.17.0.dev20240412.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.17.0.dev20240410.dist-info → tf_models_nightly-2.17.0.dev20240412.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.17.0.dev20240410.dist-info → tf_models_nightly-2.17.0.dev20240412.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.17.0.dev20240410.dist-info → tf_models_nightly-2.17.0.dev20240412.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,7 @@ from __future__ import annotations
|
|
19
19
|
import inspect
|
20
20
|
from typing import Any, Callable
|
21
21
|
|
22
|
+
import numpy as np
|
22
23
|
import tensorflow as tf, tf_keras
|
23
24
|
|
24
25
|
from official.recommendation.uplift import types
|
@@ -29,9 +30,6 @@ from official.recommendation.uplift.metrics import treatment_sliced_metric
|
|
29
30
|
class LossMetric(tf_keras.metrics.Metric):
|
30
31
|
"""Computes a loss sliced by treatment group.
|
31
32
|
|
32
|
-
Note that the prediction tensor is expected to be of type
|
33
|
-
`TwoTowerTrainingOutputs`.
|
34
|
-
|
35
33
|
Example standalone usage:
|
36
34
|
|
37
35
|
>>> sliced_loss = LossMetric(tf_keras.losses.mean_squared_error)
|
@@ -74,9 +72,10 @@ class LossMetric(tf_keras.metrics.Metric):
|
|
74
72
|
`__call__(y_true: tf,Tensor, y_pred: tf.Tensor, **loss_fn_kwargs)`. Note
|
75
73
|
that the `loss_fn_kwargs` will not be passed to the `__call__` method if
|
76
74
|
`loss_fn` is a Keras metric.
|
77
|
-
from_logits:
|
78
|
-
|
79
|
-
the true logits.
|
75
|
+
from_logits: When `y_pred` is of type `TwoTowerTrainingOutputs`, specifies
|
76
|
+
whether the true logits or true predictions should be used to compute
|
77
|
+
the loss (defaults to using the true logits). Othwerwise, this argument
|
78
|
+
will be ignored if `y_pred` is of type `tf.Tensor`.
|
80
79
|
slice_by_treatment: Specifies whether the loss should be sliced by the
|
81
80
|
treatment indicator tensor. If `True`, `loss_fn` will be wrapped in a
|
82
81
|
`TreatmentSlicedMetric` to report the loss values sliced by the
|
@@ -129,16 +128,16 @@ class LossMetric(tf_keras.metrics.Metric):
|
|
129
128
|
def update_state(
|
130
129
|
self,
|
131
130
|
y_true: tf.Tensor,
|
132
|
-
y_pred: types.TwoTowerTrainingOutputs,
|
131
|
+
y_pred: types.TwoTowerTrainingOutputs | tf.Tensor | np.ndarray,
|
133
132
|
sample_weight: tf.Tensor | None = None,
|
134
133
|
):
|
135
134
|
"""Updates the overall, control and treatment losses.
|
136
135
|
|
137
136
|
Args:
|
138
137
|
y_true: A `tf.Tensor` with the targets.
|
139
|
-
y_pred:
|
140
|
-
to slice the true logits or true predictions
|
141
|
-
losses.
|
138
|
+
y_pred: Model outputs. If of type `TwoTowerTrainingOutputs`, the treatment
|
139
|
+
indicator tensor is used to slice the true logits or true predictions
|
140
|
+
into control and treatment losses.
|
142
141
|
sample_weight: Optional sample weight to compute weighted losses. If
|
143
142
|
given, the sample weight will also be sliced by the treatment indicator
|
144
143
|
tensor to compute the weighted control and treatment losses.
|
@@ -146,14 +145,23 @@ class LossMetric(tf_keras.metrics.Metric):
|
|
146
145
|
Raises:
|
147
146
|
TypeError: if `y_pred` is not of type `TwoTowerTrainingOutputs`.
|
148
147
|
"""
|
149
|
-
if
|
148
|
+
if isinstance(y_pred, (tf.Tensor, np.ndarray)):
|
149
|
+
if self._slice_by_treatment:
|
150
|
+
raise ValueError(
|
151
|
+
"`slice_by_treatment` must be False when y_pred is a `tf.Tensor` or"
|
152
|
+
" `np.ndarray`."
|
153
|
+
)
|
154
|
+
pred = y_pred
|
155
|
+
elif isinstance(y_pred, types.TwoTowerTrainingOutputs):
|
156
|
+
pred = (
|
157
|
+
y_pred.true_logits if self._from_logits else y_pred.true_predictions
|
158
|
+
)
|
159
|
+
else:
|
150
160
|
raise TypeError(
|
151
|
-
"y_pred must be of type `TwoTowerTrainingOutputs`
|
152
|
-
f" {type(y_pred)} instead."
|
161
|
+
"y_pred must be of type `TwoTowerTrainingOutputs`, `tf.Tensor` or"
|
162
|
+
f" `np.ndarray` but got type {type(y_pred)} instead."
|
153
163
|
)
|
154
164
|
|
155
|
-
pred = y_pred.true_logits if self._from_logits else y_pred.true_predictions
|
156
|
-
|
157
165
|
is_treatment = {}
|
158
166
|
if self._slice_by_treatment:
|
159
167
|
is_treatment["is_treatment"] = y_pred.is_treatment
|
@@ -288,6 +288,19 @@ class LossMetricTest(keras_test_case.KerasTestCase, parameterized.TestCase):
|
|
288
288
|
metric(y_true, outputs, sample_weight=sample_weight)
|
289
289
|
self.assertEqual(expected_losses, metric.result())
|
290
290
|
|
291
|
+
def test_metric_with_y_pred_tensor(self):
|
292
|
+
y_true = tf.constant([[0], [0], [2], [7]])
|
293
|
+
y_pred = tf.constant([[1], [2], [3], [4]])
|
294
|
+
sample_weight = tf.constant([[0.5], [0.5], [0.7], [1.8]])
|
295
|
+
|
296
|
+
metric = loss_metric.LossMetric(
|
297
|
+
loss_fn=tf_keras.metrics.mae, slice_by_treatment=False
|
298
|
+
)
|
299
|
+
metric(y_true, y_pred, sample_weight)
|
300
|
+
|
301
|
+
expected_loss = np.average([1, 2, 1, 3], weights=[0.5, 0.5, 0.7, 1.8])
|
302
|
+
self.assertAllClose(expected_loss, metric.result())
|
303
|
+
|
291
304
|
def test_multiple_update_batches_returns_aggregated_sliced_losses(self):
|
292
305
|
metric = loss_metric.LossMetric(
|
293
306
|
loss_fn=tf_keras.losses.mean_absolute_error,
|
@@ -397,8 +410,26 @@ class LossMetricTest(keras_test_case.KerasTestCase, parameterized.TestCase):
|
|
397
410
|
metric = loss_metric.LossMetric(
|
398
411
|
tf_keras.metrics.mean_absolute_percentage_error
|
399
412
|
)
|
413
|
+
y_true = tf.ones((3, 1))
|
414
|
+
y_pred = types.TwoTowerNetworkOutputs(
|
415
|
+
shared_embedding=tf.ones((3, 5)),
|
416
|
+
control_logits=tf.ones((3, 1)),
|
417
|
+
treatment_logits=tf.ones((3, 1)),
|
418
|
+
)
|
419
|
+
with self.assertRaisesRegex(
|
420
|
+
TypeError,
|
421
|
+
"y_pred must be of type `TwoTowerTrainingOutputs`, `tf.Tensor` or"
|
422
|
+
" `np.ndarray`",
|
423
|
+
):
|
424
|
+
metric.update_state(y_true=y_true, y_pred=y_pred)
|
425
|
+
|
426
|
+
def test_slice_by_treatment_with_y_pred_tensor_raises_error(self):
|
427
|
+
metric = loss_metric.LossMetric(
|
428
|
+
tf_keras.metrics.mae, slice_by_treatment=True
|
429
|
+
)
|
400
430
|
with self.assertRaisesRegex(
|
401
|
-
|
431
|
+
ValueError,
|
432
|
+
"`slice_by_treatment` must be False when y_pred is a `tf.Tensor`.",
|
402
433
|
):
|
403
434
|
metric.update_state(y_true=tf.ones((3, 1)), y_pred=tf.ones((3, 1)))
|
404
435
|
|
tensorflow_models/__init__.py
CHANGED
@@ -35,6 +35,17 @@ class TensorflowModelsTest(tf.test.TestCase):
|
|
35
35
|
_ = tfm.optimization.LinearWarmup(
|
36
36
|
after_warmup_lr_sched=0.0, warmup_steps=10, warmup_learning_rate=0.1)
|
37
37
|
|
38
|
+
def testUpliftImports(self):
|
39
|
+
_ = tfm.uplift.keys.TwoTowerOutputKeys.CONTROL_PREDICTIONS
|
40
|
+
_ = tfm.uplift.types.TwoTowerNetworkOutputs(
|
41
|
+
shared_embedding=tf.ones((10, 10)),
|
42
|
+
control_logits=tf.ones((10, 1)),
|
43
|
+
treatment_logits=tf.ones((10, 1)),
|
44
|
+
)
|
45
|
+
_ = tfm.uplift.layers.encoders.concat_features.ConcatFeatures(['feature'])
|
46
|
+
_ = tfm.uplift.metrics.treatment_fraction.TreatmentFraction()
|
47
|
+
_ = tfm.uplift.losses.true_logits_loss.TrueLogitsLoss(tf_keras.losses.mse)
|
48
|
+
|
38
49
|
|
39
50
|
if __name__ == '__main__':
|
40
51
|
tf.test.main()
|
@@ -0,0 +1,23 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""TensorFlow Models Uplift Libraries."""
|
16
|
+
|
17
|
+
from official.recommendation.uplift import keys
|
18
|
+
from official.recommendation.uplift import layers
|
19
|
+
from official.recommendation.uplift import losses
|
20
|
+
from official.recommendation.uplift import metrics
|
21
|
+
from official.recommendation.uplift import models
|
22
|
+
from official.recommendation.uplift import types
|
23
|
+
from official.recommendation.uplift import utils
|
@@ -887,7 +887,7 @@ official/recommendation/ranking/data/data_pipeline_multi_hot_test.py,sha256=arLj
|
|
887
887
|
official/recommendation/ranking/data/data_pipeline_test.py,sha256=VRYo7WqURRkM3lbmfctvSZxyH1EzUfqwxR3sy2sZxdc,2345
|
888
888
|
official/recommendation/uplift/__init__.py,sha256=_jZilTPWKu-MfMaz1IgBjEW6wqkK3FNZ1QAP4a8my3I,990
|
889
889
|
official/recommendation/uplift/keras_test_case.py,sha256=gF5Z2FzXlKAvhuJDdj7PmFj3jsW_ZmUAv_F9Xokvs2M,6156
|
890
|
-
official/recommendation/uplift/keys.py,sha256=
|
890
|
+
official/recommendation/uplift/keys.py,sha256=7zkxkPIcXceIN5hWm4ATai4h8ymwCj85dU2r8-3XifM,1032
|
891
891
|
official/recommendation/uplift/types.py,sha256=OBpMvU4uAOKXGwN1UIPZw6KWJXBXeuZGT3P3VyDkyWc,4769
|
892
892
|
official/recommendation/uplift/utils.py,sha256=ZSrzoFJosSdmu7P3yYDo-TER8UQKFimy9ntuXqV4L4c,3501
|
893
893
|
official/recommendation/uplift/utils_test.py,sha256=K-pzTe3MoOEjdy5tLNSRpIQoWWH12MjFTPkhEIHfS3A,4610
|
@@ -896,7 +896,7 @@ official/recommendation/uplift/layers/encoders/__init__.py,sha256=QaQk0BMHJNE8Pc
|
|
896
896
|
official/recommendation/uplift/layers/encoders/concat_features.py,sha256=UNjDXN2C-GBXJaiitMYmKuhIsOiR1ROfJ73vHkuGkoA,4210
|
897
897
|
official/recommendation/uplift/layers/encoders/concat_features_test.py,sha256=sIjdoXqQHfQQ5P0yOz_PCuGQHl6j-wKNzaN42PL1jKI,9368
|
898
898
|
official/recommendation/uplift/layers/heads/__init__.py,sha256=nbtNxVbIGZq4GLEw9APPwCX3Zhatv5qwc0ddvAaelvU,728
|
899
|
-
official/recommendation/uplift/layers/heads/two_tower_logits_head.py,sha256=
|
899
|
+
official/recommendation/uplift/layers/heads/two_tower_logits_head.py,sha256=EyeVdiRNe4qk1aEvYEp1tEeGdllUFqtt8saCrK_QI1M,6615
|
900
900
|
official/recommendation/uplift/layers/heads/two_tower_logits_head_test.py,sha256=GKk8kYgq_gXcheUfcG1u_i5I5nVh8jRdam9KxvEHsXU,7615
|
901
901
|
official/recommendation/uplift/layers/uplift_networks/__init__.py,sha256=eaM75bIO4WZRABzDjTl5bIsR5kKiICpzw-QpgXh7K0A,917
|
902
902
|
official/recommendation/uplift/layers/uplift_networks/base_uplift_networks.py,sha256=pGVmdNXyDZR0TA9Htxa07KoUOzE0uEl5IyHEAVRgTk4,1329
|
@@ -912,8 +912,8 @@ official/recommendation/uplift/metrics/label_mean.py,sha256=ECaes7FZmsksnwySn7jf
|
|
912
912
|
official/recommendation/uplift/metrics/label_mean_test.py,sha256=b_d3lNlpkDm2xKLUkxfiXeQg7pjL8HNx7y9NaYarpV0,7083
|
913
913
|
official/recommendation/uplift/metrics/label_variance.py,sha256=9DCl42BJkehxfWD3pSbZnRNvwfhVM6VyHwivGdaU72s,3610
|
914
914
|
official/recommendation/uplift/metrics/label_variance_test.py,sha256=k0mdEU1WU53-HIEO5HGtfp1MleifD-h4bZNKtTvM3Ws,7681
|
915
|
-
official/recommendation/uplift/metrics/loss_metric.py,sha256=
|
916
|
-
official/recommendation/uplift/metrics/loss_metric_test.py,sha256=
|
915
|
+
official/recommendation/uplift/metrics/loss_metric.py,sha256=owN7A98TCc_UhOURvGfccaoVGOthdHdx1_fawEUGnmw,7289
|
916
|
+
official/recommendation/uplift/metrics/loss_metric_test.py,sha256=48rQG8bKFdy0xBFjoOLXKRUlYpCEyAzSmPOFoF7FX94,16021
|
917
917
|
official/recommendation/uplift/metrics/metric_configs.py,sha256=Z-r79orE4EycQ5TJ7xdI5LhjOHT3wzChYyDxcxGqLXk,1670
|
918
918
|
official/recommendation/uplift/metrics/sliced_metric.py,sha256=O2I2apZK6IfOQK9Q_mgSiTUCnGokczp4e14zrrYNeRU,8564
|
919
919
|
official/recommendation/uplift/metrics/sliced_metric_test.py,sha256=dhY41X8lqT_WW04XLjyjDerZwujEBGeTXtxf4NkYThw,11359
|
@@ -1199,13 +1199,14 @@ orbit/utils/summary_manager.py,sha256=oTdTQYDjH315rbJ__NFALldruHX5lLrwY34IU1KJwN
|
|
1199
1199
|
orbit/utils/summary_manager_interface.py,sha256=MXXPd3aWqushByK-gh-Jk9A42VC5ETEEm7XY5DCB-9w,2253
|
1200
1200
|
orbit/utils/tpu_summaries.py,sha256=vjfLJlNifFJayLo_nkdRD85XoisVBV2VDuzgZWZud6U,5641
|
1201
1201
|
orbit/utils/tpu_summaries_test.py,sha256=ZRYguds3BcXn1gOXBHVJ5rGRZdeXQLa5w3IbwEBwGS4,4523
|
1202
|
-
tensorflow_models/__init__.py,sha256=
|
1203
|
-
tensorflow_models/tensorflow_models_test.py,sha256=
|
1202
|
+
tensorflow_models/__init__.py,sha256=n7s5ZGPDM-RqOQNldPyipjL8aEB4DdAQQUENYhgf_Q8,946
|
1203
|
+
tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbdanoDB12hb32dzPc,1897
|
1204
1204
|
tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
|
1205
|
+
tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
|
1205
1206
|
tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
|
1206
|
-
tf_models_nightly-2.17.0.
|
1207
|
-
tf_models_nightly-2.17.0.
|
1208
|
-
tf_models_nightly-2.17.0.
|
1209
|
-
tf_models_nightly-2.17.0.
|
1210
|
-
tf_models_nightly-2.17.0.
|
1211
|
-
tf_models_nightly-2.17.0.
|
1207
|
+
tf_models_nightly-2.17.0.dev20240412.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
|
1208
|
+
tf_models_nightly-2.17.0.dev20240412.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
|
1209
|
+
tf_models_nightly-2.17.0.dev20240412.dist-info/METADATA,sha256=y_U6M920Hgob94pOpRJWfqistgFzPqBH0Ds0HOZOpoo,1432
|
1210
|
+
tf_models_nightly-2.17.0.dev20240412.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
|
1211
|
+
tf_models_nightly-2.17.0.dev20240412.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
|
1212
|
+
tf_models_nightly-2.17.0.dev20240412.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|