tf-models-nightly 2.19.0.dev20250226__py2.py3-none-any.whl → 2.19.0.dev20250227__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- official/modeling/multitask/train_lib.py +34 -9
- {tf_models_nightly-2.19.0.dev20250226.dist-info → tf_models_nightly-2.19.0.dev20250227.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.19.0.dev20250226.dist-info → tf_models_nightly-2.19.0.dev20250227.dist-info}/RECORD +7 -7
- {tf_models_nightly-2.19.0.dev20250226.dist-info → tf_models_nightly-2.19.0.dev20250227.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.19.0.dev20250226.dist-info → tf_models_nightly-2.19.0.dev20250227.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.19.0.dev20250226.dist-info → tf_models_nightly-2.19.0.dev20250227.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.19.0.dev20250226.dist-info → tf_models_nightly-2.19.0.dev20250227.dist-info}/top_level.txt +0 -0
@@ -78,15 +78,8 @@ def run_experiment(
|
|
78
78
|
is_training = 'train' in mode
|
79
79
|
is_eval = 'eval' in mode
|
80
80
|
with distribution_strategy.scope():
|
81
|
-
|
82
|
-
|
83
|
-
if params.trainer.trainer_type == 'interleaving':
|
84
|
-
sampler = task_sampler.get_task_sampler(params.trainer.task_sampler,
|
85
|
-
task.task_weights)
|
86
|
-
kwargs.update(dict(task_sampler=sampler))
|
87
|
-
if trainer is None:
|
88
|
-
trainer = TRAINERS[params.trainer.trainer_type](
|
89
|
-
**kwargs) if is_training else None
|
81
|
+
if is_training and trainer is None:
|
82
|
+
trainer = get_trainer(distribution_strategy, params, task, model)
|
90
83
|
if is_eval:
|
91
84
|
eval_steps = task.task_eval_steps
|
92
85
|
evaluator = evaluator_lib.MultiTaskEvaluator(
|
@@ -157,6 +150,38 @@ def run_experiment(
|
|
157
150
|
return model
|
158
151
|
|
159
152
|
|
153
|
+
def get_trainer(
|
154
|
+
distribution_strategy: tf.distribute.Strategy,
|
155
|
+
params: configs.MultiEvalExperimentConfig,
|
156
|
+
task: multitask.MultiTask,
|
157
|
+
model: base_model.MultiTaskBaseModel | tf_keras.Model,
|
158
|
+
) -> orbit.StandardTrainer:
|
159
|
+
"""Creates a multi-task trainer for the given task.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
distribution_strategy: A distribution strategy.
|
163
|
+
params: ExperimentConfig instance.
|
164
|
+
task: A MultiTaskTask instance.
|
165
|
+
model: A MultiTaskBaseModel instance.
|
166
|
+
|
167
|
+
Returns:
|
168
|
+
An Orbit trainer instance.
|
169
|
+
"""
|
170
|
+
with distribution_strategy.scope():
|
171
|
+
kwargs = dict(
|
172
|
+
multi_task=task,
|
173
|
+
multi_task_model=model,
|
174
|
+
optimizer=train_utils.create_optimizer(task, params),
|
175
|
+
)
|
176
|
+
if params.trainer.trainer_type == 'interleaving':
|
177
|
+
kwargs.update(
|
178
|
+
task_sampler=task_sampler.get_task_sampler(
|
179
|
+
params.trainer.task_sampler, task.task_weights
|
180
|
+
)
|
181
|
+
)
|
182
|
+
return TRAINERS[params.trainer.trainer_type](**kwargs)
|
183
|
+
|
184
|
+
|
160
185
|
TrainActionsFactoryType = Callable[
|
161
186
|
[
|
162
187
|
configs.MultiEvalExperimentConfig,
|
@@ -224,7 +224,7 @@ official/modeling/multitask/multitask.py,sha256=DV-ysfhPiIZgsrzZNylsPBxKNBf_xzPx
|
|
224
224
|
official/modeling/multitask/task_sampler.py,sha256=SGVVdjMb5oG4vnCczpfdgBtbsdsXiyBLl9si_0V6nko,4897
|
225
225
|
official/modeling/multitask/task_sampler_test.py,sha256=wkPTp1LCNx4uJbfHfVbKrQULQxumZS3ctBemEbnMokk,3037
|
226
226
|
official/modeling/multitask/test_utils.py,sha256=fPi_TxtzHy_NGTYmdDAQay9TRjwG70UWHwysDqol4jw,4315
|
227
|
-
official/modeling/multitask/train_lib.py,sha256=
|
227
|
+
official/modeling/multitask/train_lib.py,sha256=TbqjoGlrNxxbq3XxKZ04PTxtZFQUlfwX5sd1-9dm5QI,13403
|
228
228
|
official/modeling/multitask/train_lib_test.py,sha256=F_uDEQhSlXq9BH1QBlamUMQ_vRUflli7ef2P9MUKXeg,4766
|
229
229
|
official/modeling/optimization/__init__.py,sha256=BI86b89P0xOksBlzOmbAdWOhlJQe6cxtGIRN_zqCwa8,1201
|
230
230
|
official/modeling/optimization/adafactor_optimizer.py,sha256=zEHHFg9iH1UwcBSzzYE8bCNWln7QIIftXJEXTGTtCtE,792
|
@@ -1248,9 +1248,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbda
|
|
1248
1248
|
tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
|
1249
1249
|
tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
|
1250
1250
|
tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
|
1251
|
-
tf_models_nightly-2.19.0.
|
1252
|
-
tf_models_nightly-2.19.0.
|
1253
|
-
tf_models_nightly-2.19.0.
|
1254
|
-
tf_models_nightly-2.19.0.
|
1255
|
-
tf_models_nightly-2.19.0.
|
1256
|
-
tf_models_nightly-2.19.0.
|
1251
|
+
tf_models_nightly-2.19.0.dev20250227.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
|
1252
|
+
tf_models_nightly-2.19.0.dev20250227.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
|
1253
|
+
tf_models_nightly-2.19.0.dev20250227.dist-info/METADATA,sha256=c2C7Tg_sj_ybjTh9TtaYbTlf7RMk16JqrLmVI6K3DUI,1432
|
1254
|
+
tf_models_nightly-2.19.0.dev20250227.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
|
1255
|
+
tf_models_nightly-2.19.0.dev20250227.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
|
1256
|
+
tf_models_nightly-2.19.0.dev20250227.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|