tf-models-nightly 2.19.0.dev20250226__py2.py3-none-any.whl → 2.19.0.dev20250227__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -78,15 +78,8 @@ def run_experiment(
78
78
  is_training = 'train' in mode
79
79
  is_eval = 'eval' in mode
80
80
  with distribution_strategy.scope():
81
- optimizer = train_utils.create_optimizer(task, params)
82
- kwargs = dict(multi_task=task, multi_task_model=model, optimizer=optimizer)
83
- if params.trainer.trainer_type == 'interleaving':
84
- sampler = task_sampler.get_task_sampler(params.trainer.task_sampler,
85
- task.task_weights)
86
- kwargs.update(dict(task_sampler=sampler))
87
- if trainer is None:
88
- trainer = TRAINERS[params.trainer.trainer_type](
89
- **kwargs) if is_training else None
81
+ if is_training and trainer is None:
82
+ trainer = get_trainer(distribution_strategy, params, task, model)
90
83
  if is_eval:
91
84
  eval_steps = task.task_eval_steps
92
85
  evaluator = evaluator_lib.MultiTaskEvaluator(
@@ -157,6 +150,38 @@ def run_experiment(
157
150
  return model
158
151
 
159
152
 
153
+ def get_trainer(
154
+ distribution_strategy: tf.distribute.Strategy,
155
+ params: configs.MultiEvalExperimentConfig,
156
+ task: multitask.MultiTask,
157
+ model: base_model.MultiTaskBaseModel | tf_keras.Model,
158
+ ) -> orbit.StandardTrainer:
159
+ """Creates a multi-task trainer for the given task.
160
+
161
+ Args:
162
+ distribution_strategy: A distribution strategy.
163
+ params: ExperimentConfig instance.
164
+ task: A MultiTaskTask instance.
165
+ model: A MultiTaskBaseModel instance.
166
+
167
+ Returns:
168
+ An Orbit trainer instance.
169
+ """
170
+ with distribution_strategy.scope():
171
+ kwargs = dict(
172
+ multi_task=task,
173
+ multi_task_model=model,
174
+ optimizer=train_utils.create_optimizer(task, params),
175
+ )
176
+ if params.trainer.trainer_type == 'interleaving':
177
+ kwargs.update(
178
+ task_sampler=task_sampler.get_task_sampler(
179
+ params.trainer.task_sampler, task.task_weights
180
+ )
181
+ )
182
+ return TRAINERS[params.trainer.trainer_type](**kwargs)
183
+
184
+
160
185
  TrainActionsFactoryType = Callable[
161
186
  [
162
187
  configs.MultiEvalExperimentConfig,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.19.0.dev20250226
3
+ Version: 2.19.0.dev20250227
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -224,7 +224,7 @@ official/modeling/multitask/multitask.py,sha256=DV-ysfhPiIZgsrzZNylsPBxKNBf_xzPx
224
224
  official/modeling/multitask/task_sampler.py,sha256=SGVVdjMb5oG4vnCczpfdgBtbsdsXiyBLl9si_0V6nko,4897
225
225
  official/modeling/multitask/task_sampler_test.py,sha256=wkPTp1LCNx4uJbfHfVbKrQULQxumZS3ctBemEbnMokk,3037
226
226
  official/modeling/multitask/test_utils.py,sha256=fPi_TxtzHy_NGTYmdDAQay9TRjwG70UWHwysDqol4jw,4315
227
- official/modeling/multitask/train_lib.py,sha256=ZHRN6afigqcupAz61P6qBEh_RhomAwVgvjgkRHSv0eU,12827
227
+ official/modeling/multitask/train_lib.py,sha256=TbqjoGlrNxxbq3XxKZ04PTxtZFQUlfwX5sd1-9dm5QI,13403
228
228
  official/modeling/multitask/train_lib_test.py,sha256=F_uDEQhSlXq9BH1QBlamUMQ_vRUflli7ef2P9MUKXeg,4766
229
229
  official/modeling/optimization/__init__.py,sha256=BI86b89P0xOksBlzOmbAdWOhlJQe6cxtGIRN_zqCwa8,1201
230
230
  official/modeling/optimization/adafactor_optimizer.py,sha256=zEHHFg9iH1UwcBSzzYE8bCNWln7QIIftXJEXTGTtCtE,792
@@ -1248,9 +1248,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbda
1248
1248
  tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
1249
1249
  tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
1250
1250
  tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
1251
- tf_models_nightly-2.19.0.dev20250226.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1252
- tf_models_nightly-2.19.0.dev20250226.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1253
- tf_models_nightly-2.19.0.dev20250226.dist-info/METADATA,sha256=pXPpfanhACUR5sKJVyUH2c8SY0vi-8tcXz-JPP71KpQ,1432
1254
- tf_models_nightly-2.19.0.dev20250226.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1255
- tf_models_nightly-2.19.0.dev20250226.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1256
- tf_models_nightly-2.19.0.dev20250226.dist-info/RECORD,,
1251
+ tf_models_nightly-2.19.0.dev20250227.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1252
+ tf_models_nightly-2.19.0.dev20250227.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1253
+ tf_models_nightly-2.19.0.dev20250227.dist-info/METADATA,sha256=c2C7Tg_sj_ybjTh9TtaYbTlf7RMk16JqrLmVI6K3DUI,1432
1254
+ tf_models_nightly-2.19.0.dev20250227.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1255
+ tf_models_nightly-2.19.0.dev20250227.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1256
+ tf_models_nightly-2.19.0.dev20250227.dist-info/RECORD,,