truefoundry 0.4.4rc12__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truefoundry might be problematic. Click here for more details.

Files changed (55) hide show
  1. truefoundry/common/constants.py +6 -1
  2. truefoundry/common/utils.py +0 -18
  3. truefoundry/logger.py +1 -0
  4. truefoundry/ml/__init__.py +36 -0
  5. truefoundry/ml/artifact/truefoundry_artifact_repo.py +433 -415
  6. truefoundry/ml/autogen/client/__init__.py +29 -6
  7. truefoundry/ml/autogen/client/api/__init__.py +3 -3
  8. truefoundry/ml/autogen/client/api/deprecated_api.py +7 -7
  9. truefoundry/ml/autogen/client/api/generate_code_snippet_api.py +526 -0
  10. truefoundry/ml/autogen/client/models/__init__.py +26 -3
  11. truefoundry/ml/autogen/client/models/command.py +152 -0
  12. truefoundry/ml/autogen/client/models/create_workflow_task_config_request_dto.py +72 -0
  13. truefoundry/ml/autogen/client/models/external_model_source.py +3 -2
  14. truefoundry/ml/autogen/client/models/fast_ai_framework.py +75 -0
  15. truefoundry/ml/autogen/client/models/framework.py +250 -14
  16. truefoundry/ml/autogen/client/models/gluon_framework.py +74 -0
  17. truefoundry/ml/autogen/client/models/{upload_model_source.py → h2_o_framework.py} +11 -11
  18. truefoundry/ml/autogen/client/models/keras_framework.py +74 -0
  19. truefoundry/ml/autogen/client/models/light_gbm_framework.py +75 -0
  20. truefoundry/ml/autogen/client/models/model_version_manifest.py +1 -1
  21. truefoundry/ml/autogen/client/models/onnx_framework.py +74 -0
  22. truefoundry/ml/autogen/client/models/paddle_framework.py +75 -0
  23. truefoundry/ml/autogen/client/models/py_torch_framework.py +75 -0
  24. truefoundry/ml/autogen/client/models/sklearn_framework.py +75 -0
  25. truefoundry/ml/autogen/client/models/source.py +9 -32
  26. truefoundry/ml/autogen/client/models/spa_cy_framework.py +74 -0
  27. truefoundry/ml/autogen/client/models/stats_models_framework.py +75 -0
  28. truefoundry/ml/autogen/client/models/{tensorflow_framework.py → tensor_flow_framework.py} +10 -9
  29. truefoundry/ml/autogen/client/models/transformers_framework.py +3 -2
  30. truefoundry/ml/autogen/client/models/trigger_job_run_config_request_dto.py +90 -0
  31. truefoundry/ml/autogen/client/models/trigger_job_run_config_response_dto.py +71 -0
  32. truefoundry/ml/autogen/client/models/truefoundry_model_source.py +5 -3
  33. truefoundry/ml/autogen/client/models/xg_boost_framework.py +75 -0
  34. truefoundry/ml/autogen/client_README.md +22 -5
  35. truefoundry/ml/autogen/entities/artifacts.py +19 -2
  36. truefoundry/ml/log_types/artifacts/artifact.py +10 -6
  37. truefoundry/ml/log_types/artifacts/dataset.py +13 -5
  38. truefoundry/ml/log_types/artifacts/general_artifact.py +3 -1
  39. truefoundry/ml/log_types/artifacts/model.py +172 -194
  40. truefoundry/ml/log_types/artifacts/utils.py +43 -26
  41. truefoundry/ml/log_types/image/image.py +2 -0
  42. truefoundry/ml/log_types/plot.py +2 -0
  43. truefoundry/ml/mlfoundry_api.py +47 -18
  44. truefoundry/ml/mlfoundry_run.py +27 -12
  45. truefoundry/ml/model_framework.py +169 -0
  46. truefoundry/workflow/__init__.py +3 -1
  47. truefoundry/workflow/remote_filesystem/__init__.py +8 -0
  48. truefoundry/workflow/remote_filesystem/logger.py +36 -0
  49. truefoundry/{common → workflow/remote_filesystem}/tfy_signed_url_client.py +1 -2
  50. truefoundry/{common → workflow/remote_filesystem}/tfy_signed_url_fs.py +5 -2
  51. {truefoundry-0.4.4rc12.dist-info → truefoundry-0.5.0rc1.dist-info}/METADATA +1 -1
  52. {truefoundry-0.4.4rc12.dist-info → truefoundry-0.5.0rc1.dist-info}/RECORD +54 -36
  53. truefoundry/ml/autogen/client/api/python_deployment_config_api.py +0 -201
  54. {truefoundry-0.4.4rc12.dist-info → truefoundry-0.5.0rc1.dist-info}/WHEEL +0 -0
  55. {truefoundry-0.4.4rc12.dist-info → truefoundry-0.5.0rc1.dist-info}/entry_points.txt +0 -0
@@ -29,9 +29,10 @@ from truefoundry.ml.autogen.entities.artifacts import ChatPrompt
29
29
  from truefoundry.ml.exceptions import MlFoundryException
30
30
  from truefoundry.ml.log_types.artifacts.constants import INTERNAL_METADATA_PATH
31
31
  from truefoundry.ml.log_types.artifacts.utils import (
32
+ _get_src_dest_pairs,
32
33
  _validate_artifact_metadata,
33
34
  _validate_description,
34
- calculate_local_directory_size,
35
+ calculate_total_size,
35
36
  )
36
37
  from truefoundry.ml.logger import logger
37
38
  from truefoundry.ml.session import _get_api_client
@@ -397,7 +398,7 @@ class ChatPromptVersion(ArtifactVersion):
397
398
  @property
398
399
  def extra_parameters(self) -> Dict[str, Any]:
399
400
  _extra_parameters = self._chat_prompt.model_configuration.extra_parameters
400
- return _extra_parameters.dict(exclude_unset=True) if _extra_parameters else {}
401
+ return _extra_parameters if _extra_parameters else {}
401
402
 
402
403
  @property
403
404
  def variables(self) -> Dict[str, Any]:
@@ -409,6 +410,7 @@ def _log_artifact_version_helper(
409
410
  name: str,
410
411
  artifact_type: ArtifactType,
411
412
  artifact_dir: tempfile.TemporaryDirectory,
413
+ dest_to_src_map: Dict[str, str],
412
414
  mlfoundry_artifacts_api: Optional[MlfoundryArtifactsApi] = None,
413
415
  ml_repo_id: Optional[str] = None,
414
416
  description: Optional[str] = None,
@@ -443,15 +445,17 @@ def _log_artifact_version_helper(
443
445
  ),
444
446
  api_client=mlfoundry_artifacts_api.api_client,
445
447
  )
446
- total_size = calculate_local_directory_size(artifact_dir)
448
+
449
+ total_size = calculate_total_size(list(dest_to_src_map.values()))
447
450
  try:
448
451
  logger.info(
449
- "Packaging and uploading files to remote with Artifact Size: %.6f MB",
452
+ "Packaging and uploading files to remote with size: %.6f MB",
450
453
  total_size / 1000000.0,
451
454
  )
452
- artifacts_repo.log_artifacts(
453
- local_dir=artifact_dir.name, artifact_path=None, progress=progress
455
+ src_dest_pairs = _get_src_dest_pairs(
456
+ root_dir=artifact_dir.name, dest_to_src_map=dest_to_src_map
454
457
  )
458
+ artifacts_repo.log_artifacts(src_dest_pairs=src_dest_pairs, progress=progress)
455
459
  except Exception as e:
456
460
  mlfoundry_artifacts_api.notify_failure_post(
457
461
  notify_artifact_version_failure_dto=NotifyArtifactVersionFailureDto(
@@ -19,9 +19,10 @@ from truefoundry.ml.entities import FileInfo
19
19
  from truefoundry.ml.exceptions import MlFoundryException
20
20
  from truefoundry.ml.log_types.artifacts.utils import (
21
21
  _copy_additional_files,
22
+ _get_src_dest_pairs,
22
23
  _validate_artifact_metadata,
23
24
  _validate_description,
24
- calculate_local_directory_size,
25
+ calculate_total_size,
25
26
  )
26
27
  from truefoundry.ml.logger import logger
27
28
  from truefoundry.ml.session import _get_api_client
@@ -139,6 +140,7 @@ class DataDirectory:
139
140
  file_paths: List[
140
141
  Union[Tuple[str], Tuple[str, Optional[str]], DataDirectoryPath]
141
142
  ],
143
+ progress: Optional[bool] = None,
142
144
  ) -> None:
143
145
  """Logs File in the `DataDirectory`.
144
146
 
@@ -207,7 +209,7 @@ class DataDirectory:
207
209
 
208
210
  try:
209
211
  logger.info("Copying the files to add")
210
- _copy_additional_files(
212
+ temp_dest_to_src_map = _copy_additional_files(
211
213
  root_dir=temp_dir.name,
212
214
  files_dir="",
213
215
  model_dir=None,
@@ -219,13 +221,19 @@ class DataDirectory:
219
221
  raise MlFoundryException("Failed to Add Files to DataDirectory") from e
220
222
 
221
223
  artifacts_repo = self._get_artifacts_repo()
222
- total_size = calculate_local_directory_size(temp_dir)
224
+ total_size = calculate_total_size(list(temp_dest_to_src_map.values()))
223
225
  try:
224
226
  logger.info(
225
- "Packaging and uploading files to remote with Size: %.6f MB",
227
+ "Packaging and uploading files to remote with size: %.6f MB",
226
228
  total_size / 1000000.0,
227
229
  )
228
- artifacts_repo.log_artifacts(local_dir=temp_dir.name, artifact_path=None)
230
+ src_dest_pairs = _get_src_dest_pairs(
231
+ root_dir=temp_dir.name, dest_to_src_map=temp_dest_to_src_map
232
+ )
233
+ artifacts_repo.log_artifacts(
234
+ src_dest_pairs=src_dest_pairs,
235
+ progress=progress,
236
+ )
229
237
  except Exception as e:
230
238
  raise MlFoundryException("Failed to Add Files to DataDirectory") from e
231
239
  finally:
@@ -77,7 +77,7 @@ def _log_artifact_version(
77
77
  os.makedirs(local_files_dir, exist_ok=True)
78
78
 
79
79
  logger.info("Copying the files to log")
80
- _copy_additional_files(
80
+ temp_dest_to_src_map = _copy_additional_files(
81
81
  root_dir=temp_dir.name,
82
82
  files_dir=internal_metadata.files_dir,
83
83
  model_dir=None,
@@ -94,6 +94,7 @@ def _log_artifact_version(
94
94
  os.makedirs(os.path.dirname(local_internal_metadata_path), exist_ok=True)
95
95
  with open(local_internal_metadata_path, "w") as f:
96
96
  json.dump(internal_metadata.dict(), f)
97
+ temp_dest_to_src_map[local_internal_metadata_path] = local_internal_metadata_path
97
98
 
98
99
  return _log_artifact_version_helper(
99
100
  run=run,
@@ -101,6 +102,7 @@ def _log_artifact_version(
101
102
  name=name,
102
103
  artifact_type=ArtifactType.ARTIFACT,
103
104
  artifact_dir=temp_dir,
105
+ dest_to_src_map=temp_dest_to_src_map,
104
106
  mlfoundry_artifacts_api=mlfoundry_artifacts_api,
105
107
  description=description,
106
108
  internal_metadata=internal_metadata,
@@ -5,57 +5,57 @@ import logging
5
5
  import os.path
6
6
  import tempfile
7
7
  import uuid
8
+ import warnings
8
9
  from pathlib import Path
9
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
10
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union
10
11
 
11
12
  from truefoundry.ml.artifact.truefoundry_artifact_repo import (
12
13
  ArtifactIdentifier,
13
14
  MlFoundryArtifactsRepository,
14
15
  )
15
16
  from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
16
- AddCustomMetricsToModelVersionRequestDto,
17
17
  ArtifactType,
18
+ ArtifactVersionSerializationFormat,
18
19
  CreateArtifactVersionRequestDto,
19
- CreateModelVersionRequestDto,
20
20
  DeleteArtifactVersionsRequestDto,
21
+ ExternalModelSource,
21
22
  FinalizeArtifactVersionRequestDto,
22
- MetricDto,
23
+ Framework,
23
24
  MlfoundryArtifactsApi,
24
25
  ModelDto,
25
26
  ModelVersionDto,
27
+ ModelVersionManifest,
26
28
  NotifyArtifactVersionFailureDto,
29
+ TruefoundryModelSource,
27
30
  UpdateModelVersionRequestDto,
28
31
  )
29
- from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
30
- InternalMetadata as InternalMetadataDto,
32
+ from truefoundry.ml.autogen.client import (
33
+ Source as ModelVersionSource,
31
34
  )
32
35
  from truefoundry.ml.enums import ModelFramework
33
36
  from truefoundry.ml.exceptions import MlFoundryException
34
37
  from truefoundry.ml.log_types.artifacts.constants import (
35
- FILES_DIR,
36
38
  INTERNAL_METADATA_PATH,
37
- MODEL_DIR_NAME,
38
- MODEL_SCHEMA_UPDATE_FAILURE_HELP,
39
39
  )
40
- from truefoundry.ml.log_types.artifacts.model_extras import CustomMetric, ModelSchema
41
40
  from truefoundry.ml.log_types.artifacts.utils import (
42
41
  _copy_additional_files,
42
+ _get_src_dest_pairs,
43
43
  _validate_artifact_metadata,
44
44
  _validate_description,
45
+ calculate_total_size,
45
46
  )
47
+ from truefoundry.ml.model_framework import ModelFrameworkType, _ModelFramework
46
48
  from truefoundry.ml.session import _get_api_client
47
- from truefoundry.pydantic_v1 import BaseModel, Extra
48
- from truefoundry.version import __version__
49
+ from truefoundry.pydantic_v1 import BaseModel, Extra, StrictStr
49
50
 
50
51
  if TYPE_CHECKING:
51
52
  from truefoundry.ml.mlfoundry_run import MlFoundryRun
52
53
 
53
- logger = logging.getLogger("truefoundry.ml")
54
-
55
54
 
56
- # TODO: Support async download and upload
55
+ logger = logging.getLogger(__name__)
57
56
 
58
57
 
58
+ # TODO: Support async download and upload
59
59
  class ModelVersionInternalMetadata(BaseModel):
60
60
  class Config:
61
61
  extra = Extra.allow
@@ -75,6 +75,10 @@ class ModelVersionInternalMetadata(BaseModel):
75
75
  return dct
76
76
 
77
77
 
78
+ class BlobStorageModelDirectory(BaseModel):
79
+ uri: StrictStr
80
+
81
+
78
82
  class ModelVersionDownloadInfo(BaseModel):
79
83
  download_dir: str
80
84
  model_dir: str
@@ -97,8 +101,7 @@ class ModelVersion:
97
101
  self._deleted = False
98
102
  self._description: str = ""
99
103
  self._metadata: Dict[str, Any] = {}
100
- self._metrics: List[MetricDto] = []
101
- self._set_metrics_attr()
104
+ self._framework: Optional["Framework"] = None
102
105
  self._set_mutable_attrs()
103
106
 
104
107
  @classmethod
@@ -136,21 +139,20 @@ class ModelVersion:
136
139
  "Model Version was deleted, cannot perform updates on a deleted version"
137
140
  )
138
141
 
139
- def _set_metrics_attr(self):
140
- self._metrics = sorted(
141
- self._model_version.metrics or [], key=lambda m: m.timestamp
142
- )
143
-
144
142
  def _set_mutable_attrs(self):
145
143
  self._description = self._model_version.description or ""
146
144
  self._metadata = copy.deepcopy(self._model_version.artifact_metadata)
145
+ self._framework = (
146
+ copy.deepcopy(self._model_version.manifest.framework)
147
+ if self._model_version.manifest
148
+ else None
149
+ )
147
150
 
148
151
  def _refetch_model_version(self):
149
152
  _model_version = self._mlfoundry_artifacts_api.get_model_version_get(
150
153
  id=self._model_version.id
151
154
  )
152
155
  self._model_version = _model_version.model_version
153
- self._set_metrics_attr()
154
156
  self._set_mutable_attrs()
155
157
 
156
158
  def __repr__(self):
@@ -199,6 +201,8 @@ class ModelVersion:
199
201
  """set the description of the model"""
200
202
  _validate_description(value)
201
203
  self._description = value
204
+ if self._model_version.manifest:
205
+ self._model_version.manifest.description = value
202
206
 
203
207
  @property
204
208
  def metadata(self) -> Dict[str, Any]:
@@ -210,14 +214,43 @@ class ModelVersion:
210
214
  """set the metadata for current model"""
211
215
  _validate_artifact_metadata(value)
212
216
  self._metadata = value
217
+ if self._model_version.manifest:
218
+ self._model_version.manifest.metadata = value
213
219
 
214
220
  @property
215
- def metrics(self) -> Dict[str, Union[float, int]]:
216
- """get the metrics for the current version of the model"""
217
- metrics_as_kv: Dict[str, Union[float, int]] = {}
218
- for metric in self._metrics:
219
- metrics_as_kv[metric.key] = metric.value
220
- return metrics_as_kv
221
+ def framework(self) -> Optional["ModelFrameworkType"]:
222
+ """Get the framework of the model"""
223
+ return (
224
+ _ModelFramework.from_dict(self._framework.actual_instance.dict())
225
+ if self._framework and self._framework.actual_instance
226
+ else None
227
+ )
228
+
229
+ @framework.setter
230
+ def framework(
231
+ self, value: Optional[Union[str, ModelFramework, "ModelFrameworkType"]]
232
+ ):
233
+ """set the framework of the model"""
234
+ model_framework_type: Optional[ModelFrameworkType] = None
235
+
236
+ if value is not None:
237
+ model_framework_type = _ModelFramework.to_model_framework_type(value)
238
+
239
+ if (
240
+ self._model_version.serialization_format
241
+ == ArtifactVersionSerializationFormat.V2
242
+ ):
243
+ self._model_version.manifest.framework = self._framework = (
244
+ Framework.from_dict(model_framework_type.dict())
245
+ if model_framework_type
246
+ else None
247
+ )
248
+ else:
249
+ warnings.warn(
250
+ message="This model version was created using an older serialization format. Framework will not be updated",
251
+ category=DeprecationWarning,
252
+ stacklevel=2,
253
+ )
221
254
 
222
255
  @property
223
256
  def created_at(self) -> datetime.datetime:
@@ -272,11 +305,29 @@ class ModelVersion:
272
305
  path: Optional[Union[str, Path]],
273
306
  overwrite: bool = False,
274
307
  progress: Optional[bool] = None,
275
- ) -> Tuple[ModelVersionInternalMetadata, ModelVersionDownloadInfo]:
308
+ ) -> ModelVersionDownloadInfo:
276
309
  self._ensure_not_deleted()
277
310
  download_dir = self.raw_download(
278
311
  path=path, overwrite=overwrite, progress=progress
279
312
  )
313
+
314
+ if (
315
+ self._model_version.serialization_format
316
+ == ArtifactVersionSerializationFormat.V2
317
+ ):
318
+ model_framework = (
319
+ ModelFramework(self.framework.type)
320
+ if self.framework
321
+ else ModelFramework.UNKNOWN
322
+ )
323
+ download_info = ModelVersionDownloadInfo(
324
+ download_dir=download_dir,
325
+ model_dir=download_dir,
326
+ model_framework=model_framework,
327
+ model_filename=None,
328
+ )
329
+ return download_info
330
+
280
331
  internal_metadata_path = os.path.join(download_dir, INTERNAL_METADATA_PATH)
281
332
  if not os.path.exists(internal_metadata_path):
282
333
  raise MlFoundryException(
@@ -288,12 +339,14 @@ class ModelVersion:
288
339
  download_info = ModelVersionDownloadInfo(
289
340
  download_dir=os.path.join(download_dir, internal_metadata.files_dir),
290
341
  model_dir=os.path.join(
291
- download_dir, internal_metadata.files_dir, internal_metadata.model_dir
342
+ download_dir,
343
+ internal_metadata.files_dir,
344
+ internal_metadata.model_dir,
292
345
  ),
293
346
  model_framework=internal_metadata.framework,
294
347
  model_filename=internal_metadata.model_filename,
295
348
  )
296
- return internal_metadata, download_info
349
+ return download_info
297
350
 
298
351
  def download(
299
352
  self,
@@ -329,7 +382,7 @@ class ModelVersion:
329
382
  print(download_info.model_dir)
330
383
  ```
331
384
  """
332
- _, download_info = self._download(
385
+ download_info = self._download(
333
386
  path=path, overwrite=overwrite, progress=progress
334
387
  )
335
388
  return download_info
@@ -381,50 +434,36 @@ class ModelVersion:
381
434
  id=self._model_version.id,
382
435
  description=self.description,
383
436
  artifact_metadata=self.metadata,
437
+ manifest=self._model_version.manifest,
384
438
  )
385
439
  )
386
440
  self._model_version = _model_version.model_version
387
- self._set_metrics_attr()
388
441
  self._set_mutable_attrs()
389
442
 
390
443
 
391
- def calculate_model_size(artifact_dir: tempfile.TemporaryDirectory):
392
- """
393
- Tells about the size of the model
394
-
395
- Args:
396
- artifact_dir (str): directory in which model is present.
397
-
398
- Returns:
399
- total size of the model
400
- """
401
- total_size = 0
402
- for path, _dirs, files in os.walk(artifact_dir.name):
403
- for f in files:
404
- file_path = os.path.join(path, f)
405
- total_size += os.stat(file_path).st_size
406
- return total_size
407
-
408
-
409
444
  def _log_model_version( # noqa: C901
410
445
  run: Optional["MlFoundryRun"],
411
446
  name: str,
412
- model_file_or_folder: str,
413
- framework: Optional[Union[ModelFramework, str]],
447
+ model_file_or_folder: Union[str, BlobStorageModelDirectory],
414
448
  mlfoundry_artifacts_api: Optional[MlfoundryArtifactsApi] = None,
415
449
  ml_repo_id: Optional[str] = None,
416
450
  additional_files: Sequence[Tuple[Union[str, Path], Optional[str]]] = (),
417
451
  description: Optional[str] = None,
418
452
  metadata: Optional[Dict[str, Any]] = None,
419
- model_schema: Optional[Union[Dict[str, Any], ModelSchema]] = None,
420
- custom_metrics: Optional[List[Union[CustomMetric, Dict[str, Any]]]] = None,
421
453
  step: Optional[int] = 0,
422
454
  progress: Optional[bool] = None,
455
+ framework: Optional[Union[str, ModelFramework, "ModelFrameworkType"]] = None,
423
456
  ) -> ModelVersion:
424
457
  if (run and mlfoundry_artifacts_api) or (not run and not mlfoundry_artifacts_api):
425
458
  raise MlFoundryException(
426
459
  "Exactly one of run, mlfoundry_artifacts_api should be passed"
427
460
  )
461
+
462
+ if not isinstance(model_file_or_folder, (str, BlobStorageModelDirectory)):
463
+ raise MlFoundryException(
464
+ "model_file_or_folder should be of type str or BlobStorageModelDirectory"
465
+ )
466
+
428
467
  if mlfoundry_artifacts_api and not ml_repo_id:
429
468
  raise MlFoundryException(
430
469
  "If mlfoundry_artifacts_api is passed, ml_repo_id must also be passed"
@@ -434,87 +473,41 @@ def _log_model_version( # noqa: C901
434
473
 
435
474
  assert mlfoundry_artifacts_api is not None
436
475
 
437
- custom_metrics = custom_metrics or []
476
+ step = step or 0
477
+ total_size = None
438
478
  metadata = metadata or {}
439
479
  additional_files = additional_files or {}
440
- step = step or 0
441
-
442
- # validations
443
- if framework is None:
444
- framework = ModelFramework.UNKNOWN
445
- elif not isinstance(framework, ModelFramework):
446
- framework = ModelFramework(framework)
447
480
 
448
481
  _validate_description(description)
449
482
  _validate_artifact_metadata(metadata)
450
483
 
451
- if model_schema is not None and not isinstance(model_schema, ModelSchema):
452
- model_schema = ModelSchema.parse_obj(model_schema)
453
-
454
- if custom_metrics and not model_schema:
455
- raise MlFoundryException(
456
- "Custom Metrics defined without adding the Model Schema"
457
- )
458
- custom_metrics = [
459
- CustomMetric.parse_obj(cm) if not isinstance(cm, CustomMetric) else cm
460
- for cm in custom_metrics
461
- ]
462
-
463
- logger.info("Logging model and additional files, this might take a while ...")
464
- temp_dir = tempfile.TemporaryDirectory(prefix="truefoundry-")
465
-
466
- internal_metadata = ModelVersionInternalMetadata(
467
- framework=framework,
468
- files_dir=FILES_DIR,
469
- model_dir=MODEL_DIR_NAME,
470
- model_filename=(
471
- os.path.basename(model_file_or_folder)
472
- if model_file_or_folder and os.path.isfile(model_file_or_folder)
473
- else None
474
- ),
475
- mlfoundry_version=__version__,
476
- truefoundry_version=__version__,
477
- )
478
-
479
- try:
480
- local_files_dir = os.path.join(temp_dir.name, internal_metadata.files_dir)
481
- os.makedirs(local_files_dir, exist_ok=True)
482
- # in case model was None, we still create an empty dir
483
- local_model_dir = os.path.join(local_files_dir, internal_metadata.model_dir)
484
- os.makedirs(local_model_dir, exist_ok=True)
485
-
486
- logger.info("Adding model file/folder to model version content")
487
- model_file_or_folder = [
488
- (model_file_or_folder, MODEL_DIR_NAME.rstrip(os.sep) + os.sep)
489
- ]
490
- _copy_additional_files(
491
- root_dir=temp_dir.name,
492
- files_dir=internal_metadata.files_dir,
493
- model_dir=internal_metadata.model_dir,
494
- additional_files=model_file_or_folder,
495
- ignore_model_dir_dest_conflict=True,
496
- )
497
-
498
- # verify additional files and paths, copy additional files
499
- if additional_files:
500
- logger.info("Adding `additional_files` to model version contents")
501
- _copy_additional_files(
484
+ if isinstance(model_file_or_folder, str):
485
+ logger.info("Logging model and additional files, this might take a while ...")
486
+ temp_dir = tempfile.TemporaryDirectory(prefix="truefoundry-")
487
+ try:
488
+ logger.info("Adding model file/folder to model version content")
489
+ temp_dest_to_src_map = _copy_additional_files(
502
490
  root_dir=temp_dir.name,
503
- files_dir=internal_metadata.files_dir,
504
- model_dir=internal_metadata.model_dir,
505
- additional_files=additional_files,
506
- ignore_model_dir_dest_conflict=False,
491
+ files_dir="",
492
+ model_dir=None,
493
+ additional_files=[(model_file_or_folder, "")],
494
+ ignore_model_dir_dest_conflict=True,
507
495
  )
508
496
 
509
- except Exception as e:
510
- temp_dir.cleanup()
511
- raise MlFoundryException("Failed to log model") from e
512
-
513
- # save internal metadata
514
- local_internal_metadata_path = os.path.join(temp_dir.name, INTERNAL_METADATA_PATH)
515
- os.makedirs(os.path.dirname(local_internal_metadata_path), exist_ok=True)
516
- with open(local_internal_metadata_path, "w") as f:
517
- json.dump(internal_metadata.dict(), f)
497
+ # verify additional files and paths, copy additional files
498
+ if additional_files:
499
+ logger.info("Adding `additional_files` to model version contents")
500
+ temp_dest_to_src_map = _copy_additional_files(
501
+ root_dir=temp_dir.name,
502
+ files_dir="",
503
+ model_dir=None,
504
+ additional_files=additional_files,
505
+ ignore_model_dir_dest_conflict=False,
506
+ existing_dest_to_src_map=temp_dest_to_src_map,
507
+ )
508
+ except Exception as e:
509
+ temp_dir.cleanup()
510
+ raise MlFoundryException("Failed to log model") from e
518
511
 
519
512
  # create entry
520
513
  _create_artifact_version_response = (
@@ -527,80 +520,65 @@ def _log_model_version( # noqa: C901
527
520
  )
528
521
  )
529
522
  version_id = _create_artifact_version_response.id
530
- artifacts_repo = MlFoundryArtifactsRepository(
531
- artifact_identifier=ArtifactIdentifier(
532
- artifact_version_id=uuid.UUID(version_id)
533
- ),
534
- api_client=mlfoundry_artifacts_api.api_client,
535
- )
536
- model_size = calculate_model_size(temp_dir)
537
- try:
538
- logger.info(
539
- "Packaging and uploading files to remote with Total Size: %.6f MB",
540
- model_size / 1000000.0,
523
+ artifact_storage_root = _create_artifact_version_response.artifact_storage_root
524
+ if isinstance(model_file_or_folder, str):
525
+ # Source is of type TruefoundryModelSource
526
+ source = ModelVersionSource.from_json(
527
+ TruefoundryModelSource(
528
+ type="truefoundry", uri=artifact_storage_root
529
+ ).to_json()
541
530
  )
542
- artifacts_repo.log_artifacts(
543
- local_dir=temp_dir.name, artifact_path=None, progress=progress
531
+ artifacts_repo = MlFoundryArtifactsRepository(
532
+ artifact_identifier=ArtifactIdentifier(
533
+ artifact_version_id=uuid.UUID(version_id)
534
+ ),
535
+ api_client=mlfoundry_artifacts_api.api_client,
544
536
  )
545
- except Exception as e:
546
- mlfoundry_artifacts_api.notify_failure_post(
547
- notify_artifact_version_failure_dto=NotifyArtifactVersionFailureDto(
548
- id=version_id
537
+ total_size = calculate_total_size(list(temp_dest_to_src_map.values()))
538
+ try:
539
+ logger.info(
540
+ "Packaging and uploading files to remote with Total Size: %.6f MB",
541
+ total_size / 1000000.0,
542
+ )
543
+ src_dest_pairs = _get_src_dest_pairs(
544
+ root_dir=temp_dir.name, dest_to_src_map=temp_dest_to_src_map
545
+ )
546
+ artifacts_repo.log_artifacts(
547
+ src_dest_pairs=src_dest_pairs, progress=progress
548
+ )
549
+ except Exception as e:
550
+ mlfoundry_artifacts_api.notify_failure_post(
551
+ notify_artifact_version_failure_dto=NotifyArtifactVersionFailureDto(
552
+ id=version_id
553
+ )
549
554
  )
555
+ raise MlFoundryException("Failed to log model") from e
556
+ finally:
557
+ temp_dir.cleanup()
558
+ elif isinstance(model_file_or_folder, BlobStorageModelDirectory):
559
+ source = ModelVersionSource.from_json(
560
+ ExternalModelSource(type="external", uri=model_file_or_folder.uri).to_json()
550
561
  )
551
- raise MlFoundryException("Failed to log model") from e
552
- finally:
553
- temp_dir.cleanup()
554
-
555
- # Note: Here we call from_dict instead of directly passing in init and relying on it
556
- # to convert because the complicated union of types generates a custom type to handle casting
557
- # Check the source of `InternalMetadataDto` to see the generated code
558
- internal_metadata_dto = InternalMetadataDto.from_dict(
559
- internal_metadata.dict() if internal_metadata is not None else {}
562
+ else:
563
+ raise MlFoundryException("Invalid model_file_or_folder provided")
564
+
565
+ _framework = _ModelFramework.to_model_framework_type(framework)
566
+ model_manifest = ModelVersionManifest(
567
+ description=description,
568
+ metadata=metadata,
569
+ source=source,
570
+ framework=Framework.from_dict(_framework.dict()) if _framework else None,
571
+ step=step if run else 0,
560
572
  )
561
- mlfoundry_artifacts_api.finalize_artifact_version_post(
573
+ artifact_version_response = mlfoundry_artifacts_api.finalize_artifact_version_post(
562
574
  finalize_artifact_version_request_dto=FinalizeArtifactVersionRequestDto(
563
575
  id=version_id,
564
576
  run_uuid=run.run_id if run else None,
565
- artifact_size=model_size,
566
- internal_metadata=internal_metadata_dto,
567
- step=step if run else None,
568
- )
569
- )
570
- _model_version = mlfoundry_artifacts_api.create_model_version_post(
571
- create_model_version_request_dto=CreateModelVersionRequestDto(
572
- artifact_version_id=version_id,
573
- description=description,
577
+ artifact_size=total_size,
574
578
  artifact_metadata=metadata,
575
- internal_metadata=internal_metadata_dto,
576
- data_path=INTERNAL_METADATA_PATH,
577
- step=step if run else None,
579
+ internal_metadata=None,
580
+ step=model_manifest.step,
581
+ manifest=model_manifest,
578
582
  )
579
583
  )
580
- model_version = _model_version.model_version
581
-
582
- # update model schema at end
583
- update_args: Dict[str, Any] = {
584
- "id": version_id,
585
- "model_framework": framework.value,
586
- }
587
- if model_schema:
588
- update_args["model_schema"] = model_schema
589
-
590
- try:
591
- _model_version = mlfoundry_artifacts_api.update_model_version_post(
592
- update_model_version_request_dto=UpdateModelVersionRequestDto(**update_args)
593
- )
594
- model_version = _model_version.model_version
595
- if model_schema:
596
- _model_version = mlfoundry_artifacts_api.add_custom_metrics_to_model_version_post(
597
- add_custom_metrics_to_model_version_request_dto=AddCustomMetricsToModelVersionRequestDto(
598
- id=version_id, custom_metrics=custom_metrics
599
- )
600
- )
601
- model_version = _model_version.model_version
602
- except Exception:
603
- # TODO (chiragjn): what is the best exception to catch here?
604
- logger.error(MODEL_SCHEMA_UPDATE_FAILURE_HELP.format(fqn=model_version.fqn))
605
-
606
- return ModelVersion.from_fqn(fqn=model_version.fqn)
584
+ return ModelVersion.from_fqn(fqn=artifact_version_response.artifact_version.fqn)