tf-models-nightly 2.19.0.dev20250225__py2.py3-none-any.whl → 2.19.0.dev20250227__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,7 +15,7 @@
15
15
  """Multitask training driver library."""
16
16
  # pytype: disable=attribute-error
17
17
  import os
18
- from typing import Any, List, Mapping, Optional, Tuple, Union
18
+ from typing import Any, List, Mapping, Optional, Tuple, Union, Callable
19
19
  from absl import logging
20
20
  import orbit
21
21
  import tensorflow as tf, tf_keras
@@ -78,15 +78,8 @@ def run_experiment(
78
78
  is_training = 'train' in mode
79
79
  is_eval = 'eval' in mode
80
80
  with distribution_strategy.scope():
81
- optimizer = train_utils.create_optimizer(task, params)
82
- kwargs = dict(multi_task=task, multi_task_model=model, optimizer=optimizer)
83
- if params.trainer.trainer_type == 'interleaving':
84
- sampler = task_sampler.get_task_sampler(params.trainer.task_sampler,
85
- task.task_weights)
86
- kwargs.update(dict(task_sampler=sampler))
87
- if trainer is None:
88
- trainer = TRAINERS[params.trainer.trainer_type](
89
- **kwargs) if is_training else None
81
+ if is_training and trainer is None:
82
+ trainer = get_trainer(distribution_strategy, params, task, model)
90
83
  if is_eval:
91
84
  eval_steps = task.task_eval_steps
92
85
  evaluator = evaluator_lib.MultiTaskEvaluator(
@@ -157,6 +150,57 @@ def run_experiment(
157
150
  return model
158
151
 
159
152
 
153
+ def get_trainer(
154
+ distribution_strategy: tf.distribute.Strategy,
155
+ params: configs.MultiEvalExperimentConfig,
156
+ task: multitask.MultiTask,
157
+ model: base_model.MultiTaskBaseModel | tf_keras.Model,
158
+ ) -> orbit.StandardTrainer:
159
+ """Creates a multi-task trainer for the given task.
160
+
161
+ Args:
162
+ distribution_strategy: A distribution strategy.
163
+ params: ExperimentConfig instance.
164
+ task: A MultiTaskTask instance.
165
+ model: A MultiTaskBaseModel instance.
166
+
167
+ Returns:
168
+ An Orbit trainer instance.
169
+ """
170
+ with distribution_strategy.scope():
171
+ kwargs = dict(
172
+ multi_task=task,
173
+ multi_task_model=model,
174
+ optimizer=train_utils.create_optimizer(task, params),
175
+ )
176
+ if params.trainer.trainer_type == 'interleaving':
177
+ kwargs.update(
178
+ task_sampler=task_sampler.get_task_sampler(
179
+ params.trainer.task_sampler, task.task_weights
180
+ )
181
+ )
182
+ return TRAINERS[params.trainer.trainer_type](**kwargs)
183
+
184
+
185
+ TrainActionsFactoryType = Callable[
186
+ [
187
+ configs.MultiEvalExperimentConfig,
188
+ orbit.StandardTrainer,
189
+ str,
190
+ tf.train.CheckpointManager,
191
+ ],
192
+ List[orbit.Action],
193
+ ]
194
+ EvalActionsFactoryType = Callable[
195
+ [
196
+ configs.MultiEvalExperimentConfig,
197
+ orbit.AbstractEvaluator,
198
+ str,
199
+ ],
200
+ List[orbit.Action],
201
+ ]
202
+
203
+
160
204
  def run_experiment_with_multitask_eval(
161
205
  *,
162
206
  distribution_strategy: tf.distribute.Strategy,
@@ -171,6 +215,8 @@ def run_experiment_with_multitask_eval(
171
215
  eval_summary_manager: Optional[orbit.utils.SummaryManagerInterface] = None,
172
216
  best_ckpt_exporter_creator: Optional[Any] = train_utils
173
217
  .maybe_create_best_ckpt_exporter,
218
+ train_actions_factory: Optional[TrainActionsFactoryType] = None,
219
+ eval_actions_factory: Optional[EvalActionsFactoryType] = None,
174
220
  ) -> Tuple[Any, Any]:
175
221
  """Runs train/eval configured by the experiment params.
176
222
 
@@ -193,6 +239,8 @@ def run_experiment_with_multitask_eval(
193
239
  will be created internally for TensorBoard summaries by default from the
194
240
  `eval_summary_dir`.
195
241
  best_ckpt_exporter_creator: A functor for creating best checkpoint exporter.
242
+ train_actions_factory: Optional factory function to create train actions.
243
+ eval_actions_factory: Optional factory function to create eval actions.
196
244
 
197
245
  Returns:
198
246
  model: `tf_keras.Model` instance.
@@ -214,7 +262,6 @@ def run_experiment_with_multitask_eval(
214
262
 
215
263
  # Build the model or fetch the pre-cached one (which could be either
216
264
  # multi-task model or single task model).
217
- model = None
218
265
  if trainer is None:
219
266
  if isinstance(train_task, multitask.MultiTask):
220
267
  model = train_task.build_multitask_model()
@@ -254,6 +301,23 @@ def run_experiment_with_multitask_eval(
254
301
  checkpoint_interval=params.trainer.checkpoint_interval,
255
302
  init_fn=trainer.initialize if trainer else None)
256
303
 
304
+ if trainer and train_actions_factory:
305
+ # pytype: disable=wrong-keyword-args
306
+ train_actions = train_actions_factory(
307
+ params=params,
308
+ trainer=trainer,
309
+ model_dir=model_dir,
310
+ checkpoint_manager=checkpoint_manager,
311
+ )
312
+ # pytype: enable=wrong-keyword-args
313
+ else:
314
+ train_actions = None
315
+
316
+ if evaluator and eval_actions_factory:
317
+ eval_actions = eval_actions_factory(params, evaluator, model_dir)
318
+ else:
319
+ eval_actions = None
320
+
257
321
  controller = orbit.Controller(
258
322
  strategy=distribution_strategy,
259
323
  trainer=trainer,
@@ -266,7 +330,10 @@ def run_experiment_with_multitask_eval(
266
330
  (save_summary) else None,
267
331
  eval_summary_manager=eval_summary_manager,
268
332
  summary_interval=params.trainer.summary_interval if
269
- (save_summary) else None)
333
+ (save_summary) else None,
334
+ train_actions=train_actions,
335
+ eval_actions=eval_actions,
336
+ )
270
337
 
271
338
  logging.info('Starts to execute mode: %s', mode)
272
339
  with distribution_strategy.scope():
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.19.0.dev20250225
3
+ Version: 2.19.0.dev20250227
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -224,7 +224,7 @@ official/modeling/multitask/multitask.py,sha256=DV-ysfhPiIZgsrzZNylsPBxKNBf_xzPx
224
224
  official/modeling/multitask/task_sampler.py,sha256=SGVVdjMb5oG4vnCczpfdgBtbsdsXiyBLl9si_0V6nko,4897
225
225
  official/modeling/multitask/task_sampler_test.py,sha256=wkPTp1LCNx4uJbfHfVbKrQULQxumZS3ctBemEbnMokk,3037
226
226
  official/modeling/multitask/test_utils.py,sha256=fPi_TxtzHy_NGTYmdDAQay9TRjwG70UWHwysDqol4jw,4315
227
- official/modeling/multitask/train_lib.py,sha256=XmWRj-IOKijIo9EDCRQNRU9rj3qiIDhExiPxDN6mi6A,11633
227
+ official/modeling/multitask/train_lib.py,sha256=TbqjoGlrNxxbq3XxKZ04PTxtZFQUlfwX5sd1-9dm5QI,13403
228
228
  official/modeling/multitask/train_lib_test.py,sha256=F_uDEQhSlXq9BH1QBlamUMQ_vRUflli7ef2P9MUKXeg,4766
229
229
  official/modeling/optimization/__init__.py,sha256=BI86b89P0xOksBlzOmbAdWOhlJQe6cxtGIRN_zqCwa8,1201
230
230
  official/modeling/optimization/adafactor_optimizer.py,sha256=zEHHFg9iH1UwcBSzzYE8bCNWln7QIIftXJEXTGTtCtE,792
@@ -1248,9 +1248,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbda
1248
1248
  tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
1249
1249
  tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
1250
1250
  tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
1251
- tf_models_nightly-2.19.0.dev20250225.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1252
- tf_models_nightly-2.19.0.dev20250225.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1253
- tf_models_nightly-2.19.0.dev20250225.dist-info/METADATA,sha256=vS0DCt_fQoAAf-PhHxDVQVBSvWehhkLQBvFpZIkYaks,1432
1254
- tf_models_nightly-2.19.0.dev20250225.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1255
- tf_models_nightly-2.19.0.dev20250225.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1256
- tf_models_nightly-2.19.0.dev20250225.dist-info/RECORD,,
1251
+ tf_models_nightly-2.19.0.dev20250227.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1252
+ tf_models_nightly-2.19.0.dev20250227.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1253
+ tf_models_nightly-2.19.0.dev20250227.dist-info/METADATA,sha256=c2C7Tg_sj_ybjTh9TtaYbTlf7RMk16JqrLmVI6K3DUI,1432
1254
+ tf_models_nightly-2.19.0.dev20250227.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1255
+ tf_models_nightly-2.19.0.dev20250227.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1256
+ tf_models_nightly-2.19.0.dev20250227.dist-info/RECORD,,