tf-models-nightly 2.17.0.dev20240410__py2.py3-none-any.whl → 2.17.0.dev20240411__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -43,6 +43,12 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
43
43
  trainer_options=trainer_options)
44
44
  self._task_sampler = task_sampler
45
45
 
46
+ # TODO(haozhangthu): Add taskwise step counter to train_loop_end for logging
47
+ # on TensorBoard.
48
+ self._task_step_counters = {
49
+ name: orbit.utils.create_global_step() for name in self.multi_task.tasks
50
+ }
51
+
46
52
  # Build per task train step.
47
53
  def _get_task_step(task_name, task):
48
54
 
@@ -57,6 +63,8 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
57
63
  optimizer=self.optimizer,
58
64
  metrics=self.training_metrics[task_name])
59
65
  self.training_losses[task_name].update_state(task_logs[task.loss])
66
+ self.global_step.assign_add(1)
67
+ self.task_step_counter(task_name).assign_add(1)
60
68
 
61
69
  return step_fn
62
70
 
@@ -65,12 +73,6 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
65
73
  for name, task in self.multi_task.tasks.items()
66
74
  }
67
75
 
68
- # TODO(haozhangthu): Add taskwise step counter to train_loop_end for logging
69
- # on TensorBoard.
70
- self._task_step_counters = {
71
- name: orbit.utils.create_global_step() for name in self.multi_task.tasks
72
- }
73
-
74
76
  # If the new Keras optimizer is used, we require all model variables are
75
77
  # created before the training and let the optimizer to create the slot
76
78
  # variable all together.
@@ -97,8 +99,6 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
97
99
  if rn >= begin and rn < end:
98
100
  self._strategy.run(
99
101
  self._task_train_step_map[name], args=(next(iterator_map[name]),))
100
- self.global_step.assign_add(1)
101
- self.task_step_counter(name).assign_add(1)
102
102
 
103
103
  def train_loop_end(self):
104
104
  """Record loss and metric values per task."""
@@ -19,6 +19,7 @@ from __future__ import annotations
19
19
  import inspect
20
20
  from typing import Any, Callable
21
21
 
22
+ import numpy as np
22
23
  import tensorflow as tf, tf_keras
23
24
 
24
25
  from official.recommendation.uplift import types
@@ -29,9 +30,6 @@ from official.recommendation.uplift.metrics import treatment_sliced_metric
29
30
  class LossMetric(tf_keras.metrics.Metric):
30
31
  """Computes a loss sliced by treatment group.
31
32
 
32
- Note that the prediction tensor is expected to be of type
33
- `TwoTowerTrainingOutputs`.
34
-
35
33
  Example standalone usage:
36
34
 
37
35
  >>> sliced_loss = LossMetric(tf_keras.losses.mean_squared_error)
@@ -74,9 +72,10 @@ class LossMetric(tf_keras.metrics.Metric):
74
72
  `__call__(y_true: tf,Tensor, y_pred: tf.Tensor, **loss_fn_kwargs)`. Note
75
73
  that the `loss_fn_kwargs` will not be passed to the `__call__` method if
76
74
  `loss_fn` is a Keras metric.
77
- from_logits: Specifies whether the true logits or true predictions should
78
- be used from the model outputs to compute the loss. Defaults to using
79
- the true logits.
75
+ from_logits: When `y_pred` is of type `TwoTowerTrainingOutputs`, specifies
76
+ whether the true logits or true predictions should be used to compute
77
+ the loss (defaults to using the true logits). Othwerwise, this argument
78
+ will be ignored if `y_pred` is of type `tf.Tensor`.
80
79
  slice_by_treatment: Specifies whether the loss should be sliced by the
81
80
  treatment indicator tensor. If `True`, `loss_fn` will be wrapped in a
82
81
  `TreatmentSlicedMetric` to report the loss values sliced by the
@@ -129,16 +128,16 @@ class LossMetric(tf_keras.metrics.Metric):
129
128
  def update_state(
130
129
  self,
131
130
  y_true: tf.Tensor,
132
- y_pred: types.TwoTowerTrainingOutputs,
131
+ y_pred: types.TwoTowerTrainingOutputs | tf.Tensor | np.ndarray,
133
132
  sample_weight: tf.Tensor | None = None,
134
133
  ):
135
134
  """Updates the overall, control and treatment losses.
136
135
 
137
136
  Args:
138
137
  y_true: A `tf.Tensor` with the targets.
139
- y_pred: Two tower training outputs. The treatment indicator tensor is used
140
- to slice the true logits or true predictions into control and treatment
141
- losses.
138
+ y_pred: Model outputs. If of type `TwoTowerTrainingOutputs`, the treatment
139
+ indicator tensor is used to slice the true logits or true predictions
140
+ into control and treatment losses.
142
141
  sample_weight: Optional sample weight to compute weighted losses. If
143
142
  given, the sample weight will also be sliced by the treatment indicator
144
143
  tensor to compute the weighted control and treatment losses.
@@ -146,14 +145,23 @@ class LossMetric(tf_keras.metrics.Metric):
146
145
  Raises:
147
146
  TypeError: if `y_pred` is not of type `TwoTowerTrainingOutputs`.
148
147
  """
149
- if not isinstance(y_pred, types.TwoTowerTrainingOutputs):
148
+ if isinstance(y_pred, (tf.Tensor, np.ndarray)):
149
+ if self._slice_by_treatment:
150
+ raise ValueError(
151
+ "`slice_by_treatment` must be False when y_pred is a `tf.Tensor` or"
152
+ " `np.ndarray`."
153
+ )
154
+ pred = y_pred
155
+ elif isinstance(y_pred, types.TwoTowerTrainingOutputs):
156
+ pred = (
157
+ y_pred.true_logits if self._from_logits else y_pred.true_predictions
158
+ )
159
+ else:
150
160
  raise TypeError(
151
- "y_pred must be of type `TwoTowerTrainingOutputs` but got type"
152
- f" {type(y_pred)} instead."
161
+ "y_pred must be of type `TwoTowerTrainingOutputs`, `tf.Tensor` or"
162
+ f" `np.ndarray` but got type {type(y_pred)} instead."
153
163
  )
154
164
 
155
- pred = y_pred.true_logits if self._from_logits else y_pred.true_predictions
156
-
157
165
  is_treatment = {}
158
166
  if self._slice_by_treatment:
159
167
  is_treatment["is_treatment"] = y_pred.is_treatment
@@ -288,6 +288,19 @@ class LossMetricTest(keras_test_case.KerasTestCase, parameterized.TestCase):
288
288
  metric(y_true, outputs, sample_weight=sample_weight)
289
289
  self.assertEqual(expected_losses, metric.result())
290
290
 
291
+ def test_metric_with_y_pred_tensor(self):
292
+ y_true = tf.constant([[0], [0], [2], [7]])
293
+ y_pred = tf.constant([[1], [2], [3], [4]])
294
+ sample_weight = tf.constant([[0.5], [0.5], [0.7], [1.8]])
295
+
296
+ metric = loss_metric.LossMetric(
297
+ loss_fn=tf_keras.metrics.mae, slice_by_treatment=False
298
+ )
299
+ metric(y_true, y_pred, sample_weight)
300
+
301
+ expected_loss = np.average([1, 2, 1, 3], weights=[0.5, 0.5, 0.7, 1.8])
302
+ self.assertAllClose(expected_loss, metric.result())
303
+
291
304
  def test_multiple_update_batches_returns_aggregated_sliced_losses(self):
292
305
  metric = loss_metric.LossMetric(
293
306
  loss_fn=tf_keras.losses.mean_absolute_error,
@@ -397,8 +410,26 @@ class LossMetricTest(keras_test_case.KerasTestCase, parameterized.TestCase):
397
410
  metric = loss_metric.LossMetric(
398
411
  tf_keras.metrics.mean_absolute_percentage_error
399
412
  )
413
+ y_true = tf.ones((3, 1))
414
+ y_pred = types.TwoTowerNetworkOutputs(
415
+ shared_embedding=tf.ones((3, 5)),
416
+ control_logits=tf.ones((3, 1)),
417
+ treatment_logits=tf.ones((3, 1)),
418
+ )
419
+ with self.assertRaisesRegex(
420
+ TypeError,
421
+ "y_pred must be of type `TwoTowerTrainingOutputs`, `tf.Tensor` or"
422
+ " `np.ndarray`",
423
+ ):
424
+ metric.update_state(y_true=y_true, y_pred=y_pred)
425
+
426
+ def test_slice_by_treatment_with_y_pred_tensor_raises_error(self):
427
+ metric = loss_metric.LossMetric(
428
+ tf_keras.metrics.mae, slice_by_treatment=True
429
+ )
400
430
  with self.assertRaisesRegex(
401
- TypeError, "y_pred must be of type `TwoTowerTrainingOutputs`"
431
+ ValueError,
432
+ "`slice_by_treatment` must be False when y_pred is a `tf.Tensor`.",
402
433
  ):
403
434
  metric.update_state(y_true=tf.ones((3, 1)), y_pred=tf.ones((3, 1)))
404
435
 
@@ -15,6 +15,7 @@
15
15
  """TensorFlow Models Libraries."""
16
16
  # pylint: disable=wildcard-import
17
17
  from tensorflow_models import nlp
18
+ from tensorflow_models import uplift
18
19
  from tensorflow_models import vision
19
20
 
20
21
  from official import core
@@ -35,6 +35,17 @@ class TensorflowModelsTest(tf.test.TestCase):
35
35
  _ = tfm.optimization.LinearWarmup(
36
36
  after_warmup_lr_sched=0.0, warmup_steps=10, warmup_learning_rate=0.1)
37
37
 
38
+ def testUpliftImports(self):
39
+ _ = tfm.uplift.keys.TwoTowerOutputKeys.CONTROL_PREDICTIONS
40
+ _ = tfm.uplift.types.TwoTowerNetworkOutputs(
41
+ shared_embedding=tf.ones((10, 10)),
42
+ control_logits=tf.ones((10, 1)),
43
+ treatment_logits=tf.ones((10, 1)),
44
+ )
45
+ _ = tfm.uplift.layers.encoders.concat_features.ConcatFeatures(['feature'])
46
+ _ = tfm.uplift.metrics.treatment_fraction.TreatmentFraction()
47
+ _ = tfm.uplift.losses.true_logits_loss.TrueLogitsLoss(tf_keras.losses.mse)
48
+
38
49
 
39
50
  if __name__ == '__main__':
40
51
  tf.test.main()
@@ -0,0 +1,23 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """TensorFlow Models Uplift Libraries."""
16
+
17
+ from official.recommendation.uplift import keys
18
+ from official.recommendation.uplift import layers
19
+ from official.recommendation.uplift import losses
20
+ from official.recommendation.uplift import metrics
21
+ from official.recommendation.uplift import models
22
+ from official.recommendation.uplift import types
23
+ from official.recommendation.uplift import utils
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.17.0.dev20240410
3
+ Version: 2.17.0.dev20240411
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -218,7 +218,7 @@ official/modeling/multitask/base_trainer_test.py,sha256=qJ7z4kid2XAX6hOIvUHa7dwq
218
218
  official/modeling/multitask/configs.py,sha256=ZO2waQrMn9CAgyFpsmeQvplCF5VeXz7tCPmIuy5jvlc,3164
219
219
  official/modeling/multitask/evaluator.py,sha256=spDm2X8EX62qsxI2ehVjrkIKoo-omQQOYcAVKZNgxHc,6078
220
220
  official/modeling/multitask/evaluator_test.py,sha256=vU-q-gM7GqiMqE5zbBnOT8mPFhQmHjniMyNnwganhso,4643
221
- official/modeling/multitask/interleaving_trainer.py,sha256=ZZHKsqbJKLqvwtgy-PUv_S_8bDG0MhJDNwWICY_IF6Q,4458
221
+ official/modeling/multitask/interleaving_trainer.py,sha256=G4jwnWywn9tlEkdHgN8BApA3ymXFHIGnquq8rbVVBHo,4463
222
222
  official/modeling/multitask/interleaving_trainer_test.py,sha256=MeQQxpcinPTQuTrAcITjwHa2bAj-XCBCqYsrbxPBus8,4305
223
223
  official/modeling/multitask/multitask.py,sha256=DV-ysfhPiIZgsrzZNylsPBxKNBf_xzPxJYjF4buWVgE,5948
224
224
  official/modeling/multitask/task_sampler.py,sha256=SGVVdjMb5oG4vnCczpfdgBtbsdsXiyBLl9si_0V6nko,4897
@@ -912,8 +912,8 @@ official/recommendation/uplift/metrics/label_mean.py,sha256=ECaes7FZmsksnwySn7jf
912
912
  official/recommendation/uplift/metrics/label_mean_test.py,sha256=b_d3lNlpkDm2xKLUkxfiXeQg7pjL8HNx7y9NaYarpV0,7083
913
913
  official/recommendation/uplift/metrics/label_variance.py,sha256=9DCl42BJkehxfWD3pSbZnRNvwfhVM6VyHwivGdaU72s,3610
914
914
  official/recommendation/uplift/metrics/label_variance_test.py,sha256=k0mdEU1WU53-HIEO5HGtfp1MleifD-h4bZNKtTvM3Ws,7681
915
- official/recommendation/uplift/metrics/loss_metric.py,sha256=qWMcLk1JmLkru5T9qcqO2UbB_AJQKjQ8uUSAxjR1l58,6883
916
- official/recommendation/uplift/metrics/loss_metric_test.py,sha256=QIUujZvUud-A2itI2vs2KpzWfLwbq2XDd9H9CFq5D4Y,14929
915
+ official/recommendation/uplift/metrics/loss_metric.py,sha256=owN7A98TCc_UhOURvGfccaoVGOthdHdx1_fawEUGnmw,7289
916
+ official/recommendation/uplift/metrics/loss_metric_test.py,sha256=48rQG8bKFdy0xBFjoOLXKRUlYpCEyAzSmPOFoF7FX94,16021
917
917
  official/recommendation/uplift/metrics/metric_configs.py,sha256=Z-r79orE4EycQ5TJ7xdI5LhjOHT3wzChYyDxcxGqLXk,1670
918
918
  official/recommendation/uplift/metrics/sliced_metric.py,sha256=O2I2apZK6IfOQK9Q_mgSiTUCnGokczp4e14zrrYNeRU,8564
919
919
  official/recommendation/uplift/metrics/sliced_metric_test.py,sha256=dhY41X8lqT_WW04XLjyjDerZwujEBGeTXtxf4NkYThw,11359
@@ -1199,13 +1199,14 @@ orbit/utils/summary_manager.py,sha256=oTdTQYDjH315rbJ__NFALldruHX5lLrwY34IU1KJwN
1199
1199
  orbit/utils/summary_manager_interface.py,sha256=MXXPd3aWqushByK-gh-Jk9A42VC5ETEEm7XY5DCB-9w,2253
1200
1200
  orbit/utils/tpu_summaries.py,sha256=vjfLJlNifFJayLo_nkdRD85XoisVBV2VDuzgZWZud6U,5641
1201
1201
  orbit/utils/tpu_summaries_test.py,sha256=ZRYguds3BcXn1gOXBHVJ5rGRZdeXQLa5w3IbwEBwGS4,4523
1202
- tensorflow_models/__init__.py,sha256=etxw45SHxuwFCRX5qGxGMP83II0JfJulzNl5GSNJvhw,909
1203
- tensorflow_models/tensorflow_models_test.py,sha256=AxUYUdiQn416UR7jg0h6rmv688esvlKDfpyDCIQkF18,1395
1202
+ tensorflow_models/__init__.py,sha256=n7s5ZGPDM-RqOQNldPyipjL8aEB4DdAQQUENYhgf_Q8,946
1203
+ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbdanoDB12hb32dzPc,1897
1204
1204
  tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
1205
+ tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
1205
1206
  tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
1206
- tf_models_nightly-2.17.0.dev20240410.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1207
- tf_models_nightly-2.17.0.dev20240410.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1208
- tf_models_nightly-2.17.0.dev20240410.dist-info/METADATA,sha256=Bh56e1__Pn_yvpm9QMq4h12nKucvnKAftuk0ElQqn_w,1432
1209
- tf_models_nightly-2.17.0.dev20240410.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1210
- tf_models_nightly-2.17.0.dev20240410.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1211
- tf_models_nightly-2.17.0.dev20240410.dist-info/RECORD,,
1207
+ tf_models_nightly-2.17.0.dev20240411.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1208
+ tf_models_nightly-2.17.0.dev20240411.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1209
+ tf_models_nightly-2.17.0.dev20240411.dist-info/METADATA,sha256=I12qTeIZ3Y8oc91Y2TGdZ1WonCJu2KZaGuZdmIB7aa0,1432
1210
+ tf_models_nightly-2.17.0.dev20240411.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1211
+ tf_models_nightly-2.17.0.dev20240411.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1212
+ tf_models_nightly-2.17.0.dev20240411.dist-info/RECORD,,