tf-models-nightly 2.17.0.dev20240410__py2.py3-none-any.whl → 2.17.0.dev20240412__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,7 +17,7 @@
17
17
  import enum
18
18
 
19
19
 
20
- class TwoTowerOutputKeys(enum.StrEnum):
20
+ class TwoTowerOutputKeys(str, enum.Enum):
21
21
  """Keys for training and inference output tensors."""
22
22
 
23
23
  CONTROL_PREDICTIONS = "control_predictions"
@@ -24,7 +24,7 @@ import tensorflow as tf, tf_keras
24
24
 
25
25
 
26
26
  @enum.unique
27
- class LayeringMethod(enum.StrEnum):
27
+ class LayeringMethod(str, enum.Enum):
28
28
  """Layering method between the control and treatment towers."""
29
29
 
30
30
  # No layering.
@@ -19,6 +19,7 @@ from __future__ import annotations
19
19
  import inspect
20
20
  from typing import Any, Callable
21
21
 
22
+ import numpy as np
22
23
  import tensorflow as tf, tf_keras
23
24
 
24
25
  from official.recommendation.uplift import types
@@ -29,9 +30,6 @@ from official.recommendation.uplift.metrics import treatment_sliced_metric
29
30
  class LossMetric(tf_keras.metrics.Metric):
30
31
  """Computes a loss sliced by treatment group.
31
32
 
32
- Note that the prediction tensor is expected to be of type
33
- `TwoTowerTrainingOutputs`.
34
-
35
33
  Example standalone usage:
36
34
 
37
35
  >>> sliced_loss = LossMetric(tf_keras.losses.mean_squared_error)
@@ -74,9 +72,10 @@ class LossMetric(tf_keras.metrics.Metric):
74
72
  `__call__(y_true: tf,Tensor, y_pred: tf.Tensor, **loss_fn_kwargs)`. Note
75
73
  that the `loss_fn_kwargs` will not be passed to the `__call__` method if
76
74
  `loss_fn` is a Keras metric.
77
- from_logits: Specifies whether the true logits or true predictions should
78
- be used from the model outputs to compute the loss. Defaults to using
79
- the true logits.
75
+ from_logits: When `y_pred` is of type `TwoTowerTrainingOutputs`, specifies
76
+ whether the true logits or true predictions should be used to compute
77
+ the loss (defaults to using the true logits). Othwerwise, this argument
78
+ will be ignored if `y_pred` is of type `tf.Tensor`.
80
79
  slice_by_treatment: Specifies whether the loss should be sliced by the
81
80
  treatment indicator tensor. If `True`, `loss_fn` will be wrapped in a
82
81
  `TreatmentSlicedMetric` to report the loss values sliced by the
@@ -129,16 +128,16 @@ class LossMetric(tf_keras.metrics.Metric):
129
128
  def update_state(
130
129
  self,
131
130
  y_true: tf.Tensor,
132
- y_pred: types.TwoTowerTrainingOutputs,
131
+ y_pred: types.TwoTowerTrainingOutputs | tf.Tensor | np.ndarray,
133
132
  sample_weight: tf.Tensor | None = None,
134
133
  ):
135
134
  """Updates the overall, control and treatment losses.
136
135
 
137
136
  Args:
138
137
  y_true: A `tf.Tensor` with the targets.
139
- y_pred: Two tower training outputs. The treatment indicator tensor is used
140
- to slice the true logits or true predictions into control and treatment
141
- losses.
138
+ y_pred: Model outputs. If of type `TwoTowerTrainingOutputs`, the treatment
139
+ indicator tensor is used to slice the true logits or true predictions
140
+ into control and treatment losses.
142
141
  sample_weight: Optional sample weight to compute weighted losses. If
143
142
  given, the sample weight will also be sliced by the treatment indicator
144
143
  tensor to compute the weighted control and treatment losses.
@@ -146,14 +145,23 @@ class LossMetric(tf_keras.metrics.Metric):
146
145
  Raises:
147
146
  TypeError: if `y_pred` is not of type `TwoTowerTrainingOutputs`.
148
147
  """
149
- if not isinstance(y_pred, types.TwoTowerTrainingOutputs):
148
+ if isinstance(y_pred, (tf.Tensor, np.ndarray)):
149
+ if self._slice_by_treatment:
150
+ raise ValueError(
151
+ "`slice_by_treatment` must be False when y_pred is a `tf.Tensor` or"
152
+ " `np.ndarray`."
153
+ )
154
+ pred = y_pred
155
+ elif isinstance(y_pred, types.TwoTowerTrainingOutputs):
156
+ pred = (
157
+ y_pred.true_logits if self._from_logits else y_pred.true_predictions
158
+ )
159
+ else:
150
160
  raise TypeError(
151
- "y_pred must be of type `TwoTowerTrainingOutputs` but got type"
152
- f" {type(y_pred)} instead."
161
+ "y_pred must be of type `TwoTowerTrainingOutputs`, `tf.Tensor` or"
162
+ f" `np.ndarray` but got type {type(y_pred)} instead."
153
163
  )
154
164
 
155
- pred = y_pred.true_logits if self._from_logits else y_pred.true_predictions
156
-
157
165
  is_treatment = {}
158
166
  if self._slice_by_treatment:
159
167
  is_treatment["is_treatment"] = y_pred.is_treatment
@@ -288,6 +288,19 @@ class LossMetricTest(keras_test_case.KerasTestCase, parameterized.TestCase):
288
288
  metric(y_true, outputs, sample_weight=sample_weight)
289
289
  self.assertEqual(expected_losses, metric.result())
290
290
 
291
+ def test_metric_with_y_pred_tensor(self):
292
+ y_true = tf.constant([[0], [0], [2], [7]])
293
+ y_pred = tf.constant([[1], [2], [3], [4]])
294
+ sample_weight = tf.constant([[0.5], [0.5], [0.7], [1.8]])
295
+
296
+ metric = loss_metric.LossMetric(
297
+ loss_fn=tf_keras.metrics.mae, slice_by_treatment=False
298
+ )
299
+ metric(y_true, y_pred, sample_weight)
300
+
301
+ expected_loss = np.average([1, 2, 1, 3], weights=[0.5, 0.5, 0.7, 1.8])
302
+ self.assertAllClose(expected_loss, metric.result())
303
+
291
304
  def test_multiple_update_batches_returns_aggregated_sliced_losses(self):
292
305
  metric = loss_metric.LossMetric(
293
306
  loss_fn=tf_keras.losses.mean_absolute_error,
@@ -397,8 +410,26 @@ class LossMetricTest(keras_test_case.KerasTestCase, parameterized.TestCase):
397
410
  metric = loss_metric.LossMetric(
398
411
  tf_keras.metrics.mean_absolute_percentage_error
399
412
  )
413
+ y_true = tf.ones((3, 1))
414
+ y_pred = types.TwoTowerNetworkOutputs(
415
+ shared_embedding=tf.ones((3, 5)),
416
+ control_logits=tf.ones((3, 1)),
417
+ treatment_logits=tf.ones((3, 1)),
418
+ )
419
+ with self.assertRaisesRegex(
420
+ TypeError,
421
+ "y_pred must be of type `TwoTowerTrainingOutputs`, `tf.Tensor` or"
422
+ " `np.ndarray`",
423
+ ):
424
+ metric.update_state(y_true=y_true, y_pred=y_pred)
425
+
426
+ def test_slice_by_treatment_with_y_pred_tensor_raises_error(self):
427
+ metric = loss_metric.LossMetric(
428
+ tf_keras.metrics.mae, slice_by_treatment=True
429
+ )
400
430
  with self.assertRaisesRegex(
401
- TypeError, "y_pred must be of type `TwoTowerTrainingOutputs`"
431
+ ValueError,
432
+ "`slice_by_treatment` must be False when y_pred is a `tf.Tensor`.",
402
433
  ):
403
434
  metric.update_state(y_true=tf.ones((3, 1)), y_pred=tf.ones((3, 1)))
404
435
 
@@ -15,6 +15,7 @@
15
15
  """TensorFlow Models Libraries."""
16
16
  # pylint: disable=wildcard-import
17
17
  from tensorflow_models import nlp
18
+ from tensorflow_models import uplift
18
19
  from tensorflow_models import vision
19
20
 
20
21
  from official import core
@@ -35,6 +35,17 @@ class TensorflowModelsTest(tf.test.TestCase):
35
35
  _ = tfm.optimization.LinearWarmup(
36
36
  after_warmup_lr_sched=0.0, warmup_steps=10, warmup_learning_rate=0.1)
37
37
 
38
+ def testUpliftImports(self):
39
+ _ = tfm.uplift.keys.TwoTowerOutputKeys.CONTROL_PREDICTIONS
40
+ _ = tfm.uplift.types.TwoTowerNetworkOutputs(
41
+ shared_embedding=tf.ones((10, 10)),
42
+ control_logits=tf.ones((10, 1)),
43
+ treatment_logits=tf.ones((10, 1)),
44
+ )
45
+ _ = tfm.uplift.layers.encoders.concat_features.ConcatFeatures(['feature'])
46
+ _ = tfm.uplift.metrics.treatment_fraction.TreatmentFraction()
47
+ _ = tfm.uplift.losses.true_logits_loss.TrueLogitsLoss(tf_keras.losses.mse)
48
+
38
49
 
39
50
  if __name__ == '__main__':
40
51
  tf.test.main()
@@ -0,0 +1,23 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """TensorFlow Models Uplift Libraries."""
16
+
17
+ from official.recommendation.uplift import keys
18
+ from official.recommendation.uplift import layers
19
+ from official.recommendation.uplift import losses
20
+ from official.recommendation.uplift import metrics
21
+ from official.recommendation.uplift import models
22
+ from official.recommendation.uplift import types
23
+ from official.recommendation.uplift import utils
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.17.0.dev20240410
3
+ Version: 2.17.0.dev20240412
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -887,7 +887,7 @@ official/recommendation/ranking/data/data_pipeline_multi_hot_test.py,sha256=arLj
887
887
  official/recommendation/ranking/data/data_pipeline_test.py,sha256=VRYo7WqURRkM3lbmfctvSZxyH1EzUfqwxR3sy2sZxdc,2345
888
888
  official/recommendation/uplift/__init__.py,sha256=_jZilTPWKu-MfMaz1IgBjEW6wqkK3FNZ1QAP4a8my3I,990
889
889
  official/recommendation/uplift/keras_test_case.py,sha256=gF5Z2FzXlKAvhuJDdj7PmFj3jsW_ZmUAv_F9Xokvs2M,6156
890
- official/recommendation/uplift/keys.py,sha256=jXrgCwSl1Ma4sThUeh8GdeQAig4lzdSbgrOwBcjvQSE,1030
890
+ official/recommendation/uplift/keys.py,sha256=7zkxkPIcXceIN5hWm4ATai4h8ymwCj85dU2r8-3XifM,1032
891
891
  official/recommendation/uplift/types.py,sha256=OBpMvU4uAOKXGwN1UIPZw6KWJXBXeuZGT3P3VyDkyWc,4769
892
892
  official/recommendation/uplift/utils.py,sha256=ZSrzoFJosSdmu7P3yYDo-TER8UQKFimy9ntuXqV4L4c,3501
893
893
  official/recommendation/uplift/utils_test.py,sha256=K-pzTe3MoOEjdy5tLNSRpIQoWWH12MjFTPkhEIHfS3A,4610
@@ -896,7 +896,7 @@ official/recommendation/uplift/layers/encoders/__init__.py,sha256=QaQk0BMHJNE8Pc
896
896
  official/recommendation/uplift/layers/encoders/concat_features.py,sha256=UNjDXN2C-GBXJaiitMYmKuhIsOiR1ROfJ73vHkuGkoA,4210
897
897
  official/recommendation/uplift/layers/encoders/concat_features_test.py,sha256=sIjdoXqQHfQQ5P0yOz_PCuGQHl6j-wKNzaN42PL1jKI,9368
898
898
  official/recommendation/uplift/layers/heads/__init__.py,sha256=nbtNxVbIGZq4GLEw9APPwCX3Zhatv5qwc0ddvAaelvU,728
899
- official/recommendation/uplift/layers/heads/two_tower_logits_head.py,sha256=1HkR2LP3nBbAQrYcbGHsQPO36UuVQ8Pb8CAjWFT0c1Q,6613
899
+ official/recommendation/uplift/layers/heads/two_tower_logits_head.py,sha256=EyeVdiRNe4qk1aEvYEp1tEeGdllUFqtt8saCrK_QI1M,6615
900
900
  official/recommendation/uplift/layers/heads/two_tower_logits_head_test.py,sha256=GKk8kYgq_gXcheUfcG1u_i5I5nVh8jRdam9KxvEHsXU,7615
901
901
  official/recommendation/uplift/layers/uplift_networks/__init__.py,sha256=eaM75bIO4WZRABzDjTl5bIsR5kKiICpzw-QpgXh7K0A,917
902
902
  official/recommendation/uplift/layers/uplift_networks/base_uplift_networks.py,sha256=pGVmdNXyDZR0TA9Htxa07KoUOzE0uEl5IyHEAVRgTk4,1329
@@ -912,8 +912,8 @@ official/recommendation/uplift/metrics/label_mean.py,sha256=ECaes7FZmsksnwySn7jf
912
912
  official/recommendation/uplift/metrics/label_mean_test.py,sha256=b_d3lNlpkDm2xKLUkxfiXeQg7pjL8HNx7y9NaYarpV0,7083
913
913
  official/recommendation/uplift/metrics/label_variance.py,sha256=9DCl42BJkehxfWD3pSbZnRNvwfhVM6VyHwivGdaU72s,3610
914
914
  official/recommendation/uplift/metrics/label_variance_test.py,sha256=k0mdEU1WU53-HIEO5HGtfp1MleifD-h4bZNKtTvM3Ws,7681
915
- official/recommendation/uplift/metrics/loss_metric.py,sha256=qWMcLk1JmLkru5T9qcqO2UbB_AJQKjQ8uUSAxjR1l58,6883
916
- official/recommendation/uplift/metrics/loss_metric_test.py,sha256=QIUujZvUud-A2itI2vs2KpzWfLwbq2XDd9H9CFq5D4Y,14929
915
+ official/recommendation/uplift/metrics/loss_metric.py,sha256=owN7A98TCc_UhOURvGfccaoVGOthdHdx1_fawEUGnmw,7289
916
+ official/recommendation/uplift/metrics/loss_metric_test.py,sha256=48rQG8bKFdy0xBFjoOLXKRUlYpCEyAzSmPOFoF7FX94,16021
917
917
  official/recommendation/uplift/metrics/metric_configs.py,sha256=Z-r79orE4EycQ5TJ7xdI5LhjOHT3wzChYyDxcxGqLXk,1670
918
918
  official/recommendation/uplift/metrics/sliced_metric.py,sha256=O2I2apZK6IfOQK9Q_mgSiTUCnGokczp4e14zrrYNeRU,8564
919
919
  official/recommendation/uplift/metrics/sliced_metric_test.py,sha256=dhY41X8lqT_WW04XLjyjDerZwujEBGeTXtxf4NkYThw,11359
@@ -1199,13 +1199,14 @@ orbit/utils/summary_manager.py,sha256=oTdTQYDjH315rbJ__NFALldruHX5lLrwY34IU1KJwN
1199
1199
  orbit/utils/summary_manager_interface.py,sha256=MXXPd3aWqushByK-gh-Jk9A42VC5ETEEm7XY5DCB-9w,2253
1200
1200
  orbit/utils/tpu_summaries.py,sha256=vjfLJlNifFJayLo_nkdRD85XoisVBV2VDuzgZWZud6U,5641
1201
1201
  orbit/utils/tpu_summaries_test.py,sha256=ZRYguds3BcXn1gOXBHVJ5rGRZdeXQLa5w3IbwEBwGS4,4523
1202
- tensorflow_models/__init__.py,sha256=etxw45SHxuwFCRX5qGxGMP83II0JfJulzNl5GSNJvhw,909
1203
- tensorflow_models/tensorflow_models_test.py,sha256=AxUYUdiQn416UR7jg0h6rmv688esvlKDfpyDCIQkF18,1395
1202
+ tensorflow_models/__init__.py,sha256=n7s5ZGPDM-RqOQNldPyipjL8aEB4DdAQQUENYhgf_Q8,946
1203
+ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbdanoDB12hb32dzPc,1897
1204
1204
  tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
1205
+ tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
1205
1206
  tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
1206
- tf_models_nightly-2.17.0.dev20240410.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1207
- tf_models_nightly-2.17.0.dev20240410.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1208
- tf_models_nightly-2.17.0.dev20240410.dist-info/METADATA,sha256=Bh56e1__Pn_yvpm9QMq4h12nKucvnKAftuk0ElQqn_w,1432
1209
- tf_models_nightly-2.17.0.dev20240410.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1210
- tf_models_nightly-2.17.0.dev20240410.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1211
- tf_models_nightly-2.17.0.dev20240410.dist-info/RECORD,,
1207
+ tf_models_nightly-2.17.0.dev20240412.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1208
+ tf_models_nightly-2.17.0.dev20240412.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1209
+ tf_models_nightly-2.17.0.dev20240412.dist-info/METADATA,sha256=y_U6M920Hgob94pOpRJWfqistgFzPqBH0Ds0HOZOpoo,1432
1210
+ tf_models_nightly-2.17.0.dev20240412.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1211
+ tf_models_nightly-2.17.0.dev20240412.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1212
+ tf_models_nightly-2.17.0.dev20240412.dist-info/RECORD,,