truefoundry 0.5.0rc2__py3-none-any.whl → 0.5.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truefoundry might be problematic. Click here for more details.

Files changed (31) hide show
  1. truefoundry/deploy/auto_gen/models.py +12 -8
  2. truefoundry/deploy/builder/builders/tfy_python_buildpack/dockerfile_template.py +2 -2
  3. truefoundry/deploy/python_deploy_codegen.py +1 -0
  4. truefoundry/deploy/v2/lib/deploy_workflow.py +8 -2
  5. truefoundry/ml/__init__.py +6 -3
  6. truefoundry/ml/autogen/client/__init__.py +9 -4
  7. truefoundry/ml/autogen/client/models/__init__.py +9 -4
  8. truefoundry/ml/autogen/client/models/artifact_version_dto.py +3 -5
  9. truefoundry/ml/autogen/client/models/artifact_version_manifest.py +111 -0
  10. truefoundry/ml/autogen/client/models/{external_model_source.py → external_artifact_source.py} +8 -8
  11. truefoundry/ml/autogen/client/models/finalize_artifact_version_request_dto.py +3 -5
  12. truefoundry/ml/autogen/client/models/manifest.py +154 -0
  13. truefoundry/ml/autogen/client/models/model_version_manifest.py +3 -3
  14. truefoundry/ml/autogen/client/models/source.py +23 -23
  15. truefoundry/ml/autogen/client/models/source1.py +154 -0
  16. truefoundry/ml/autogen/client/models/transformers_framework.py +6 -1
  17. truefoundry/ml/autogen/client/models/{truefoundry_model_source.py → true_foundry_artifact_source.py} +9 -9
  18. truefoundry/ml/autogen/client/models/update_artifact_version_request_dto.py +11 -1
  19. truefoundry/ml/autogen/client_README.md +5 -2
  20. truefoundry/ml/autogen/entities/artifacts.py +22 -8
  21. truefoundry/ml/log_types/artifacts/artifact.py +131 -63
  22. truefoundry/ml/log_types/artifacts/general_artifact.py +7 -26
  23. truefoundry/ml/log_types/artifacts/model.py +74 -81
  24. truefoundry/ml/mlfoundry_api.py +4 -4
  25. truefoundry/ml/mlfoundry_run.py +8 -5
  26. truefoundry/ml/model_framework.py +2 -2
  27. truefoundry/workflow/remote_filesystem/tfy_signed_url_client.py +41 -8
  28. {truefoundry-0.5.0rc2.dist-info → truefoundry-0.5.0rc4.dist-info}/METADATA +2 -2
  29. {truefoundry-0.5.0rc2.dist-info → truefoundry-0.5.0rc4.dist-info}/RECORD +31 -28
  30. {truefoundry-0.5.0rc2.dist-info → truefoundry-0.5.0rc4.dist-info}/WHEEL +0 -0
  31. {truefoundry-0.5.0rc2.dist-info → truefoundry-0.5.0rc4.dist-info}/entry_points.txt +0 -0
@@ -4,6 +4,7 @@ import json
4
4
  import logging
5
5
  import os.path
6
6
  import tempfile
7
+ import typing
7
8
  import uuid
8
9
  import warnings
9
10
  from pathlib import Path
@@ -15,25 +16,23 @@ from truefoundry.ml.artifact.truefoundry_artifact_repo import (
15
16
  )
16
17
  from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
17
18
  ArtifactType,
18
- ArtifactVersionSerializationFormat,
19
19
  CreateArtifactVersionRequestDto,
20
20
  DeleteArtifactVersionsRequestDto,
21
- ExternalModelSource,
21
+ ExternalArtifactSource,
22
22
  FinalizeArtifactVersionRequestDto,
23
23
  Framework,
24
+ Manifest,
24
25
  MlfoundryArtifactsApi,
25
26
  ModelDto,
26
27
  ModelVersionDto,
27
28
  ModelVersionManifest,
28
29
  NotifyArtifactVersionFailureDto,
29
- TruefoundryModelSource,
30
+ TrueFoundryArtifactSource,
30
31
  UpdateModelVersionRequestDto,
31
32
  )
32
- from truefoundry.ml.autogen.client import (
33
- Source as ModelVersionSource,
34
- )
35
33
  from truefoundry.ml.enums import ModelFramework
36
34
  from truefoundry.ml.exceptions import MlFoundryException
35
+ from truefoundry.ml.log_types.artifacts.artifact import BlobStorageDirectory
37
36
  from truefoundry.ml.log_types.artifacts.constants import (
38
37
  INTERNAL_METADATA_PATH,
39
38
  )
@@ -46,7 +45,7 @@ from truefoundry.ml.log_types.artifacts.utils import (
46
45
  )
47
46
  from truefoundry.ml.model_framework import ModelFrameworkType, _ModelFramework
48
47
  from truefoundry.ml.session import _get_api_client
49
- from truefoundry.pydantic_v1 import BaseModel, Extra, StrictStr
48
+ from truefoundry.pydantic_v1 import BaseModel, Extra
50
49
 
51
50
  if TYPE_CHECKING:
52
51
  from truefoundry.ml.mlfoundry_run import MlFoundryRun
@@ -75,10 +74,6 @@ class ModelVersionInternalMetadata(BaseModel):
75
74
  return dct
76
75
 
77
76
 
78
- class BlobStorageModelDirectory(BaseModel):
79
- uri: StrictStr
80
-
81
-
82
77
  class ModelVersionDownloadInfo(BaseModel):
83
78
  download_dir: str
84
79
  model_dir: str
@@ -101,7 +96,7 @@ class ModelVersion:
101
96
  self._deleted = False
102
97
  self._description: str = ""
103
98
  self._metadata: Dict[str, Any] = {}
104
- self._framework: Optional["Framework"] = None
99
+ self._framework: Optional[ModelFrameworkType] = None
105
100
  self._set_mutable_attrs()
106
101
 
107
102
  @classmethod
@@ -140,20 +135,29 @@ class ModelVersion:
140
135
  )
141
136
 
142
137
  def _set_mutable_attrs(self):
143
- self._description = self._model_version.description or ""
144
- self._metadata = copy.deepcopy(self._model_version.artifact_metadata)
145
- self._framework = (
146
- copy.deepcopy(self._model_version.manifest.framework)
147
- if self._model_version.manifest
148
- else None
149
- )
138
+ if self._model_version.manifest:
139
+ self._description = self._model_version.manifest.description or ""
140
+ self._metadata = copy.deepcopy(self._model_version.manifest.metadata)
141
+ if self._model_version.manifest.framework:
142
+ self._framework = copy.deepcopy(
143
+ self._model_version.manifest.framework.actual_instance
144
+ )
145
+ else:
146
+ self._framework = None
147
+ else:
148
+ self._description = self._model_version.description or ""
149
+ self._metadata = copy.deepcopy(self._model_version.artifact_metadata)
150
+ self._framework = _ModelFramework.to_model_framework_type(
151
+ self._model_version.model_framework
152
+ )
150
153
 
151
- def _refetch_model_version(self):
154
+ def _refetch_model_version(self, reset_mutable_attrs: bool = True):
152
155
  _model_version = self._mlfoundry_artifacts_api.get_model_version_get(
153
156
  id=self._model_version.id
154
157
  )
155
158
  self._model_version = _model_version.model_version
156
- self._set_mutable_attrs()
159
+ if reset_mutable_attrs:
160
+ self._set_mutable_attrs()
157
161
 
158
162
  def __repr__(self):
159
163
  return f"{self.__class__.__name__}(fqn={self.fqn!r})"
@@ -187,12 +191,14 @@ class ModelVersion:
187
191
  return self._model_version.fqn
188
192
 
189
193
  @property
190
- def step(self) -> int:
194
+ def step(self) -> Optional[int]:
191
195
  """Get the step in which model was created"""
196
+ if self._model_version.manifest:
197
+ return self._model_version.manifest.step
192
198
  return self._model_version.step
193
199
 
194
200
  @property
195
- def description(self) -> Optional[str]:
201
+ def description(self) -> str:
196
202
  """Get description of the model"""
197
203
  return self._description
198
204
 
@@ -201,8 +207,6 @@ class ModelVersion:
201
207
  """set the description of the model"""
202
208
  _validate_description(value)
203
209
  self._description = value
204
- if self._model_version.manifest:
205
- self._model_version.manifest.description = value
206
210
 
207
211
  @property
208
212
  def metadata(self) -> Dict[str, Any]:
@@ -213,52 +217,34 @@ class ModelVersion:
213
217
  def metadata(self, value: Dict[str, Any]):
214
218
  """set the metadata for current model"""
215
219
  _validate_artifact_metadata(value)
216
- self._metadata = value
217
- if self._model_version.manifest:
218
- self._model_version.manifest.metadata = value
220
+ self._metadata = copy.deepcopy(value)
219
221
 
220
222
  @property
221
223
  def framework(self) -> Optional["ModelFrameworkType"]:
222
224
  """Get the framework of the model"""
223
- return (
224
- _ModelFramework.from_dict(self._framework.actual_instance.dict())
225
- if self._framework and self._framework.actual_instance
226
- else None
227
- )
225
+ return self._framework
228
226
 
229
227
  @framework.setter
230
228
  def framework(
231
229
  self, value: Optional[Union[str, ModelFramework, "ModelFrameworkType"]]
232
230
  ):
233
- """set the framework of the model"""
234
- model_framework_type: Optional[ModelFrameworkType] = None
235
-
236
- if value is not None:
237
- model_framework_type = _ModelFramework.to_model_framework_type(value)
238
-
239
- if (
240
- self._model_version.serialization_format
241
- == ArtifactVersionSerializationFormat.V2
242
- ):
243
- self._model_version.manifest.framework = self._framework = (
244
- Framework.from_dict(model_framework_type.dict())
245
- if model_framework_type
246
- else None
247
- )
248
- else:
231
+ """Set the framework of the model"""
232
+ if not self._model_version.manifest:
249
233
  warnings.warn(
250
234
  message="This model version was created using an older serialization format. Framework will not be updated",
251
235
  category=DeprecationWarning,
252
236
  stacklevel=2,
253
237
  )
238
+ return
239
+ self._framework = _ModelFramework.to_model_framework_type(value)
254
240
 
255
241
  @property
256
- def created_at(self) -> datetime.datetime:
242
+ def created_at(self) -> Optional[datetime.datetime]:
257
243
  """Get the time at which model version was created"""
258
244
  return self._model_version.created_at
259
245
 
260
246
  @property
261
- def updated_at(self) -> datetime.datetime:
247
+ def updated_at(self) -> Optional[datetime.datetime]:
262
248
  """Get the information about when the model version was updated"""
263
249
  return self._model_version.updated_at
264
250
 
@@ -311,13 +297,11 @@ class ModelVersion:
311
297
  path=path, overwrite=overwrite, progress=progress
312
298
  )
313
299
 
314
- if (
315
- self._model_version.serialization_format
316
- == ArtifactVersionSerializationFormat.V2
317
- ):
300
+ if self._model_version.manifest:
301
+ _framework = self._model_version.manifest.framework
318
302
  model_framework = (
319
- ModelFramework(self.framework.type)
320
- if self.framework
303
+ ModelFramework(_framework.actual_instance.type)
304
+ if _framework
321
305
  else ModelFramework.UNKNOWN
322
306
  )
323
307
  download_info = ModelVersionDownloadInfo(
@@ -429,22 +413,34 @@ class ModelVersion:
429
413
  ```
430
414
  """
431
415
  self._ensure_not_deleted()
432
- _model_version = self._mlfoundry_artifacts_api.update_model_version_post(
433
- update_model_version_request_dto=UpdateModelVersionRequestDto(
434
- id=self._model_version.id,
435
- description=self.description,
436
- artifact_metadata=self.metadata,
437
- manifest=self._model_version.manifest,
416
+ if self._model_version.manifest:
417
+ self._model_version.manifest.description = self.description
418
+ self._model_version.manifest.metadata = self.metadata
419
+ self._model_version.manifest.framework = (
420
+ Framework.from_dict(self.framework.dict()) if self.framework else None
438
421
  )
439
- )
440
- self._model_version = _model_version.model_version
441
- self._set_mutable_attrs()
422
+ try:
423
+ _model_version = self._mlfoundry_artifacts_api.update_model_version_post(
424
+ update_model_version_request_dto=UpdateModelVersionRequestDto(
425
+ id=self._model_version.id,
426
+ description=self.description,
427
+ artifact_metadata=self.metadata,
428
+ manifest=self._model_version.manifest,
429
+ )
430
+ )
431
+ except Exception:
432
+ # rollback edits to internal object
433
+ self._refetch_model_version(reset_mutable_attrs=False)
434
+ raise
435
+ else:
436
+ self._model_version = _model_version.model_version
437
+ self._set_mutable_attrs()
442
438
 
443
439
 
444
440
  def _log_model_version( # noqa: C901
445
441
  run: Optional["MlFoundryRun"],
446
442
  name: str,
447
- model_file_or_folder: Union[str, BlobStorageModelDirectory],
443
+ model_file_or_folder: Union[str, BlobStorageDirectory],
448
444
  mlfoundry_artifacts_api: Optional[MlfoundryArtifactsApi] = None,
449
445
  ml_repo_id: Optional[str] = None,
450
446
  additional_files: Sequence[Tuple[Union[str, Path], Optional[str]]] = (),
@@ -459,9 +455,9 @@ def _log_model_version( # noqa: C901
459
455
  "Exactly one of run, mlfoundry_artifacts_api should be passed"
460
456
  )
461
457
 
462
- if not isinstance(model_file_or_folder, (str, BlobStorageModelDirectory)):
458
+ if not isinstance(model_file_or_folder, (str, BlobStorageDirectory)):
463
459
  raise MlFoundryException(
464
- "model_file_or_folder should be of type str or BlobStorageModelDirectory"
460
+ "model_file_or_folder should be of type str or BlobStorageDirectory"
465
461
  )
466
462
 
467
463
  if mlfoundry_artifacts_api and not ml_repo_id:
@@ -522,11 +518,9 @@ def _log_model_version( # noqa: C901
522
518
  version_id = _create_artifact_version_response.id
523
519
  artifact_storage_root = _create_artifact_version_response.artifact_storage_root
524
520
  if isinstance(model_file_or_folder, str):
525
- # Source is of type TruefoundryModelSource
526
- source = ModelVersionSource.from_json(
527
- TruefoundryModelSource(
528
- type="truefoundry", uri=artifact_storage_root
529
- ).to_json()
521
+ # Source is of type TrueFoundryArtifactSource
522
+ source = TrueFoundryArtifactSource(
523
+ type="truefoundry", uri=artifact_storage_root
530
524
  )
531
525
  artifacts_repo = MlFoundryArtifactsRepository(
532
526
  artifact_identifier=ArtifactIdentifier(
@@ -555,20 +549,19 @@ def _log_model_version( # noqa: C901
555
549
  raise MlFoundryException("Failed to log model") from e
556
550
  finally:
557
551
  temp_dir.cleanup()
558
- elif isinstance(model_file_or_folder, BlobStorageModelDirectory):
559
- source = ModelVersionSource.from_json(
560
- ExternalModelSource(type="external", uri=model_file_or_folder.uri).to_json()
561
- )
552
+ elif isinstance(model_file_or_folder, BlobStorageDirectory):
553
+ source = ExternalArtifactSource(type="external", uri=model_file_or_folder.uri)
562
554
  else:
563
555
  raise MlFoundryException("Invalid model_file_or_folder provided")
564
556
 
565
557
  _framework = _ModelFramework.to_model_framework_type(framework)
558
+ _source_cls = typing.get_type_hints(ModelVersionManifest)["source"]
566
559
  model_manifest = ModelVersionManifest(
567
560
  description=description,
568
561
  metadata=metadata,
569
- source=source,
562
+ source=_source_cls.from_dict(source.dict()),
570
563
  framework=Framework.from_dict(_framework.dict()) if _framework else None,
571
- step=step if run else 0,
564
+ step=step,
572
565
  )
573
566
  artifact_version_response = mlfoundry_artifacts_api.finalize_artifact_version_post(
574
567
  finalize_artifact_version_request_dto=FinalizeArtifactVersionRequestDto(
@@ -578,7 +571,7 @@ def _log_model_version( # noqa: C901
578
571
  artifact_metadata=metadata,
579
572
  internal_metadata=None,
580
573
  step=model_manifest.step,
581
- manifest=model_manifest,
574
+ manifest=Manifest.from_dict(model_manifest.to_dict()),
582
575
  )
583
576
  )
584
577
  return ModelVersion.from_fqn(fqn=artifact_version_response.artifact_version.fqn)
@@ -48,12 +48,12 @@ from truefoundry.ml.internal_namespace import NAMESPACE
48
48
  from truefoundry.ml.log_types.artifacts.artifact import (
49
49
  ArtifactPath,
50
50
  ArtifactVersion,
51
+ BlobStorageDirectory,
51
52
  ChatPromptVersion,
52
53
  )
53
54
  from truefoundry.ml.log_types.artifacts.dataset import DataDirectory
54
55
  from truefoundry.ml.log_types.artifacts.general_artifact import _log_artifact_version
55
56
  from truefoundry.ml.log_types.artifacts.model import (
56
- BlobStorageModelDirectory,
57
57
  ModelVersion,
58
58
  _log_model_version,
59
59
  )
@@ -1226,7 +1226,7 @@ class MlFoundry:
1226
1226
  *,
1227
1227
  ml_repo: str,
1228
1228
  name: str,
1229
- model_file_or_folder: Union[str, BlobStorageModelDirectory],
1229
+ model_file_or_folder: Union[str, BlobStorageDirectory],
1230
1230
  additional_files: Sequence[Tuple[Union[str, Path], Optional[str]]] = (),
1231
1231
  description: Optional[str] = None,
1232
1232
  metadata: Optional[Dict[str, Any]] = None,
@@ -1244,12 +1244,12 @@ class MlFoundry:
1244
1244
  the logged model will be added as a new version under that `name`. If no models exist with the given
1245
1245
  `name`, the given model will be logged as version 1.
1246
1246
 
1247
- model_file_or_folder (Union[str, BlobStorageModelDirectory]):
1247
+ model_file_or_folder (Union[str, BlobStorageDirectory]):
1248
1248
  str:
1249
1249
  Path to either a single file or a folder containing model files.
1250
1250
  This folder is typically created using serialization methods from libraries or frameworks,
1251
1251
  e.g., `joblib.dump`, `model.save_pretrained(...)`, `torch.save(...)`, or `model.save(...)`.
1252
- BlobStorageModelDirectory:
1252
+ BlobStorageDirectory:
1253
1253
  uri (str): URI to the model file or folder in a storage integration associated with the specified ML Repo.
1254
1254
  The model files or folder must reside within the same storage integration as the specified ML Repo.
1255
1255
  Accepted URI formats include `s3://integration-bucket-name/prefix/path/to/model` or `gs://integration-bucket-name/prefix/path/to/model`.
@@ -44,10 +44,13 @@ from truefoundry.ml.enums import ModelFramework, RunStatus
44
44
  from truefoundry.ml.exceptions import MlFoundryException
45
45
  from truefoundry.ml.internal_namespace import NAMESPACE
46
46
  from truefoundry.ml.log_types import Image, Plot
47
- from truefoundry.ml.log_types.artifacts.artifact import ArtifactPath, ArtifactVersion
47
+ from truefoundry.ml.log_types.artifacts.artifact import (
48
+ ArtifactPath,
49
+ ArtifactVersion,
50
+ BlobStorageDirectory,
51
+ )
48
52
  from truefoundry.ml.log_types.artifacts.general_artifact import _log_artifact_version
49
53
  from truefoundry.ml.log_types.artifacts.model import (
50
- BlobStorageModelDirectory,
51
54
  ModelVersion,
52
55
  _log_model_version,
53
56
  )
@@ -926,7 +929,7 @@ class MlFoundryRun:
926
929
  self,
927
930
  *,
928
931
  name: str,
929
- model_file_or_folder: Union[str, BlobStorageModelDirectory],
932
+ model_file_or_folder: Union[str, BlobStorageDirectory],
930
933
  additional_files: Sequence[Tuple[Union[str, Path], Optional[str]]] = (),
931
934
  description: Optional[str] = None,
932
935
  metadata: Optional[Dict[str, Any]] = None,
@@ -945,12 +948,12 @@ class MlFoundryRun:
945
948
  name (str): Name of the model. If a model with this name already exists under the current ML Repo,
946
949
  the logged model will be added as a new version under that `name`. If no models exist with the given
947
950
  `name`, the given model will be logged as version 1.
948
- model_file_or_folder (Union[str, BlobStorageModelDirectory]):
951
+ model_file_or_folder (Union[str, BlobStorageDirectory]):
949
952
  str:
950
953
  Path to either a single file or a folder containing model files.
951
954
  This folder is typically created using serialization methods from libraries or frameworks,
952
955
  e.g., `joblib.dump`, `model.save_pretrained(...)`, `torch.save(...)`, or `model.save(...)`.
953
- BlobStorageModelDirectory:
956
+ BlobStorageDirectory:
954
957
  uri (str): URI to the model file or folder in a storage integration associated with the specified ML Repo.
955
958
  The model files or folder must reside within the same storage integration as the specified ML Repo.
956
959
  Accepted URI formats include `s3://integration-bucket-name/prefix/path/to/model` or `gs://integration-bucket-name/prefix/path/to/model`.
@@ -1,5 +1,5 @@
1
1
  import warnings
2
- from typing import Literal, Optional, Union, get_args
2
+ from typing import Any, Dict, Literal, Optional, Union, get_args
3
3
 
4
4
  from truefoundry.ml import ModelFramework
5
5
  from truefoundry.ml.autogen.entities import artifacts as autogen_artifacts
@@ -161,7 +161,7 @@ class _ModelFramework(BaseModel):
161
161
  )
162
162
 
163
163
  @classmethod
164
- def from_dict(cls, obj: dict) -> ModelFrameworkType:
164
+ def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[ModelFrameworkType]:
165
165
  """Create an instance of ModelFramework from a dict"""
166
166
  if obj is None:
167
167
  return None
@@ -1,4 +1,6 @@
1
1
  # file: client.py
2
+ import io
3
+ import os
2
4
  from enum import Enum
3
5
  from typing import Any, Dict, List, Optional, Union
4
6
  from urllib.parse import urlencode, urljoin
@@ -125,7 +127,12 @@ class SignedURLClient:
125
127
  )
126
128
 
127
129
  @log_time(prefix=LOG_PREFIX)
128
- def _upload_data(self, signed_url: str, data: Any) -> None:
130
+ def _upload_data(
131
+ self,
132
+ signed_url: str,
133
+ data: Union[bytes, io.BufferedReader],
134
+ headers: Optional[Dict] = None,
135
+ ) -> None:
129
136
  """
130
137
  Upload data to the specified storage path using a signed URL.
131
138
 
@@ -133,10 +140,15 @@ class SignedURLClient:
133
140
  signed_url: str: The signed URL to upload the data to.
134
141
  data: Bytes or IO: The data to upload.
135
142
  """
143
+ if isinstance(data, io.BufferedReader):
144
+ if os.fstat(data.fileno()).st_size == 0:
145
+ data = b""
136
146
  try:
147
+ headers = headers or {}
148
+ headers["Content-Type"] = "application/octet-stream"
137
149
  response = self.session.put(
138
150
  url=signed_url,
139
- headers={"Content-Type": "application/octet-stream"},
151
+ headers=headers,
140
152
  data=data,
141
153
  timeout=REQUEST_TIMEOUT,
142
154
  )
@@ -153,7 +165,11 @@ class SignedURLClient:
153
165
  headers=self.signed_url_server_headers,
154
166
  )
155
167
  pre_signed_object_dto = SignedURLAPIResponseDto.parse_obj(signed_object)
156
- self._upload_data(pre_signed_object_dto.signed_url, data)
168
+ self._upload_data(
169
+ signed_url=pre_signed_object_dto.signed_url,
170
+ data=data,
171
+ headers=pre_signed_object_dto.headers,
172
+ )
157
173
  return storage_uri
158
174
 
159
175
  @log_time(prefix=LOG_PREFIX)
@@ -167,19 +183,30 @@ class SignedURLClient:
167
183
  )
168
184
  pre_signed_object_dto = SignedURLAPIResponseDto.parse_obj(response)
169
185
  with open(file_path, "rb") as file:
170
- self._upload_data(pre_signed_object_dto.signed_url, file)
186
+ self._upload_data(
187
+ signed_url=pre_signed_object_dto.signed_url,
188
+ data=file,
189
+ headers=pre_signed_object_dto.headers,
190
+ )
171
191
  return storage_uri
172
192
 
173
193
  @log_time(prefix=LOG_PREFIX)
174
194
  def _download_file(
175
- self, signed_url: str, local_path: Optional[str] = None
195
+ self,
196
+ signed_url: str,
197
+ local_path: Optional[str] = None,
198
+ headers: Optional[Dict] = None,
176
199
  ) -> Optional[bytes]:
177
200
  """Common method to download a file using a signed URL."""
178
201
  try:
202
+ if headers is None:
203
+ headers = {"Content-Type": "application/octet-stream"}
204
+ else:
205
+ headers["Content-Type"] = "application/octet-stream"
179
206
  response = self.session.get(
180
207
  signed_url,
181
208
  stream=True,
182
- headers={"Content-Type": "application/octet-stream"},
209
+ headers=headers,
183
210
  timeout=REQUEST_TIMEOUT,
184
211
  )
185
212
  response.raise_for_status()
@@ -202,7 +229,11 @@ class SignedURLClient:
202
229
  headers=self.signed_url_server_headers,
203
230
  )
204
231
  presigned_object = SignedURLAPIResponseDto.parse_obj(response)
205
- self._download_file(presigned_object.signed_url, local_path)
232
+ self._download_file(
233
+ signed_url=presigned_object.signed_url,
234
+ local_path=local_path,
235
+ headers=presigned_object.headers,
236
+ )
206
237
  return local_path
207
238
 
208
239
  @log_time(prefix=LOG_PREFIX)
@@ -214,7 +245,9 @@ class SignedURLClient:
214
245
  headers=self.signed_url_server_headers,
215
246
  )
216
247
  presigned_object = SignedURLAPIResponseDto.parse_obj(response)
217
- return self._download_file(presigned_object.signed_url)
248
+ return self._download_file(
249
+ signed_url=presigned_object.signed_url, headers=presigned_object.headers
250
+ )
218
251
 
219
252
  @log_time(prefix=LOG_PREFIX)
220
253
  def exists(self, uri: str) -> bool:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: truefoundry
3
- Version: 0.5.0rc2
3
+ Version: 0.5.0rc4
4
4
  Summary: Truefoundry CLI
5
5
  Author: Abhishek Choudhary
6
6
  Author-email: abhishek@truefoundry.com
@@ -21,7 +21,7 @@ Requires-Dist: coolname (>=1.1.0,<2.0.0)
21
21
  Requires-Dist: docker (>=6.1.2,<8.0.0)
22
22
  Requires-Dist: fastapi (>=0.56.0,<0.200.0)
23
23
  Requires-Dist: filelock (>=3.8.0,<4.0.0)
24
- Requires-Dist: flytekit (==1.12.2) ; extra == "workflow"
24
+ Requires-Dist: flytekit (==1.13.13) ; extra == "workflow"
25
25
  Requires-Dist: gitignorefile (>=1.1.2,<2.0.0)
26
26
  Requires-Dist: importlib-metadata (>=4.11.3,<9.0.0)
27
27
  Requires-Dist: importlib-resources (>=5.2.0,<7.0.0)