snowflake-ml-python 1.19.0__py3-none-any.whl → 1.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +16 -0
- snowflake/ml/_internal/platform_capabilities.py +36 -0
- snowflake/ml/_internal/telemetry.py +56 -7
- snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
- snowflake/ml/data/data_connector.py +103 -1
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
- snowflake/ml/experiment/_entities/run.py +15 -0
- snowflake/ml/experiment/callback/keras.py +25 -2
- snowflake/ml/experiment/callback/lightgbm.py +27 -2
- snowflake/ml/experiment/callback/xgboost.py +25 -2
- snowflake/ml/experiment/experiment_tracking.py +123 -13
- snowflake/ml/experiment/utils.py +6 -0
- snowflake/ml/feature_store/access_manager.py +1 -0
- snowflake/ml/feature_store/feature_store.py +1 -1
- snowflake/ml/feature_store/feature_view.py +34 -24
- snowflake/ml/jobs/_interop/protocols.py +3 -0
- snowflake/ml/jobs/_utils/feature_flags.py +1 -0
- snowflake/ml/jobs/_utils/payload_utils.py +360 -357
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
- snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
- snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
- snowflake/ml/jobs/_utils/spec_utils.py +2 -406
- snowflake/ml/jobs/_utils/stage_utils.py +22 -1
- snowflake/ml/jobs/_utils/types.py +14 -7
- snowflake/ml/jobs/job.py +8 -9
- snowflake/ml/jobs/manager.py +64 -129
- snowflake/ml/model/_client/model/inference_engine_utils.py +8 -4
- snowflake/ml/model/_client/model/model_version_impl.py +109 -28
- snowflake/ml/model/_client/ops/model_ops.py +32 -6
- snowflake/ml/model/_client/ops/service_ops.py +9 -4
- snowflake/ml/model/_client/sql/service.py +69 -2
- snowflake/ml/model/_packager/model_handler.py +8 -2
- snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
- snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/_signatures/core.py +305 -8
- snowflake/ml/model/_signatures/utils.py +13 -4
- snowflake/ml/model/compute_pool.py +2 -0
- snowflake/ml/model/models/huggingface.py +285 -0
- snowflake/ml/model/models/huggingface_pipeline.py +25 -215
- snowflake/ml/model/type_hints.py +5 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
- snowflake/ml/utils/html_utils.py +67 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/METADATA +94 -7
- {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/RECORD +52 -48
- {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/top_level.txt +0 -0
|
@@ -7,7 +7,9 @@ import warnings
|
|
|
7
7
|
from typing import Any, Literal, Optional, TypedDict, Union, cast, overload
|
|
8
8
|
|
|
9
9
|
import yaml
|
|
10
|
+
from typing_extensions import NotRequired
|
|
10
11
|
|
|
12
|
+
from snowflake.ml._internal import platform_capabilities
|
|
11
13
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
|
12
14
|
from snowflake.ml._internal.utils import formatting, identifier, sql_identifier
|
|
13
15
|
from snowflake.ml.model import model_signature, type_hints
|
|
@@ -42,6 +44,8 @@ class ServiceInfo(TypedDict):
|
|
|
42
44
|
name: str
|
|
43
45
|
status: str
|
|
44
46
|
inference_endpoint: Optional[str]
|
|
47
|
+
internal_endpoint: Optional[str]
|
|
48
|
+
autocapture_enabled: NotRequired[bool]
|
|
45
49
|
|
|
46
50
|
|
|
47
51
|
class ModelOperator:
|
|
@@ -651,6 +655,13 @@ class ModelOperator:
|
|
|
651
655
|
url_str = str(url_value)
|
|
652
656
|
return url_str if ModelOperator.PRIVATELINK_INGRESS_ENDPOINT_URL_SUBSTRING in url_str else None
|
|
653
657
|
|
|
658
|
+
def _extract_and_validate_port(self, res_row: "row.Row") -> Optional[int]:
|
|
659
|
+
"""Extract and validate port from endpoint row."""
|
|
660
|
+
port_value = res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_PORT_COL_NAME]
|
|
661
|
+
if port_value is None:
|
|
662
|
+
return None
|
|
663
|
+
return int(port_value)
|
|
664
|
+
|
|
654
665
|
def show_services(
|
|
655
666
|
self,
|
|
656
667
|
*,
|
|
@@ -684,8 +695,12 @@ class ModelOperator:
|
|
|
684
695
|
|
|
685
696
|
result: list[ServiceInfo] = []
|
|
686
697
|
is_privatelink_connection = self._is_privatelink_connection()
|
|
698
|
+
is_autocapture_param_enabled = (
|
|
699
|
+
platform_capabilities.PlatformCapabilities.get_instance().is_inference_autocapture_enabled()
|
|
700
|
+
)
|
|
687
701
|
|
|
688
702
|
for fully_qualified_service_name in fully_qualified_service_names:
|
|
703
|
+
port: Optional[int] = None
|
|
689
704
|
inference_endpoint: Optional[str] = None
|
|
690
705
|
db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
|
|
691
706
|
statuses = self._service_client.get_service_container_statuses(
|
|
@@ -695,6 +710,11 @@ class ModelOperator:
|
|
|
695
710
|
return result
|
|
696
711
|
|
|
697
712
|
service_status = statuses[0].service_status
|
|
713
|
+
service_description = self._service_client.describe_service(
|
|
714
|
+
database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
|
|
715
|
+
)
|
|
716
|
+
internal_dns = str(service_description[self._service_client.DESC_SERVICE_INTERNAL_DNS_COL_NAME])
|
|
717
|
+
|
|
698
718
|
for res_row in self._service_client.show_endpoints(
|
|
699
719
|
database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
|
|
700
720
|
):
|
|
@@ -706,19 +726,25 @@ class ModelOperator:
|
|
|
706
726
|
|
|
707
727
|
ingress_url = self._extract_and_validate_ingress_url(res_row)
|
|
708
728
|
privatelink_ingress_url = self._extract_and_validate_privatelink_url(res_row)
|
|
729
|
+
port = self._extract_and_validate_port(res_row)
|
|
709
730
|
|
|
710
731
|
if is_privatelink_connection and privatelink_ingress_url is not None:
|
|
711
732
|
inference_endpoint = privatelink_ingress_url
|
|
712
733
|
else:
|
|
713
734
|
inference_endpoint = ingress_url
|
|
714
735
|
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
)
|
|
736
|
+
service_info = ServiceInfo(
|
|
737
|
+
name=fully_qualified_service_name,
|
|
738
|
+
status=service_status.value,
|
|
739
|
+
inference_endpoint=inference_endpoint,
|
|
740
|
+
internal_endpoint=f"http://{internal_dns}:{port}" if port is not None else None,
|
|
721
741
|
)
|
|
742
|
+
if is_autocapture_param_enabled and self._service_client.DESC_SERVICE_SPEC_COL_NAME in service_description:
|
|
743
|
+
# Include column only if parameter is enabled and spec exists for service owner caller
|
|
744
|
+
autocapture_enabled = self._service_client.get_proxy_container_autocapture(service_description)
|
|
745
|
+
service_info["autocapture_enabled"] = autocapture_enabled
|
|
746
|
+
|
|
747
|
+
result.append(service_info)
|
|
722
748
|
|
|
723
749
|
return result
|
|
724
750
|
|
|
@@ -11,9 +11,9 @@ import warnings
|
|
|
11
11
|
from typing import Any, Optional, Union, cast
|
|
12
12
|
|
|
13
13
|
from snowflake import snowpark
|
|
14
|
-
from snowflake.ml import jobs
|
|
15
14
|
from snowflake.ml._internal import file_utils, platform_capabilities as pc
|
|
16
15
|
from snowflake.ml._internal.utils import identifier, service_logger, sql_identifier
|
|
16
|
+
from snowflake.ml.jobs import job
|
|
17
17
|
from snowflake.ml.model import inference_engine as inference_engine_module, type_hints
|
|
18
18
|
from snowflake.ml.model._client.model import batch_inference_specs
|
|
19
19
|
from snowflake.ml.model._client.service import model_deployment_spec
|
|
@@ -171,6 +171,7 @@ class ServiceOperator:
|
|
|
171
171
|
self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec(
|
|
172
172
|
workspace_path=pathlib.Path(self._workspace.name)
|
|
173
173
|
)
|
|
174
|
+
self._inference_autocapture_enabled = pc.PlatformCapabilities.get_instance().is_inference_autocapture_enabled()
|
|
174
175
|
|
|
175
176
|
def __eq__(self, __value: object) -> bool:
|
|
176
177
|
if not isinstance(__value, ServiceOperator):
|
|
@@ -207,7 +208,7 @@ class ServiceOperator:
|
|
|
207
208
|
# inference engine model
|
|
208
209
|
inference_engine_args: Optional[InferenceEngineArgs] = None,
|
|
209
210
|
# inference table
|
|
210
|
-
autocapture:
|
|
211
|
+
autocapture: bool = False,
|
|
211
212
|
) -> Union[str, async_job.AsyncJob]:
|
|
212
213
|
|
|
213
214
|
# Generate operation ID for this deployment
|
|
@@ -231,6 +232,10 @@ class ServiceOperator:
|
|
|
231
232
|
progress_status.update("preparing deployment artifacts...")
|
|
232
233
|
progress_status.increment()
|
|
233
234
|
|
|
235
|
+
# If autocapture param is disabled, don't allow create service with autocapture
|
|
236
|
+
if not self._inference_autocapture_enabled and autocapture:
|
|
237
|
+
raise ValueError("Invalid Argument: Autocapture feature is not supported.")
|
|
238
|
+
|
|
234
239
|
if self._workspace:
|
|
235
240
|
stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
|
|
236
241
|
else:
|
|
@@ -976,7 +981,7 @@ class ServiceOperator:
|
|
|
976
981
|
gpu_requests: Optional[str],
|
|
977
982
|
replicas: Optional[int],
|
|
978
983
|
statement_params: Optional[dict[str, Any]] = None,
|
|
979
|
-
) ->
|
|
984
|
+
) -> job.MLJob[Any]:
|
|
980
985
|
database_name = self._database_name
|
|
981
986
|
schema_name = self._schema_name
|
|
982
987
|
|
|
@@ -1045,7 +1050,7 @@ class ServiceOperator:
|
|
|
1045
1050
|
# Block until the async job is done
|
|
1046
1051
|
async_job.result()
|
|
1047
1052
|
|
|
1048
|
-
return
|
|
1053
|
+
return job.MLJob(
|
|
1049
1054
|
id=sql_identifier.get_fully_qualified_name(job_database_name, job_schema_name, job_name),
|
|
1050
1055
|
session=self._session,
|
|
1051
1056
|
)
|
|
@@ -3,7 +3,9 @@ import dataclasses
|
|
|
3
3
|
import enum
|
|
4
4
|
import logging
|
|
5
5
|
import textwrap
|
|
6
|
-
from typing import Any, Generator, Optional
|
|
6
|
+
from typing import Any, Generator, Optional, cast
|
|
7
|
+
|
|
8
|
+
import yaml
|
|
7
9
|
|
|
8
10
|
from snowflake import snowpark
|
|
9
11
|
from snowflake.ml._internal.utils import (
|
|
@@ -68,6 +70,7 @@ class ServiceStatusInfo:
|
|
|
68
70
|
|
|
69
71
|
class ServiceSQLClient(_base._BaseSQLClient):
|
|
70
72
|
MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME = "name"
|
|
73
|
+
MODEL_INFERENCE_SERVICE_ENDPOINT_PORT_COL_NAME = "port"
|
|
71
74
|
MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME = "ingress_url"
|
|
72
75
|
MODEL_INFERENCE_SERVICE_ENDPOINT_PRIVATELINK_INGRESS_URL_COL_NAME = "privatelink_ingress_url"
|
|
73
76
|
SERVICE_STATUS = "service_status"
|
|
@@ -75,6 +78,14 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
75
78
|
INSTANCE_STATUS = "instance_status"
|
|
76
79
|
CONTAINER_STATUS = "status"
|
|
77
80
|
MESSAGE = "message"
|
|
81
|
+
DESC_SERVICE_INTERNAL_DNS_COL_NAME = "dns_name"
|
|
82
|
+
DESC_SERVICE_SPEC_COL_NAME = "spec"
|
|
83
|
+
DESC_SERVICE_CONTAINERS_SPEC_NAME = "containers"
|
|
84
|
+
DESC_SERVICE_NAME_SPEC_NAME = "name"
|
|
85
|
+
DESC_SERVICE_PROXY_SPEC_ENV_NAME = "env"
|
|
86
|
+
PROXY_CONTAINER_NAME = "proxy"
|
|
87
|
+
MODEL_INFERENCE_AUTOCAPTURE_ENV_NAME = "SPCS_MODEL_INFERENCE_SERVER__AUTOCAPTURE_ENABLED"
|
|
88
|
+
FEATURE_MODEL_INFERENCE_AUTOCAPTURE = "FEATURE_MODEL_INFERENCE_AUTOCAPTURE"
|
|
78
89
|
|
|
79
90
|
@contextlib.contextmanager
|
|
80
91
|
def _qmark_paramstyle(self) -> Generator[None, None, None]:
|
|
@@ -233,7 +244,15 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
233
244
|
) -> list[ServiceStatusInfo]:
|
|
234
245
|
fully_qualified_object_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
|
|
235
246
|
query = f"SHOW SERVICE CONTAINERS IN SERVICE {fully_qualified_object_name}"
|
|
236
|
-
rows =
|
|
247
|
+
rows = (
|
|
248
|
+
query_result_checker.SqlResultValidator(self._session, query, statement_params=statement_params)
|
|
249
|
+
.has_column(ServiceSQLClient.INSTANCE_STATUS)
|
|
250
|
+
.has_column(ServiceSQLClient.CONTAINER_STATUS)
|
|
251
|
+
.has_column(ServiceSQLClient.SERVICE_STATUS)
|
|
252
|
+
.has_column(ServiceSQLClient.INSTANCE_ID)
|
|
253
|
+
.has_column(ServiceSQLClient.MESSAGE)
|
|
254
|
+
.validate()
|
|
255
|
+
)
|
|
237
256
|
statuses = []
|
|
238
257
|
for r in rows:
|
|
239
258
|
instance_status, container_status = None, None
|
|
@@ -252,6 +271,53 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
252
271
|
)
|
|
253
272
|
return statuses
|
|
254
273
|
|
|
274
|
+
def describe_service(
|
|
275
|
+
self,
|
|
276
|
+
*,
|
|
277
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
|
278
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
279
|
+
service_name: sql_identifier.SqlIdentifier,
|
|
280
|
+
statement_params: Optional[dict[str, Any]] = None,
|
|
281
|
+
) -> row.Row:
|
|
282
|
+
fully_qualified_object_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
|
|
283
|
+
query = f"DESCRIBE SERVICE {fully_qualified_object_name}"
|
|
284
|
+
rows = (
|
|
285
|
+
query_result_checker.SqlResultValidator(self._session, query, statement_params=statement_params)
|
|
286
|
+
.has_dimensions(expected_rows=1)
|
|
287
|
+
.has_column(ServiceSQLClient.DESC_SERVICE_INTERNAL_DNS_COL_NAME)
|
|
288
|
+
.validate()
|
|
289
|
+
)
|
|
290
|
+
return rows[0]
|
|
291
|
+
|
|
292
|
+
def get_proxy_container_autocapture(self, row: row.Row) -> bool:
|
|
293
|
+
"""Extract whether service has autocapture enabled from proxy container spec.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
row: A row.Row object from DESCRIBE SERVICE containing the service YAML spec.
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
True if autocapture is enabled in proxy spec
|
|
300
|
+
False if disabled or not set in proxy spec
|
|
301
|
+
False if service doesn't have proxy container
|
|
302
|
+
"""
|
|
303
|
+
try:
|
|
304
|
+
spec_raw = yaml.safe_load(row[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME])
|
|
305
|
+
spec = cast(dict[str, Any], spec_raw)
|
|
306
|
+
|
|
307
|
+
proxy_container_spec = next(
|
|
308
|
+
container
|
|
309
|
+
for container in spec[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME][
|
|
310
|
+
ServiceSQLClient.DESC_SERVICE_CONTAINERS_SPEC_NAME
|
|
311
|
+
]
|
|
312
|
+
if container[ServiceSQLClient.DESC_SERVICE_NAME_SPEC_NAME] == ServiceSQLClient.PROXY_CONTAINER_NAME
|
|
313
|
+
)
|
|
314
|
+
env = proxy_container_spec.get(ServiceSQLClient.DESC_SERVICE_PROXY_SPEC_ENV_NAME, {})
|
|
315
|
+
autocapture_enabled = env.get(ServiceSQLClient.MODEL_INFERENCE_AUTOCAPTURE_ENV_NAME, "false")
|
|
316
|
+
return str(autocapture_enabled).lower() == "true"
|
|
317
|
+
|
|
318
|
+
except StopIteration:
|
|
319
|
+
return False
|
|
320
|
+
|
|
255
321
|
def drop_service(
|
|
256
322
|
self,
|
|
257
323
|
*,
|
|
@@ -282,6 +348,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
282
348
|
statement_params=statement_params,
|
|
283
349
|
)
|
|
284
350
|
.has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME, allow_empty=True)
|
|
351
|
+
.has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_PORT_COL_NAME, allow_empty=True)
|
|
285
352
|
.has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME, allow_empty=True)
|
|
286
353
|
)
|
|
287
354
|
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import importlib
|
|
3
|
+
import logging
|
|
3
4
|
import pkgutil
|
|
4
5
|
from types import ModuleType
|
|
5
6
|
from typing import Any, Callable, Optional, TypeVar, cast
|
|
@@ -11,6 +12,8 @@ _HANDLERS_BASE = "snowflake.ml.model._packager.model_handlers"
|
|
|
11
12
|
_MODEL_HANDLER_REGISTRY: dict[str, type[_base.BaseModelHandler[model_types.SupportedModelType]]] = dict()
|
|
12
13
|
_IS_HANDLER_LOADED = False
|
|
13
14
|
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
14
17
|
|
|
15
18
|
def _register_handlers() -> None:
|
|
16
19
|
"""
|
|
@@ -56,8 +59,11 @@ def find_handler(
|
|
|
56
59
|
model: model_types.SupportedModelType,
|
|
57
60
|
) -> Optional[type[_base.BaseModelHandler[model_types.SupportedModelType]]]:
|
|
58
61
|
for handler in _MODEL_HANDLER_REGISTRY.values():
|
|
59
|
-
|
|
60
|
-
|
|
62
|
+
try:
|
|
63
|
+
if handler.can_handle(model):
|
|
64
|
+
return handler
|
|
65
|
+
except Exception:
|
|
66
|
+
logger.error(f"Error in {handler.__name__} `can_handle` method for model {type(model)}", exc_info=True)
|
|
61
67
|
return None
|
|
62
68
|
|
|
63
69
|
|