tf-models-nightly 2.19.0.dev20250224__py2.py3-none-any.whl → 2.19.0.dev20250226__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,7 +15,7 @@
15
15
  """Multitask training driver library."""
16
16
  # pytype: disable=attribute-error
17
17
  import os
18
- from typing import Any, List, Mapping, Optional, Tuple, Union
18
+ from typing import Any, List, Mapping, Optional, Tuple, Union, Callable
19
19
  from absl import logging
20
20
  import orbit
21
21
  import tensorflow as tf, tf_keras
@@ -157,6 +157,25 @@ def run_experiment(
157
157
  return model
158
158
 
159
159
 
160
+ TrainActionsFactoryType = Callable[
161
+ [
162
+ configs.MultiEvalExperimentConfig,
163
+ orbit.StandardTrainer,
164
+ str,
165
+ tf.train.CheckpointManager,
166
+ ],
167
+ List[orbit.Action],
168
+ ]
169
+ EvalActionsFactoryType = Callable[
170
+ [
171
+ configs.MultiEvalExperimentConfig,
172
+ orbit.AbstractEvaluator,
173
+ str,
174
+ ],
175
+ List[orbit.Action],
176
+ ]
177
+
178
+
160
179
  def run_experiment_with_multitask_eval(
161
180
  *,
162
181
  distribution_strategy: tf.distribute.Strategy,
@@ -171,6 +190,8 @@ def run_experiment_with_multitask_eval(
171
190
  eval_summary_manager: Optional[orbit.utils.SummaryManagerInterface] = None,
172
191
  best_ckpt_exporter_creator: Optional[Any] = train_utils
173
192
  .maybe_create_best_ckpt_exporter,
193
+ train_actions_factory: Optional[TrainActionsFactoryType] = None,
194
+ eval_actions_factory: Optional[EvalActionsFactoryType] = None,
174
195
  ) -> Tuple[Any, Any]:
175
196
  """Runs train/eval configured by the experiment params.
176
197
 
@@ -193,6 +214,8 @@ def run_experiment_with_multitask_eval(
193
214
  will be created internally for TensorBoard summaries by default from the
194
215
  `eval_summary_dir`.
195
216
  best_ckpt_exporter_creator: A functor for creating best checkpoint exporter.
217
+ train_actions_factory: Optional factory function to create train actions.
218
+ eval_actions_factory: Optional factory function to create eval actions.
196
219
 
197
220
  Returns:
198
221
  model: `tf_keras.Model` instance.
@@ -214,7 +237,6 @@ def run_experiment_with_multitask_eval(
214
237
 
215
238
  # Build the model or fetch the pre-cached one (which could be either
216
239
  # multi-task model or single task model).
217
- model = None
218
240
  if trainer is None:
219
241
  if isinstance(train_task, multitask.MultiTask):
220
242
  model = train_task.build_multitask_model()
@@ -254,6 +276,23 @@ def run_experiment_with_multitask_eval(
254
276
  checkpoint_interval=params.trainer.checkpoint_interval,
255
277
  init_fn=trainer.initialize if trainer else None)
256
278
 
279
+ if trainer and train_actions_factory:
280
+ # pytype: disable=wrong-keyword-args
281
+ train_actions = train_actions_factory(
282
+ params=params,
283
+ trainer=trainer,
284
+ model_dir=model_dir,
285
+ checkpoint_manager=checkpoint_manager,
286
+ )
287
+ # pytype: enable=wrong-keyword-args
288
+ else:
289
+ train_actions = None
290
+
291
+ if evaluator and eval_actions_factory:
292
+ eval_actions = eval_actions_factory(params, evaluator, model_dir)
293
+ else:
294
+ eval_actions = None
295
+
257
296
  controller = orbit.Controller(
258
297
  strategy=distribution_strategy,
259
298
  trainer=trainer,
@@ -266,7 +305,10 @@ def run_experiment_with_multitask_eval(
266
305
  (save_summary) else None,
267
306
  eval_summary_manager=eval_summary_manager,
268
307
  summary_interval=params.trainer.summary_interval if
269
- (save_summary) else None)
308
+ (save_summary) else None,
309
+ train_actions=train_actions,
310
+ eval_actions=eval_actions,
311
+ )
270
312
 
271
313
  logging.info('Starts to execute mode: %s', mode)
272
314
  with distribution_strategy.scope():
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.19.0.dev20250224
3
+ Version: 2.19.0.dev20250226
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -224,7 +224,7 @@ official/modeling/multitask/multitask.py,sha256=DV-ysfhPiIZgsrzZNylsPBxKNBf_xzPx
224
224
  official/modeling/multitask/task_sampler.py,sha256=SGVVdjMb5oG4vnCczpfdgBtbsdsXiyBLl9si_0V6nko,4897
225
225
  official/modeling/multitask/task_sampler_test.py,sha256=wkPTp1LCNx4uJbfHfVbKrQULQxumZS3ctBemEbnMokk,3037
226
226
  official/modeling/multitask/test_utils.py,sha256=fPi_TxtzHy_NGTYmdDAQay9TRjwG70UWHwysDqol4jw,4315
227
- official/modeling/multitask/train_lib.py,sha256=XmWRj-IOKijIo9EDCRQNRU9rj3qiIDhExiPxDN6mi6A,11633
227
+ official/modeling/multitask/train_lib.py,sha256=ZHRN6afigqcupAz61P6qBEh_RhomAwVgvjgkRHSv0eU,12827
228
228
  official/modeling/multitask/train_lib_test.py,sha256=F_uDEQhSlXq9BH1QBlamUMQ_vRUflli7ef2P9MUKXeg,4766
229
229
  official/modeling/optimization/__init__.py,sha256=BI86b89P0xOksBlzOmbAdWOhlJQe6cxtGIRN_zqCwa8,1201
230
230
  official/modeling/optimization/adafactor_optimizer.py,sha256=zEHHFg9iH1UwcBSzzYE8bCNWln7QIIftXJEXTGTtCtE,792
@@ -1248,9 +1248,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbda
1248
1248
  tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
1249
1249
  tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
1250
1250
  tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
1251
- tf_models_nightly-2.19.0.dev20250224.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1252
- tf_models_nightly-2.19.0.dev20250224.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1253
- tf_models_nightly-2.19.0.dev20250224.dist-info/METADATA,sha256=Cd694VFHARcVs9nk3pbtQAcpd7fP1TMWRRA265kwmkg,1432
1254
- tf_models_nightly-2.19.0.dev20250224.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1255
- tf_models_nightly-2.19.0.dev20250224.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1256
- tf_models_nightly-2.19.0.dev20250224.dist-info/RECORD,,
1251
+ tf_models_nightly-2.19.0.dev20250226.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1252
+ tf_models_nightly-2.19.0.dev20250226.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1253
+ tf_models_nightly-2.19.0.dev20250226.dist-info/METADATA,sha256=pXPpfanhACUR5sKJVyUH2c8SY0vi-8tcXz-JPP71KpQ,1432
1254
+ tf_models_nightly-2.19.0.dev20250226.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1255
+ tf_models_nightly-2.19.0.dev20250226.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1256
+ tf_models_nightly-2.19.0.dev20250226.dist-info/RECORD,,