tf-models-nightly 2.17.0.dev20240417__py2.py3-none-any.whl → 2.17.0.dev20240418__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -198,3 +198,81 @@ class LogLossMeanBaseline(tf_keras.metrics.Metric):
198
198
  @classmethod
199
199
  def from_config(cls, config: dict[str, Any]) -> LogLossMeanBaseline:
200
200
  return cls(**config)
201
+
202
+
203
+ class LogLossMinimum(tf_keras.metrics.Metric):
204
+ """Computes the minimum achievable (weighted) poisson log loss.
205
+
206
+ Given labels `y` and the model's predictions `x`, the minimum loss is obtained
207
+ when `x` equals `y`. In this case the loss is computed as:
208
+ `loss = y - y * log(y) + [y * log(y) - y + 0.5 * log(2 * pi * y)]`
209
+
210
+ Note that `[y * log(y) - y + 0.5 * log(2 * pi * y)]` is only computed if
211
+ `compute_full_loss` is set to `True`.
212
+ """
213
+
214
+ def __init__(
215
+ self,
216
+ compute_full_loss: bool = False,
217
+ slice_by_treatment: bool = True,
218
+ name: str = "poisson_log_loss_minimum",
219
+ dtype: tf.DType = tf.float32,
220
+ ):
221
+ """Initializes the instance.
222
+
223
+ Args:
224
+ compute_full_loss: Specifies whether to compute the full minimum log loss
225
+ or not. Defaults to `False`.
226
+ slice_by_treatment: Specifies whether the loss should be sliced by the
227
+ treatment indicator tensor. If `True`, the metric's result will return
228
+ the loss values sliced by the treatment group. Note that this can only
229
+ be set to `True` when `y_pred` is of type `TwoTowerTrainingOutputs`.
230
+ name: Optional name for the instance.
231
+ dtype: Optional data type for the instance.
232
+ """
233
+ super().__init__(name=name, dtype=dtype)
234
+
235
+ if compute_full_loss:
236
+ raise NotImplementedError("Full loss computation is not yet supported.")
237
+
238
+ self._compute_full_loss = compute_full_loss
239
+ self._slice_by_treatment = slice_by_treatment
240
+
241
+ if slice_by_treatment:
242
+ self._loss = treatment_sliced_metric.TreatmentSlicedMetric(
243
+ metric=tf_keras.metrics.Mean(name=name, dtype=dtype)
244
+ )
245
+ else:
246
+ self._loss = tf_keras.metrics.Mean(name=name, dtype=dtype)
247
+
248
+ def update_state(
249
+ self,
250
+ y_true: tf.Tensor,
251
+ y_pred: types.TwoTowerTrainingOutputs | tf.Tensor | None = None,
252
+ sample_weight: tf.Tensor | None = None,
253
+ ):
254
+ is_treatment = {}
255
+ if self._slice_by_treatment:
256
+ if not isinstance(y_pred, types.TwoTowerTrainingOutputs):
257
+ raise ValueError(
258
+ "`slice_by_treatment` must be set to `False` when `y_pred` is not"
259
+ " of type `TwoTowerTrainingOutputs`."
260
+ )
261
+ is_treatment["is_treatment"] = y_pred.is_treatment
262
+
263
+ self._loss.update_state(
264
+ _safe_x_minus_xlogx(y_true), sample_weight=sample_weight, **is_treatment
265
+ )
266
+
267
+ def result(self) -> tf.Tensor | dict[str, tf.Tensor]:
268
+ return self._loss.result()
269
+
270
+ def get_config(self) -> dict[str, Any]:
271
+ config = super().get_config()
272
+ config["compute_full_loss"] = self._compute_full_loss
273
+ config["slice_by_treatment"] = self._slice_by_treatment
274
+ return config
275
+
276
+ @classmethod
277
+ def from_config(cls, config: dict[str, Any]) -> LogLossMinimum:
278
+ return cls(**config)
@@ -256,5 +256,79 @@ class LogLossMeanBaselineTest(
256
256
  )
257
257
 
258
258
 
259
+ class LogLossMinimumTest(keras_test_case.KerasTestCase, parameterized.TestCase):
260
+
261
+ @parameterized.named_parameters(
262
+ {
263
+ "testcase_name": "label_zero",
264
+ "expected_loss": 0.0,
265
+ "y_true": tf.constant([0], dtype=tf.float32),
266
+ },
267
+ {
268
+ "testcase_name": "small_positive_label",
269
+ "expected_loss": 0.0,
270
+ "y_true": tf.constant([1e-10], dtype=tf.float32),
271
+ },
272
+ {
273
+ "testcase_name": "label_one",
274
+ "expected_loss": 1.0,
275
+ "y_true": tf.constant([1], dtype=tf.float32),
276
+ },
277
+ {
278
+ "testcase_name": "weighted_loss",
279
+ "expected_loss": 1.0,
280
+ "y_true": tf.constant([[0], [1]], dtype=tf.float32),
281
+ "sample_weight": tf.constant([[0], [1]], dtype=tf.float32),
282
+ },
283
+ {
284
+ "testcase_name": "two_tower_outputs",
285
+ "expected_loss": 0.5,
286
+ "y_true": tf.constant([[0], [1]], dtype=tf.float32),
287
+ "y_pred": _get_two_tower_outputs(
288
+ is_treatment=tf.constant([[0], [1]], dtype=tf.float32),
289
+ ),
290
+ },
291
+ {
292
+ "testcase_name": "two_tower_outputs_sliced_loss",
293
+ "expected_loss": {
294
+ "loss": 0.5,
295
+ "loss/control": 0.0,
296
+ "loss/treatment": 1.0,
297
+ },
298
+ "y_true": tf.constant([[0], [1]], dtype=tf.float32),
299
+ "y_pred": _get_two_tower_outputs(
300
+ is_treatment=tf.constant([[0], [1]], dtype=tf.float32),
301
+ ),
302
+ "slice_by_treatment": True,
303
+ },
304
+ )
305
+ def test_metric_computes_correct_loss(
306
+ self,
307
+ expected_loss: tf.Tensor,
308
+ y_true: tf.Tensor,
309
+ y_pred: types.TwoTowerTrainingOutputs | tf.Tensor | None = None,
310
+ sample_weight: tf.Tensor | None = None,
311
+ slice_by_treatment: bool = False,
312
+ ):
313
+ metric = poisson_metrics.LogLossMinimum(
314
+ slice_by_treatment=slice_by_treatment, name="loss"
315
+ )
316
+ metric.update_state(y_true, y_pred, sample_weight=sample_weight)
317
+ self.assertAllClose(expected_loss, metric.result())
318
+
319
+ def test_negative_label_returns_nan_loss(self):
320
+ metric = poisson_metrics.LogLossMinimum(slice_by_treatment=False)
321
+ metric.update_state(tf.constant([-1.0]))
322
+ self.assertTrue(tf.math.is_nan(metric.result()).numpy().item())
323
+
324
+ def test_metric_is_configurable(self):
325
+ metric = poisson_metrics.LogLossMinimum(slice_by_treatment=False)
326
+ self.assertLayerConfigurable(
327
+ layer=metric,
328
+ y_true=tf.constant([[0], [0], [2], [7]], dtype=tf.float32),
329
+ serializable=True,
330
+ )
331
+
332
+
259
333
  if __name__ == "__main__":
260
334
  tf.test.main()
@@ -625,6 +625,14 @@ MNV3SmallReducedFilters = {
625
625
  }
626
626
 
627
627
 
628
+ """
629
+ Architecture: https://arxiv.org/abs/2404.10518
630
+
631
+ "MobileNetV4 - Universal Models for the Mobile Ecosystem"
632
+ Danfeng Qin, Chas Leichner, Manolis Delakis, Marco Fornoni, Shixin Luo, Fan
633
+ Yang, Weijun Wang, Colby Banbury, Chengxi Ye, Berkin Akin, Vaibhav Aggarwal,
634
+ Tenghui Zhu, Daniele Moro, Andrew Howard
635
+ """
628
636
  MNV4ConvSmall_BLOCK_SPECS = {
629
637
  'spec_name': 'MobileNetV4ConvSmall',
630
638
  'block_spec_schema': [
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.17.0.dev20240417
3
+ Version: 2.17.0.dev20240418
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -915,8 +915,8 @@ official/recommendation/uplift/metrics/label_variance_test.py,sha256=k0mdEU1WU53
915
915
  official/recommendation/uplift/metrics/loss_metric.py,sha256=gYZdnTsuL_2q1FZuPip-DaWxt_Q-02YYaePyMBVNx7w,7344
916
916
  official/recommendation/uplift/metrics/loss_metric_test.py,sha256=48rQG8bKFdy0xBFjoOLXKRUlYpCEyAzSmPOFoF7FX94,16021
917
917
  official/recommendation/uplift/metrics/metric_configs.py,sha256=Z-r79orE4EycQ5TJ7xdI5LhjOHT3wzChYyDxcxGqLXk,1670
918
- official/recommendation/uplift/metrics/poisson_metrics.py,sha256=1a0sHwdOLUVhmsdpnZIQxAWkt3QAPWS9TZCK6P9H7ug,7443
919
- official/recommendation/uplift/metrics/poisson_metrics_test.py,sha256=sLx4RjtvJiD8kRkXBDfsUANK0P5LSMfTTPj9GF73wOw,9473
918
+ official/recommendation/uplift/metrics/poisson_metrics.py,sha256=1zzwots4WkpxoYmOICt_CZxuAeXWyT9dktdDgT2IAu4,10197
919
+ official/recommendation/uplift/metrics/poisson_metrics_test.py,sha256=o0efkAz1OusU2C85qLl9ZZJLfXflV7acfvWp1igT-1U,12016
920
920
  official/recommendation/uplift/metrics/sliced_metric.py,sha256=uhvzudOWtMNKZ0avwGhX-37UELR9Cq9b4C0g8erBkXw,8688
921
921
  official/recommendation/uplift/metrics/sliced_metric_test.py,sha256=bhVGyI1tOkFkVOtruJo3p6XopDFyG1JW5qdZm9-RqeU,12248
922
922
  official/recommendation/uplift/metrics/treatment_fraction.py,sha256=WHrKfsN42xU7S-pK99xEVpVtd3zLD7UidLT1K8vgIn4,2757
@@ -1049,7 +1049,7 @@ official/vision/modeling/backbones/factory.py,sha256=coJKJpPMhgM9gAc2Q7I5_CuzAaH
1049
1049
  official/vision/modeling/backbones/factory_test.py,sha256=7ZJRDSQ_cqJFyfqLK375V_wEqgrQpqibzNDZzNbhthU,8635
1050
1050
  official/vision/modeling/backbones/mobiledet.py,sha256=iEC_KbqYqUBBBwZUfRCVtqllQwK6N4T1jmiDl29B-Ys,24896
1051
1051
  official/vision/modeling/backbones/mobiledet_test.py,sha256=O2yfL7MSCGtKsnXr0IVUtjicrhZGGkwTXWCLtqdsL0Y,3804
1052
- official/vision/modeling/backbones/mobilenet.py,sha256=dVABJm7mRizcqwVFiWjulb6aIGgVoa4JuUX5ms_H7II,60978
1052
+ official/vision/modeling/backbones/mobilenet.py,sha256=iwUS9WSAZA6cyagw2Ld1zUlyR1MmvA9-AS7Gm4M8HZA,61286
1053
1053
  official/vision/modeling/backbones/mobilenet_test.py,sha256=7cl5eerD5j5UqHL8SLmpou-PjufBz8oz_cn3tqwW1vM,13057
1054
1054
  official/vision/modeling/backbones/resnet.py,sha256=dnYkdlYUzChGLOrQnUbwb9YJ7BDiFwgnLptks7kFb7k,16384
1055
1055
  official/vision/modeling/backbones/resnet_3d.py,sha256=Cq1lrlRqIg9ss_ud1iM_axW9lsTVtGYe3iA4DL9Orzk,18657
@@ -1206,9 +1206,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbda
1206
1206
  tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
1207
1207
  tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
1208
1208
  tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
1209
- tf_models_nightly-2.17.0.dev20240417.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1210
- tf_models_nightly-2.17.0.dev20240417.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1211
- tf_models_nightly-2.17.0.dev20240417.dist-info/METADATA,sha256=Sft4CH1GT2IMtgKOIyKA8O9FFE9nlzhWqHn2oHYV-G4,1432
1212
- tf_models_nightly-2.17.0.dev20240417.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1213
- tf_models_nightly-2.17.0.dev20240417.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1214
- tf_models_nightly-2.17.0.dev20240417.dist-info/RECORD,,
1209
+ tf_models_nightly-2.17.0.dev20240418.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1210
+ tf_models_nightly-2.17.0.dev20240418.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1211
+ tf_models_nightly-2.17.0.dev20240418.dist-info/METADATA,sha256=Qga8KYfDdCPyinPp79D7UyN7doPZH066mZ98Wp5QlUM,1432
1212
+ tf_models_nightly-2.17.0.dev20240418.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1213
+ tf_models_nightly-2.17.0.dev20240418.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1214
+ tf_models_nightly-2.17.0.dev20240418.dist-info/RECORD,,