truefoundry 0.4.6__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truefoundry might be problematic. Click here for more details.

Files changed (40) hide show
  1. truefoundry/ml/__init__.py +36 -0
  2. truefoundry/ml/autogen/client/__init__.py +29 -6
  3. truefoundry/ml/autogen/client/api/__init__.py +3 -3
  4. truefoundry/ml/autogen/client/api/deprecated_api.py +7 -7
  5. truefoundry/ml/autogen/client/api/generate_code_snippet_api.py +526 -0
  6. truefoundry/ml/autogen/client/models/__init__.py +26 -3
  7. truefoundry/ml/autogen/client/models/command.py +152 -0
  8. truefoundry/ml/autogen/client/models/create_workflow_task_config_request_dto.py +72 -0
  9. truefoundry/ml/autogen/client/models/external_model_source.py +3 -2
  10. truefoundry/ml/autogen/client/models/fast_ai_framework.py +75 -0
  11. truefoundry/ml/autogen/client/models/framework.py +250 -14
  12. truefoundry/ml/autogen/client/models/gluon_framework.py +74 -0
  13. truefoundry/ml/autogen/client/models/{upload_model_source.py → h2_o_framework.py} +11 -11
  14. truefoundry/ml/autogen/client/models/keras_framework.py +74 -0
  15. truefoundry/ml/autogen/client/models/light_gbm_framework.py +75 -0
  16. truefoundry/ml/autogen/client/models/model_version_manifest.py +1 -1
  17. truefoundry/ml/autogen/client/models/onnx_framework.py +74 -0
  18. truefoundry/ml/autogen/client/models/paddle_framework.py +75 -0
  19. truefoundry/ml/autogen/client/models/py_torch_framework.py +75 -0
  20. truefoundry/ml/autogen/client/models/sklearn_framework.py +75 -0
  21. truefoundry/ml/autogen/client/models/source.py +9 -32
  22. truefoundry/ml/autogen/client/models/spa_cy_framework.py +74 -0
  23. truefoundry/ml/autogen/client/models/stats_models_framework.py +75 -0
  24. truefoundry/ml/autogen/client/models/{tensorflow_framework.py → tensor_flow_framework.py} +10 -9
  25. truefoundry/ml/autogen/client/models/transformers_framework.py +3 -2
  26. truefoundry/ml/autogen/client/models/trigger_job_run_config_request_dto.py +90 -0
  27. truefoundry/ml/autogen/client/models/trigger_job_run_config_response_dto.py +71 -0
  28. truefoundry/ml/autogen/client/models/truefoundry_model_source.py +5 -3
  29. truefoundry/ml/autogen/client/models/xg_boost_framework.py +75 -0
  30. truefoundry/ml/autogen/client_README.md +22 -5
  31. truefoundry/ml/autogen/entities/artifacts.py +19 -2
  32. truefoundry/ml/log_types/artifacts/model.py +167 -177
  33. truefoundry/ml/mlfoundry_api.py +47 -18
  34. truefoundry/ml/mlfoundry_run.py +27 -12
  35. truefoundry/ml/model_framework.py +169 -0
  36. {truefoundry-0.4.6.dist-info → truefoundry-0.5.0rc2.dist-info}/METADATA +1 -1
  37. {truefoundry-0.4.6.dist-info → truefoundry-0.5.0rc2.dist-info}/RECORD +39 -23
  38. truefoundry/ml/autogen/client/api/python_deployment_config_api.py +0 -201
  39. {truefoundry-0.4.6.dist-info → truefoundry-0.5.0rc2.dist-info}/WHEEL +0 -0
  40. {truefoundry-0.4.6.dist-info → truefoundry-0.5.0rc2.dist-info}/entry_points.txt +0 -0
@@ -5,39 +5,38 @@ import logging
5
5
  import os.path
6
6
  import tempfile
7
7
  import uuid
8
+ import warnings
8
9
  from pathlib import Path
9
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
10
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union
10
11
 
11
12
  from truefoundry.ml.artifact.truefoundry_artifact_repo import (
12
13
  ArtifactIdentifier,
13
14
  MlFoundryArtifactsRepository,
14
15
  )
15
16
  from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
16
- AddCustomMetricsToModelVersionRequestDto,
17
17
  ArtifactType,
18
+ ArtifactVersionSerializationFormat,
18
19
  CreateArtifactVersionRequestDto,
19
- CreateModelVersionRequestDto,
20
20
  DeleteArtifactVersionsRequestDto,
21
+ ExternalModelSource,
21
22
  FinalizeArtifactVersionRequestDto,
22
- MetricDto,
23
+ Framework,
23
24
  MlfoundryArtifactsApi,
24
25
  ModelDto,
25
26
  ModelVersionDto,
27
+ ModelVersionManifest,
26
28
  NotifyArtifactVersionFailureDto,
29
+ TruefoundryModelSource,
27
30
  UpdateModelVersionRequestDto,
28
31
  )
29
- from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
30
- InternalMetadata as InternalMetadataDto,
32
+ from truefoundry.ml.autogen.client import (
33
+ Source as ModelVersionSource,
31
34
  )
32
35
  from truefoundry.ml.enums import ModelFramework
33
36
  from truefoundry.ml.exceptions import MlFoundryException
34
37
  from truefoundry.ml.log_types.artifacts.constants import (
35
- FILES_DIR,
36
38
  INTERNAL_METADATA_PATH,
37
- MODEL_DIR_NAME,
38
- MODEL_SCHEMA_UPDATE_FAILURE_HELP,
39
39
  )
40
- from truefoundry.ml.log_types.artifacts.model_extras import CustomMetric, ModelSchema
41
40
  from truefoundry.ml.log_types.artifacts.utils import (
42
41
  _copy_additional_files,
43
42
  _get_src_dest_pairs,
@@ -45,19 +44,18 @@ from truefoundry.ml.log_types.artifacts.utils import (
45
44
  _validate_description,
46
45
  calculate_total_size,
47
46
  )
47
+ from truefoundry.ml.model_framework import ModelFrameworkType, _ModelFramework
48
48
  from truefoundry.ml.session import _get_api_client
49
- from truefoundry.pydantic_v1 import BaseModel, Extra
50
- from truefoundry.version import __version__
49
+ from truefoundry.pydantic_v1 import BaseModel, Extra, StrictStr
51
50
 
52
51
  if TYPE_CHECKING:
53
52
  from truefoundry.ml.mlfoundry_run import MlFoundryRun
54
53
 
54
+
55
55
  logger = logging.getLogger(__name__)
56
56
 
57
57
 
58
58
  # TODO: Support async download and upload
59
-
60
-
61
59
  class ModelVersionInternalMetadata(BaseModel):
62
60
  class Config:
63
61
  extra = Extra.allow
@@ -77,6 +75,10 @@ class ModelVersionInternalMetadata(BaseModel):
77
75
  return dct
78
76
 
79
77
 
78
+ class BlobStorageModelDirectory(BaseModel):
79
+ uri: StrictStr
80
+
81
+
80
82
  class ModelVersionDownloadInfo(BaseModel):
81
83
  download_dir: str
82
84
  model_dir: str
@@ -99,8 +101,7 @@ class ModelVersion:
99
101
  self._deleted = False
100
102
  self._description: str = ""
101
103
  self._metadata: Dict[str, Any] = {}
102
- self._metrics: List[MetricDto] = []
103
- self._set_metrics_attr()
104
+ self._framework: Optional["Framework"] = None
104
105
  self._set_mutable_attrs()
105
106
 
106
107
  @classmethod
@@ -138,21 +139,20 @@ class ModelVersion:
138
139
  "Model Version was deleted, cannot perform updates on a deleted version"
139
140
  )
140
141
 
141
- def _set_metrics_attr(self):
142
- self._metrics = sorted(
143
- self._model_version.metrics or [], key=lambda m: m.timestamp
144
- )
145
-
146
142
  def _set_mutable_attrs(self):
147
143
  self._description = self._model_version.description or ""
148
144
  self._metadata = copy.deepcopy(self._model_version.artifact_metadata)
145
+ self._framework = (
146
+ copy.deepcopy(self._model_version.manifest.framework)
147
+ if self._model_version.manifest
148
+ else None
149
+ )
149
150
 
150
151
  def _refetch_model_version(self):
151
152
  _model_version = self._mlfoundry_artifacts_api.get_model_version_get(
152
153
  id=self._model_version.id
153
154
  )
154
155
  self._model_version = _model_version.model_version
155
- self._set_metrics_attr()
156
156
  self._set_mutable_attrs()
157
157
 
158
158
  def __repr__(self):
@@ -201,6 +201,8 @@ class ModelVersion:
201
201
  """set the description of the model"""
202
202
  _validate_description(value)
203
203
  self._description = value
204
+ if self._model_version.manifest:
205
+ self._model_version.manifest.description = value
204
206
 
205
207
  @property
206
208
  def metadata(self) -> Dict[str, Any]:
@@ -212,14 +214,43 @@ class ModelVersion:
212
214
  """set the metadata for current model"""
213
215
  _validate_artifact_metadata(value)
214
216
  self._metadata = value
217
+ if self._model_version.manifest:
218
+ self._model_version.manifest.metadata = value
215
219
 
216
220
  @property
217
- def metrics(self) -> Dict[str, Union[float, int]]:
218
- """get the metrics for the current version of the model"""
219
- metrics_as_kv: Dict[str, Union[float, int]] = {}
220
- for metric in self._metrics:
221
- metrics_as_kv[metric.key] = metric.value
222
- return metrics_as_kv
221
+ def framework(self) -> Optional["ModelFrameworkType"]:
222
+ """Get the framework of the model"""
223
+ return (
224
+ _ModelFramework.from_dict(self._framework.actual_instance.dict())
225
+ if self._framework and self._framework.actual_instance
226
+ else None
227
+ )
228
+
229
+ @framework.setter
230
+ def framework(
231
+ self, value: Optional[Union[str, ModelFramework, "ModelFrameworkType"]]
232
+ ):
233
+ """set the framework of the model"""
234
+ model_framework_type: Optional[ModelFrameworkType] = None
235
+
236
+ if value is not None:
237
+ model_framework_type = _ModelFramework.to_model_framework_type(value)
238
+
239
+ if (
240
+ self._model_version.serialization_format
241
+ == ArtifactVersionSerializationFormat.V2
242
+ ):
243
+ self._model_version.manifest.framework = self._framework = (
244
+ Framework.from_dict(model_framework_type.dict())
245
+ if model_framework_type
246
+ else None
247
+ )
248
+ else:
249
+ warnings.warn(
250
+ message="This model version was created using an older serialization format. Framework will not be updated",
251
+ category=DeprecationWarning,
252
+ stacklevel=2,
253
+ )
223
254
 
224
255
  @property
225
256
  def created_at(self) -> datetime.datetime:
@@ -274,11 +305,29 @@ class ModelVersion:
274
305
  path: Optional[Union[str, Path]],
275
306
  overwrite: bool = False,
276
307
  progress: Optional[bool] = None,
277
- ) -> Tuple[ModelVersionInternalMetadata, ModelVersionDownloadInfo]:
308
+ ) -> ModelVersionDownloadInfo:
278
309
  self._ensure_not_deleted()
279
310
  download_dir = self.raw_download(
280
311
  path=path, overwrite=overwrite, progress=progress
281
312
  )
313
+
314
+ if (
315
+ self._model_version.serialization_format
316
+ == ArtifactVersionSerializationFormat.V2
317
+ ):
318
+ model_framework = (
319
+ ModelFramework(self.framework.type)
320
+ if self.framework
321
+ else ModelFramework.UNKNOWN
322
+ )
323
+ download_info = ModelVersionDownloadInfo(
324
+ download_dir=download_dir,
325
+ model_dir=download_dir,
326
+ model_framework=model_framework,
327
+ model_filename=None,
328
+ )
329
+ return download_info
330
+
282
331
  internal_metadata_path = os.path.join(download_dir, INTERNAL_METADATA_PATH)
283
332
  if not os.path.exists(internal_metadata_path):
284
333
  raise MlFoundryException(
@@ -290,12 +339,14 @@ class ModelVersion:
290
339
  download_info = ModelVersionDownloadInfo(
291
340
  download_dir=os.path.join(download_dir, internal_metadata.files_dir),
292
341
  model_dir=os.path.join(
293
- download_dir, internal_metadata.files_dir, internal_metadata.model_dir
342
+ download_dir,
343
+ internal_metadata.files_dir,
344
+ internal_metadata.model_dir,
294
345
  ),
295
346
  model_framework=internal_metadata.framework,
296
347
  model_filename=internal_metadata.model_filename,
297
348
  )
298
- return internal_metadata, download_info
349
+ return download_info
299
350
 
300
351
  def download(
301
352
  self,
@@ -331,7 +382,7 @@ class ModelVersion:
331
382
  print(download_info.model_dir)
332
383
  ```
333
384
  """
334
- _, download_info = self._download(
385
+ download_info = self._download(
335
386
  path=path, overwrite=overwrite, progress=progress
336
387
  )
337
388
  return download_info
@@ -383,32 +434,36 @@ class ModelVersion:
383
434
  id=self._model_version.id,
384
435
  description=self.description,
385
436
  artifact_metadata=self.metadata,
437
+ manifest=self._model_version.manifest,
386
438
  )
387
439
  )
388
440
  self._model_version = _model_version.model_version
389
- self._set_metrics_attr()
390
441
  self._set_mutable_attrs()
391
442
 
392
443
 
393
444
  def _log_model_version( # noqa: C901
394
445
  run: Optional["MlFoundryRun"],
395
446
  name: str,
396
- model_file_or_folder: str,
397
- framework: Optional[Union[ModelFramework, str]],
447
+ model_file_or_folder: Union[str, BlobStorageModelDirectory],
398
448
  mlfoundry_artifacts_api: Optional[MlfoundryArtifactsApi] = None,
399
449
  ml_repo_id: Optional[str] = None,
400
450
  additional_files: Sequence[Tuple[Union[str, Path], Optional[str]]] = (),
401
451
  description: Optional[str] = None,
402
452
  metadata: Optional[Dict[str, Any]] = None,
403
- model_schema: Optional[Union[Dict[str, Any], ModelSchema]] = None,
404
- custom_metrics: Optional[List[Union[CustomMetric, Dict[str, Any]]]] = None,
405
453
  step: Optional[int] = 0,
406
454
  progress: Optional[bool] = None,
455
+ framework: Optional[Union[str, ModelFramework, "ModelFrameworkType"]] = None,
407
456
  ) -> ModelVersion:
408
457
  if (run and mlfoundry_artifacts_api) or (not run and not mlfoundry_artifacts_api):
409
458
  raise MlFoundryException(
410
459
  "Exactly one of run, mlfoundry_artifacts_api should be passed"
411
460
  )
461
+
462
+ if not isinstance(model_file_or_folder, (str, BlobStorageModelDirectory)):
463
+ raise MlFoundryException(
464
+ "model_file_or_folder should be of type str or BlobStorageModelDirectory"
465
+ )
466
+
412
467
  if mlfoundry_artifacts_api and not ml_repo_id:
413
468
  raise MlFoundryException(
414
469
  "If mlfoundry_artifacts_api is passed, ml_repo_id must also be passed"
@@ -418,89 +473,41 @@ def _log_model_version( # noqa: C901
418
473
 
419
474
  assert mlfoundry_artifacts_api is not None
420
475
 
421
- custom_metrics = custom_metrics or []
476
+ step = step or 0
477
+ total_size = None
422
478
  metadata = metadata or {}
423
479
  additional_files = additional_files or {}
424
- step = step or 0
425
-
426
- # validations
427
- if framework is None:
428
- framework = ModelFramework.UNKNOWN
429
- elif not isinstance(framework, ModelFramework):
430
- framework = ModelFramework(framework)
431
480
 
432
481
  _validate_description(description)
433
482
  _validate_artifact_metadata(metadata)
434
483
 
435
- if model_schema is not None and not isinstance(model_schema, ModelSchema):
436
- model_schema = ModelSchema.parse_obj(model_schema)
437
-
438
- if custom_metrics and not model_schema:
439
- raise MlFoundryException(
440
- "Custom Metrics defined without adding the Model Schema"
441
- )
442
- custom_metrics = [
443
- CustomMetric.parse_obj(cm) if not isinstance(cm, CustomMetric) else cm
444
- for cm in custom_metrics
445
- ]
446
-
447
- logger.info("Logging model and additional files, this might take a while ...")
448
- temp_dir = tempfile.TemporaryDirectory(prefix="truefoundry-")
449
-
450
- internal_metadata = ModelVersionInternalMetadata(
451
- framework=framework,
452
- files_dir=FILES_DIR,
453
- model_dir=MODEL_DIR_NAME,
454
- model_filename=(
455
- os.path.basename(model_file_or_folder)
456
- if model_file_or_folder and os.path.isfile(model_file_or_folder)
457
- else None
458
- ),
459
- mlfoundry_version=__version__,
460
- truefoundry_version=__version__,
461
- )
462
-
463
- try:
464
- local_files_dir = os.path.join(temp_dir.name, internal_metadata.files_dir)
465
- os.makedirs(local_files_dir, exist_ok=True)
466
- # in case model was None, we still create an empty dir
467
- local_model_dir = os.path.join(local_files_dir, internal_metadata.model_dir)
468
- os.makedirs(local_model_dir, exist_ok=True)
469
-
470
- logger.info("Adding model file/folder to model version content")
471
- _model_file_or_folder: Sequence[Tuple[str, str]] = [
472
- (model_file_or_folder, MODEL_DIR_NAME.rstrip(os.sep) + os.sep),
473
- ]
474
-
475
- temp_dest_to_src_map = _copy_additional_files(
476
- root_dir=temp_dir.name,
477
- files_dir=internal_metadata.files_dir,
478
- model_dir=internal_metadata.model_dir,
479
- additional_files=_model_file_or_folder,
480
- ignore_model_dir_dest_conflict=True,
481
- )
482
-
483
- # verify additional files and paths, copy additional files
484
- if additional_files:
485
- logger.info("Adding `additional_files` to model version contents")
484
+ if isinstance(model_file_or_folder, str):
485
+ logger.info("Logging model and additional files, this might take a while ...")
486
+ temp_dir = tempfile.TemporaryDirectory(prefix="truefoundry-")
487
+ try:
488
+ logger.info("Adding model file/folder to model version content")
486
489
  temp_dest_to_src_map = _copy_additional_files(
487
490
  root_dir=temp_dir.name,
488
- files_dir=internal_metadata.files_dir,
489
- model_dir=internal_metadata.model_dir,
490
- additional_files=additional_files,
491
- ignore_model_dir_dest_conflict=False,
492
- existing_dest_to_src_map=temp_dest_to_src_map,
491
+ files_dir="",
492
+ model_dir=None,
493
+ additional_files=[(model_file_or_folder, "")],
494
+ ignore_model_dir_dest_conflict=True,
493
495
  )
494
- except Exception as e:
495
- temp_dir.cleanup()
496
- raise MlFoundryException("Failed to log model") from e
497
496
 
498
- # save internal metadata
499
- local_internal_metadata_path = os.path.join(temp_dir.name, INTERNAL_METADATA_PATH)
500
- os.makedirs(os.path.dirname(local_internal_metadata_path), exist_ok=True)
501
- with open(local_internal_metadata_path, "w") as f:
502
- json.dump(internal_metadata.dict(), f)
503
- temp_dest_to_src_map[local_internal_metadata_path] = local_internal_metadata_path
497
+ # verify additional files and paths, copy additional files
498
+ if additional_files:
499
+ logger.info("Adding `additional_files` to model version contents")
500
+ temp_dest_to_src_map = _copy_additional_files(
501
+ root_dir=temp_dir.name,
502
+ files_dir="",
503
+ model_dir=None,
504
+ additional_files=additional_files,
505
+ ignore_model_dir_dest_conflict=False,
506
+ existing_dest_to_src_map=temp_dest_to_src_map,
507
+ )
508
+ except Exception as e:
509
+ temp_dir.cleanup()
510
+ raise MlFoundryException("Failed to log model") from e
504
511
 
505
512
  # create entry
506
513
  _create_artifact_version_response = (
@@ -513,82 +520,65 @@ def _log_model_version( # noqa: C901
513
520
  )
514
521
  )
515
522
  version_id = _create_artifact_version_response.id
516
- artifacts_repo = MlFoundryArtifactsRepository(
517
- artifact_identifier=ArtifactIdentifier(
518
- artifact_version_id=uuid.UUID(version_id)
519
- ),
520
- api_client=mlfoundry_artifacts_api.api_client,
521
- )
522
-
523
- total_size = calculate_total_size(list(temp_dest_to_src_map.values()))
524
- try:
525
- logger.info(
526
- "Packaging and uploading files to remote with size: %.6f MB",
527
- total_size / 1000000.0,
523
+ artifact_storage_root = _create_artifact_version_response.artifact_storage_root
524
+ if isinstance(model_file_or_folder, str):
525
+ # Source is of type TruefoundryModelSource
526
+ source = ModelVersionSource.from_json(
527
+ TruefoundryModelSource(
528
+ type="truefoundry", uri=artifact_storage_root
529
+ ).to_json()
528
530
  )
529
- src_dest_pairs = _get_src_dest_pairs(
530
- root_dir=temp_dir.name, dest_to_src_map=temp_dest_to_src_map
531
+ artifacts_repo = MlFoundryArtifactsRepository(
532
+ artifact_identifier=ArtifactIdentifier(
533
+ artifact_version_id=uuid.UUID(version_id)
534
+ ),
535
+ api_client=mlfoundry_artifacts_api.api_client,
531
536
  )
532
- artifacts_repo.log_artifacts(src_dest_pairs=src_dest_pairs, progress=progress)
533
- except Exception as e:
534
- mlfoundry_artifacts_api.notify_failure_post(
535
- notify_artifact_version_failure_dto=NotifyArtifactVersionFailureDto(
536
- id=version_id
537
+ total_size = calculate_total_size(list(temp_dest_to_src_map.values()))
538
+ try:
539
+ logger.info(
540
+ "Packaging and uploading files to remote with Total Size: %.6f MB",
541
+ total_size / 1000000.0,
542
+ )
543
+ src_dest_pairs = _get_src_dest_pairs(
544
+ root_dir=temp_dir.name, dest_to_src_map=temp_dest_to_src_map
545
+ )
546
+ artifacts_repo.log_artifacts(
547
+ src_dest_pairs=src_dest_pairs, progress=progress
537
548
  )
549
+ except Exception as e:
550
+ mlfoundry_artifacts_api.notify_failure_post(
551
+ notify_artifact_version_failure_dto=NotifyArtifactVersionFailureDto(
552
+ id=version_id
553
+ )
554
+ )
555
+ raise MlFoundryException("Failed to log model") from e
556
+ finally:
557
+ temp_dir.cleanup()
558
+ elif isinstance(model_file_or_folder, BlobStorageModelDirectory):
559
+ source = ModelVersionSource.from_json(
560
+ ExternalModelSource(type="external", uri=model_file_or_folder.uri).to_json()
538
561
  )
539
- raise MlFoundryException("Failed to log model") from e
540
- finally:
541
- temp_dir.cleanup()
542
-
543
- # Note: Here we call from_dict instead of directly passing in init and relying on it
544
- # to convert because the complicated union of types generates a custom type to handle casting
545
- # Check the source of `InternalMetadataDto` to see the generated code
546
- internal_metadata_dto = InternalMetadataDto.from_dict(
547
- internal_metadata.dict() if internal_metadata is not None else {}
562
+ else:
563
+ raise MlFoundryException("Invalid model_file_or_folder provided")
564
+
565
+ _framework = _ModelFramework.to_model_framework_type(framework)
566
+ model_manifest = ModelVersionManifest(
567
+ description=description,
568
+ metadata=metadata,
569
+ source=source,
570
+ framework=Framework.from_dict(_framework.dict()) if _framework else None,
571
+ step=step if run else 0,
548
572
  )
549
- mlfoundry_artifacts_api.finalize_artifact_version_post(
573
+ artifact_version_response = mlfoundry_artifacts_api.finalize_artifact_version_post(
550
574
  finalize_artifact_version_request_dto=FinalizeArtifactVersionRequestDto(
551
575
  id=version_id,
552
576
  run_uuid=run.run_id if run else None,
553
577
  artifact_size=total_size,
554
- internal_metadata=internal_metadata_dto,
555
- step=step if run else None,
556
- )
557
- )
558
- _model_version = mlfoundry_artifacts_api.create_model_version_post(
559
- create_model_version_request_dto=CreateModelVersionRequestDto(
560
- artifact_version_id=version_id,
561
- description=description,
562
578
  artifact_metadata=metadata,
563
- internal_metadata=internal_metadata_dto,
564
- data_path=INTERNAL_METADATA_PATH,
565
- step=step if run else None,
579
+ internal_metadata=None,
580
+ step=model_manifest.step,
581
+ manifest=model_manifest,
566
582
  )
567
583
  )
568
- model_version = _model_version.model_version
569
-
570
- # update model schema at end
571
- update_args: Dict[str, Any] = {
572
- "id": version_id,
573
- "model_framework": framework.value,
574
- }
575
- if model_schema:
576
- update_args["model_schema"] = model_schema
577
-
578
- try:
579
- _model_version = mlfoundry_artifacts_api.update_model_version_post(
580
- update_model_version_request_dto=UpdateModelVersionRequestDto(**update_args)
581
- )
582
- model_version = _model_version.model_version
583
- if model_schema:
584
- _model_version = mlfoundry_artifacts_api.add_custom_metrics_to_model_version_post(
585
- add_custom_metrics_to_model_version_request_dto=AddCustomMetricsToModelVersionRequestDto(
586
- id=version_id, custom_metrics=custom_metrics
587
- )
588
- )
589
- model_version = _model_version.model_version
590
- except Exception:
591
- # TODO (chiragjn): what is the best exception to catch here?
592
- logger.error(MODEL_SCHEMA_UPDATE_FAILURE_HELP.format(fqn=model_version.fqn))
593
-
594
- return ModelVersion.from_fqn(fqn=model_version.fqn)
584
+ return ModelVersion.from_fqn(fqn=artifact_version_response.artifact_version.fqn)
@@ -2,7 +2,17 @@ import os
2
2
  import time
3
3
  import uuid
4
4
  from pathlib import Path
5
- from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
5
+ from typing import (
6
+ TYPE_CHECKING,
7
+ Any,
8
+ Dict,
9
+ Iterator,
10
+ List,
11
+ Optional,
12
+ Sequence,
13
+ Tuple,
14
+ Union,
15
+ )
6
16
 
7
17
  import coolname
8
18
  import pandas as pd
@@ -42,7 +52,11 @@ from truefoundry.ml.log_types.artifacts.artifact import (
42
52
  )
43
53
  from truefoundry.ml.log_types.artifacts.dataset import DataDirectory
44
54
  from truefoundry.ml.log_types.artifacts.general_artifact import _log_artifact_version
45
- from truefoundry.ml.log_types.artifacts.model import ModelVersion, _log_model_version
55
+ from truefoundry.ml.log_types.artifacts.model import (
56
+ BlobStorageModelDirectory,
57
+ ModelVersion,
58
+ _log_model_version,
59
+ )
46
60
  from truefoundry.ml.logger import logger
47
61
  from truefoundry.ml.mlfoundry_run import MlFoundryRun
48
62
  from truefoundry.ml.session import (
@@ -57,6 +71,9 @@ from truefoundry.ml.validation_utils import (
57
71
  _validate_run_name,
58
72
  )
59
73
 
74
+ if TYPE_CHECKING:
75
+ from truefoundry.ml import ModelFrameworkType
76
+
60
77
  _SEARCH_MAX_RESULTS_DEFAULT = 1000
61
78
 
62
79
  _INTERNAL_ENV_VARS = [
@@ -1209,12 +1226,12 @@ class MlFoundry:
1209
1226
  *,
1210
1227
  ml_repo: str,
1211
1228
  name: str,
1212
- model_file_or_folder: str,
1213
- framework: Optional[Union[ModelFramework, str]],
1229
+ model_file_or_folder: Union[str, BlobStorageModelDirectory],
1214
1230
  additional_files: Sequence[Tuple[Union[str, Path], Optional[str]]] = (),
1215
1231
  description: Optional[str] = None,
1216
1232
  metadata: Optional[Dict[str, Any]] = None,
1217
1233
  progress: Optional[bool] = None,
1234
+ framework: Optional[Union[str, ModelFramework, "ModelFrameworkType"]] = None,
1218
1235
  ) -> ModelVersion:
1219
1236
  """
1220
1237
  Serialize and log a versioned model under the current ml_repo. Each logged model generates a new version
@@ -1226,12 +1243,24 @@ class MlFoundry:
1226
1243
  name (str): Name of the model. If a model with this name already exists under the current ML Repo,
1227
1244
  the logged model will be added as a new version under that `name`. If no models exist with the given
1228
1245
  `name`, the given model will be logged as version 1.
1229
- model_file_or_folder (str): Path to either a single file or a folder containing model files. This folder
1230
- is usually created using serialization methods of libraries or frameworks e.g. `joblib.dump`,
1231
- `model.save_pretrained(...)`, `torch.save(...)`, `model.save(...)`
1232
- framework (Union[enums.ModelFramework, str]): Model Framework. Ex:- pytorch, sklearn, tensorflow etc.
1233
- The full list of supported frameworks can be found in `truefoundry.ml.enums.ModelFramework`.
1234
- Can also be `None` when `model` is `None`.
1246
+
1247
+ model_file_or_folder (Union[str, BlobStorageModelDirectory]):
1248
+ str:
1249
+ Path to either a single file or a folder containing model files.
1250
+ This folder is typically created using serialization methods from libraries or frameworks,
1251
+ e.g., `joblib.dump`, `model.save_pretrained(...)`, `torch.save(...)`, or `model.save(...)`.
1252
+ BlobStorageModelDirectory:
1253
+ uri (str): URI to the model file or folder in a storage integration associated with the specified ML Repo.
1254
+ The model files or folder must reside within the same storage integration as the specified ML Repo.
1255
+ Accepted URI formats include `s3://integration-bucket-name/prefix/path/to/model` or `gs://integration-bucket-name/prefix/path/to/model`.
1256
+ If the URI points to a model in a different storage integration, an error will be raised indicating "Invalid source URI."
1257
+
1258
+ framework (Optional[Union[ModelFramework, ModelFrameworkType]]): Framework used for model serialization.
1259
+ Supported frameworks values (ModelFrameworkType) can be imported from `from truefoundry.ml import *`.
1260
+ Supported frameworks can be found in `truefoundry.ml.enums.ModelFramework`.
1261
+ Can also be `None` if the framework is not known or not supported.
1262
+ **Deprecated**: Prefer `ModelFrameworkType` over `enums.ModelFramework`.
1263
+
1235
1264
  additional_files (Sequence[Tuple[Union[str, Path], Optional[str]]], optional): A list of pairs
1236
1265
  of (source path, destination path) to add additional files and folders
1237
1266
  to the model version contents. The first member of the pair should be a file or directory path
@@ -1245,10 +1274,12 @@ class MlFoundry:
1245
1274
  You can also add additional files to model/ subdirectory by specifying the destination path as model/
1246
1275
 
1247
1276
  ```python
1277
+ from truefoundry.ml import TensorFlowFramework
1278
+
1248
1279
  run.log_model(
1249
1280
  name="xyz",
1250
1281
  model_file_or_folder="clf.joblib",
1251
- framework="sklearn",
1282
+ framework=TensorFlowFramework(),
1252
1283
  additional_files=[("foo.txt", "foo/bar/foo.txt"), ("tokenizer/", "foo/tokenizer/")]
1253
1284
  )
1254
1285
  ```
@@ -1281,8 +1312,7 @@ class MlFoundry:
1281
1312
  ### Sklearn
1282
1313
 
1283
1314
  ```python
1284
- from truefoundry.ml import get_client
1285
- from truefoundry.ml.enums import ModelFramework
1315
+ from truefoundry.ml import get_client, SklearnFramework
1286
1316
 
1287
1317
  import joblib
1288
1318
  import numpy as np
@@ -1307,7 +1337,7 @@ class MlFoundry:
1307
1337
  ml_repo="my-classification-project",
1308
1338
  name="my-sklearn-model",
1309
1339
  model_file_or_folder="sklearn-pipeline.joblib",
1310
- framework=ModelFramework.SKLEARN,
1340
+ framework=SklearnFramework(),
1311
1341
  metadata={"accuracy": 0.99, "f1": 0.80},
1312
1342
  step=1, # step number, useful when using iterative algorithms like SGD
1313
1343
  )
@@ -1317,8 +1347,7 @@ class MlFoundry:
1317
1347
  ### Huggingface Transformers
1318
1348
 
1319
1349
  ```python
1320
- from truefoundry.ml import get_client
1321
- from truefoundry.ml.enums import ModelFramework
1350
+ from truefoundry.ml import get_client, TransformersFramework, LibraryName
1322
1351
 
1323
1352
  import torch
1324
1353
  from transformers import AutoTokenizer, AutoConfig, pipeline, AutoModelForCausalLM
@@ -1342,7 +1371,7 @@ class MlFoundry:
1342
1371
  ml_repo="my-llm-project",
1343
1372
  name="my-transformers-model",
1344
1373
  model_file_or_folder="my-transformers-model/",
1345
- framework=ModelFramework.TRANSFORMERS
1374
+ framework=TransformersFramework(library_name=LibraryName.TRANSFORMERS, pipeline_tag='text-generation')
1346
1375
  )
1347
1376
  print(model_version.fqn)
1348
1377
  ```
@@ -1356,12 +1385,12 @@ class MlFoundry:
1356
1385
  ml_repo_id=ml_repo_id,
1357
1386
  name=name,
1358
1387
  model_file_or_folder=model_file_or_folder,
1359
- framework=framework,
1360
1388
  additional_files=additional_files,
1361
1389
  description=description,
1362
1390
  metadata=metadata,
1363
1391
  step=None,
1364
1392
  progress=progress,
1393
+ framework=framework,
1365
1394
  )
1366
1395
  logger.info(f"Logged model successfully with fqn {model_version.fqn!r}")
1367
1396
  return model_version