tf-models-nightly 2.17.0.dev20240407__py2.py3-none-any.whl → 2.17.0.dev20240409__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -40,7 +40,7 @@ class LossMetric(tf_keras.metrics.Metric):
40
40
  ... true_logits=tf.constant([1, 2, 3, 4])
41
41
  ... is_treatment=tf.constant([True, False, True, False]),
42
42
  ... )
43
- >>> sliced_loss(y_true=tf.zeros(4), y_pred=y_pred)
43
+ >>> sliced_loss(y_true=y_true, y_pred=y_pred)
44
44
  {
45
45
  "loss": 2.5
46
46
  "loss/control": 4.0
@@ -58,7 +58,9 @@ class LossMetric(tf_keras.metrics.Metric):
58
58
 
59
59
  def __init__(
60
60
  self,
61
- loss_fn: Callable[[tf.Tensor, tf.Tensor], tf.Tensor],
61
+ loss_fn: (
62
+ Callable[[tf.Tensor, tf.Tensor], tf.Tensor] | tf_keras.metrics.Metric
63
+ ),
62
64
  from_logits: bool = True,
63
65
  name: str = "loss",
64
66
  dtype: tf.DType = tf.float32,
@@ -67,27 +69,59 @@ class LossMetric(tf_keras.metrics.Metric):
67
69
  """Initializes the instance.
68
70
 
69
71
  Args:
70
- loss_fn: The loss function to apply between the targets and model outputs,
71
- with signature `loss_fn(y_true, y_pred, **loss_fn_kwargs)`.
72
+ loss_fn: The loss function or Keras metric to apply with call signature
73
+ `__call__(y_true: tf,Tensor, y_pred: tf.Tensor, **loss_fn_kwargs)`. Note
74
+ that the `loss_fn_kwargs` will not be passed to the `__call__` method if
75
+ `loss_fn` is a Keras metric.
72
76
  from_logits: Specifies whether the true logits or true predictions should
73
77
  be used from the model outputs to compute the loss. Defaults to using
74
78
  the true logits.
75
- name: Optional name for the instance.
76
- dtype: Optional data type for the instance.
79
+ name: Optional name for the instance. If `loss_fn` is a Keras metric then
80
+ its name will be used instead.
81
+ dtype: Optional data type for the instance. If `loss_fn` is a Keras metric
82
+ then its `dtype` will be used instead.
77
83
  **loss_fn_kwargs: The keyword arguments that are passed on to `loss_fn`.
84
+ These arguments will be ignored if `loss_fn` is a Keras metric.
78
85
  """
86
+ # Do not accept Loss objects as they reduce tensors before weighting.
87
+ if isinstance(loss_fn, tf_keras.losses.Loss):
88
+ raise TypeError(
89
+ "`loss_fn` cannot be a Keras `Loss` object, pass a non-reducing loss"
90
+ " function or a metric instance instead."
91
+ )
92
+
93
+ if isinstance(loss_fn, tf_keras.metrics.Metric):
94
+ name = loss_fn.name
95
+ dtype = loss_fn.dtype
96
+
79
97
  super().__init__(name=name, dtype=dtype)
80
98
 
81
99
  self._loss_fn = loss_fn
82
100
  self._from_logits = from_logits
83
101
  self._loss_fn_kwargs = loss_fn_kwargs
84
102
 
85
- if "from_logits" in inspect.signature(loss_fn).parameters:
86
- self._loss_fn_kwargs.update({"from_logits": from_logits})
103
+ if isinstance(loss_fn, tf_keras.metrics.Metric):
104
+ metric_from_logits = loss_fn.get_config().get("from_logits", from_logits)
105
+ if from_logits != metric_from_logits:
106
+ raise ValueError(
107
+ f"Value passed to `from_logits` ({from_logits}) is conflicting with"
108
+ " the `from_logits` value passed to the `loss_fn` metric"
109
+ f" ({metric_from_logits}). Ensure that they have the same value."
110
+ )
111
+
112
+ self._treatment_sliced_loss = (
113
+ treatment_sliced_metric.TreatmentSlicedMetric(loss_fn)
114
+ )
87
115
 
88
- self._treatment_sliced_loss = treatment_sliced_metric.TreatmentSlicedMetric(
89
- tf_keras.metrics.Mean(name=name, dtype=dtype)
90
- )
116
+ else:
117
+ if "from_logits" in inspect.signature(loss_fn).parameters:
118
+ self._loss_fn_kwargs.update({"from_logits": from_logits})
119
+
120
+ self._treatment_sliced_loss = (
121
+ treatment_sliced_metric.TreatmentSlicedMetric(
122
+ tf_keras.metrics.Mean(name=name, dtype=dtype)
123
+ )
124
+ )
91
125
 
92
126
  def update_state(
93
127
  self,
@@ -115,16 +149,21 @@ class LossMetric(tf_keras.metrics.Metric):
115
149
  f" {type(y_pred)} instead."
116
150
  )
117
151
 
118
- if self._from_logits:
119
- loss = self._loss_fn(y_true, y_pred.true_logits)
120
- else:
121
- loss = self._loss_fn(y_true, y_pred.true_predictions)
152
+ pred = y_pred.true_logits if self._from_logits else y_pred.true_predictions
122
153
 
123
- self._treatment_sliced_loss.update_state(
124
- values=loss,
125
- is_treatment=y_pred.is_treatment,
126
- sample_weight=sample_weight,
127
- )
154
+ if isinstance(self._loss_fn, tf_keras.metrics.Metric):
155
+ self._treatment_sliced_loss.update_state(
156
+ y_true,
157
+ y_pred=pred,
158
+ is_treatment=y_pred.is_treatment,
159
+ sample_weight=sample_weight,
160
+ )
161
+ else:
162
+ self._treatment_sliced_loss.update_state(
163
+ values=self._loss_fn(y_true, pred, **self._loss_fn_kwargs),
164
+ is_treatment=y_pred.is_treatment,
165
+ sample_weight=sample_weight,
166
+ )
128
167
 
129
168
  def result(self) -> dict[str, tf.Tensor]:
130
169
  return self._treatment_sliced_loss.result()
@@ -25,7 +25,7 @@ from official.recommendation.uplift import types
25
25
  from official.recommendation.uplift.metrics import loss_metric
26
26
 
27
27
 
28
- class TrueLogitsTreatmentLossTest(
28
+ class LossMetricTest(
29
29
  keras_test_case.KerasTestCase, parameterized.TestCase
30
30
  ):
31
31
 
@@ -64,6 +64,20 @@ class TrueLogitsTreatmentLossTest(
64
64
  "loss/treatment": 1.0,
65
65
  },
66
66
  },
67
+ {
68
+ "testcase_name": "unweighted_metric",
69
+ "loss_fn": tf_keras.metrics.MeanSquaredError(name="loss"),
70
+ "from_logits": False,
71
+ "y_true": tf.constant([[0], [0], [2], [2]]),
72
+ "y_pred": tf.constant([[1], [2], [3], [4]]),
73
+ "is_treatment": tf.constant([[True], [False], [True], [False]]),
74
+ "sample_weight": None,
75
+ "expected_losses": {
76
+ "loss": 2.5,
77
+ "loss/control": 4.0,
78
+ "loss/treatment": 1.0,
79
+ },
80
+ },
67
81
  {
68
82
  "testcase_name": "weighted",
69
83
  "loss_fn": tf_keras.losses.mean_absolute_error,
@@ -78,6 +92,20 @@ class TrueLogitsTreatmentLossTest(
78
92
  "loss/treatment": 1.0,
79
93
  },
80
94
  },
95
+ {
96
+ "testcase_name": "weighted_keras_metric",
97
+ "loss_fn": tf_keras.metrics.MeanAbsoluteError(name="loss"),
98
+ "from_logits": False,
99
+ "y_true": tf.constant([[0], [0], [2], [7]]),
100
+ "y_pred": tf.constant([[1], [2], [3], [4]]),
101
+ "is_treatment": tf.constant([[True], [False], [True], [False]]),
102
+ "sample_weight": tf.constant([[0.5], [0.5], [0.7], [1.8]]),
103
+ "expected_losses": {
104
+ "loss": np.average([1, 2, 1, 3], weights=[0.5, 0.5, 0.7, 1.8]),
105
+ "loss/control": np.average([2, 3], weights=[0.5, 1.8]),
106
+ "loss/treatment": 1.0,
107
+ },
108
+ },
81
109
  {
82
110
  "testcase_name": "only_control",
83
111
  "loss_fn": tf_keras.metrics.mean_squared_error,
@@ -92,6 +120,20 @@ class TrueLogitsTreatmentLossTest(
92
120
  "loss/treatment": 0.0,
93
121
  },
94
122
  },
123
+ {
124
+ "testcase_name": "only_control_metric",
125
+ "loss_fn": tf_keras.metrics.MeanSquaredError(name="loss"),
126
+ "from_logits": False,
127
+ "y_true": tf.constant([[0], [1], [5]]),
128
+ "y_pred": tf.constant([[1], [2], [5]]),
129
+ "is_treatment": tf.constant([[False], [False], [False]]),
130
+ "sample_weight": tf.constant([1, 0, 1]),
131
+ "expected_losses": {
132
+ "loss": 0.5,
133
+ "loss/control": 0.5,
134
+ "loss/treatment": 0.0,
135
+ },
136
+ },
95
137
  {
96
138
  "testcase_name": "only_treatment",
97
139
  "loss_fn": tf_keras.metrics.mean_absolute_error,
@@ -106,6 +148,20 @@ class TrueLogitsTreatmentLossTest(
106
148
  "loss/treatment": 0.5,
107
149
  },
108
150
  },
151
+ {
152
+ "testcase_name": "only_treatment_metric",
153
+ "loss_fn": tf_keras.metrics.MeanAbsoluteError(name="loss"),
154
+ "from_logits": False,
155
+ "y_true": tf.constant([[0], [1], [5]]),
156
+ "y_pred": tf.constant([[1], [2], [5]]),
157
+ "is_treatment": tf.constant([[True], [True], [True]]),
158
+ "sample_weight": tf.constant([1, 0, 1]),
159
+ "expected_losses": {
160
+ "loss": 0.5,
161
+ "loss/control": 0.0,
162
+ "loss/treatment": 0.5,
163
+ },
164
+ },
109
165
  {
110
166
  "testcase_name": "one_entry",
111
167
  "loss_fn": tf.nn.log_poisson_loss,
@@ -138,6 +194,48 @@ class TrueLogitsTreatmentLossTest(
138
194
  "loss/treatment": 0.0,
139
195
  },
140
196
  },
197
+ {
198
+ "testcase_name": "no_entry_metric",
199
+ "loss_fn": tf_keras.metrics.BinaryCrossentropy(name="loss"),
200
+ "from_logits": False,
201
+ "y_true": tf.constant([[]]),
202
+ "y_pred": tf.constant([[]]),
203
+ "is_treatment": tf.constant([[]]),
204
+ "sample_weight": tf.constant([[]]),
205
+ "expected_losses": {
206
+ "loss": 0.0,
207
+ "loss/control": 0.0,
208
+ "loss/treatment": 0.0,
209
+ },
210
+ },
211
+ {
212
+ "testcase_name": "auc_metric",
213
+ "loss_fn": tf_keras.metrics.AUC(from_logits=True, name="loss"),
214
+ "from_logits": True,
215
+ "y_true": tf.constant([[0], [0], [1], [1]]),
216
+ "y_pred": tf.constant([[0], [0.5], [0.3], [0.9]]),
217
+ "is_treatment": tf.constant([[1], [1], [1], [1]]),
218
+ "sample_weight": None,
219
+ "expected_losses": {
220
+ "loss": 0.75,
221
+ "loss/control": 0.0,
222
+ "loss/treatment": 0.75,
223
+ },
224
+ },
225
+ {
226
+ "testcase_name": "loss_fn_with_from_logits",
227
+ "loss_fn": tf_keras.losses.binary_crossentropy,
228
+ "from_logits": True,
229
+ "y_true": tf.constant([[0.0, 1.0]]),
230
+ "y_pred": tf.constant([[0.0, 1.0]]),
231
+ "is_treatment": tf.constant([[0], [0]]),
232
+ "sample_weight": None,
233
+ "expected_losses": {
234
+ "loss": 0.50320446,
235
+ "loss/control": 0.50320446,
236
+ "loss/treatment": 0.0,
237
+ },
238
+ },
141
239
  )
142
240
  def test_metric_computes_sliced_losses(
143
241
  self,
@@ -237,10 +335,17 @@ class TrueLogitsTreatmentLossTest(
237
335
  metric.reset_states()
238
336
  self.assertEqual(expected_initial_result, metric.result())
239
337
 
240
- def test_metric_is_configurable(self):
241
- metric = loss_metric.LossMetric(
242
- tf_keras.losses.binary_crossentropy, from_logits=True, name="bce_loss"
243
- )
338
+ @parameterized.parameters(
339
+ tf_keras.losses.binary_crossentropy,
340
+ tf_keras.metrics.BinaryCrossentropy(from_logits=True, name="bce_loss"),
341
+ )
342
+ def test_metric_is_configurable(
343
+ self,
344
+ loss_fn: (
345
+ Callable[[tf.Tensor, tf.Tensor], tf.Tensor] | tf_keras.metrics.Metric
346
+ ),
347
+ ):
348
+ metric = loss_metric.LossMetric(loss_fn, from_logits=True, name="bce_loss")
244
349
  self.assertLayerConfigurable(
245
350
  layer=metric,
246
351
  y_true=tf.constant([[1], [1], [0]]),
@@ -261,6 +366,19 @@ class TrueLogitsTreatmentLossTest(
261
366
  ):
262
367
  metric.update_state(y_true=tf.ones((3, 1)), y_pred=tf.ones((3, 1)))
263
368
 
369
+ def test_passing_loss_object_raises_error(self):
370
+ with self.assertRaisesRegex(
371
+ TypeError, "`loss_fn` cannot be a Keras `Loss` object"
372
+ ):
373
+ loss_metric.LossMetric(loss_fn=tf_keras.losses.MeanAbsoluteError())
374
+
375
+ def test_conflicting_from_logits_values_raises_error(self):
376
+ with self.assertRaises(ValueError):
377
+ loss_metric.LossMetric(
378
+ loss_fn=tf_keras.metrics.BinaryCrossentropy(from_logits=True),
379
+ from_logits=False,
380
+ )
381
+
264
382
 
265
383
  if __name__ == "__main__":
266
384
  tf.test.main()
@@ -39,11 +39,11 @@ class Parser(parser.Parser):
39
39
 
40
40
  def __init__(self,
41
41
  output_size,
42
- min_level,
42
+ min_level: int | None,
43
43
  max_level,
44
- num_scales,
45
- aspect_ratios,
46
- anchor_size,
44
+ num_scales: int | None,
45
+ aspect_ratios: list[float] | None,
46
+ anchor_size: float | None,
47
47
  match_threshold=0.5,
48
48
  unmatched_threshold=0.5,
49
49
  box_coder_weights=None,
@@ -63,20 +63,28 @@ class Parser(parser.Parser):
63
63
  keep_aspect_ratio=True):
64
64
  """Initializes parameters for parsing annotations in the dataset.
65
65
 
66
+ If one provides `input_anchor` when calling `_parse_eval_data()` and
67
+ `_parse_train_data()`, the `min_level`, `num_scales`, `aspect_ratios`, and
68
+ `anchor_size` can be `None`.
69
+
66
70
  Args:
67
71
  output_size: `Tensor` or `list` for [height, width] of output image. The
68
72
  output_size should be divided by the largest feature stride 2^max_level.
69
73
  min_level: `int` number of minimum level of the output feature pyramid.
74
+ Can be `None` if `input_anchor` is provided in `_parse_*_data()`.
70
75
  max_level: `int` number of maximum level of the output feature pyramid.
71
76
  num_scales: `int` number representing intermediate scales added on each
72
77
  level. For instances, num_scales=2 adds one additional intermediate
73
- anchor scales [2^0, 2^0.5] on each level.
78
+ anchor scales [2^0, 2^0.5] on each level. Can be `None` if
79
+ `input_anchor` is provided in `_parse_*_data()`.
74
80
  aspect_ratios: `list` of float numbers representing the aspect ratio
75
81
  anchors added on each level. The number indicates the ratio of width to
76
82
  height. For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors
77
- on each scale level.
83
+ on each scale level. Can be `None` if `input_anchor` is provided in
84
+ `_parse_*_data()`.
78
85
  anchor_size: `float` number representing the scale of size of the base
79
- anchor to the feature stride 2^level.
86
+ anchor to the feature stride 2^level. Can be `None` if `input_anchor` is
87
+ provided in `_parse_*_data()`.
80
88
  match_threshold: `float` number between 0 and 1 representing the
81
89
  lower-bound threshold to assign positive labels for anchors. An anchor
82
90
  with a score over the threshold is labeled positive.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.17.0.dev20240407
3
+ Version: 2.17.0.dev20240409
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -912,8 +912,8 @@ official/recommendation/uplift/metrics/label_mean.py,sha256=ECaes7FZmsksnwySn7jf
912
912
  official/recommendation/uplift/metrics/label_mean_test.py,sha256=b_d3lNlpkDm2xKLUkxfiXeQg7pjL8HNx7y9NaYarpV0,7083
913
913
  official/recommendation/uplift/metrics/label_variance.py,sha256=9DCl42BJkehxfWD3pSbZnRNvwfhVM6VyHwivGdaU72s,3610
914
914
  official/recommendation/uplift/metrics/label_variance_test.py,sha256=k0mdEU1WU53-HIEO5HGtfp1MleifD-h4bZNKtTvM3Ws,7681
915
- official/recommendation/uplift/metrics/loss_metric.py,sha256=8GAcqehBgHVz0TZn6obRhyQlLhz_-jmIpP0EAvJLBYI,4870
916
- official/recommendation/uplift/metrics/loss_metric_test.py,sha256=IwbWd6PQOYG4uGzobWNxP2Mh6PrVTEcFV_-6GGxaJTU,9346
915
+ official/recommendation/uplift/metrics/loss_metric.py,sha256=8Lfi4FNR6uj8eLSdT5t0XjEzTZtwL-xv5e5KNGzrYWU,6498
916
+ official/recommendation/uplift/metrics/loss_metric_test.py,sha256=yropE1t8PblsJ_kJEno5ratkY_ka81PHgcfroWtRKVI,13768
917
917
  official/recommendation/uplift/metrics/metric_configs.py,sha256=Z-r79orE4EycQ5TJ7xdI5LhjOHT3wzChYyDxcxGqLXk,1670
918
918
  official/recommendation/uplift/metrics/sliced_metric.py,sha256=O2I2apZK6IfOQK9Q_mgSiTUCnGokczp4e14zrrYNeRU,8564
919
919
  official/recommendation/uplift/metrics/sliced_metric_test.py,sha256=dhY41X8lqT_WW04XLjyjDerZwujEBGeTXtxf4NkYThw,11359
@@ -987,7 +987,7 @@ official/vision/dataloaders/input_reader.py,sha256=CHojw8PJKf74jl8Q3rtH2ylwhmTYg
987
987
  official/vision/dataloaders/input_reader_factory.py,sha256=WpvSA8qyqAo3wkmme4WqXpICBVg0SuR6_nNWHZ0ECM0,1623
988
988
  official/vision/dataloaders/maskrcnn_input.py,sha256=iCc08yYD-7mvIPojgBjm_nSvoQACXWCIeZNZN8CfXSs,16822
989
989
  official/vision/dataloaders/parser.py,sha256=nMXnhigMa_ascSJ2OK88xi4HdE9xvfL3G4oMrHau-t4,2315
990
- official/vision/dataloaders/retinanet_input.py,sha256=joxJL4hQVPw-FW5iUc7RsxP60N7iYGRuVFpU3gC5flE,18291
990
+ official/vision/dataloaders/retinanet_input.py,sha256=bU1fDpJuOtBZVJSg3Fzaku2PjxHh22E4d3M7B3Vu8ZQ,18831
991
991
  official/vision/dataloaders/segmentation_input.py,sha256=Klg5KAChYZDRvqzZfyIzdPy54rTlWYZp2AotolD3WX8,12934
992
992
  official/vision/dataloaders/tf_example_decoder.py,sha256=9yCT6uSLMpmw50w7zdaRR_BXy6vIvliLZntrYAgzD18,8647
993
993
  official/vision/dataloaders/tf_example_decoder_test.py,sha256=PHxneXHn5-eIMdmk1uI4IPLa178kTCifa4EF53ik2Jo,12629
@@ -1203,9 +1203,9 @@ tensorflow_models/__init__.py,sha256=etxw45SHxuwFCRX5qGxGMP83II0JfJulzNl5GSNJvhw
1203
1203
  tensorflow_models/tensorflow_models_test.py,sha256=AxUYUdiQn416UR7jg0h6rmv688esvlKDfpyDCIQkF18,1395
1204
1204
  tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
1205
1205
  tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
1206
- tf_models_nightly-2.17.0.dev20240407.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1207
- tf_models_nightly-2.17.0.dev20240407.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1208
- tf_models_nightly-2.17.0.dev20240407.dist-info/METADATA,sha256=fmtDQInbfhr3S9Uy1tWUiCvoT9_bN6MLoG6GG3jQkKg,1432
1209
- tf_models_nightly-2.17.0.dev20240407.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1210
- tf_models_nightly-2.17.0.dev20240407.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1211
- tf_models_nightly-2.17.0.dev20240407.dist-info/RECORD,,
1206
+ tf_models_nightly-2.17.0.dev20240409.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1207
+ tf_models_nightly-2.17.0.dev20240409.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1208
+ tf_models_nightly-2.17.0.dev20240409.dist-info/METADATA,sha256=738N3KUX3sKTeF-7mDL3_2r1htitc4BCT9A4QMgB9U4,1432
1209
+ tf_models_nightly-2.17.0.dev20240409.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1210
+ tf_models_nightly-2.17.0.dev20240409.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1211
+ tf_models_nightly-2.17.0.dev20240409.dist-info/RECORD,,