snowflake-ml-python 1.5.4__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +2 -0
- snowflake/cortex/_classify_text.py +36 -0
- snowflake/cortex/_complete.py +66 -35
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
- snowflake/ml/_internal/telemetry.py +26 -2
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/data/_internal/arrow_ingestor.py +284 -0
- snowflake/ml/data/data_connector.py +186 -0
- snowflake/ml/data/data_ingestor.py +45 -0
- snowflake/ml/data/data_source.py +23 -0
- snowflake/ml/data/ingestor_utils.py +62 -0
- snowflake/ml/data/torch_dataset.py +33 -0
- snowflake/ml/dataset/dataset.py +1 -13
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +23 -117
- snowflake/ml/feature_store/access_manager.py +7 -1
- snowflake/ml/feature_store/entity.py +19 -2
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +37 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +30 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/example_helper.py +278 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +44 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +9 -0
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
- snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
- snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +36 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +24 -0
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +8 -0
- snowflake/ml/feature_store/feature_store.py +637 -76
- snowflake/ml/feature_store/feature_view.py +316 -9
- snowflake/ml/fileset/stage_fs.py +18 -10
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +171 -20
- snowflake/ml/model/_client/ops/model_ops.py +105 -27
- snowflake/ml/model/_client/ops/service_ops.py +121 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +95 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +31 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +129 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
- snowflake/ml/model/_model_composer/model_composer.py +14 -14
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +33 -17
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -1
- snowflake/ml/model/_model_composer/model_method/function_generator.py +3 -3
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +3 -27
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/model_method.py +5 -2
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +30 -3
- snowflake/ml/model/_packager/model_handlers/_utils.py +58 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +52 -3
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +9 -5
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +80 -3
- snowflake/ml/model/_packager/model_handlers/llm.py +7 -3
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +87 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +8 -3
- snowflake/ml/model/_packager/model_handlers/xgboost.py +71 -3
- snowflake/ml/model/_packager/model_meta/model_meta.py +32 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +19 -0
- snowflake/ml/model/_packager/model_packager.py +2 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +7 -7
- snowflake/ml/model/model_signature.py +4 -4
- snowflake/ml/model/type_hints.py +2 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +1 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/impute/simple_imputer.py +26 -0
- snowflake/ml/modeling/pipeline/pipeline.py +7 -4
- snowflake/ml/registry/_manager/model_manager.py +16 -2
- snowflake/ml/registry/registry.py +100 -13
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/METADATA +81 -2
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/RECORD +99 -66
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/lineage/data_source.py +0 -10
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,31 @@
|
|
1
|
+
from typing import List, TypedDict
|
2
|
+
|
3
|
+
from typing_extensions import NotRequired, Required
|
4
|
+
|
5
|
+
|
6
|
+
class ModelDict(TypedDict):
|
7
|
+
name: Required[str]
|
8
|
+
version: Required[str]
|
9
|
+
|
10
|
+
|
11
|
+
class ImageBuildDict(TypedDict):
|
12
|
+
compute_pool: Required[str]
|
13
|
+
image_repo: Required[str]
|
14
|
+
image_name: NotRequired[str]
|
15
|
+
force_rebuild: Required[bool]
|
16
|
+
external_access_integrations: Required[List[str]]
|
17
|
+
|
18
|
+
|
19
|
+
class ServiceDict(TypedDict):
|
20
|
+
name: Required[str]
|
21
|
+
compute_pool: Required[str]
|
22
|
+
ingress_enabled: Required[bool]
|
23
|
+
min_instances: Required[int]
|
24
|
+
max_instances: Required[int]
|
25
|
+
gpu: NotRequired[str]
|
26
|
+
|
27
|
+
|
28
|
+
class ModelDeploymentSpecDict(TypedDict):
|
29
|
+
models: Required[List[ModelDict]]
|
30
|
+
image_build: Required[ImageBuildDict]
|
31
|
+
service: Required[ServiceDict]
|
@@ -371,6 +371,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
371
371
|
returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
372
372
|
partition_column: Optional[sql_identifier.SqlIdentifier],
|
373
373
|
statement_params: Optional[Dict[str, Any]] = None,
|
374
|
+
is_partitioned: bool = True,
|
374
375
|
) -> dataframe.DataFrame:
|
375
376
|
with_statements = []
|
376
377
|
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
@@ -409,12 +410,20 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
409
410
|
|
410
411
|
sql = textwrap.dedent(
|
411
412
|
f"""WITH {','.join(with_statements)}
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
OVER (PARTITION BY {partition_by}))"""
|
413
|
+
SELECT *,
|
414
|
+
FROM {INTERMEDIATE_TABLE_NAME},
|
415
|
+
TABLE({module_version_alias}!{method_name.identifier()}({args_sql}))"""
|
416
416
|
)
|
417
417
|
|
418
|
+
if is_partitioned or partition_column is not None:
|
419
|
+
sql = textwrap.dedent(
|
420
|
+
f"""WITH {','.join(with_statements)}
|
421
|
+
SELECT *,
|
422
|
+
FROM {INTERMEDIATE_TABLE_NAME},
|
423
|
+
TABLE({module_version_alias}!{method_name.identifier()}({args_sql})
|
424
|
+
OVER (PARTITION BY {partition_by}))"""
|
425
|
+
)
|
426
|
+
|
418
427
|
output_df = self._session.sql(sql)
|
419
428
|
|
420
429
|
# Prepare the output
|
@@ -0,0 +1,129 @@
|
|
1
|
+
import textwrap
|
2
|
+
from typing import Any, Dict, List, Optional, Tuple
|
3
|
+
|
4
|
+
from snowflake.ml._internal.utils import (
|
5
|
+
identifier,
|
6
|
+
query_result_checker,
|
7
|
+
sql_identifier,
|
8
|
+
)
|
9
|
+
from snowflake.ml.model._client.sql import _base
|
10
|
+
from snowflake.snowpark import dataframe, functions as F, types as spt
|
11
|
+
from snowflake.snowpark._internal import utils as snowpark_utils
|
12
|
+
|
13
|
+
|
14
|
+
class ServiceSQLClient(_base._BaseSQLClient):
|
15
|
+
def build_model_container(
|
16
|
+
self,
|
17
|
+
*,
|
18
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
19
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
20
|
+
model_name: sql_identifier.SqlIdentifier,
|
21
|
+
version_name: sql_identifier.SqlIdentifier,
|
22
|
+
compute_pool_name: sql_identifier.SqlIdentifier,
|
23
|
+
image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
|
24
|
+
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
25
|
+
image_repo_name: sql_identifier.SqlIdentifier,
|
26
|
+
gpu: Optional[str],
|
27
|
+
force_rebuild: bool,
|
28
|
+
external_access_integration: sql_identifier.SqlIdentifier,
|
29
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
30
|
+
) -> None:
|
31
|
+
actual_image_repo_database = image_repo_database_name or self._database_name
|
32
|
+
actual_image_repo_schema = image_repo_schema_name or self._schema_name
|
33
|
+
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
34
|
+
fq_image_repo_name = "/" + "/".join(
|
35
|
+
[
|
36
|
+
actual_image_repo_database.identifier(),
|
37
|
+
actual_image_repo_schema.identifier(),
|
38
|
+
image_repo_name.identifier(),
|
39
|
+
]
|
40
|
+
)
|
41
|
+
is_gpu = gpu is not None
|
42
|
+
query_result_checker.SqlResultValidator(
|
43
|
+
self._session,
|
44
|
+
(
|
45
|
+
f"CALL SYSTEM$BUILD_MODEL_CONTAINER('{fq_model_name}', '{version_name}', '{compute_pool_name}',"
|
46
|
+
f" '{fq_image_repo_name}', '{is_gpu}', '{force_rebuild}', '', '{external_access_integration}')"
|
47
|
+
),
|
48
|
+
statement_params=statement_params,
|
49
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
50
|
+
|
51
|
+
def deploy_model(
|
52
|
+
self,
|
53
|
+
*,
|
54
|
+
stage_path: str,
|
55
|
+
model_deployment_spec_file_rel_path: str,
|
56
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
57
|
+
) -> None:
|
58
|
+
query_result_checker.SqlResultValidator(
|
59
|
+
self._session,
|
60
|
+
f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')",
|
61
|
+
statement_params=statement_params,
|
62
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
63
|
+
|
64
|
+
def invoke_function_method(
|
65
|
+
self,
|
66
|
+
*,
|
67
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
68
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
69
|
+
service_name: sql_identifier.SqlIdentifier,
|
70
|
+
method_name: sql_identifier.SqlIdentifier,
|
71
|
+
input_df: dataframe.DataFrame,
|
72
|
+
input_args: List[sql_identifier.SqlIdentifier],
|
73
|
+
returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
74
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
75
|
+
) -> dataframe.DataFrame:
|
76
|
+
with_statements = []
|
77
|
+
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
78
|
+
INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
|
79
|
+
with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
|
80
|
+
else:
|
81
|
+
actual_database_name = database_name or self._database_name
|
82
|
+
actual_schema_name = schema_name or self._schema_name
|
83
|
+
tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
84
|
+
INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
|
85
|
+
actual_database_name.identifier(),
|
86
|
+
actual_schema_name.identifier(),
|
87
|
+
tmp_table_name,
|
88
|
+
)
|
89
|
+
input_df.write.save_as_table(
|
90
|
+
table_name=INTERMEDIATE_TABLE_NAME,
|
91
|
+
mode="errorifexists",
|
92
|
+
table_type="temporary",
|
93
|
+
statement_params=statement_params,
|
94
|
+
)
|
95
|
+
|
96
|
+
INTERMEDIATE_OBJ_NAME = "TMP_RESULT"
|
97
|
+
|
98
|
+
with_sql = f"WITH {','.join(with_statements)}" if with_statements else ""
|
99
|
+
args_sql_list = []
|
100
|
+
for input_arg_value in input_args:
|
101
|
+
args_sql_list.append(input_arg_value)
|
102
|
+
args_sql = ", ".join(args_sql_list)
|
103
|
+
|
104
|
+
sql = textwrap.dedent(
|
105
|
+
f"""{with_sql}
|
106
|
+
SELECT *,
|
107
|
+
{service_name.identifier()}_{method_name.identifier()}({args_sql}) AS {INTERMEDIATE_OBJ_NAME}
|
108
|
+
FROM {INTERMEDIATE_TABLE_NAME}"""
|
109
|
+
)
|
110
|
+
|
111
|
+
output_df = self._session.sql(sql)
|
112
|
+
|
113
|
+
# Prepare the output
|
114
|
+
output_cols = []
|
115
|
+
output_names = []
|
116
|
+
|
117
|
+
for output_name, output_type, output_col_name in returns:
|
118
|
+
output_cols.append(F.col(INTERMEDIATE_OBJ_NAME)[output_name].astype(output_type))
|
119
|
+
output_names.append(output_col_name)
|
120
|
+
|
121
|
+
output_df = output_df.with_columns(
|
122
|
+
col_names=output_names,
|
123
|
+
values=output_cols,
|
124
|
+
).drop(INTERMEDIATE_OBJ_NAME)
|
125
|
+
|
126
|
+
if statement_params:
|
127
|
+
output_df._statement_params = statement_params # type: ignore[assignment]
|
128
|
+
|
129
|
+
return output_df
|
@@ -85,9 +85,8 @@ def _run_setup() -> None:
|
|
85
85
|
|
86
86
|
TARGET_METHOD = os.getenv("TARGET_METHOD")
|
87
87
|
|
88
|
-
_concurrent_requests_max_env = os.getenv("_CONCURRENT_REQUESTS_MAX",
|
89
|
-
|
90
|
-
_CONCURRENT_REQUESTS_MAX = int(_concurrent_requests_max_env) if _concurrent_requests_max_env else None
|
88
|
+
_concurrent_requests_max_env = os.getenv("_CONCURRENT_REQUESTS_MAX", "1")
|
89
|
+
_CONCURRENT_REQUESTS_MAX = int(_concurrent_requests_max_env)
|
91
90
|
|
92
91
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
93
92
|
if zipfile.is_zipfile(model_zip_stage_path):
|
@@ -10,8 +10,10 @@ from absl import logging
|
|
10
10
|
from packaging import requirements
|
11
11
|
from typing_extensions import deprecated
|
12
12
|
|
13
|
+
from snowflake import snowpark
|
13
14
|
from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
|
14
|
-
from snowflake.ml._internal.lineage import
|
15
|
+
from snowflake.ml._internal.lineage import lineage_utils
|
16
|
+
from snowflake.ml.data import data_source
|
15
17
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
16
18
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest
|
17
19
|
from snowflake.ml.model._packager import model_packager
|
@@ -128,16 +130,14 @@ class ModelComposer:
|
|
128
130
|
file_utils.copytree(
|
129
131
|
str(self._packager_workspace_path), str(self.workspace_path / ModelComposer.MODEL_DIR_REL_PATH)
|
130
132
|
)
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
data_sources=self._get_data_sources(model, sample_input_data),
|
140
|
-
)
|
133
|
+
self.manifest.save(
|
134
|
+
model_meta=self.packager.meta,
|
135
|
+
model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
|
136
|
+
options=options,
|
137
|
+
data_sources=self._get_data_sources(model, sample_input_data),
|
138
|
+
)
|
139
|
+
else:
|
140
|
+
file_utils.make_archive(self.model_local_path, str(self._packager_workspace_path))
|
141
141
|
|
142
142
|
file_utils.upload_directory_to_stage(
|
143
143
|
self.session,
|
@@ -186,6 +186,6 @@ class ModelComposer:
|
|
186
186
|
data_sources = lineage_utils.get_data_sources(model)
|
187
187
|
if not data_sources and sample_input_data is not None:
|
188
188
|
data_sources = lineage_utils.get_data_sources(sample_input_data)
|
189
|
-
|
190
|
-
|
191
|
-
return
|
189
|
+
if not data_sources and isinstance(sample_input_data, snowpark.DataFrame):
|
190
|
+
data_sources = [data_source.DataFrameInfo(sample_input_data.queries["queries"][-1])]
|
191
|
+
return data_sources
|
@@ -1,11 +1,13 @@
|
|
1
1
|
import collections
|
2
2
|
import copy
|
3
3
|
import pathlib
|
4
|
+
import warnings
|
4
5
|
from typing import List, Optional, cast
|
5
6
|
|
6
7
|
import yaml
|
7
8
|
|
8
|
-
from snowflake.ml._internal
|
9
|
+
from snowflake.ml._internal import env_utils
|
10
|
+
from snowflake.ml.data import data_source
|
9
11
|
from snowflake.ml.model import type_hints
|
10
12
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
11
13
|
from snowflake.ml.model._model_composer.model_method import (
|
@@ -16,7 +18,6 @@ from snowflake.ml.model._packager.model_meta import (
|
|
16
18
|
model_meta as model_meta_api,
|
17
19
|
model_meta_schema,
|
18
20
|
)
|
19
|
-
from snowflake.snowpark import Session
|
20
21
|
|
21
22
|
|
22
23
|
class ModelManifest:
|
@@ -36,9 +37,8 @@ class ModelManifest:
|
|
36
37
|
|
37
38
|
def save(
|
38
39
|
self,
|
39
|
-
session: Session,
|
40
40
|
model_meta: model_meta_api.ModelMetadata,
|
41
|
-
|
41
|
+
model_rel_path: pathlib.PurePosixPath,
|
42
42
|
options: Optional[type_hints.ModelSaveOption] = None,
|
43
43
|
data_sources: Optional[List[data_source.DataSource]] = None,
|
44
44
|
) -> None:
|
@@ -47,10 +47,12 @@ class ModelManifest:
|
|
47
47
|
|
48
48
|
runtime_to_use = copy.deepcopy(model_meta.runtimes["cpu"])
|
49
49
|
runtime_to_use.name = self._DEFAULT_RUNTIME_NAME
|
50
|
-
runtime_to_use.imports.append(
|
51
|
-
runtime_dict = runtime_to_use.save(
|
50
|
+
runtime_to_use.imports.append(str(model_rel_path) + "/")
|
51
|
+
runtime_dict = runtime_to_use.save(
|
52
|
+
self.workspace_path, default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
53
|
+
)
|
52
54
|
|
53
|
-
self.function_generator = function_generator.FunctionGenerator(
|
55
|
+
self.function_generator = function_generator.FunctionGenerator(model_dir_rel_path=model_rel_path)
|
54
56
|
self.methods: List[model_method.ModelMethod] = []
|
55
57
|
for target_method in model_meta.signatures.keys():
|
56
58
|
method = model_method.ModelMethod(
|
@@ -75,6 +77,16 @@ class ModelManifest:
|
|
75
77
|
"In this case, set case_sensitive as True for those methods to distinguish them."
|
76
78
|
)
|
77
79
|
|
80
|
+
dependencies = model_manifest_schema.ModelRuntimeDependenciesDict(conda=runtime_dict["dependencies"]["conda"])
|
81
|
+
if options.get("include_pip_dependencies"):
|
82
|
+
warnings.warn(
|
83
|
+
"`include_pip_dependencies` specified as True: pip dependencies will be included and may not"
|
84
|
+
"be warehouse-compabible. The model may need to be run in SPCS.",
|
85
|
+
category=UserWarning,
|
86
|
+
stacklevel=1,
|
87
|
+
)
|
88
|
+
dependencies["pip"] = runtime_dict["dependencies"]["pip"]
|
89
|
+
|
78
90
|
manifest_dict = model_manifest_schema.ModelManifestDict(
|
79
91
|
manifest_version=model_manifest_schema.MODEL_MANIFEST_VERSION,
|
80
92
|
runtimes={
|
@@ -82,9 +94,7 @@ class ModelManifest:
|
|
82
94
|
language="PYTHON",
|
83
95
|
version=runtime_to_use.runtime_env.python_version,
|
84
96
|
imports=runtime_dict["imports"],
|
85
|
-
dependencies=
|
86
|
-
conda=runtime_dict["dependencies"]["conda"]
|
87
|
-
),
|
97
|
+
dependencies=dependencies,
|
88
98
|
)
|
89
99
|
},
|
90
100
|
methods=[
|
@@ -127,12 +137,18 @@ class ModelManifest:
|
|
127
137
|
result = []
|
128
138
|
if data_sources:
|
129
139
|
for source in data_sources:
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
140
|
+
if isinstance(source, data_source.DatasetInfo):
|
141
|
+
result.append(
|
142
|
+
model_manifest_schema.LineageSourceDict(
|
143
|
+
type=model_manifest_schema.LineageSourceTypes.DATASET.value,
|
144
|
+
entity=source.fully_qualified_name,
|
145
|
+
version=source.version,
|
146
|
+
)
|
147
|
+
)
|
148
|
+
elif isinstance(source, data_source.DataFrameInfo):
|
149
|
+
result.append(
|
150
|
+
model_manifest_schema.LineageSourceDict(
|
151
|
+
type=model_manifest_schema.LineageSourceTypes.QUERY.value, entity=source.sql
|
152
|
+
)
|
136
153
|
)
|
137
|
-
)
|
138
154
|
return result
|
@@ -18,7 +18,8 @@ class ModelMethodFunctionTypes(enum.Enum):
|
|
18
18
|
|
19
19
|
|
20
20
|
class ModelRuntimeDependenciesDict(TypedDict):
|
21
|
-
conda:
|
21
|
+
conda: NotRequired[str]
|
22
|
+
pip: NotRequired[str]
|
22
23
|
|
23
24
|
|
24
25
|
class ModelRuntimeDict(TypedDict):
|
@@ -56,12 +57,14 @@ class ModelFunctionInfo(TypedDict):
|
|
56
57
|
target_method: actual target method name to be called.
|
57
58
|
target_method_function_type: target method function type (FUNCTION or TABLE_FUNCTION).
|
58
59
|
signature: The signature of the model method.
|
60
|
+
is_partitioned: Whether the function is partitioned.
|
59
61
|
"""
|
60
62
|
|
61
63
|
name: Required[str]
|
62
64
|
target_method: Required[str]
|
63
65
|
target_method_function_type: Required[str]
|
64
66
|
signature: Required[model_signature.ModelSignature]
|
67
|
+
is_partitioned: Required[bool]
|
65
68
|
|
66
69
|
|
67
70
|
class ModelFunctionInfoDict(TypedDict):
|
@@ -77,6 +80,7 @@ class SnowparkMLDataDict(TypedDict):
|
|
77
80
|
|
78
81
|
class LineageSourceTypes(enum.Enum):
|
79
82
|
DATASET = "DATASET"
|
83
|
+
QUERY = "QUERY"
|
80
84
|
|
81
85
|
|
82
86
|
class LineageSourceDict(TypedDict):
|
@@ -33,9 +33,9 @@ class FunctionGenerator:
|
|
33
33
|
|
34
34
|
def __init__(
|
35
35
|
self,
|
36
|
-
|
36
|
+
model_dir_rel_path: pathlib.PurePosixPath,
|
37
37
|
) -> None:
|
38
|
-
self.
|
38
|
+
self.model_dir_rel_path = model_dir_rel_path
|
39
39
|
|
40
40
|
def generate(
|
41
41
|
self,
|
@@ -67,7 +67,7 @@ class FunctionGenerator:
|
|
67
67
|
)
|
68
68
|
|
69
69
|
udf_code = function_template.format(
|
70
|
-
|
70
|
+
model_dir_name=self.model_dir_rel_path.name,
|
71
71
|
target_method=target_method,
|
72
72
|
max_batch_size=options.get("max_batch_size", None),
|
73
73
|
function_name=FunctionGenerator.FUNCTION_NAME,
|
@@ -1,12 +1,7 @@
|
|
1
|
-
import fcntl
|
2
1
|
import functools
|
3
2
|
import inspect
|
4
3
|
import os
|
5
4
|
import sys
|
6
|
-
import threading
|
7
|
-
import zipfile
|
8
|
-
from types import TracebackType
|
9
|
-
from typing import Optional, Type
|
10
5
|
|
11
6
|
import anyio
|
12
7
|
import pandas as pd
|
@@ -15,42 +10,18 @@ from _snowflake import vectorized
|
|
15
10
|
from snowflake.ml.model._packager import model_packager
|
16
11
|
|
17
12
|
|
18
|
-
class FileLock:
|
19
|
-
def __enter__(self) -> None:
|
20
|
-
self._lock = threading.Lock()
|
21
|
-
self._lock.acquire()
|
22
|
-
self._fd = open("/tmp/lockfile.LOCK", "w+")
|
23
|
-
fcntl.lockf(self._fd, fcntl.LOCK_EX)
|
24
|
-
|
25
|
-
def __exit__(
|
26
|
-
self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType]
|
27
|
-
) -> None:
|
28
|
-
self._fd.close()
|
29
|
-
self._lock.release()
|
30
|
-
|
31
|
-
|
32
13
|
# User-defined parameters
|
33
|
-
|
14
|
+
MODEL_DIR_REL_PATH = "{model_dir_name}"
|
34
15
|
TARGET_METHOD = "{target_method}"
|
35
16
|
MAX_BATCH_SIZE = {max_batch_size}
|
36
17
|
|
37
|
-
|
38
18
|
# Retrieve the model
|
39
19
|
IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
|
40
20
|
import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
|
41
|
-
|
42
|
-
model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0]
|
43
|
-
zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME)
|
44
|
-
extracted = "/tmp/models"
|
45
|
-
extracted_model_dir_path = os.path.join(extracted, model_dir_name)
|
46
|
-
|
47
|
-
with FileLock():
|
48
|
-
if not os.path.isdir(extracted_model_dir_path):
|
49
|
-
with zipfile.ZipFile(zip_model_path, "r") as myzip:
|
50
|
-
myzip.extractall(extracted_model_dir_path)
|
21
|
+
model_dir_path = os.path.join(import_dir, MODEL_DIR_REL_PATH)
|
51
22
|
|
52
23
|
# Load the model
|
53
|
-
pk = model_packager.ModelPackager(
|
24
|
+
pk = model_packager.ModelPackager(model_dir_path)
|
54
25
|
pk.load(as_custom_model=True)
|
55
26
|
assert pk.model, "model is not loaded"
|
56
27
|
assert pk.meta, "model metadata is not loaded"
|
@@ -15,42 +15,18 @@ from _snowflake import vectorized
|
|
15
15
|
from snowflake.ml.model._packager import model_packager
|
16
16
|
|
17
17
|
|
18
|
-
class FileLock:
|
19
|
-
def __enter__(self) -> None:
|
20
|
-
self._lock = threading.Lock()
|
21
|
-
self._lock.acquire()
|
22
|
-
self._fd = open("/tmp/lockfile.LOCK", "w+")
|
23
|
-
fcntl.lockf(self._fd, fcntl.LOCK_EX)
|
24
|
-
|
25
|
-
def __exit__(
|
26
|
-
self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType]
|
27
|
-
) -> None:
|
28
|
-
self._fd.close()
|
29
|
-
self._lock.release()
|
30
|
-
|
31
|
-
|
32
18
|
# User-defined parameters
|
33
|
-
|
19
|
+
MODEL_DIR_REL_PATH = "{model_dir_name}"
|
34
20
|
TARGET_METHOD = "{target_method}"
|
35
21
|
MAX_BATCH_SIZE = {max_batch_size}
|
36
22
|
|
37
|
-
|
38
23
|
# Retrieve the model
|
39
24
|
IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
|
40
25
|
import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
|
41
|
-
|
42
|
-
model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0]
|
43
|
-
zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME)
|
44
|
-
extracted = "/tmp/models"
|
45
|
-
extracted_model_dir_path = os.path.join(extracted, model_dir_name)
|
46
|
-
|
47
|
-
with FileLock():
|
48
|
-
if not os.path.isdir(extracted_model_dir_path):
|
49
|
-
with zipfile.ZipFile(zip_model_path, "r") as myzip:
|
50
|
-
myzip.extractall(extracted_model_dir_path)
|
26
|
+
model_dir_path = os.path.join(import_dir, MODEL_DIR_REL_PATH)
|
51
27
|
|
52
28
|
# Load the model
|
53
|
-
pk = model_packager.ModelPackager(
|
29
|
+
pk = model_packager.ModelPackager(model_dir_path)
|
54
30
|
pk.load(as_custom_model=True)
|
55
31
|
assert pk.model, "model is not loaded"
|
56
32
|
assert pk.meta, "model metadata is not loaded"
|
@@ -1,12 +1,7 @@
|
|
1
|
-
import fcntl
|
2
1
|
import functools
|
3
2
|
import inspect
|
4
3
|
import os
|
5
4
|
import sys
|
6
|
-
import threading
|
7
|
-
import zipfile
|
8
|
-
from types import TracebackType
|
9
|
-
from typing import Optional, Type
|
10
5
|
|
11
6
|
import anyio
|
12
7
|
import pandas as pd
|
@@ -15,42 +10,18 @@ from _snowflake import vectorized
|
|
15
10
|
from snowflake.ml.model._packager import model_packager
|
16
11
|
|
17
12
|
|
18
|
-
class FileLock:
|
19
|
-
def __enter__(self) -> None:
|
20
|
-
self._lock = threading.Lock()
|
21
|
-
self._lock.acquire()
|
22
|
-
self._fd = open("/tmp/lockfile.LOCK", "w+")
|
23
|
-
fcntl.lockf(self._fd, fcntl.LOCK_EX)
|
24
|
-
|
25
|
-
def __exit__(
|
26
|
-
self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType]
|
27
|
-
) -> None:
|
28
|
-
self._fd.close()
|
29
|
-
self._lock.release()
|
30
|
-
|
31
|
-
|
32
13
|
# User-defined parameters
|
33
|
-
|
14
|
+
MODEL_DIR_REL_PATH = "{model_dir_name}"
|
34
15
|
TARGET_METHOD = "{target_method}"
|
35
16
|
MAX_BATCH_SIZE = {max_batch_size}
|
36
17
|
|
37
|
-
|
38
18
|
# Retrieve the model
|
39
19
|
IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
|
40
20
|
import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
|
41
|
-
|
42
|
-
model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0]
|
43
|
-
zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME)
|
44
|
-
extracted = "/tmp/models"
|
45
|
-
extracted_model_dir_path = os.path.join(extracted, model_dir_name)
|
46
|
-
|
47
|
-
with FileLock():
|
48
|
-
if not os.path.isdir(extracted_model_dir_path):
|
49
|
-
with zipfile.ZipFile(zip_model_path, "r") as myzip:
|
50
|
-
myzip.extractall(extracted_model_dir_path)
|
21
|
+
model_dir_path = os.path.join(import_dir, MODEL_DIR_REL_PATH)
|
51
22
|
|
52
23
|
# Load the model
|
53
|
-
pk = model_packager.ModelPackager(
|
24
|
+
pk = model_packager.ModelPackager(model_dir_path)
|
54
25
|
pk.load(as_custom_model=True)
|
55
26
|
assert pk.model, "model is not loaded"
|
56
27
|
assert pk.meta, "model metadata is not loaded"
|
@@ -26,11 +26,14 @@ class ModelMethodOptions(TypedDict):
|
|
26
26
|
def get_model_method_options_from_options(
|
27
27
|
options: type_hints.ModelSaveOption, target_method: str
|
28
28
|
) -> ModelMethodOptions:
|
29
|
+
default_function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
|
30
|
+
if options.get("enable_explainability", False) and target_method.startswith("explain"):
|
31
|
+
default_function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
|
29
32
|
method_option = options.get("method_options", {}).get(target_method, {})
|
30
|
-
global_function_type = options.get("function_type",
|
33
|
+
global_function_type = options.get("function_type", default_function_type)
|
31
34
|
function_type = method_option.get("function_type", global_function_type)
|
32
35
|
if function_type not in [function_type.value for function_type in model_manifest_schema.ModelMethodFunctionTypes]:
|
33
|
-
raise NotImplementedError
|
36
|
+
raise NotImplementedError(f"Function type {function_type} is not supported.")
|
34
37
|
|
35
38
|
return ModelMethodOptions(
|
36
39
|
case_sensitive=method_option.get("case_sensitive", False),
|
@@ -363,9 +363,14 @@ class ModelEnv:
|
|
363
363
|
self.cuda_version = env_dict.get("cuda_version", None)
|
364
364
|
self.snowpark_ml_version = env_dict["snowpark_ml_version"]
|
365
365
|
|
366
|
-
def save_as_dict(
|
366
|
+
def save_as_dict(
|
367
|
+
self, base_dir: pathlib.Path, default_channel_override: str = env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
368
|
+
) -> model_meta_schema.ModelEnvDict:
|
367
369
|
env_utils.save_conda_env_file(
|
368
|
-
pathlib.Path(base_dir / self.conda_env_rel_path),
|
370
|
+
pathlib.Path(base_dir / self.conda_env_rel_path),
|
371
|
+
self._conda_dependencies,
|
372
|
+
self.python_version,
|
373
|
+
default_channel_override=default_channel_override,
|
369
374
|
)
|
370
375
|
env_utils.save_requirements_file(
|
371
376
|
pathlib.Path(base_dir / self.pip_requirements_rel_path), self._pip_requirements
|