truefoundry 0.4.10__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truefoundry might be problematic. Click here for more details.

Files changed (76) hide show
  1. truefoundry/deploy/builder/builders/tfy_python_buildpack/dockerfile_template.py +2 -2
  2. truefoundry/deploy/lib/dao/application.py +2 -1
  3. truefoundry/ml/__init__.py +41 -1
  4. truefoundry/ml/autogen/client/__init__.py +44 -14
  5. truefoundry/ml/autogen/client/api/__init__.py +3 -3
  6. truefoundry/ml/autogen/client/api/deprecated_api.py +333 -0
  7. truefoundry/ml/autogen/client/api/generate_code_snippet_api.py +526 -0
  8. truefoundry/ml/autogen/client/api/mlfoundry_artifacts_api.py +0 -322
  9. truefoundry/ml/autogen/client/api_client.py +8 -1
  10. truefoundry/ml/autogen/client/models/__init__.py +41 -11
  11. truefoundry/ml/autogen/client/models/add_features_to_model_version_request_dto.py +3 -17
  12. truefoundry/ml/autogen/client/models/agent.py +1 -1
  13. truefoundry/ml/autogen/client/models/agent_app.py +1 -1
  14. truefoundry/ml/autogen/client/models/agent_open_api_tool.py +1 -1
  15. truefoundry/ml/autogen/client/models/agent_open_api_tool_with_fqn.py +1 -1
  16. truefoundry/ml/autogen/client/models/agent_with_fqn.py +1 -1
  17. truefoundry/ml/autogen/client/models/artifact_version_dto.py +3 -5
  18. truefoundry/ml/autogen/client/models/artifact_version_manifest.py +111 -0
  19. truefoundry/ml/autogen/client/models/assistant_message.py +1 -1
  20. truefoundry/ml/autogen/client/models/blob_storage_reference.py +1 -1
  21. truefoundry/ml/autogen/client/models/chat_prompt.py +1 -1
  22. truefoundry/ml/autogen/client/models/command.py +152 -0
  23. truefoundry/ml/autogen/client/models/{feature_dto.py → create_workflow_task_config_request_dto.py} +18 -14
  24. truefoundry/ml/autogen/client/models/{external_model_source.py → external_artifact_source.py} +12 -11
  25. truefoundry/ml/autogen/client/models/fast_ai_framework.py +75 -0
  26. truefoundry/ml/autogen/client/models/finalize_artifact_version_request_dto.py +3 -5
  27. truefoundry/ml/autogen/client/models/framework.py +250 -14
  28. truefoundry/ml/autogen/client/models/gluon_framework.py +74 -0
  29. truefoundry/ml/autogen/client/models/{upload_model_source.py → h2_o_framework.py} +11 -11
  30. truefoundry/ml/autogen/client/models/image_content_part.py +1 -1
  31. truefoundry/ml/autogen/client/models/keras_framework.py +74 -0
  32. truefoundry/ml/autogen/client/models/light_gbm_framework.py +75 -0
  33. truefoundry/ml/autogen/client/models/manifest.py +154 -0
  34. truefoundry/ml/autogen/client/models/model_version_dto.py +7 -8
  35. truefoundry/ml/autogen/client/models/model_version_environment.py +97 -0
  36. truefoundry/ml/autogen/client/models/model_version_manifest.py +30 -6
  37. truefoundry/ml/autogen/client/models/onnx_framework.py +74 -0
  38. truefoundry/ml/autogen/client/models/paddle_framework.py +75 -0
  39. truefoundry/ml/autogen/client/models/py_torch_framework.py +75 -0
  40. truefoundry/ml/autogen/client/models/{feature_value_type.py → serialization_format.py} +8 -8
  41. truefoundry/ml/autogen/client/models/sklearn_framework.py +92 -0
  42. truefoundry/ml/autogen/client/models/source.py +23 -46
  43. truefoundry/ml/autogen/client/models/source1.py +154 -0
  44. truefoundry/ml/autogen/client/models/spa_cy_framework.py +74 -0
  45. truefoundry/ml/autogen/client/models/stats_models_framework.py +75 -0
  46. truefoundry/ml/autogen/client/models/system_message.py +1 -1
  47. truefoundry/ml/autogen/client/models/{tensorflow_framework.py → tensor_flow_framework.py} +11 -10
  48. truefoundry/ml/autogen/client/models/text_content_part.py +1 -1
  49. truefoundry/ml/autogen/client/models/transformers_framework.py +10 -4
  50. truefoundry/ml/autogen/client/models/trigger_job_run_config_request_dto.py +90 -0
  51. truefoundry/ml/autogen/client/models/trigger_job_run_config_response_dto.py +71 -0
  52. truefoundry/ml/autogen/client/models/{truefoundry_model_source.py → true_foundry_artifact_source.py} +13 -11
  53. truefoundry/ml/autogen/client/models/update_artifact_version_request_dto.py +11 -1
  54. truefoundry/ml/autogen/client/models/update_model_version_request_dto.py +1 -13
  55. truefoundry/ml/autogen/client/models/user_message.py +1 -1
  56. truefoundry/ml/autogen/client/models/xg_boost_framework.py +92 -0
  57. truefoundry/ml/autogen/client_README.md +30 -12
  58. truefoundry/ml/autogen/entities/artifacts.py +87 -9
  59. truefoundry/ml/autogen/models/__init__.py +4 -0
  60. truefoundry/ml/autogen/models/exceptions.py +30 -0
  61. truefoundry/ml/autogen/models/schema.py +1547 -0
  62. truefoundry/ml/autogen/models/signature.py +139 -0
  63. truefoundry/ml/autogen/models/utils.py +699 -0
  64. truefoundry/ml/log_types/artifacts/artifact.py +131 -63
  65. truefoundry/ml/log_types/artifacts/general_artifact.py +7 -26
  66. truefoundry/ml/log_types/artifacts/model.py +195 -197
  67. truefoundry/ml/mlfoundry_api.py +47 -52
  68. truefoundry/ml/mlfoundry_run.py +35 -43
  69. truefoundry/ml/model_framework.py +169 -0
  70. {truefoundry-0.4.10.dist-info → truefoundry-0.5.0.dist-info}/METADATA +1 -1
  71. {truefoundry-0.4.10.dist-info → truefoundry-0.5.0.dist-info}/RECORD +73 -51
  72. truefoundry/ml/autogen/client/api/python_deployment_config_api.py +0 -201
  73. truefoundry/ml/autogen/client/models/model_schema_dto.py +0 -85
  74. truefoundry/ml/autogen/client/models/prediction_type.py +0 -34
  75. {truefoundry-0.4.10.dist-info → truefoundry-0.5.0.dist-info}/WHEEL +0 -0
  76. {truefoundry-0.4.10.dist-info → truefoundry-0.5.0.dist-info}/entry_points.txt +0 -0
@@ -4,40 +4,39 @@ import json
4
4
  import logging
5
5
  import os.path
6
6
  import tempfile
7
+ import typing
7
8
  import uuid
9
+ import warnings
8
10
  from pathlib import Path
9
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
11
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
10
12
 
11
13
  from truefoundry.ml.artifact.truefoundry_artifact_repo import (
12
14
  ArtifactIdentifier,
13
15
  MlFoundryArtifactsRepository,
14
16
  )
15
17
  from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
16
- AddCustomMetricsToModelVersionRequestDto,
17
18
  ArtifactType,
18
19
  CreateArtifactVersionRequestDto,
19
- CreateModelVersionRequestDto,
20
20
  DeleteArtifactVersionsRequestDto,
21
+ ExternalArtifactSource,
21
22
  FinalizeArtifactVersionRequestDto,
22
- MetricDto,
23
+ Framework,
24
+ Manifest,
23
25
  MlfoundryArtifactsApi,
24
26
  ModelDto,
25
27
  ModelVersionDto,
28
+ ModelVersionEnvironment,
29
+ ModelVersionManifest,
26
30
  NotifyArtifactVersionFailureDto,
31
+ TrueFoundryArtifactSource,
27
32
  UpdateModelVersionRequestDto,
28
33
  )
29
- from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
30
- InternalMetadata as InternalMetadataDto,
31
- )
32
34
  from truefoundry.ml.enums import ModelFramework
33
35
  from truefoundry.ml.exceptions import MlFoundryException
36
+ from truefoundry.ml.log_types.artifacts.artifact import BlobStorageDirectory
34
37
  from truefoundry.ml.log_types.artifacts.constants import (
35
- FILES_DIR,
36
38
  INTERNAL_METADATA_PATH,
37
- MODEL_DIR_NAME,
38
- MODEL_SCHEMA_UPDATE_FAILURE_HELP,
39
39
  )
40
- from truefoundry.ml.log_types.artifacts.model_extras import CustomMetric, ModelSchema
41
40
  from truefoundry.ml.log_types.artifacts.utils import (
42
41
  _copy_additional_files,
43
42
  _get_src_dest_pairs,
@@ -45,19 +44,18 @@ from truefoundry.ml.log_types.artifacts.utils import (
45
44
  _validate_description,
46
45
  calculate_total_size,
47
46
  )
47
+ from truefoundry.ml.model_framework import ModelFrameworkType, _ModelFramework
48
48
  from truefoundry.ml.session import _get_api_client
49
49
  from truefoundry.pydantic_v1 import BaseModel, Extra
50
- from truefoundry.version import __version__
51
50
 
52
51
  if TYPE_CHECKING:
53
52
  from truefoundry.ml.mlfoundry_run import MlFoundryRun
54
53
 
54
+
55
55
  logger = logging.getLogger(__name__)
56
56
 
57
57
 
58
58
  # TODO: Support async download and upload
59
-
60
-
61
59
  class ModelVersionInternalMetadata(BaseModel):
62
60
  class Config:
63
61
  extra = Extra.allow
@@ -99,8 +97,8 @@ class ModelVersion:
99
97
  self._deleted = False
100
98
  self._description: str = ""
101
99
  self._metadata: Dict[str, Any] = {}
102
- self._metrics: List[MetricDto] = []
103
- self._set_metrics_attr()
100
+ self._model_schema: Optional[Dict[str, Any]] = None
101
+ self._framework: Optional[ModelFrameworkType] = None
104
102
  self._set_mutable_attrs()
105
103
 
106
104
  @classmethod
@@ -138,22 +136,34 @@ class ModelVersion:
138
136
  "Model Version was deleted, cannot perform updates on a deleted version"
139
137
  )
140
138
 
141
- def _set_metrics_attr(self):
142
- self._metrics = sorted(
143
- self._model_version.metrics or [], key=lambda m: m.timestamp
144
- )
145
-
146
139
  def _set_mutable_attrs(self):
147
- self._description = self._model_version.description or ""
148
- self._metadata = copy.deepcopy(self._model_version.artifact_metadata)
140
+ if self._model_version.manifest:
141
+ self._description = self._model_version.manifest.description or ""
142
+ self._metadata = copy.deepcopy(self._model_version.manifest.metadata)
143
+ self._model_schema = copy.deepcopy(
144
+ self._model_version.manifest.model_schema
145
+ )
146
+ if self._model_version.manifest.framework:
147
+ self._framework = copy.deepcopy(
148
+ self._model_version.manifest.framework.actual_instance
149
+ )
150
+ else:
151
+ self._framework = None
152
+ else:
153
+ self._description = self._model_version.description or ""
154
+ self._metadata = copy.deepcopy(self._model_version.artifact_metadata)
155
+ self._framework = _ModelFramework.to_model_framework_type(
156
+ self._model_version.model_framework
157
+ )
158
+ self._model_schema = None
149
159
 
150
- def _refetch_model_version(self):
160
+ def _refetch_model_version(self, reset_mutable_attrs: bool = True):
151
161
  _model_version = self._mlfoundry_artifacts_api.get_model_version_get(
152
162
  id=self._model_version.id
153
163
  )
154
164
  self._model_version = _model_version.model_version
155
- self._set_metrics_attr()
156
- self._set_mutable_attrs()
165
+ if reset_mutable_attrs:
166
+ self._set_mutable_attrs()
157
167
 
158
168
  def __repr__(self):
159
169
  return f"{self.__class__.__name__}(fqn={self.fqn!r})"
@@ -187,12 +197,14 @@ class ModelVersion:
187
197
  return self._model_version.fqn
188
198
 
189
199
  @property
190
- def step(self) -> int:
200
+ def step(self) -> Optional[int]:
191
201
  """Get the step in which model was created"""
202
+ if self._model_version.manifest:
203
+ return self._model_version.manifest.step
192
204
  return self._model_version.step
193
205
 
194
206
  @property
195
- def description(self) -> Optional[str]:
207
+ def description(self) -> str:
196
208
  """Get description of the model"""
197
209
  return self._description
198
210
 
@@ -211,23 +223,51 @@ class ModelVersion:
211
223
  def metadata(self, value: Dict[str, Any]):
212
224
  """set the metadata for current model"""
213
225
  _validate_artifact_metadata(value)
214
- self._metadata = value
226
+ self._metadata = copy.deepcopy(value)
227
+
228
+ @property
229
+ def model_schema(self) -> Optional[Dict[str, Any]]:
230
+ """Get model_schema for the current model"""
231
+ return self._model_schema
232
+
233
+ @model_schema.setter
234
+ def model_schema(self, value: Optional[Dict[str, Any]]):
235
+ """set the model_schema for current model"""
236
+ if not self._model_version.manifest:
237
+ warnings.warn(
238
+ message="This model version was created using an older serialization format. model_schema will not be updated",
239
+ category=DeprecationWarning,
240
+ stacklevel=2,
241
+ )
242
+ return
243
+ self._model_schema = copy.deepcopy(value)
215
244
 
216
245
  @property
217
- def metrics(self) -> Dict[str, Union[float, int]]:
218
- """get the metrics for the current version of the model"""
219
- metrics_as_kv: Dict[str, Union[float, int]] = {}
220
- for metric in self._metrics:
221
- metrics_as_kv[metric.key] = metric.value
222
- return metrics_as_kv
246
+ def framework(self) -> Optional["ModelFrameworkType"]:
247
+ """Get the framework of the model"""
248
+ return self._framework
249
+
250
+ @framework.setter
251
+ def framework(
252
+ self, value: Optional[Union[str, ModelFramework, "ModelFrameworkType"]]
253
+ ):
254
+ """Set the framework of the model"""
255
+ if not self._model_version.manifest:
256
+ warnings.warn(
257
+ message="This model version was created using an older serialization format. Framework will not be updated",
258
+ category=DeprecationWarning,
259
+ stacklevel=2,
260
+ )
261
+ return
262
+ self._framework = _ModelFramework.to_model_framework_type(value)
223
263
 
224
264
  @property
225
- def created_at(self) -> datetime.datetime:
265
+ def created_at(self) -> Optional[datetime.datetime]:
226
266
  """Get the time at which model version was created"""
227
267
  return self._model_version.created_at
228
268
 
229
269
  @property
230
- def updated_at(self) -> datetime.datetime:
270
+ def updated_at(self) -> Optional[datetime.datetime]:
231
271
  """Get the information about when the model version was updated"""
232
272
  return self._model_version.updated_at
233
273
 
@@ -274,11 +314,27 @@ class ModelVersion:
274
314
  path: Optional[Union[str, Path]],
275
315
  overwrite: bool = False,
276
316
  progress: Optional[bool] = None,
277
- ) -> Tuple[ModelVersionInternalMetadata, ModelVersionDownloadInfo]:
317
+ ) -> ModelVersionDownloadInfo:
278
318
  self._ensure_not_deleted()
279
319
  download_dir = self.raw_download(
280
320
  path=path, overwrite=overwrite, progress=progress
281
321
  )
322
+
323
+ if self._model_version.manifest:
324
+ _framework = self._model_version.manifest.framework
325
+ model_framework = (
326
+ ModelFramework(_framework.actual_instance.type)
327
+ if _framework
328
+ else ModelFramework.UNKNOWN
329
+ )
330
+ download_info = ModelVersionDownloadInfo(
331
+ download_dir=download_dir,
332
+ model_dir=download_dir,
333
+ model_framework=model_framework,
334
+ model_filename=None,
335
+ )
336
+ return download_info
337
+
282
338
  internal_metadata_path = os.path.join(download_dir, INTERNAL_METADATA_PATH)
283
339
  if not os.path.exists(internal_metadata_path):
284
340
  raise MlFoundryException(
@@ -290,12 +346,14 @@ class ModelVersion:
290
346
  download_info = ModelVersionDownloadInfo(
291
347
  download_dir=os.path.join(download_dir, internal_metadata.files_dir),
292
348
  model_dir=os.path.join(
293
- download_dir, internal_metadata.files_dir, internal_metadata.model_dir
349
+ download_dir,
350
+ internal_metadata.files_dir,
351
+ internal_metadata.model_dir,
294
352
  ),
295
353
  model_framework=internal_metadata.framework,
296
354
  model_filename=internal_metadata.model_filename,
297
355
  )
298
- return internal_metadata, download_info
356
+ return download_info
299
357
 
300
358
  def download(
301
359
  self,
@@ -331,7 +389,7 @@ class ModelVersion:
331
389
  print(download_info.model_dir)
332
390
  ```
333
391
  """
334
- _, download_info = self._download(
392
+ download_info = self._download(
335
393
  path=path, overwrite=overwrite, progress=progress
336
394
  )
337
395
  return download_info
@@ -378,37 +436,55 @@ class ModelVersion:
378
436
  ```
379
437
  """
380
438
  self._ensure_not_deleted()
381
- _model_version = self._mlfoundry_artifacts_api.update_model_version_post(
382
- update_model_version_request_dto=UpdateModelVersionRequestDto(
383
- id=self._model_version.id,
384
- description=self.description,
385
- artifact_metadata=self.metadata,
439
+ if self._model_version.manifest:
440
+ self._model_version.manifest.description = self.description
441
+ self._model_version.manifest.metadata = self.metadata
442
+ self._model_version.manifest.model_schema = self.model_schema
443
+ self._model_version.manifest.framework = (
444
+ Framework.from_dict(self.framework.dict()) if self.framework else None
386
445
  )
387
- )
388
- self._model_version = _model_version.model_version
389
- self._set_metrics_attr()
390
- self._set_mutable_attrs()
446
+ try:
447
+ _model_version = self._mlfoundry_artifacts_api.update_model_version_post(
448
+ update_model_version_request_dto=UpdateModelVersionRequestDto(
449
+ id=self._model_version.id,
450
+ description=self.description,
451
+ artifact_metadata=self.metadata,
452
+ manifest=self._model_version.manifest,
453
+ )
454
+ )
455
+ except Exception:
456
+ # rollback edits to internal object
457
+ self._refetch_model_version(reset_mutable_attrs=False)
458
+ raise
459
+ else:
460
+ self._model_version = _model_version.model_version
461
+ self._set_mutable_attrs()
391
462
 
392
463
 
393
464
  def _log_model_version( # noqa: C901
394
465
  run: Optional["MlFoundryRun"],
395
466
  name: str,
396
- model_file_or_folder: str,
397
- framework: Optional[Union[ModelFramework, str]],
467
+ model_file_or_folder: Union[str, BlobStorageDirectory],
398
468
  mlfoundry_artifacts_api: Optional[MlfoundryArtifactsApi] = None,
399
469
  ml_repo_id: Optional[str] = None,
400
- additional_files: Sequence[Tuple[Union[str, Path], Optional[str]]] = (),
401
470
  description: Optional[str] = None,
402
471
  metadata: Optional[Dict[str, Any]] = None,
403
- model_schema: Optional[Union[Dict[str, Any], ModelSchema]] = None,
404
- custom_metrics: Optional[List[Union[CustomMetric, Dict[str, Any]]]] = None,
405
472
  step: Optional[int] = 0,
406
473
  progress: Optional[bool] = None,
474
+ framework: Optional[Union[str, ModelFramework, "ModelFrameworkType"]] = None,
475
+ environment: Optional[ModelVersionEnvironment] = None,
476
+ model_schema: Optional[Dict[str, Any]] = None,
407
477
  ) -> ModelVersion:
408
478
  if (run and mlfoundry_artifacts_api) or (not run and not mlfoundry_artifacts_api):
409
479
  raise MlFoundryException(
410
480
  "Exactly one of run, mlfoundry_artifacts_api should be passed"
411
481
  )
482
+
483
+ if not isinstance(model_file_or_folder, (str, BlobStorageDirectory)):
484
+ raise MlFoundryException(
485
+ "model_file_or_folder should be of type str or BlobStorageDirectory"
486
+ )
487
+
412
488
  if mlfoundry_artifacts_api and not ml_repo_id:
413
489
  raise MlFoundryException(
414
490
  "If mlfoundry_artifacts_api is passed, ml_repo_id must also be passed"
@@ -418,89 +494,29 @@ def _log_model_version( # noqa: C901
418
494
 
419
495
  assert mlfoundry_artifacts_api is not None
420
496
 
421
- custom_metrics = custom_metrics or []
422
- metadata = metadata or {}
423
- additional_files = additional_files or {}
424
497
  step = step or 0
425
-
426
- # validations
427
- if framework is None:
428
- framework = ModelFramework.UNKNOWN
429
- elif not isinstance(framework, ModelFramework):
430
- framework = ModelFramework(framework)
498
+ total_size = None
499
+ metadata = metadata or {}
431
500
 
432
501
  _validate_description(description)
433
502
  _validate_artifact_metadata(metadata)
434
503
 
435
- if model_schema is not None and not isinstance(model_schema, ModelSchema):
436
- model_schema = ModelSchema.parse_obj(model_schema)
437
-
438
- if custom_metrics and not model_schema:
439
- raise MlFoundryException(
440
- "Custom Metrics defined without adding the Model Schema"
441
- )
442
- custom_metrics = [
443
- CustomMetric.parse_obj(cm) if not isinstance(cm, CustomMetric) else cm
444
- for cm in custom_metrics
445
- ]
446
-
447
- logger.info("Logging model and additional files, this might take a while ...")
448
- temp_dir = tempfile.TemporaryDirectory(prefix="truefoundry-")
449
-
450
- internal_metadata = ModelVersionInternalMetadata(
451
- framework=framework,
452
- files_dir=FILES_DIR,
453
- model_dir=MODEL_DIR_NAME,
454
- model_filename=(
455
- os.path.basename(model_file_or_folder)
456
- if model_file_or_folder and os.path.isfile(model_file_or_folder)
457
- else None
458
- ),
459
- mlfoundry_version=__version__,
460
- truefoundry_version=__version__,
461
- )
462
-
463
- try:
464
- local_files_dir = os.path.join(temp_dir.name, internal_metadata.files_dir)
465
- os.makedirs(local_files_dir, exist_ok=True)
466
- # in case model was None, we still create an empty dir
467
- local_model_dir = os.path.join(local_files_dir, internal_metadata.model_dir)
468
- os.makedirs(local_model_dir, exist_ok=True)
469
-
470
- logger.info("Adding model file/folder to model version content")
471
- _model_file_or_folder: Sequence[Tuple[str, str]] = [
472
- (model_file_or_folder, MODEL_DIR_NAME.rstrip(os.sep) + os.sep),
473
- ]
474
-
475
- temp_dest_to_src_map = _copy_additional_files(
476
- root_dir=temp_dir.name,
477
- files_dir=internal_metadata.files_dir,
478
- model_dir=internal_metadata.model_dir,
479
- additional_files=_model_file_or_folder,
480
- ignore_model_dir_dest_conflict=True,
481
- )
482
-
483
- # verify additional files and paths, copy additional files
484
- if additional_files:
485
- logger.info("Adding `additional_files` to model version contents")
504
+ if isinstance(model_file_or_folder, str):
505
+ logger.info("Logging model and additional files, this might take a while ...")
506
+ temp_dir = tempfile.TemporaryDirectory(prefix="truefoundry-")
507
+ try:
508
+ logger.info("Adding model file/folder to model version content")
486
509
  temp_dest_to_src_map = _copy_additional_files(
487
510
  root_dir=temp_dir.name,
488
- files_dir=internal_metadata.files_dir,
489
- model_dir=internal_metadata.model_dir,
490
- additional_files=additional_files,
491
- ignore_model_dir_dest_conflict=False,
492
- existing_dest_to_src_map=temp_dest_to_src_map,
511
+ files_dir="",
512
+ model_dir=None,
513
+ additional_files=[(model_file_or_folder, "")],
514
+ ignore_model_dir_dest_conflict=True,
493
515
  )
494
- except Exception as e:
495
- temp_dir.cleanup()
496
- raise MlFoundryException("Failed to log model") from e
497
516
 
498
- # save internal metadata
499
- local_internal_metadata_path = os.path.join(temp_dir.name, INTERNAL_METADATA_PATH)
500
- os.makedirs(os.path.dirname(local_internal_metadata_path), exist_ok=True)
501
- with open(local_internal_metadata_path, "w") as f:
502
- json.dump(internal_metadata.dict(), f)
503
- temp_dest_to_src_map[local_internal_metadata_path] = local_internal_metadata_path
517
+ except Exception as e:
518
+ temp_dir.cleanup()
519
+ raise MlFoundryException("Failed to log model") from e
504
520
 
505
521
  # create entry
506
522
  _create_artifact_version_response = (
@@ -513,82 +529,64 @@ def _log_model_version( # noqa: C901
513
529
  )
514
530
  )
515
531
  version_id = _create_artifact_version_response.id
516
- artifacts_repo = MlFoundryArtifactsRepository(
517
- artifact_identifier=ArtifactIdentifier(
518
- artifact_version_id=uuid.UUID(version_id)
519
- ),
520
- api_client=mlfoundry_artifacts_api.api_client,
521
- )
522
-
523
- total_size = calculate_total_size(list(temp_dest_to_src_map.values()))
524
- try:
525
- logger.info(
526
- "Packaging and uploading files to remote with size: %.6f MB",
527
- total_size / 1000000.0,
532
+ artifact_storage_root = _create_artifact_version_response.artifact_storage_root
533
+ if isinstance(model_file_or_folder, str):
534
+ # Source is of type TrueFoundryArtifactSource
535
+ source = TrueFoundryArtifactSource(
536
+ type="truefoundry", uri=artifact_storage_root
528
537
  )
529
- src_dest_pairs = _get_src_dest_pairs(
530
- root_dir=temp_dir.name, dest_to_src_map=temp_dest_to_src_map
538
+ artifacts_repo = MlFoundryArtifactsRepository(
539
+ artifact_identifier=ArtifactIdentifier(
540
+ artifact_version_id=uuid.UUID(version_id)
541
+ ),
542
+ api_client=mlfoundry_artifacts_api.api_client,
531
543
  )
532
- artifacts_repo.log_artifacts(src_dest_pairs=src_dest_pairs, progress=progress)
533
- except Exception as e:
534
- mlfoundry_artifacts_api.notify_failure_post(
535
- notify_artifact_version_failure_dto=NotifyArtifactVersionFailureDto(
536
- id=version_id
544
+ total_size = calculate_total_size(list(temp_dest_to_src_map.values()))
545
+ try:
546
+ logger.info(
547
+ "Packaging and uploading files to remote with Total Size: %.6f MB",
548
+ total_size / 1000000.0,
537
549
  )
538
- )
539
- raise MlFoundryException("Failed to log model") from e
540
- finally:
541
- temp_dir.cleanup()
542
-
543
- # Note: Here we call from_dict instead of directly passing in init and relying on it
544
- # to convert because the complicated union of types generates a custom type to handle casting
545
- # Check the source of `InternalMetadataDto` to see the generated code
546
- internal_metadata_dto = InternalMetadataDto.from_dict(
547
- internal_metadata.dict() if internal_metadata is not None else {}
550
+ src_dest_pairs = _get_src_dest_pairs(
551
+ root_dir=temp_dir.name, dest_to_src_map=temp_dest_to_src_map
552
+ )
553
+ artifacts_repo.log_artifacts(
554
+ src_dest_pairs=src_dest_pairs, progress=progress
555
+ )
556
+ except Exception as e:
557
+ mlfoundry_artifacts_api.notify_failure_post(
558
+ notify_artifact_version_failure_dto=NotifyArtifactVersionFailureDto(
559
+ id=version_id
560
+ )
561
+ )
562
+ raise MlFoundryException("Failed to log model") from e
563
+ finally:
564
+ temp_dir.cleanup()
565
+ elif isinstance(model_file_or_folder, BlobStorageDirectory):
566
+ source = ExternalArtifactSource(type="external", uri=model_file_or_folder.uri)
567
+ else:
568
+ raise MlFoundryException("Invalid model_file_or_folder provided")
569
+
570
+ _framework = _ModelFramework.to_model_framework_type(framework)
571
+ _source_cls = typing.get_type_hints(ModelVersionManifest)["source"]
572
+ model_manifest = ModelVersionManifest(
573
+ description=description,
574
+ metadata=metadata,
575
+ source=_source_cls.from_dict(source.dict()),
576
+ framework=Framework.from_dict(_framework.dict()) if _framework else None,
577
+ environment=environment,
578
+ step=step,
579
+ model_schema=model_schema,
548
580
  )
549
- mlfoundry_artifacts_api.finalize_artifact_version_post(
581
+ artifact_version_response = mlfoundry_artifacts_api.finalize_artifact_version_post(
550
582
  finalize_artifact_version_request_dto=FinalizeArtifactVersionRequestDto(
551
583
  id=version_id,
552
584
  run_uuid=run.run_id if run else None,
553
585
  artifact_size=total_size,
554
- internal_metadata=internal_metadata_dto,
555
- step=step if run else None,
556
- )
557
- )
558
- _model_version = mlfoundry_artifacts_api.create_model_version_post(
559
- create_model_version_request_dto=CreateModelVersionRequestDto(
560
- artifact_version_id=version_id,
561
- description=description,
562
586
  artifact_metadata=metadata,
563
- internal_metadata=internal_metadata_dto,
564
- data_path=INTERNAL_METADATA_PATH,
565
- step=step if run else None,
587
+ internal_metadata=None,
588
+ step=model_manifest.step,
589
+ manifest=Manifest.from_dict(model_manifest.to_dict()),
566
590
  )
567
591
  )
568
- model_version = _model_version.model_version
569
-
570
- # update model schema at end
571
- update_args: Dict[str, Any] = {
572
- "id": version_id,
573
- "model_framework": framework.value,
574
- }
575
- if model_schema:
576
- update_args["model_schema"] = model_schema
577
-
578
- try:
579
- _model_version = mlfoundry_artifacts_api.update_model_version_post(
580
- update_model_version_request_dto=UpdateModelVersionRequestDto(**update_args)
581
- )
582
- model_version = _model_version.model_version
583
- if model_schema:
584
- _model_version = mlfoundry_artifacts_api.add_custom_metrics_to_model_version_post(
585
- add_custom_metrics_to_model_version_request_dto=AddCustomMetricsToModelVersionRequestDto(
586
- id=version_id, custom_metrics=custom_metrics
587
- )
588
- )
589
- model_version = _model_version.model_version
590
- except Exception:
591
- # TODO (chiragjn): what is the best exception to catch here?
592
- logger.error(MODEL_SCHEMA_UPDATE_FAILURE_HELP.format(fqn=model_version.fqn))
593
-
594
- return ModelVersion.from_fqn(fqn=model_version.fqn)
592
+ return ModelVersion.from_fqn(fqn=artifact_version_response.artifact_version.fqn)