snowflake-ml-python 1.17.0__py3-none-any.whl → 1.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. snowflake/ml/_internal/telemetry.py +3 -2
  2. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +18 -19
  3. snowflake/ml/experiment/callback/keras.py +3 -0
  4. snowflake/ml/experiment/callback/lightgbm.py +3 -0
  5. snowflake/ml/experiment/callback/xgboost.py +3 -0
  6. snowflake/ml/experiment/experiment_tracking.py +50 -70
  7. snowflake/ml/feature_store/feature_store.py +299 -69
  8. snowflake/ml/feature_store/feature_view.py +12 -6
  9. snowflake/ml/fileset/stage_fs.py +12 -1
  10. snowflake/ml/jobs/_utils/constants.py +12 -1
  11. snowflake/ml/jobs/_utils/payload_utils.py +7 -1
  12. snowflake/ml/jobs/_utils/stage_utils.py +4 -0
  13. snowflake/ml/jobs/_utils/types.py +5 -0
  14. snowflake/ml/jobs/job.py +19 -5
  15. snowflake/ml/jobs/manager.py +18 -7
  16. snowflake/ml/model/__init__.py +19 -0
  17. snowflake/ml/model/_client/model/batch_inference_specs.py +63 -0
  18. snowflake/ml/model/_client/model/inference_engine_utils.py +1 -5
  19. snowflake/ml/model/_client/model/model_version_impl.py +129 -11
  20. snowflake/ml/model/_client/ops/model_ops.py +11 -4
  21. snowflake/ml/model/_client/ops/service_ops.py +3 -0
  22. snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
  23. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -1
  25. snowflake/ml/model/_model_composer/model_method/model_method.py +4 -1
  26. snowflake/ml/model/_packager/model_handlers/_utils.py +70 -0
  27. snowflake/ml/model/_packager/model_handlers/prophet.py +566 -0
  28. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -0
  29. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  30. snowflake/ml/model/type_hints.py +16 -0
  31. snowflake/ml/modeling/metrics/metrics_utils.py +9 -2
  32. snowflake/ml/monitoring/explain_visualize.py +3 -1
  33. snowflake/ml/version.py +1 -1
  34. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.19.0.dist-info}/METADATA +50 -4
  35. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.19.0.dist-info}/RECORD +38 -37
  36. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.19.0.dist-info}/WHEEL +0 -0
  37. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.19.0.dist-info}/licenses/LICENSE.txt +0 -0
  38. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.19.0.dist-info}/top_level.txt +0 -0
@@ -488,10 +488,13 @@ class JobPayload:
488
488
  " comment = 'Created by snowflake.ml.jobs Python API'",
489
489
  params=[stage_name],
490
490
  )
491
-
491
+ payload_name = None
492
492
  # Upload payload to stage - organize into app/ subdirectory
493
493
  app_stage_path = stage_path.joinpath(constants.APP_STAGE_SUBPATH)
494
494
  if not isinstance(source, types.PayloadPath):
495
+ if isinstance(source, function_payload_utils.FunctionPayload):
496
+ payload_name = source.function.__name__
497
+
495
498
  source_code = generate_python_code(source, source_code_display=True)
496
499
  _ = session.file.put_stream(
497
500
  io.BytesIO(source_code.encode()),
@@ -502,12 +505,14 @@ class JobPayload:
502
505
  source = Path(entrypoint.file_path.parent)
503
506
 
504
507
  elif isinstance(source, stage_utils.StagePath):
508
+ payload_name = entrypoint.file_path.stem
505
509
  # copy payload to stage
506
510
  if source == entrypoint.file_path:
507
511
  source = source.parent
508
512
  upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
509
513
 
510
514
  elif isinstance(source, Path):
515
+ payload_name = entrypoint.file_path.stem
511
516
  upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
512
517
  if source.is_file():
513
518
  source = source.parent
@@ -562,6 +567,7 @@ class JobPayload:
562
567
  *python_entrypoint,
563
568
  ],
564
569
  env_vars=env_vars,
570
+ payload_name=payload_name,
565
571
  )
566
572
 
567
573
 
@@ -32,6 +32,10 @@ class StagePath:
32
32
  self._root = self._raw_path[0:start].rstrip("/") if relpath else self._raw_path.rstrip("/")
33
33
  self._path = Path(relpath or "")
34
34
 
35
+ @property
36
+ def stem(self) -> str:
37
+ return self._path.stem
38
+
35
39
  @property
36
40
  def parts(self) -> tuple[str, ...]:
37
41
  return self._path.parts
@@ -23,6 +23,10 @@ class PayloadPath(Protocol):
23
23
  def name(self) -> str:
24
24
  ...
25
25
 
26
+ @property
27
+ def stem(self) -> str:
28
+ ...
29
+
26
30
  @property
27
31
  def suffix(self) -> str:
28
32
  ...
@@ -92,6 +96,7 @@ class UploadedPayload:
92
96
  stage_path: PurePath
93
97
  entrypoint: list[Union[str, PurePath]]
94
98
  env_vars: dict[str, str] = field(default_factory=dict)
99
+ payload_name: Optional[str] = None
95
100
 
96
101
 
97
102
  @dataclass(frozen=True)
snowflake/ml/jobs/job.py CHANGED
@@ -109,18 +109,18 @@ class MLJob(Generic[T], SerializableSessionMixin):
109
109
  return cast(dict[str, Any], container_spec)
110
110
 
111
111
  @property
112
- def _stage_path(self) -> str:
112
+ def _stage_path(self) -> Optional[str]:
113
113
  """Get the job's artifact storage stage location."""
114
114
  volumes = self._service_spec["spec"]["volumes"]
115
- stage_path = next(v for v in volumes if v["name"] == constants.STAGE_VOLUME_NAME)["source"]
116
- return cast(str, stage_path)
115
+ stage_volume = next((v for v in volumes if v["name"] == constants.STAGE_VOLUME_NAME), None)
116
+ return cast(str, stage_volume["source"]) if stage_volume else None
117
117
 
118
118
  @property
119
119
  def _result_path(self) -> str:
120
120
  """Get the job's result file location."""
121
121
  result_path_str = self._container_spec["env"].get(constants.RESULT_PATH_ENV_VAR)
122
122
  if result_path_str is None:
123
- raise RuntimeError(f"Job {self.name} doesn't have a result path configured")
123
+ raise NotImplementedError(f"Job {self.name} doesn't have a result path configured")
124
124
 
125
125
  return self._transform_path(result_path_str)
126
126
 
@@ -229,8 +229,22 @@ class MLJob(Generic[T], SerializableSessionMixin):
229
229
  Raises:
230
230
  TimeoutError: If the job does not complete within the specified timeout.
231
231
  """
232
- delay = constants.JOB_POLL_INITIAL_DELAY_SECONDS # Start with 100ms delay
233
232
  start_time = time.monotonic()
233
+ try:
234
+ # spcs_wait_for() is a synchronous query, it’s more effective to do polling with exponential
235
+ # backoff. If the job is running for a long time. We want a hybrid option: use spcs_wait_for()
236
+ # for the first 30 seconds, then switch to polling for long running jobs
237
+ min_timeout = (
238
+ int(min(timeout, constants.JOB_SPCS_TIMEOUT_SECONDS))
239
+ if timeout >= 0
240
+ else constants.JOB_SPCS_TIMEOUT_SECONDS
241
+ )
242
+ query_helper.run_query(self._session, f"call {self.id}!spcs_wait_for('DONE', {min_timeout})")
243
+ return self.status
244
+ except SnowparkSQLException:
245
+ # if the function does not support for this environment
246
+ pass
247
+ delay: float = float(constants.JOB_POLL_INITIAL_DELAY_SECONDS) # Start with 5s delay
234
248
  warning_shown = False
235
249
  while (status := self.status) not in TERMINAL_JOB_STATUSES:
236
250
  elapsed = time.monotonic() - start_time
@@ -192,12 +192,12 @@ def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Sessio
192
192
  """Delete a job service from the backend. Status and logs will be lost."""
193
193
  job = job if isinstance(job, jb.MLJob) else get_job(job, session=session)
194
194
  session = job._session
195
- try:
196
- stage_path = job._stage_path
197
- session.sql(f"REMOVE {stage_path}/").collect()
198
- logger.info(f"Successfully cleaned up stage files for job {job.id} at {stage_path}")
199
- except Exception as e:
200
- logger.warning(f"Failed to clean up stage files for job {job.id}: {e}")
195
+ if job._stage_path:
196
+ try:
197
+ session.sql(f"REMOVE {job._stage_path}/").collect()
198
+ logger.debug(f"Successfully cleaned up stage files for job {job.id} at {job._stage_path}")
199
+ except Exception as e:
200
+ logger.warning(f"Failed to clean up stage files for job {job.id}: {e}")
201
201
  query_helper.run_query(session, "DROP SERVICE IDENTIFIER(?)", params=(job.id,))
202
202
 
203
203
 
@@ -697,10 +697,21 @@ def _do_submit_job_v2(
697
697
  "MIN_INSTANCES": min_instances,
698
698
  "ASYNC": use_async,
699
699
  }
700
+ if payload.payload_name:
701
+ job_options["GENERATE_SUFFIX"] = True
700
702
  job_options = {k: v for k, v in job_options.items() if v is not None}
701
703
 
702
704
  query_template = "CALL SYSTEM$EXECUTE_ML_JOB(?, ?, ?, ?)"
703
- params = [job_id, compute_pool, json.dumps(spec_options), json.dumps(job_options)]
705
+ if job_id:
706
+ database, schema, _ = identifier.parse_schema_level_object_identifier(job_id)
707
+ params = [
708
+ job_id
709
+ if payload.payload_name is None
710
+ else identifier.get_schema_level_object_identifier(database, schema, payload.payload_name) + "_",
711
+ compute_pool,
712
+ json.dumps(spec_options),
713
+ json.dumps(job_options),
714
+ ]
704
715
  actual_job_id = query_helper.run_query(session, query_template, params=params)[0][0]
705
716
 
706
717
  return get_job(actual_job_id, session=session)
@@ -1,3 +1,6 @@
1
+ import sys
2
+ import warnings
3
+
1
4
  from snowflake.ml.model._client.model.batch_inference_specs import (
2
5
  JobSpec,
3
6
  OutputSpec,
@@ -18,3 +21,19 @@ __all__ = [
18
21
  "SaveMode",
19
22
  "Volatility",
20
23
  ]
24
+
25
+ _deprecation_warning_msg_for_3_9 = (
26
+ "Python 3.9 is deprecated in snowflake-ml-python. " "Please upgrade to Python 3.10 or greater."
27
+ )
28
+
29
+ warnings.filterwarnings(
30
+ "once",
31
+ message=_deprecation_warning_msg_for_3_9,
32
+ )
33
+
34
+ if sys.version_info.major == 3 and sys.version_info.minor == 9:
35
+ warnings.warn(
36
+ _deprecation_warning_msg_for_3_9,
37
+ category=DeprecationWarning,
38
+ stacklevel=2,
39
+ )
@@ -19,11 +19,74 @@ class SaveMode(str, Enum):
19
19
 
20
20
 
21
21
  class OutputSpec(BaseModel):
22
+ """Specification for batch inference output.
23
+
24
+ Defines where the inference results should be written and how to handle
25
+ existing files at the output location.
26
+
27
+ Attributes:
28
+ stage_location (str): The stage path where batch inference results will be saved.
29
+ This should be a full path including the stage with @ prefix. For example,
30
+ '@My_DB.PUBLIC.MY_STAGE/someth/path/'. A non-existent directory will be re-created.
31
+ Only Snowflake internal stages are supported at this moment.
32
+ mode (SaveMode): The save mode that determines behavior when files already exist
33
+ at the output location. Defaults to SaveMode.ERROR which raises an error
34
+ if files exist. Can be set to SaveMode.OVERWRITE to replace existing files.
35
+
36
+ Example:
37
+ >>> output_spec = OutputSpec(
38
+ ... stage_location="@My_DB.PUBLIC.MY_STAGE/someth/path/",
39
+ ... mode=SaveMode.OVERWRITE
40
+ ... )
41
+ """
42
+
22
43
  stage_location: str
23
44
  mode: SaveMode = SaveMode.ERROR
24
45
 
25
46
 
26
47
  class JobSpec(BaseModel):
48
+ """Specification for batch inference job execution.
49
+
50
+ Defines the compute resources, job settings, and execution parameters
51
+ for running batch inference jobs in Snowflake.
52
+
53
+ Attributes:
54
+ image_repo (Optional[str]): Container image repository for the inference job.
55
+ If not specified, uses the default repository.
56
+ job_name (Optional[str]): Custom name for the batch inference job.
57
+ If not provided, a name will be auto-generated in the form of "BATCH_INFERENCE_<UUID>".
58
+ num_workers (Optional[int]): The number of workers to run the inference service for handling
59
+ requests in parallel within an instance of the service. By default, it is set to 2*vCPU+1
60
+ of the node for CPU based inference and 1 for GPU based inference. For GPU based inference,
61
+ please see best practices before playing with this value.
62
+ function_name (Optional[str]): Name of the specific function to call for inference.
63
+ Required when the model has multiple inference functions.
64
+ force_rebuild (bool): Whether to force rebuilding the container image even if
65
+ it already exists. Defaults to False.
66
+ max_batch_rows (int): Maximum number of rows to process in a single batch.
67
+ Defaults to 1024. Larger values may improve throughput.
68
+ warehouse (Optional[str]): Snowflake warehouse to use for the batch inference job.
69
+ If not specified, uses the session's current warehouse.
70
+ cpu_requests (Optional[str]): The cpu limit for CPU based inference. Can be an integer,
71
+ fractional or string values. If None, we attempt to utilize all the vCPU of the node.
72
+ memory_requests (Optional[str]): The memory limit for inference. Can be an integer
73
+ or a fractional value, but requires a unit (GiB, MiB). If None, we attempt to utilize all
74
+ the memory of the node.
75
+ gpu_requests (Optional[str]): The gpu limit for GPU based inference. Can be integer or
76
+ string values. Use CPU if None.
77
+ replicas (Optional[int]): Number of job replicas to run for high availability.
78
+ If not specified, defaults to 1 replica.
79
+
80
+ Example:
81
+ >>> job_spec = JobSpec(
82
+ ... job_name="my_inference_job",
83
+ ... num_workers=4,
84
+ ... cpu_requests="2",
85
+ ... memory_requests="8Gi",
86
+ ... max_batch_rows=2048
87
+ ... )
88
+ """
89
+
27
90
  image_repo: Optional[str] = None
28
91
  job_name: Optional[str] = None
29
92
  num_workers: Optional[int] = None
@@ -6,13 +6,9 @@ from snowflake.ml.model._client.ops import service_ops
6
6
  def _get_inference_engine_args(
7
7
  experimental_options: Optional[dict[str, Any]],
8
8
  ) -> Optional[service_ops.InferenceEngineArgs]:
9
-
10
- if not experimental_options:
9
+ if not experimental_options or "inference_engine" not in experimental_options:
11
10
  return None
12
11
 
13
- if "inference_engine" not in experimental_options:
14
- raise ValueError("inference_engine is required in experimental_options")
15
-
16
12
  return service_ops.InferenceEngineArgs(
17
13
  inference_engine=experimental_options["inference_engine"],
18
14
  inference_engine_args_override=experimental_options.get("inference_engine_args_override"),
@@ -7,6 +7,7 @@ from typing import Any, Callable, Optional, Union, overload
7
7
 
8
8
  import pandas as pd
9
9
 
10
+ from snowflake import snowpark
10
11
  from snowflake.ml import jobs
11
12
  from snowflake.ml._internal import telemetry
12
13
  from snowflake.ml._internal.utils import sql_identifier
@@ -593,7 +594,8 @@ class ModelVersion(lineage_node.LineageNode):
593
594
  "job_spec",
594
595
  ],
595
596
  )
596
- def _run_batch(
597
+ @snowpark._internal.utils.private_preview(version="1.18.0")
598
+ def run_batch(
597
599
  self,
598
600
  *,
599
601
  compute_pool: str,
@@ -601,6 +603,68 @@ class ModelVersion(lineage_node.LineageNode):
601
603
  output_spec: batch_inference_specs.OutputSpec,
602
604
  job_spec: Optional[batch_inference_specs.JobSpec] = None,
603
605
  ) -> jobs.MLJob[Any]:
606
+ """Execute batch inference on datasets as an SPCS job.
607
+
608
+ Args:
609
+ compute_pool (str): Name of the compute pool to use for building the image containers and batch
610
+ inference execution.
611
+ input_spec (dataframe.DataFrame): Snowpark DataFrame containing the input data for inference.
612
+ The DataFrame should contain all required features for model prediction and passthrough columns.
613
+ output_spec (batch_inference_specs.OutputSpec): Configuration for where and how to save
614
+ the inference results. Specifies the stage location and file handling behavior.
615
+ job_spec (Optional[batch_inference_specs.JobSpec]): Optional configuration for job
616
+ execution parameters such as compute resources, worker counts, and job naming.
617
+ If None, default values will be used.
618
+
619
+ Returns:
620
+ jobs.MLJob[Any]: A batch inference job object that can be used to monitor progress and manage the job
621
+ lifecycle.
622
+
623
+ Raises:
624
+ ValueError: If warehouse is not set in job_spec and no current warehouse is available.
625
+ RuntimeError: If the input_spec cannot be processed or written to the staging location.
626
+
627
+ Example:
628
+ >>> # Prepare input data - Example 1: From a table
629
+ >>> input_df = session.table("my_input_table")
630
+ >>>
631
+ >>> # Prepare input data - Example 2: From a SQL query
632
+ >>> input_df = session.sql(
633
+ ... "SELECT id, feature_1, feature_2 FROM feature_table WHERE feature_1 > 100"
634
+ ... )
635
+ >>>
636
+ >>> # Prepare input data - Example 3: From Parquet files in a stage
637
+ >>> input_df = session.read.option("pattern", ".*\\.parquet").parquet(
638
+ ... "@my_stage/input_data/"
639
+ ... ).select("id", "feature_1", "feature_2")
640
+ >>>
641
+ >>> # Configure output location
642
+ >>> output_spec = OutputSpec(
643
+ ... stage_location='@My_DB.PUBLIC.MY_STAGE/someth/path/',
644
+ ... mode=SaveMode.OVERWRITE
645
+ ... )
646
+ >>>
647
+ >>> # Configure job parameters
648
+ >>> job_spec = JobSpec(
649
+ ... job_name="my_batch_inference",
650
+ ... num_workers=4,
651
+ ... cpu_requests="2",
652
+ ... memory_requests="8Gi"
653
+ ... )
654
+ >>>
655
+ >>> # Run batch inference
656
+ >>> job = model_version.run_batch(
657
+ ... compute_pool="my_compute_pool",
658
+ ... input_spec=input_df,
659
+ ... output_spec=output_spec,
660
+ ... job_spec=job_spec
661
+ ... )
662
+
663
+ Note:
664
+ This method is currently in private preview and requires Snowflake version 1.18.0 or later.
665
+ The input data is temporarily stored in the output stage location under /_temporary before
666
+ inference execution.
667
+ """
604
668
  statement_params = telemetry.get_statement_params(
605
669
  project=_TELEMETRY_PROJECT,
606
670
  subproject=_TELEMETRY_SUBPROJECT,
@@ -827,6 +891,51 @@ class ModelVersion(lineage_node.LineageNode):
827
891
  version_name=sql_identifier.SqlIdentifier(version),
828
892
  )
829
893
 
894
+ def _can_run_on_gpu(
895
+ self,
896
+ statement_params: Optional[dict[str, Any]] = None,
897
+ ) -> bool:
898
+ """Check if the model has GPU runtime support.
899
+
900
+ Args:
901
+ statement_params: Optional dictionary of statement parameters to include
902
+ in the SQL command to fetch model spec.
903
+
904
+ Returns:
905
+ True if the model has GPU runtime configured, False otherwise.
906
+ """
907
+ # Fetch model spec
908
+ model_spec = self._get_model_spec(statement_params)
909
+
910
+ # Check if runtimes section exists and has gpu runtime
911
+ runtimes = model_spec.get("runtimes", {})
912
+ return "gpu" in runtimes
913
+
914
+ def _throw_error_if_gpu_is_not_supported(
915
+ self,
916
+ gpu_requests: Optional[Union[str, int]] = None,
917
+ statement_params: Optional[dict[str, Any]] = None,
918
+ ) -> None:
919
+ """Check if the model has GPU runtime support.
920
+
921
+ Args:
922
+ gpu_requests: The gpu limit for GPU based inference. Can be integer, fractional or string values. Use CPU
923
+ if None.
924
+ statement_params: Optional dictionary of statement parameters to include
925
+ in the SQL command to fetch model spec.
926
+
927
+ Raises:
928
+ ValueError: If the model does not have GPU runtime support.
929
+ """
930
+ if gpu_requests is not None and not self._can_run_on_gpu(statement_params):
931
+ raise ValueError(
932
+ f"GPU resources requested (gpu_requests={gpu_requests}), but the model "
933
+ f"{self.fully_qualified_model_name} version {self.version_name} does not have GPU runtime support. "
934
+ "Please ensure the model was logged with GPU runtime configuration or do not provide gpu_requests. "
935
+ "To log the model with GPU runtime configuration, provide `cuda_version` in the `options` while calling"
936
+ " the `log_model` function."
937
+ )
938
+
830
939
  def _check_huggingface_text_generation_model(
831
940
  self,
832
941
  statement_params: Optional[dict[str, Any]] = None,
@@ -926,9 +1035,10 @@ class ModelVersion(lineage_node.LineageNode):
926
1035
  When it is ``False``, this function executes the underlying service creation asynchronously
927
1036
  and returns an :class:`AsyncJob`.
928
1037
  experimental_options: Experimental options for the service creation with custom inference engine.
929
- Currently, only `inference_engine` and `inference_engine_args_override` are supported.
1038
+ Currently, `inference_engine`, `inference_engine_args_override`, and `autocapture` are supported.
930
1039
  `inference_engine` is the name of the inference engine to use.
931
1040
  `inference_engine_args_override` is a list of string arguments to pass to the inference engine.
1041
+ `autocapture` is a boolean to enable/disable inference table.
932
1042
  """
933
1043
  ...
934
1044
 
@@ -984,9 +1094,10 @@ class ModelVersion(lineage_node.LineageNode):
984
1094
  When it is ``False``, this function executes the underlying service creation asynchronously
985
1095
  and returns an :class:`AsyncJob`.
986
1096
  experimental_options: Experimental options for the service creation with custom inference engine.
987
- Currently, only `inference_engine` and `inference_engine_args_override` are supported.
1097
+ Currently, `inference_engine`, `inference_engine_args_override`, and `autocapture` are supported.
988
1098
  `inference_engine` is the name of the inference engine to use.
989
1099
  `inference_engine_args_override` is a list of string arguments to pass to the inference engine.
1100
+ `autocapture` is a boolean to enable/disable inference table.
990
1101
  """
991
1102
  ...
992
1103
 
@@ -1059,21 +1170,20 @@ class ModelVersion(lineage_node.LineageNode):
1059
1170
  When it is False, this function executes the underlying service creation asynchronously
1060
1171
  and returns an AsyncJob.
1061
1172
  experimental_options: Experimental options for the service creation with custom inference engine.
1062
- Currently, only `inference_engine` and `inference_engine_args_override` are supported.
1173
+ Currently, `inference_engine`, `inference_engine_args_override`, and `autocapture` are supported.
1063
1174
  `inference_engine` is the name of the inference engine to use.
1064
1175
  `inference_engine_args_override` is a list of string arguments to pass to the inference engine.
1176
+ `autocapture` is a boolean to enable/disable inference table.
1065
1177
 
1066
1178
 
1067
1179
  Raises:
1068
- ValueError: Illegal external access integration arguments.
1180
+ ValueError: Illegal external access integration arguments, or if GPU resources are requested
1181
+ but the model does not have GPU runtime support.
1069
1182
  exceptions.SnowparkSQLException: if service already exists.
1070
1183
 
1071
1184
  Returns:
1072
1185
  If `block=True`, return result information about service creation from server.
1073
1186
  Otherwise, return the service creation AsyncJob.
1074
-
1075
- Raises:
1076
- ValueError: Illegal external access integration arguments.
1077
1187
  """
1078
1188
  statement_params = telemetry.get_statement_params(
1079
1189
  project=_TELEMETRY_PROJECT,
@@ -1096,12 +1206,16 @@ class ModelVersion(lineage_node.LineageNode):
1096
1206
 
1097
1207
  service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
1098
1208
 
1099
- # Check if model is HuggingFace text-generation before doing inference engine checks
1100
- if experimental_options:
1101
- self._check_huggingface_text_generation_model(statement_params)
1209
+ # Validate GPU support if GPU resources are requested
1210
+ self._throw_error_if_gpu_is_not_supported(gpu_requests, statement_params)
1102
1211
 
1103
1212
  inference_engine_args = inference_engine_utils._get_inference_engine_args(experimental_options)
1104
1213
 
1214
+ # Check if model is HuggingFace text-generation before doing inference engine checks
1215
+ # Only validate if inference engine is actually specified
1216
+ if inference_engine_args is not None:
1217
+ self._check_huggingface_text_generation_model(statement_params)
1218
+
1105
1219
  # Enrich inference engine args if inference engine is specified
1106
1220
  if inference_engine_args is not None:
1107
1221
  inference_engine_args = inference_engine_utils._enrich_inference_engine_args(
@@ -1109,6 +1223,9 @@ class ModelVersion(lineage_node.LineageNode):
1109
1223
  gpu_requests,
1110
1224
  )
1111
1225
 
1226
+ # Extract autocapture from experimental_options
1227
+ autocapture = experimental_options.get("autocapture") if experimental_options else None
1228
+
1112
1229
  from snowflake.ml.model import event_handler
1113
1230
  from snowflake.snowpark import exceptions
1114
1231
 
@@ -1148,6 +1265,7 @@ class ModelVersion(lineage_node.LineageNode):
1148
1265
  statement_params=statement_params,
1149
1266
  progress_status=status,
1150
1267
  inference_engine_args=inference_engine_args,
1268
+ autocapture=autocapture,
1151
1269
  )
1152
1270
  status.update(label="Model service created successfully", state="complete", expanded=False)
1153
1271
  return result
@@ -515,10 +515,17 @@ class ModelOperator:
515
515
  statement_params=statement_params,
516
516
  )
517
517
  for r in res:
518
- if alias_name in r[self._model_client.MODEL_VERSION_ALIASES_COL_NAME]:
519
- return sql_identifier.SqlIdentifier(
520
- r[self._model_client.MODEL_VERSION_NAME_COL_NAME], case_sensitive=True
521
- )
518
+ aliases_data = r[self._model_client.MODEL_VERSION_ALIASES_COL_NAME]
519
+ if aliases_data:
520
+ aliases_list = json.loads(aliases_data)
521
+
522
+ # Compare using Snowflake identifier semantics for exact match
523
+ for alias in aliases_list:
524
+ if sql_identifier.SqlIdentifier(alias) == alias_name:
525
+ return sql_identifier.SqlIdentifier(
526
+ r[self._model_client.MODEL_VERSION_NAME_COL_NAME], case_sensitive=True
527
+ )
528
+
522
529
  return None
523
530
 
524
531
  def get_tag_value(
@@ -206,6 +206,8 @@ class ServiceOperator:
206
206
  hf_model_args: Optional[HFModelArgs] = None,
207
207
  # inference engine model
208
208
  inference_engine_args: Optional[InferenceEngineArgs] = None,
209
+ # inference table
210
+ autocapture: Optional[bool] = None,
209
211
  ) -> Union[str, async_job.AsyncJob]:
210
212
 
211
213
  # Generate operation ID for this deployment
@@ -261,6 +263,7 @@ class ServiceOperator:
261
263
  gpu=gpu_requests,
262
264
  num_workers=num_workers,
263
265
  max_batch_rows=max_batch_rows,
266
+ autocapture=autocapture,
264
267
  )
265
268
  if hf_model_args:
266
269
  # hf model
@@ -146,6 +146,7 @@ class ModelDeploymentSpec:
146
146
  gpu: Optional[Union[str, int]] = None,
147
147
  num_workers: Optional[int] = None,
148
148
  max_batch_rows: Optional[int] = None,
149
+ autocapture: Optional[bool] = None,
149
150
  ) -> "ModelDeploymentSpec":
150
151
  """Add service specification to the deployment spec.
151
152
 
@@ -161,6 +162,7 @@ class ModelDeploymentSpec:
161
162
  gpu: GPU requirement.
162
163
  num_workers: Number of workers.
163
164
  max_batch_rows: Maximum batch rows for inference.
165
+ autocapture: Whether to enable inference table.
164
166
 
165
167
  Raises:
166
168
  ValueError: If a job spec already exists.
@@ -186,6 +188,7 @@ class ModelDeploymentSpec:
186
188
  compute_pool=inference_compute_pool_name.identifier(),
187
189
  ingress_enabled=ingress_enabled,
188
190
  max_instances=max_instances,
191
+ autocapture=autocapture,
189
192
  **self._inference_spec,
190
193
  )
191
194
  return self
@@ -32,6 +32,7 @@ class Service(BaseModel):
32
32
  gpu: Optional[str] = None
33
33
  num_workers: Optional[int] = None
34
34
  max_batch_rows: Optional[int] = None
35
+ autocapture: Optional[bool] = None
35
36
  inference_engine_spec: Optional[InferenceEngineSpec] = None
36
37
 
37
38
 
@@ -87,7 +87,9 @@ class ModelManifest:
87
87
  model_meta_schema.FunctionProperties.PARTITIONED.value, False
88
88
  ),
89
89
  wide_input=len(model_meta.signatures[target_method].inputs) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT,
90
- options=model_method.get_model_method_options_from_options(options, target_method),
90
+ options=model_method.get_model_method_options_from_options(
91
+ options, target_method, model_meta.model_type
92
+ ),
91
93
  )
92
94
 
93
95
  self.methods.append(method)
@@ -32,7 +32,7 @@ class ModelMethodOptions(TypedDict):
32
32
 
33
33
 
34
34
  def get_model_method_options_from_options(
35
- options: type_hints.ModelSaveOption, target_method: str
35
+ options: type_hints.ModelSaveOption, target_method: str, model_type: Optional[str] = None
36
36
  ) -> ModelMethodOptions:
37
37
  default_function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
38
38
  method_option = options.get("method_options", {}).get(target_method, {})
@@ -42,6 +42,9 @@ def get_model_method_options_from_options(
42
42
  case_sensitive = utils.determine_explain_case_sensitive_from_method_options(
43
43
  options.get("method_options", {}), target_method
44
44
  )
45
+ elif model_type == "prophet":
46
+ # Prophet models always require TABLE_FUNCTION because they need entire time series context
47
+ default_function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
45
48
  global_function_type = options.get("function_type", default_function_type)
46
49
  function_type = method_option.get("function_type", global_function_type)
47
50
  if function_type not in [function_type.value for function_type in model_manifest_schema.ModelMethodFunctionTypes]: