truefoundry 0.5.1rc6__py3-none-any.whl → 0.5.1rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truefoundry might be problematic. Click here for more details.
- truefoundry/common/constants.py +9 -0
- truefoundry/deploy/builder/builders/tfy_notebook_buildpack/__init__.py +4 -2
- truefoundry/deploy/builder/builders/tfy_python_buildpack/__init__.py +7 -5
- truefoundry/deploy/builder/builders/tfy_python_buildpack/dockerfile_template.py +87 -28
- truefoundry/deploy/builder/constants.py +8 -0
- truefoundry/deploy/builder/utils.py +9 -4
- truefoundry/ml/autogen/client/__init__.py +12 -3
- truefoundry/ml/autogen/client/api/mlfoundry_artifacts_api.py +164 -0
- truefoundry/ml/autogen/client/models/__init__.py +12 -3
- truefoundry/ml/autogen/client/models/sklearn_framework.py +4 -7
- truefoundry/ml/autogen/client/models/sklearn_model_schema.py +1 -1
- truefoundry/ml/autogen/client/models/{serialization_format.py → sklearn_serialization_format.py} +5 -5
- truefoundry/ml/autogen/client/models/validate_external_storage_root_request_dto.py +71 -0
- truefoundry/ml/autogen/client/models/validate_external_storage_root_response_dto.py +69 -0
- truefoundry/ml/autogen/client/models/xg_boost_framework.py +4 -7
- truefoundry/ml/autogen/client/models/xg_boost_model_schema.py +10 -4
- truefoundry/ml/autogen/client/models/xg_boost_serialization_format.py +36 -0
- truefoundry/ml/autogen/client_README.md +5 -1
- truefoundry/ml/autogen/entities/artifacts.py +49 -36
- truefoundry/ml/model_framework.py +97 -69
- {truefoundry-0.5.1rc6.dist-info → truefoundry-0.5.1rc8.dist-info}/METADATA +1 -1
- {truefoundry-0.5.1rc6.dist-info → truefoundry-0.5.1rc8.dist-info}/RECORD +24 -21
- {truefoundry-0.5.1rc6.dist-info → truefoundry-0.5.1rc8.dist-info}/WHEEL +0 -0
- {truefoundry-0.5.1rc6.dist-info → truefoundry-0.5.1rc8.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
# coding: utf-8
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
FastAPI
|
|
5
|
+
|
|
6
|
+
No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
|
7
|
+
|
|
8
|
+
The version of the OpenAPI document: 0.1.0
|
|
9
|
+
Generated by OpenAPI Generator (https://openapi-generator.tech)
|
|
10
|
+
|
|
11
|
+
Do not edit the class manually.
|
|
12
|
+
""" # noqa: E501
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import pprint
|
|
18
|
+
import re # noqa: F401
|
|
19
|
+
|
|
20
|
+
from truefoundry.pydantic_v1 import BaseModel, Field, StrictStr
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ValidateExternalStorageRootRequestDto(BaseModel):
|
|
24
|
+
"""
|
|
25
|
+
ValidateExternalStorageRootRequestDto
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
storage_root_uri: StrictStr = Field(...)
|
|
29
|
+
experiment_id: StrictStr = Field(...)
|
|
30
|
+
__properties = ["storage_root_uri", "experiment_id"]
|
|
31
|
+
|
|
32
|
+
class Config:
|
|
33
|
+
"""Pydantic configuration"""
|
|
34
|
+
|
|
35
|
+
allow_population_by_field_name = True
|
|
36
|
+
validate_assignment = True
|
|
37
|
+
|
|
38
|
+
def to_str(self) -> str:
|
|
39
|
+
"""Returns the string representation of the model using alias"""
|
|
40
|
+
return pprint.pformat(self.dict(by_alias=True))
|
|
41
|
+
|
|
42
|
+
def to_json(self) -> str:
|
|
43
|
+
"""Returns the JSON representation of the model using alias"""
|
|
44
|
+
return json.dumps(self.to_dict())
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def from_json(cls, json_str: str) -> ValidateExternalStorageRootRequestDto:
|
|
48
|
+
"""Create an instance of ValidateExternalStorageRootRequestDto from a JSON string"""
|
|
49
|
+
return cls.from_dict(json.loads(json_str))
|
|
50
|
+
|
|
51
|
+
def to_dict(self):
|
|
52
|
+
"""Returns the dictionary representation of the model using alias"""
|
|
53
|
+
_dict = self.dict(by_alias=True, exclude={}, exclude_none=True)
|
|
54
|
+
return _dict
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def from_dict(cls, obj: dict) -> ValidateExternalStorageRootRequestDto:
|
|
58
|
+
"""Create an instance of ValidateExternalStorageRootRequestDto from a dict"""
|
|
59
|
+
if obj is None:
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
if not isinstance(obj, dict):
|
|
63
|
+
return ValidateExternalStorageRootRequestDto.parse_obj(obj)
|
|
64
|
+
|
|
65
|
+
_obj = ValidateExternalStorageRootRequestDto.parse_obj(
|
|
66
|
+
{
|
|
67
|
+
"storage_root_uri": obj.get("storage_root_uri"),
|
|
68
|
+
"experiment_id": obj.get("experiment_id"),
|
|
69
|
+
}
|
|
70
|
+
)
|
|
71
|
+
return _obj
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# coding: utf-8
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
FastAPI
|
|
5
|
+
|
|
6
|
+
No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
|
7
|
+
|
|
8
|
+
The version of the OpenAPI document: 0.1.0
|
|
9
|
+
Generated by OpenAPI Generator (https://openapi-generator.tech)
|
|
10
|
+
|
|
11
|
+
Do not edit the class manually.
|
|
12
|
+
""" # noqa: E501
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import pprint
|
|
18
|
+
import re # noqa: F401
|
|
19
|
+
from typing import Optional
|
|
20
|
+
|
|
21
|
+
from truefoundry.pydantic_v1 import BaseModel, Field, StrictBool, StrictStr
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ValidateExternalStorageRootResponseDto(BaseModel):
|
|
25
|
+
"""
|
|
26
|
+
ValidateExternalStorageRootResponseDto
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
is_valid: StrictBool = Field(...)
|
|
30
|
+
message: Optional[StrictStr] = None
|
|
31
|
+
__properties = ["is_valid", "message"]
|
|
32
|
+
|
|
33
|
+
class Config:
|
|
34
|
+
"""Pydantic configuration"""
|
|
35
|
+
|
|
36
|
+
allow_population_by_field_name = True
|
|
37
|
+
validate_assignment = True
|
|
38
|
+
|
|
39
|
+
def to_str(self) -> str:
|
|
40
|
+
"""Returns the string representation of the model using alias"""
|
|
41
|
+
return pprint.pformat(self.dict(by_alias=True))
|
|
42
|
+
|
|
43
|
+
def to_json(self) -> str:
|
|
44
|
+
"""Returns the JSON representation of the model using alias"""
|
|
45
|
+
return json.dumps(self.to_dict())
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def from_json(cls, json_str: str) -> ValidateExternalStorageRootResponseDto:
|
|
49
|
+
"""Create an instance of ValidateExternalStorageRootResponseDto from a JSON string"""
|
|
50
|
+
return cls.from_dict(json.loads(json_str))
|
|
51
|
+
|
|
52
|
+
def to_dict(self):
|
|
53
|
+
"""Returns the dictionary representation of the model using alias"""
|
|
54
|
+
_dict = self.dict(by_alias=True, exclude={}, exclude_none=True)
|
|
55
|
+
return _dict
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def from_dict(cls, obj: dict) -> ValidateExternalStorageRootResponseDto:
|
|
59
|
+
"""Create an instance of ValidateExternalStorageRootResponseDto from a dict"""
|
|
60
|
+
if obj is None:
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
if not isinstance(obj, dict):
|
|
64
|
+
return ValidateExternalStorageRootResponseDto.parse_obj(obj)
|
|
65
|
+
|
|
66
|
+
_obj = ValidateExternalStorageRootResponseDto.parse_obj(
|
|
67
|
+
{"is_valid": obj.get("is_valid"), "message": obj.get("message")}
|
|
68
|
+
)
|
|
69
|
+
return _obj
|
|
@@ -18,12 +18,12 @@ import pprint
|
|
|
18
18
|
import re # noqa: F401
|
|
19
19
|
from typing import Optional
|
|
20
20
|
|
|
21
|
-
from truefoundry.ml.autogen.client.models.serialization_format import (
|
|
22
|
-
SerializationFormat,
|
|
23
|
-
)
|
|
24
21
|
from truefoundry.ml.autogen.client.models.xg_boost_model_schema import (
|
|
25
22
|
XGBoostModelSchema,
|
|
26
23
|
)
|
|
24
|
+
from truefoundry.ml.autogen.client.models.xg_boost_serialization_format import (
|
|
25
|
+
XGBoostSerializationFormat,
|
|
26
|
+
)
|
|
27
27
|
from truefoundry.pydantic_v1 import BaseModel, Field, StrictStr, validator
|
|
28
28
|
|
|
29
29
|
|
|
@@ -36,10 +36,7 @@ class XGBoostFramework(BaseModel):
|
|
|
36
36
|
default=...,
|
|
37
37
|
description="+label=Type +usage=Type of the framework +value=xgboost",
|
|
38
38
|
)
|
|
39
|
-
serialization_format: Optional[
|
|
40
|
-
default=None,
|
|
41
|
-
description="+label=Serialization format +usage=Serialization format used for the model",
|
|
42
|
-
)
|
|
39
|
+
serialization_format: Optional[XGBoostSerializationFormat] = None
|
|
43
40
|
model_filepath: Optional[StrictStr] = Field(
|
|
44
41
|
default=None,
|
|
45
42
|
description="+label=Model file path +usage=Relative path to the model file",
|
|
@@ -18,16 +18,15 @@ import pprint
|
|
|
18
18
|
import re # noqa: F401
|
|
19
19
|
from typing import Any, Dict
|
|
20
20
|
|
|
21
|
-
from truefoundry.
|
|
22
|
-
from truefoundry.pydantic_v1 import BaseModel, Field, conlist
|
|
21
|
+
from truefoundry.pydantic_v1 import BaseModel, Field, StrictStr, conlist, validator
|
|
23
22
|
|
|
24
23
|
|
|
25
24
|
class XGBoostModelSchema(BaseModel):
|
|
26
25
|
"""
|
|
27
|
-
|
|
26
|
+
+label=XGBoost Model Schema # noqa: E501
|
|
28
27
|
"""
|
|
29
28
|
|
|
30
|
-
infer_method_name:
|
|
29
|
+
infer_method_name: StrictStr = Field(
|
|
31
30
|
default=...,
|
|
32
31
|
description="+label=Inference Method Name +usage=Name of the method used for inference",
|
|
33
32
|
)
|
|
@@ -39,6 +38,13 @@ class XGBoostModelSchema(BaseModel):
|
|
|
39
38
|
)
|
|
40
39
|
__properties = ["infer_method_name", "inputs", "outputs"]
|
|
41
40
|
|
|
41
|
+
@validator("infer_method_name")
|
|
42
|
+
def infer_method_name_validate_enum(cls, value):
|
|
43
|
+
"""Validates the enum"""
|
|
44
|
+
if value not in ("predict",):
|
|
45
|
+
raise ValueError("must be one of enum values ('predict')")
|
|
46
|
+
return value
|
|
47
|
+
|
|
42
48
|
class Config:
|
|
43
49
|
"""Pydantic configuration"""
|
|
44
50
|
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# coding: utf-8
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
FastAPI
|
|
5
|
+
|
|
6
|
+
No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
|
7
|
+
|
|
8
|
+
The version of the OpenAPI document: 0.1.0
|
|
9
|
+
Generated by OpenAPI Generator (https://openapi-generator.tech)
|
|
10
|
+
|
|
11
|
+
Do not edit the class manually.
|
|
12
|
+
""" # noqa: E501
|
|
13
|
+
|
|
14
|
+
import json
|
|
15
|
+
import re # noqa: F401
|
|
16
|
+
|
|
17
|
+
from aenum import Enum
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class XGBoostSerializationFormat(str, Enum):
|
|
21
|
+
"""
|
|
22
|
+
+label=Serialization format +usage=Serialization format used for XGBoost models
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
"""
|
|
26
|
+
allowed enum values
|
|
27
|
+
"""
|
|
28
|
+
CLOUDPICKLE = "cloudpickle"
|
|
29
|
+
JOBLIB = "joblib"
|
|
30
|
+
PICKLE = "pickle"
|
|
31
|
+
JSON = "json"
|
|
32
|
+
|
|
33
|
+
@classmethod
|
|
34
|
+
def from_json(cls, json_str: str) -> XGBoostSerializationFormat:
|
|
35
|
+
"""Create an instance of XGBoostSerializationFormat from a JSON string"""
|
|
36
|
+
return XGBoostSerializationFormat(json.loads(json_str))
|
|
@@ -133,6 +133,7 @@ Class | Method | HTTP request | Description
|
|
|
133
133
|
*MlfoundryArtifactsApi* | [**update_artifact_version_post**](truefoundry/ml/autogen/client/docs/MlfoundryArtifactsApi.md#update_artifact_version_post) | **POST** /api/2.0/mlflow/mlfoundry-artifacts/artifact-versions/update | Update Artifact Version
|
|
134
134
|
*MlfoundryArtifactsApi* | [**update_dataset_post**](truefoundry/ml/autogen/client/docs/MlfoundryArtifactsApi.md#update_dataset_post) | **POST** /api/2.0/mlflow/mlfoundry-artifacts/datasets/update | Update Dataset
|
|
135
135
|
*MlfoundryArtifactsApi* | [**update_model_version_post**](truefoundry/ml/autogen/client/docs/MlfoundryArtifactsApi.md#update_model_version_post) | **POST** /api/2.0/mlflow/mlfoundry-artifacts/model-versions/update | Update Model Version
|
|
136
|
+
*MlfoundryArtifactsApi* | [**validate_external_storage_root_path_post**](truefoundry/ml/autogen/client/docs/MlfoundryArtifactsApi.md#validate_external_storage_root_path_post) | **POST** /api/2.0/mlflow/mlfoundry-artifacts/artifact-versions/validate-storage-root | Validate External Storage Root Path
|
|
136
137
|
*RunArtifactsApi* | [**list_run_artifacts_get**](truefoundry/ml/autogen/client/docs/RunArtifactsApi.md#list_run_artifacts_get) | **GET** /api/2.0/mlflow/artifacts/list | List Run Artifacts
|
|
137
138
|
*RunsApi* | [**create_run_post**](truefoundry/ml/autogen/client/docs/RunsApi.md#create_run_post) | **POST** /api/2.0/mlflow/runs/create | Create Run
|
|
138
139
|
*RunsApi* | [**delete_run_post**](truefoundry/ml/autogen/client/docs/RunsApi.md#delete_run_post) | **POST** /api/2.0/mlflow/runs/delete | Delete Run
|
|
@@ -299,12 +300,12 @@ Class | Method | HTTP request | Description
|
|
|
299
300
|
- [RunTagDto](truefoundry/ml/autogen/client/docs/RunTagDto.md)
|
|
300
301
|
- [SearchRunsRequestDto](truefoundry/ml/autogen/client/docs/SearchRunsRequestDto.md)
|
|
301
302
|
- [SearchRunsResponseDto](truefoundry/ml/autogen/client/docs/SearchRunsResponseDto.md)
|
|
302
|
-
- [SerializationFormat](truefoundry/ml/autogen/client/docs/SerializationFormat.md)
|
|
303
303
|
- [SetExperimentTagRequestDto](truefoundry/ml/autogen/client/docs/SetExperimentTagRequestDto.md)
|
|
304
304
|
- [SetTagRequestDto](truefoundry/ml/autogen/client/docs/SetTagRequestDto.md)
|
|
305
305
|
- [SignedURLDto](truefoundry/ml/autogen/client/docs/SignedURLDto.md)
|
|
306
306
|
- [SklearnFramework](truefoundry/ml/autogen/client/docs/SklearnFramework.md)
|
|
307
307
|
- [SklearnModelSchema](truefoundry/ml/autogen/client/docs/SklearnModelSchema.md)
|
|
308
|
+
- [SklearnSerializationFormat](truefoundry/ml/autogen/client/docs/SklearnSerializationFormat.md)
|
|
308
309
|
- [Source](truefoundry/ml/autogen/client/docs/Source.md)
|
|
309
310
|
- [Source1](truefoundry/ml/autogen/client/docs/Source1.md)
|
|
310
311
|
- [SpaCyFramework](truefoundry/ml/autogen/client/docs/SpaCyFramework.md)
|
|
@@ -329,10 +330,13 @@ Class | Method | HTTP request | Description
|
|
|
329
330
|
- [UpdateRunResponseDto](truefoundry/ml/autogen/client/docs/UpdateRunResponseDto.md)
|
|
330
331
|
- [Url](truefoundry/ml/autogen/client/docs/Url.md)
|
|
331
332
|
- [UserMessage](truefoundry/ml/autogen/client/docs/UserMessage.md)
|
|
333
|
+
- [ValidateExternalStorageRootRequestDto](truefoundry/ml/autogen/client/docs/ValidateExternalStorageRootRequestDto.md)
|
|
334
|
+
- [ValidateExternalStorageRootResponseDto](truefoundry/ml/autogen/client/docs/ValidateExternalStorageRootResponseDto.md)
|
|
332
335
|
- [ValidationError](truefoundry/ml/autogen/client/docs/ValidationError.md)
|
|
333
336
|
- [ValidationErrorLocInner](truefoundry/ml/autogen/client/docs/ValidationErrorLocInner.md)
|
|
334
337
|
- [XGBoostFramework](truefoundry/ml/autogen/client/docs/XGBoostFramework.md)
|
|
335
338
|
- [XGBoostModelSchema](truefoundry/ml/autogen/client/docs/XGBoostModelSchema.md)
|
|
339
|
+
- [XGBoostSerializationFormat](truefoundry/ml/autogen/client/docs/XGBoostSerializationFormat.md)
|
|
336
340
|
|
|
337
341
|
|
|
338
342
|
<a id="documentation-for-authorization"></a>
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# generated by datamodel-codegen:
|
|
2
2
|
# filename: artifacts.json
|
|
3
|
-
# timestamp: 2024-12-
|
|
3
|
+
# timestamp: 2024-12-09T09:04:12+00:00
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
@@ -60,19 +60,6 @@ class BaseArtifactVersion(BaseModel):
|
|
|
60
60
|
)
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
class BaseModelSchema(BaseModel):
|
|
64
|
-
infer_method_name: str = Field(
|
|
65
|
-
...,
|
|
66
|
-
description="+label=Inference Method Name\n+usage=Name of the method used for inference",
|
|
67
|
-
)
|
|
68
|
-
inputs: List[Dict[str, Any]] = Field(
|
|
69
|
-
..., description="+label= Input Schema\n+usage=Schema of the input"
|
|
70
|
-
)
|
|
71
|
-
outputs: List[Dict[str, Any]] = Field(
|
|
72
|
-
..., description="+label= Output Schema\n+usage=Schema of the output"
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
|
|
76
63
|
class MimeType(str, Enum):
|
|
77
64
|
"""
|
|
78
65
|
+label=MIME Type
|
|
@@ -289,17 +276,6 @@ class PyTorchFramework(BaseModel):
|
|
|
289
276
|
)
|
|
290
277
|
|
|
291
278
|
|
|
292
|
-
class SerializationFormat(str, Enum):
|
|
293
|
-
"""
|
|
294
|
-
+label=Serialization format
|
|
295
|
-
+usage=Serialization format used for the model
|
|
296
|
-
"""
|
|
297
|
-
|
|
298
|
-
cloudpickle = "cloudpickle"
|
|
299
|
-
joblib = "joblib"
|
|
300
|
-
pickle = "pickle"
|
|
301
|
-
|
|
302
|
-
|
|
303
279
|
class InferMethodName(str, Enum):
|
|
304
280
|
"""
|
|
305
281
|
+label=Inference Method Name
|
|
@@ -310,11 +286,32 @@ class InferMethodName(str, Enum):
|
|
|
310
286
|
predict_proba = "predict_proba"
|
|
311
287
|
|
|
312
288
|
|
|
313
|
-
class SklearnModelSchema(
|
|
289
|
+
class SklearnModelSchema(BaseModel):
|
|
290
|
+
"""
|
|
291
|
+
+label=Sklearn Model Schema
|
|
292
|
+
"""
|
|
293
|
+
|
|
314
294
|
infer_method_name: InferMethodName = Field(
|
|
315
295
|
...,
|
|
316
296
|
description="+label=Inference Method Name\n+usage=Name of the method used for inference",
|
|
317
297
|
)
|
|
298
|
+
inputs: List[Dict[str, Any]] = Field(
|
|
299
|
+
..., description="+label= Input Schema\n+usage=Schema of the input"
|
|
300
|
+
)
|
|
301
|
+
outputs: List[Dict[str, Any]] = Field(
|
|
302
|
+
..., description="+label= Output Schema\n+usage=Schema of the output"
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
class SklearnSerializationFormat(str, Enum):
|
|
307
|
+
"""
|
|
308
|
+
+label=Serialization format
|
|
309
|
+
+usage=Serialization format used for sklearn models
|
|
310
|
+
"""
|
|
311
|
+
|
|
312
|
+
cloudpickle = "cloudpickle"
|
|
313
|
+
joblib = "joblib"
|
|
314
|
+
pickle = "pickle"
|
|
318
315
|
|
|
319
316
|
|
|
320
317
|
class SpaCyFramework(BaseModel):
|
|
@@ -448,11 +445,33 @@ class UserMessage(BaseModel):
|
|
|
448
445
|
)
|
|
449
446
|
|
|
450
447
|
|
|
451
|
-
class XGBoostModelSchema(
|
|
452
|
-
|
|
448
|
+
class XGBoostModelSchema(BaseModel):
|
|
449
|
+
"""
|
|
450
|
+
+label=XGBoost Model Schema
|
|
451
|
+
"""
|
|
452
|
+
|
|
453
|
+
infer_method_name: Literal["predict"] = Field(
|
|
453
454
|
...,
|
|
454
455
|
description="+label=Inference Method Name\n+usage=Name of the method used for inference",
|
|
455
456
|
)
|
|
457
|
+
inputs: List[Dict[str, Any]] = Field(
|
|
458
|
+
..., description="+label= Input Schema\n+usage=Schema of the input"
|
|
459
|
+
)
|
|
460
|
+
outputs: List[Dict[str, Any]] = Field(
|
|
461
|
+
..., description="+label= Output Schema\n+usage=Schema of the output"
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
class XGBoostSerializationFormat(str, Enum):
|
|
466
|
+
"""
|
|
467
|
+
+label=Serialization format
|
|
468
|
+
+usage=Serialization format used for XGBoost models
|
|
469
|
+
"""
|
|
470
|
+
|
|
471
|
+
cloudpickle = "cloudpickle"
|
|
472
|
+
joblib = "joblib"
|
|
473
|
+
pickle = "pickle"
|
|
474
|
+
json = "json"
|
|
456
475
|
|
|
457
476
|
|
|
458
477
|
class AgentOpenAPITool(BaseModel):
|
|
@@ -561,10 +580,7 @@ class SklearnFramework(BaseModel):
|
|
|
561
580
|
None,
|
|
562
581
|
description="+label=Model file path\n+usage=Relative path to the model file",
|
|
563
582
|
)
|
|
564
|
-
serialization_format: Optional[
|
|
565
|
-
None,
|
|
566
|
-
description="+label=Serialization format\n+usage=Serialization format used for the model",
|
|
567
|
-
)
|
|
583
|
+
serialization_format: Optional[SklearnSerializationFormat] = None
|
|
568
584
|
model_schema: Optional[SklearnModelSchema] = None
|
|
569
585
|
|
|
570
586
|
|
|
@@ -577,10 +593,7 @@ class XGBoostFramework(BaseModel):
|
|
|
577
593
|
type: Literal["xgboost"] = Field(
|
|
578
594
|
..., description="+label=Type\n+usage=Type of the framework\n+value=xgboost"
|
|
579
595
|
)
|
|
580
|
-
serialization_format: Optional[
|
|
581
|
-
None,
|
|
582
|
-
description="+label=Serialization format\n+usage=Serialization format used for the model",
|
|
583
|
-
)
|
|
596
|
+
serialization_format: Optional[XGBoostSerializationFormat] = None
|
|
584
597
|
model_filepath: Optional[str] = Field(
|
|
585
598
|
None,
|
|
586
599
|
description="+label=Model file path\n+usage=Relative path to the model file",
|
|
@@ -20,7 +20,10 @@ from truefoundry.common.utils import (
|
|
|
20
20
|
get_python_version_major_minor,
|
|
21
21
|
list_pip_packages_installed,
|
|
22
22
|
)
|
|
23
|
-
from truefoundry.ml.autogen.client import
|
|
23
|
+
from truefoundry.ml.autogen.client import (
|
|
24
|
+
SklearnSerializationFormat,
|
|
25
|
+
XGBoostSerializationFormat,
|
|
26
|
+
)
|
|
24
27
|
from truefoundry.ml.autogen.entities import artifacts as autogen_artifacts
|
|
25
28
|
from truefoundry.ml.autogen.models import infer_signature
|
|
26
29
|
from truefoundry.ml.enums import ModelFramework
|
|
@@ -35,46 +38,13 @@ if TYPE_CHECKING:
|
|
|
35
38
|
|
|
36
39
|
# Map serialization format to corresponding pip packages
|
|
37
40
|
SERIALIZATION_FORMAT_TO_PACKAGES_NAME_MAP = {
|
|
38
|
-
|
|
39
|
-
|
|
41
|
+
SklearnSerializationFormat.JOBLIB: ["joblib"],
|
|
42
|
+
SklearnSerializationFormat.CLOUDPICKLE: ["cloudpickle"],
|
|
43
|
+
XGBoostSerializationFormat.JOBLIB: ["joblib"],
|
|
44
|
+
XGBoostSerializationFormat.CLOUDPICKLE: ["cloudpickle"],
|
|
40
45
|
}
|
|
41
46
|
|
|
42
47
|
|
|
43
|
-
class _SerializationFormatLoaderRegistry:
|
|
44
|
-
def __init__(self):
|
|
45
|
-
# An OrderedDict is used to maintain the order of loaders based on priority
|
|
46
|
-
# The loaders are added in the following order:
|
|
47
|
-
# 1. joblib (if available)
|
|
48
|
-
# 2. cloudpickle (if available)
|
|
49
|
-
# 3. pickle (default fallback)
|
|
50
|
-
# This ensures that when looking up a loader, it follows the correct loading priority.
|
|
51
|
-
self._loader_map: Dict[SerializationFormat, Callable[[bytes], object]] = (
|
|
52
|
-
OrderedDict()
|
|
53
|
-
)
|
|
54
|
-
try:
|
|
55
|
-
from joblib import load as joblib_load
|
|
56
|
-
|
|
57
|
-
self._loader_map[SerializationFormat.JOBLIB] = joblib_load
|
|
58
|
-
except ImportError:
|
|
59
|
-
pass
|
|
60
|
-
|
|
61
|
-
try:
|
|
62
|
-
from cloudpickle import load as cloudpickle_load
|
|
63
|
-
|
|
64
|
-
self._loader_map[SerializationFormat.CLOUDPICKLE] = cloudpickle_load
|
|
65
|
-
except ImportError:
|
|
66
|
-
pass
|
|
67
|
-
|
|
68
|
-
# Add pickle loader as a fallback
|
|
69
|
-
self._loader_map[SerializationFormat.PICKLE] = pickle_load
|
|
70
|
-
|
|
71
|
-
def get_loader_map(self) -> Dict[SerializationFormat, Callable[[bytes], object]]:
|
|
72
|
-
return self._loader_map
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
_serialization_format_loader_registry = _SerializationFormatLoaderRegistry()
|
|
76
|
-
|
|
77
|
-
|
|
78
48
|
class FastAIFramework(autogen_artifacts.FastAIFramework):
|
|
79
49
|
"""FastAI model Framework"""
|
|
80
50
|
|
|
@@ -180,6 +150,87 @@ ModelFrameworkType = Union[
|
|
|
180
150
|
]
|
|
181
151
|
|
|
182
152
|
|
|
153
|
+
class _SerializationFormatLoaderRegistry:
|
|
154
|
+
def __init__(self, framework: Type[Union[SklearnFramework, XGBoostFramework]]):
|
|
155
|
+
# An OrderedDict is used to maintain the order of loaders based on priority
|
|
156
|
+
# The loaders are added in the following order:
|
|
157
|
+
# 1. joblib (if available)
|
|
158
|
+
# 2. cloudpickle (if available)
|
|
159
|
+
# 3. pickle (default fallback)
|
|
160
|
+
# This ensures that when looking up a loader, it follows the correct loading priority.
|
|
161
|
+
self._loader_map: Dict[
|
|
162
|
+
Union[SklearnSerializationFormat, XGBoostSerializationFormat],
|
|
163
|
+
Callable[[bytes], object],
|
|
164
|
+
] = OrderedDict()
|
|
165
|
+
format_class: Union[SklearnSerializationFormat, XGBoostSerializationFormat] = (
|
|
166
|
+
SklearnSerializationFormat
|
|
167
|
+
if framework == SklearnFramework
|
|
168
|
+
else XGBoostSerializationFormat
|
|
169
|
+
)
|
|
170
|
+
is_xgboost = framework == XGBoostFramework
|
|
171
|
+
|
|
172
|
+
try:
|
|
173
|
+
from joblib import load as joblib_load
|
|
174
|
+
|
|
175
|
+
self._loader_map[format_class.JOBLIB] = joblib_load
|
|
176
|
+
except ImportError:
|
|
177
|
+
pass
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
from cloudpickle import load as cloudpickle_load
|
|
181
|
+
|
|
182
|
+
self._loader_map[format_class.CLOUDPICKLE] = cloudpickle_load
|
|
183
|
+
|
|
184
|
+
except ImportError:
|
|
185
|
+
pass
|
|
186
|
+
|
|
187
|
+
if is_xgboost:
|
|
188
|
+
try:
|
|
189
|
+
from xgboost import Booster
|
|
190
|
+
|
|
191
|
+
booster = Booster()
|
|
192
|
+
self._loader_map[format_class.JSON] = booster.load_model
|
|
193
|
+
except ImportError:
|
|
194
|
+
pass
|
|
195
|
+
|
|
196
|
+
# Add pickle loader as a fallback
|
|
197
|
+
self._loader_map[format_class.PICKLE] = pickle_load
|
|
198
|
+
|
|
199
|
+
def get_loader_map(
|
|
200
|
+
self,
|
|
201
|
+
) -> Dict[
|
|
202
|
+
Union[SklearnSerializationFormat, XGBoostSerializationFormat],
|
|
203
|
+
Callable[[bytes], object],
|
|
204
|
+
]:
|
|
205
|
+
return self._loader_map
|
|
206
|
+
|
|
207
|
+
def _detect_model_serialization_format(
|
|
208
|
+
self,
|
|
209
|
+
model_file_path: str,
|
|
210
|
+
) -> Optional[Union[SklearnSerializationFormat, XGBoostSerializationFormat]]:
|
|
211
|
+
"""
|
|
212
|
+
The function will attempt to load the model using each different serialization format's loader and return the first successful one.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
model_file_path (str): The path to the file to be loaded.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
Optional[Union[SklearnSerializationFormat, XGBoostSerializationFormat]]: The serialization format if successfully loaded, None otherwise.
|
|
219
|
+
"""
|
|
220
|
+
# Attempt to load the model using each framework
|
|
221
|
+
for (
|
|
222
|
+
serialization_format,
|
|
223
|
+
loader,
|
|
224
|
+
) in self._loader_map.items():
|
|
225
|
+
try:
|
|
226
|
+
with open(model_file_path, "rb") as f:
|
|
227
|
+
loader(f)
|
|
228
|
+
return serialization_format
|
|
229
|
+
except Exception:
|
|
230
|
+
continue
|
|
231
|
+
return None
|
|
232
|
+
|
|
233
|
+
|
|
183
234
|
class _ModelFramework(BaseModel):
|
|
184
235
|
__root__: ModelFrameworkType = Field(discriminator="type")
|
|
185
236
|
|
|
@@ -259,32 +310,6 @@ def _get_required_framework_pip_packages(framework: "ModelFrameworkType") -> Lis
|
|
|
259
310
|
return _MODEL_FRAMEWORK_TO_PIP_PACKAGES.get(framework.__class__, [])
|
|
260
311
|
|
|
261
312
|
|
|
262
|
-
def _detect_model_serialization_format(
|
|
263
|
-
model_file_path: str,
|
|
264
|
-
) -> Optional[SerializationFormat]:
|
|
265
|
-
"""
|
|
266
|
-
The function will attempt to load the model using each framework's loader and return the first successful one.
|
|
267
|
-
|
|
268
|
-
Args:
|
|
269
|
-
model_file_path (str): The path to the file to be loaded.
|
|
270
|
-
|
|
271
|
-
Returns:
|
|
272
|
-
Optional[SerializationFormat]: The serialization format if successfully loaded, None otherwise.
|
|
273
|
-
"""
|
|
274
|
-
# Attempt to load the model using each framework
|
|
275
|
-
for (
|
|
276
|
-
serialization_format,
|
|
277
|
-
loader,
|
|
278
|
-
) in _serialization_format_loader_registry.get_loader_map().items():
|
|
279
|
-
try:
|
|
280
|
-
with open(model_file_path, "rb") as f:
|
|
281
|
-
loader(f)
|
|
282
|
-
return serialization_format
|
|
283
|
-
except Exception:
|
|
284
|
-
continue
|
|
285
|
-
return None
|
|
286
|
-
|
|
287
|
-
|
|
288
313
|
def _fetch_framework_specific_pip_packages(
|
|
289
314
|
framework: "ModelFrameworkType",
|
|
290
315
|
) -> List[str]:
|
|
@@ -431,12 +456,15 @@ def auto_update_model_framework_details(
|
|
|
431
456
|
if os.path.isfile(model_file_or_folder)
|
|
432
457
|
else os.path.join(model_file_or_folder, framework.model_filepath)
|
|
433
458
|
)
|
|
434
|
-
framework.serialization_format
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
459
|
+
if not framework.serialization_format:
|
|
460
|
+
loader_registry = _SerializationFormatLoaderRegistry(
|
|
461
|
+
framework=framework
|
|
462
|
+
)
|
|
463
|
+
framework.serialization_format = (
|
|
464
|
+
loader_registry._detect_model_serialization_format(
|
|
465
|
+
model_file_path=absolute_model_filepath
|
|
466
|
+
)
|
|
438
467
|
)
|
|
439
|
-
)
|
|
440
468
|
|
|
441
469
|
|
|
442
470
|
def sklearn_infer_schema(
|