truefoundry 0.5.1rc6__py3-none-any.whl → 0.5.1rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truefoundry might be problematic. Click here for more details.

Files changed (24) hide show
  1. truefoundry/common/constants.py +9 -0
  2. truefoundry/deploy/builder/builders/tfy_notebook_buildpack/__init__.py +4 -2
  3. truefoundry/deploy/builder/builders/tfy_python_buildpack/__init__.py +7 -5
  4. truefoundry/deploy/builder/builders/tfy_python_buildpack/dockerfile_template.py +87 -28
  5. truefoundry/deploy/builder/constants.py +8 -0
  6. truefoundry/deploy/builder/utils.py +9 -4
  7. truefoundry/ml/autogen/client/__init__.py +12 -3
  8. truefoundry/ml/autogen/client/api/mlfoundry_artifacts_api.py +164 -0
  9. truefoundry/ml/autogen/client/models/__init__.py +12 -3
  10. truefoundry/ml/autogen/client/models/sklearn_framework.py +4 -7
  11. truefoundry/ml/autogen/client/models/sklearn_model_schema.py +1 -1
  12. truefoundry/ml/autogen/client/models/{serialization_format.py → sklearn_serialization_format.py} +5 -5
  13. truefoundry/ml/autogen/client/models/validate_external_storage_root_request_dto.py +71 -0
  14. truefoundry/ml/autogen/client/models/validate_external_storage_root_response_dto.py +69 -0
  15. truefoundry/ml/autogen/client/models/xg_boost_framework.py +4 -7
  16. truefoundry/ml/autogen/client/models/xg_boost_model_schema.py +10 -4
  17. truefoundry/ml/autogen/client/models/xg_boost_serialization_format.py +36 -0
  18. truefoundry/ml/autogen/client_README.md +5 -1
  19. truefoundry/ml/autogen/entities/artifacts.py +49 -36
  20. truefoundry/ml/model_framework.py +97 -69
  21. {truefoundry-0.5.1rc6.dist-info → truefoundry-0.5.1rc8.dist-info}/METADATA +1 -1
  22. {truefoundry-0.5.1rc6.dist-info → truefoundry-0.5.1rc8.dist-info}/RECORD +24 -21
  23. {truefoundry-0.5.1rc6.dist-info → truefoundry-0.5.1rc8.dist-info}/WHEEL +0 -0
  24. {truefoundry-0.5.1rc6.dist-info → truefoundry-0.5.1rc8.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,71 @@
1
+ # coding: utf-8
2
+
3
+ """
4
+ FastAPI
5
+
6
+ No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
7
+
8
+ The version of the OpenAPI document: 0.1.0
9
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
10
+
11
+ Do not edit the class manually.
12
+ """ # noqa: E501
13
+
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ import pprint
18
+ import re # noqa: F401
19
+
20
+ from truefoundry.pydantic_v1 import BaseModel, Field, StrictStr
21
+
22
+
23
+ class ValidateExternalStorageRootRequestDto(BaseModel):
24
+ """
25
+ ValidateExternalStorageRootRequestDto
26
+ """
27
+
28
+ storage_root_uri: StrictStr = Field(...)
29
+ experiment_id: StrictStr = Field(...)
30
+ __properties = ["storage_root_uri", "experiment_id"]
31
+
32
+ class Config:
33
+ """Pydantic configuration"""
34
+
35
+ allow_population_by_field_name = True
36
+ validate_assignment = True
37
+
38
+ def to_str(self) -> str:
39
+ """Returns the string representation of the model using alias"""
40
+ return pprint.pformat(self.dict(by_alias=True))
41
+
42
+ def to_json(self) -> str:
43
+ """Returns the JSON representation of the model using alias"""
44
+ return json.dumps(self.to_dict())
45
+
46
+ @classmethod
47
+ def from_json(cls, json_str: str) -> ValidateExternalStorageRootRequestDto:
48
+ """Create an instance of ValidateExternalStorageRootRequestDto from a JSON string"""
49
+ return cls.from_dict(json.loads(json_str))
50
+
51
+ def to_dict(self):
52
+ """Returns the dictionary representation of the model using alias"""
53
+ _dict = self.dict(by_alias=True, exclude={}, exclude_none=True)
54
+ return _dict
55
+
56
+ @classmethod
57
+ def from_dict(cls, obj: dict) -> ValidateExternalStorageRootRequestDto:
58
+ """Create an instance of ValidateExternalStorageRootRequestDto from a dict"""
59
+ if obj is None:
60
+ return None
61
+
62
+ if not isinstance(obj, dict):
63
+ return ValidateExternalStorageRootRequestDto.parse_obj(obj)
64
+
65
+ _obj = ValidateExternalStorageRootRequestDto.parse_obj(
66
+ {
67
+ "storage_root_uri": obj.get("storage_root_uri"),
68
+ "experiment_id": obj.get("experiment_id"),
69
+ }
70
+ )
71
+ return _obj
@@ -0,0 +1,69 @@
1
+ # coding: utf-8
2
+
3
+ """
4
+ FastAPI
5
+
6
+ No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
7
+
8
+ The version of the OpenAPI document: 0.1.0
9
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
10
+
11
+ Do not edit the class manually.
12
+ """ # noqa: E501
13
+
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ import pprint
18
+ import re # noqa: F401
19
+ from typing import Optional
20
+
21
+ from truefoundry.pydantic_v1 import BaseModel, Field, StrictBool, StrictStr
22
+
23
+
24
+ class ValidateExternalStorageRootResponseDto(BaseModel):
25
+ """
26
+ ValidateExternalStorageRootResponseDto
27
+ """
28
+
29
+ is_valid: StrictBool = Field(...)
30
+ message: Optional[StrictStr] = None
31
+ __properties = ["is_valid", "message"]
32
+
33
+ class Config:
34
+ """Pydantic configuration"""
35
+
36
+ allow_population_by_field_name = True
37
+ validate_assignment = True
38
+
39
+ def to_str(self) -> str:
40
+ """Returns the string representation of the model using alias"""
41
+ return pprint.pformat(self.dict(by_alias=True))
42
+
43
+ def to_json(self) -> str:
44
+ """Returns the JSON representation of the model using alias"""
45
+ return json.dumps(self.to_dict())
46
+
47
+ @classmethod
48
+ def from_json(cls, json_str: str) -> ValidateExternalStorageRootResponseDto:
49
+ """Create an instance of ValidateExternalStorageRootResponseDto from a JSON string"""
50
+ return cls.from_dict(json.loads(json_str))
51
+
52
+ def to_dict(self):
53
+ """Returns the dictionary representation of the model using alias"""
54
+ _dict = self.dict(by_alias=True, exclude={}, exclude_none=True)
55
+ return _dict
56
+
57
+ @classmethod
58
+ def from_dict(cls, obj: dict) -> ValidateExternalStorageRootResponseDto:
59
+ """Create an instance of ValidateExternalStorageRootResponseDto from a dict"""
60
+ if obj is None:
61
+ return None
62
+
63
+ if not isinstance(obj, dict):
64
+ return ValidateExternalStorageRootResponseDto.parse_obj(obj)
65
+
66
+ _obj = ValidateExternalStorageRootResponseDto.parse_obj(
67
+ {"is_valid": obj.get("is_valid"), "message": obj.get("message")}
68
+ )
69
+ return _obj
@@ -18,12 +18,12 @@ import pprint
18
18
  import re # noqa: F401
19
19
  from typing import Optional
20
20
 
21
- from truefoundry.ml.autogen.client.models.serialization_format import (
22
- SerializationFormat,
23
- )
24
21
  from truefoundry.ml.autogen.client.models.xg_boost_model_schema import (
25
22
  XGBoostModelSchema,
26
23
  )
24
+ from truefoundry.ml.autogen.client.models.xg_boost_serialization_format import (
25
+ XGBoostSerializationFormat,
26
+ )
27
27
  from truefoundry.pydantic_v1 import BaseModel, Field, StrictStr, validator
28
28
 
29
29
 
@@ -36,10 +36,7 @@ class XGBoostFramework(BaseModel):
36
36
  default=...,
37
37
  description="+label=Type +usage=Type of the framework +value=xgboost",
38
38
  )
39
- serialization_format: Optional[SerializationFormat] = Field(
40
- default=None,
41
- description="+label=Serialization format +usage=Serialization format used for the model",
42
- )
39
+ serialization_format: Optional[XGBoostSerializationFormat] = None
43
40
  model_filepath: Optional[StrictStr] = Field(
44
41
  default=None,
45
42
  description="+label=Model file path +usage=Relative path to the model file",
@@ -18,16 +18,15 @@ import pprint
18
18
  import re # noqa: F401
19
19
  from typing import Any, Dict
20
20
 
21
- from truefoundry.ml.autogen.client.models.infer_method_name import InferMethodName
22
- from truefoundry.pydantic_v1 import BaseModel, Field, conlist
21
+ from truefoundry.pydantic_v1 import BaseModel, Field, StrictStr, conlist, validator
23
22
 
24
23
 
25
24
  class XGBoostModelSchema(BaseModel):
26
25
  """
27
- XGBoostModelSchema
26
+ +label=XGBoost Model Schema # noqa: E501
28
27
  """
29
28
 
30
- infer_method_name: InferMethodName = Field(
29
+ infer_method_name: StrictStr = Field(
31
30
  default=...,
32
31
  description="+label=Inference Method Name +usage=Name of the method used for inference",
33
32
  )
@@ -39,6 +38,13 @@ class XGBoostModelSchema(BaseModel):
39
38
  )
40
39
  __properties = ["infer_method_name", "inputs", "outputs"]
41
40
 
41
+ @validator("infer_method_name")
42
+ def infer_method_name_validate_enum(cls, value):
43
+ """Validates the enum"""
44
+ if value not in ("predict",):
45
+ raise ValueError("must be one of enum values ('predict')")
46
+ return value
47
+
42
48
  class Config:
43
49
  """Pydantic configuration"""
44
50
 
@@ -0,0 +1,36 @@
1
+ # coding: utf-8
2
+
3
+ """
4
+ FastAPI
5
+
6
+ No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
7
+
8
+ The version of the OpenAPI document: 0.1.0
9
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
10
+
11
+ Do not edit the class manually.
12
+ """ # noqa: E501
13
+
14
+ import json
15
+ import re # noqa: F401
16
+
17
+ from aenum import Enum
18
+
19
+
20
+ class XGBoostSerializationFormat(str, Enum):
21
+ """
22
+ +label=Serialization format +usage=Serialization format used for XGBoost models
23
+ """
24
+
25
+ """
26
+ allowed enum values
27
+ """
28
+ CLOUDPICKLE = "cloudpickle"
29
+ JOBLIB = "joblib"
30
+ PICKLE = "pickle"
31
+ JSON = "json"
32
+
33
+ @classmethod
34
+ def from_json(cls, json_str: str) -> XGBoostSerializationFormat:
35
+ """Create an instance of XGBoostSerializationFormat from a JSON string"""
36
+ return XGBoostSerializationFormat(json.loads(json_str))
@@ -133,6 +133,7 @@ Class | Method | HTTP request | Description
133
133
  *MlfoundryArtifactsApi* | [**update_artifact_version_post**](truefoundry/ml/autogen/client/docs/MlfoundryArtifactsApi.md#update_artifact_version_post) | **POST** /api/2.0/mlflow/mlfoundry-artifacts/artifact-versions/update | Update Artifact Version
134
134
  *MlfoundryArtifactsApi* | [**update_dataset_post**](truefoundry/ml/autogen/client/docs/MlfoundryArtifactsApi.md#update_dataset_post) | **POST** /api/2.0/mlflow/mlfoundry-artifacts/datasets/update | Update Dataset
135
135
  *MlfoundryArtifactsApi* | [**update_model_version_post**](truefoundry/ml/autogen/client/docs/MlfoundryArtifactsApi.md#update_model_version_post) | **POST** /api/2.0/mlflow/mlfoundry-artifacts/model-versions/update | Update Model Version
136
+ *MlfoundryArtifactsApi* | [**validate_external_storage_root_path_post**](truefoundry/ml/autogen/client/docs/MlfoundryArtifactsApi.md#validate_external_storage_root_path_post) | **POST** /api/2.0/mlflow/mlfoundry-artifacts/artifact-versions/validate-storage-root | Validate External Storage Root Path
136
137
  *RunArtifactsApi* | [**list_run_artifacts_get**](truefoundry/ml/autogen/client/docs/RunArtifactsApi.md#list_run_artifacts_get) | **GET** /api/2.0/mlflow/artifacts/list | List Run Artifacts
137
138
  *RunsApi* | [**create_run_post**](truefoundry/ml/autogen/client/docs/RunsApi.md#create_run_post) | **POST** /api/2.0/mlflow/runs/create | Create Run
138
139
  *RunsApi* | [**delete_run_post**](truefoundry/ml/autogen/client/docs/RunsApi.md#delete_run_post) | **POST** /api/2.0/mlflow/runs/delete | Delete Run
@@ -299,12 +300,12 @@ Class | Method | HTTP request | Description
299
300
  - [RunTagDto](truefoundry/ml/autogen/client/docs/RunTagDto.md)
300
301
  - [SearchRunsRequestDto](truefoundry/ml/autogen/client/docs/SearchRunsRequestDto.md)
301
302
  - [SearchRunsResponseDto](truefoundry/ml/autogen/client/docs/SearchRunsResponseDto.md)
302
- - [SerializationFormat](truefoundry/ml/autogen/client/docs/SerializationFormat.md)
303
303
  - [SetExperimentTagRequestDto](truefoundry/ml/autogen/client/docs/SetExperimentTagRequestDto.md)
304
304
  - [SetTagRequestDto](truefoundry/ml/autogen/client/docs/SetTagRequestDto.md)
305
305
  - [SignedURLDto](truefoundry/ml/autogen/client/docs/SignedURLDto.md)
306
306
  - [SklearnFramework](truefoundry/ml/autogen/client/docs/SklearnFramework.md)
307
307
  - [SklearnModelSchema](truefoundry/ml/autogen/client/docs/SklearnModelSchema.md)
308
+ - [SklearnSerializationFormat](truefoundry/ml/autogen/client/docs/SklearnSerializationFormat.md)
308
309
  - [Source](truefoundry/ml/autogen/client/docs/Source.md)
309
310
  - [Source1](truefoundry/ml/autogen/client/docs/Source1.md)
310
311
  - [SpaCyFramework](truefoundry/ml/autogen/client/docs/SpaCyFramework.md)
@@ -329,10 +330,13 @@ Class | Method | HTTP request | Description
329
330
  - [UpdateRunResponseDto](truefoundry/ml/autogen/client/docs/UpdateRunResponseDto.md)
330
331
  - [Url](truefoundry/ml/autogen/client/docs/Url.md)
331
332
  - [UserMessage](truefoundry/ml/autogen/client/docs/UserMessage.md)
333
+ - [ValidateExternalStorageRootRequestDto](truefoundry/ml/autogen/client/docs/ValidateExternalStorageRootRequestDto.md)
334
+ - [ValidateExternalStorageRootResponseDto](truefoundry/ml/autogen/client/docs/ValidateExternalStorageRootResponseDto.md)
332
335
  - [ValidationError](truefoundry/ml/autogen/client/docs/ValidationError.md)
333
336
  - [ValidationErrorLocInner](truefoundry/ml/autogen/client/docs/ValidationErrorLocInner.md)
334
337
  - [XGBoostFramework](truefoundry/ml/autogen/client/docs/XGBoostFramework.md)
335
338
  - [XGBoostModelSchema](truefoundry/ml/autogen/client/docs/XGBoostModelSchema.md)
339
+ - [XGBoostSerializationFormat](truefoundry/ml/autogen/client/docs/XGBoostSerializationFormat.md)
336
340
 
337
341
 
338
342
  <a id="documentation-for-authorization"></a>
@@ -1,6 +1,6 @@
1
1
  # generated by datamodel-codegen:
2
2
  # filename: artifacts.json
3
- # timestamp: 2024-12-05T14:45:34+00:00
3
+ # timestamp: 2024-12-09T09:04:12+00:00
4
4
 
5
5
  from __future__ import annotations
6
6
 
@@ -60,19 +60,6 @@ class BaseArtifactVersion(BaseModel):
60
60
  )
61
61
 
62
62
 
63
- class BaseModelSchema(BaseModel):
64
- infer_method_name: str = Field(
65
- ...,
66
- description="+label=Inference Method Name\n+usage=Name of the method used for inference",
67
- )
68
- inputs: List[Dict[str, Any]] = Field(
69
- ..., description="+label= Input Schema\n+usage=Schema of the input"
70
- )
71
- outputs: List[Dict[str, Any]] = Field(
72
- ..., description="+label= Output Schema\n+usage=Schema of the output"
73
- )
74
-
75
-
76
63
  class MimeType(str, Enum):
77
64
  """
78
65
  +label=MIME Type
@@ -289,17 +276,6 @@ class PyTorchFramework(BaseModel):
289
276
  )
290
277
 
291
278
 
292
- class SerializationFormat(str, Enum):
293
- """
294
- +label=Serialization format
295
- +usage=Serialization format used for the model
296
- """
297
-
298
- cloudpickle = "cloudpickle"
299
- joblib = "joblib"
300
- pickle = "pickle"
301
-
302
-
303
279
  class InferMethodName(str, Enum):
304
280
  """
305
281
  +label=Inference Method Name
@@ -310,11 +286,32 @@ class InferMethodName(str, Enum):
310
286
  predict_proba = "predict_proba"
311
287
 
312
288
 
313
- class SklearnModelSchema(BaseModelSchema):
289
+ class SklearnModelSchema(BaseModel):
290
+ """
291
+ +label=Sklearn Model Schema
292
+ """
293
+
314
294
  infer_method_name: InferMethodName = Field(
315
295
  ...,
316
296
  description="+label=Inference Method Name\n+usage=Name of the method used for inference",
317
297
  )
298
+ inputs: List[Dict[str, Any]] = Field(
299
+ ..., description="+label= Input Schema\n+usage=Schema of the input"
300
+ )
301
+ outputs: List[Dict[str, Any]] = Field(
302
+ ..., description="+label= Output Schema\n+usage=Schema of the output"
303
+ )
304
+
305
+
306
+ class SklearnSerializationFormat(str, Enum):
307
+ """
308
+ +label=Serialization format
309
+ +usage=Serialization format used for sklearn models
310
+ """
311
+
312
+ cloudpickle = "cloudpickle"
313
+ joblib = "joblib"
314
+ pickle = "pickle"
318
315
 
319
316
 
320
317
  class SpaCyFramework(BaseModel):
@@ -448,11 +445,33 @@ class UserMessage(BaseModel):
448
445
  )
449
446
 
450
447
 
451
- class XGBoostModelSchema(BaseModelSchema):
452
- infer_method_name: InferMethodName = Field(
448
+ class XGBoostModelSchema(BaseModel):
449
+ """
450
+ +label=XGBoost Model Schema
451
+ """
452
+
453
+ infer_method_name: Literal["predict"] = Field(
453
454
  ...,
454
455
  description="+label=Inference Method Name\n+usage=Name of the method used for inference",
455
456
  )
457
+ inputs: List[Dict[str, Any]] = Field(
458
+ ..., description="+label= Input Schema\n+usage=Schema of the input"
459
+ )
460
+ outputs: List[Dict[str, Any]] = Field(
461
+ ..., description="+label= Output Schema\n+usage=Schema of the output"
462
+ )
463
+
464
+
465
+ class XGBoostSerializationFormat(str, Enum):
466
+ """
467
+ +label=Serialization format
468
+ +usage=Serialization format used for XGBoost models
469
+ """
470
+
471
+ cloudpickle = "cloudpickle"
472
+ joblib = "joblib"
473
+ pickle = "pickle"
474
+ json = "json"
456
475
 
457
476
 
458
477
  class AgentOpenAPITool(BaseModel):
@@ -561,10 +580,7 @@ class SklearnFramework(BaseModel):
561
580
  None,
562
581
  description="+label=Model file path\n+usage=Relative path to the model file",
563
582
  )
564
- serialization_format: Optional[SerializationFormat] = Field(
565
- None,
566
- description="+label=Serialization format\n+usage=Serialization format used for the model",
567
- )
583
+ serialization_format: Optional[SklearnSerializationFormat] = None
568
584
  model_schema: Optional[SklearnModelSchema] = None
569
585
 
570
586
 
@@ -577,10 +593,7 @@ class XGBoostFramework(BaseModel):
577
593
  type: Literal["xgboost"] = Field(
578
594
  ..., description="+label=Type\n+usage=Type of the framework\n+value=xgboost"
579
595
  )
580
- serialization_format: Optional[SerializationFormat] = Field(
581
- None,
582
- description="+label=Serialization format\n+usage=Serialization format used for the model",
583
- )
596
+ serialization_format: Optional[XGBoostSerializationFormat] = None
584
597
  model_filepath: Optional[str] = Field(
585
598
  None,
586
599
  description="+label=Model file path\n+usage=Relative path to the model file",
@@ -20,7 +20,10 @@ from truefoundry.common.utils import (
20
20
  get_python_version_major_minor,
21
21
  list_pip_packages_installed,
22
22
  )
23
- from truefoundry.ml.autogen.client import SerializationFormat
23
+ from truefoundry.ml.autogen.client import (
24
+ SklearnSerializationFormat,
25
+ XGBoostSerializationFormat,
26
+ )
24
27
  from truefoundry.ml.autogen.entities import artifacts as autogen_artifacts
25
28
  from truefoundry.ml.autogen.models import infer_signature
26
29
  from truefoundry.ml.enums import ModelFramework
@@ -35,46 +38,13 @@ if TYPE_CHECKING:
35
38
 
36
39
  # Map serialization format to corresponding pip packages
37
40
  SERIALIZATION_FORMAT_TO_PACKAGES_NAME_MAP = {
38
- SerializationFormat.JOBLIB: ["joblib"],
39
- SerializationFormat.CLOUDPICKLE: ["cloudpickle"],
41
+ SklearnSerializationFormat.JOBLIB: ["joblib"],
42
+ SklearnSerializationFormat.CLOUDPICKLE: ["cloudpickle"],
43
+ XGBoostSerializationFormat.JOBLIB: ["joblib"],
44
+ XGBoostSerializationFormat.CLOUDPICKLE: ["cloudpickle"],
40
45
  }
41
46
 
42
47
 
43
- class _SerializationFormatLoaderRegistry:
44
- def __init__(self):
45
- # An OrderedDict is used to maintain the order of loaders based on priority
46
- # The loaders are added in the following order:
47
- # 1. joblib (if available)
48
- # 2. cloudpickle (if available)
49
- # 3. pickle (default fallback)
50
- # This ensures that when looking up a loader, it follows the correct loading priority.
51
- self._loader_map: Dict[SerializationFormat, Callable[[bytes], object]] = (
52
- OrderedDict()
53
- )
54
- try:
55
- from joblib import load as joblib_load
56
-
57
- self._loader_map[SerializationFormat.JOBLIB] = joblib_load
58
- except ImportError:
59
- pass
60
-
61
- try:
62
- from cloudpickle import load as cloudpickle_load
63
-
64
- self._loader_map[SerializationFormat.CLOUDPICKLE] = cloudpickle_load
65
- except ImportError:
66
- pass
67
-
68
- # Add pickle loader as a fallback
69
- self._loader_map[SerializationFormat.PICKLE] = pickle_load
70
-
71
- def get_loader_map(self) -> Dict[SerializationFormat, Callable[[bytes], object]]:
72
- return self._loader_map
73
-
74
-
75
- _serialization_format_loader_registry = _SerializationFormatLoaderRegistry()
76
-
77
-
78
48
  class FastAIFramework(autogen_artifacts.FastAIFramework):
79
49
  """FastAI model Framework"""
80
50
 
@@ -180,6 +150,87 @@ ModelFrameworkType = Union[
180
150
  ]
181
151
 
182
152
 
153
+ class _SerializationFormatLoaderRegistry:
154
+ def __init__(self, framework: Type[Union[SklearnFramework, XGBoostFramework]]):
155
+ # An OrderedDict is used to maintain the order of loaders based on priority
156
+ # The loaders are added in the following order:
157
+ # 1. joblib (if available)
158
+ # 2. cloudpickle (if available)
159
+ # 3. pickle (default fallback)
160
+ # This ensures that when looking up a loader, it follows the correct loading priority.
161
+ self._loader_map: Dict[
162
+ Union[SklearnSerializationFormat, XGBoostSerializationFormat],
163
+ Callable[[bytes], object],
164
+ ] = OrderedDict()
165
+ format_class: Union[SklearnSerializationFormat, XGBoostSerializationFormat] = (
166
+ SklearnSerializationFormat
167
+ if framework == SklearnFramework
168
+ else XGBoostSerializationFormat
169
+ )
170
+ is_xgboost = framework == XGBoostFramework
171
+
172
+ try:
173
+ from joblib import load as joblib_load
174
+
175
+ self._loader_map[format_class.JOBLIB] = joblib_load
176
+ except ImportError:
177
+ pass
178
+
179
+ try:
180
+ from cloudpickle import load as cloudpickle_load
181
+
182
+ self._loader_map[format_class.CLOUDPICKLE] = cloudpickle_load
183
+
184
+ except ImportError:
185
+ pass
186
+
187
+ if is_xgboost:
188
+ try:
189
+ from xgboost import Booster
190
+
191
+ booster = Booster()
192
+ self._loader_map[format_class.JSON] = booster.load_model
193
+ except ImportError:
194
+ pass
195
+
196
+ # Add pickle loader as a fallback
197
+ self._loader_map[format_class.PICKLE] = pickle_load
198
+
199
+ def get_loader_map(
200
+ self,
201
+ ) -> Dict[
202
+ Union[SklearnSerializationFormat, XGBoostSerializationFormat],
203
+ Callable[[bytes], object],
204
+ ]:
205
+ return self._loader_map
206
+
207
+ def _detect_model_serialization_format(
208
+ self,
209
+ model_file_path: str,
210
+ ) -> Optional[Union[SklearnSerializationFormat, XGBoostSerializationFormat]]:
211
+ """
212
+ The function will attempt to load the model using each different serialization format's loader and return the first successful one.
213
+
214
+ Args:
215
+ model_file_path (str): The path to the file to be loaded.
216
+
217
+ Returns:
218
+ Optional[Union[SklearnSerializationFormat, XGBoostSerializationFormat]]: The serialization format if successfully loaded, None otherwise.
219
+ """
220
+ # Attempt to load the model using each framework
221
+ for (
222
+ serialization_format,
223
+ loader,
224
+ ) in self._loader_map.items():
225
+ try:
226
+ with open(model_file_path, "rb") as f:
227
+ loader(f)
228
+ return serialization_format
229
+ except Exception:
230
+ continue
231
+ return None
232
+
233
+
183
234
  class _ModelFramework(BaseModel):
184
235
  __root__: ModelFrameworkType = Field(discriminator="type")
185
236
 
@@ -259,32 +310,6 @@ def _get_required_framework_pip_packages(framework: "ModelFrameworkType") -> Lis
259
310
  return _MODEL_FRAMEWORK_TO_PIP_PACKAGES.get(framework.__class__, [])
260
311
 
261
312
 
262
- def _detect_model_serialization_format(
263
- model_file_path: str,
264
- ) -> Optional[SerializationFormat]:
265
- """
266
- The function will attempt to load the model using each framework's loader and return the first successful one.
267
-
268
- Args:
269
- model_file_path (str): The path to the file to be loaded.
270
-
271
- Returns:
272
- Optional[SerializationFormat]: The serialization format if successfully loaded, None otherwise.
273
- """
274
- # Attempt to load the model using each framework
275
- for (
276
- serialization_format,
277
- loader,
278
- ) in _serialization_format_loader_registry.get_loader_map().items():
279
- try:
280
- with open(model_file_path, "rb") as f:
281
- loader(f)
282
- return serialization_format
283
- except Exception:
284
- continue
285
- return None
286
-
287
-
288
313
  def _fetch_framework_specific_pip_packages(
289
314
  framework: "ModelFrameworkType",
290
315
  ) -> List[str]:
@@ -431,12 +456,15 @@ def auto_update_model_framework_details(
431
456
  if os.path.isfile(model_file_or_folder)
432
457
  else os.path.join(model_file_or_folder, framework.model_filepath)
433
458
  )
434
- framework.serialization_format = (
435
- framework.serialization_format
436
- or _detect_model_serialization_format(
437
- model_file_path=absolute_model_filepath
459
+ if not framework.serialization_format:
460
+ loader_registry = _SerializationFormatLoaderRegistry(
461
+ framework=framework
462
+ )
463
+ framework.serialization_format = (
464
+ loader_registry._detect_model_serialization_format(
465
+ model_file_path=absolute_model_filepath
466
+ )
438
467
  )
439
- )
440
468
 
441
469
 
442
470
  def sklearn_infer_schema(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: truefoundry
3
- Version: 0.5.1rc6
3
+ Version: 0.5.1rc8
4
4
  Summary: Truefoundry CLI
5
5
  Author: Abhishek Choudhary
6
6
  Author-email: abhishek@truefoundry.com