truefoundry 0.5.1rc7__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truefoundry might be problematic. Click here for more details.
- truefoundry/autodeploy/cli.py +1 -1
- truefoundry/cli/__main__.py +92 -2
- truefoundry/{deploy/cli → cli}/display_util.py +9 -4
- truefoundry/{deploy/cli → cli}/util.py +2 -11
- truefoundry/common/constants.py +11 -0
- truefoundry/common/utils.py +10 -0
- truefoundry/deploy/auto_gen/models.py +3 -3
- truefoundry/deploy/builder/builders/tfy_notebook_buildpack/__init__.py +4 -2
- truefoundry/deploy/builder/builders/tfy_python_buildpack/__init__.py +7 -5
- truefoundry/deploy/builder/builders/tfy_python_buildpack/dockerfile_template.py +87 -28
- truefoundry/deploy/builder/constants.py +8 -0
- truefoundry/deploy/builder/utils.py +9 -4
- truefoundry/deploy/cli/commands/apply_command.py +12 -5
- truefoundry/deploy/cli/commands/build_command.py +3 -3
- truefoundry/deploy/cli/commands/build_logs_command.py +2 -2
- truefoundry/deploy/cli/commands/create_command.py +3 -3
- truefoundry/deploy/cli/commands/delete_command.py +4 -4
- truefoundry/deploy/cli/commands/deploy_command.py +2 -2
- truefoundry/deploy/cli/commands/deploy_init_command.py +1 -1
- truefoundry/deploy/cli/commands/get_command.py +5 -5
- truefoundry/deploy/cli/commands/list_command.py +4 -4
- truefoundry/deploy/cli/commands/login_command.py +3 -3
- truefoundry/deploy/cli/commands/logout_command.py +2 -2
- truefoundry/deploy/cli/commands/logs_command.py +2 -2
- truefoundry/deploy/cli/commands/patch_application_command.py +2 -2
- truefoundry/deploy/cli/commands/patch_command.py +2 -2
- truefoundry/deploy/cli/commands/redeploy_command.py +2 -2
- truefoundry/deploy/cli/commands/terminate_comand.py +3 -3
- truefoundry/deploy/cli/commands/trigger_command.py +2 -2
- truefoundry/deploy/lib/clients/servicefoundry_client.py +2 -2
- truefoundry/deploy/lib/const.py +0 -3
- truefoundry/deploy/lib/dao/apply.py +21 -6
- truefoundry/deploy/lib/dao/workspace.py +1 -1
- truefoundry/deploy/lib/model/entity.py +2 -2
- truefoundry/deploy/lib/util.py +0 -14
- truefoundry/ml/__init__.py +8 -4
- truefoundry/ml/cli/cli.py +5 -1
- truefoundry/ml/cli/commands/download.py +16 -3
- truefoundry/ml/cli/commands/model_init.py +4 -4
- truefoundry/ml/mlfoundry_api.py +2 -1
- truefoundry/ml/mlfoundry_run.py +2 -1
- truefoundry/ml/model_framework.py +47 -10
- {truefoundry-0.5.1rc7.dist-info → truefoundry-0.5.2.dist-info}/METADATA +1 -1
- {truefoundry-0.5.1rc7.dist-info → truefoundry-0.5.2.dist-info}/RECORD +49 -52
- truefoundry/deploy/cli/cli.py +0 -96
- truefoundry/deploy/json_util.py +0 -7
- truefoundry/deploy/lib/exceptions.py +0 -10
- /truefoundry/{deploy/cli → cli}/config.py +0 -0
- /truefoundry/{deploy/cli → cli}/console.py +0 -0
- /truefoundry/{deploy/cli → cli}/const.py +0 -0
- {truefoundry-0.5.1rc7.dist-info → truefoundry-0.5.2.dist-info}/WHEEL +0 -0
- {truefoundry-0.5.1rc7.dist-info → truefoundry-0.5.2.dist-info}/entry_points.txt +0 -0
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import rich_click as click
|
|
2
2
|
|
|
3
|
-
from truefoundry.
|
|
4
|
-
from truefoundry.
|
|
5
|
-
from truefoundry.
|
|
6
|
-
from truefoundry.
|
|
3
|
+
from truefoundry.cli.config import CliConfig
|
|
4
|
+
from truefoundry.cli.const import COMMAND_CLS, GROUP_CLS
|
|
5
|
+
from truefoundry.cli.display_util import print_entity_list, print_json
|
|
6
|
+
from truefoundry.cli.util import handle_exception_wrapper
|
|
7
7
|
from truefoundry.deploy.io.rich_output_callback import RichOutputCallBack
|
|
8
8
|
from truefoundry.deploy.lib.dao import application as application_lib
|
|
9
9
|
from truefoundry.deploy.lib.dao import version as version_lib
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
import rich_click as click
|
|
2
2
|
|
|
3
|
-
from truefoundry.
|
|
4
|
-
from truefoundry.
|
|
5
|
-
from truefoundry.deploy.cli.util import (
|
|
3
|
+
from truefoundry.cli.const import COMMAND_CLS
|
|
4
|
+
from truefoundry.cli.util import (
|
|
6
5
|
_prompt_if_no_value_and_supported,
|
|
7
6
|
handle_exception_wrapper,
|
|
8
7
|
)
|
|
8
|
+
from truefoundry.common.constants import TFY_HOST_ENV_KEY
|
|
9
9
|
from truefoundry.deploy.io.rich_output_callback import RichOutputCallBack
|
|
10
10
|
from truefoundry.deploy.lib.session import login
|
|
11
11
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import rich_click as click
|
|
2
2
|
|
|
3
|
-
from truefoundry.
|
|
4
|
-
from truefoundry.
|
|
3
|
+
from truefoundry.cli.const import COMMAND_CLS
|
|
4
|
+
from truefoundry.cli.util import handle_exception_wrapper
|
|
5
5
|
from truefoundry.deploy.io.rich_output_callback import RichOutputCallBack
|
|
6
6
|
from truefoundry.deploy.lib.session import logout
|
|
7
7
|
|
|
@@ -4,8 +4,8 @@ from typing import Optional
|
|
|
4
4
|
import rich_click as click
|
|
5
5
|
from dateutil.tz import tzlocal
|
|
6
6
|
|
|
7
|
-
from truefoundry.
|
|
8
|
-
from truefoundry.
|
|
7
|
+
from truefoundry.cli.const import COMMAND_CLS
|
|
8
|
+
from truefoundry.cli.util import handle_exception_wrapper
|
|
9
9
|
from truefoundry.deploy.io.rich_output_callback import RichOutputCallBack
|
|
10
10
|
from truefoundry.deploy.lib.clients.servicefoundry_client import (
|
|
11
11
|
ServiceFoundryServiceClient,
|
|
@@ -3,8 +3,8 @@ import json
|
|
|
3
3
|
import rich_click as click
|
|
4
4
|
import yaml
|
|
5
5
|
|
|
6
|
-
from truefoundry.
|
|
7
|
-
from truefoundry.
|
|
6
|
+
from truefoundry.cli.const import GROUP_CLS
|
|
7
|
+
from truefoundry.cli.util import handle_exception_wrapper
|
|
8
8
|
from truefoundry.deploy.lib.dao import application as application_lib
|
|
9
9
|
|
|
10
10
|
|
|
@@ -3,8 +3,8 @@ from typing import Optional
|
|
|
3
3
|
|
|
4
4
|
import rich_click as click
|
|
5
5
|
|
|
6
|
-
from truefoundry.
|
|
7
|
-
from truefoundry.
|
|
6
|
+
from truefoundry.cli.const import GROUP_CLS
|
|
7
|
+
from truefoundry.cli.util import handle_exception_wrapper
|
|
8
8
|
from truefoundry.deploy.io.rich_output_callback import RichOutputCallBack
|
|
9
9
|
|
|
10
10
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import rich_click as click
|
|
2
2
|
|
|
3
|
-
from truefoundry.
|
|
4
|
-
from truefoundry.
|
|
3
|
+
from truefoundry.cli.const import COMMAND_CLS
|
|
4
|
+
from truefoundry.cli.util import handle_exception_wrapper
|
|
5
5
|
from truefoundry.deploy.lib.dao import application as application_lib
|
|
6
6
|
|
|
7
7
|
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import rich_click as click
|
|
2
2
|
|
|
3
|
-
from truefoundry.
|
|
4
|
-
from truefoundry.
|
|
5
|
-
from truefoundry.
|
|
3
|
+
from truefoundry.cli.const import COMMAND_CLS, GROUP_CLS
|
|
4
|
+
from truefoundry.cli.display_util import print_json
|
|
5
|
+
from truefoundry.cli.util import handle_exception_wrapper
|
|
6
6
|
from truefoundry.deploy.lib.dao import application as application_lib
|
|
7
7
|
|
|
8
8
|
|
|
@@ -3,8 +3,8 @@ from typing import Optional, Sequence
|
|
|
3
3
|
import rich_click as click
|
|
4
4
|
from click import ClickException
|
|
5
5
|
|
|
6
|
-
from truefoundry.
|
|
7
|
-
from truefoundry.
|
|
6
|
+
from truefoundry.cli.const import COMMAND_CLS, GROUP_CLS
|
|
7
|
+
from truefoundry.cli.util import handle_exception_wrapper
|
|
8
8
|
from truefoundry.deploy.lib.dao import application
|
|
9
9
|
|
|
10
10
|
|
|
@@ -680,9 +680,9 @@ class ServiceFoundryServiceClient(BaseServiceFoundryServiceClient):
|
|
|
680
680
|
return parse_obj_as(List[Deployment], response)
|
|
681
681
|
|
|
682
682
|
@check_min_cli_version
|
|
683
|
-
def apply(self, manifest: Dict[str, Any]):
|
|
683
|
+
def apply(self, manifest: Dict[str, Any], dry_run: bool = False):
|
|
684
684
|
url = f"{self._api_server_url}/{VERSION_PREFIX}/apply"
|
|
685
|
-
body = {"manifest": manifest}
|
|
685
|
+
body = {"manifest": manifest, "dryRun": dry_run}
|
|
686
686
|
response = session_with_retries().put(
|
|
687
687
|
url, headers=self._get_header(), json=body
|
|
688
688
|
)
|
truefoundry/deploy/lib/const.py
CHANGED
|
@@ -15,6 +15,7 @@ def _apply_manifest(
|
|
|
15
15
|
client: Optional[ServiceFoundryServiceClient] = None,
|
|
16
16
|
filename: Optional[str] = None,
|
|
17
17
|
index: Optional[int] = None,
|
|
18
|
+
dry_run: bool = False,
|
|
18
19
|
) -> ApplyResult:
|
|
19
20
|
client = client or ServiceFoundryServiceClient()
|
|
20
21
|
|
|
@@ -32,29 +33,38 @@ def _apply_manifest(
|
|
|
32
33
|
message=f"Failed to apply manifest{file_metadata}. Error: {ex}",
|
|
33
34
|
)
|
|
34
35
|
|
|
36
|
+
prefix = "[Dry Run] " if dry_run else ""
|
|
37
|
+
suffix = " (No changes were applied)" if dry_run else ""
|
|
35
38
|
try:
|
|
36
|
-
client.apply(manifest.dict())
|
|
39
|
+
client.apply(manifest.dict(), dry_run)
|
|
40
|
+
|
|
37
41
|
return ApplyResult(
|
|
38
42
|
success=True,
|
|
39
|
-
message=
|
|
43
|
+
message=(
|
|
44
|
+
f"{prefix}Successfully configured manifest {manifest.name} of type {manifest.type}.{suffix}"
|
|
45
|
+
),
|
|
40
46
|
)
|
|
41
47
|
except Exception as ex:
|
|
42
48
|
return ApplyResult(
|
|
43
49
|
success=False,
|
|
44
|
-
message=
|
|
50
|
+
message=(
|
|
51
|
+
f"{prefix}Failed to apply manifest {manifest.name} of type {manifest.type}. Error: {ex}.{suffix}"
|
|
52
|
+
),
|
|
45
53
|
)
|
|
46
54
|
|
|
47
55
|
|
|
48
56
|
def apply_manifest(
|
|
49
57
|
manifest: Dict[str, Any],
|
|
50
58
|
client: Optional[ServiceFoundryServiceClient] = None,
|
|
59
|
+
dry_run: bool = False,
|
|
51
60
|
) -> ApplyResult:
|
|
52
|
-
return _apply_manifest(manifest=manifest, client=client)
|
|
61
|
+
return _apply_manifest(manifest=manifest, client=client, dry_run=dry_run)
|
|
53
62
|
|
|
54
63
|
|
|
55
64
|
def apply_manifest_file(
|
|
56
65
|
filepath: str,
|
|
57
66
|
client: Optional[ServiceFoundryServiceClient] = None,
|
|
67
|
+
dry_run: bool = False,
|
|
58
68
|
) -> Iterator[ApplyResult]:
|
|
59
69
|
client = client or ServiceFoundryServiceClient()
|
|
60
70
|
filename = Path(filepath).name
|
|
@@ -67,14 +77,19 @@ def apply_manifest_file(
|
|
|
67
77
|
message=f"Failed to read file {filepath} as a valid YAML file. Error: {ex}",
|
|
68
78
|
)
|
|
69
79
|
else:
|
|
80
|
+
prefix = "[Dry Run] " if dry_run else ""
|
|
70
81
|
for index, manifest in enumerate(manifests_it):
|
|
71
82
|
if not isinstance(manifest, dict):
|
|
72
83
|
yield ApplyResult(
|
|
73
84
|
success=False,
|
|
74
|
-
message=f"Failed to apply manifest at index {index} from file {filename}. Error: A manifest must be a dict, got type {type(manifest)}",
|
|
85
|
+
message=f"{prefix}Failed to apply manifest at index {index} from file {filename}. Error: A manifest must be a dict, got type {type(manifest)}",
|
|
75
86
|
)
|
|
76
87
|
continue
|
|
77
88
|
|
|
78
89
|
yield _apply_manifest(
|
|
79
|
-
manifest=manifest,
|
|
90
|
+
manifest=manifest,
|
|
91
|
+
client=client,
|
|
92
|
+
filename=filename,
|
|
93
|
+
index=index,
|
|
94
|
+
dry_run=dry_run,
|
|
80
95
|
)
|
|
@@ -220,7 +220,7 @@ class JobRun(Base):
|
|
|
220
220
|
status: str
|
|
221
221
|
|
|
222
222
|
def list_row_data(self) -> Dict[str, Any]:
|
|
223
|
-
from truefoundry.
|
|
223
|
+
from truefoundry.cli.display_util import display_time_passed
|
|
224
224
|
|
|
225
225
|
triggered_at = (
|
|
226
226
|
(datetime.datetime.now().timestamp() * 1000) - self.createdAt
|
|
@@ -238,7 +238,7 @@ class JobRun(Base):
|
|
|
238
238
|
}
|
|
239
239
|
|
|
240
240
|
def get_data(self) -> Dict[str, Any]:
|
|
241
|
-
from truefoundry.
|
|
241
|
+
from truefoundry.cli.display_util import display_time_passed
|
|
242
242
|
|
|
243
243
|
created_at = datetime.datetime.fromtimestamp(self.createdAt // 1000)
|
|
244
244
|
end_time = ""
|
truefoundry/deploy/lib/util.py
CHANGED
|
@@ -1,20 +1,6 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import re
|
|
3
2
|
from typing import Union
|
|
4
3
|
|
|
5
|
-
from truefoundry.deploy.lib.const import (
|
|
6
|
-
TFY_DEBUG_ENV_KEY,
|
|
7
|
-
TFY_INTERNAL_ENV_KEY,
|
|
8
|
-
)
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def is_debug_env_set() -> bool:
|
|
12
|
-
return bool(os.getenv(TFY_DEBUG_ENV_KEY))
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def is_internal_env_set() -> bool:
|
|
16
|
-
return bool(os.getenv(TFY_INTERNAL_ENV_KEY))
|
|
17
|
-
|
|
18
4
|
|
|
19
5
|
def get_application_fqn_from_deployment_fqn(deployment_fqn: str) -> str:
|
|
20
6
|
if not re.search(r":\d+$", deployment_fqn):
|
truefoundry/ml/__init__.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
from truefoundry.ml.autogen.client.models import (
|
|
2
|
-
|
|
1
|
+
from truefoundry.ml.autogen.client.models import ( # type: ignore[attr-defined]
|
|
2
|
+
InferMethodName,
|
|
3
3
|
ModelVersionEnvironment,
|
|
4
4
|
SklearnModelSchema,
|
|
5
5
|
XGBoostModelSchema,
|
|
6
6
|
)
|
|
7
|
+
from truefoundry.ml.autogen.entities.artifacts import LibraryName
|
|
7
8
|
from truefoundry.ml.enums import (
|
|
8
9
|
DataSlice,
|
|
9
10
|
FileFormat,
|
|
@@ -40,6 +41,7 @@ from truefoundry.ml.model_framework import (
|
|
|
40
41
|
TransformersFramework,
|
|
41
42
|
XGBoostFramework,
|
|
42
43
|
sklearn_infer_schema,
|
|
44
|
+
xgboost_infer_schema,
|
|
43
45
|
)
|
|
44
46
|
|
|
45
47
|
__all__ = [
|
|
@@ -51,9 +53,11 @@ __all__ = [
|
|
|
51
53
|
"DataSlice",
|
|
52
54
|
"FastAIFramework",
|
|
53
55
|
"FileFormat",
|
|
56
|
+
"get_client",
|
|
54
57
|
"GluonFramework",
|
|
55
58
|
"H2OFramework",
|
|
56
59
|
"Image",
|
|
60
|
+
"InferMethodName",
|
|
57
61
|
"KerasFramework",
|
|
58
62
|
"LibraryName",
|
|
59
63
|
"LightGBMFramework",
|
|
@@ -68,6 +72,7 @@ __all__ = [
|
|
|
68
72
|
"PaddleFramework",
|
|
69
73
|
"Plot",
|
|
70
74
|
"PyTorchFramework",
|
|
75
|
+
"sklearn_infer_schema",
|
|
71
76
|
"SklearnFramework",
|
|
72
77
|
"SklearnModelSchema",
|
|
73
78
|
"SpaCyFramework",
|
|
@@ -75,10 +80,9 @@ __all__ = [
|
|
|
75
80
|
"TensorFlowFramework",
|
|
76
81
|
"TransformersFramework",
|
|
77
82
|
"ViewType",
|
|
83
|
+
"xgboost_infer_schema",
|
|
78
84
|
"XGBoostFramework",
|
|
79
85
|
"XGBoostModelSchema",
|
|
80
|
-
"get_client",
|
|
81
|
-
"sklearn_infer_schema",
|
|
82
86
|
]
|
|
83
87
|
|
|
84
88
|
init_logger()
|
truefoundry/ml/cli/cli.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
import rich_click as click
|
|
2
2
|
|
|
3
|
+
from truefoundry.cli.const import GROUP_CLS
|
|
3
4
|
from truefoundry.ml.cli.commands import download
|
|
4
5
|
|
|
5
6
|
click.rich_click.USE_RICH_MARKUP = True
|
|
6
7
|
|
|
7
8
|
|
|
8
|
-
@click.group(
|
|
9
|
+
@click.group(
|
|
10
|
+
name="ml",
|
|
11
|
+
cls=GROUP_CLS,
|
|
12
|
+
)
|
|
9
13
|
def ml():
|
|
10
14
|
"""
|
|
11
15
|
TrueFoundry ML CLI
|
|
@@ -2,14 +2,23 @@ from typing import Optional
|
|
|
2
2
|
|
|
3
3
|
import click
|
|
4
4
|
|
|
5
|
+
from truefoundry.cli.const import COMMAND_CLS, GROUP_CLS
|
|
5
6
|
from truefoundry.ml import get_client
|
|
6
7
|
|
|
7
8
|
|
|
8
|
-
@click.group(
|
|
9
|
+
@click.group(
|
|
10
|
+
name="download",
|
|
11
|
+
cls=GROUP_CLS,
|
|
12
|
+
help="Download artifact/model versions logged with TrueFoundry",
|
|
13
|
+
)
|
|
9
14
|
def download(): ...
|
|
10
15
|
|
|
11
16
|
|
|
12
|
-
@download.command(
|
|
17
|
+
@download.command(
|
|
18
|
+
name="model",
|
|
19
|
+
cls=COMMAND_CLS,
|
|
20
|
+
help="Download a model version logged with TrueFoundry",
|
|
21
|
+
)
|
|
13
22
|
@click.option(
|
|
14
23
|
"--fqn",
|
|
15
24
|
required=True,
|
|
@@ -48,7 +57,11 @@ def model(fqn: str, path: str, overwrite: bool, progress: Optional[bool] = None)
|
|
|
48
57
|
print(f"Downloaded model files to {download_path}")
|
|
49
58
|
|
|
50
59
|
|
|
51
|
-
@download.command(
|
|
60
|
+
@download.command(
|
|
61
|
+
name="artifact",
|
|
62
|
+
cls=COMMAND_CLS,
|
|
63
|
+
short_help="Download a artifact version logged with TrueFoundry",
|
|
64
|
+
)
|
|
52
65
|
@click.option(
|
|
53
66
|
"--fqn",
|
|
54
67
|
required=True,
|
|
@@ -3,10 +3,10 @@ from typing import Optional
|
|
|
3
3
|
|
|
4
4
|
import rich_click as click
|
|
5
5
|
|
|
6
|
-
from truefoundry.
|
|
7
|
-
from truefoundry.
|
|
8
|
-
from truefoundry.
|
|
9
|
-
from truefoundry.ml.autogen.client
|
|
6
|
+
from truefoundry.cli.console import console
|
|
7
|
+
from truefoundry.cli.const import COMMAND_CLS
|
|
8
|
+
from truefoundry.cli.util import handle_exception_wrapper
|
|
9
|
+
from truefoundry.ml.autogen.client import ModelServer # type: ignore[attr-defined]
|
|
10
10
|
from truefoundry.ml.cli.utils import (
|
|
11
11
|
AppName,
|
|
12
12
|
NonEmptyString,
|
truefoundry/ml/mlfoundry_api.py
CHANGED
|
@@ -18,7 +18,7 @@ from typing import (
|
|
|
18
18
|
import coolname
|
|
19
19
|
|
|
20
20
|
from truefoundry.common.utils import ContextualDirectoryManager, relogin_error_message
|
|
21
|
-
from truefoundry.ml import
|
|
21
|
+
from truefoundry.ml import constants
|
|
22
22
|
from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
|
|
23
23
|
ArtifactDto,
|
|
24
24
|
ArtifactType,
|
|
@@ -35,6 +35,7 @@ from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
|
|
|
35
35
|
MlfoundryArtifactsApi,
|
|
36
36
|
ModelDto,
|
|
37
37
|
ModelServer,
|
|
38
|
+
ModelVersionEnvironment,
|
|
38
39
|
RunsApi,
|
|
39
40
|
RunTagDto,
|
|
40
41
|
SearchRunsRequestDto,
|
truefoundry/ml/mlfoundry_run.py
CHANGED
|
@@ -18,7 +18,7 @@ from urllib.parse import urljoin, urlsplit
|
|
|
18
18
|
|
|
19
19
|
from truefoundry import version
|
|
20
20
|
from truefoundry.common.utils import relogin_error_message
|
|
21
|
-
from truefoundry.ml import
|
|
21
|
+
from truefoundry.ml import constants
|
|
22
22
|
from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
|
|
23
23
|
ArtifactType,
|
|
24
24
|
DeleteRunRequest,
|
|
@@ -29,6 +29,7 @@ from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
|
|
|
29
29
|
MetricDto,
|
|
30
30
|
MetricsApi,
|
|
31
31
|
MlfoundryArtifactsApi,
|
|
32
|
+
ModelVersionEnvironment,
|
|
32
33
|
ParamDto,
|
|
33
34
|
RunDataDto,
|
|
34
35
|
RunDto,
|
|
@@ -35,6 +35,7 @@ from truefoundry.pydantic_v1 import BaseModel, Field
|
|
|
35
35
|
|
|
36
36
|
if TYPE_CHECKING:
|
|
37
37
|
from sklearn.base import BaseEstimator
|
|
38
|
+
from xgboost import Booster, XGBModel
|
|
38
39
|
|
|
39
40
|
# Map serialization format to corresponding pip packages
|
|
40
41
|
SERIALIZATION_FORMAT_TO_PACKAGES_NAME_MAP = {
|
|
@@ -467,6 +468,24 @@ def auto_update_model_framework_details(
|
|
|
467
468
|
)
|
|
468
469
|
|
|
469
470
|
|
|
471
|
+
def _infer_schema(
|
|
472
|
+
model_input: Any,
|
|
473
|
+
model: Union["BaseEstimator", "Booster", "XGBModel"],
|
|
474
|
+
infer_method_name: str = "predict",
|
|
475
|
+
) -> Dict[str, Any]:
|
|
476
|
+
if not hasattr(model, infer_method_name):
|
|
477
|
+
raise ValueError(
|
|
478
|
+
f"Model does not have the method '{infer_method_name}' to infer the schema."
|
|
479
|
+
)
|
|
480
|
+
model_infer_method = getattr(model, infer_method_name)
|
|
481
|
+
model_output = model_infer_method(model_input)
|
|
482
|
+
|
|
483
|
+
model_signature = infer_signature(
|
|
484
|
+
model_input=model_input, model_output=model_output
|
|
485
|
+
)
|
|
486
|
+
return model_signature.to_dict()
|
|
487
|
+
|
|
488
|
+
|
|
470
489
|
def sklearn_infer_schema(
|
|
471
490
|
model_input: Any,
|
|
472
491
|
model: "BaseEstimator",
|
|
@@ -483,19 +502,37 @@ def sklearn_infer_schema(
|
|
|
483
502
|
Returns:
|
|
484
503
|
SklearnModelSchema: The inferred schema of the Sklearn model.
|
|
485
504
|
"""
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
f"Model does not have the method '{infer_method_name}' to infer the schema."
|
|
489
|
-
)
|
|
490
|
-
model_infer_method = getattr(model, infer_method_name)
|
|
491
|
-
model_output = model_infer_method(model_input)
|
|
492
|
-
|
|
493
|
-
model_signature = infer_signature(
|
|
494
|
-
model_input=model_input, model_output=model_output
|
|
505
|
+
model_signature_json = _infer_schema(
|
|
506
|
+
model_input=model_input, model=model, infer_method_name=infer_method_name
|
|
495
507
|
)
|
|
496
|
-
model_signature_json = model_signature.to_dict()
|
|
497
508
|
return autogen_artifacts.SklearnModelSchema(
|
|
498
509
|
infer_method_name=infer_method_name,
|
|
499
510
|
inputs=json.loads(model_signature_json["inputs"]),
|
|
500
511
|
outputs=json.loads(model_signature_json["outputs"]),
|
|
501
512
|
)
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
def xgboost_infer_schema(
|
|
516
|
+
model_input: Any,
|
|
517
|
+
model: Union["Booster", "XGBModel"],
|
|
518
|
+
infer_method_name: str = "predict",
|
|
519
|
+
) -> autogen_artifacts.XGBoostModelSchema:
|
|
520
|
+
"""
|
|
521
|
+
Infer the schema of an XGBoost model.
|
|
522
|
+
|
|
523
|
+
Args:
|
|
524
|
+
model_input (Any): The input data to be used for schema inference.
|
|
525
|
+
model (Any): The XGBoost model instance.
|
|
526
|
+
infer_method_name (str): The name of the method to be used for schema inference.
|
|
527
|
+
Eg: predict (default), predict_proba
|
|
528
|
+
Returns:
|
|
529
|
+
XGBoostModelSchema: The inferred schema of the XGBoost model.
|
|
530
|
+
"""
|
|
531
|
+
model_signature_json = _infer_schema(
|
|
532
|
+
model_input=model_input, model=model, infer_method_name=infer_method_name
|
|
533
|
+
)
|
|
534
|
+
return autogen_artifacts.XGBoostModelSchema(
|
|
535
|
+
infer_method_name=infer_method_name,
|
|
536
|
+
inputs=json.loads(model_signature_json["inputs"]),
|
|
537
|
+
outputs=json.loads(model_signature_json["outputs"]),
|
|
538
|
+
)
|