snowflake-ml-python 1.19.0__py3-none-any.whl → 1.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. snowflake/ml/_internal/env_utils.py +16 -0
  2. snowflake/ml/_internal/platform_capabilities.py +36 -0
  3. snowflake/ml/_internal/telemetry.py +56 -7
  4. snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
  5. snowflake/ml/data/data_connector.py +103 -1
  6. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
  7. snowflake/ml/experiment/_entities/run.py +15 -0
  8. snowflake/ml/experiment/callback/keras.py +25 -2
  9. snowflake/ml/experiment/callback/lightgbm.py +27 -2
  10. snowflake/ml/experiment/callback/xgboost.py +25 -2
  11. snowflake/ml/experiment/experiment_tracking.py +123 -13
  12. snowflake/ml/experiment/utils.py +6 -0
  13. snowflake/ml/feature_store/access_manager.py +1 -0
  14. snowflake/ml/feature_store/feature_store.py +1 -1
  15. snowflake/ml/feature_store/feature_view.py +34 -24
  16. snowflake/ml/jobs/_interop/protocols.py +3 -0
  17. snowflake/ml/jobs/_utils/feature_flags.py +1 -0
  18. snowflake/ml/jobs/_utils/payload_utils.py +360 -357
  19. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
  20. snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
  21. snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
  22. snowflake/ml/jobs/_utils/spec_utils.py +2 -406
  23. snowflake/ml/jobs/_utils/stage_utils.py +22 -1
  24. snowflake/ml/jobs/_utils/types.py +14 -7
  25. snowflake/ml/jobs/job.py +8 -9
  26. snowflake/ml/jobs/manager.py +64 -129
  27. snowflake/ml/model/_client/model/inference_engine_utils.py +8 -4
  28. snowflake/ml/model/_client/model/model_version_impl.py +109 -28
  29. snowflake/ml/model/_client/ops/model_ops.py +32 -6
  30. snowflake/ml/model/_client/ops/service_ops.py +9 -4
  31. snowflake/ml/model/_client/sql/service.py +69 -2
  32. snowflake/ml/model/_packager/model_handler.py +8 -2
  33. snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
  34. snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
  35. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  36. snowflake/ml/model/_signatures/core.py +305 -8
  37. snowflake/ml/model/_signatures/utils.py +13 -4
  38. snowflake/ml/model/compute_pool.py +2 -0
  39. snowflake/ml/model/models/huggingface.py +285 -0
  40. snowflake/ml/model/models/huggingface_pipeline.py +25 -215
  41. snowflake/ml/model/type_hints.py +5 -1
  42. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  43. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
  44. snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
  45. snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
  46. snowflake/ml/utils/html_utils.py +67 -1
  47. snowflake/ml/version.py +1 -1
  48. {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/METADATA +94 -7
  49. {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/RECORD +52 -48
  50. {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/WHEEL +0 -0
  51. {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/licenses/LICENSE.txt +0 -0
  52. {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/top_level.txt +0 -0
@@ -7,7 +7,9 @@ import warnings
7
7
  from typing import Any, Literal, Optional, TypedDict, Union, cast, overload
8
8
 
9
9
  import yaml
10
+ from typing_extensions import NotRequired
10
11
 
12
+ from snowflake.ml._internal import platform_capabilities
11
13
  from snowflake.ml._internal.exceptions import error_codes, exceptions
12
14
  from snowflake.ml._internal.utils import formatting, identifier, sql_identifier
13
15
  from snowflake.ml.model import model_signature, type_hints
@@ -42,6 +44,8 @@ class ServiceInfo(TypedDict):
42
44
  name: str
43
45
  status: str
44
46
  inference_endpoint: Optional[str]
47
+ internal_endpoint: Optional[str]
48
+ autocapture_enabled: NotRequired[bool]
45
49
 
46
50
 
47
51
  class ModelOperator:
@@ -651,6 +655,13 @@ class ModelOperator:
651
655
  url_str = str(url_value)
652
656
  return url_str if ModelOperator.PRIVATELINK_INGRESS_ENDPOINT_URL_SUBSTRING in url_str else None
653
657
 
658
+ def _extract_and_validate_port(self, res_row: "row.Row") -> Optional[int]:
659
+ """Extract and validate port from endpoint row."""
660
+ port_value = res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_PORT_COL_NAME]
661
+ if port_value is None:
662
+ return None
663
+ return int(port_value)
664
+
654
665
  def show_services(
655
666
  self,
656
667
  *,
@@ -684,8 +695,12 @@ class ModelOperator:
684
695
 
685
696
  result: list[ServiceInfo] = []
686
697
  is_privatelink_connection = self._is_privatelink_connection()
698
+ is_autocapture_param_enabled = (
699
+ platform_capabilities.PlatformCapabilities.get_instance().is_inference_autocapture_enabled()
700
+ )
687
701
 
688
702
  for fully_qualified_service_name in fully_qualified_service_names:
703
+ port: Optional[int] = None
689
704
  inference_endpoint: Optional[str] = None
690
705
  db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
691
706
  statuses = self._service_client.get_service_container_statuses(
@@ -695,6 +710,11 @@ class ModelOperator:
695
710
  return result
696
711
 
697
712
  service_status = statuses[0].service_status
713
+ service_description = self._service_client.describe_service(
714
+ database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
715
+ )
716
+ internal_dns = str(service_description[self._service_client.DESC_SERVICE_INTERNAL_DNS_COL_NAME])
717
+
698
718
  for res_row in self._service_client.show_endpoints(
699
719
  database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
700
720
  ):
@@ -706,19 +726,25 @@ class ModelOperator:
706
726
 
707
727
  ingress_url = self._extract_and_validate_ingress_url(res_row)
708
728
  privatelink_ingress_url = self._extract_and_validate_privatelink_url(res_row)
729
+ port = self._extract_and_validate_port(res_row)
709
730
 
710
731
  if is_privatelink_connection and privatelink_ingress_url is not None:
711
732
  inference_endpoint = privatelink_ingress_url
712
733
  else:
713
734
  inference_endpoint = ingress_url
714
735
 
715
- result.append(
716
- ServiceInfo(
717
- name=fully_qualified_service_name,
718
- status=service_status.value,
719
- inference_endpoint=inference_endpoint,
720
- )
736
+ service_info = ServiceInfo(
737
+ name=fully_qualified_service_name,
738
+ status=service_status.value,
739
+ inference_endpoint=inference_endpoint,
740
+ internal_endpoint=f"http://{internal_dns}:{port}" if port is not None else None,
721
741
  )
742
+ if is_autocapture_param_enabled and self._service_client.DESC_SERVICE_SPEC_COL_NAME in service_description:
743
+ # Include column only if parameter is enabled and spec exists for service owner caller
744
+ autocapture_enabled = self._service_client.get_proxy_container_autocapture(service_description)
745
+ service_info["autocapture_enabled"] = autocapture_enabled
746
+
747
+ result.append(service_info)
722
748
 
723
749
  return result
724
750
 
@@ -11,9 +11,9 @@ import warnings
11
11
  from typing import Any, Optional, Union, cast
12
12
 
13
13
  from snowflake import snowpark
14
- from snowflake.ml import jobs
15
14
  from snowflake.ml._internal import file_utils, platform_capabilities as pc
16
15
  from snowflake.ml._internal.utils import identifier, service_logger, sql_identifier
16
+ from snowflake.ml.jobs import job
17
17
  from snowflake.ml.model import inference_engine as inference_engine_module, type_hints
18
18
  from snowflake.ml.model._client.model import batch_inference_specs
19
19
  from snowflake.ml.model._client.service import model_deployment_spec
@@ -171,6 +171,7 @@ class ServiceOperator:
171
171
  self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec(
172
172
  workspace_path=pathlib.Path(self._workspace.name)
173
173
  )
174
+ self._inference_autocapture_enabled = pc.PlatformCapabilities.get_instance().is_inference_autocapture_enabled()
174
175
 
175
176
  def __eq__(self, __value: object) -> bool:
176
177
  if not isinstance(__value, ServiceOperator):
@@ -207,7 +208,7 @@ class ServiceOperator:
207
208
  # inference engine model
208
209
  inference_engine_args: Optional[InferenceEngineArgs] = None,
209
210
  # inference table
210
- autocapture: Optional[bool] = None,
211
+ autocapture: bool = False,
211
212
  ) -> Union[str, async_job.AsyncJob]:
212
213
 
213
214
  # Generate operation ID for this deployment
@@ -231,6 +232,10 @@ class ServiceOperator:
231
232
  progress_status.update("preparing deployment artifacts...")
232
233
  progress_status.increment()
233
234
 
235
+ # If autocapture param is disabled, don't allow create service with autocapture
236
+ if not self._inference_autocapture_enabled and autocapture:
237
+ raise ValueError("Invalid Argument: Autocapture feature is not supported.")
238
+
234
239
  if self._workspace:
235
240
  stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
236
241
  else:
@@ -976,7 +981,7 @@ class ServiceOperator:
976
981
  gpu_requests: Optional[str],
977
982
  replicas: Optional[int],
978
983
  statement_params: Optional[dict[str, Any]] = None,
979
- ) -> jobs.MLJob[Any]:
984
+ ) -> job.MLJob[Any]:
980
985
  database_name = self._database_name
981
986
  schema_name = self._schema_name
982
987
 
@@ -1045,7 +1050,7 @@ class ServiceOperator:
1045
1050
  # Block until the async job is done
1046
1051
  async_job.result()
1047
1052
 
1048
- return jobs.MLJob(
1053
+ return job.MLJob(
1049
1054
  id=sql_identifier.get_fully_qualified_name(job_database_name, job_schema_name, job_name),
1050
1055
  session=self._session,
1051
1056
  )
@@ -3,7 +3,9 @@ import dataclasses
3
3
  import enum
4
4
  import logging
5
5
  import textwrap
6
- from typing import Any, Generator, Optional
6
+ from typing import Any, Generator, Optional, cast
7
+
8
+ import yaml
7
9
 
8
10
  from snowflake import snowpark
9
11
  from snowflake.ml._internal.utils import (
@@ -68,6 +70,7 @@ class ServiceStatusInfo:
68
70
 
69
71
  class ServiceSQLClient(_base._BaseSQLClient):
70
72
  MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME = "name"
73
+ MODEL_INFERENCE_SERVICE_ENDPOINT_PORT_COL_NAME = "port"
71
74
  MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME = "ingress_url"
72
75
  MODEL_INFERENCE_SERVICE_ENDPOINT_PRIVATELINK_INGRESS_URL_COL_NAME = "privatelink_ingress_url"
73
76
  SERVICE_STATUS = "service_status"
@@ -75,6 +78,14 @@ class ServiceSQLClient(_base._BaseSQLClient):
75
78
  INSTANCE_STATUS = "instance_status"
76
79
  CONTAINER_STATUS = "status"
77
80
  MESSAGE = "message"
81
+ DESC_SERVICE_INTERNAL_DNS_COL_NAME = "dns_name"
82
+ DESC_SERVICE_SPEC_COL_NAME = "spec"
83
+ DESC_SERVICE_CONTAINERS_SPEC_NAME = "containers"
84
+ DESC_SERVICE_NAME_SPEC_NAME = "name"
85
+ DESC_SERVICE_PROXY_SPEC_ENV_NAME = "env"
86
+ PROXY_CONTAINER_NAME = "proxy"
87
+ MODEL_INFERENCE_AUTOCAPTURE_ENV_NAME = "SPCS_MODEL_INFERENCE_SERVER__AUTOCAPTURE_ENABLED"
88
+ FEATURE_MODEL_INFERENCE_AUTOCAPTURE = "FEATURE_MODEL_INFERENCE_AUTOCAPTURE"
78
89
 
79
90
  @contextlib.contextmanager
80
91
  def _qmark_paramstyle(self) -> Generator[None, None, None]:
@@ -233,7 +244,15 @@ class ServiceSQLClient(_base._BaseSQLClient):
233
244
  ) -> list[ServiceStatusInfo]:
234
245
  fully_qualified_object_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
235
246
  query = f"SHOW SERVICE CONTAINERS IN SERVICE {fully_qualified_object_name}"
236
- rows = self._session.sql(query).collect(statement_params=statement_params)
247
+ rows = (
248
+ query_result_checker.SqlResultValidator(self._session, query, statement_params=statement_params)
249
+ .has_column(ServiceSQLClient.INSTANCE_STATUS)
250
+ .has_column(ServiceSQLClient.CONTAINER_STATUS)
251
+ .has_column(ServiceSQLClient.SERVICE_STATUS)
252
+ .has_column(ServiceSQLClient.INSTANCE_ID)
253
+ .has_column(ServiceSQLClient.MESSAGE)
254
+ .validate()
255
+ )
237
256
  statuses = []
238
257
  for r in rows:
239
258
  instance_status, container_status = None, None
@@ -252,6 +271,53 @@ class ServiceSQLClient(_base._BaseSQLClient):
252
271
  )
253
272
  return statuses
254
273
 
274
+ def describe_service(
275
+ self,
276
+ *,
277
+ database_name: Optional[sql_identifier.SqlIdentifier],
278
+ schema_name: Optional[sql_identifier.SqlIdentifier],
279
+ service_name: sql_identifier.SqlIdentifier,
280
+ statement_params: Optional[dict[str, Any]] = None,
281
+ ) -> row.Row:
282
+ fully_qualified_object_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
283
+ query = f"DESCRIBE SERVICE {fully_qualified_object_name}"
284
+ rows = (
285
+ query_result_checker.SqlResultValidator(self._session, query, statement_params=statement_params)
286
+ .has_dimensions(expected_rows=1)
287
+ .has_column(ServiceSQLClient.DESC_SERVICE_INTERNAL_DNS_COL_NAME)
288
+ .validate()
289
+ )
290
+ return rows[0]
291
+
292
+ def get_proxy_container_autocapture(self, row: row.Row) -> bool:
293
+ """Extract whether service has autocapture enabled from proxy container spec.
294
+
295
+ Args:
296
+ row: A row.Row object from DESCRIBE SERVICE containing the service YAML spec.
297
+
298
+ Returns:
299
+ True if autocapture is enabled in proxy spec
300
+ False if disabled or not set in proxy spec
301
+ False if service doesn't have proxy container
302
+ """
303
+ try:
304
+ spec_raw = yaml.safe_load(row[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME])
305
+ spec = cast(dict[str, Any], spec_raw)
306
+
307
+ proxy_container_spec = next(
308
+ container
309
+ for container in spec[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME][
310
+ ServiceSQLClient.DESC_SERVICE_CONTAINERS_SPEC_NAME
311
+ ]
312
+ if container[ServiceSQLClient.DESC_SERVICE_NAME_SPEC_NAME] == ServiceSQLClient.PROXY_CONTAINER_NAME
313
+ )
314
+ env = proxy_container_spec.get(ServiceSQLClient.DESC_SERVICE_PROXY_SPEC_ENV_NAME, {})
315
+ autocapture_enabled = env.get(ServiceSQLClient.MODEL_INFERENCE_AUTOCAPTURE_ENV_NAME, "false")
316
+ return str(autocapture_enabled).lower() == "true"
317
+
318
+ except StopIteration:
319
+ return False
320
+
255
321
  def drop_service(
256
322
  self,
257
323
  *,
@@ -282,6 +348,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
282
348
  statement_params=statement_params,
283
349
  )
284
350
  .has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME, allow_empty=True)
351
+ .has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_PORT_COL_NAME, allow_empty=True)
285
352
  .has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME, allow_empty=True)
286
353
  )
287
354
 
@@ -1,5 +1,6 @@
1
1
  import functools
2
2
  import importlib
3
+ import logging
3
4
  import pkgutil
4
5
  from types import ModuleType
5
6
  from typing import Any, Callable, Optional, TypeVar, cast
@@ -11,6 +12,8 @@ _HANDLERS_BASE = "snowflake.ml.model._packager.model_handlers"
11
12
  _MODEL_HANDLER_REGISTRY: dict[str, type[_base.BaseModelHandler[model_types.SupportedModelType]]] = dict()
12
13
  _IS_HANDLER_LOADED = False
13
14
 
15
+ logger = logging.getLogger(__name__)
16
+
14
17
 
15
18
  def _register_handlers() -> None:
16
19
  """
@@ -56,8 +59,11 @@ def find_handler(
56
59
  model: model_types.SupportedModelType,
57
60
  ) -> Optional[type[_base.BaseModelHandler[model_types.SupportedModelType]]]:
58
61
  for handler in _MODEL_HANDLER_REGISTRY.values():
59
- if handler.can_handle(model):
60
- return handler
62
+ try:
63
+ if handler.can_handle(model):
64
+ return handler
65
+ except Exception:
66
+ logger.error(f"Error in {handler.__name__} `can_handle` method for model {type(model)}", exc_info=True)
61
67
  return None
62
68
 
63
69