truefoundry 0.11.3rc2__py3-none-any.whl → 0.11.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truefoundry might be problematic. Click here for more details.

@@ -654,7 +654,7 @@ class MlFoundryArtifactsRepository:
654
654
  artifact_identifier=self.artifact_identifier, paths=[remote_file_path]
655
655
  )[0]
656
656
 
657
- if progress_bar is None or not progress_bar.disable:
657
+ if progress_bar is None or progress_bar.disable:
658
658
  logger.info("Downloading %s to %s", remote_file_path, local_path)
659
659
 
660
660
  if progress_bar is not None:
File without changes
File without changes
@@ -0,0 +1,198 @@
1
+ import logging
2
+ import math
3
+ import os
4
+ from typing import TYPE_CHECKING, Any, Dict, Optional
5
+
6
+ import numpy as np
7
+
8
+ from truefoundry import ml
9
+
10
+ try:
11
+ from transformers.integrations.integration_utils import rewrite_logs
12
+ from transformers.trainer_callback import TrainerCallback
13
+ except ImportError as e:
14
+ raise ImportError(
15
+ "Importing this module requires `transformers` to be installed"
16
+ ) from e
17
+
18
+ if TYPE_CHECKING:
19
+ from transformers.trainer_callback import TrainerControl, TrainerState
20
+ from transformers.training_args import TrainingArguments
21
+
22
+ from truefoundry.ml import MlFoundryRun
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class TrueFoundryMLCallback(TrainerCallback):
28
+ def __init__(
29
+ self,
30
+ run: "MlFoundryRun",
31
+ log_checkpoints: bool = True,
32
+ checkpoint_artifact_name: Optional[str] = None,
33
+ auto_end_run_on_train_end: bool = False,
34
+ ):
35
+ """
36
+ Args:
37
+ run: The run entity to log metrics to.
38
+ log_checkpoints: Whether to log checkpoints or not, defaults to True.
39
+ checkpoint_artifact_name: The name of the artifact to log checkpoints to, required if log_checkpoints is True.
40
+ auto_end_run_on_train_end: Whether to end the run automatically when training ends, defaults to False.
41
+
42
+ Usage:
43
+ from transformers import Trainer
44
+ from truefoundry.ml.integrations.huggingface.trainer_callback import TrueFoundryMLCallback
45
+ from truefoundry.ml import get_client
46
+
47
+ client = get_client()
48
+ run = client.create_run(ml_repo="my-ml-repo", run_name="my-run", auto_end=False)
49
+
50
+ callback = TrueFoundryMLCallback(
51
+ run=run,
52
+ log_checkpoints=True,
53
+ checkpoint_artifact_name="my-checkpoint",
54
+ auto_end_run_on_train_end=True,
55
+ )
56
+
57
+ trainer = Trainer(
58
+ ...,
59
+ callbacks=[callback]
60
+ )
61
+ """
62
+ self._run = run
63
+ self._log_checkpoints = log_checkpoints
64
+ if self._log_checkpoints and not checkpoint_artifact_name:
65
+ raise ValueError(
66
+ "`checkpoint_artifact_name` is required when `log_checkpoints` is True"
67
+ )
68
+ self._checkpoint_artifact_name = checkpoint_artifact_name
69
+ self._auto_end_run_on_train_end = auto_end_run_on_train_end
70
+
71
+ @classmethod
72
+ def with_managed_run(
73
+ cls,
74
+ ml_repo: str,
75
+ run_name: Optional[str] = None,
76
+ log_checkpoints: bool = True,
77
+ checkpoint_artifact_name: Optional[str] = None,
78
+ auto_end_run_on_train_end: bool = True,
79
+ ) -> "TrueFoundryMLCallback":
80
+ """
81
+ Args:
82
+ ml_repo: The name of the ML Repository to log metrics and data to.
83
+ run_name: The name of the run, if not provided, a random name will be generated.
84
+ log_checkpoints: Whether to log checkpoints or not, defaults to True.
85
+ checkpoint_artifact_name: The name of the artifact to log checkpoints to, required if log_checkpoints is True.
86
+ auto_end_run_on_train_end: Whether to end the run automatically when training ends, defaults to True.
87
+
88
+ Usage:
89
+ from transformers import Trainer
90
+ from truefoundry.ml.integrations.huggingface.trainer_callback import TrueFoundryMLCallback
91
+
92
+ callback = TrueFoundryMLCallback.with_managed_run(
93
+ ml_repo="my-ml-repo",
94
+ run_name="my-run",
95
+ log_checkpoints=True,
96
+ checkpoint_artifact_name="my-checkpoint",
97
+ auto_end_run_on_train_end=True,
98
+ )
99
+ trainer = Trainer(
100
+ ...,
101
+ callbacks=[callback]
102
+ )
103
+ """
104
+ run = ml.get_client().create_run(
105
+ ml_repo=ml_repo, run_name=run_name, auto_end=False
106
+ )
107
+ return cls(
108
+ run=run,
109
+ log_checkpoints=log_checkpoints,
110
+ checkpoint_artifact_name=checkpoint_artifact_name,
111
+ auto_end_run_on_train_end=auto_end_run_on_train_end,
112
+ )
113
+
114
+ def _drop_non_finite_values(self, dct: Dict[str, Any]) -> Dict[str, Any]:
115
+ sanitized = {}
116
+ for k, v in dct.items():
117
+ if isinstance(v, (int, float, np.integer, np.floating)) and math.isfinite(
118
+ v
119
+ ):
120
+ sanitized[k] = v
121
+ else:
122
+ logger.warning(
123
+ f'Trainer is attempting to log a value of "{v}" of'
124
+ f' type {type(v)} for key "{k}" as a metric.'
125
+ " Mlfoundry's log_metric() only accepts finite float and"
126
+ " int types so we dropped this attribute."
127
+ )
128
+ return sanitized
129
+
130
+ @property
131
+ def run(self) -> "MlFoundryRun":
132
+ return self._run
133
+
134
+ # noinspection PyMethodOverriding
135
+ def on_log(
136
+ self,
137
+ args: "TrainingArguments",
138
+ state: "TrainerState",
139
+ control: "TrainerControl",
140
+ logs: Optional[Dict[str, Any]] = None,
141
+ **kwargs,
142
+ ):
143
+ logs = logs or {}
144
+ if not state.is_world_process_zero:
145
+ return
146
+
147
+ metrics = self._drop_non_finite_values(logs)
148
+ self._run.log_metrics(rewrite_logs(metrics), step=state.global_step)
149
+
150
+ def on_save(
151
+ self,
152
+ args: "TrainingArguments",
153
+ state: "TrainerState",
154
+ control: "TrainerControl",
155
+ **kwargs,
156
+ ):
157
+ if not state.is_world_process_zero:
158
+ return
159
+
160
+ if not self._log_checkpoints:
161
+ return
162
+
163
+ if not self._checkpoint_artifact_name:
164
+ return
165
+
166
+ ckpt_dir = f"checkpoint-{state.global_step}"
167
+ artifact_path = os.path.join(args.output_dir, ckpt_dir)
168
+ description = None
169
+ _job_name = os.getenv("TFY_INTERNAL_COMPONENT_NAME")
170
+ _job_run_name = os.getenv("TFY_INTERNAL_JOB_RUN_NAME")
171
+ if _job_name:
172
+ description = f"Checkpoint from job={_job_name} run={_job_run_name}"
173
+ logger.info(f"Uploading checkpoint {ckpt_dir} ...")
174
+ metadata = {}
175
+ for log in state.log_history:
176
+ if isinstance(log, dict) and log.get("step") == state.global_step:
177
+ metadata = log.copy()
178
+ metadata = self._drop_non_finite_values(metadata)
179
+ self._run.log_artifact(
180
+ name=self._checkpoint_artifact_name,
181
+ artifact_paths=[(artifact_path, None)],
182
+ metadata=metadata,
183
+ step=state.global_step,
184
+ description=description,
185
+ )
186
+
187
+ def on_train_end(
188
+ self,
189
+ args: "TrainingArguments",
190
+ state: "TrainerState",
191
+ control: "TrainerControl",
192
+ **kwargs,
193
+ ):
194
+ """
195
+ Event called at the end of training.
196
+ """
197
+ if self._auto_end_run_on_train_end:
198
+ self._run.end()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: truefoundry
3
- Version: 0.11.3rc2
3
+ Version: 0.11.4
4
4
  Summary: TrueFoundry CLI
5
5
  Author-email: TrueFoundry Team <abhishek@truefoundry.com>
6
6
  Requires-Python: <3.14,>=3.8.1
@@ -349,7 +349,7 @@ truefoundry/ml/_autogen/models/schema.py,sha256=a_bp42MMPUbwO3407m0UW2W8EOhnxZXf
349
349
  truefoundry/ml/_autogen/models/signature.py,sha256=rBjpxUIsEeWM0sIyYG5uCJB18DKHR4k5yZw8TzuoP48,4987
350
350
  truefoundry/ml/_autogen/models/utils.py,sha256=c7RtSLXhOLcP8rjuUtfnMdaKVTZvvbsmw98gPAkAFrs,24371
351
351
  truefoundry/ml/artifact/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
352
- truefoundry/ml/artifact/truefoundry_artifact_repo.py,sha256=hbgLxSoihkLVuICzRueuh8iAIc-yruCW5TuMXYQ-aCU,35692
352
+ truefoundry/ml/artifact/truefoundry_artifact_repo.py,sha256=8BFKaXDxutw8bPJLnDI0bO0oNS_xJKo2ijubc2PLFsU,35688
353
353
  truefoundry/ml/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
354
354
  truefoundry/ml/cli/cli.py,sha256=MwpY7z_NEeJE_XIP7XbZELjNeu2vpMmohttHCKDRk54,335
355
355
  truefoundry/ml/cli/utils.py,sha256=j6_mZ4Spn114mz3P4QQ8jx0tmorXIuyQnHXVUSDvZi4,1035
@@ -357,6 +357,9 @@ truefoundry/ml/cli/commands/__init__.py,sha256=diDUiRUX4l6TtNLI4iF-ZblczkELM7FRV
357
357
  truefoundry/ml/cli/commands/download.py,sha256=N9MhsEQ3U24v_OmnMZT8Q4SoAi38Sm7a21unrACOSDw,2573
358
358
  truefoundry/ml/cli/commands/model_init.py,sha256=INyUAU6hiFClI8cZqX5hgnrtNbeKxlZxrjFrjzStU18,2664
359
359
  truefoundry/ml/clients/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
360
+ truefoundry/ml/integrations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
361
+ truefoundry/ml/integrations/huggingface/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
362
+ truefoundry/ml/integrations/huggingface/trainer_callback.py,sha256=Zu5AUbH_ct8I1dHyNYJQZBj9Y__hKo0sc2OxpPXJARE,6952
360
363
  truefoundry/ml/log_types/__init__.py,sha256=g4u4D4Jaj0aBK5GtrLV88-qThKZR9pSZ17vFEkN-LmM,125
361
364
  truefoundry/ml/log_types/plot.py,sha256=LDh4uy6z2P_a2oPM2lc85c0lt8utVvunohzeMawFjZw,7572
362
365
  truefoundry/ml/log_types/pydantic_base.py,sha256=eBlw_AEyAz4iJKDP4zgJOCFWcldwQqpf7FADW1jzIQY,272
@@ -383,7 +386,7 @@ truefoundry/workflow/remote_filesystem/__init__.py,sha256=LQ95ViEjJ7Ts4JcCGOxMPs
383
386
  truefoundry/workflow/remote_filesystem/logger.py,sha256=em2l7D6sw7xTLDP0kQSLpgfRRCLpN14Qw85TN7ujQcE,1022
384
387
  truefoundry/workflow/remote_filesystem/tfy_signed_url_client.py,sha256=xcT0wQmQlgzcj0nP3tJopyFSVWT1uv3nhiTIuwfXYeg,12342
385
388
  truefoundry/workflow/remote_filesystem/tfy_signed_url_fs.py,sha256=nSGPZu0Gyd_jz0KsEE-7w_BmnTD8CVF1S8cUJoxaCbc,13305
386
- truefoundry-0.11.3rc2.dist-info/METADATA,sha256=JnuUsg_bJq6c07XAzjN-khBwx5sDL9nVkdA7NEGGvlk,2762
387
- truefoundry-0.11.3rc2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
388
- truefoundry-0.11.3rc2.dist-info/entry_points.txt,sha256=xVjn7RMN-MW2-9f7YU-bBdlZSvvrwzhpX1zmmRmsNPU,98
389
- truefoundry-0.11.3rc2.dist-info/RECORD,,
389
+ truefoundry-0.11.4.dist-info/METADATA,sha256=RD0XhZ5hvcV7BAguapQ9yYssfoEDIXwNUi11w5riKtc,2759
390
+ truefoundry-0.11.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
391
+ truefoundry-0.11.4.dist-info/entry_points.txt,sha256=xVjn7RMN-MW2-9f7YU-bBdlZSvvrwzhpX1zmmRmsNPU,98
392
+ truefoundry-0.11.4.dist-info/RECORD,,