truefoundry 0.11.3rc2__py3-none-any.whl → 0.11.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truefoundry might be problematic. Click here for more details.
- truefoundry/ml/artifact/truefoundry_artifact_repo.py +1 -1
- truefoundry/ml/integrations/__init__.py +0 -0
- truefoundry/ml/integrations/huggingface/__init__.py +0 -0
- truefoundry/ml/integrations/huggingface/trainer_callback.py +198 -0
- {truefoundry-0.11.3rc2.dist-info → truefoundry-0.11.4.dist-info}/METADATA +1 -1
- {truefoundry-0.11.3rc2.dist-info → truefoundry-0.11.4.dist-info}/RECORD +8 -5
- {truefoundry-0.11.3rc2.dist-info → truefoundry-0.11.4.dist-info}/WHEEL +0 -0
- {truefoundry-0.11.3rc2.dist-info → truefoundry-0.11.4.dist-info}/entry_points.txt +0 -0
|
@@ -654,7 +654,7 @@ class MlFoundryArtifactsRepository:
|
|
|
654
654
|
artifact_identifier=self.artifact_identifier, paths=[remote_file_path]
|
|
655
655
|
)[0]
|
|
656
656
|
|
|
657
|
-
if progress_bar is None or
|
|
657
|
+
if progress_bar is None or progress_bar.disable:
|
|
658
658
|
logger.info("Downloading %s to %s", remote_file_path, local_path)
|
|
659
659
|
|
|
660
660
|
if progress_bar is not None:
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import math
|
|
3
|
+
import os
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from truefoundry import ml
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from transformers.integrations.integration_utils import rewrite_logs
|
|
12
|
+
from transformers.trainer_callback import TrainerCallback
|
|
13
|
+
except ImportError as e:
|
|
14
|
+
raise ImportError(
|
|
15
|
+
"Importing this module requires `transformers` to be installed"
|
|
16
|
+
) from e
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from transformers.trainer_callback import TrainerControl, TrainerState
|
|
20
|
+
from transformers.training_args import TrainingArguments
|
|
21
|
+
|
|
22
|
+
from truefoundry.ml import MlFoundryRun
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class TrueFoundryMLCallback(TrainerCallback):
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
run: "MlFoundryRun",
|
|
31
|
+
log_checkpoints: bool = True,
|
|
32
|
+
checkpoint_artifact_name: Optional[str] = None,
|
|
33
|
+
auto_end_run_on_train_end: bool = False,
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
Args:
|
|
37
|
+
run: The run entity to log metrics to.
|
|
38
|
+
log_checkpoints: Whether to log checkpoints or not, defaults to True.
|
|
39
|
+
checkpoint_artifact_name: The name of the artifact to log checkpoints to, required if log_checkpoints is True.
|
|
40
|
+
auto_end_run_on_train_end: Whether to end the run automatically when training ends, defaults to False.
|
|
41
|
+
|
|
42
|
+
Usage:
|
|
43
|
+
from transformers import Trainer
|
|
44
|
+
from truefoundry.ml.integrations.huggingface.trainer_callback import TrueFoundryMLCallback
|
|
45
|
+
from truefoundry.ml import get_client
|
|
46
|
+
|
|
47
|
+
client = get_client()
|
|
48
|
+
run = client.create_run(ml_repo="my-ml-repo", run_name="my-run", auto_end=False)
|
|
49
|
+
|
|
50
|
+
callback = TrueFoundryMLCallback(
|
|
51
|
+
run=run,
|
|
52
|
+
log_checkpoints=True,
|
|
53
|
+
checkpoint_artifact_name="my-checkpoint",
|
|
54
|
+
auto_end_run_on_train_end=True,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
trainer = Trainer(
|
|
58
|
+
...,
|
|
59
|
+
callbacks=[callback]
|
|
60
|
+
)
|
|
61
|
+
"""
|
|
62
|
+
self._run = run
|
|
63
|
+
self._log_checkpoints = log_checkpoints
|
|
64
|
+
if self._log_checkpoints and not checkpoint_artifact_name:
|
|
65
|
+
raise ValueError(
|
|
66
|
+
"`checkpoint_artifact_name` is required when `log_checkpoints` is True"
|
|
67
|
+
)
|
|
68
|
+
self._checkpoint_artifact_name = checkpoint_artifact_name
|
|
69
|
+
self._auto_end_run_on_train_end = auto_end_run_on_train_end
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def with_managed_run(
|
|
73
|
+
cls,
|
|
74
|
+
ml_repo: str,
|
|
75
|
+
run_name: Optional[str] = None,
|
|
76
|
+
log_checkpoints: bool = True,
|
|
77
|
+
checkpoint_artifact_name: Optional[str] = None,
|
|
78
|
+
auto_end_run_on_train_end: bool = True,
|
|
79
|
+
) -> "TrueFoundryMLCallback":
|
|
80
|
+
"""
|
|
81
|
+
Args:
|
|
82
|
+
ml_repo: The name of the ML Repository to log metrics and data to.
|
|
83
|
+
run_name: The name of the run, if not provided, a random name will be generated.
|
|
84
|
+
log_checkpoints: Whether to log checkpoints or not, defaults to True.
|
|
85
|
+
checkpoint_artifact_name: The name of the artifact to log checkpoints to, required if log_checkpoints is True.
|
|
86
|
+
auto_end_run_on_train_end: Whether to end the run automatically when training ends, defaults to True.
|
|
87
|
+
|
|
88
|
+
Usage:
|
|
89
|
+
from transformers import Trainer
|
|
90
|
+
from truefoundry.ml.integrations.huggingface.trainer_callback import TrueFoundryMLCallback
|
|
91
|
+
|
|
92
|
+
callback = TrueFoundryMLCallback.with_managed_run(
|
|
93
|
+
ml_repo="my-ml-repo",
|
|
94
|
+
run_name="my-run",
|
|
95
|
+
log_checkpoints=True,
|
|
96
|
+
checkpoint_artifact_name="my-checkpoint",
|
|
97
|
+
auto_end_run_on_train_end=True,
|
|
98
|
+
)
|
|
99
|
+
trainer = Trainer(
|
|
100
|
+
...,
|
|
101
|
+
callbacks=[callback]
|
|
102
|
+
)
|
|
103
|
+
"""
|
|
104
|
+
run = ml.get_client().create_run(
|
|
105
|
+
ml_repo=ml_repo, run_name=run_name, auto_end=False
|
|
106
|
+
)
|
|
107
|
+
return cls(
|
|
108
|
+
run=run,
|
|
109
|
+
log_checkpoints=log_checkpoints,
|
|
110
|
+
checkpoint_artifact_name=checkpoint_artifact_name,
|
|
111
|
+
auto_end_run_on_train_end=auto_end_run_on_train_end,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
def _drop_non_finite_values(self, dct: Dict[str, Any]) -> Dict[str, Any]:
|
|
115
|
+
sanitized = {}
|
|
116
|
+
for k, v in dct.items():
|
|
117
|
+
if isinstance(v, (int, float, np.integer, np.floating)) and math.isfinite(
|
|
118
|
+
v
|
|
119
|
+
):
|
|
120
|
+
sanitized[k] = v
|
|
121
|
+
else:
|
|
122
|
+
logger.warning(
|
|
123
|
+
f'Trainer is attempting to log a value of "{v}" of'
|
|
124
|
+
f' type {type(v)} for key "{k}" as a metric.'
|
|
125
|
+
" Mlfoundry's log_metric() only accepts finite float and"
|
|
126
|
+
" int types so we dropped this attribute."
|
|
127
|
+
)
|
|
128
|
+
return sanitized
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def run(self) -> "MlFoundryRun":
|
|
132
|
+
return self._run
|
|
133
|
+
|
|
134
|
+
# noinspection PyMethodOverriding
|
|
135
|
+
def on_log(
|
|
136
|
+
self,
|
|
137
|
+
args: "TrainingArguments",
|
|
138
|
+
state: "TrainerState",
|
|
139
|
+
control: "TrainerControl",
|
|
140
|
+
logs: Optional[Dict[str, Any]] = None,
|
|
141
|
+
**kwargs,
|
|
142
|
+
):
|
|
143
|
+
logs = logs or {}
|
|
144
|
+
if not state.is_world_process_zero:
|
|
145
|
+
return
|
|
146
|
+
|
|
147
|
+
metrics = self._drop_non_finite_values(logs)
|
|
148
|
+
self._run.log_metrics(rewrite_logs(metrics), step=state.global_step)
|
|
149
|
+
|
|
150
|
+
def on_save(
|
|
151
|
+
self,
|
|
152
|
+
args: "TrainingArguments",
|
|
153
|
+
state: "TrainerState",
|
|
154
|
+
control: "TrainerControl",
|
|
155
|
+
**kwargs,
|
|
156
|
+
):
|
|
157
|
+
if not state.is_world_process_zero:
|
|
158
|
+
return
|
|
159
|
+
|
|
160
|
+
if not self._log_checkpoints:
|
|
161
|
+
return
|
|
162
|
+
|
|
163
|
+
if not self._checkpoint_artifact_name:
|
|
164
|
+
return
|
|
165
|
+
|
|
166
|
+
ckpt_dir = f"checkpoint-{state.global_step}"
|
|
167
|
+
artifact_path = os.path.join(args.output_dir, ckpt_dir)
|
|
168
|
+
description = None
|
|
169
|
+
_job_name = os.getenv("TFY_INTERNAL_COMPONENT_NAME")
|
|
170
|
+
_job_run_name = os.getenv("TFY_INTERNAL_JOB_RUN_NAME")
|
|
171
|
+
if _job_name:
|
|
172
|
+
description = f"Checkpoint from job={_job_name} run={_job_run_name}"
|
|
173
|
+
logger.info(f"Uploading checkpoint {ckpt_dir} ...")
|
|
174
|
+
metadata = {}
|
|
175
|
+
for log in state.log_history:
|
|
176
|
+
if isinstance(log, dict) and log.get("step") == state.global_step:
|
|
177
|
+
metadata = log.copy()
|
|
178
|
+
metadata = self._drop_non_finite_values(metadata)
|
|
179
|
+
self._run.log_artifact(
|
|
180
|
+
name=self._checkpoint_artifact_name,
|
|
181
|
+
artifact_paths=[(artifact_path, None)],
|
|
182
|
+
metadata=metadata,
|
|
183
|
+
step=state.global_step,
|
|
184
|
+
description=description,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
def on_train_end(
|
|
188
|
+
self,
|
|
189
|
+
args: "TrainingArguments",
|
|
190
|
+
state: "TrainerState",
|
|
191
|
+
control: "TrainerControl",
|
|
192
|
+
**kwargs,
|
|
193
|
+
):
|
|
194
|
+
"""
|
|
195
|
+
Event called at the end of training.
|
|
196
|
+
"""
|
|
197
|
+
if self._auto_end_run_on_train_end:
|
|
198
|
+
self._run.end()
|
|
@@ -349,7 +349,7 @@ truefoundry/ml/_autogen/models/schema.py,sha256=a_bp42MMPUbwO3407m0UW2W8EOhnxZXf
|
|
|
349
349
|
truefoundry/ml/_autogen/models/signature.py,sha256=rBjpxUIsEeWM0sIyYG5uCJB18DKHR4k5yZw8TzuoP48,4987
|
|
350
350
|
truefoundry/ml/_autogen/models/utils.py,sha256=c7RtSLXhOLcP8rjuUtfnMdaKVTZvvbsmw98gPAkAFrs,24371
|
|
351
351
|
truefoundry/ml/artifact/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
352
|
-
truefoundry/ml/artifact/truefoundry_artifact_repo.py,sha256=
|
|
352
|
+
truefoundry/ml/artifact/truefoundry_artifact_repo.py,sha256=8BFKaXDxutw8bPJLnDI0bO0oNS_xJKo2ijubc2PLFsU,35688
|
|
353
353
|
truefoundry/ml/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
354
354
|
truefoundry/ml/cli/cli.py,sha256=MwpY7z_NEeJE_XIP7XbZELjNeu2vpMmohttHCKDRk54,335
|
|
355
355
|
truefoundry/ml/cli/utils.py,sha256=j6_mZ4Spn114mz3P4QQ8jx0tmorXIuyQnHXVUSDvZi4,1035
|
|
@@ -357,6 +357,9 @@ truefoundry/ml/cli/commands/__init__.py,sha256=diDUiRUX4l6TtNLI4iF-ZblczkELM7FRV
|
|
|
357
357
|
truefoundry/ml/cli/commands/download.py,sha256=N9MhsEQ3U24v_OmnMZT8Q4SoAi38Sm7a21unrACOSDw,2573
|
|
358
358
|
truefoundry/ml/cli/commands/model_init.py,sha256=INyUAU6hiFClI8cZqX5hgnrtNbeKxlZxrjFrjzStU18,2664
|
|
359
359
|
truefoundry/ml/clients/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
360
|
+
truefoundry/ml/integrations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
361
|
+
truefoundry/ml/integrations/huggingface/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
362
|
+
truefoundry/ml/integrations/huggingface/trainer_callback.py,sha256=Zu5AUbH_ct8I1dHyNYJQZBj9Y__hKo0sc2OxpPXJARE,6952
|
|
360
363
|
truefoundry/ml/log_types/__init__.py,sha256=g4u4D4Jaj0aBK5GtrLV88-qThKZR9pSZ17vFEkN-LmM,125
|
|
361
364
|
truefoundry/ml/log_types/plot.py,sha256=LDh4uy6z2P_a2oPM2lc85c0lt8utVvunohzeMawFjZw,7572
|
|
362
365
|
truefoundry/ml/log_types/pydantic_base.py,sha256=eBlw_AEyAz4iJKDP4zgJOCFWcldwQqpf7FADW1jzIQY,272
|
|
@@ -383,7 +386,7 @@ truefoundry/workflow/remote_filesystem/__init__.py,sha256=LQ95ViEjJ7Ts4JcCGOxMPs
|
|
|
383
386
|
truefoundry/workflow/remote_filesystem/logger.py,sha256=em2l7D6sw7xTLDP0kQSLpgfRRCLpN14Qw85TN7ujQcE,1022
|
|
384
387
|
truefoundry/workflow/remote_filesystem/tfy_signed_url_client.py,sha256=xcT0wQmQlgzcj0nP3tJopyFSVWT1uv3nhiTIuwfXYeg,12342
|
|
385
388
|
truefoundry/workflow/remote_filesystem/tfy_signed_url_fs.py,sha256=nSGPZu0Gyd_jz0KsEE-7w_BmnTD8CVF1S8cUJoxaCbc,13305
|
|
386
|
-
truefoundry-0.11.
|
|
387
|
-
truefoundry-0.11.
|
|
388
|
-
truefoundry-0.11.
|
|
389
|
-
truefoundry-0.11.
|
|
389
|
+
truefoundry-0.11.4.dist-info/METADATA,sha256=RD0XhZ5hvcV7BAguapQ9yYssfoEDIXwNUi11w5riKtc,2759
|
|
390
|
+
truefoundry-0.11.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
391
|
+
truefoundry-0.11.4.dist-info/entry_points.txt,sha256=xVjn7RMN-MW2-9f7YU-bBdlZSvvrwzhpX1zmmRmsNPU,98
|
|
392
|
+
truefoundry-0.11.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|