snowflake-ml-python 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +36 -0
  2. snowflake/ml/_internal/utils/url.py +42 -0
  3. snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
  4. snowflake/ml/data/data_connector.py +103 -1
  5. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
  6. snowflake/ml/experiment/callback/__init__.py +0 -0
  7. snowflake/ml/experiment/callback/keras.py +25 -2
  8. snowflake/ml/experiment/callback/lightgbm.py +27 -2
  9. snowflake/ml/experiment/callback/xgboost.py +25 -2
  10. snowflake/ml/experiment/experiment_tracking.py +93 -3
  11. snowflake/ml/experiment/utils.py +6 -0
  12. snowflake/ml/feature_store/feature_view.py +34 -24
  13. snowflake/ml/jobs/_interop/protocols.py +3 -0
  14. snowflake/ml/jobs/_utils/constants.py +1 -0
  15. snowflake/ml/jobs/_utils/payload_utils.py +354 -356
  16. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
  17. snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
  18. snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
  19. snowflake/ml/jobs/_utils/spec_utils.py +1 -445
  20. snowflake/ml/jobs/_utils/stage_utils.py +22 -1
  21. snowflake/ml/jobs/_utils/types.py +14 -7
  22. snowflake/ml/jobs/job.py +2 -8
  23. snowflake/ml/jobs/manager.py +57 -135
  24. snowflake/ml/lineage/lineage_node.py +1 -1
  25. snowflake/ml/model/__init__.py +6 -0
  26. snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
  27. snowflake/ml/model/_client/model/model_version_impl.py +130 -14
  28. snowflake/ml/model/_client/ops/deployment_step.py +36 -0
  29. snowflake/ml/model/_client/ops/model_ops.py +93 -8
  30. snowflake/ml/model/_client/ops/service_ops.py +32 -52
  31. snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
  32. snowflake/ml/model/_client/service/model_deployment_spec.py +12 -4
  33. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -0
  34. snowflake/ml/model/_client/sql/model_version.py +30 -6
  35. snowflake/ml/model/_client/sql/service.py +94 -5
  36. snowflake/ml/model/_model_composer/model_composer.py +1 -1
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
  38. snowflake/ml/model/_model_composer/model_method/model_method.py +61 -2
  39. snowflake/ml/model/_packager/model_handler.py +8 -2
  40. snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
  41. snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
  42. snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
  43. snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
  45. snowflake/ml/model/_packager/model_packager.py +1 -1
  46. snowflake/ml/model/_signatures/core.py +390 -8
  47. snowflake/ml/model/_signatures/utils.py +13 -4
  48. snowflake/ml/model/code_path.py +104 -0
  49. snowflake/ml/model/compute_pool.py +2 -0
  50. snowflake/ml/model/custom_model.py +55 -13
  51. snowflake/ml/model/model_signature.py +13 -1
  52. snowflake/ml/model/models/huggingface.py +285 -0
  53. snowflake/ml/model/models/huggingface_pipeline.py +19 -208
  54. snowflake/ml/model/type_hints.py +7 -1
  55. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  56. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
  57. snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
  58. snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
  59. snowflake/ml/registry/_manager/model_manager.py +230 -15
  60. snowflake/ml/registry/registry.py +4 -4
  61. snowflake/ml/utils/html_utils.py +67 -1
  62. snowflake/ml/version.py +1 -1
  63. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/METADATA +81 -7
  64. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/RECORD +67 -59
  65. snowflake/ml/jobs/_utils/runtime_env_utils.py +0 -63
  66. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/WHEEL +0 -0
  67. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/licenses/LICENSE.txt +0 -0
  68. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,4 @@
1
1
  import dataclasses
2
- import enum
3
- import hashlib
4
2
  import logging
5
3
  import pathlib
6
4
  import re
@@ -11,11 +9,12 @@ import warnings
11
9
  from typing import Any, Optional, Union, cast
12
10
 
13
11
  from snowflake import snowpark
14
- from snowflake.ml import jobs
15
12
  from snowflake.ml._internal import file_utils, platform_capabilities as pc
16
13
  from snowflake.ml._internal.utils import identifier, service_logger, sql_identifier
14
+ from snowflake.ml.jobs import job
17
15
  from snowflake.ml.model import inference_engine as inference_engine_module, type_hints
18
16
  from snowflake.ml.model._client.model import batch_inference_specs
17
+ from snowflake.ml.model._client.ops import deployment_step
19
18
  from snowflake.ml.model._client.service import model_deployment_spec
20
19
  from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
21
20
  from snowflake.snowpark import async_job, exceptions, row, session
@@ -25,32 +24,12 @@ module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY
25
24
  module_logger.propagate = False
26
25
 
27
26
 
28
- class DeploymentStep(enum.Enum):
29
- MODEL_BUILD = ("model-build", "model_build_")
30
- MODEL_INFERENCE = ("model-inference", None)
31
- MODEL_LOGGING = ("model-logging", "model_logging_")
32
-
33
- def __init__(self, container_name: str, service_name_prefix: Optional[str]) -> None:
34
- self._container_name = container_name
35
- self._service_name_prefix = service_name_prefix
36
-
37
- @property
38
- def container_name(self) -> str:
39
- """Get the container name for the deployment step."""
40
- return self._container_name
41
-
42
- @property
43
- def service_name_prefix(self) -> Optional[str]:
44
- """Get the service name prefix for the deployment step."""
45
- return self._service_name_prefix
46
-
47
-
48
27
  @dataclasses.dataclass
49
28
  class ServiceLogInfo:
50
29
  database_name: Optional[sql_identifier.SqlIdentifier]
51
30
  schema_name: Optional[sql_identifier.SqlIdentifier]
52
31
  service_name: sql_identifier.SqlIdentifier
53
- deployment_step: DeploymentStep
32
+ deployment_step: deployment_step.DeploymentStep
54
33
  instance_id: str = "0"
55
34
  log_color: service_logger.LogColor = service_logger.LogColor.GREY
56
35
 
@@ -171,6 +150,7 @@ class ServiceOperator:
171
150
  self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec(
172
151
  workspace_path=pathlib.Path(self._workspace.name)
173
152
  )
153
+ self._inference_autocapture_enabled = pc.PlatformCapabilities.get_instance().is_inference_autocapture_enabled()
174
154
 
175
155
  def __eq__(self, __value: object) -> bool:
176
156
  if not isinstance(__value, ServiceOperator):
@@ -207,7 +187,7 @@ class ServiceOperator:
207
187
  # inference engine model
208
188
  inference_engine_args: Optional[InferenceEngineArgs] = None,
209
189
  # inference table
210
- autocapture: Optional[bool] = None,
190
+ autocapture: bool = False,
211
191
  ) -> Union[str, async_job.AsyncJob]:
212
192
 
213
193
  # Generate operation ID for this deployment
@@ -231,6 +211,10 @@ class ServiceOperator:
231
211
  progress_status.update("preparing deployment artifacts...")
232
212
  progress_status.increment()
233
213
 
214
+ # If autocapture param is disabled, don't allow create service with autocapture
215
+ if not self._inference_autocapture_enabled and autocapture:
216
+ raise ValueError("Invalid Argument: Autocapture feature is not supported.")
217
+
234
218
  if self._workspace:
235
219
  stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
236
220
  else:
@@ -348,13 +332,16 @@ class ServiceOperator:
348
332
  if is_enable_image_build:
349
333
  # stream service logs in a thread
350
334
  model_build_service_name = sql_identifier.SqlIdentifier(
351
- self._get_service_id_from_deployment_step(query_id, DeploymentStep.MODEL_BUILD)
335
+ deployment_step.get_service_id_from_deployment_step(
336
+ query_id,
337
+ deployment_step.DeploymentStep.MODEL_BUILD,
338
+ )
352
339
  )
353
340
  model_build_service = ServiceLogInfo(
354
341
  database_name=service_database_name,
355
342
  schema_name=service_schema_name,
356
343
  service_name=model_build_service_name,
357
- deployment_step=DeploymentStep.MODEL_BUILD,
344
+ deployment_step=deployment_step.DeploymentStep.MODEL_BUILD,
358
345
  log_color=service_logger.LogColor.GREEN,
359
346
  )
360
347
 
@@ -362,21 +349,23 @@ class ServiceOperator:
362
349
  database_name=service_database_name,
363
350
  schema_name=service_schema_name,
364
351
  service_name=service_name,
365
- deployment_step=DeploymentStep.MODEL_INFERENCE,
352
+ deployment_step=deployment_step.DeploymentStep.MODEL_INFERENCE,
366
353
  log_color=service_logger.LogColor.BLUE,
367
354
  )
368
355
 
369
356
  model_logger_service: Optional[ServiceLogInfo] = None
370
357
  if hf_model_args:
371
358
  model_logger_service_name = sql_identifier.SqlIdentifier(
372
- self._get_service_id_from_deployment_step(query_id, DeploymentStep.MODEL_LOGGING)
359
+ deployment_step.get_service_id_from_deployment_step(
360
+ query_id, deployment_step.DeploymentStep.MODEL_LOGGING
361
+ )
373
362
  )
374
363
 
375
364
  model_logger_service = ServiceLogInfo(
376
365
  database_name=service_database_name,
377
366
  schema_name=service_schema_name,
378
367
  service_name=model_logger_service_name,
379
- deployment_step=DeploymentStep.MODEL_LOGGING,
368
+ deployment_step=deployment_step.DeploymentStep.MODEL_LOGGING,
380
369
  log_color=service_logger.LogColor.ORANGE,
381
370
  )
382
371
 
@@ -531,7 +520,7 @@ class ServiceOperator:
531
520
  service = service_log_meta.service
532
521
  # check if using an existing model build image
533
522
  if (
534
- service.deployment_step == DeploymentStep.MODEL_BUILD
523
+ service.deployment_step == deployment_step.DeploymentStep.MODEL_BUILD
535
524
  and not force_rebuild
536
525
  and service_log_meta.is_model_logger_service_done
537
526
  and not service_log_meta.is_model_build_service_done
@@ -577,16 +566,16 @@ class ServiceOperator:
577
566
  if (service_status != service_sql.ServiceStatus.RUNNING) or (service_status != service_log_meta.service_status):
578
567
  service_log_meta.service_status = service_status
579
568
 
580
- if service.deployment_step == DeploymentStep.MODEL_BUILD:
569
+ if service.deployment_step == deployment_step.DeploymentStep.MODEL_BUILD:
581
570
  module_logger.info(
582
571
  f"Image build service {service.display_service_name} is "
583
572
  f"{service_log_meta.service_status.value}."
584
573
  )
585
- elif service.deployment_step == DeploymentStep.MODEL_INFERENCE:
574
+ elif service.deployment_step == deployment_step.DeploymentStep.MODEL_INFERENCE:
586
575
  module_logger.info(
587
576
  f"Inference service {service.display_service_name} is {service_log_meta.service_status.value}."
588
577
  )
589
- elif service.deployment_step == DeploymentStep.MODEL_LOGGING:
578
+ elif service.deployment_step == deployment_step.DeploymentStep.MODEL_LOGGING:
590
579
  module_logger.info(
591
580
  f"Model logger service {service.display_service_name} is "
592
581
  f"{service_log_meta.service_status.value}."
@@ -622,7 +611,7 @@ class ServiceOperator:
622
611
  if service_status == service_sql.ServiceStatus.DONE:
623
612
  # check if model logger service is done
624
613
  # and transition the service log metadata to the model image build service
625
- if service.deployment_step == DeploymentStep.MODEL_LOGGING:
614
+ if service.deployment_step == deployment_step.DeploymentStep.MODEL_LOGGING:
626
615
  if model_build_service:
627
616
  # building the inference image, transition to the model build service
628
617
  service_log_meta.transition_service_log_metadata(
@@ -643,7 +632,7 @@ class ServiceOperator:
643
632
  )
644
633
  # check if model build service is done
645
634
  # and transition the service log metadata to the model inference service
646
- elif service.deployment_step == DeploymentStep.MODEL_BUILD:
635
+ elif service.deployment_step == deployment_step.DeploymentStep.MODEL_BUILD:
647
636
  service_log_meta.transition_service_log_metadata(
648
637
  model_inference_service,
649
638
  f"Image build service {service.display_service_name} complete.",
@@ -651,7 +640,7 @@ class ServiceOperator:
651
640
  is_model_logger_service_done=service_log_meta.is_model_logger_service_done,
652
641
  operation_id=operation_id,
653
642
  )
654
- elif service.deployment_step == DeploymentStep.MODEL_INFERENCE:
643
+ elif service.deployment_step == deployment_step.DeploymentStep.MODEL_INFERENCE:
655
644
  module_logger.info(f"Inference service {service.display_service_name} is deployed.")
656
645
  else:
657
646
  module_logger.warning(f"Service {service.display_service_name} is done, but not transitioning.")
@@ -911,19 +900,6 @@ class ServiceOperator:
911
900
 
912
901
  time.sleep(2) # Poll every 2 seconds
913
902
 
914
- @staticmethod
915
- def _get_service_id_from_deployment_step(query_id: str, deployment_step: DeploymentStep) -> str:
916
- """Get the service ID through the server-side logic."""
917
- uuid = query_id.replace("-", "")
918
- big_int = int(uuid, 16)
919
- md5_hash = hashlib.md5(str(big_int).encode()).hexdigest()
920
- identifier = md5_hash[:8]
921
- service_name_prefix = deployment_step.service_name_prefix
922
- if service_name_prefix is None:
923
- # raise an exception if the service name prefix is None
924
- raise ValueError(f"Service name prefix is {service_name_prefix} for deployment step {deployment_step}.")
925
- return (service_name_prefix + identifier).upper()
926
-
927
903
  def _check_if_service_exists(
928
904
  self,
929
905
  database_name: Optional[sql_identifier.SqlIdentifier],
@@ -966,6 +942,8 @@ class ServiceOperator:
966
942
  image_repo_name: Optional[str],
967
943
  input_stage_location: str,
968
944
  input_file_pattern: str,
945
+ column_handling: Optional[str],
946
+ params: Optional[str],
969
947
  output_stage_location: str,
970
948
  completion_filename: str,
971
949
  force_rebuild: bool,
@@ -976,7 +954,7 @@ class ServiceOperator:
976
954
  gpu_requests: Optional[str],
977
955
  replicas: Optional[int],
978
956
  statement_params: Optional[dict[str, Any]] = None,
979
- ) -> jobs.MLJob[Any]:
957
+ ) -> job.MLJob[Any]:
980
958
  database_name = self._database_name
981
959
  schema_name = self._schema_name
982
960
 
@@ -1002,6 +980,8 @@ class ServiceOperator:
1002
980
  max_batch_rows=max_batch_rows,
1003
981
  input_stage_location=input_stage_location,
1004
982
  input_file_pattern=input_file_pattern,
983
+ column_handling=column_handling,
984
+ params=params,
1005
985
  output_stage_location=output_stage_location,
1006
986
  completion_filename=completion_filename,
1007
987
  function_name=function_name,
@@ -1045,7 +1025,7 @@ class ServiceOperator:
1045
1025
  # Block until the async job is done
1046
1026
  async_job.result()
1047
1027
 
1048
- return jobs.MLJob(
1028
+ return job.MLJob(
1049
1029
  id=sql_identifier.get_fully_qualified_name(job_database_name, job_schema_name, job_name),
1050
1030
  session=self._session,
1051
1031
  )
@@ -0,0 +1,23 @@
1
+ from typing import Optional
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from snowflake.ml.model._client.service import model_deployment_spec_schema
6
+
7
+ BaseModel.model_config["protected_namespaces"] = ()
8
+
9
+
10
+ class ModelName(BaseModel):
11
+ model_name: str
12
+ version_name: str
13
+
14
+
15
+ class ModelSpec(BaseModel):
16
+ name: ModelName
17
+ hf_model: Optional[model_deployment_spec_schema.HuggingFaceModel] = None
18
+ log_model_args: Optional[model_deployment_spec_schema.LogModelArgs] = None
19
+
20
+
21
+ class ImportModelSpec(BaseModel):
22
+ compute_pool: str
23
+ models: list[ModelSpec]
@@ -195,6 +195,7 @@ class ModelDeploymentSpec:
195
195
 
196
196
  def add_job_spec(
197
197
  self,
198
+ *,
198
199
  job_name: sql_identifier.SqlIdentifier,
199
200
  inference_compute_pool_name: sql_identifier.SqlIdentifier,
200
201
  function_name: str,
@@ -202,6 +203,8 @@ class ModelDeploymentSpec:
202
203
  output_stage_location: str,
203
204
  completion_filename: str,
204
205
  input_file_pattern: str,
206
+ column_handling: Optional[str] = None,
207
+ params: Optional[str] = None,
205
208
  warehouse: sql_identifier.SqlIdentifier,
206
209
  job_database_name: Optional[sql_identifier.SqlIdentifier] = None,
207
210
  job_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
@@ -217,14 +220,16 @@ class ModelDeploymentSpec:
217
220
  Args:
218
221
  job_name: Name of the job.
219
222
  inference_compute_pool_name: Compute pool for inference.
220
- warehouse: Warehouse for the job.
221
223
  function_name: Function name.
222
224
  input_stage_location: Stage location for input data.
223
225
  output_stage_location: Stage location for output data.
226
+ completion_filename: Name of completion file (default: "completion.txt").
227
+ input_file_pattern: Pattern for input files (optional).
228
+ column_handling: Column handling mode for input data.
229
+ params: Additional parameters for the job.
230
+ warehouse: Warehouse for the job.
224
231
  job_database_name: Database name for the job.
225
232
  job_schema_name: Schema name for the job.
226
- input_file_pattern: Pattern for input files (optional).
227
- completion_filename: Name of completion file (default: "completion.txt").
228
233
  cpu: CPU requirement.
229
234
  memory: Memory requirement.
230
235
  gpu: GPU requirement.
@@ -259,7 +264,10 @@ class ModelDeploymentSpec:
259
264
  warehouse=warehouse.identifier() if warehouse else None,
260
265
  function_name=function_name,
261
266
  input=model_deployment_spec_schema.Input(
262
- input_stage_location=input_stage_location, input_file_pattern=input_file_pattern
267
+ input_stage_location=input_stage_location,
268
+ input_file_pattern=input_file_pattern,
269
+ column_handling=column_handling,
270
+ params=params,
263
271
  ),
264
272
  output=model_deployment_spec_schema.Output(
265
273
  output_stage_location=output_stage_location,
@@ -39,6 +39,8 @@ class Service(BaseModel):
39
39
  class Input(BaseModel):
40
40
  input_stage_location: str
41
41
  input_file_pattern: str
42
+ column_handling: Optional[str] = None
43
+ params: Optional[str] = None
42
44
 
43
45
 
44
46
  class Output(BaseModel):
@@ -74,6 +76,7 @@ class HuggingFaceModel(BaseModel):
74
76
  task: Optional[str] = None
75
77
  tokenizer: Optional[str] = None
76
78
  token: Optional[str] = None
79
+ token_secret_object: Optional[str] = None
77
80
  trust_remote_code: Optional[bool] = False
78
81
  revision: Optional[str] = None
79
82
  hf_model_kwargs: Optional[str] = "{}"
@@ -22,6 +22,14 @@ def _normalize_url_for_sql(url: str) -> str:
22
22
  return f"'{url}'"
23
23
 
24
24
 
25
+ def _format_param_value(value: Any) -> str:
26
+ if isinstance(value, str):
27
+ return f"'{snowpark_utils.escape_single_quotes(value)}'" # type: ignore[no-untyped-call]
28
+ elif value is None:
29
+ return "NULL"
30
+ return str(value)
31
+
32
+
25
33
  class ModelVersionSQLClient(_base._BaseSQLClient):
26
34
  FUNCTION_NAME_COL_NAME = "name"
27
35
  FUNCTION_RETURN_TYPE_COL_NAME = "return_type"
@@ -354,6 +362,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
354
362
  input_args: list[sql_identifier.SqlIdentifier],
355
363
  returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
356
364
  statement_params: Optional[dict[str, Any]] = None,
365
+ params: Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]] = None,
357
366
  ) -> dataframe.DataFrame:
358
367
  with_statements = []
359
368
  if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
@@ -392,10 +401,17 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
392
401
 
393
402
  args_sql = ", ".join(args_sql_list)
394
403
 
395
- wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
404
+ if params:
405
+ param_sql = ", ".join(_format_param_value(val) for _, val in params)
406
+ args_sql = f"{args_sql}, {param_sql}" if args_sql else param_sql
407
+
408
+ total_args = len(input_args) + (len(params) if params else 0)
409
+ wide_input = total_args > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
396
410
  if wide_input:
397
- input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
398
- args_sql = f"object_construct_keep_null({input_args_sql})"
411
+ parts = [f"'{arg}', {arg.identifier()}" for arg in input_args]
412
+ if params:
413
+ parts.extend(f"'{name}', {_format_param_value(val)}" for name, val in params)
414
+ args_sql = f"object_construct_keep_null({', '.join(parts)})"
399
415
 
400
416
  sql = textwrap.dedent(
401
417
  f"""WITH {','.join(with_statements)}
@@ -439,6 +455,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
439
455
  statement_params: Optional[dict[str, Any]] = None,
440
456
  is_partitioned: bool = True,
441
457
  explain_case_sensitive: bool = False,
458
+ params: Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]] = None,
442
459
  ) -> dataframe.DataFrame:
443
460
  with_statements = []
444
461
  if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
@@ -477,10 +494,17 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
477
494
 
478
495
  args_sql = ", ".join(args_sql_list)
479
496
 
480
- wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
497
+ if params:
498
+ param_sql = ", ".join(_format_param_value(val) for _, val in params)
499
+ args_sql = f"{args_sql}, {param_sql}" if args_sql else param_sql
500
+
501
+ total_args = len(input_args) + (len(params) if params else 0)
502
+ wide_input = total_args > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
481
503
  if wide_input:
482
- input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
483
- args_sql = f"object_construct_keep_null({input_args_sql})"
504
+ parts = [f"'{arg}', {arg.identifier()}" for arg in input_args]
505
+ if params:
506
+ parts.extend(f"'{name}', {_format_param_value(val)}" for name, val in params)
507
+ args_sql = f"object_construct_keep_null({', '.join(parts)})"
484
508
 
485
509
  sql = textwrap.dedent(
486
510
  f"""WITH {','.join(with_statements)}
@@ -3,7 +3,9 @@ import dataclasses
3
3
  import enum
4
4
  import logging
5
5
  import textwrap
6
- from typing import Any, Generator, Optional
6
+ from typing import Any, Generator, Optional, cast
7
+
8
+ import yaml
7
9
 
8
10
  from snowflake import snowpark
9
11
  from snowflake.ml._internal.utils import (
@@ -18,6 +20,15 @@ from snowflake.snowpark._internal import utils as snowpark_utils
18
20
 
19
21
  logger = logging.getLogger(__name__)
20
22
 
23
+
24
+ def _format_param_value(value: Any) -> str:
25
+ if isinstance(value, str):
26
+ return f"'{snowpark_utils.escape_single_quotes(value)}'" # type: ignore[no-untyped-call]
27
+ elif value is None:
28
+ return "NULL"
29
+ return str(value)
30
+
31
+
21
32
  # Using this token instead of '?' to avoid escaping issues
22
33
  # After quotes are escaped, we replace this token with '|| ? ||'
23
34
  QMARK_RESERVED_TOKEN = "<QMARK_RESERVED_TOKEN>"
@@ -68,6 +79,7 @@ class ServiceStatusInfo:
68
79
 
69
80
  class ServiceSQLClient(_base._BaseSQLClient):
70
81
  MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME = "name"
82
+ MODEL_INFERENCE_SERVICE_ENDPOINT_PORT_COL_NAME = "port"
71
83
  MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME = "ingress_url"
72
84
  MODEL_INFERENCE_SERVICE_ENDPOINT_PRIVATELINK_INGRESS_URL_COL_NAME = "privatelink_ingress_url"
73
85
  SERVICE_STATUS = "service_status"
@@ -75,6 +87,14 @@ class ServiceSQLClient(_base._BaseSQLClient):
75
87
  INSTANCE_STATUS = "instance_status"
76
88
  CONTAINER_STATUS = "status"
77
89
  MESSAGE = "message"
90
+ DESC_SERVICE_INTERNAL_DNS_COL_NAME = "dns_name"
91
+ DESC_SERVICE_SPEC_COL_NAME = "spec"
92
+ DESC_SERVICE_CONTAINERS_SPEC_NAME = "containers"
93
+ DESC_SERVICE_NAME_SPEC_NAME = "name"
94
+ DESC_SERVICE_PROXY_SPEC_ENV_NAME = "env"
95
+ PROXY_CONTAINER_NAME = "proxy"
96
+ MODEL_INFERENCE_AUTOCAPTURE_ENV_NAME = "SPCS_MODEL_INFERENCE_SERVER__AUTOCAPTURE_ENABLED"
97
+ FEATURE_MODEL_INFERENCE_AUTOCAPTURE = "FEATURE_MODEL_INFERENCE_AUTOCAPTURE"
78
98
 
79
99
  @contextlib.contextmanager
80
100
  def _qmark_paramstyle(self) -> Generator[None, None, None]:
@@ -129,6 +149,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
129
149
  input_args: list[sql_identifier.SqlIdentifier],
130
150
  returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
131
151
  statement_params: Optional[dict[str, Any]] = None,
152
+ params: Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]] = None,
132
153
  ) -> dataframe.DataFrame:
133
154
  with_statements = []
134
155
  actual_database_name = database_name or self._database_name
@@ -159,10 +180,17 @@ class ServiceSQLClient(_base._BaseSQLClient):
159
180
  args_sql_list.append(input_arg_value)
160
181
  args_sql = ", ".join(args_sql_list)
161
182
 
162
- wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
183
+ if params:
184
+ param_sql = ", ".join(_format_param_value(val) for _, val in params)
185
+ args_sql = f"{args_sql}, {param_sql}" if args_sql else param_sql
186
+
187
+ total_args = len(input_args) + (len(params) if params else 0)
188
+ wide_input = total_args > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
163
189
  if wide_input:
164
- input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
165
- args_sql = f"object_construct_keep_null({input_args_sql})"
190
+ parts = [f"'{arg}', {arg.identifier()}" for arg in input_args]
191
+ if params:
192
+ parts.extend(f"'{name}', {_format_param_value(val)}" for name, val in params)
193
+ args_sql = f"object_construct_keep_null({', '.join(parts)})"
166
194
 
167
195
  fully_qualified_service_name = self.fully_qualified_object_name(
168
196
  actual_database_name, actual_schema_name, service_name
@@ -233,7 +261,15 @@ class ServiceSQLClient(_base._BaseSQLClient):
233
261
  ) -> list[ServiceStatusInfo]:
234
262
  fully_qualified_object_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
235
263
  query = f"SHOW SERVICE CONTAINERS IN SERVICE {fully_qualified_object_name}"
236
- rows = self._session.sql(query).collect(statement_params=statement_params)
264
+ rows = (
265
+ query_result_checker.SqlResultValidator(self._session, query, statement_params=statement_params)
266
+ .has_column(ServiceSQLClient.INSTANCE_STATUS)
267
+ .has_column(ServiceSQLClient.CONTAINER_STATUS)
268
+ .has_column(ServiceSQLClient.SERVICE_STATUS)
269
+ .has_column(ServiceSQLClient.INSTANCE_ID)
270
+ .has_column(ServiceSQLClient.MESSAGE)
271
+ .validate()
272
+ )
237
273
  statuses = []
238
274
  for r in rows:
239
275
  instance_status, container_status = None, None
@@ -252,6 +288,58 @@ class ServiceSQLClient(_base._BaseSQLClient):
252
288
  )
253
289
  return statuses
254
290
 
291
+ def describe_service(
292
+ self,
293
+ *,
294
+ database_name: Optional[sql_identifier.SqlIdentifier],
295
+ schema_name: Optional[sql_identifier.SqlIdentifier],
296
+ service_name: sql_identifier.SqlIdentifier,
297
+ statement_params: Optional[dict[str, Any]] = None,
298
+ ) -> row.Row:
299
+ fully_qualified_object_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
300
+ query = f"DESCRIBE SERVICE {fully_qualified_object_name}"
301
+ rows = (
302
+ query_result_checker.SqlResultValidator(self._session, query, statement_params=statement_params)
303
+ .has_dimensions(expected_rows=1)
304
+ .has_column(ServiceSQLClient.DESC_SERVICE_INTERNAL_DNS_COL_NAME)
305
+ .validate()
306
+ )
307
+ return rows[0]
308
+
309
+ def get_proxy_container_autocapture(self, row: row.Row) -> bool:
310
+ """Extract whether service has autocapture enabled from proxy container spec.
311
+
312
+ Args:
313
+ row: A row.Row object from DESCRIBE SERVICE containing the service YAML spec.
314
+
315
+ Returns:
316
+ True if autocapture is enabled in proxy spec
317
+ False if disabled or not set in proxy spec
318
+ False if service doesn't have proxy container
319
+ """
320
+ try:
321
+ spec_yaml = row[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME]
322
+ if spec_yaml is None:
323
+ return False
324
+ spec_raw = yaml.safe_load(spec_yaml)
325
+ if spec_raw is None:
326
+ return False
327
+ spec = cast(dict[str, Any], spec_raw)
328
+
329
+ proxy_container_spec = next(
330
+ container
331
+ for container in spec[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME][
332
+ ServiceSQLClient.DESC_SERVICE_CONTAINERS_SPEC_NAME
333
+ ]
334
+ if container[ServiceSQLClient.DESC_SERVICE_NAME_SPEC_NAME] == ServiceSQLClient.PROXY_CONTAINER_NAME
335
+ )
336
+ env = proxy_container_spec.get(ServiceSQLClient.DESC_SERVICE_PROXY_SPEC_ENV_NAME, {})
337
+ autocapture_enabled = env.get(ServiceSQLClient.MODEL_INFERENCE_AUTOCAPTURE_ENV_NAME, "false")
338
+ return str(autocapture_enabled).lower() == "true"
339
+
340
+ except StopIteration:
341
+ return False
342
+
255
343
  def drop_service(
256
344
  self,
257
345
  *,
@@ -282,6 +370,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
282
370
  statement_params=statement_params,
283
371
  )
284
372
  .has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME, allow_empty=True)
373
+ .has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_PORT_COL_NAME, allow_empty=True)
285
374
  .has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME, allow_empty=True)
286
375
  )
287
376
 
@@ -131,7 +131,7 @@ class ModelComposer:
131
131
  python_version: Optional[str] = None,
132
132
  user_files: Optional[dict[str, list[str]]] = None,
133
133
  ext_modules: Optional[list[ModuleType]] = None,
134
- code_paths: Optional[list[str]] = None,
134
+ code_paths: Optional[list[model_types.CodePathLike]] = None,
135
135
  task: model_types.Task = model_types.Task.UNKNOWN,
136
136
  experiment_info: Optional["ExperimentInfo"] = None,
137
137
  options: Optional[model_types.ModelSaveOption] = None,
@@ -39,6 +39,10 @@ class ModelMethodSignatureFieldWithName(ModelMethodSignatureField):
39
39
  name: Required[str]
40
40
 
41
41
 
42
+ class ModelMethodSignatureFieldWithNameAndDefault(ModelMethodSignatureFieldWithName):
43
+ default: Required[Any]
44
+
45
+
42
46
  class ModelFunctionMethodDict(TypedDict):
43
47
  name: Required[str]
44
48
  runtime: Required[str]
@@ -46,6 +50,7 @@ class ModelFunctionMethodDict(TypedDict):
46
50
  handler: Required[str]
47
51
  inputs: Required[list[ModelMethodSignatureFieldWithName]]
48
52
  outputs: Required[Union[list[ModelMethodSignatureField], list[ModelMethodSignatureFieldWithName]]]
53
+ params: NotRequired[list[ModelMethodSignatureFieldWithNameAndDefault]]
49
54
  volatility: NotRequired[str]
50
55
 
51
56