tf-models-nightly 2.17.0.dev20240418__py2.py3-none-any.whl → 2.17.0.dev20240420__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,6 +18,7 @@ from official.recommendation.uplift.metrics import label_mean
18
18
  from official.recommendation.uplift.metrics import label_variance
19
19
  from official.recommendation.uplift.metrics import loss_metric
20
20
  from official.recommendation.uplift.metrics import metric_configs
21
+ from official.recommendation.uplift.metrics import poisson_metrics
21
22
  from official.recommendation.uplift.metrics import sliced_metric
22
23
  from official.recommendation.uplift.metrics import treatment_fraction
23
24
  from official.recommendation.uplift.metrics import treatment_sliced_metric
@@ -276,3 +276,129 @@ class LogLossMinimum(tf_keras.metrics.Metric):
276
276
  @classmethod
277
277
  def from_config(cls, config: dict[str, Any]) -> LogLossMinimum:
278
278
  return cls(**config)
279
+
280
+
281
+ class PseudoRSquared(tf_keras.metrics.Metric):
282
+ """Computes the pseudo R-squared metric for poisson regression.
283
+
284
+ The pseudo R-squared is computed from log likelihoods of three models:
285
+ 1) LLbaseline: log likelihood of a mean baseline predictor.
286
+ 2) LLfit: log likelihood of the fitted model.
287
+ 3) LLmax: maximum achievable log likelihood, which occurs when the predictions
288
+ equal to the labels.
289
+
290
+ The equation that computes the pseudo R-squared is:
291
+ >>> R_squared = (LLfit - LLbaseline) / (LLmax - LLbaseline)
292
+ """
293
+
294
+ def __init__(
295
+ self,
296
+ from_logits: bool = True,
297
+ slice_by_treatment: bool = True,
298
+ name: str = "pseudo_r_squared",
299
+ dtype: tf.DType = tf.float32,
300
+ ):
301
+ """Initializes the instance.
302
+
303
+ Args:
304
+ from_logits: When `y_pred` is of type `tf.Tensor`, specifies whether
305
+ `y_pred` represents the model's logits or predictions. Otherwise, when
306
+ `y_pred` is of type `TwoTowerTrainingOutputs`, set this to `True` in
307
+ order to compute the loss using the true logits.
308
+ slice_by_treatment: Specifies whether the loss should be sliced by the
309
+ treatment indicator tensor. If `True`, the metric's result will return
310
+ the loss values sliced by the treatment group. Note that this can only
311
+ be set to `True` when `y_pred` is of type `TwoTowerTrainingOutputs`.
312
+ name: Optional name for the instance.
313
+ dtype: Optional data type for the instance.
314
+ """
315
+ super().__init__(name=name, dtype=dtype)
316
+
317
+ self._from_logits = from_logits
318
+ self._slice_by_treatment = slice_by_treatment
319
+
320
+ # Since log_loss = -1 * log_likelihood we can just accumulate the losses.
321
+ loss = LogLoss(
322
+ from_logits=from_logits,
323
+ compute_full_loss=False,
324
+ slice_by_treatment=False,
325
+ name=name,
326
+ dtype=dtype,
327
+ )
328
+ minimum_loss = LogLossMinimum(
329
+ compute_full_loss=False,
330
+ slice_by_treatment=False,
331
+ name=name,
332
+ dtype=dtype,
333
+ )
334
+ mean_baseline_loss = LogLossMeanBaseline(
335
+ compute_full_loss=False,
336
+ slice_by_treatment=False,
337
+ name=name,
338
+ dtype=dtype,
339
+ )
340
+
341
+ if slice_by_treatment:
342
+ self._model_loss = treatment_sliced_metric.TreatmentSlicedMetric(
343
+ metric=loss
344
+ )
345
+ self._minimum_loss = treatment_sliced_metric.TreatmentSlicedMetric(
346
+ metric=minimum_loss
347
+ )
348
+ self._mean_baseline_loss = treatment_sliced_metric.TreatmentSlicedMetric(
349
+ metric=mean_baseline_loss
350
+ )
351
+ else:
352
+ self._model_loss = loss
353
+ self._minimum_loss = minimum_loss
354
+ self._mean_baseline_loss = mean_baseline_loss
355
+
356
+ def update_state(
357
+ self,
358
+ y_true: tf.Tensor,
359
+ y_pred: types.TwoTowerTrainingOutputs | tf.Tensor,
360
+ sample_weight: tf.Tensor | None = None,
361
+ ):
362
+ is_treatment = {}
363
+ if self._slice_by_treatment:
364
+ if not isinstance(y_pred, types.TwoTowerTrainingOutputs):
365
+ raise ValueError(
366
+ "`slice_by_treatment` must be set to `False` when `y_pred` is not"
367
+ " of type `TwoTowerTrainingOutputs`."
368
+ )
369
+ is_treatment["is_treatment"] = y_pred.is_treatment
370
+
371
+ self._model_loss.update_state(
372
+ y_true, y_pred=y_pred, sample_weight=sample_weight, **is_treatment
373
+ )
374
+ self._minimum_loss.update_state(
375
+ y_true, y_pred=y_pred, sample_weight=sample_weight, **is_treatment
376
+ )
377
+ self._mean_baseline_loss.update_state(
378
+ y_true, y_pred=y_pred, sample_weight=sample_weight, **is_treatment
379
+ )
380
+
381
+ def result(self) -> tf.Tensor | dict[str, tf.Tensor]:
382
+ def _pseudo_r_squared(
383
+ loss_model: tf.Tensor, loss_baseline: tf.Tensor, loss_min: tf.Tensor
384
+ ) -> tf.Tensor:
385
+ return tf.math.divide_no_nan(
386
+ loss_model - loss_baseline, loss_min - loss_baseline
387
+ )
388
+
389
+ return tf.nest.map_structure(
390
+ _pseudo_r_squared,
391
+ self._model_loss.result(),
392
+ self._mean_baseline_loss.result(),
393
+ self._minimum_loss.result(),
394
+ )
395
+
396
+ def get_config(self) -> dict[str, Any]:
397
+ config = super().get_config()
398
+ config["from_logits"] = self._from_logits
399
+ config["slice_by_treatment"] = self._slice_by_treatment
400
+ return config
401
+
402
+ @classmethod
403
+ def from_config(cls, config: dict[str, Any]) -> PseudoRSquared:
404
+ return cls(**config)
@@ -330,5 +330,120 @@ class LogLossMinimumTest(keras_test_case.KerasTestCase, parameterized.TestCase):
330
330
  )
331
331
 
332
332
 
333
+ class PseudoRSquaredTest(keras_test_case.KerasTestCase, parameterized.TestCase):
334
+
335
+ @parameterized.named_parameters(
336
+ {
337
+ "testcase_name": "no_data",
338
+ "expected_loss": 0.0,
339
+ "y_true": tf.constant([], dtype=tf.float32),
340
+ "y_pred": tf.constant([], dtype=tf.float32),
341
+ },
342
+ {
343
+ "testcase_name": "one_correct_prediction",
344
+ "expected_loss": 0.0,
345
+ "y_true": tf.constant([1], dtype=tf.float32),
346
+ "y_pred": tf.constant([1], dtype=tf.float32),
347
+ },
348
+ {
349
+ "testcase_name": "one_wrong_prediction",
350
+ "expected_loss": 0.0, # LLmax and LLbaseline are equal.
351
+ "y_true": tf.constant([0], dtype=tf.float32),
352
+ "y_pred": tf.constant([1], dtype=tf.float32),
353
+ },
354
+ {
355
+ "testcase_name": "all_correct_predictions",
356
+ "expected_loss": 1.0,
357
+ "y_true": tf.constant([[1], [2], [3]], dtype=tf.float32),
358
+ "y_pred": tf.constant([[1], [2], [3]], dtype=tf.float32),
359
+ "from_logits": False,
360
+ },
361
+ {
362
+ "testcase_name": "almost_correct_predictions",
363
+ "expected_loss": 1.0,
364
+ "y_true": tf.constant([[1], [2], [3]], dtype=tf.float32),
365
+ "y_pred": tf.constant([[1], [1.9999], [3.0001]], dtype=tf.float32),
366
+ },
367
+ {
368
+ "testcase_name": "from_logits",
369
+ "expected_loss": (
370
+ (tf.math.exp(1.0) / 2) - (0.5 - 0.5 * tf.math.log(0.5))
371
+ ) / (0.5 - (0.5 - 0.5 * tf.math.log(0.5))),
372
+ "y_true": tf.constant([[0], [1]], dtype=tf.float32),
373
+ "y_pred": tf.constant([[0], [1]], dtype=tf.float32),
374
+ "from_logits": True,
375
+ },
376
+ {
377
+ "testcase_name": "two_tower_outputs",
378
+ "expected_loss": (
379
+ ((tf.math.exp(1.0) - 1) + 1) / 2 - (0.5 - 0.5 * tf.math.log(0.5))
380
+ ) / (0.5 - (0.5 - 0.5 * tf.math.log(0.5))),
381
+ "y_true": tf.constant([[0], [1]], dtype=tf.float32),
382
+ "y_pred": _get_two_tower_outputs(
383
+ true_logits=tf.constant([[0], [1]], dtype=tf.float32),
384
+ is_treatment=tf.constant([[0], [1]], dtype=tf.float32),
385
+ ),
386
+ "from_logits": True,
387
+ },
388
+ {
389
+ "testcase_name": "two_tower_outputs_sliced_loss",
390
+ "expected_loss": {
391
+ "r2": (
392
+ ((tf.math.exp(1.0) - 1) + 1) / 2 # LLfit
393
+ - (0.5 - 0.5 * tf.math.log(0.5)) # LLbaseline
394
+ ) / (0.5 - (0.5 - 0.5 * tf.math.log(0.5))),
395
+ "r2/control": 0.0,
396
+ "r2/treatment": 0.0,
397
+ },
398
+ "y_true": tf.constant([[0], [1]], dtype=tf.float32),
399
+ "y_pred": _get_two_tower_outputs(
400
+ true_logits=tf.constant([[0], [1]], dtype=tf.float32),
401
+ is_treatment=tf.constant([[0], [1]], dtype=tf.float32),
402
+ ),
403
+ "from_logits": True,
404
+ "slice_by_treatment": True,
405
+ },
406
+ )
407
+ def test_metric_computation_is_correct(
408
+ self,
409
+ expected_loss: tf.Tensor,
410
+ y_true: tf.Tensor,
411
+ y_pred: types.TwoTowerTrainingOutputs | tf.Tensor,
412
+ sample_weight: tf.Tensor | None = None,
413
+ from_logits: bool = False,
414
+ slice_by_treatment: bool = False,
415
+ ):
416
+ metric = poisson_metrics.PseudoRSquared(
417
+ from_logits=from_logits,
418
+ slice_by_treatment=slice_by_treatment,
419
+ name="r2",
420
+ )
421
+ metric.update_state(y_true, y_pred, sample_weight=sample_weight)
422
+ self.assertAllClose(expected_loss, metric.result())
423
+
424
+ def test_slicing_raises_error_when_input_is_tensor(self):
425
+ metric = poisson_metrics.PseudoRSquared()
426
+ y_true = tf.constant([[0], [0], [2], [7]], dtype=tf.float32)
427
+ y_pred = tf.constant([[1], [2], [3], [4]], dtype=tf.float32)
428
+ with self.assertRaisesRegex(
429
+ ValueError,
430
+ "`slice_by_treatment` must be set to `False` when `y_pred` is not of"
431
+ " type `TwoTowerTrainingOutputs`.",
432
+ ):
433
+ metric(y_true, y_pred)
434
+
435
+ @parameterized.parameters(True, False)
436
+ def test_metric_is_configurable(self, from_logits: bool):
437
+ metric = poisson_metrics.PseudoRSquared(
438
+ from_logits=from_logits, slice_by_treatment=False
439
+ )
440
+ self.assertLayerConfigurable(
441
+ layer=metric,
442
+ y_true=tf.constant([[0], [0], [2], [7]], dtype=tf.float32),
443
+ y_pred=tf.constant([[1], [2], [3], [4]], dtype=tf.float32),
444
+ serializable=True,
445
+ )
446
+
447
+
333
448
  if __name__ == "__main__":
334
449
  tf.test.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.17.0.dev20240418
3
+ Version: 2.17.0.dev20240420
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -907,7 +907,7 @@ official/recommendation/uplift/layers/uplift_networks/two_tower_uplift_network_t
907
907
  official/recommendation/uplift/losses/__init__.py,sha256=b1_FBAD9SWUy8-Wfs2KeiER5WappdJ_sE4dDzip2k6w,711
908
908
  official/recommendation/uplift/losses/true_logits_loss.py,sha256=YrI89Fgcg1fEIpAFczzyNS0Ty9YXa_XClXgqlDsNh10,3524
909
909
  official/recommendation/uplift/losses/true_logits_loss_test.py,sha256=xF91tzTarqG9lNfE8HSYVvRy-dpjP8bJ8FYpHUg3H6U,3105
910
- official/recommendation/uplift/metrics/__init__.py,sha256=qdlYbLPbVMpHUACS1Mobo9chcOyiRj3uLp0ZM-hOVH8,1234
910
+ official/recommendation/uplift/metrics/__init__.py,sha256=0YyJbxTptbh3RQYTeqwnD_YIIc86KeX2ot-C-vQcD6E,1301
911
911
  official/recommendation/uplift/metrics/label_mean.py,sha256=ECaes7FZmsksnwySn7jfTR4IZeOu3X0ZOOx4UiCo7CI,3494
912
912
  official/recommendation/uplift/metrics/label_mean_test.py,sha256=b_d3lNlpkDm2xKLUkxfiXeQg7pjL8HNx7y9NaYarpV0,7083
913
913
  official/recommendation/uplift/metrics/label_variance.py,sha256=9DCl42BJkehxfWD3pSbZnRNvwfhVM6VyHwivGdaU72s,3610
@@ -915,8 +915,8 @@ official/recommendation/uplift/metrics/label_variance_test.py,sha256=k0mdEU1WU53
915
915
  official/recommendation/uplift/metrics/loss_metric.py,sha256=gYZdnTsuL_2q1FZuPip-DaWxt_Q-02YYaePyMBVNx7w,7344
916
916
  official/recommendation/uplift/metrics/loss_metric_test.py,sha256=48rQG8bKFdy0xBFjoOLXKRUlYpCEyAzSmPOFoF7FX94,16021
917
917
  official/recommendation/uplift/metrics/metric_configs.py,sha256=Z-r79orE4EycQ5TJ7xdI5LhjOHT3wzChYyDxcxGqLXk,1670
918
- official/recommendation/uplift/metrics/poisson_metrics.py,sha256=1zzwots4WkpxoYmOICt_CZxuAeXWyT9dktdDgT2IAu4,10197
919
- official/recommendation/uplift/metrics/poisson_metrics_test.py,sha256=o0efkAz1OusU2C85qLl9ZZJLfXflV7acfvWp1igT-1U,12016
918
+ official/recommendation/uplift/metrics/poisson_metrics.py,sha256=ZltEopHvYJlrqjVr9X_NhNQOwWcv2O3rFqvKVVTC57I,14475
919
+ official/recommendation/uplift/metrics/poisson_metrics_test.py,sha256=zGo7Y7XJY4kvecoeAy7Jci_N3YZKMTKpj28j7ZtoTLc,16378
920
920
  official/recommendation/uplift/metrics/sliced_metric.py,sha256=uhvzudOWtMNKZ0avwGhX-37UELR9Cq9b4C0g8erBkXw,8688
921
921
  official/recommendation/uplift/metrics/sliced_metric_test.py,sha256=bhVGyI1tOkFkVOtruJo3p6XopDFyG1JW5qdZm9-RqeU,12248
922
922
  official/recommendation/uplift/metrics/treatment_fraction.py,sha256=WHrKfsN42xU7S-pK99xEVpVtd3zLD7UidLT1K8vgIn4,2757
@@ -1206,9 +1206,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbda
1206
1206
  tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
1207
1207
  tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
1208
1208
  tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
1209
- tf_models_nightly-2.17.0.dev20240418.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1210
- tf_models_nightly-2.17.0.dev20240418.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1211
- tf_models_nightly-2.17.0.dev20240418.dist-info/METADATA,sha256=Qga8KYfDdCPyinPp79D7UyN7doPZH066mZ98Wp5QlUM,1432
1212
- tf_models_nightly-2.17.0.dev20240418.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1213
- tf_models_nightly-2.17.0.dev20240418.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1214
- tf_models_nightly-2.17.0.dev20240418.dist-info/RECORD,,
1209
+ tf_models_nightly-2.17.0.dev20240420.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1210
+ tf_models_nightly-2.17.0.dev20240420.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1211
+ tf_models_nightly-2.17.0.dev20240420.dist-info/METADATA,sha256=sKpzXk5RWNvqAoLbbpYJF3EETKC71XE_H9WqbhV1WOE,1432
1212
+ tf_models_nightly-2.17.0.dev20240420.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1213
+ tf_models_nightly-2.17.0.dev20240420.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1214
+ tf_models_nightly-2.17.0.dev20240420.dist-info/RECORD,,