snowflake-ml-python 1.21.0__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. snowflake/ml/_internal/utils/url.py +42 -0
  2. snowflake/ml/experiment/callback/__init__.py +0 -0
  3. snowflake/ml/jobs/_utils/constants.py +1 -0
  4. snowflake/ml/jobs/_utils/spec_utils.py +0 -31
  5. snowflake/ml/lineage/lineage_node.py +1 -1
  6. snowflake/ml/model/__init__.py +6 -0
  7. snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
  8. snowflake/ml/model/_client/model/model_version_impl.py +63 -0
  9. snowflake/ml/model/_client/ops/deployment_step.py +36 -0
  10. snowflake/ml/model/_client/ops/model_ops.py +61 -2
  11. snowflake/ml/model/_client/ops/service_ops.py +23 -48
  12. snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
  13. snowflake/ml/model/_client/service/model_deployment_spec.py +12 -4
  14. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -0
  15. snowflake/ml/model/_client/sql/model_version.py +30 -6
  16. snowflake/ml/model/_client/sql/service.py +26 -4
  17. snowflake/ml/model/_model_composer/model_composer.py +1 -1
  18. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
  19. snowflake/ml/model/_model_composer/model_method/model_method.py +61 -2
  20. snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
  21. snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
  22. snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
  23. snowflake/ml/model/_packager/model_packager.py +1 -1
  24. snowflake/ml/model/_signatures/core.py +85 -0
  25. snowflake/ml/model/code_path.py +104 -0
  26. snowflake/ml/model/custom_model.py +55 -13
  27. snowflake/ml/model/model_signature.py +13 -1
  28. snowflake/ml/model/type_hints.py +2 -0
  29. snowflake/ml/registry/_manager/model_manager.py +230 -15
  30. snowflake/ml/registry/registry.py +4 -4
  31. snowflake/ml/version.py +1 -1
  32. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/METADATA +29 -1
  33. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/RECORD +36 -32
  34. snowflake/ml/jobs/_utils/runtime_env_utils.py +0 -63
  35. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/WHEEL +0 -0
  36. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/licenses/LICENSE.txt +0 -0
  37. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,4 @@
1
1
  import dataclasses
2
- import enum
3
- import hashlib
4
2
  import logging
5
3
  import pathlib
6
4
  import re
@@ -16,6 +14,7 @@ from snowflake.ml._internal.utils import identifier, service_logger, sql_identif
16
14
  from snowflake.ml.jobs import job
17
15
  from snowflake.ml.model import inference_engine as inference_engine_module, type_hints
18
16
  from snowflake.ml.model._client.model import batch_inference_specs
17
+ from snowflake.ml.model._client.ops import deployment_step
19
18
  from snowflake.ml.model._client.service import model_deployment_spec
20
19
  from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
21
20
  from snowflake.snowpark import async_job, exceptions, row, session
@@ -25,32 +24,12 @@ module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY
25
24
  module_logger.propagate = False
26
25
 
27
26
 
28
- class DeploymentStep(enum.Enum):
29
- MODEL_BUILD = ("model-build", "model_build_")
30
- MODEL_INFERENCE = ("model-inference", None)
31
- MODEL_LOGGING = ("model-logging", "model_logging_")
32
-
33
- def __init__(self, container_name: str, service_name_prefix: Optional[str]) -> None:
34
- self._container_name = container_name
35
- self._service_name_prefix = service_name_prefix
36
-
37
- @property
38
- def container_name(self) -> str:
39
- """Get the container name for the deployment step."""
40
- return self._container_name
41
-
42
- @property
43
- def service_name_prefix(self) -> Optional[str]:
44
- """Get the service name prefix for the deployment step."""
45
- return self._service_name_prefix
46
-
47
-
48
27
  @dataclasses.dataclass
49
28
  class ServiceLogInfo:
50
29
  database_name: Optional[sql_identifier.SqlIdentifier]
51
30
  schema_name: Optional[sql_identifier.SqlIdentifier]
52
31
  service_name: sql_identifier.SqlIdentifier
53
- deployment_step: DeploymentStep
32
+ deployment_step: deployment_step.DeploymentStep
54
33
  instance_id: str = "0"
55
34
  log_color: service_logger.LogColor = service_logger.LogColor.GREY
56
35
 
@@ -353,13 +332,16 @@ class ServiceOperator:
353
332
  if is_enable_image_build:
354
333
  # stream service logs in a thread
355
334
  model_build_service_name = sql_identifier.SqlIdentifier(
356
- self._get_service_id_from_deployment_step(query_id, DeploymentStep.MODEL_BUILD)
335
+ deployment_step.get_service_id_from_deployment_step(
336
+ query_id,
337
+ deployment_step.DeploymentStep.MODEL_BUILD,
338
+ )
357
339
  )
358
340
  model_build_service = ServiceLogInfo(
359
341
  database_name=service_database_name,
360
342
  schema_name=service_schema_name,
361
343
  service_name=model_build_service_name,
362
- deployment_step=DeploymentStep.MODEL_BUILD,
344
+ deployment_step=deployment_step.DeploymentStep.MODEL_BUILD,
363
345
  log_color=service_logger.LogColor.GREEN,
364
346
  )
365
347
 
@@ -367,21 +349,23 @@ class ServiceOperator:
367
349
  database_name=service_database_name,
368
350
  schema_name=service_schema_name,
369
351
  service_name=service_name,
370
- deployment_step=DeploymentStep.MODEL_INFERENCE,
352
+ deployment_step=deployment_step.DeploymentStep.MODEL_INFERENCE,
371
353
  log_color=service_logger.LogColor.BLUE,
372
354
  )
373
355
 
374
356
  model_logger_service: Optional[ServiceLogInfo] = None
375
357
  if hf_model_args:
376
358
  model_logger_service_name = sql_identifier.SqlIdentifier(
377
- self._get_service_id_from_deployment_step(query_id, DeploymentStep.MODEL_LOGGING)
359
+ deployment_step.get_service_id_from_deployment_step(
360
+ query_id, deployment_step.DeploymentStep.MODEL_LOGGING
361
+ )
378
362
  )
379
363
 
380
364
  model_logger_service = ServiceLogInfo(
381
365
  database_name=service_database_name,
382
366
  schema_name=service_schema_name,
383
367
  service_name=model_logger_service_name,
384
- deployment_step=DeploymentStep.MODEL_LOGGING,
368
+ deployment_step=deployment_step.DeploymentStep.MODEL_LOGGING,
385
369
  log_color=service_logger.LogColor.ORANGE,
386
370
  )
387
371
 
@@ -536,7 +520,7 @@ class ServiceOperator:
536
520
  service = service_log_meta.service
537
521
  # check if using an existing model build image
538
522
  if (
539
- service.deployment_step == DeploymentStep.MODEL_BUILD
523
+ service.deployment_step == deployment_step.DeploymentStep.MODEL_BUILD
540
524
  and not force_rebuild
541
525
  and service_log_meta.is_model_logger_service_done
542
526
  and not service_log_meta.is_model_build_service_done
@@ -582,16 +566,16 @@ class ServiceOperator:
582
566
  if (service_status != service_sql.ServiceStatus.RUNNING) or (service_status != service_log_meta.service_status):
583
567
  service_log_meta.service_status = service_status
584
568
 
585
- if service.deployment_step == DeploymentStep.MODEL_BUILD:
569
+ if service.deployment_step == deployment_step.DeploymentStep.MODEL_BUILD:
586
570
  module_logger.info(
587
571
  f"Image build service {service.display_service_name} is "
588
572
  f"{service_log_meta.service_status.value}."
589
573
  )
590
- elif service.deployment_step == DeploymentStep.MODEL_INFERENCE:
574
+ elif service.deployment_step == deployment_step.DeploymentStep.MODEL_INFERENCE:
591
575
  module_logger.info(
592
576
  f"Inference service {service.display_service_name} is {service_log_meta.service_status.value}."
593
577
  )
594
- elif service.deployment_step == DeploymentStep.MODEL_LOGGING:
578
+ elif service.deployment_step == deployment_step.DeploymentStep.MODEL_LOGGING:
595
579
  module_logger.info(
596
580
  f"Model logger service {service.display_service_name} is "
597
581
  f"{service_log_meta.service_status.value}."
@@ -627,7 +611,7 @@ class ServiceOperator:
627
611
  if service_status == service_sql.ServiceStatus.DONE:
628
612
  # check if model logger service is done
629
613
  # and transition the service log metadata to the model image build service
630
- if service.deployment_step == DeploymentStep.MODEL_LOGGING:
614
+ if service.deployment_step == deployment_step.DeploymentStep.MODEL_LOGGING:
631
615
  if model_build_service:
632
616
  # building the inference image, transition to the model build service
633
617
  service_log_meta.transition_service_log_metadata(
@@ -648,7 +632,7 @@ class ServiceOperator:
648
632
  )
649
633
  # check if model build service is done
650
634
  # and transition the service log metadata to the model inference service
651
- elif service.deployment_step == DeploymentStep.MODEL_BUILD:
635
+ elif service.deployment_step == deployment_step.DeploymentStep.MODEL_BUILD:
652
636
  service_log_meta.transition_service_log_metadata(
653
637
  model_inference_service,
654
638
  f"Image build service {service.display_service_name} complete.",
@@ -656,7 +640,7 @@ class ServiceOperator:
656
640
  is_model_logger_service_done=service_log_meta.is_model_logger_service_done,
657
641
  operation_id=operation_id,
658
642
  )
659
- elif service.deployment_step == DeploymentStep.MODEL_INFERENCE:
643
+ elif service.deployment_step == deployment_step.DeploymentStep.MODEL_INFERENCE:
660
644
  module_logger.info(f"Inference service {service.display_service_name} is deployed.")
661
645
  else:
662
646
  module_logger.warning(f"Service {service.display_service_name} is done, but not transitioning.")
@@ -916,19 +900,6 @@ class ServiceOperator:
916
900
 
917
901
  time.sleep(2) # Poll every 2 seconds
918
902
 
919
- @staticmethod
920
- def _get_service_id_from_deployment_step(query_id: str, deployment_step: DeploymentStep) -> str:
921
- """Get the service ID through the server-side logic."""
922
- uuid = query_id.replace("-", "")
923
- big_int = int(uuid, 16)
924
- md5_hash = hashlib.md5(str(big_int).encode()).hexdigest()
925
- identifier = md5_hash[:8]
926
- service_name_prefix = deployment_step.service_name_prefix
927
- if service_name_prefix is None:
928
- # raise an exception if the service name prefix is None
929
- raise ValueError(f"Service name prefix is {service_name_prefix} for deployment step {deployment_step}.")
930
- return (service_name_prefix + identifier).upper()
931
-
932
903
  def _check_if_service_exists(
933
904
  self,
934
905
  database_name: Optional[sql_identifier.SqlIdentifier],
@@ -971,6 +942,8 @@ class ServiceOperator:
971
942
  image_repo_name: Optional[str],
972
943
  input_stage_location: str,
973
944
  input_file_pattern: str,
945
+ column_handling: Optional[str],
946
+ params: Optional[str],
974
947
  output_stage_location: str,
975
948
  completion_filename: str,
976
949
  force_rebuild: bool,
@@ -1007,6 +980,8 @@ class ServiceOperator:
1007
980
  max_batch_rows=max_batch_rows,
1008
981
  input_stage_location=input_stage_location,
1009
982
  input_file_pattern=input_file_pattern,
983
+ column_handling=column_handling,
984
+ params=params,
1010
985
  output_stage_location=output_stage_location,
1011
986
  completion_filename=completion_filename,
1012
987
  function_name=function_name,
@@ -0,0 +1,23 @@
1
+ from typing import Optional
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from snowflake.ml.model._client.service import model_deployment_spec_schema
6
+
7
+ BaseModel.model_config["protected_namespaces"] = ()
8
+
9
+
10
+ class ModelName(BaseModel):
11
+ model_name: str
12
+ version_name: str
13
+
14
+
15
+ class ModelSpec(BaseModel):
16
+ name: ModelName
17
+ hf_model: Optional[model_deployment_spec_schema.HuggingFaceModel] = None
18
+ log_model_args: Optional[model_deployment_spec_schema.LogModelArgs] = None
19
+
20
+
21
+ class ImportModelSpec(BaseModel):
22
+ compute_pool: str
23
+ models: list[ModelSpec]
@@ -195,6 +195,7 @@ class ModelDeploymentSpec:
195
195
 
196
196
  def add_job_spec(
197
197
  self,
198
+ *,
198
199
  job_name: sql_identifier.SqlIdentifier,
199
200
  inference_compute_pool_name: sql_identifier.SqlIdentifier,
200
201
  function_name: str,
@@ -202,6 +203,8 @@ class ModelDeploymentSpec:
202
203
  output_stage_location: str,
203
204
  completion_filename: str,
204
205
  input_file_pattern: str,
206
+ column_handling: Optional[str] = None,
207
+ params: Optional[str] = None,
205
208
  warehouse: sql_identifier.SqlIdentifier,
206
209
  job_database_name: Optional[sql_identifier.SqlIdentifier] = None,
207
210
  job_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
@@ -217,14 +220,16 @@ class ModelDeploymentSpec:
217
220
  Args:
218
221
  job_name: Name of the job.
219
222
  inference_compute_pool_name: Compute pool for inference.
220
- warehouse: Warehouse for the job.
221
223
  function_name: Function name.
222
224
  input_stage_location: Stage location for input data.
223
225
  output_stage_location: Stage location for output data.
226
+ completion_filename: Name of completion file (default: "completion.txt").
227
+ input_file_pattern: Pattern for input files (optional).
228
+ column_handling: Column handling mode for input data.
229
+ params: Additional parameters for the job.
230
+ warehouse: Warehouse for the job.
224
231
  job_database_name: Database name for the job.
225
232
  job_schema_name: Schema name for the job.
226
- input_file_pattern: Pattern for input files (optional).
227
- completion_filename: Name of completion file (default: "completion.txt").
228
233
  cpu: CPU requirement.
229
234
  memory: Memory requirement.
230
235
  gpu: GPU requirement.
@@ -259,7 +264,10 @@ class ModelDeploymentSpec:
259
264
  warehouse=warehouse.identifier() if warehouse else None,
260
265
  function_name=function_name,
261
266
  input=model_deployment_spec_schema.Input(
262
- input_stage_location=input_stage_location, input_file_pattern=input_file_pattern
267
+ input_stage_location=input_stage_location,
268
+ input_file_pattern=input_file_pattern,
269
+ column_handling=column_handling,
270
+ params=params,
263
271
  ),
264
272
  output=model_deployment_spec_schema.Output(
265
273
  output_stage_location=output_stage_location,
@@ -39,6 +39,8 @@ class Service(BaseModel):
39
39
  class Input(BaseModel):
40
40
  input_stage_location: str
41
41
  input_file_pattern: str
42
+ column_handling: Optional[str] = None
43
+ params: Optional[str] = None
42
44
 
43
45
 
44
46
  class Output(BaseModel):
@@ -74,6 +76,7 @@ class HuggingFaceModel(BaseModel):
74
76
  task: Optional[str] = None
75
77
  tokenizer: Optional[str] = None
76
78
  token: Optional[str] = None
79
+ token_secret_object: Optional[str] = None
77
80
  trust_remote_code: Optional[bool] = False
78
81
  revision: Optional[str] = None
79
82
  hf_model_kwargs: Optional[str] = "{}"
@@ -22,6 +22,14 @@ def _normalize_url_for_sql(url: str) -> str:
22
22
  return f"'{url}'"
23
23
 
24
24
 
25
+ def _format_param_value(value: Any) -> str:
26
+ if isinstance(value, str):
27
+ return f"'{snowpark_utils.escape_single_quotes(value)}'" # type: ignore[no-untyped-call]
28
+ elif value is None:
29
+ return "NULL"
30
+ return str(value)
31
+
32
+
25
33
  class ModelVersionSQLClient(_base._BaseSQLClient):
26
34
  FUNCTION_NAME_COL_NAME = "name"
27
35
  FUNCTION_RETURN_TYPE_COL_NAME = "return_type"
@@ -354,6 +362,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
354
362
  input_args: list[sql_identifier.SqlIdentifier],
355
363
  returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
356
364
  statement_params: Optional[dict[str, Any]] = None,
365
+ params: Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]] = None,
357
366
  ) -> dataframe.DataFrame:
358
367
  with_statements = []
359
368
  if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
@@ -392,10 +401,17 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
392
401
 
393
402
  args_sql = ", ".join(args_sql_list)
394
403
 
395
- wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
404
+ if params:
405
+ param_sql = ", ".join(_format_param_value(val) for _, val in params)
406
+ args_sql = f"{args_sql}, {param_sql}" if args_sql else param_sql
407
+
408
+ total_args = len(input_args) + (len(params) if params else 0)
409
+ wide_input = total_args > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
396
410
  if wide_input:
397
- input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
398
- args_sql = f"object_construct_keep_null({input_args_sql})"
411
+ parts = [f"'{arg}', {arg.identifier()}" for arg in input_args]
412
+ if params:
413
+ parts.extend(f"'{name}', {_format_param_value(val)}" for name, val in params)
414
+ args_sql = f"object_construct_keep_null({', '.join(parts)})"
399
415
 
400
416
  sql = textwrap.dedent(
401
417
  f"""WITH {','.join(with_statements)}
@@ -439,6 +455,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
439
455
  statement_params: Optional[dict[str, Any]] = None,
440
456
  is_partitioned: bool = True,
441
457
  explain_case_sensitive: bool = False,
458
+ params: Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]] = None,
442
459
  ) -> dataframe.DataFrame:
443
460
  with_statements = []
444
461
  if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
@@ -477,10 +494,17 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
477
494
 
478
495
  args_sql = ", ".join(args_sql_list)
479
496
 
480
- wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
497
+ if params:
498
+ param_sql = ", ".join(_format_param_value(val) for _, val in params)
499
+ args_sql = f"{args_sql}, {param_sql}" if args_sql else param_sql
500
+
501
+ total_args = len(input_args) + (len(params) if params else 0)
502
+ wide_input = total_args > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
481
503
  if wide_input:
482
- input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
483
- args_sql = f"object_construct_keep_null({input_args_sql})"
504
+ parts = [f"'{arg}', {arg.identifier()}" for arg in input_args]
505
+ if params:
506
+ parts.extend(f"'{name}', {_format_param_value(val)}" for name, val in params)
507
+ args_sql = f"object_construct_keep_null({', '.join(parts)})"
484
508
 
485
509
  sql = textwrap.dedent(
486
510
  f"""WITH {','.join(with_statements)}
@@ -20,6 +20,15 @@ from snowflake.snowpark._internal import utils as snowpark_utils
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
23
+
24
+ def _format_param_value(value: Any) -> str:
25
+ if isinstance(value, str):
26
+ return f"'{snowpark_utils.escape_single_quotes(value)}'" # type: ignore[no-untyped-call]
27
+ elif value is None:
28
+ return "NULL"
29
+ return str(value)
30
+
31
+
23
32
  # Using this token instead of '?' to avoid escaping issues
24
33
  # After quotes are escaped, we replace this token with '|| ? ||'
25
34
  QMARK_RESERVED_TOKEN = "<QMARK_RESERVED_TOKEN>"
@@ -140,6 +149,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
140
149
  input_args: list[sql_identifier.SqlIdentifier],
141
150
  returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
142
151
  statement_params: Optional[dict[str, Any]] = None,
152
+ params: Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]] = None,
143
153
  ) -> dataframe.DataFrame:
144
154
  with_statements = []
145
155
  actual_database_name = database_name or self._database_name
@@ -170,10 +180,17 @@ class ServiceSQLClient(_base._BaseSQLClient):
170
180
  args_sql_list.append(input_arg_value)
171
181
  args_sql = ", ".join(args_sql_list)
172
182
 
173
- wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
183
+ if params:
184
+ param_sql = ", ".join(_format_param_value(val) for _, val in params)
185
+ args_sql = f"{args_sql}, {param_sql}" if args_sql else param_sql
186
+
187
+ total_args = len(input_args) + (len(params) if params else 0)
188
+ wide_input = total_args > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
174
189
  if wide_input:
175
- input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
176
- args_sql = f"object_construct_keep_null({input_args_sql})"
190
+ parts = [f"'{arg}', {arg.identifier()}" for arg in input_args]
191
+ if params:
192
+ parts.extend(f"'{name}', {_format_param_value(val)}" for name, val in params)
193
+ args_sql = f"object_construct_keep_null({', '.join(parts)})"
177
194
 
178
195
  fully_qualified_service_name = self.fully_qualified_object_name(
179
196
  actual_database_name, actual_schema_name, service_name
@@ -301,7 +318,12 @@ class ServiceSQLClient(_base._BaseSQLClient):
301
318
  False if service doesn't have proxy container
302
319
  """
303
320
  try:
304
- spec_raw = yaml.safe_load(row[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME])
321
+ spec_yaml = row[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME]
322
+ if spec_yaml is None:
323
+ return False
324
+ spec_raw = yaml.safe_load(spec_yaml)
325
+ if spec_raw is None:
326
+ return False
305
327
  spec = cast(dict[str, Any], spec_raw)
306
328
 
307
329
  proxy_container_spec = next(
@@ -131,7 +131,7 @@ class ModelComposer:
131
131
  python_version: Optional[str] = None,
132
132
  user_files: Optional[dict[str, list[str]]] = None,
133
133
  ext_modules: Optional[list[ModuleType]] = None,
134
- code_paths: Optional[list[str]] = None,
134
+ code_paths: Optional[list[model_types.CodePathLike]] = None,
135
135
  task: model_types.Task = model_types.Task.UNKNOWN,
136
136
  experiment_info: Optional["ExperimentInfo"] = None,
137
137
  options: Optional[model_types.ModelSaveOption] = None,
@@ -39,6 +39,10 @@ class ModelMethodSignatureFieldWithName(ModelMethodSignatureField):
39
39
  name: Required[str]
40
40
 
41
41
 
42
+ class ModelMethodSignatureFieldWithNameAndDefault(ModelMethodSignatureFieldWithName):
43
+ default: Required[Any]
44
+
45
+
42
46
  class ModelFunctionMethodDict(TypedDict):
43
47
  name: Required[str]
44
48
  runtime: Required[str]
@@ -46,6 +50,7 @@ class ModelFunctionMethodDict(TypedDict):
46
50
  handler: Required[str]
47
51
  inputs: Required[list[ModelMethodSignatureFieldWithName]]
48
52
  outputs: Required[Union[list[ModelMethodSignatureField], list[ModelMethodSignatureFieldWithName]]]
53
+ params: NotRequired[list[ModelMethodSignatureFieldWithNameAndDefault]]
49
54
  volatility: NotRequired[str]
50
55
 
51
56
 
@@ -105,7 +105,7 @@ class ModelMethod:
105
105
  except ValueError as e:
106
106
  raise ValueError(
107
107
  f"Your target method {self.target_method} cannot be resolved as valid SQL identifier. "
108
- "Try specify `case_sensitive` as True."
108
+ "Try specifying `case_sensitive` as True."
109
109
  ) from e
110
110
 
111
111
  if self.target_method not in self.model_meta.signatures.keys():
@@ -127,12 +127,41 @@ class ModelMethod:
127
127
  except ValueError as e:
128
128
  raise ValueError(
129
129
  f"Your feature {feature.name} cannot be resolved as valid SQL identifier. "
130
- "Try specify `case_sensitive` as True."
130
+ "Try specifying `case_sensitive` as True."
131
131
  ) from e
132
132
  return model_manifest_schema.ModelMethodSignatureFieldWithName(
133
133
  name=feature_name.resolved(), type=type_utils.convert_sp_to_sf_type(feature.as_snowpark_type())
134
134
  )
135
135
 
136
+ @staticmethod
137
+ def _flatten_params(params: list[model_signature.BaseParamSpec]) -> list[model_signature.ParamSpec]:
138
+ """Flatten ParamGroupSpec into leaf ParamSpec items."""
139
+ result: list[model_signature.ParamSpec] = []
140
+ for param in params:
141
+ if isinstance(param, model_signature.ParamSpec):
142
+ result.append(param)
143
+ elif isinstance(param, model_signature.ParamGroupSpec):
144
+ result.extend(ModelMethod._flatten_params(param.specs))
145
+ return result
146
+
147
+ @staticmethod
148
+ def _get_method_arg_from_param(
149
+ param_spec: model_signature.ParamSpec,
150
+ case_sensitive: bool = False,
151
+ ) -> model_manifest_schema.ModelMethodSignatureFieldWithNameAndDefault:
152
+ try:
153
+ param_name = sql_identifier.SqlIdentifier(param_spec.name, case_sensitive=case_sensitive)
154
+ except ValueError as e:
155
+ raise ValueError(
156
+ f"Your parameter {param_spec.name} cannot be resolved as valid SQL identifier. "
157
+ "Try specifying `case_sensitive` as True."
158
+ ) from e
159
+ return model_manifest_schema.ModelMethodSignatureFieldWithNameAndDefault(
160
+ name=param_name.resolved(),
161
+ type=type_utils.convert_sp_to_sf_type(param_spec.dtype.as_snowpark_type()),
162
+ default=param_spec.default_value,
163
+ )
164
+
136
165
  def save(
137
166
  self, workspace_path: pathlib.Path, options: Optional[function_generator.FunctionGenerateOptions] = None
138
167
  ) -> model_manifest_schema.ModelMethodDict:
@@ -182,6 +211,36 @@ class ModelMethod:
182
211
  inputs=input_list,
183
212
  outputs=outputs,
184
213
  )
214
+
215
+ # Add parameters if signature has parameters
216
+ if self.model_meta.signatures[self.target_method].params:
217
+ flat_params = ModelMethod._flatten_params(list(self.model_meta.signatures[self.target_method].params))
218
+ param_list = [
219
+ ModelMethod._get_method_arg_from_param(
220
+ param_spec, case_sensitive=self.options.get("case_sensitive", False)
221
+ )
222
+ for param_spec in flat_params
223
+ ]
224
+ param_name_counter = collections.Counter([param_info["name"] for param_info in param_list])
225
+ dup_param_names = [k for k, v in param_name_counter.items() if v > 1]
226
+ if dup_param_names:
227
+ raise ValueError(
228
+ f"Found duplicate parameter named resolved as {', '.join(dup_param_names)} in the method"
229
+ f" {self.target_method}. This might be because you have parameters with same letters but "
230
+ "different cases. In this case, set case_sensitive as True for those methods to distinguish them."
231
+ )
232
+
233
+ # Check for name collisions between parameters and inputs using existing counters
234
+ collision_names = [name for name in param_name_counter if name in input_name_counter]
235
+ if collision_names:
236
+ raise ValueError(
237
+ f"Found parameter(s) with the same name as input feature(s): {', '.join(sorted(collision_names))} "
238
+ f"in the method {self.target_method}. Parameters and inputs must have distinct names. "
239
+ "Try using case_sensitive=True if the names differ only by case."
240
+ )
241
+
242
+ method_dict["params"] = param_list
243
+
185
244
  should_set_volatility = (
186
245
  platform_capabilities.PlatformCapabilities.get_instance().is_set_module_functions_volatility_from_manifest()
187
246
  )
@@ -86,6 +86,9 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
86
86
  get_prediction_fn=get_prediction,
87
87
  )
88
88
 
89
+ # Add parameters extracted from custom model inference methods to signatures
90
+ cls._add_method_parameters_to_signatures(model, model_meta)
91
+
89
92
  model_blob_path = os.path.join(model_blobs_dir_path, name)
90
93
  os.makedirs(model_blob_path, exist_ok=True)
91
94
  if model.context.artifacts:
@@ -188,6 +191,55 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
188
191
  assert isinstance(model, custom_model.CustomModel)
189
192
  return model
190
193
 
194
+ @classmethod
195
+ def _add_method_parameters_to_signatures(
196
+ cls,
197
+ model: "custom_model.CustomModel",
198
+ model_meta: model_meta_api.ModelMetadata,
199
+ ) -> None:
200
+ """Extract parameters from custom model inference methods and add them to signatures.
201
+
202
+ For each inference method, if the signature doesn't already have parameters and the method
203
+ has keyword-only parameters with defaults, create ParamSpecs and add them to the signature.
204
+
205
+ Args:
206
+ model: The custom model instance.
207
+ model_meta: The model metadata containing signatures to augment.
208
+ """
209
+ for method in model._get_infer_methods():
210
+ method_name = method.__name__
211
+ if method_name not in model_meta.signatures:
212
+ continue
213
+
214
+ sig = model_meta.signatures[method_name]
215
+
216
+ # Skip if the signature already has parameters (user-provided or previously set)
217
+ if sig.params:
218
+ continue
219
+
220
+ # Extract parameters from the method
221
+ method_params = custom_model.get_method_parameters(method)
222
+ if not method_params:
223
+ continue
224
+
225
+ # Create ParamSpecs from the method parameters
226
+ param_specs = []
227
+ for param_name, param_type, param_default in method_params:
228
+ dtype = model_signature.DataType.from_python_type(param_type)
229
+ param_spec = model_signature.ParamSpec(
230
+ name=param_name,
231
+ dtype=dtype,
232
+ default_value=param_default,
233
+ )
234
+ param_specs.append(param_spec)
235
+
236
+ # Create a new signature with parameters
237
+ model_meta.signatures[method_name] = model_signature.ModelSignature(
238
+ inputs=sig.inputs,
239
+ outputs=sig.outputs,
240
+ params=param_specs,
241
+ )
242
+
191
243
  @classmethod
192
244
  def convert_as_custom_model(
193
245
  cls,
@@ -194,7 +194,18 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
194
194
 
195
195
  if kwargs.get("use_gpu", False):
196
196
  assert type(kwargs.get("use_gpu", False)) == bool
197
- gpu_params = {"tree_method": "gpu_hist", "predictor": "gpu_predictor"}
197
+ from packaging import version
198
+
199
+ xgb_version = version.parse(xgboost.__version__)
200
+ if xgb_version >= version.parse("3.1.0"):
201
+ # XGBoost 3.1.0+: Use device="cuda" for GPU acceleration
202
+ # gpu_hist and gpu_predictor were removed in XGBoost 3.1.0
203
+ # See: https://xgboost.readthedocs.io/en/latest/changes/v3.1.0.html
204
+ gpu_params = {"tree_method": "hist", "device": "cuda"}
205
+ else:
206
+ # XGBoost < 3.1.0: Use legacy gpu_hist tree_method
207
+ gpu_params = {"tree_method": "gpu_hist", "predictor": "gpu_predictor"}
208
+
198
209
  if isinstance(m, xgboost.Booster):
199
210
  m.set_param(gpu_params)
200
211
  elif isinstance(m, xgboost.XGBModel):
@@ -256,6 +267,20 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
256
267
  @custom_model.inference_api
257
268
  def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
258
269
  import shap
270
+ from packaging import version
271
+
272
+ xgb_version = version.parse(xgboost.__version__)
273
+ shap_version = version.parse(shap.__version__)
274
+
275
+ # SHAP < 0.50.0 is incompatible with XGBoost >= 3.1.0 due to base_score format change
276
+ # (base_score is now stored as a vector for multi-output models)
277
+ # See: https://xgboost.readthedocs.io/en/latest/changes/v3.1.0.html
278
+ if xgb_version >= version.parse("3.1.0") and shap_version < version.parse("0.50.0"):
279
+ raise RuntimeError(
280
+ f"SHAP version {shap.__version__} is incompatible with XGBoost version "
281
+ f"{xgboost.__version__}. XGBoost 3.1+ changed the model format which requires "
282
+ f"SHAP >= 0.50.0. Please upgrade SHAP or use XGBoost < 3.1."
283
+ )
259
284
 
260
285
  explainer = shap.TreeExplainer(raw_model)
261
286
  df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer.shap_values(X))