snowflake-ml-python 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +36 -0
  2. snowflake/ml/_internal/utils/url.py +42 -0
  3. snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
  4. snowflake/ml/data/data_connector.py +103 -1
  5. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
  6. snowflake/ml/experiment/callback/__init__.py +0 -0
  7. snowflake/ml/experiment/callback/keras.py +25 -2
  8. snowflake/ml/experiment/callback/lightgbm.py +27 -2
  9. snowflake/ml/experiment/callback/xgboost.py +25 -2
  10. snowflake/ml/experiment/experiment_tracking.py +93 -3
  11. snowflake/ml/experiment/utils.py +6 -0
  12. snowflake/ml/feature_store/feature_view.py +34 -24
  13. snowflake/ml/jobs/_interop/protocols.py +3 -0
  14. snowflake/ml/jobs/_utils/constants.py +1 -0
  15. snowflake/ml/jobs/_utils/payload_utils.py +354 -356
  16. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
  17. snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
  18. snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
  19. snowflake/ml/jobs/_utils/spec_utils.py +1 -445
  20. snowflake/ml/jobs/_utils/stage_utils.py +22 -1
  21. snowflake/ml/jobs/_utils/types.py +14 -7
  22. snowflake/ml/jobs/job.py +2 -8
  23. snowflake/ml/jobs/manager.py +57 -135
  24. snowflake/ml/lineage/lineage_node.py +1 -1
  25. snowflake/ml/model/__init__.py +6 -0
  26. snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
  27. snowflake/ml/model/_client/model/model_version_impl.py +130 -14
  28. snowflake/ml/model/_client/ops/deployment_step.py +36 -0
  29. snowflake/ml/model/_client/ops/model_ops.py +93 -8
  30. snowflake/ml/model/_client/ops/service_ops.py +32 -52
  31. snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
  32. snowflake/ml/model/_client/service/model_deployment_spec.py +12 -4
  33. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -0
  34. snowflake/ml/model/_client/sql/model_version.py +30 -6
  35. snowflake/ml/model/_client/sql/service.py +94 -5
  36. snowflake/ml/model/_model_composer/model_composer.py +1 -1
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
  38. snowflake/ml/model/_model_composer/model_method/model_method.py +61 -2
  39. snowflake/ml/model/_packager/model_handler.py +8 -2
  40. snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
  41. snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
  42. snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
  43. snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
  45. snowflake/ml/model/_packager/model_packager.py +1 -1
  46. snowflake/ml/model/_signatures/core.py +390 -8
  47. snowflake/ml/model/_signatures/utils.py +13 -4
  48. snowflake/ml/model/code_path.py +104 -0
  49. snowflake/ml/model/compute_pool.py +2 -0
  50. snowflake/ml/model/custom_model.py +55 -13
  51. snowflake/ml/model/model_signature.py +13 -1
  52. snowflake/ml/model/models/huggingface.py +285 -0
  53. snowflake/ml/model/models/huggingface_pipeline.py +19 -208
  54. snowflake/ml/model/type_hints.py +7 -1
  55. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  56. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
  57. snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
  58. snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
  59. snowflake/ml/registry/_manager/model_manager.py +230 -15
  60. snowflake/ml/registry/registry.py +4 -4
  61. snowflake/ml/utils/html_utils.py +67 -1
  62. snowflake/ml/version.py +1 -1
  63. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/METADATA +81 -7
  64. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/RECORD +67 -59
  65. snowflake/ml/jobs/_utils/runtime_env_utils.py +0 -63
  66. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/WHEEL +0 -0
  67. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/licenses/LICENSE.txt +0 -0
  68. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,6 @@
1
+ import base64
1
2
  import enum
3
+ import json
2
4
  import pathlib
3
5
  import tempfile
4
6
  import uuid
@@ -6,11 +8,12 @@ import warnings
6
8
  from typing import Any, Callable, Optional, Union, overload
7
9
 
8
10
  import pandas as pd
11
+ from pydantic import TypeAdapter
9
12
 
10
13
  from snowflake import snowpark
11
- from snowflake.ml import jobs
12
14
  from snowflake.ml._internal import telemetry
13
15
  from snowflake.ml._internal.utils import sql_identifier
16
+ from snowflake.ml.jobs import job
14
17
  from snowflake.ml.lineage import lineage_node
15
18
  from snowflake.ml.model import openai_signatures, task, type_hints
16
19
  from snowflake.ml.model._client.model import (
@@ -30,6 +33,7 @@ _TELEMETRY_PROJECT = "MLOps"
30
33
  _TELEMETRY_SUBPROJECT = "ModelManagement"
31
34
  _BATCH_INFERENCE_JOB_ID_PREFIX = "BATCH_INFERENCE_"
32
35
  _BATCH_INFERENCE_TEMPORARY_FOLDER = "_temporary"
36
+ _UTF8_ENCODING = "utf-8"
33
37
 
34
38
 
35
39
  class ExportMode(enum.Enum):
@@ -46,6 +50,7 @@ class ModelVersion(lineage_node.LineageNode):
46
50
  _version_name: sql_identifier.SqlIdentifier
47
51
  _functions: list[model_manifest_schema.ModelFunctionInfo]
48
52
  _model_spec: Optional[model_meta_schema.ModelMetadataDict]
53
+ _model_manifest: Optional[model_manifest_schema.ModelManifestDict]
49
54
 
50
55
  def __init__(self) -> None:
51
56
  raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
@@ -156,6 +161,7 @@ class ModelVersion(lineage_node.LineageNode):
156
161
  self._version_name = version_name
157
162
  self._functions = self._get_functions()
158
163
  self._model_spec = None
164
+ self._model_manifest = None
159
165
  super(cls, cls).__init__(
160
166
  self,
161
167
  session=model_ops._session,
@@ -463,6 +469,28 @@ class ModelVersion(lineage_node.LineageNode):
463
469
  )
464
470
  return self._model_spec
465
471
 
472
+ def _get_model_manifest(
473
+ self, statement_params: Optional[dict[str, Any]] = None
474
+ ) -> model_manifest_schema.ModelManifestDict:
475
+ """Fetch and cache the model manifest for this model version.
476
+
477
+ Args:
478
+ statement_params: Optional dictionary of statement parameters to include
479
+ in the SQL command to fetch the model manifest.
480
+
481
+ Returns:
482
+ The model manifest as a dictionary for this model version.
483
+ """
484
+ if self._model_manifest is None:
485
+ self._model_manifest = self._model_ops.get_model_version_manifest(
486
+ database_name=None,
487
+ schema_name=None,
488
+ model_name=self._model_name,
489
+ version_name=self._version_name,
490
+ statement_params=statement_params,
491
+ )
492
+ return self._model_manifest
493
+
466
494
  @overload
467
495
  def run(
468
496
  self,
@@ -471,6 +499,7 @@ class ModelVersion(lineage_node.LineageNode):
471
499
  function_name: Optional[str] = None,
472
500
  partition_column: Optional[str] = None,
473
501
  strict_input_validation: bool = False,
502
+ params: Optional[dict[str, Any]] = None,
474
503
  ) -> Union[pd.DataFrame, dataframe.DataFrame]:
475
504
  """Invoke a method in a model version object.
476
505
 
@@ -481,6 +510,8 @@ class ModelVersion(lineage_node.LineageNode):
481
510
  partition_column: The partition column name to partition by.
482
511
  strict_input_validation: Enable stricter validation for the input data. This will result value range based
483
512
  type validation to make sure your input data won't overflow when providing to the model.
513
+ params: Optional dictionary of model inference parameters (e.g., temperature, top_k for LLMs).
514
+ These are passed as keyword arguments to the model's inference method. Defaults to None.
484
515
  """
485
516
  ...
486
517
 
@@ -492,6 +523,7 @@ class ModelVersion(lineage_node.LineageNode):
492
523
  service_name: str,
493
524
  function_name: Optional[str] = None,
494
525
  strict_input_validation: bool = False,
526
+ params: Optional[dict[str, Any]] = None,
495
527
  ) -> Union[pd.DataFrame, dataframe.DataFrame]:
496
528
  """Invoke a method in a model version object via a service.
497
529
 
@@ -501,6 +533,8 @@ class ModelVersion(lineage_node.LineageNode):
501
533
  function_name: The function name to run. It is the name used to call a function in SQL.
502
534
  strict_input_validation: Enable stricter validation for the input data. This will result value range based
503
535
  type validation to make sure your input data won't overflow when providing to the model.
536
+ params: Optional dictionary of model inference parameters (e.g., temperature, top_k for LLMs).
537
+ These are passed as keyword arguments to the model's inference method. Defaults to None.
504
538
  """
505
539
  ...
506
540
 
@@ -517,6 +551,7 @@ class ModelVersion(lineage_node.LineageNode):
517
551
  function_name: Optional[str] = None,
518
552
  partition_column: Optional[str] = None,
519
553
  strict_input_validation: bool = False,
554
+ params: Optional[dict[str, Any]] = None,
520
555
  ) -> Union[pd.DataFrame, "dataframe.DataFrame"]:
521
556
  """Invoke a method in a model version object via the warehouse or a service.
522
557
 
@@ -528,9 +563,14 @@ class ModelVersion(lineage_node.LineageNode):
528
563
  partition_column: The partition column name to partition by.
529
564
  strict_input_validation: Enable stricter validation for the input data. This will result value range based
530
565
  type validation to make sure your input data won't overflow when providing to the model.
566
+ params: Optional dictionary of model inference parameters (e.g., temperature, top_k for LLMs).
567
+ These are passed as keyword arguments to the model's inference method. Defaults to None.
531
568
 
532
569
  Returns:
533
570
  The prediction data. It would be the same type dataframe as your input.
571
+
572
+ Raises:
573
+ ValueError: When the model does not support running on warehouse and no service name is provided.
534
574
  """
535
575
  statement_params = telemetry.get_statement_params(
536
576
  project=_TELEMETRY_PROJECT,
@@ -555,8 +595,30 @@ class ModelVersion(lineage_node.LineageNode):
555
595
  service_name=service_name_id,
556
596
  strict_input_validation=strict_input_validation,
557
597
  statement_params=statement_params,
598
+ params=params,
558
599
  )
559
600
  else:
601
+ manifest = self._get_model_manifest(statement_params=statement_params)
602
+ target_platforms = manifest.get("target_platforms", None)
603
+ if (
604
+ target_platforms is not None
605
+ and len(target_platforms) > 0
606
+ and type_hints.TargetPlatform.WAREHOUSE.value not in target_platforms
607
+ ):
608
+ raise ValueError(
609
+ f"The model {self.fully_qualified_model_name} version {self.version_name} "
610
+ "is not logged for inference in Warehouse. "
611
+ "To run the model in Warehouse, please log the model again using `log_model` API with "
612
+ '`target_platforms=["WAREHOUSE"]` or '
613
+ '`target_platforms=["WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"]` and rerun the command. '
614
+ "To run the model in Snowpark Container Services, the `service_name` argument must be provided. "
615
+ "You can create a service using the `create_service` API. "
616
+ "For inference in Warehouse, see https://docs.snowflake.com/en/developer-guide/"
617
+ "snowflake-ml/model-registry/warehouse#inference-from-python. "
618
+ "For inference in Snowpark Container Services, see https://docs.snowflake.com/en/developer-guide/"
619
+ "snowflake-ml/model-registry/container#python."
620
+ )
621
+
560
622
  explain_case_sensitive = self._determine_explain_case_sensitivity(target_function_info, statement_params)
561
623
 
562
624
  return self._model_ops.invoke_method(
@@ -573,6 +635,7 @@ class ModelVersion(lineage_node.LineageNode):
573
635
  statement_params=statement_params,
574
636
  is_partitioned=target_function_info["is_partitioned"],
575
637
  explain_case_sensitive=explain_case_sensitive,
638
+ params=params,
576
639
  )
577
640
 
578
641
  def _determine_explain_case_sensitivity(
@@ -586,6 +649,41 @@ class ModelVersion(lineage_node.LineageNode):
586
649
  method_options, target_function_info["name"]
587
650
  )
588
651
 
652
+ @staticmethod
653
+ def _encode_column_handling(
654
+ column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]],
655
+ ) -> Optional[str]:
656
+ """Validate and encode column_handling to a base64 string.
657
+
658
+ Args:
659
+ column_handling: Optional dictionary mapping column names to file encoding options.
660
+
661
+ Returns:
662
+ Base64 encoded JSON string of the column handling options, or None if input is None.
663
+ """
664
+ # TODO: validation for column names
665
+ if column_handling is None:
666
+ return None
667
+ adapter = TypeAdapter(dict[str, batch_inference_specs.ColumnHandlingOptions])
668
+ # TODO: throw error if the validate_python function fails
669
+ validated_input = adapter.validate_python(column_handling)
670
+ return base64.b64encode(adapter.dump_json(validated_input)).decode(_UTF8_ENCODING)
671
+
672
+ @staticmethod
673
+ def _encode_params(params: Optional[dict[str, Any]]) -> Optional[str]:
674
+ """Encode params dictionary to a base64 string.
675
+
676
+ Args:
677
+ params: Optional dictionary of model inference parameters.
678
+
679
+ Returns:
680
+ Base64 encoded JSON string of the params, or None if input is None.
681
+ """
682
+ if params is None:
683
+ return None
684
+ # TODO: validation for param names, types
685
+ return base64.b64encode(json.dumps(params).encode(_UTF8_ENCODING)).decode(_UTF8_ENCODING)
686
+
589
687
  @telemetry.send_api_usage_telemetry(
590
688
  project=_TELEMETRY_PROJECT,
591
689
  subproject=_TELEMETRY_SUBPROJECT,
@@ -603,7 +701,9 @@ class ModelVersion(lineage_node.LineageNode):
603
701
  input_spec: dataframe.DataFrame,
604
702
  output_spec: batch_inference_specs.OutputSpec,
605
703
  job_spec: Optional[batch_inference_specs.JobSpec] = None,
606
- ) -> jobs.MLJob[Any]:
704
+ params: Optional[dict[str, Any]] = None,
705
+ column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]] = None,
706
+ ) -> job.MLJob[Any]:
607
707
  """Execute batch inference on datasets as an SPCS job.
608
708
 
609
709
  Args:
@@ -616,9 +716,15 @@ class ModelVersion(lineage_node.LineageNode):
616
716
  job_spec (Optional[batch_inference_specs.JobSpec]): Optional configuration for job
617
717
  execution parameters such as compute resources, worker counts, and job naming.
618
718
  If None, default values will be used.
719
+ params (Optional[dict[str, Any]]): Optional dictionary of model inference parameters
720
+ (e.g., temperature, top_k for LLMs). These are passed as keyword arguments to the
721
+ model's inference method. Defaults to None.
722
+ column_handling (Optional[dict[str, batch_inference_specs.FileEncoding]]): Optional dictionary
723
+ specifying how to handle specific columns during file I/O. Maps column names to their
724
+ file encoding configuration.
619
725
 
620
726
  Returns:
621
- jobs.MLJob[Any]: A batch inference job object that can be used to monitor progress and manage the job
727
+ job.MLJob[Any]: A batch inference job object that can be used to monitor progress and manage the job
622
728
  lifecycle.
623
729
 
624
730
  Raises:
@@ -671,6 +777,9 @@ class ModelVersion(lineage_node.LineageNode):
671
777
  subproject=_TELEMETRY_SUBPROJECT,
672
778
  )
673
779
 
780
+ column_handling_as_string = self._encode_column_handling(column_handling)
781
+ params_as_string = self._encode_params(params)
782
+
674
783
  if job_spec is None:
675
784
  job_spec = batch_inference_specs.JobSpec()
676
785
 
@@ -718,6 +827,8 @@ class ModelVersion(lineage_node.LineageNode):
718
827
  # input and output
719
828
  input_stage_location=input_stage_location,
720
829
  input_file_pattern="*",
830
+ column_handling=column_handling_as_string,
831
+ params=params_as_string,
721
832
  output_stage_location=output_stage_location,
722
833
  completion_filename="_SUCCESS",
723
834
  # misc
@@ -1019,6 +1130,7 @@ class ModelVersion(lineage_node.LineageNode):
1019
1130
  force_rebuild: bool = False,
1020
1131
  build_external_access_integration: Optional[str] = None,
1021
1132
  block: bool = True,
1133
+ autocapture: bool = False,
1022
1134
  inference_engine_options: Optional[dict[str, Any]] = None,
1023
1135
  experimental_options: Optional[dict[str, Any]] = None,
1024
1136
  ) -> Union[str, async_job.AsyncJob]:
@@ -1053,13 +1165,13 @@ class ModelVersion(lineage_node.LineageNode):
1053
1165
  block: A bool value indicating whether this function will wait until the service is available.
1054
1166
  When it is ``False``, this function executes the underlying service creation asynchronously
1055
1167
  and returns an :class:`AsyncJob`.
1168
+ autocapture: Whether inference autocapture is enabled on the service. If true, inference data will be
1169
+ captured in the model inference table.
1056
1170
  inference_engine_options: Options for the service creation with custom inference engine.
1057
1171
  Supports `engine` and `engine_args_override`.
1058
1172
  `engine` is the type of the inference engine to use.
1059
1173
  `engine_args_override` is a list of string arguments to pass to the inference engine.
1060
1174
  experimental_options: Experimental options for the service creation.
1061
- Currently only `autocapture` is supported.
1062
- `autocapture` is a boolean to enable/disable inference table.
1063
1175
  """
1064
1176
  ...
1065
1177
 
@@ -1081,6 +1193,7 @@ class ModelVersion(lineage_node.LineageNode):
1081
1193
  force_rebuild: bool = False,
1082
1194
  build_external_access_integrations: Optional[list[str]] = None,
1083
1195
  block: bool = True,
1196
+ autocapture: bool = False,
1084
1197
  inference_engine_options: Optional[dict[str, Any]] = None,
1085
1198
  experimental_options: Optional[dict[str, Any]] = None,
1086
1199
  ) -> Union[str, async_job.AsyncJob]:
@@ -1115,13 +1228,13 @@ class ModelVersion(lineage_node.LineageNode):
1115
1228
  block: A bool value indicating whether this function will wait until the service is available.
1116
1229
  When it is ``False``, this function executes the underlying service creation asynchronously
1117
1230
  and returns an :class:`AsyncJob`.
1231
+ autocapture: Whether inference autocapture is enabled on the service. If true, inference data will be
1232
+ captured in the model inference table.
1118
1233
  inference_engine_options: Options for the service creation with custom inference engine.
1119
1234
  Supports `engine` and `engine_args_override`.
1120
1235
  `engine` is the type of the inference engine to use.
1121
1236
  `engine_args_override` is a list of string arguments to pass to the inference engine.
1122
1237
  experimental_options: Experimental options for the service creation.
1123
- Currently only `autocapture` is supported.
1124
- `autocapture` is a boolean to enable/disable inference table.
1125
1238
  """
1126
1239
  ...
1127
1240
 
@@ -1158,6 +1271,7 @@ class ModelVersion(lineage_node.LineageNode):
1158
1271
  build_external_access_integration: Optional[str] = None,
1159
1272
  build_external_access_integrations: Optional[list[str]] = None,
1160
1273
  block: bool = True,
1274
+ autocapture: bool = False,
1161
1275
  inference_engine_options: Optional[dict[str, Any]] = None,
1162
1276
  experimental_options: Optional[dict[str, Any]] = None,
1163
1277
  ) -> Union[str, async_job.AsyncJob]:
@@ -1194,13 +1308,13 @@ class ModelVersion(lineage_node.LineageNode):
1194
1308
  block: A bool value indicating whether this function will wait until the service is available.
1195
1309
  When it is False, this function executes the underlying service creation asynchronously
1196
1310
  and returns an AsyncJob.
1311
+ autocapture: Whether inference autocapture is enabled on the service. If true, inference data will be
1312
+ captured in the model inference table.
1197
1313
  inference_engine_options: Options for the service creation with custom inference engine.
1198
1314
  Supports `engine` and `engine_args_override`.
1199
1315
  `engine` is the type of the inference engine to use.
1200
1316
  `engine_args_override` is a list of string arguments to pass to the inference engine.
1201
1317
  experimental_options: Experimental options for the service creation.
1202
- Currently only `autocapture` is supported.
1203
- `autocapture` is a boolean to enable/disable inference table.
1204
1318
 
1205
1319
 
1206
1320
  Raises:
@@ -1251,9 +1365,6 @@ class ModelVersion(lineage_node.LineageNode):
1251
1365
  gpu_requests,
1252
1366
  )
1253
1367
 
1254
- # Extract autocapture from experimental_options
1255
- autocapture = experimental_options.get("autocapture") if experimental_options else None
1256
-
1257
1368
  from snowflake.ml.model import event_handler
1258
1369
  from snowflake.snowpark import exceptions
1259
1370
 
@@ -1320,8 +1431,13 @@ class ModelVersion(lineage_node.LineageNode):
1320
1431
  """List all the service names using this model version.
1321
1432
 
1322
1433
  Returns:
1323
- List of service_names: The name of the service, can be fully qualified. If not fully qualified, the database
1324
- or schema of the model will be used.
1434
+ List of details about all the services associated with this model version. The details include:
1435
+ name: The name of the service.
1436
+ status: The status of the service.
1437
+ inference_endpoint: The public endpoint of the service, if enabled and services is not in PENDING state.
1438
+ This will give privatelink endpoint if the session is created with privatelink connection
1439
+ internal_endpoint: The internal endpoint of the service, if services is not in PENDING state.
1440
+ autocapture_enabled: Whether service has autocapture enabled, if it is set in service proxy spec.
1325
1441
  """
1326
1442
  statement_params = telemetry.get_statement_params(
1327
1443
  project=_TELEMETRY_PROJECT,
@@ -0,0 +1,36 @@
1
+ import enum
2
+ import hashlib
3
+ from typing import Optional
4
+
5
+
6
+ class DeploymentStep(enum.Enum):
7
+ MODEL_BUILD = ("model-build", "model_build_")
8
+ MODEL_INFERENCE = ("model-inference", None)
9
+ MODEL_LOGGING = ("model-logging", "model_logging_")
10
+
11
+ def __init__(self, container_name: str, service_name_prefix: Optional[str]) -> None:
12
+ self._container_name = container_name
13
+ self._service_name_prefix = service_name_prefix
14
+
15
+ @property
16
+ def container_name(self) -> str:
17
+ """Get the container name for the deployment step."""
18
+ return self._container_name
19
+
20
+ @property
21
+ def service_name_prefix(self) -> Optional[str]:
22
+ """Get the service name prefix for the deployment step."""
23
+ return self._service_name_prefix
24
+
25
+
26
+ def get_service_id_from_deployment_step(query_id: str, deployment_step: DeploymentStep) -> str:
27
+ """Get the service ID through the server-side logic."""
28
+ uuid = query_id.replace("-", "")
29
+ big_int = int(uuid, 16)
30
+ md5_hash = hashlib.md5(str(big_int).encode(), usedforsecurity=False).hexdigest()
31
+ identifier = md5_hash[:8]
32
+ service_name_prefix = deployment_step.service_name_prefix
33
+ if service_name_prefix is None:
34
+ # raise an exception if the service name prefix is None
35
+ raise ValueError(f"Service name prefix is {service_name_prefix} for deployment step {deployment_step}.")
36
+ return (service_name_prefix + identifier).upper()
@@ -1,5 +1,6 @@
1
1
  import enum
2
2
  import json
3
+ import logging
3
4
  import os
4
5
  import pathlib
5
6
  import tempfile
@@ -7,11 +8,13 @@ import warnings
7
8
  from typing import Any, Literal, Optional, TypedDict, Union, cast, overload
8
9
 
9
10
  import yaml
11
+ from typing_extensions import NotRequired
10
12
 
13
+ from snowflake.ml._internal import platform_capabilities
11
14
  from snowflake.ml._internal.exceptions import error_codes, exceptions
12
- from snowflake.ml._internal.utils import formatting, identifier, sql_identifier
15
+ from snowflake.ml._internal.utils import formatting, identifier, sql_identifier, url
13
16
  from snowflake.ml.model import model_signature, type_hints
14
- from snowflake.ml.model._client.ops import metadata_ops
17
+ from snowflake.ml.model._client.ops import deployment_step, metadata_ops
15
18
  from snowflake.ml.model._client.sql import (
16
19
  model as model_sql,
17
20
  model_version as model_version_sql,
@@ -31,6 +34,8 @@ from snowflake.ml.model._signatures import snowpark_handler
31
34
  from snowflake.snowpark import dataframe, row, session
32
35
  from snowflake.snowpark._internal import utils as snowpark_utils
33
36
 
37
+ logger = logging.getLogger(__name__)
38
+
34
39
 
35
40
  # An enum class to represent Create Or Alter Model SQL command.
36
41
  class ModelAction(enum.Enum):
@@ -42,6 +47,8 @@ class ServiceInfo(TypedDict):
42
47
  name: str
43
48
  status: str
44
49
  inference_endpoint: Optional[str]
50
+ internal_endpoint: Optional[str]
51
+ autocapture_enabled: NotRequired[bool]
45
52
 
46
53
 
47
54
  class ModelOperator:
@@ -651,6 +658,13 @@ class ModelOperator:
651
658
  url_str = str(url_value)
652
659
  return url_str if ModelOperator.PRIVATELINK_INGRESS_ENDPOINT_URL_SUBSTRING in url_str else None
653
660
 
661
+ def _extract_and_validate_port(self, res_row: "row.Row") -> Optional[int]:
662
+ """Extract and validate port from endpoint row."""
663
+ port_value = res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_PORT_COL_NAME]
664
+ if port_value is None:
665
+ return None
666
+ return int(port_value)
667
+
654
668
  def show_services(
655
669
  self,
656
670
  *,
@@ -684,8 +698,12 @@ class ModelOperator:
684
698
 
685
699
  result: list[ServiceInfo] = []
686
700
  is_privatelink_connection = self._is_privatelink_connection()
701
+ is_autocapture_param_enabled = (
702
+ platform_capabilities.PlatformCapabilities.get_instance().is_inference_autocapture_enabled()
703
+ )
687
704
 
688
705
  for fully_qualified_service_name in fully_qualified_service_names:
706
+ port: Optional[int] = None
689
707
  inference_endpoint: Optional[str] = None
690
708
  db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
691
709
  statuses = self._service_client.get_service_container_statuses(
@@ -695,6 +713,11 @@ class ModelOperator:
695
713
  return result
696
714
 
697
715
  service_status = statuses[0].service_status
716
+ service_description = self._service_client.describe_service(
717
+ database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
718
+ )
719
+ internal_dns = str(service_description[self._service_client.DESC_SERVICE_INTERNAL_DNS_COL_NAME])
720
+
698
721
  for res_row in self._service_client.show_endpoints(
699
722
  database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
700
723
  ):
@@ -706,19 +729,25 @@ class ModelOperator:
706
729
 
707
730
  ingress_url = self._extract_and_validate_ingress_url(res_row)
708
731
  privatelink_ingress_url = self._extract_and_validate_privatelink_url(res_row)
732
+ port = self._extract_and_validate_port(res_row)
709
733
 
710
734
  if is_privatelink_connection and privatelink_ingress_url is not None:
711
735
  inference_endpoint = privatelink_ingress_url
712
736
  else:
713
737
  inference_endpoint = ingress_url
714
738
 
715
- result.append(
716
- ServiceInfo(
717
- name=fully_qualified_service_name,
718
- status=service_status.value,
719
- inference_endpoint=inference_endpoint,
720
- )
739
+ service_info = ServiceInfo(
740
+ name=fully_qualified_service_name,
741
+ status=service_status.value,
742
+ inference_endpoint=inference_endpoint,
743
+ internal_endpoint=f"http://{internal_dns}:{port}" if port is not None else None,
721
744
  )
745
+ if is_autocapture_param_enabled and self._service_client.DESC_SERVICE_SPEC_COL_NAME in service_description:
746
+ # Include column only if parameter is enabled and spec exists for service owner caller
747
+ autocapture_enabled = self._service_client.get_proxy_container_autocapture(service_description)
748
+ service_info["autocapture_enabled"] = autocapture_enabled
749
+
750
+ result.append(service_info)
722
751
 
723
752
  return result
724
753
 
@@ -960,6 +989,7 @@ class ModelOperator:
960
989
  statement_params: Optional[dict[str, str]] = None,
961
990
  is_partitioned: Optional[bool] = None,
962
991
  explain_case_sensitive: bool = False,
992
+ params: Optional[dict[str, Any]] = None,
963
993
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
964
994
  ...
965
995
 
@@ -976,6 +1006,7 @@ class ModelOperator:
976
1006
  strict_input_validation: bool = False,
977
1007
  statement_params: Optional[dict[str, str]] = None,
978
1008
  explain_case_sensitive: bool = False,
1009
+ params: Optional[dict[str, Any]] = None,
979
1010
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
980
1011
  ...
981
1012
 
@@ -996,6 +1027,7 @@ class ModelOperator:
996
1027
  statement_params: Optional[dict[str, str]] = None,
997
1028
  is_partitioned: Optional[bool] = None,
998
1029
  explain_case_sensitive: bool = False,
1030
+ params: Optional[dict[str, Any]] = None,
999
1031
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
1000
1032
  identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
1001
1033
 
@@ -1031,6 +1063,24 @@ class ModelOperator:
1031
1063
  col_name = sql_identifier.SqlIdentifier(input_feature.name.upper(), case_sensitive=True)
1032
1064
  input_args.append(col_name)
1033
1065
 
1066
+ method_parameters: Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]] = None
1067
+ if signature.params:
1068
+ # Start with defaults from signature
1069
+ final_params = {}
1070
+ for param_spec in signature.params:
1071
+ if hasattr(param_spec, "default_value"):
1072
+ final_params[param_spec.name] = param_spec.default_value
1073
+
1074
+ # Override with provided runtime parameters
1075
+ if params:
1076
+ final_params.update(params)
1077
+
1078
+ # Convert to list of tuples with SqlIdentifier for parameter names
1079
+ method_parameters = [
1080
+ (sql_identifier.SqlIdentifier(param_name), param_value)
1081
+ for param_name, param_value in final_params.items()
1082
+ ]
1083
+
1034
1084
  returns = []
1035
1085
  for output_feature in signature.outputs:
1036
1086
  output_name = identifier_rule.get_sql_identifier_from_feature(output_feature.name)
@@ -1049,6 +1099,7 @@ class ModelOperator:
1049
1099
  schema_name=schema_name,
1050
1100
  service_name=service_name,
1051
1101
  statement_params=statement_params,
1102
+ params=method_parameters,
1052
1103
  )
1053
1104
  else:
1054
1105
  assert model_name is not None
@@ -1064,6 +1115,7 @@ class ModelOperator:
1064
1115
  model_name=model_name,
1065
1116
  version_name=version_name,
1066
1117
  statement_params=statement_params,
1118
+ params=method_parameters,
1067
1119
  )
1068
1120
  elif method_function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
1069
1121
  df_res = self._model_version_client.invoke_table_function_method(
@@ -1079,6 +1131,7 @@ class ModelOperator:
1079
1131
  statement_params=statement_params,
1080
1132
  is_partitioned=is_partitioned or False,
1081
1133
  explain_case_sensitive=explain_case_sensitive,
1134
+ params=method_parameters,
1082
1135
  )
1083
1136
 
1084
1137
  if keep_order:
@@ -1212,3 +1265,35 @@ class ModelOperator:
1212
1265
  target_path=local_file_dir,
1213
1266
  statement_params=statement_params,
1214
1267
  )
1268
+
1269
+ def run_import_model_query(
1270
+ self,
1271
+ *,
1272
+ database_name: str,
1273
+ schema_name: str,
1274
+ yaml_content: str,
1275
+ statement_params: Optional[dict[str, Any]] = None,
1276
+ ) -> None:
1277
+ yaml_content_escaped = snowpark_utils.escape_single_quotes(yaml_content) # type: ignore[no-untyped-call]
1278
+
1279
+ async_job = self._session.sql(
1280
+ f"SELECT SYSTEM$IMPORT_MODEL('{yaml_content_escaped}')",
1281
+ ).collect(block=False, statement_params=statement_params)
1282
+ query_id = async_job.query_id # type: ignore[attr-defined]
1283
+
1284
+ logger.info(f"Remotely importing model, with the query id: {query_id}")
1285
+ model_logger_service_name = sql_identifier.SqlIdentifier(
1286
+ deployment_step.get_service_id_from_deployment_step(
1287
+ query_id,
1288
+ deployment_step.DeploymentStep.MODEL_LOGGING,
1289
+ )
1290
+ )
1291
+
1292
+ logger_name = model_logger_service_name.identifier()
1293
+ job_url = f"{url.JOB_URL_PREFIX}/{database_name}/{schema_name}/{logger_name}"
1294
+ snowflake_url = url.get_snowflake_url(session=self._session, url_path=job_url)
1295
+ logger.info(
1296
+ f"To monitor the progress of the model logging job, head to the job monitoring page {snowflake_url}"
1297
+ )
1298
+
1299
+ async_job.result() # type: ignore[attr-defined]