snowflake-ml-python 1.21.0__py3-none-any.whl → 1.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. snowflake/ml/_internal/utils/url.py +42 -0
  2. snowflake/ml/jobs/__init__.py +2 -0
  3. snowflake/ml/jobs/_utils/constants.py +2 -0
  4. snowflake/ml/jobs/_utils/payload_utils.py +38 -18
  5. snowflake/ml/jobs/_utils/query_helper.py +8 -1
  6. snowflake/ml/jobs/_utils/runtime_env_utils.py +58 -4
  7. snowflake/ml/jobs/_utils/spec_utils.py +0 -31
  8. snowflake/ml/jobs/_utils/stage_utils.py +2 -2
  9. snowflake/ml/jobs/_utils/types.py +22 -2
  10. snowflake/ml/jobs/job_definition.py +232 -0
  11. snowflake/ml/jobs/manager.py +16 -177
  12. snowflake/ml/lineage/lineage_node.py +1 -1
  13. snowflake/ml/model/__init__.py +6 -0
  14. snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
  15. snowflake/ml/model/_client/model/model_version_impl.py +109 -32
  16. snowflake/ml/model/_client/ops/deployment_step.py +36 -0
  17. snowflake/ml/model/_client/ops/model_ops.py +45 -2
  18. snowflake/ml/model/_client/ops/param_utils.py +124 -0
  19. snowflake/ml/model/_client/ops/service_ops.py +81 -61
  20. snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
  21. snowflake/ml/model/_client/service/model_deployment_spec.py +24 -9
  22. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +4 -0
  23. snowflake/ml/model/_client/sql/model_version.py +30 -6
  24. snowflake/ml/model/_client/sql/service.py +30 -29
  25. snowflake/ml/model/_model_composer/model_composer.py +1 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
  27. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +21 -3
  28. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +21 -3
  29. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +21 -3
  30. snowflake/ml/model/_model_composer/model_method/model_method.py +62 -2
  31. snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
  32. snowflake/ml/model/_packager/model_handlers/huggingface.py +54 -10
  33. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +52 -16
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
  36. snowflake/ml/model/_packager/model_packager.py +1 -1
  37. snowflake/ml/model/_signatures/core.py +85 -0
  38. snowflake/ml/model/_signatures/utils.py +55 -0
  39. snowflake/ml/model/code_path.py +104 -0
  40. snowflake/ml/model/custom_model.py +55 -13
  41. snowflake/ml/model/model_signature.py +13 -1
  42. snowflake/ml/model/openai_signatures.py +97 -0
  43. snowflake/ml/model/type_hints.py +2 -0
  44. snowflake/ml/registry/_manager/model_manager.py +230 -15
  45. snowflake/ml/registry/_manager/model_parameter_reconciler.py +1 -1
  46. snowflake/ml/registry/registry.py +4 -4
  47. snowflake/ml/version.py +1 -1
  48. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/METADATA +95 -1
  49. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/RECORD +52 -46
  50. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/WHEEL +0 -0
  51. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/licenses/LICENSE.txt +0 -0
  52. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  import enum
2
2
  import json
3
+ import logging
3
4
  import os
4
5
  import pathlib
5
6
  import tempfile
@@ -11,9 +12,9 @@ from typing_extensions import NotRequired
11
12
 
12
13
  from snowflake.ml._internal import platform_capabilities
13
14
  from snowflake.ml._internal.exceptions import error_codes, exceptions
14
- from snowflake.ml._internal.utils import formatting, identifier, sql_identifier
15
+ from snowflake.ml._internal.utils import formatting, identifier, sql_identifier, url
15
16
  from snowflake.ml.model import model_signature, type_hints
16
- from snowflake.ml.model._client.ops import metadata_ops
17
+ from snowflake.ml.model._client.ops import deployment_step, metadata_ops, param_utils
17
18
  from snowflake.ml.model._client.sql import (
18
19
  model as model_sql,
19
20
  model_version as model_version_sql,
@@ -33,6 +34,8 @@ from snowflake.ml.model._signatures import snowpark_handler
33
34
  from snowflake.snowpark import dataframe, row, session
34
35
  from snowflake.snowpark._internal import utils as snowpark_utils
35
36
 
37
+ logger = logging.getLogger(__name__)
38
+
36
39
 
37
40
  # An enum class to represent Create Or Alter Model SQL command.
38
41
  class ModelAction(enum.Enum):
@@ -986,6 +989,7 @@ class ModelOperator:
986
989
  statement_params: Optional[dict[str, str]] = None,
987
990
  is_partitioned: Optional[bool] = None,
988
991
  explain_case_sensitive: bool = False,
992
+ params: Optional[dict[str, Any]] = None,
989
993
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
990
994
  ...
991
995
 
@@ -1002,6 +1006,7 @@ class ModelOperator:
1002
1006
  strict_input_validation: bool = False,
1003
1007
  statement_params: Optional[dict[str, str]] = None,
1004
1008
  explain_case_sensitive: bool = False,
1009
+ params: Optional[dict[str, Any]] = None,
1005
1010
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
1006
1011
  ...
1007
1012
 
@@ -1022,6 +1027,7 @@ class ModelOperator:
1022
1027
  statement_params: Optional[dict[str, str]] = None,
1023
1028
  is_partitioned: Optional[bool] = None,
1024
1029
  explain_case_sensitive: bool = False,
1030
+ params: Optional[dict[str, Any]] = None,
1025
1031
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
1026
1032
  identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
1027
1033
 
@@ -1057,6 +1063,8 @@ class ModelOperator:
1057
1063
  col_name = sql_identifier.SqlIdentifier(input_feature.name.upper(), case_sensitive=True)
1058
1064
  input_args.append(col_name)
1059
1065
 
1066
+ method_parameters = param_utils.validate_and_resolve_params(params, signature.params)
1067
+
1060
1068
  returns = []
1061
1069
  for output_feature in signature.outputs:
1062
1070
  output_name = identifier_rule.get_sql_identifier_from_feature(output_feature.name)
@@ -1075,6 +1083,7 @@ class ModelOperator:
1075
1083
  schema_name=schema_name,
1076
1084
  service_name=service_name,
1077
1085
  statement_params=statement_params,
1086
+ params=method_parameters,
1078
1087
  )
1079
1088
  else:
1080
1089
  assert model_name is not None
@@ -1090,6 +1099,7 @@ class ModelOperator:
1090
1099
  model_name=model_name,
1091
1100
  version_name=version_name,
1092
1101
  statement_params=statement_params,
1102
+ params=method_parameters,
1093
1103
  )
1094
1104
  elif method_function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
1095
1105
  df_res = self._model_version_client.invoke_table_function_method(
@@ -1105,6 +1115,7 @@ class ModelOperator:
1105
1115
  statement_params=statement_params,
1106
1116
  is_partitioned=is_partitioned or False,
1107
1117
  explain_case_sensitive=explain_case_sensitive,
1118
+ params=method_parameters,
1108
1119
  )
1109
1120
 
1110
1121
  if keep_order:
@@ -1238,3 +1249,35 @@ class ModelOperator:
1238
1249
  target_path=local_file_dir,
1239
1250
  statement_params=statement_params,
1240
1251
  )
1252
+
1253
+ def run_import_model_query(
1254
+ self,
1255
+ *,
1256
+ database_name: str,
1257
+ schema_name: str,
1258
+ yaml_content: str,
1259
+ statement_params: Optional[dict[str, Any]] = None,
1260
+ ) -> None:
1261
+ yaml_content_escaped = snowpark_utils.escape_single_quotes(yaml_content) # type: ignore[no-untyped-call]
1262
+
1263
+ async_job = self._session.sql(
1264
+ f"SELECT SYSTEM$IMPORT_MODEL('{yaml_content_escaped}')",
1265
+ ).collect(block=False, statement_params=statement_params)
1266
+ query_id = async_job.query_id # type: ignore[attr-defined]
1267
+
1268
+ logger.info(f"Remotely importing model, with the query id: {query_id}")
1269
+ model_logger_service_name = sql_identifier.SqlIdentifier(
1270
+ deployment_step.get_service_id_from_deployment_step(
1271
+ query_id,
1272
+ deployment_step.DeploymentStep.MODEL_LOGGING,
1273
+ )
1274
+ )
1275
+
1276
+ logger_name = model_logger_service_name.identifier()
1277
+ job_url = f"{url.JOB_URL_PREFIX}/{database_name}/{schema_name}/{logger_name}"
1278
+ snowflake_url = url.get_snowflake_url(session=self._session, url_path=job_url)
1279
+ logger.info(
1280
+ f"To monitor the progress of the model logging job, head to the job monitoring page {snowflake_url}"
1281
+ )
1282
+
1283
+ async_job.result() # type: ignore[attr-defined]
@@ -0,0 +1,124 @@
1
+ """Utility functions for model parameter validation and resolution."""
2
+
3
+ from typing import Any, Optional, Sequence
4
+
5
+ from snowflake.ml._internal.exceptions import error_codes, exceptions
6
+ from snowflake.ml._internal.utils import sql_identifier
7
+ from snowflake.ml.model._signatures import core
8
+
9
+
10
+ def validate_params(
11
+ params: Optional[dict[str, Any]],
12
+ signature_params: Optional[Sequence[core.BaseParamSpec]],
13
+ ) -> None:
14
+ """Validate user-provided params against signature params.
15
+
16
+ Args:
17
+ params: User-provided parameter dictionary (runtime values).
18
+ signature_params: Parameter specifications from the model signature.
19
+
20
+ Raises:
21
+ SnowflakeMLException: If params are provided but signature has no params,
22
+ or if unknown params are provided, or if param types are invalid,
23
+ or if duplicate params are provided with different cases.
24
+ """
25
+ # Params provided but signature has no params defined
26
+ if params and not signature_params:
27
+ raise exceptions.SnowflakeMLException(
28
+ error_code=error_codes.INVALID_ARGUMENT,
29
+ original_exception=ValueError(
30
+ f"Parameters were provided ({sorted(params.keys())}), "
31
+ "but this method does not accept any parameters."
32
+ ),
33
+ )
34
+
35
+ if not signature_params or not params:
36
+ return
37
+
38
+ # Case-insensitive lookup: normalized_name -> param_spec
39
+ param_spec_lookup = {ps.name.upper(): ps for ps in signature_params}
40
+
41
+ # Check for duplicate params with different cases (e.g., "temperature" and "TEMPERATURE")
42
+ normalized_names = [name.upper() for name in params]
43
+ if len(normalized_names) != len(set(normalized_names)):
44
+ # Find the duplicate params to raise an error
45
+ param_seen: dict[str, list[str]] = {}
46
+ for param_name in params:
47
+ param_seen.setdefault(param_name.upper(), []).append(param_name)
48
+ duplicate_param_names = [param_names for param_names in param_seen.values() if len(param_names) > 1]
49
+ raise exceptions.SnowflakeMLException(
50
+ error_code=error_codes.INVALID_ARGUMENT,
51
+ original_exception=ValueError(
52
+ f"Duplicate parameter(s) provided with different cases: {duplicate_param_names}. "
53
+ "Parameter names are case-insensitive."
54
+ ),
55
+ )
56
+
57
+ # Validate user-provided params exist (case-insensitive)
58
+ invalid_params = [name for name in params if name.upper() not in param_spec_lookup]
59
+ if invalid_params:
60
+ raise exceptions.SnowflakeMLException(
61
+ error_code=error_codes.INVALID_ARGUMENT,
62
+ original_exception=ValueError(
63
+ f"Unknown parameter(s): {sorted(invalid_params)}. "
64
+ f"Valid parameters are: {sorted(ps.name for ps in signature_params)}"
65
+ ),
66
+ )
67
+
68
+ # Validate types for each provided param
69
+ for param_name, default_value in params.items():
70
+ param_spec = param_spec_lookup[param_name.upper()]
71
+ if isinstance(param_spec, core.ParamSpec):
72
+ core.ParamSpec._validate_default_value(param_spec.dtype, default_value, param_spec.shape)
73
+
74
+
75
+ def resolve_params(
76
+ params: Optional[dict[str, Any]],
77
+ signature_params: Sequence[core.BaseParamSpec],
78
+ ) -> list[tuple[sql_identifier.SqlIdentifier, Any]]:
79
+ """Resolve final method parameters by applying user-provided params over signature defaults.
80
+
81
+ Args:
82
+ params: User-provided parameter dictionary (runtime values).
83
+ signature_params: Parameter specifications from the model signature.
84
+
85
+ Returns:
86
+ List of tuples (SqlIdentifier, value) for method invocation.
87
+ """
88
+ # Case-insensitive lookup: normalized_name -> param_spec
89
+ param_spec_lookup = {ps.name.upper(): ps for ps in signature_params}
90
+
91
+ # Start with defaults from signature
92
+ final_params: dict[str, Any] = {}
93
+ for param_spec in signature_params:
94
+ if hasattr(param_spec, "default_value"):
95
+ final_params[param_spec.name] = param_spec.default_value
96
+
97
+ # Override with provided runtime parameters (using signature's original param names)
98
+ if params:
99
+ for param_name, override_value in params.items():
100
+ canonical_name = param_spec_lookup[param_name.upper()].name
101
+ final_params[canonical_name] = override_value
102
+
103
+ return [(sql_identifier.SqlIdentifier(param_name), param_value) for param_name, param_value in final_params.items()]
104
+
105
+
106
+ def validate_and_resolve_params(
107
+ params: Optional[dict[str, Any]],
108
+ signature_params: Optional[Sequence[core.BaseParamSpec]],
109
+ ) -> Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]]:
110
+ """Validate user-provided params against signature params and return method parameters.
111
+
112
+ Args:
113
+ params: User-provided parameter dictionary (runtime values).
114
+ signature_params: Parameter specifications from the model signature.
115
+
116
+ Returns:
117
+ List of tuples (SqlIdentifier, value) for method invocation, or None if no params.
118
+ """
119
+ validate_params(params, signature_params)
120
+
121
+ if not signature_params:
122
+ return None
123
+
124
+ return resolve_params(params, signature_params)
@@ -1,6 +1,6 @@
1
+ import base64
1
2
  import dataclasses
2
- import enum
3
- import hashlib
3
+ import json
4
4
  import logging
5
5
  import pathlib
6
6
  import re
@@ -8,7 +8,9 @@ import tempfile
8
8
  import threading
9
9
  import time
10
10
  import warnings
11
- from typing import Any, Optional, Union, cast
11
+ from typing import Any, Optional, Sequence, Union, cast
12
+
13
+ from pydantic import TypeAdapter
12
14
 
13
15
  from snowflake import snowpark
14
16
  from snowflake.ml._internal import file_utils, platform_capabilities as pc
@@ -16,8 +18,10 @@ from snowflake.ml._internal.utils import identifier, service_logger, sql_identif
16
18
  from snowflake.ml.jobs import job
17
19
  from snowflake.ml.model import inference_engine as inference_engine_module, type_hints
18
20
  from snowflake.ml.model._client.model import batch_inference_specs
21
+ from snowflake.ml.model._client.ops import deployment_step, param_utils
19
22
  from snowflake.ml.model._client.service import model_deployment_spec
20
23
  from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
24
+ from snowflake.ml.model._signatures import core
21
25
  from snowflake.snowpark import async_job, exceptions, row, session
22
26
  from snowflake.snowpark._internal import utils as snowpark_utils
23
27
 
@@ -25,32 +29,12 @@ module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY
25
29
  module_logger.propagate = False
26
30
 
27
31
 
28
- class DeploymentStep(enum.Enum):
29
- MODEL_BUILD = ("model-build", "model_build_")
30
- MODEL_INFERENCE = ("model-inference", None)
31
- MODEL_LOGGING = ("model-logging", "model_logging_")
32
-
33
- def __init__(self, container_name: str, service_name_prefix: Optional[str]) -> None:
34
- self._container_name = container_name
35
- self._service_name_prefix = service_name_prefix
36
-
37
- @property
38
- def container_name(self) -> str:
39
- """Get the container name for the deployment step."""
40
- return self._container_name
41
-
42
- @property
43
- def service_name_prefix(self) -> Optional[str]:
44
- """Get the service name prefix for the deployment step."""
45
- return self._service_name_prefix
46
-
47
-
48
32
  @dataclasses.dataclass
49
33
  class ServiceLogInfo:
50
34
  database_name: Optional[sql_identifier.SqlIdentifier]
51
35
  schema_name: Optional[sql_identifier.SqlIdentifier]
52
36
  service_name: sql_identifier.SqlIdentifier
53
- deployment_step: DeploymentStep
37
+ deployment_step: deployment_step.DeploymentStep
54
38
  instance_id: str = "0"
55
39
  log_color: service_logger.LogColor = service_logger.LogColor.GREY
56
40
 
@@ -353,13 +337,16 @@ class ServiceOperator:
353
337
  if is_enable_image_build:
354
338
  # stream service logs in a thread
355
339
  model_build_service_name = sql_identifier.SqlIdentifier(
356
- self._get_service_id_from_deployment_step(query_id, DeploymentStep.MODEL_BUILD)
340
+ deployment_step.get_service_id_from_deployment_step(
341
+ query_id,
342
+ deployment_step.DeploymentStep.MODEL_BUILD,
343
+ )
357
344
  )
358
345
  model_build_service = ServiceLogInfo(
359
346
  database_name=service_database_name,
360
347
  schema_name=service_schema_name,
361
348
  service_name=model_build_service_name,
362
- deployment_step=DeploymentStep.MODEL_BUILD,
349
+ deployment_step=deployment_step.DeploymentStep.MODEL_BUILD,
363
350
  log_color=service_logger.LogColor.GREEN,
364
351
  )
365
352
 
@@ -367,21 +354,23 @@ class ServiceOperator:
367
354
  database_name=service_database_name,
368
355
  schema_name=service_schema_name,
369
356
  service_name=service_name,
370
- deployment_step=DeploymentStep.MODEL_INFERENCE,
357
+ deployment_step=deployment_step.DeploymentStep.MODEL_INFERENCE,
371
358
  log_color=service_logger.LogColor.BLUE,
372
359
  )
373
360
 
374
361
  model_logger_service: Optional[ServiceLogInfo] = None
375
362
  if hf_model_args:
376
363
  model_logger_service_name = sql_identifier.SqlIdentifier(
377
- self._get_service_id_from_deployment_step(query_id, DeploymentStep.MODEL_LOGGING)
364
+ deployment_step.get_service_id_from_deployment_step(
365
+ query_id, deployment_step.DeploymentStep.MODEL_LOGGING
366
+ )
378
367
  )
379
368
 
380
369
  model_logger_service = ServiceLogInfo(
381
370
  database_name=service_database_name,
382
371
  schema_name=service_schema_name,
383
372
  service_name=model_logger_service_name,
384
- deployment_step=DeploymentStep.MODEL_LOGGING,
373
+ deployment_step=deployment_step.DeploymentStep.MODEL_LOGGING,
385
374
  log_color=service_logger.LogColor.ORANGE,
386
375
  )
387
376
 
@@ -536,7 +525,7 @@ class ServiceOperator:
536
525
  service = service_log_meta.service
537
526
  # check if using an existing model build image
538
527
  if (
539
- service.deployment_step == DeploymentStep.MODEL_BUILD
528
+ service.deployment_step == deployment_step.DeploymentStep.MODEL_BUILD
540
529
  and not force_rebuild
541
530
  and service_log_meta.is_model_logger_service_done
542
531
  and not service_log_meta.is_model_build_service_done
@@ -582,31 +571,26 @@ class ServiceOperator:
582
571
  if (service_status != service_sql.ServiceStatus.RUNNING) or (service_status != service_log_meta.service_status):
583
572
  service_log_meta.service_status = service_status
584
573
 
585
- if service.deployment_step == DeploymentStep.MODEL_BUILD:
574
+ if service.deployment_step == deployment_step.DeploymentStep.MODEL_BUILD:
586
575
  module_logger.info(
587
576
  f"Image build service {service.display_service_name} is "
588
577
  f"{service_log_meta.service_status.value}."
589
578
  )
590
- elif service.deployment_step == DeploymentStep.MODEL_INFERENCE:
579
+ elif service.deployment_step == deployment_step.DeploymentStep.MODEL_INFERENCE:
591
580
  module_logger.info(
592
581
  f"Inference service {service.display_service_name} is {service_log_meta.service_status.value}."
593
582
  )
594
- elif service.deployment_step == DeploymentStep.MODEL_LOGGING:
583
+ elif service.deployment_step == deployment_step.DeploymentStep.MODEL_LOGGING:
595
584
  module_logger.info(
596
585
  f"Model logger service {service.display_service_name} is "
597
586
  f"{service_log_meta.service_status.value}."
598
587
  )
599
588
  for status in statuses:
600
589
  if status.instance_id is not None:
601
- instance_status, container_status = None, None
602
- if status.instance_status is not None:
603
- instance_status = status.instance_status.value
604
- if status.container_status is not None:
605
- container_status = status.container_status.value
606
590
  module_logger.info(
607
591
  f"Instance[{status.instance_id}]: "
608
- f"instance status: {instance_status}, "
609
- f"container status: {container_status}, "
592
+ f"instance status: {status.instance_status}, "
593
+ f"container status: {status.container_status}, "
610
594
  f"message: {status.message}"
611
595
  )
612
596
  time.sleep(5)
@@ -627,7 +611,7 @@ class ServiceOperator:
627
611
  if service_status == service_sql.ServiceStatus.DONE:
628
612
  # check if model logger service is done
629
613
  # and transition the service log metadata to the model image build service
630
- if service.deployment_step == DeploymentStep.MODEL_LOGGING:
614
+ if service.deployment_step == deployment_step.DeploymentStep.MODEL_LOGGING:
631
615
  if model_build_service:
632
616
  # building the inference image, transition to the model build service
633
617
  service_log_meta.transition_service_log_metadata(
@@ -648,7 +632,7 @@ class ServiceOperator:
648
632
  )
649
633
  # check if model build service is done
650
634
  # and transition the service log metadata to the model inference service
651
- elif service.deployment_step == DeploymentStep.MODEL_BUILD:
635
+ elif service.deployment_step == deployment_step.DeploymentStep.MODEL_BUILD:
652
636
  service_log_meta.transition_service_log_metadata(
653
637
  model_inference_service,
654
638
  f"Image build service {service.display_service_name} complete.",
@@ -656,7 +640,7 @@ class ServiceOperator:
656
640
  is_model_logger_service_done=service_log_meta.is_model_logger_service_done,
657
641
  operation_id=operation_id,
658
642
  )
659
- elif service.deployment_step == DeploymentStep.MODEL_INFERENCE:
643
+ elif service.deployment_step == deployment_step.DeploymentStep.MODEL_INFERENCE:
660
644
  module_logger.info(f"Inference service {service.display_service_name} is deployed.")
661
645
  else:
662
646
  module_logger.warning(f"Service {service.display_service_name} is done, but not transitioning.")
@@ -916,19 +900,6 @@ class ServiceOperator:
916
900
 
917
901
  time.sleep(2) # Poll every 2 seconds
918
902
 
919
- @staticmethod
920
- def _get_service_id_from_deployment_step(query_id: str, deployment_step: DeploymentStep) -> str:
921
- """Get the service ID through the server-side logic."""
922
- uuid = query_id.replace("-", "")
923
- big_int = int(uuid, 16)
924
- md5_hash = hashlib.md5(str(big_int).encode()).hexdigest()
925
- identifier = md5_hash[:8]
926
- service_name_prefix = deployment_step.service_name_prefix
927
- if service_name_prefix is None:
928
- # raise an exception if the service name prefix is None
929
- raise ValueError(f"Service name prefix is {service_name_prefix} for deployment step {deployment_step}.")
930
- return (service_name_prefix + identifier).upper()
931
-
932
903
  def _check_if_service_exists(
933
904
  self,
934
905
  database_name: Optional[sql_identifier.SqlIdentifier],
@@ -959,6 +930,38 @@ class ServiceOperator:
959
930
  except exceptions.SnowparkSQLException:
960
931
  return False
961
932
 
933
+ @staticmethod
934
+ def _encode_params(params: Optional[dict[str, Any]]) -> Optional[str]:
935
+ """Encode params dictionary to a base64 string.
936
+
937
+ Args:
938
+ params: Optional dictionary of model inference parameters.
939
+
940
+ Returns:
941
+ Base64 encoded JSON string of the params, or None if input is None.
942
+ """
943
+ if params is None:
944
+ return None
945
+ return base64.b64encode(json.dumps(params).encode("utf-8")).decode("utf-8")
946
+
947
+ @staticmethod
948
+ def _encode_column_handling(
949
+ column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]],
950
+ ) -> Optional[str]:
951
+ """Validate and encode column_handling to a base64 string.
952
+
953
+ Args:
954
+ column_handling: Optional dictionary mapping column names to file encoding options.
955
+
956
+ Returns:
957
+ Base64 encoded JSON string of the column handling options, or None if input is None.
958
+ """
959
+ if column_handling is None:
960
+ return None
961
+ adapter = TypeAdapter(dict[str, batch_inference_specs.ColumnHandlingOptions])
962
+ validated_input = adapter.validate_python(column_handling)
963
+ return base64.b64encode(adapter.dump_json(validated_input)).decode("utf-8")
964
+
962
965
  def invoke_batch_job_method(
963
966
  self,
964
967
  *,
@@ -971,6 +974,9 @@ class ServiceOperator:
971
974
  image_repo_name: Optional[str],
972
975
  input_stage_location: str,
973
976
  input_file_pattern: str,
977
+ column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]],
978
+ params: Optional[dict[str, Any]],
979
+ signature_params: Optional[Sequence[core.BaseParamSpec]],
974
980
  output_stage_location: str,
975
981
  completion_filename: str,
976
982
  force_rebuild: bool,
@@ -981,7 +987,13 @@ class ServiceOperator:
981
987
  gpu_requests: Optional[str],
982
988
  replicas: Optional[int],
983
989
  statement_params: Optional[dict[str, Any]] = None,
990
+ inference_engine_args: Optional[InferenceEngineArgs] = None,
984
991
  ) -> job.MLJob[Any]:
992
+ # Validate and encode params
993
+ param_utils.validate_params(params, signature_params)
994
+ params_encoded = self._encode_params(params)
995
+ column_handling_encoded = self._encode_column_handling(column_handling)
996
+
985
997
  database_name = self._database_name
986
998
  schema_name = self._schema_name
987
999
 
@@ -1007,6 +1019,8 @@ class ServiceOperator:
1007
1019
  max_batch_rows=max_batch_rows,
1008
1020
  input_stage_location=input_stage_location,
1009
1021
  input_file_pattern=input_file_pattern,
1022
+ column_handling=column_handling_encoded,
1023
+ params=params_encoded,
1010
1024
  output_stage_location=output_stage_location,
1011
1025
  completion_filename=completion_filename,
1012
1026
  function_name=function_name,
@@ -1017,11 +1031,17 @@ class ServiceOperator:
1017
1031
  replicas=replicas,
1018
1032
  )
1019
1033
 
1020
- self._model_deployment_spec.add_image_build_spec(
1021
- image_build_compute_pool_name=compute_pool_name,
1022
- fully_qualified_image_repo_name=self._get_image_repo_fqn(image_repo_name, database_name, schema_name),
1023
- force_rebuild=force_rebuild,
1024
- )
1034
+ if inference_engine_args:
1035
+ self._model_deployment_spec.add_inference_engine_spec(
1036
+ inference_engine=inference_engine_args.inference_engine,
1037
+ inference_engine_args=inference_engine_args.inference_engine_args_override,
1038
+ )
1039
+ else:
1040
+ self._model_deployment_spec.add_image_build_spec(
1041
+ image_build_compute_pool_name=compute_pool_name,
1042
+ fully_qualified_image_repo_name=self._get_image_repo_fqn(image_repo_name, database_name, schema_name),
1043
+ force_rebuild=force_rebuild,
1044
+ )
1025
1045
 
1026
1046
  spec_yaml_str_or_path = self._model_deployment_spec.save()
1027
1047
 
@@ -0,0 +1,23 @@
1
+ from typing import Optional
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from snowflake.ml.model._client.service import model_deployment_spec_schema
6
+
7
+ BaseModel.model_config["protected_namespaces"] = ()
8
+
9
+
10
+ class ModelName(BaseModel):
11
+ model_name: str
12
+ version_name: str
13
+
14
+
15
+ class ModelSpec(BaseModel):
16
+ name: ModelName
17
+ hf_model: Optional[model_deployment_spec_schema.HuggingFaceModel] = None
18
+ log_model_args: Optional[model_deployment_spec_schema.LogModelArgs] = None
19
+
20
+
21
+ class ImportModelSpec(BaseModel):
22
+ compute_pool: str
23
+ models: list[ModelSpec]
@@ -195,6 +195,7 @@ class ModelDeploymentSpec:
195
195
 
196
196
  def add_job_spec(
197
197
  self,
198
+ *,
198
199
  job_name: sql_identifier.SqlIdentifier,
199
200
  inference_compute_pool_name: sql_identifier.SqlIdentifier,
200
201
  function_name: str,
@@ -202,6 +203,8 @@ class ModelDeploymentSpec:
202
203
  output_stage_location: str,
203
204
  completion_filename: str,
204
205
  input_file_pattern: str,
206
+ column_handling: Optional[str] = None,
207
+ params: Optional[str] = None,
205
208
  warehouse: sql_identifier.SqlIdentifier,
206
209
  job_database_name: Optional[sql_identifier.SqlIdentifier] = None,
207
210
  job_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
@@ -217,14 +220,16 @@ class ModelDeploymentSpec:
217
220
  Args:
218
221
  job_name: Name of the job.
219
222
  inference_compute_pool_name: Compute pool for inference.
220
- warehouse: Warehouse for the job.
221
223
  function_name: Function name.
222
224
  input_stage_location: Stage location for input data.
223
225
  output_stage_location: Stage location for output data.
226
+ completion_filename: Name of completion file (default: "completion.txt").
227
+ input_file_pattern: Pattern for input files (optional).
228
+ column_handling: Column handling mode for input data.
229
+ params: Additional parameters for the job.
230
+ warehouse: Warehouse for the job.
224
231
  job_database_name: Database name for the job.
225
232
  job_schema_name: Schema name for the job.
226
- input_file_pattern: Pattern for input files (optional).
227
- completion_filename: Name of completion file (default: "completion.txt").
228
233
  cpu: CPU requirement.
229
234
  memory: Memory requirement.
230
235
  gpu: GPU requirement.
@@ -259,7 +264,10 @@ class ModelDeploymentSpec:
259
264
  warehouse=warehouse.identifier() if warehouse else None,
260
265
  function_name=function_name,
261
266
  input=model_deployment_spec_schema.Input(
262
- input_stage_location=input_stage_location, input_file_pattern=input_file_pattern
267
+ input_stage_location=input_stage_location,
268
+ input_file_pattern=input_file_pattern,
269
+ column_handling=column_handling,
270
+ params=params,
263
271
  ),
264
272
  output=model_deployment_spec_schema.Output(
265
273
  output_stage_location=output_stage_location,
@@ -355,7 +363,7 @@ class ModelDeploymentSpec:
355
363
  inference_engine: inference_engine_module.InferenceEngine,
356
364
  inference_engine_args: Optional[list[str]] = None,
357
365
  ) -> "ModelDeploymentSpec":
358
- """Add inference engine specification. This must be called after self.add_service_spec().
366
+ """Add inference engine specification. This must be called after self.add_service_spec() or self.add_job_spec().
359
367
 
360
368
  Args:
361
369
  inference_engine: Inference engine.
@@ -368,9 +376,10 @@ class ModelDeploymentSpec:
368
376
  ValueError: If inference engine specification is called before add_service_spec().
369
377
  ValueError: If the argument does not have a '--' prefix.
370
378
  """
371
- # TODO: needs to eventually support job deployment spec
372
- if self._service is None:
373
- raise ValueError("Inference engine specification must be called after add_service_spec().")
379
+ if self._service is None and self._job is None:
380
+ raise ValueError(
381
+ "Inference engine specification must be called after add_service_spec() or add_job_spec()."
382
+ )
374
383
 
375
384
  if inference_engine_args is None:
376
385
  inference_engine_args = []
@@ -423,11 +432,17 @@ class ModelDeploymentSpec:
423
432
 
424
433
  inference_engine_args = filtered_args
425
434
 
426
- self._service.inference_engine_spec = model_deployment_spec_schema.InferenceEngineSpec(
435
+ inference_engine_spec = model_deployment_spec_schema.InferenceEngineSpec(
427
436
  # convert to string to be saved in the deployment spec
428
437
  inference_engine_name=inference_engine.value,
429
438
  inference_engine_args=inference_engine_args,
430
439
  )
440
+
441
+ if self._service:
442
+ self._service.inference_engine_spec = inference_engine_spec
443
+ elif self._job:
444
+ self._job.inference_engine_spec = inference_engine_spec
445
+
431
446
  return self
432
447
 
433
448
  def save(self) -> str:
@@ -39,6 +39,8 @@ class Service(BaseModel):
39
39
  class Input(BaseModel):
40
40
  input_stage_location: str
41
41
  input_file_pattern: str
42
+ column_handling: Optional[str] = None
43
+ params: Optional[str] = None
42
44
 
43
45
 
44
46
  class Output(BaseModel):
@@ -59,6 +61,7 @@ class Job(BaseModel):
59
61
  input: Input
60
62
  output: Output
61
63
  replicas: Optional[int] = None
64
+ inference_engine_spec: Optional[InferenceEngineSpec] = None
62
65
 
63
66
 
64
67
  class LogModelArgs(BaseModel):
@@ -74,6 +77,7 @@ class HuggingFaceModel(BaseModel):
74
77
  task: Optional[str] = None
75
78
  tokenizer: Optional[str] = None
76
79
  token: Optional[str] = None
80
+ token_secret_object: Optional[str] = None
77
81
  trust_remote_code: Optional[bool] = False
78
82
  revision: Optional[str] = None
79
83
  hf_model_kwargs: Optional[str] = "{}"