tf-models-nightly 2.17.0.dev20240409__py2.py3-none-any.whl → 2.17.0.dev20240410__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- official/recommendation/uplift/metrics/loss_metric.py +24 -16
- official/recommendation/uplift/metrics/loss_metric_test.py +45 -9
- {tf_models_nightly-2.17.0.dev20240409.dist-info → tf_models_nightly-2.17.0.dev20240410.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.17.0.dev20240409.dist-info → tf_models_nightly-2.17.0.dev20240410.dist-info}/RECORD +8 -8
- {tf_models_nightly-2.17.0.dev20240409.dist-info → tf_models_nightly-2.17.0.dev20240410.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.17.0.dev20240409.dist-info → tf_models_nightly-2.17.0.dev20240410.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.17.0.dev20240409.dist-info → tf_models_nightly-2.17.0.dev20240410.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.17.0.dev20240409.dist-info → tf_models_nightly-2.17.0.dev20240410.dist-info}/top_level.txt +0 -0
@@ -62,6 +62,7 @@ class LossMetric(tf_keras.metrics.Metric):
|
|
62
62
|
Callable[[tf.Tensor, tf.Tensor], tf.Tensor] | tf_keras.metrics.Metric
|
63
63
|
),
|
64
64
|
from_logits: bool = True,
|
65
|
+
slice_by_treatment: bool = True,
|
65
66
|
name: str = "loss",
|
66
67
|
dtype: tf.DType = tf.float32,
|
67
68
|
**loss_fn_kwargs,
|
@@ -76,6 +77,10 @@ class LossMetric(tf_keras.metrics.Metric):
|
|
76
77
|
from_logits: Specifies whether the true logits or true predictions should
|
77
78
|
be used from the model outputs to compute the loss. Defaults to using
|
78
79
|
the true logits.
|
80
|
+
slice_by_treatment: Specifies whether the loss should be sliced by the
|
81
|
+
treatment indicator tensor. If `True`, `loss_fn` will be wrapped in a
|
82
|
+
`TreatmentSlicedMetric` to report the loss values sliced by the
|
83
|
+
treatment group.
|
79
84
|
name: Optional name for the instance. If `loss_fn` is a Keras metric then
|
80
85
|
its name will be used instead.
|
81
86
|
dtype: Optional data type for the instance. If `loss_fn` is a Keras metric
|
@@ -99,6 +104,7 @@ class LossMetric(tf_keras.metrics.Metric):
|
|
99
104
|
self._loss_fn = loss_fn
|
100
105
|
self._from_logits = from_logits
|
101
106
|
self._loss_fn_kwargs = loss_fn_kwargs
|
107
|
+
self._slice_by_treatment = slice_by_treatment
|
102
108
|
|
103
109
|
if isinstance(loss_fn, tf_keras.metrics.Metric):
|
104
110
|
metric_from_logits = loss_fn.get_config().get("from_logits", from_logits)
|
@@ -108,20 +114,17 @@ class LossMetric(tf_keras.metrics.Metric):
|
|
108
114
|
" the `from_logits` value passed to the `loss_fn` metric"
|
109
115
|
f" ({metric_from_logits}). Ensure that they have the same value."
|
110
116
|
)
|
111
|
-
|
112
|
-
self._treatment_sliced_loss = (
|
113
|
-
treatment_sliced_metric.TreatmentSlicedMetric(loss_fn)
|
114
|
-
)
|
117
|
+
loss_metric = loss_fn
|
115
118
|
|
116
119
|
else:
|
117
120
|
if "from_logits" in inspect.signature(loss_fn).parameters:
|
118
121
|
self._loss_fn_kwargs.update({"from_logits": from_logits})
|
122
|
+
loss_metric = tf_keras.metrics.Mean(name=name, dtype=dtype)
|
119
123
|
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
)
|
124
|
+
if slice_by_treatment:
|
125
|
+
self._loss = treatment_sliced_metric.TreatmentSlicedMetric(loss_metric)
|
126
|
+
else:
|
127
|
+
self._loss = loss_metric
|
125
128
|
|
126
129
|
def update_state(
|
127
130
|
self,
|
@@ -151,27 +154,32 @@ class LossMetric(tf_keras.metrics.Metric):
|
|
151
154
|
|
152
155
|
pred = y_pred.true_logits if self._from_logits else y_pred.true_predictions
|
153
156
|
|
157
|
+
is_treatment = {}
|
158
|
+
if self._slice_by_treatment:
|
159
|
+
is_treatment["is_treatment"] = y_pred.is_treatment
|
160
|
+
|
154
161
|
if isinstance(self._loss_fn, tf_keras.metrics.Metric):
|
155
|
-
self.
|
162
|
+
self._loss.update_state(
|
156
163
|
y_true,
|
157
164
|
y_pred=pred,
|
158
|
-
is_treatment=y_pred.is_treatment,
|
159
165
|
sample_weight=sample_weight,
|
166
|
+
**is_treatment,
|
160
167
|
)
|
161
168
|
else:
|
162
|
-
self.
|
169
|
+
self._loss.update_state(
|
163
170
|
values=self._loss_fn(y_true, pred, **self._loss_fn_kwargs),
|
164
|
-
is_treatment=y_pred.is_treatment,
|
165
171
|
sample_weight=sample_weight,
|
172
|
+
**is_treatment,
|
166
173
|
)
|
167
174
|
|
168
|
-
def result(self) -> dict[str, tf.Tensor]:
|
169
|
-
return self.
|
175
|
+
def result(self) -> tf.Tensor | dict[str, tf.Tensor]:
|
176
|
+
return self._loss.result()
|
170
177
|
|
171
178
|
def get_config(self) -> dict[str, Any]:
|
172
179
|
config = super().get_config()
|
173
180
|
config["loss_fn"] = tf_keras.utils.serialize_keras_object(self._loss_fn)
|
174
181
|
config["from_logits"] = self._from_logits
|
182
|
+
config["slice_by_treatment"] = self._slice_by_treatment
|
175
183
|
config.update(self._loss_fn_kwargs)
|
176
184
|
return config
|
177
185
|
|
@@ -180,4 +188,4 @@ class LossMetric(tf_keras.metrics.Metric):
|
|
180
188
|
config["loss_fn"] = tf_keras.utils.deserialize_keras_object(
|
181
189
|
config["loss_fn"]
|
182
190
|
)
|
183
|
-
return
|
191
|
+
return cls(**config)
|
@@ -25,9 +25,7 @@ from official.recommendation.uplift import types
|
|
25
25
|
from official.recommendation.uplift.metrics import loss_metric
|
26
26
|
|
27
27
|
|
28
|
-
class LossMetricTest(
|
29
|
-
keras_test_case.KerasTestCase, parameterized.TestCase
|
30
|
-
):
|
28
|
+
class LossMetricTest(keras_test_case.KerasTestCase, parameterized.TestCase):
|
31
29
|
|
32
30
|
def _get_outputs(
|
33
31
|
self,
|
@@ -236,6 +234,28 @@ class LossMetricTest(
|
|
236
234
|
"loss/treatment": 0.0,
|
237
235
|
},
|
238
236
|
},
|
237
|
+
{
|
238
|
+
"testcase_name": "no_treatment_slice",
|
239
|
+
"loss_fn": tf_keras.losses.binary_crossentropy,
|
240
|
+
"from_logits": True,
|
241
|
+
"y_true": tf.constant([[0.0, 1.0]]),
|
242
|
+
"y_pred": tf.constant([[0.0, 1.0]]),
|
243
|
+
"is_treatment": tf.constant([[0], [0]]),
|
244
|
+
"sample_weight": None,
|
245
|
+
"expected_losses": 0.50320446,
|
246
|
+
"slice_by_treatment": False,
|
247
|
+
},
|
248
|
+
{
|
249
|
+
"testcase_name": "no_treatment_slice_metric",
|
250
|
+
"loss_fn": tf_keras.metrics.BinaryCrossentropy(from_logits=False),
|
251
|
+
"from_logits": False,
|
252
|
+
"y_true": tf.constant([[0.0, 1.0]]),
|
253
|
+
"y_pred": tf.constant([[0.0, 1.0]]),
|
254
|
+
"is_treatment": tf.constant([[0], [0]]),
|
255
|
+
"sample_weight": None,
|
256
|
+
"expected_losses": 0,
|
257
|
+
"slice_by_treatment": False,
|
258
|
+
},
|
239
259
|
)
|
240
260
|
def test_metric_computes_sliced_losses(
|
241
261
|
self,
|
@@ -245,7 +265,8 @@ class LossMetricTest(
|
|
245
265
|
y_pred: tf.Tensor,
|
246
266
|
is_treatment: tf.Tensor,
|
247
267
|
sample_weight: tf.Tensor | None,
|
248
|
-
expected_losses: dict[str, float],
|
268
|
+
expected_losses: float | dict[str, float],
|
269
|
+
slice_by_treatment: bool = True,
|
249
270
|
):
|
250
271
|
if from_logits:
|
251
272
|
true_logits = y_pred
|
@@ -254,7 +275,11 @@ class LossMetricTest(
|
|
254
275
|
true_logits = tf.zeros_like(y_pred) # Irrelevant for testing.
|
255
276
|
true_predictions = y_pred
|
256
277
|
|
257
|
-
metric = loss_metric.LossMetric(
|
278
|
+
metric = loss_metric.LossMetric(
|
279
|
+
loss_fn=loss_fn,
|
280
|
+
from_logits=from_logits,
|
281
|
+
slice_by_treatment=slice_by_treatment,
|
282
|
+
)
|
258
283
|
outputs = self._get_outputs(
|
259
284
|
true_logits=true_logits,
|
260
285
|
true_predictions=true_predictions,
|
@@ -335,17 +360,28 @@ class LossMetricTest(
|
|
335
360
|
metric.reset_states()
|
336
361
|
self.assertEqual(expected_initial_result, metric.result())
|
337
362
|
|
338
|
-
@parameterized.
|
339
|
-
|
340
|
-
|
363
|
+
@parameterized.product(
|
364
|
+
loss_fn=(
|
365
|
+
tf_keras.losses.binary_crossentropy,
|
366
|
+
tf_keras.metrics.BinaryCrossentropy(
|
367
|
+
from_logits=True, name="bce_loss"
|
368
|
+
),
|
369
|
+
),
|
370
|
+
slice_by_treatment=(True, False),
|
341
371
|
)
|
342
372
|
def test_metric_is_configurable(
|
343
373
|
self,
|
344
374
|
loss_fn: (
|
345
375
|
Callable[[tf.Tensor, tf.Tensor], tf.Tensor] | tf_keras.metrics.Metric
|
346
376
|
),
|
377
|
+
slice_by_treatment: bool,
|
347
378
|
):
|
348
|
-
metric = loss_metric.LossMetric(
|
379
|
+
metric = loss_metric.LossMetric(
|
380
|
+
loss_fn,
|
381
|
+
from_logits=True,
|
382
|
+
slice_by_treatment=slice_by_treatment,
|
383
|
+
name="bce_loss",
|
384
|
+
)
|
349
385
|
self.assertLayerConfigurable(
|
350
386
|
layer=metric,
|
351
387
|
y_true=tf.constant([[1], [1], [0]]),
|
@@ -912,8 +912,8 @@ official/recommendation/uplift/metrics/label_mean.py,sha256=ECaes7FZmsksnwySn7jf
|
|
912
912
|
official/recommendation/uplift/metrics/label_mean_test.py,sha256=b_d3lNlpkDm2xKLUkxfiXeQg7pjL8HNx7y9NaYarpV0,7083
|
913
913
|
official/recommendation/uplift/metrics/label_variance.py,sha256=9DCl42BJkehxfWD3pSbZnRNvwfhVM6VyHwivGdaU72s,3610
|
914
914
|
official/recommendation/uplift/metrics/label_variance_test.py,sha256=k0mdEU1WU53-HIEO5HGtfp1MleifD-h4bZNKtTvM3Ws,7681
|
915
|
-
official/recommendation/uplift/metrics/loss_metric.py,sha256=
|
916
|
-
official/recommendation/uplift/metrics/loss_metric_test.py,sha256=
|
915
|
+
official/recommendation/uplift/metrics/loss_metric.py,sha256=qWMcLk1JmLkru5T9qcqO2UbB_AJQKjQ8uUSAxjR1l58,6883
|
916
|
+
official/recommendation/uplift/metrics/loss_metric_test.py,sha256=QIUujZvUud-A2itI2vs2KpzWfLwbq2XDd9H9CFq5D4Y,14929
|
917
917
|
official/recommendation/uplift/metrics/metric_configs.py,sha256=Z-r79orE4EycQ5TJ7xdI5LhjOHT3wzChYyDxcxGqLXk,1670
|
918
918
|
official/recommendation/uplift/metrics/sliced_metric.py,sha256=O2I2apZK6IfOQK9Q_mgSiTUCnGokczp4e14zrrYNeRU,8564
|
919
919
|
official/recommendation/uplift/metrics/sliced_metric_test.py,sha256=dhY41X8lqT_WW04XLjyjDerZwujEBGeTXtxf4NkYThw,11359
|
@@ -1203,9 +1203,9 @@ tensorflow_models/__init__.py,sha256=etxw45SHxuwFCRX5qGxGMP83II0JfJulzNl5GSNJvhw
|
|
1203
1203
|
tensorflow_models/tensorflow_models_test.py,sha256=AxUYUdiQn416UR7jg0h6rmv688esvlKDfpyDCIQkF18,1395
|
1204
1204
|
tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
|
1205
1205
|
tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
|
1206
|
-
tf_models_nightly-2.17.0.
|
1207
|
-
tf_models_nightly-2.17.0.
|
1208
|
-
tf_models_nightly-2.17.0.
|
1209
|
-
tf_models_nightly-2.17.0.
|
1210
|
-
tf_models_nightly-2.17.0.
|
1211
|
-
tf_models_nightly-2.17.0.
|
1206
|
+
tf_models_nightly-2.17.0.dev20240410.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
|
1207
|
+
tf_models_nightly-2.17.0.dev20240410.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
|
1208
|
+
tf_models_nightly-2.17.0.dev20240410.dist-info/METADATA,sha256=Bh56e1__Pn_yvpm9QMq4h12nKucvnKAftuk0ElQqn_w,1432
|
1209
|
+
tf_models_nightly-2.17.0.dev20240410.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
|
1210
|
+
tf_models_nightly-2.17.0.dev20240410.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
|
1211
|
+
tf_models_nightly-2.17.0.dev20240410.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|