tf-models-nightly 2.17.0.dev20240415__py2.py3-none-any.whl → 2.17.0.dev20240417__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- official/modeling/multitask/base_trainer.py +1 -1
- official/modeling/multitask/interleaving_trainer.py +10 -3
- official/recommendation/uplift/metrics/poisson_metrics.py +77 -0
- official/recommendation/uplift/metrics/poisson_metrics_test.py +81 -2
- official/vision/modeling/backbones/__init__.py +1 -0
- {tf_models_nightly-2.17.0.dev20240415.dist-info → tf_models_nightly-2.17.0.dev20240417.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.17.0.dev20240415.dist-info → tf_models_nightly-2.17.0.dev20240417.dist-info}/RECORD +11 -11
- {tf_models_nightly-2.17.0.dev20240415.dist-info → tf_models_nightly-2.17.0.dev20240417.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.17.0.dev20240415.dist-info → tf_models_nightly-2.17.0.dev20240417.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.17.0.dev20240415.dist-info → tf_models_nightly-2.17.0.dev20240417.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.17.0.dev20240415.dist-info → tf_models_nightly-2.17.0.dev20240417.dist-info}/top_level.txt +0 -0
@@ -164,7 +164,7 @@ class MultiTaskBaseTrainer(orbit.StandardTrainer):
|
|
164
164
|
task_metrics=self.training_metrics)
|
165
165
|
for key, loss in losses.items():
|
166
166
|
self.training_losses[key].update_state(loss)
|
167
|
+
self.global_step.assign_add(1)
|
167
168
|
|
168
169
|
self.strategy.run(
|
169
170
|
step_fn, args=(tf.nest.map_structure(next, iterator_map),))
|
170
|
-
self.global_step.assign_add(1)
|
@@ -81,6 +81,15 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
|
|
81
81
|
def task_step_counter(self, name):
|
82
82
|
return self._task_step_counters[name]
|
83
83
|
|
84
|
+
def _task_train_step(self, name):
|
85
|
+
"""Runs one training step and updates counters."""
|
86
|
+
def _step_fn(inputs):
|
87
|
+
self._task_train_step_map[name](inputs)
|
88
|
+
self.global_step.assign_add(1)
|
89
|
+
self.task_step_counter(name).assign_add(1)
|
90
|
+
|
91
|
+
return _step_fn
|
92
|
+
|
84
93
|
def train_step(self, iterator_map):
|
85
94
|
# Sample one task to train according to a multinomial distribution
|
86
95
|
rn = tf.random.stateless_uniform(shape=[], seed=(0, self.global_step))
|
@@ -96,9 +105,7 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
|
|
96
105
|
end = cumulative_sample_distribution[idx + 1]
|
97
106
|
if rn >= begin and rn < end:
|
98
107
|
self._strategy.run(
|
99
|
-
self.
|
100
|
-
self.global_step.assign_add(1)
|
101
|
-
self.task_step_counter(name).assign_add(1)
|
108
|
+
self._task_train_step(name), args=(next(iterator_map[name]),))
|
102
109
|
|
103
110
|
def train_loop_end(self):
|
104
111
|
"""Record loss and metric values per task."""
|
@@ -22,6 +22,7 @@ import tensorflow as tf, tf_keras
|
|
22
22
|
|
23
23
|
from official.recommendation.uplift import types
|
24
24
|
from official.recommendation.uplift.metrics import loss_metric
|
25
|
+
from official.recommendation.uplift.metrics import treatment_sliced_metric
|
25
26
|
|
26
27
|
|
27
28
|
@tf_keras.utils.register_keras_serializable(package="Uplift")
|
@@ -121,3 +122,79 @@ class LogLoss(loss_metric.LossMetric):
|
|
121
122
|
@classmethod
|
122
123
|
def from_config(cls, config: dict[str, Any]) -> LogLoss:
|
123
124
|
return cls(**config)
|
125
|
+
|
126
|
+
|
127
|
+
def _safe_x_minus_xlogx(x: tf.Tensor) -> tf.Tensor:
|
128
|
+
"""Computes x - x * log(x) with 0 as its continuity point when x equals 0."""
|
129
|
+
values = x * (1.0 - tf.math.log(x))
|
130
|
+
return tf.where(tf.equal(x, 0.0), tf.zeros_like(x), values)
|
131
|
+
|
132
|
+
|
133
|
+
class LogLossMeanBaseline(tf_keras.metrics.Metric):
|
134
|
+
"""Computes the (weighted) poisson log loss for a mean predictor."""
|
135
|
+
|
136
|
+
def __init__(
|
137
|
+
self,
|
138
|
+
compute_full_loss: bool = False,
|
139
|
+
slice_by_treatment: bool = True,
|
140
|
+
name: str = "poisson_log_loss_mean_baseline",
|
141
|
+
dtype: tf.DType = tf.float32,
|
142
|
+
):
|
143
|
+
"""Initializes the instance.
|
144
|
+
|
145
|
+
Args:
|
146
|
+
compute_full_loss: Specifies whether to compute the full poisson log loss
|
147
|
+
for the mean predictor or not. Defaults to `False`.
|
148
|
+
slice_by_treatment: Specifies whether the loss should be sliced by the
|
149
|
+
treatment indicator tensor. If `True`, the metric's result will return
|
150
|
+
the loss values sliced by the treatment group. Note that this can only
|
151
|
+
be set to `True` when `y_pred` is of type `TwoTowerTrainingOutputs`.
|
152
|
+
name: Optional name for the instance.
|
153
|
+
dtype: Optional data type for the instance.
|
154
|
+
"""
|
155
|
+
super().__init__(name=name, dtype=dtype)
|
156
|
+
|
157
|
+
if compute_full_loss:
|
158
|
+
raise NotImplementedError("Full loss computation is not yet supported.")
|
159
|
+
|
160
|
+
self._compute_full_loss = compute_full_loss
|
161
|
+
self._slice_by_treatment = slice_by_treatment
|
162
|
+
|
163
|
+
if slice_by_treatment:
|
164
|
+
self._mean_label = treatment_sliced_metric.TreatmentSlicedMetric(
|
165
|
+
metric=tf_keras.metrics.Mean(name=name, dtype=dtype)
|
166
|
+
)
|
167
|
+
else:
|
168
|
+
self._mean_label = tf_keras.metrics.Mean(name=name, dtype=dtype)
|
169
|
+
|
170
|
+
def update_state(
|
171
|
+
self,
|
172
|
+
y_true: tf.Tensor,
|
173
|
+
y_pred: types.TwoTowerTrainingOutputs | tf.Tensor | None = None,
|
174
|
+
sample_weight: tf.Tensor | None = None,
|
175
|
+
):
|
176
|
+
is_treatment = {}
|
177
|
+
if self._slice_by_treatment:
|
178
|
+
if not isinstance(y_pred, types.TwoTowerTrainingOutputs):
|
179
|
+
raise ValueError(
|
180
|
+
"`slice_by_treatment` must be set to `False` when `y_pred` is not"
|
181
|
+
" of type `TwoTowerTrainingOutputs`."
|
182
|
+
)
|
183
|
+
is_treatment["is_treatment"] = y_pred.is_treatment
|
184
|
+
|
185
|
+
self._mean_label.update_state(
|
186
|
+
y_true, sample_weight=sample_weight, **is_treatment
|
187
|
+
)
|
188
|
+
|
189
|
+
def result(self) -> tf.Tensor | dict[str, tf.Tensor]:
|
190
|
+
return tf.nest.map_structure(_safe_x_minus_xlogx, self._mean_label.result())
|
191
|
+
|
192
|
+
def get_config(self) -> dict[str, Any]:
|
193
|
+
config = super().get_config()
|
194
|
+
config["compute_full_loss"] = self._compute_full_loss
|
195
|
+
config["slice_by_treatment"] = self._slice_by_treatment
|
196
|
+
return config
|
197
|
+
|
198
|
+
@classmethod
|
199
|
+
def from_config(cls, config: dict[str, Any]) -> LogLossMeanBaseline:
|
200
|
+
return cls(**config)
|
@@ -23,7 +23,8 @@ from official.recommendation.uplift.metrics import poisson_metrics
|
|
23
23
|
|
24
24
|
|
25
25
|
def _get_two_tower_outputs(
|
26
|
-
|
26
|
+
is_treatment: tf.Tensor,
|
27
|
+
true_logits: tf.Tensor | None = None,
|
27
28
|
) -> types.TwoTowerTrainingOutputs:
|
28
29
|
# Only the true_logits and is_treatment tensors are needed for testing.
|
29
30
|
return types.TwoTowerTrainingOutputs(
|
@@ -33,7 +34,9 @@ def _get_two_tower_outputs(
|
|
33
34
|
uplift=tf.ones_like(is_treatment),
|
34
35
|
control_logits=tf.ones_like(is_treatment),
|
35
36
|
treatment_logits=tf.ones_like(is_treatment),
|
36
|
-
true_logits=
|
37
|
+
true_logits=(
|
38
|
+
true_logits if true_logits is not None else tf.ones_like(is_treatment)
|
39
|
+
),
|
37
40
|
true_predictions=tf.ones_like(is_treatment),
|
38
41
|
is_treatment=is_treatment,
|
39
42
|
)
|
@@ -177,5 +180,81 @@ class LogLossTest(keras_test_case.KerasTestCase, parameterized.TestCase):
|
|
177
180
|
)
|
178
181
|
|
179
182
|
|
183
|
+
class LogLossMeanBaselineTest(
|
184
|
+
keras_test_case.KerasTestCase, parameterized.TestCase
|
185
|
+
):
|
186
|
+
|
187
|
+
@parameterized.named_parameters(
|
188
|
+
{
|
189
|
+
"testcase_name": "label_zero",
|
190
|
+
"expected_loss": 0.0,
|
191
|
+
"y_true": tf.constant([0], dtype=tf.float32),
|
192
|
+
},
|
193
|
+
{
|
194
|
+
"testcase_name": "small_positive_label",
|
195
|
+
"expected_loss": 0.0,
|
196
|
+
"y_true": tf.constant([1e-10], dtype=tf.float32),
|
197
|
+
},
|
198
|
+
{
|
199
|
+
"testcase_name": "label_one",
|
200
|
+
"expected_loss": 1.0,
|
201
|
+
"y_true": tf.constant([1], dtype=tf.float32),
|
202
|
+
},
|
203
|
+
{
|
204
|
+
"testcase_name": "weighted_loss",
|
205
|
+
"expected_loss": 1.0,
|
206
|
+
"y_true": tf.constant([[0], [1]], dtype=tf.float32),
|
207
|
+
"sample_weight": tf.constant([[0], [1]], dtype=tf.float32),
|
208
|
+
},
|
209
|
+
{
|
210
|
+
"testcase_name": "two_tower_outputs",
|
211
|
+
"expected_loss": 0.5 - 0.5 * tf.math.log(0.5),
|
212
|
+
"y_true": tf.constant([[0], [1]], dtype=tf.float32),
|
213
|
+
"y_pred": _get_two_tower_outputs(
|
214
|
+
is_treatment=tf.constant([[0], [1]], dtype=tf.float32),
|
215
|
+
),
|
216
|
+
},
|
217
|
+
{
|
218
|
+
"testcase_name": "two_tower_outputs_sliced_loss",
|
219
|
+
"expected_loss": {
|
220
|
+
"loss": 0.5 - 0.5 * tf.math.log(0.5),
|
221
|
+
"loss/control": 0.0,
|
222
|
+
"loss/treatment": 1.0,
|
223
|
+
},
|
224
|
+
"y_true": tf.constant([[0], [1]], dtype=tf.float32),
|
225
|
+
"y_pred": _get_two_tower_outputs(
|
226
|
+
is_treatment=tf.constant([[0], [1]], dtype=tf.float32),
|
227
|
+
),
|
228
|
+
"slice_by_treatment": True,
|
229
|
+
},
|
230
|
+
)
|
231
|
+
def test_metric_computes_correct_loss(
|
232
|
+
self,
|
233
|
+
expected_loss: tf.Tensor,
|
234
|
+
y_true: tf.Tensor,
|
235
|
+
y_pred: types.TwoTowerTrainingOutputs | tf.Tensor | None = None,
|
236
|
+
sample_weight: tf.Tensor | None = None,
|
237
|
+
slice_by_treatment: bool = False,
|
238
|
+
):
|
239
|
+
metric = poisson_metrics.LogLossMeanBaseline(
|
240
|
+
slice_by_treatment=slice_by_treatment, name="loss"
|
241
|
+
)
|
242
|
+
metric.update_state(y_true, y_pred, sample_weight=sample_weight)
|
243
|
+
self.assertAllClose(expected_loss, metric.result())
|
244
|
+
|
245
|
+
def test_negative_label_returns_nan_loss(self):
|
246
|
+
metric = poisson_metrics.LogLossMeanBaseline(slice_by_treatment=False)
|
247
|
+
metric.update_state(tf.constant([-1.0]))
|
248
|
+
self.assertTrue(tf.math.is_nan(metric.result()).numpy().item())
|
249
|
+
|
250
|
+
def test_metric_is_configurable(self):
|
251
|
+
metric = poisson_metrics.LogLossMeanBaseline(slice_by_treatment=False)
|
252
|
+
self.assertLayerConfigurable(
|
253
|
+
layer=metric,
|
254
|
+
y_true=tf.constant([[0], [0], [2], [7]], dtype=tf.float32),
|
255
|
+
serializable=True,
|
256
|
+
)
|
257
|
+
|
258
|
+
|
180
259
|
if __name__ == "__main__":
|
181
260
|
tf.test.main()
|
@@ -14,6 +14,7 @@
|
|
14
14
|
|
15
15
|
"""Backbones package definition."""
|
16
16
|
|
17
|
+
from official.projects.maskconver.modeling.resnet_unet import ResNetUNet
|
17
18
|
from official.vision.modeling.backbones.efficientnet import EfficientNet
|
18
19
|
from official.vision.modeling.backbones.mobiledet import MobileDet
|
19
20
|
from official.vision.modeling.backbones.mobilenet import MobileNet
|
@@ -213,12 +213,12 @@ official/modeling/hyperparams/params_dict.py,sha256=63fftQdUlycgJErxcyIj7655zL57
|
|
213
213
|
official/modeling/hyperparams/params_dict_test.py,sha256=WPX-VU7L3JVjS42b7BWe77QxP1kdwch05Ib7LvNOTYs,14673
|
214
214
|
official/modeling/multitask/__init__.py,sha256=7oiypy0N82PDw9aSdcJBLVoGTd_oRSUOdvuJhMv4leQ,609
|
215
215
|
official/modeling/multitask/base_model.py,sha256=QI8qb8ipj75IUj6bKNjcAFHPjeqmNjqHr7nUbPd6a-o,1946
|
216
|
-
official/modeling/multitask/base_trainer.py,sha256=
|
216
|
+
official/modeling/multitask/base_trainer.py,sha256=83cLDajiyS2lJPMhllTdIsKXqiVTFLDaGZaherTPCa8,5858
|
217
217
|
official/modeling/multitask/base_trainer_test.py,sha256=qJ7z4kid2XAX6hOIvUHa7dwqxouemMekS9ZXhPjWW9w,3663
|
218
218
|
official/modeling/multitask/configs.py,sha256=ZO2waQrMn9CAgyFpsmeQvplCF5VeXz7tCPmIuy5jvlc,3164
|
219
219
|
official/modeling/multitask/evaluator.py,sha256=spDm2X8EX62qsxI2ehVjrkIKoo-omQQOYcAVKZNgxHc,6078
|
220
220
|
official/modeling/multitask/evaluator_test.py,sha256=vU-q-gM7GqiMqE5zbBnOT8mPFhQmHjniMyNnwganhso,4643
|
221
|
-
official/modeling/multitask/interleaving_trainer.py,sha256=
|
221
|
+
official/modeling/multitask/interleaving_trainer.py,sha256=f111ZhknyS34hpP0FfdWjX3_iiLViHfBd0VSuC715s0,4635
|
222
222
|
official/modeling/multitask/interleaving_trainer_test.py,sha256=MeQQxpcinPTQuTrAcITjwHa2bAj-XCBCqYsrbxPBus8,4305
|
223
223
|
official/modeling/multitask/multitask.py,sha256=DV-ysfhPiIZgsrzZNylsPBxKNBf_xzPxJYjF4buWVgE,5948
|
224
224
|
official/modeling/multitask/task_sampler.py,sha256=SGVVdjMb5oG4vnCczpfdgBtbsdsXiyBLl9si_0V6nko,4897
|
@@ -915,8 +915,8 @@ official/recommendation/uplift/metrics/label_variance_test.py,sha256=k0mdEU1WU53
|
|
915
915
|
official/recommendation/uplift/metrics/loss_metric.py,sha256=gYZdnTsuL_2q1FZuPip-DaWxt_Q-02YYaePyMBVNx7w,7344
|
916
916
|
official/recommendation/uplift/metrics/loss_metric_test.py,sha256=48rQG8bKFdy0xBFjoOLXKRUlYpCEyAzSmPOFoF7FX94,16021
|
917
917
|
official/recommendation/uplift/metrics/metric_configs.py,sha256=Z-r79orE4EycQ5TJ7xdI5LhjOHT3wzChYyDxcxGqLXk,1670
|
918
|
-
official/recommendation/uplift/metrics/poisson_metrics.py,sha256=
|
919
|
-
official/recommendation/uplift/metrics/poisson_metrics_test.py,sha256=
|
918
|
+
official/recommendation/uplift/metrics/poisson_metrics.py,sha256=1a0sHwdOLUVhmsdpnZIQxAWkt3QAPWS9TZCK6P9H7ug,7443
|
919
|
+
official/recommendation/uplift/metrics/poisson_metrics_test.py,sha256=sLx4RjtvJiD8kRkXBDfsUANK0P5LSMfTTPj9GF73wOw,9473
|
920
920
|
official/recommendation/uplift/metrics/sliced_metric.py,sha256=uhvzudOWtMNKZ0avwGhX-37UELR9Cq9b4C0g8erBkXw,8688
|
921
921
|
official/recommendation/uplift/metrics/sliced_metric_test.py,sha256=bhVGyI1tOkFkVOtruJo3p6XopDFyG1JW5qdZm9-RqeU,12248
|
922
922
|
official/recommendation/uplift/metrics/treatment_fraction.py,sha256=WHrKfsN42xU7S-pK99xEVpVtd3zLD7UidLT1K8vgIn4,2757
|
@@ -1042,7 +1042,7 @@ official/vision/modeling/segmentation_model.py,sha256=BrH3w0yZ60uwyEGP08iYt0rAWV
|
|
1042
1042
|
official/vision/modeling/segmentation_model_test.py,sha256=fimmpt0pFMcX214Vuf0C4NcRdWLLB9F2IJeiuTFUO54,2817
|
1043
1043
|
official/vision/modeling/video_classification_model.py,sha256=RKrT22D5yW435S9AyYlgOFCIi0knbmZFvgwh_5JubPI,4713
|
1044
1044
|
official/vision/modeling/video_classification_model_test.py,sha256=HL0hVwOFJCxSjDNFVK42MwzuoDx5r4rvraEbBXVkX8c,3271
|
1045
|
-
official/vision/modeling/backbones/__init__.py,sha256=
|
1045
|
+
official/vision/modeling/backbones/__init__.py,sha256=21iuq1HPa3KcN-6ljjRlUyklRV8Ynyc3Cz47XivHWMw,1402
|
1046
1046
|
official/vision/modeling/backbones/efficientnet.py,sha256=j716OGSpkzgpDA4jV-Hk73mAGPULrPwlwHuS_9j40bE,12438
|
1047
1047
|
official/vision/modeling/backbones/efficientnet_test.py,sha256=TYsUieiLrEU5913s3Yxhv-9eaolQK_kfzJm77lyCF0M,3762
|
1048
1048
|
official/vision/modeling/backbones/factory.py,sha256=coJKJpPMhgM9gAc2Q7I5_CuzAaHZNJwPcvGbaUYp8gU,3504
|
@@ -1206,9 +1206,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbda
|
|
1206
1206
|
tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
|
1207
1207
|
tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
|
1208
1208
|
tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
|
1209
|
-
tf_models_nightly-2.17.0.
|
1210
|
-
tf_models_nightly-2.17.0.
|
1211
|
-
tf_models_nightly-2.17.0.
|
1212
|
-
tf_models_nightly-2.17.0.
|
1213
|
-
tf_models_nightly-2.17.0.
|
1214
|
-
tf_models_nightly-2.17.0.
|
1209
|
+
tf_models_nightly-2.17.0.dev20240417.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
|
1210
|
+
tf_models_nightly-2.17.0.dev20240417.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
|
1211
|
+
tf_models_nightly-2.17.0.dev20240417.dist-info/METADATA,sha256=Sft4CH1GT2IMtgKOIyKA8O9FFE9nlzhWqHn2oHYV-G4,1432
|
1212
|
+
tf_models_nightly-2.17.0.dev20240417.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
|
1213
|
+
tf_models_nightly-2.17.0.dev20240417.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
|
1214
|
+
tf_models_nightly-2.17.0.dev20240417.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|