tf-models-nightly 2.19.0.dev20250225__py2.py3-none-any.whl → 2.19.0.dev20250227__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- official/modeling/multitask/train_lib.py +79 -12
- {tf_models_nightly-2.19.0.dev20250225.dist-info → tf_models_nightly-2.19.0.dev20250227.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.19.0.dev20250225.dist-info → tf_models_nightly-2.19.0.dev20250227.dist-info}/RECORD +7 -7
- {tf_models_nightly-2.19.0.dev20250225.dist-info → tf_models_nightly-2.19.0.dev20250227.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.19.0.dev20250225.dist-info → tf_models_nightly-2.19.0.dev20250227.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.19.0.dev20250225.dist-info → tf_models_nightly-2.19.0.dev20250227.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.19.0.dev20250225.dist-info → tf_models_nightly-2.19.0.dev20250227.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,7 @@
|
|
15
15
|
"""Multitask training driver library."""
|
16
16
|
# pytype: disable=attribute-error
|
17
17
|
import os
|
18
|
-
from typing import Any, List, Mapping, Optional, Tuple, Union
|
18
|
+
from typing import Any, List, Mapping, Optional, Tuple, Union, Callable
|
19
19
|
from absl import logging
|
20
20
|
import orbit
|
21
21
|
import tensorflow as tf, tf_keras
|
@@ -78,15 +78,8 @@ def run_experiment(
|
|
78
78
|
is_training = 'train' in mode
|
79
79
|
is_eval = 'eval' in mode
|
80
80
|
with distribution_strategy.scope():
|
81
|
-
|
82
|
-
|
83
|
-
if params.trainer.trainer_type == 'interleaving':
|
84
|
-
sampler = task_sampler.get_task_sampler(params.trainer.task_sampler,
|
85
|
-
task.task_weights)
|
86
|
-
kwargs.update(dict(task_sampler=sampler))
|
87
|
-
if trainer is None:
|
88
|
-
trainer = TRAINERS[params.trainer.trainer_type](
|
89
|
-
**kwargs) if is_training else None
|
81
|
+
if is_training and trainer is None:
|
82
|
+
trainer = get_trainer(distribution_strategy, params, task, model)
|
90
83
|
if is_eval:
|
91
84
|
eval_steps = task.task_eval_steps
|
92
85
|
evaluator = evaluator_lib.MultiTaskEvaluator(
|
@@ -157,6 +150,57 @@ def run_experiment(
|
|
157
150
|
return model
|
158
151
|
|
159
152
|
|
153
|
+
def get_trainer(
|
154
|
+
distribution_strategy: tf.distribute.Strategy,
|
155
|
+
params: configs.MultiEvalExperimentConfig,
|
156
|
+
task: multitask.MultiTask,
|
157
|
+
model: base_model.MultiTaskBaseModel | tf_keras.Model,
|
158
|
+
) -> orbit.StandardTrainer:
|
159
|
+
"""Creates a multi-task trainer for the given task.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
distribution_strategy: A distribution strategy.
|
163
|
+
params: ExperimentConfig instance.
|
164
|
+
task: A MultiTaskTask instance.
|
165
|
+
model: A MultiTaskBaseModel instance.
|
166
|
+
|
167
|
+
Returns:
|
168
|
+
An Orbit trainer instance.
|
169
|
+
"""
|
170
|
+
with distribution_strategy.scope():
|
171
|
+
kwargs = dict(
|
172
|
+
multi_task=task,
|
173
|
+
multi_task_model=model,
|
174
|
+
optimizer=train_utils.create_optimizer(task, params),
|
175
|
+
)
|
176
|
+
if params.trainer.trainer_type == 'interleaving':
|
177
|
+
kwargs.update(
|
178
|
+
task_sampler=task_sampler.get_task_sampler(
|
179
|
+
params.trainer.task_sampler, task.task_weights
|
180
|
+
)
|
181
|
+
)
|
182
|
+
return TRAINERS[params.trainer.trainer_type](**kwargs)
|
183
|
+
|
184
|
+
|
185
|
+
TrainActionsFactoryType = Callable[
|
186
|
+
[
|
187
|
+
configs.MultiEvalExperimentConfig,
|
188
|
+
orbit.StandardTrainer,
|
189
|
+
str,
|
190
|
+
tf.train.CheckpointManager,
|
191
|
+
],
|
192
|
+
List[orbit.Action],
|
193
|
+
]
|
194
|
+
EvalActionsFactoryType = Callable[
|
195
|
+
[
|
196
|
+
configs.MultiEvalExperimentConfig,
|
197
|
+
orbit.AbstractEvaluator,
|
198
|
+
str,
|
199
|
+
],
|
200
|
+
List[orbit.Action],
|
201
|
+
]
|
202
|
+
|
203
|
+
|
160
204
|
def run_experiment_with_multitask_eval(
|
161
205
|
*,
|
162
206
|
distribution_strategy: tf.distribute.Strategy,
|
@@ -171,6 +215,8 @@ def run_experiment_with_multitask_eval(
|
|
171
215
|
eval_summary_manager: Optional[orbit.utils.SummaryManagerInterface] = None,
|
172
216
|
best_ckpt_exporter_creator: Optional[Any] = train_utils
|
173
217
|
.maybe_create_best_ckpt_exporter,
|
218
|
+
train_actions_factory: Optional[TrainActionsFactoryType] = None,
|
219
|
+
eval_actions_factory: Optional[EvalActionsFactoryType] = None,
|
174
220
|
) -> Tuple[Any, Any]:
|
175
221
|
"""Runs train/eval configured by the experiment params.
|
176
222
|
|
@@ -193,6 +239,8 @@ def run_experiment_with_multitask_eval(
|
|
193
239
|
will be created internally for TensorBoard summaries by default from the
|
194
240
|
`eval_summary_dir`.
|
195
241
|
best_ckpt_exporter_creator: A functor for creating best checkpoint exporter.
|
242
|
+
train_actions_factory: Optional factory function to create train actions.
|
243
|
+
eval_actions_factory: Optional factory function to create eval actions.
|
196
244
|
|
197
245
|
Returns:
|
198
246
|
model: `tf_keras.Model` instance.
|
@@ -214,7 +262,6 @@ def run_experiment_with_multitask_eval(
|
|
214
262
|
|
215
263
|
# Build the model or fetch the pre-cached one (which could be either
|
216
264
|
# multi-task model or single task model).
|
217
|
-
model = None
|
218
265
|
if trainer is None:
|
219
266
|
if isinstance(train_task, multitask.MultiTask):
|
220
267
|
model = train_task.build_multitask_model()
|
@@ -254,6 +301,23 @@ def run_experiment_with_multitask_eval(
|
|
254
301
|
checkpoint_interval=params.trainer.checkpoint_interval,
|
255
302
|
init_fn=trainer.initialize if trainer else None)
|
256
303
|
|
304
|
+
if trainer and train_actions_factory:
|
305
|
+
# pytype: disable=wrong-keyword-args
|
306
|
+
train_actions = train_actions_factory(
|
307
|
+
params=params,
|
308
|
+
trainer=trainer,
|
309
|
+
model_dir=model_dir,
|
310
|
+
checkpoint_manager=checkpoint_manager,
|
311
|
+
)
|
312
|
+
# pytype: enable=wrong-keyword-args
|
313
|
+
else:
|
314
|
+
train_actions = None
|
315
|
+
|
316
|
+
if evaluator and eval_actions_factory:
|
317
|
+
eval_actions = eval_actions_factory(params, evaluator, model_dir)
|
318
|
+
else:
|
319
|
+
eval_actions = None
|
320
|
+
|
257
321
|
controller = orbit.Controller(
|
258
322
|
strategy=distribution_strategy,
|
259
323
|
trainer=trainer,
|
@@ -266,7 +330,10 @@ def run_experiment_with_multitask_eval(
|
|
266
330
|
(save_summary) else None,
|
267
331
|
eval_summary_manager=eval_summary_manager,
|
268
332
|
summary_interval=params.trainer.summary_interval if
|
269
|
-
(save_summary) else None
|
333
|
+
(save_summary) else None,
|
334
|
+
train_actions=train_actions,
|
335
|
+
eval_actions=eval_actions,
|
336
|
+
)
|
270
337
|
|
271
338
|
logging.info('Starts to execute mode: %s', mode)
|
272
339
|
with distribution_strategy.scope():
|
@@ -224,7 +224,7 @@ official/modeling/multitask/multitask.py,sha256=DV-ysfhPiIZgsrzZNylsPBxKNBf_xzPx
|
|
224
224
|
official/modeling/multitask/task_sampler.py,sha256=SGVVdjMb5oG4vnCczpfdgBtbsdsXiyBLl9si_0V6nko,4897
|
225
225
|
official/modeling/multitask/task_sampler_test.py,sha256=wkPTp1LCNx4uJbfHfVbKrQULQxumZS3ctBemEbnMokk,3037
|
226
226
|
official/modeling/multitask/test_utils.py,sha256=fPi_TxtzHy_NGTYmdDAQay9TRjwG70UWHwysDqol4jw,4315
|
227
|
-
official/modeling/multitask/train_lib.py,sha256=
|
227
|
+
official/modeling/multitask/train_lib.py,sha256=TbqjoGlrNxxbq3XxKZ04PTxtZFQUlfwX5sd1-9dm5QI,13403
|
228
228
|
official/modeling/multitask/train_lib_test.py,sha256=F_uDEQhSlXq9BH1QBlamUMQ_vRUflli7ef2P9MUKXeg,4766
|
229
229
|
official/modeling/optimization/__init__.py,sha256=BI86b89P0xOksBlzOmbAdWOhlJQe6cxtGIRN_zqCwa8,1201
|
230
230
|
official/modeling/optimization/adafactor_optimizer.py,sha256=zEHHFg9iH1UwcBSzzYE8bCNWln7QIIftXJEXTGTtCtE,792
|
@@ -1248,9 +1248,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbda
|
|
1248
1248
|
tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
|
1249
1249
|
tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
|
1250
1250
|
tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
|
1251
|
-
tf_models_nightly-2.19.0.
|
1252
|
-
tf_models_nightly-2.19.0.
|
1253
|
-
tf_models_nightly-2.19.0.
|
1254
|
-
tf_models_nightly-2.19.0.
|
1255
|
-
tf_models_nightly-2.19.0.
|
1256
|
-
tf_models_nightly-2.19.0.
|
1251
|
+
tf_models_nightly-2.19.0.dev20250227.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
|
1252
|
+
tf_models_nightly-2.19.0.dev20250227.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
|
1253
|
+
tf_models_nightly-2.19.0.dev20250227.dist-info/METADATA,sha256=c2C7Tg_sj_ybjTh9TtaYbTlf7RMk16JqrLmVI6K3DUI,1432
|
1254
|
+
tf_models_nightly-2.19.0.dev20250227.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
|
1255
|
+
tf_models_nightly-2.19.0.dev20250227.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
|
1256
|
+
tf_models_nightly-2.19.0.dev20250227.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|