tf-models-nightly 2.17.0.dev20240409__py2.py3-none-any.whl → 2.17.0.dev20240410__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -62,6 +62,7 @@ class LossMetric(tf_keras.metrics.Metric):
62
62
  Callable[[tf.Tensor, tf.Tensor], tf.Tensor] | tf_keras.metrics.Metric
63
63
  ),
64
64
  from_logits: bool = True,
65
+ slice_by_treatment: bool = True,
65
66
  name: str = "loss",
66
67
  dtype: tf.DType = tf.float32,
67
68
  **loss_fn_kwargs,
@@ -76,6 +77,10 @@ class LossMetric(tf_keras.metrics.Metric):
76
77
  from_logits: Specifies whether the true logits or true predictions should
77
78
  be used from the model outputs to compute the loss. Defaults to using
78
79
  the true logits.
80
+ slice_by_treatment: Specifies whether the loss should be sliced by the
81
+ treatment indicator tensor. If `True`, `loss_fn` will be wrapped in a
82
+ `TreatmentSlicedMetric` to report the loss values sliced by the
83
+ treatment group.
79
84
  name: Optional name for the instance. If `loss_fn` is a Keras metric then
80
85
  its name will be used instead.
81
86
  dtype: Optional data type for the instance. If `loss_fn` is a Keras metric
@@ -99,6 +104,7 @@ class LossMetric(tf_keras.metrics.Metric):
99
104
  self._loss_fn = loss_fn
100
105
  self._from_logits = from_logits
101
106
  self._loss_fn_kwargs = loss_fn_kwargs
107
+ self._slice_by_treatment = slice_by_treatment
102
108
 
103
109
  if isinstance(loss_fn, tf_keras.metrics.Metric):
104
110
  metric_from_logits = loss_fn.get_config().get("from_logits", from_logits)
@@ -108,20 +114,17 @@ class LossMetric(tf_keras.metrics.Metric):
108
114
  " the `from_logits` value passed to the `loss_fn` metric"
109
115
  f" ({metric_from_logits}). Ensure that they have the same value."
110
116
  )
111
-
112
- self._treatment_sliced_loss = (
113
- treatment_sliced_metric.TreatmentSlicedMetric(loss_fn)
114
- )
117
+ loss_metric = loss_fn
115
118
 
116
119
  else:
117
120
  if "from_logits" in inspect.signature(loss_fn).parameters:
118
121
  self._loss_fn_kwargs.update({"from_logits": from_logits})
122
+ loss_metric = tf_keras.metrics.Mean(name=name, dtype=dtype)
119
123
 
120
- self._treatment_sliced_loss = (
121
- treatment_sliced_metric.TreatmentSlicedMetric(
122
- tf_keras.metrics.Mean(name=name, dtype=dtype)
123
- )
124
- )
124
+ if slice_by_treatment:
125
+ self._loss = treatment_sliced_metric.TreatmentSlicedMetric(loss_metric)
126
+ else:
127
+ self._loss = loss_metric
125
128
 
126
129
  def update_state(
127
130
  self,
@@ -151,27 +154,32 @@ class LossMetric(tf_keras.metrics.Metric):
151
154
 
152
155
  pred = y_pred.true_logits if self._from_logits else y_pred.true_predictions
153
156
 
157
+ is_treatment = {}
158
+ if self._slice_by_treatment:
159
+ is_treatment["is_treatment"] = y_pred.is_treatment
160
+
154
161
  if isinstance(self._loss_fn, tf_keras.metrics.Metric):
155
- self._treatment_sliced_loss.update_state(
162
+ self._loss.update_state(
156
163
  y_true,
157
164
  y_pred=pred,
158
- is_treatment=y_pred.is_treatment,
159
165
  sample_weight=sample_weight,
166
+ **is_treatment,
160
167
  )
161
168
  else:
162
- self._treatment_sliced_loss.update_state(
169
+ self._loss.update_state(
163
170
  values=self._loss_fn(y_true, pred, **self._loss_fn_kwargs),
164
- is_treatment=y_pred.is_treatment,
165
171
  sample_weight=sample_weight,
172
+ **is_treatment,
166
173
  )
167
174
 
168
- def result(self) -> dict[str, tf.Tensor]:
169
- return self._treatment_sliced_loss.result()
175
+ def result(self) -> tf.Tensor | dict[str, tf.Tensor]:
176
+ return self._loss.result()
170
177
 
171
178
  def get_config(self) -> dict[str, Any]:
172
179
  config = super().get_config()
173
180
  config["loss_fn"] = tf_keras.utils.serialize_keras_object(self._loss_fn)
174
181
  config["from_logits"] = self._from_logits
182
+ config["slice_by_treatment"] = self._slice_by_treatment
175
183
  config.update(self._loss_fn_kwargs)
176
184
  return config
177
185
 
@@ -180,4 +188,4 @@ class LossMetric(tf_keras.metrics.Metric):
180
188
  config["loss_fn"] = tf_keras.utils.deserialize_keras_object(
181
189
  config["loss_fn"]
182
190
  )
183
- return LossMetric(**config)
191
+ return cls(**config)
@@ -25,9 +25,7 @@ from official.recommendation.uplift import types
25
25
  from official.recommendation.uplift.metrics import loss_metric
26
26
 
27
27
 
28
- class LossMetricTest(
29
- keras_test_case.KerasTestCase, parameterized.TestCase
30
- ):
28
+ class LossMetricTest(keras_test_case.KerasTestCase, parameterized.TestCase):
31
29
 
32
30
  def _get_outputs(
33
31
  self,
@@ -236,6 +234,28 @@ class LossMetricTest(
236
234
  "loss/treatment": 0.0,
237
235
  },
238
236
  },
237
+ {
238
+ "testcase_name": "no_treatment_slice",
239
+ "loss_fn": tf_keras.losses.binary_crossentropy,
240
+ "from_logits": True,
241
+ "y_true": tf.constant([[0.0, 1.0]]),
242
+ "y_pred": tf.constant([[0.0, 1.0]]),
243
+ "is_treatment": tf.constant([[0], [0]]),
244
+ "sample_weight": None,
245
+ "expected_losses": 0.50320446,
246
+ "slice_by_treatment": False,
247
+ },
248
+ {
249
+ "testcase_name": "no_treatment_slice_metric",
250
+ "loss_fn": tf_keras.metrics.BinaryCrossentropy(from_logits=False),
251
+ "from_logits": False,
252
+ "y_true": tf.constant([[0.0, 1.0]]),
253
+ "y_pred": tf.constant([[0.0, 1.0]]),
254
+ "is_treatment": tf.constant([[0], [0]]),
255
+ "sample_weight": None,
256
+ "expected_losses": 0,
257
+ "slice_by_treatment": False,
258
+ },
239
259
  )
240
260
  def test_metric_computes_sliced_losses(
241
261
  self,
@@ -245,7 +265,8 @@ class LossMetricTest(
245
265
  y_pred: tf.Tensor,
246
266
  is_treatment: tf.Tensor,
247
267
  sample_weight: tf.Tensor | None,
248
- expected_losses: dict[str, float],
268
+ expected_losses: float | dict[str, float],
269
+ slice_by_treatment: bool = True,
249
270
  ):
250
271
  if from_logits:
251
272
  true_logits = y_pred
@@ -254,7 +275,11 @@ class LossMetricTest(
254
275
  true_logits = tf.zeros_like(y_pred) # Irrelevant for testing.
255
276
  true_predictions = y_pred
256
277
 
257
- metric = loss_metric.LossMetric(loss_fn=loss_fn, from_logits=from_logits)
278
+ metric = loss_metric.LossMetric(
279
+ loss_fn=loss_fn,
280
+ from_logits=from_logits,
281
+ slice_by_treatment=slice_by_treatment,
282
+ )
258
283
  outputs = self._get_outputs(
259
284
  true_logits=true_logits,
260
285
  true_predictions=true_predictions,
@@ -335,17 +360,28 @@ class LossMetricTest(
335
360
  metric.reset_states()
336
361
  self.assertEqual(expected_initial_result, metric.result())
337
362
 
338
- @parameterized.parameters(
339
- tf_keras.losses.binary_crossentropy,
340
- tf_keras.metrics.BinaryCrossentropy(from_logits=True, name="bce_loss"),
363
+ @parameterized.product(
364
+ loss_fn=(
365
+ tf_keras.losses.binary_crossentropy,
366
+ tf_keras.metrics.BinaryCrossentropy(
367
+ from_logits=True, name="bce_loss"
368
+ ),
369
+ ),
370
+ slice_by_treatment=(True, False),
341
371
  )
342
372
  def test_metric_is_configurable(
343
373
  self,
344
374
  loss_fn: (
345
375
  Callable[[tf.Tensor, tf.Tensor], tf.Tensor] | tf_keras.metrics.Metric
346
376
  ),
377
+ slice_by_treatment: bool,
347
378
  ):
348
- metric = loss_metric.LossMetric(loss_fn, from_logits=True, name="bce_loss")
379
+ metric = loss_metric.LossMetric(
380
+ loss_fn,
381
+ from_logits=True,
382
+ slice_by_treatment=slice_by_treatment,
383
+ name="bce_loss",
384
+ )
349
385
  self.assertLayerConfigurable(
350
386
  layer=metric,
351
387
  y_true=tf.constant([[1], [1], [0]]),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.17.0.dev20240409
3
+ Version: 2.17.0.dev20240410
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -912,8 +912,8 @@ official/recommendation/uplift/metrics/label_mean.py,sha256=ECaes7FZmsksnwySn7jf
912
912
  official/recommendation/uplift/metrics/label_mean_test.py,sha256=b_d3lNlpkDm2xKLUkxfiXeQg7pjL8HNx7y9NaYarpV0,7083
913
913
  official/recommendation/uplift/metrics/label_variance.py,sha256=9DCl42BJkehxfWD3pSbZnRNvwfhVM6VyHwivGdaU72s,3610
914
914
  official/recommendation/uplift/metrics/label_variance_test.py,sha256=k0mdEU1WU53-HIEO5HGtfp1MleifD-h4bZNKtTvM3Ws,7681
915
- official/recommendation/uplift/metrics/loss_metric.py,sha256=8Lfi4FNR6uj8eLSdT5t0XjEzTZtwL-xv5e5KNGzrYWU,6498
916
- official/recommendation/uplift/metrics/loss_metric_test.py,sha256=yropE1t8PblsJ_kJEno5ratkY_ka81PHgcfroWtRKVI,13768
915
+ official/recommendation/uplift/metrics/loss_metric.py,sha256=qWMcLk1JmLkru5T9qcqO2UbB_AJQKjQ8uUSAxjR1l58,6883
916
+ official/recommendation/uplift/metrics/loss_metric_test.py,sha256=QIUujZvUud-A2itI2vs2KpzWfLwbq2XDd9H9CFq5D4Y,14929
917
917
  official/recommendation/uplift/metrics/metric_configs.py,sha256=Z-r79orE4EycQ5TJ7xdI5LhjOHT3wzChYyDxcxGqLXk,1670
918
918
  official/recommendation/uplift/metrics/sliced_metric.py,sha256=O2I2apZK6IfOQK9Q_mgSiTUCnGokczp4e14zrrYNeRU,8564
919
919
  official/recommendation/uplift/metrics/sliced_metric_test.py,sha256=dhY41X8lqT_WW04XLjyjDerZwujEBGeTXtxf4NkYThw,11359
@@ -1203,9 +1203,9 @@ tensorflow_models/__init__.py,sha256=etxw45SHxuwFCRX5qGxGMP83II0JfJulzNl5GSNJvhw
1203
1203
  tensorflow_models/tensorflow_models_test.py,sha256=AxUYUdiQn416UR7jg0h6rmv688esvlKDfpyDCIQkF18,1395
1204
1204
  tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
1205
1205
  tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
1206
- tf_models_nightly-2.17.0.dev20240409.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1207
- tf_models_nightly-2.17.0.dev20240409.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1208
- tf_models_nightly-2.17.0.dev20240409.dist-info/METADATA,sha256=738N3KUX3sKTeF-7mDL3_2r1htitc4BCT9A4QMgB9U4,1432
1209
- tf_models_nightly-2.17.0.dev20240409.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1210
- tf_models_nightly-2.17.0.dev20240409.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1211
- tf_models_nightly-2.17.0.dev20240409.dist-info/RECORD,,
1206
+ tf_models_nightly-2.17.0.dev20240410.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1207
+ tf_models_nightly-2.17.0.dev20240410.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1208
+ tf_models_nightly-2.17.0.dev20240410.dist-info/METADATA,sha256=Bh56e1__Pn_yvpm9QMq4h12nKucvnKAftuk0ElQqn_w,1432
1209
+ tf_models_nightly-2.17.0.dev20240410.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1210
+ tf_models_nightly-2.17.0.dev20240410.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1211
+ tf_models_nightly-2.17.0.dev20240410.dist-info/RECORD,,