snowflake-ml-python 1.21.0__py3-none-any.whl → 1.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/url.py +42 -0
- snowflake/ml/experiment/callback/__init__.py +0 -0
- snowflake/ml/jobs/_utils/constants.py +1 -0
- snowflake/ml/jobs/_utils/spec_utils.py +0 -31
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
- snowflake/ml/model/_client/model/model_version_impl.py +63 -0
- snowflake/ml/model/_client/ops/deployment_step.py +36 -0
- snowflake/ml/model/_client/ops/model_ops.py +61 -2
- snowflake/ml/model/_client/ops/service_ops.py +23 -48
- snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +12 -4
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -0
- snowflake/ml/model/_client/sql/model_version.py +30 -6
- snowflake/ml/model/_client/sql/service.py +26 -4
- snowflake/ml/model/_model_composer/model_composer.py +1 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
- snowflake/ml/model/_model_composer/model_method/model_method.py +61 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
- snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
- snowflake/ml/model/_packager/model_packager.py +1 -1
- snowflake/ml/model/_signatures/core.py +85 -0
- snowflake/ml/model/code_path.py +104 -0
- snowflake/ml/model/custom_model.py +55 -13
- snowflake/ml/model/model_signature.py +13 -1
- snowflake/ml/model/type_hints.py +2 -0
- snowflake/ml/registry/_manager/model_manager.py +230 -15
- snowflake/ml/registry/registry.py +4 -4
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/METADATA +29 -1
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/RECORD +36 -32
- snowflake/ml/jobs/_utils/runtime_env_utils.py +0 -63
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,4 @@
|
|
|
1
1
|
import dataclasses
|
|
2
|
-
import enum
|
|
3
|
-
import hashlib
|
|
4
2
|
import logging
|
|
5
3
|
import pathlib
|
|
6
4
|
import re
|
|
@@ -16,6 +14,7 @@ from snowflake.ml._internal.utils import identifier, service_logger, sql_identif
|
|
|
16
14
|
from snowflake.ml.jobs import job
|
|
17
15
|
from snowflake.ml.model import inference_engine as inference_engine_module, type_hints
|
|
18
16
|
from snowflake.ml.model._client.model import batch_inference_specs
|
|
17
|
+
from snowflake.ml.model._client.ops import deployment_step
|
|
19
18
|
from snowflake.ml.model._client.service import model_deployment_spec
|
|
20
19
|
from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
|
|
21
20
|
from snowflake.snowpark import async_job, exceptions, row, session
|
|
@@ -25,32 +24,12 @@ module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY
|
|
|
25
24
|
module_logger.propagate = False
|
|
26
25
|
|
|
27
26
|
|
|
28
|
-
class DeploymentStep(enum.Enum):
|
|
29
|
-
MODEL_BUILD = ("model-build", "model_build_")
|
|
30
|
-
MODEL_INFERENCE = ("model-inference", None)
|
|
31
|
-
MODEL_LOGGING = ("model-logging", "model_logging_")
|
|
32
|
-
|
|
33
|
-
def __init__(self, container_name: str, service_name_prefix: Optional[str]) -> None:
|
|
34
|
-
self._container_name = container_name
|
|
35
|
-
self._service_name_prefix = service_name_prefix
|
|
36
|
-
|
|
37
|
-
@property
|
|
38
|
-
def container_name(self) -> str:
|
|
39
|
-
"""Get the container name for the deployment step."""
|
|
40
|
-
return self._container_name
|
|
41
|
-
|
|
42
|
-
@property
|
|
43
|
-
def service_name_prefix(self) -> Optional[str]:
|
|
44
|
-
"""Get the service name prefix for the deployment step."""
|
|
45
|
-
return self._service_name_prefix
|
|
46
|
-
|
|
47
|
-
|
|
48
27
|
@dataclasses.dataclass
|
|
49
28
|
class ServiceLogInfo:
|
|
50
29
|
database_name: Optional[sql_identifier.SqlIdentifier]
|
|
51
30
|
schema_name: Optional[sql_identifier.SqlIdentifier]
|
|
52
31
|
service_name: sql_identifier.SqlIdentifier
|
|
53
|
-
deployment_step: DeploymentStep
|
|
32
|
+
deployment_step: deployment_step.DeploymentStep
|
|
54
33
|
instance_id: str = "0"
|
|
55
34
|
log_color: service_logger.LogColor = service_logger.LogColor.GREY
|
|
56
35
|
|
|
@@ -353,13 +332,16 @@ class ServiceOperator:
|
|
|
353
332
|
if is_enable_image_build:
|
|
354
333
|
# stream service logs in a thread
|
|
355
334
|
model_build_service_name = sql_identifier.SqlIdentifier(
|
|
356
|
-
|
|
335
|
+
deployment_step.get_service_id_from_deployment_step(
|
|
336
|
+
query_id,
|
|
337
|
+
deployment_step.DeploymentStep.MODEL_BUILD,
|
|
338
|
+
)
|
|
357
339
|
)
|
|
358
340
|
model_build_service = ServiceLogInfo(
|
|
359
341
|
database_name=service_database_name,
|
|
360
342
|
schema_name=service_schema_name,
|
|
361
343
|
service_name=model_build_service_name,
|
|
362
|
-
deployment_step=DeploymentStep.MODEL_BUILD,
|
|
344
|
+
deployment_step=deployment_step.DeploymentStep.MODEL_BUILD,
|
|
363
345
|
log_color=service_logger.LogColor.GREEN,
|
|
364
346
|
)
|
|
365
347
|
|
|
@@ -367,21 +349,23 @@ class ServiceOperator:
|
|
|
367
349
|
database_name=service_database_name,
|
|
368
350
|
schema_name=service_schema_name,
|
|
369
351
|
service_name=service_name,
|
|
370
|
-
deployment_step=DeploymentStep.MODEL_INFERENCE,
|
|
352
|
+
deployment_step=deployment_step.DeploymentStep.MODEL_INFERENCE,
|
|
371
353
|
log_color=service_logger.LogColor.BLUE,
|
|
372
354
|
)
|
|
373
355
|
|
|
374
356
|
model_logger_service: Optional[ServiceLogInfo] = None
|
|
375
357
|
if hf_model_args:
|
|
376
358
|
model_logger_service_name = sql_identifier.SqlIdentifier(
|
|
377
|
-
|
|
359
|
+
deployment_step.get_service_id_from_deployment_step(
|
|
360
|
+
query_id, deployment_step.DeploymentStep.MODEL_LOGGING
|
|
361
|
+
)
|
|
378
362
|
)
|
|
379
363
|
|
|
380
364
|
model_logger_service = ServiceLogInfo(
|
|
381
365
|
database_name=service_database_name,
|
|
382
366
|
schema_name=service_schema_name,
|
|
383
367
|
service_name=model_logger_service_name,
|
|
384
|
-
deployment_step=DeploymentStep.MODEL_LOGGING,
|
|
368
|
+
deployment_step=deployment_step.DeploymentStep.MODEL_LOGGING,
|
|
385
369
|
log_color=service_logger.LogColor.ORANGE,
|
|
386
370
|
)
|
|
387
371
|
|
|
@@ -536,7 +520,7 @@ class ServiceOperator:
|
|
|
536
520
|
service = service_log_meta.service
|
|
537
521
|
# check if using an existing model build image
|
|
538
522
|
if (
|
|
539
|
-
service.deployment_step == DeploymentStep.MODEL_BUILD
|
|
523
|
+
service.deployment_step == deployment_step.DeploymentStep.MODEL_BUILD
|
|
540
524
|
and not force_rebuild
|
|
541
525
|
and service_log_meta.is_model_logger_service_done
|
|
542
526
|
and not service_log_meta.is_model_build_service_done
|
|
@@ -582,16 +566,16 @@ class ServiceOperator:
|
|
|
582
566
|
if (service_status != service_sql.ServiceStatus.RUNNING) or (service_status != service_log_meta.service_status):
|
|
583
567
|
service_log_meta.service_status = service_status
|
|
584
568
|
|
|
585
|
-
if service.deployment_step == DeploymentStep.MODEL_BUILD:
|
|
569
|
+
if service.deployment_step == deployment_step.DeploymentStep.MODEL_BUILD:
|
|
586
570
|
module_logger.info(
|
|
587
571
|
f"Image build service {service.display_service_name} is "
|
|
588
572
|
f"{service_log_meta.service_status.value}."
|
|
589
573
|
)
|
|
590
|
-
elif service.deployment_step == DeploymentStep.MODEL_INFERENCE:
|
|
574
|
+
elif service.deployment_step == deployment_step.DeploymentStep.MODEL_INFERENCE:
|
|
591
575
|
module_logger.info(
|
|
592
576
|
f"Inference service {service.display_service_name} is {service_log_meta.service_status.value}."
|
|
593
577
|
)
|
|
594
|
-
elif service.deployment_step == DeploymentStep.MODEL_LOGGING:
|
|
578
|
+
elif service.deployment_step == deployment_step.DeploymentStep.MODEL_LOGGING:
|
|
595
579
|
module_logger.info(
|
|
596
580
|
f"Model logger service {service.display_service_name} is "
|
|
597
581
|
f"{service_log_meta.service_status.value}."
|
|
@@ -627,7 +611,7 @@ class ServiceOperator:
|
|
|
627
611
|
if service_status == service_sql.ServiceStatus.DONE:
|
|
628
612
|
# check if model logger service is done
|
|
629
613
|
# and transition the service log metadata to the model image build service
|
|
630
|
-
if service.deployment_step == DeploymentStep.MODEL_LOGGING:
|
|
614
|
+
if service.deployment_step == deployment_step.DeploymentStep.MODEL_LOGGING:
|
|
631
615
|
if model_build_service:
|
|
632
616
|
# building the inference image, transition to the model build service
|
|
633
617
|
service_log_meta.transition_service_log_metadata(
|
|
@@ -648,7 +632,7 @@ class ServiceOperator:
|
|
|
648
632
|
)
|
|
649
633
|
# check if model build service is done
|
|
650
634
|
# and transition the service log metadata to the model inference service
|
|
651
|
-
elif service.deployment_step == DeploymentStep.MODEL_BUILD:
|
|
635
|
+
elif service.deployment_step == deployment_step.DeploymentStep.MODEL_BUILD:
|
|
652
636
|
service_log_meta.transition_service_log_metadata(
|
|
653
637
|
model_inference_service,
|
|
654
638
|
f"Image build service {service.display_service_name} complete.",
|
|
@@ -656,7 +640,7 @@ class ServiceOperator:
|
|
|
656
640
|
is_model_logger_service_done=service_log_meta.is_model_logger_service_done,
|
|
657
641
|
operation_id=operation_id,
|
|
658
642
|
)
|
|
659
|
-
elif service.deployment_step == DeploymentStep.MODEL_INFERENCE:
|
|
643
|
+
elif service.deployment_step == deployment_step.DeploymentStep.MODEL_INFERENCE:
|
|
660
644
|
module_logger.info(f"Inference service {service.display_service_name} is deployed.")
|
|
661
645
|
else:
|
|
662
646
|
module_logger.warning(f"Service {service.display_service_name} is done, but not transitioning.")
|
|
@@ -916,19 +900,6 @@ class ServiceOperator:
|
|
|
916
900
|
|
|
917
901
|
time.sleep(2) # Poll every 2 seconds
|
|
918
902
|
|
|
919
|
-
@staticmethod
|
|
920
|
-
def _get_service_id_from_deployment_step(query_id: str, deployment_step: DeploymentStep) -> str:
|
|
921
|
-
"""Get the service ID through the server-side logic."""
|
|
922
|
-
uuid = query_id.replace("-", "")
|
|
923
|
-
big_int = int(uuid, 16)
|
|
924
|
-
md5_hash = hashlib.md5(str(big_int).encode()).hexdigest()
|
|
925
|
-
identifier = md5_hash[:8]
|
|
926
|
-
service_name_prefix = deployment_step.service_name_prefix
|
|
927
|
-
if service_name_prefix is None:
|
|
928
|
-
# raise an exception if the service name prefix is None
|
|
929
|
-
raise ValueError(f"Service name prefix is {service_name_prefix} for deployment step {deployment_step}.")
|
|
930
|
-
return (service_name_prefix + identifier).upper()
|
|
931
|
-
|
|
932
903
|
def _check_if_service_exists(
|
|
933
904
|
self,
|
|
934
905
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
|
@@ -971,6 +942,8 @@ class ServiceOperator:
|
|
|
971
942
|
image_repo_name: Optional[str],
|
|
972
943
|
input_stage_location: str,
|
|
973
944
|
input_file_pattern: str,
|
|
945
|
+
column_handling: Optional[str],
|
|
946
|
+
params: Optional[str],
|
|
974
947
|
output_stage_location: str,
|
|
975
948
|
completion_filename: str,
|
|
976
949
|
force_rebuild: bool,
|
|
@@ -1007,6 +980,8 @@ class ServiceOperator:
|
|
|
1007
980
|
max_batch_rows=max_batch_rows,
|
|
1008
981
|
input_stage_location=input_stage_location,
|
|
1009
982
|
input_file_pattern=input_file_pattern,
|
|
983
|
+
column_handling=column_handling,
|
|
984
|
+
params=params,
|
|
1010
985
|
output_stage_location=output_stage_location,
|
|
1011
986
|
completion_filename=completion_filename,
|
|
1012
987
|
function_name=function_name,
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
from snowflake.ml.model._client.service import model_deployment_spec_schema
|
|
6
|
+
|
|
7
|
+
BaseModel.model_config["protected_namespaces"] = ()
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ModelName(BaseModel):
|
|
11
|
+
model_name: str
|
|
12
|
+
version_name: str
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ModelSpec(BaseModel):
|
|
16
|
+
name: ModelName
|
|
17
|
+
hf_model: Optional[model_deployment_spec_schema.HuggingFaceModel] = None
|
|
18
|
+
log_model_args: Optional[model_deployment_spec_schema.LogModelArgs] = None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ImportModelSpec(BaseModel):
|
|
22
|
+
compute_pool: str
|
|
23
|
+
models: list[ModelSpec]
|
|
@@ -195,6 +195,7 @@ class ModelDeploymentSpec:
|
|
|
195
195
|
|
|
196
196
|
def add_job_spec(
|
|
197
197
|
self,
|
|
198
|
+
*,
|
|
198
199
|
job_name: sql_identifier.SqlIdentifier,
|
|
199
200
|
inference_compute_pool_name: sql_identifier.SqlIdentifier,
|
|
200
201
|
function_name: str,
|
|
@@ -202,6 +203,8 @@ class ModelDeploymentSpec:
|
|
|
202
203
|
output_stage_location: str,
|
|
203
204
|
completion_filename: str,
|
|
204
205
|
input_file_pattern: str,
|
|
206
|
+
column_handling: Optional[str] = None,
|
|
207
|
+
params: Optional[str] = None,
|
|
205
208
|
warehouse: sql_identifier.SqlIdentifier,
|
|
206
209
|
job_database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
207
210
|
job_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
@@ -217,14 +220,16 @@ class ModelDeploymentSpec:
|
|
|
217
220
|
Args:
|
|
218
221
|
job_name: Name of the job.
|
|
219
222
|
inference_compute_pool_name: Compute pool for inference.
|
|
220
|
-
warehouse: Warehouse for the job.
|
|
221
223
|
function_name: Function name.
|
|
222
224
|
input_stage_location: Stage location for input data.
|
|
223
225
|
output_stage_location: Stage location for output data.
|
|
226
|
+
completion_filename: Name of completion file (default: "completion.txt").
|
|
227
|
+
input_file_pattern: Pattern for input files (optional).
|
|
228
|
+
column_handling: Column handling mode for input data.
|
|
229
|
+
params: Additional parameters for the job.
|
|
230
|
+
warehouse: Warehouse for the job.
|
|
224
231
|
job_database_name: Database name for the job.
|
|
225
232
|
job_schema_name: Schema name for the job.
|
|
226
|
-
input_file_pattern: Pattern for input files (optional).
|
|
227
|
-
completion_filename: Name of completion file (default: "completion.txt").
|
|
228
233
|
cpu: CPU requirement.
|
|
229
234
|
memory: Memory requirement.
|
|
230
235
|
gpu: GPU requirement.
|
|
@@ -259,7 +264,10 @@ class ModelDeploymentSpec:
|
|
|
259
264
|
warehouse=warehouse.identifier() if warehouse else None,
|
|
260
265
|
function_name=function_name,
|
|
261
266
|
input=model_deployment_spec_schema.Input(
|
|
262
|
-
input_stage_location=input_stage_location,
|
|
267
|
+
input_stage_location=input_stage_location,
|
|
268
|
+
input_file_pattern=input_file_pattern,
|
|
269
|
+
column_handling=column_handling,
|
|
270
|
+
params=params,
|
|
263
271
|
),
|
|
264
272
|
output=model_deployment_spec_schema.Output(
|
|
265
273
|
output_stage_location=output_stage_location,
|
|
@@ -39,6 +39,8 @@ class Service(BaseModel):
|
|
|
39
39
|
class Input(BaseModel):
|
|
40
40
|
input_stage_location: str
|
|
41
41
|
input_file_pattern: str
|
|
42
|
+
column_handling: Optional[str] = None
|
|
43
|
+
params: Optional[str] = None
|
|
42
44
|
|
|
43
45
|
|
|
44
46
|
class Output(BaseModel):
|
|
@@ -74,6 +76,7 @@ class HuggingFaceModel(BaseModel):
|
|
|
74
76
|
task: Optional[str] = None
|
|
75
77
|
tokenizer: Optional[str] = None
|
|
76
78
|
token: Optional[str] = None
|
|
79
|
+
token_secret_object: Optional[str] = None
|
|
77
80
|
trust_remote_code: Optional[bool] = False
|
|
78
81
|
revision: Optional[str] = None
|
|
79
82
|
hf_model_kwargs: Optional[str] = "{}"
|
|
@@ -22,6 +22,14 @@ def _normalize_url_for_sql(url: str) -> str:
|
|
|
22
22
|
return f"'{url}'"
|
|
23
23
|
|
|
24
24
|
|
|
25
|
+
def _format_param_value(value: Any) -> str:
|
|
26
|
+
if isinstance(value, str):
|
|
27
|
+
return f"'{snowpark_utils.escape_single_quotes(value)}'" # type: ignore[no-untyped-call]
|
|
28
|
+
elif value is None:
|
|
29
|
+
return "NULL"
|
|
30
|
+
return str(value)
|
|
31
|
+
|
|
32
|
+
|
|
25
33
|
class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
26
34
|
FUNCTION_NAME_COL_NAME = "name"
|
|
27
35
|
FUNCTION_RETURN_TYPE_COL_NAME = "return_type"
|
|
@@ -354,6 +362,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
|
354
362
|
input_args: list[sql_identifier.SqlIdentifier],
|
|
355
363
|
returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
|
356
364
|
statement_params: Optional[dict[str, Any]] = None,
|
|
365
|
+
params: Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]] = None,
|
|
357
366
|
) -> dataframe.DataFrame:
|
|
358
367
|
with_statements = []
|
|
359
368
|
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
|
@@ -392,10 +401,17 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
|
392
401
|
|
|
393
402
|
args_sql = ", ".join(args_sql_list)
|
|
394
403
|
|
|
395
|
-
|
|
404
|
+
if params:
|
|
405
|
+
param_sql = ", ".join(_format_param_value(val) for _, val in params)
|
|
406
|
+
args_sql = f"{args_sql}, {param_sql}" if args_sql else param_sql
|
|
407
|
+
|
|
408
|
+
total_args = len(input_args) + (len(params) if params else 0)
|
|
409
|
+
wide_input = total_args > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
|
|
396
410
|
if wide_input:
|
|
397
|
-
|
|
398
|
-
|
|
411
|
+
parts = [f"'{arg}', {arg.identifier()}" for arg in input_args]
|
|
412
|
+
if params:
|
|
413
|
+
parts.extend(f"'{name}', {_format_param_value(val)}" for name, val in params)
|
|
414
|
+
args_sql = f"object_construct_keep_null({', '.join(parts)})"
|
|
399
415
|
|
|
400
416
|
sql = textwrap.dedent(
|
|
401
417
|
f"""WITH {','.join(with_statements)}
|
|
@@ -439,6 +455,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
|
439
455
|
statement_params: Optional[dict[str, Any]] = None,
|
|
440
456
|
is_partitioned: bool = True,
|
|
441
457
|
explain_case_sensitive: bool = False,
|
|
458
|
+
params: Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]] = None,
|
|
442
459
|
) -> dataframe.DataFrame:
|
|
443
460
|
with_statements = []
|
|
444
461
|
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
|
@@ -477,10 +494,17 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
|
477
494
|
|
|
478
495
|
args_sql = ", ".join(args_sql_list)
|
|
479
496
|
|
|
480
|
-
|
|
497
|
+
if params:
|
|
498
|
+
param_sql = ", ".join(_format_param_value(val) for _, val in params)
|
|
499
|
+
args_sql = f"{args_sql}, {param_sql}" if args_sql else param_sql
|
|
500
|
+
|
|
501
|
+
total_args = len(input_args) + (len(params) if params else 0)
|
|
502
|
+
wide_input = total_args > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
|
|
481
503
|
if wide_input:
|
|
482
|
-
|
|
483
|
-
|
|
504
|
+
parts = [f"'{arg}', {arg.identifier()}" for arg in input_args]
|
|
505
|
+
if params:
|
|
506
|
+
parts.extend(f"'{name}', {_format_param_value(val)}" for name, val in params)
|
|
507
|
+
args_sql = f"object_construct_keep_null({', '.join(parts)})"
|
|
484
508
|
|
|
485
509
|
sql = textwrap.dedent(
|
|
486
510
|
f"""WITH {','.join(with_statements)}
|
|
@@ -20,6 +20,15 @@ from snowflake.snowpark._internal import utils as snowpark_utils
|
|
|
20
20
|
|
|
21
21
|
logger = logging.getLogger(__name__)
|
|
22
22
|
|
|
23
|
+
|
|
24
|
+
def _format_param_value(value: Any) -> str:
|
|
25
|
+
if isinstance(value, str):
|
|
26
|
+
return f"'{snowpark_utils.escape_single_quotes(value)}'" # type: ignore[no-untyped-call]
|
|
27
|
+
elif value is None:
|
|
28
|
+
return "NULL"
|
|
29
|
+
return str(value)
|
|
30
|
+
|
|
31
|
+
|
|
23
32
|
# Using this token instead of '?' to avoid escaping issues
|
|
24
33
|
# After quotes are escaped, we replace this token with '|| ? ||'
|
|
25
34
|
QMARK_RESERVED_TOKEN = "<QMARK_RESERVED_TOKEN>"
|
|
@@ -140,6 +149,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
140
149
|
input_args: list[sql_identifier.SqlIdentifier],
|
|
141
150
|
returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
|
142
151
|
statement_params: Optional[dict[str, Any]] = None,
|
|
152
|
+
params: Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]] = None,
|
|
143
153
|
) -> dataframe.DataFrame:
|
|
144
154
|
with_statements = []
|
|
145
155
|
actual_database_name = database_name or self._database_name
|
|
@@ -170,10 +180,17 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
170
180
|
args_sql_list.append(input_arg_value)
|
|
171
181
|
args_sql = ", ".join(args_sql_list)
|
|
172
182
|
|
|
173
|
-
|
|
183
|
+
if params:
|
|
184
|
+
param_sql = ", ".join(_format_param_value(val) for _, val in params)
|
|
185
|
+
args_sql = f"{args_sql}, {param_sql}" if args_sql else param_sql
|
|
186
|
+
|
|
187
|
+
total_args = len(input_args) + (len(params) if params else 0)
|
|
188
|
+
wide_input = total_args > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
|
|
174
189
|
if wide_input:
|
|
175
|
-
|
|
176
|
-
|
|
190
|
+
parts = [f"'{arg}', {arg.identifier()}" for arg in input_args]
|
|
191
|
+
if params:
|
|
192
|
+
parts.extend(f"'{name}', {_format_param_value(val)}" for name, val in params)
|
|
193
|
+
args_sql = f"object_construct_keep_null({', '.join(parts)})"
|
|
177
194
|
|
|
178
195
|
fully_qualified_service_name = self.fully_qualified_object_name(
|
|
179
196
|
actual_database_name, actual_schema_name, service_name
|
|
@@ -301,7 +318,12 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
301
318
|
False if service doesn't have proxy container
|
|
302
319
|
"""
|
|
303
320
|
try:
|
|
304
|
-
|
|
321
|
+
spec_yaml = row[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME]
|
|
322
|
+
if spec_yaml is None:
|
|
323
|
+
return False
|
|
324
|
+
spec_raw = yaml.safe_load(spec_yaml)
|
|
325
|
+
if spec_raw is None:
|
|
326
|
+
return False
|
|
305
327
|
spec = cast(dict[str, Any], spec_raw)
|
|
306
328
|
|
|
307
329
|
proxy_container_spec = next(
|
|
@@ -131,7 +131,7 @@ class ModelComposer:
|
|
|
131
131
|
python_version: Optional[str] = None,
|
|
132
132
|
user_files: Optional[dict[str, list[str]]] = None,
|
|
133
133
|
ext_modules: Optional[list[ModuleType]] = None,
|
|
134
|
-
code_paths: Optional[list[
|
|
134
|
+
code_paths: Optional[list[model_types.CodePathLike]] = None,
|
|
135
135
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
|
136
136
|
experiment_info: Optional["ExperimentInfo"] = None,
|
|
137
137
|
options: Optional[model_types.ModelSaveOption] = None,
|
|
@@ -39,6 +39,10 @@ class ModelMethodSignatureFieldWithName(ModelMethodSignatureField):
|
|
|
39
39
|
name: Required[str]
|
|
40
40
|
|
|
41
41
|
|
|
42
|
+
class ModelMethodSignatureFieldWithNameAndDefault(ModelMethodSignatureFieldWithName):
|
|
43
|
+
default: Required[Any]
|
|
44
|
+
|
|
45
|
+
|
|
42
46
|
class ModelFunctionMethodDict(TypedDict):
|
|
43
47
|
name: Required[str]
|
|
44
48
|
runtime: Required[str]
|
|
@@ -46,6 +50,7 @@ class ModelFunctionMethodDict(TypedDict):
|
|
|
46
50
|
handler: Required[str]
|
|
47
51
|
inputs: Required[list[ModelMethodSignatureFieldWithName]]
|
|
48
52
|
outputs: Required[Union[list[ModelMethodSignatureField], list[ModelMethodSignatureFieldWithName]]]
|
|
53
|
+
params: NotRequired[list[ModelMethodSignatureFieldWithNameAndDefault]]
|
|
49
54
|
volatility: NotRequired[str]
|
|
50
55
|
|
|
51
56
|
|
|
@@ -105,7 +105,7 @@ class ModelMethod:
|
|
|
105
105
|
except ValueError as e:
|
|
106
106
|
raise ValueError(
|
|
107
107
|
f"Your target method {self.target_method} cannot be resolved as valid SQL identifier. "
|
|
108
|
-
"Try
|
|
108
|
+
"Try specifying `case_sensitive` as True."
|
|
109
109
|
) from e
|
|
110
110
|
|
|
111
111
|
if self.target_method not in self.model_meta.signatures.keys():
|
|
@@ -127,12 +127,41 @@ class ModelMethod:
|
|
|
127
127
|
except ValueError as e:
|
|
128
128
|
raise ValueError(
|
|
129
129
|
f"Your feature {feature.name} cannot be resolved as valid SQL identifier. "
|
|
130
|
-
"Try
|
|
130
|
+
"Try specifying `case_sensitive` as True."
|
|
131
131
|
) from e
|
|
132
132
|
return model_manifest_schema.ModelMethodSignatureFieldWithName(
|
|
133
133
|
name=feature_name.resolved(), type=type_utils.convert_sp_to_sf_type(feature.as_snowpark_type())
|
|
134
134
|
)
|
|
135
135
|
|
|
136
|
+
@staticmethod
|
|
137
|
+
def _flatten_params(params: list[model_signature.BaseParamSpec]) -> list[model_signature.ParamSpec]:
|
|
138
|
+
"""Flatten ParamGroupSpec into leaf ParamSpec items."""
|
|
139
|
+
result: list[model_signature.ParamSpec] = []
|
|
140
|
+
for param in params:
|
|
141
|
+
if isinstance(param, model_signature.ParamSpec):
|
|
142
|
+
result.append(param)
|
|
143
|
+
elif isinstance(param, model_signature.ParamGroupSpec):
|
|
144
|
+
result.extend(ModelMethod._flatten_params(param.specs))
|
|
145
|
+
return result
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def _get_method_arg_from_param(
|
|
149
|
+
param_spec: model_signature.ParamSpec,
|
|
150
|
+
case_sensitive: bool = False,
|
|
151
|
+
) -> model_manifest_schema.ModelMethodSignatureFieldWithNameAndDefault:
|
|
152
|
+
try:
|
|
153
|
+
param_name = sql_identifier.SqlIdentifier(param_spec.name, case_sensitive=case_sensitive)
|
|
154
|
+
except ValueError as e:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
f"Your parameter {param_spec.name} cannot be resolved as valid SQL identifier. "
|
|
157
|
+
"Try specifying `case_sensitive` as True."
|
|
158
|
+
) from e
|
|
159
|
+
return model_manifest_schema.ModelMethodSignatureFieldWithNameAndDefault(
|
|
160
|
+
name=param_name.resolved(),
|
|
161
|
+
type=type_utils.convert_sp_to_sf_type(param_spec.dtype.as_snowpark_type()),
|
|
162
|
+
default=param_spec.default_value,
|
|
163
|
+
)
|
|
164
|
+
|
|
136
165
|
def save(
|
|
137
166
|
self, workspace_path: pathlib.Path, options: Optional[function_generator.FunctionGenerateOptions] = None
|
|
138
167
|
) -> model_manifest_schema.ModelMethodDict:
|
|
@@ -182,6 +211,36 @@ class ModelMethod:
|
|
|
182
211
|
inputs=input_list,
|
|
183
212
|
outputs=outputs,
|
|
184
213
|
)
|
|
214
|
+
|
|
215
|
+
# Add parameters if signature has parameters
|
|
216
|
+
if self.model_meta.signatures[self.target_method].params:
|
|
217
|
+
flat_params = ModelMethod._flatten_params(list(self.model_meta.signatures[self.target_method].params))
|
|
218
|
+
param_list = [
|
|
219
|
+
ModelMethod._get_method_arg_from_param(
|
|
220
|
+
param_spec, case_sensitive=self.options.get("case_sensitive", False)
|
|
221
|
+
)
|
|
222
|
+
for param_spec in flat_params
|
|
223
|
+
]
|
|
224
|
+
param_name_counter = collections.Counter([param_info["name"] for param_info in param_list])
|
|
225
|
+
dup_param_names = [k for k, v in param_name_counter.items() if v > 1]
|
|
226
|
+
if dup_param_names:
|
|
227
|
+
raise ValueError(
|
|
228
|
+
f"Found duplicate parameter named resolved as {', '.join(dup_param_names)} in the method"
|
|
229
|
+
f" {self.target_method}. This might be because you have parameters with same letters but "
|
|
230
|
+
"different cases. In this case, set case_sensitive as True for those methods to distinguish them."
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
# Check for name collisions between parameters and inputs using existing counters
|
|
234
|
+
collision_names = [name for name in param_name_counter if name in input_name_counter]
|
|
235
|
+
if collision_names:
|
|
236
|
+
raise ValueError(
|
|
237
|
+
f"Found parameter(s) with the same name as input feature(s): {', '.join(sorted(collision_names))} "
|
|
238
|
+
f"in the method {self.target_method}. Parameters and inputs must have distinct names. "
|
|
239
|
+
"Try using case_sensitive=True if the names differ only by case."
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
method_dict["params"] = param_list
|
|
243
|
+
|
|
185
244
|
should_set_volatility = (
|
|
186
245
|
platform_capabilities.PlatformCapabilities.get_instance().is_set_module_functions_volatility_from_manifest()
|
|
187
246
|
)
|
|
@@ -86,6 +86,9 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
|
86
86
|
get_prediction_fn=get_prediction,
|
|
87
87
|
)
|
|
88
88
|
|
|
89
|
+
# Add parameters extracted from custom model inference methods to signatures
|
|
90
|
+
cls._add_method_parameters_to_signatures(model, model_meta)
|
|
91
|
+
|
|
89
92
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
|
90
93
|
os.makedirs(model_blob_path, exist_ok=True)
|
|
91
94
|
if model.context.artifacts:
|
|
@@ -188,6 +191,55 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
|
188
191
|
assert isinstance(model, custom_model.CustomModel)
|
|
189
192
|
return model
|
|
190
193
|
|
|
194
|
+
@classmethod
|
|
195
|
+
def _add_method_parameters_to_signatures(
|
|
196
|
+
cls,
|
|
197
|
+
model: "custom_model.CustomModel",
|
|
198
|
+
model_meta: model_meta_api.ModelMetadata,
|
|
199
|
+
) -> None:
|
|
200
|
+
"""Extract parameters from custom model inference methods and add them to signatures.
|
|
201
|
+
|
|
202
|
+
For each inference method, if the signature doesn't already have parameters and the method
|
|
203
|
+
has keyword-only parameters with defaults, create ParamSpecs and add them to the signature.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
model: The custom model instance.
|
|
207
|
+
model_meta: The model metadata containing signatures to augment.
|
|
208
|
+
"""
|
|
209
|
+
for method in model._get_infer_methods():
|
|
210
|
+
method_name = method.__name__
|
|
211
|
+
if method_name not in model_meta.signatures:
|
|
212
|
+
continue
|
|
213
|
+
|
|
214
|
+
sig = model_meta.signatures[method_name]
|
|
215
|
+
|
|
216
|
+
# Skip if the signature already has parameters (user-provided or previously set)
|
|
217
|
+
if sig.params:
|
|
218
|
+
continue
|
|
219
|
+
|
|
220
|
+
# Extract parameters from the method
|
|
221
|
+
method_params = custom_model.get_method_parameters(method)
|
|
222
|
+
if not method_params:
|
|
223
|
+
continue
|
|
224
|
+
|
|
225
|
+
# Create ParamSpecs from the method parameters
|
|
226
|
+
param_specs = []
|
|
227
|
+
for param_name, param_type, param_default in method_params:
|
|
228
|
+
dtype = model_signature.DataType.from_python_type(param_type)
|
|
229
|
+
param_spec = model_signature.ParamSpec(
|
|
230
|
+
name=param_name,
|
|
231
|
+
dtype=dtype,
|
|
232
|
+
default_value=param_default,
|
|
233
|
+
)
|
|
234
|
+
param_specs.append(param_spec)
|
|
235
|
+
|
|
236
|
+
# Create a new signature with parameters
|
|
237
|
+
model_meta.signatures[method_name] = model_signature.ModelSignature(
|
|
238
|
+
inputs=sig.inputs,
|
|
239
|
+
outputs=sig.outputs,
|
|
240
|
+
params=param_specs,
|
|
241
|
+
)
|
|
242
|
+
|
|
191
243
|
@classmethod
|
|
192
244
|
def convert_as_custom_model(
|
|
193
245
|
cls,
|
|
@@ -194,7 +194,18 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
|
194
194
|
|
|
195
195
|
if kwargs.get("use_gpu", False):
|
|
196
196
|
assert type(kwargs.get("use_gpu", False)) == bool
|
|
197
|
-
|
|
197
|
+
from packaging import version
|
|
198
|
+
|
|
199
|
+
xgb_version = version.parse(xgboost.__version__)
|
|
200
|
+
if xgb_version >= version.parse("3.1.0"):
|
|
201
|
+
# XGBoost 3.1.0+: Use device="cuda" for GPU acceleration
|
|
202
|
+
# gpu_hist and gpu_predictor were removed in XGBoost 3.1.0
|
|
203
|
+
# See: https://xgboost.readthedocs.io/en/latest/changes/v3.1.0.html
|
|
204
|
+
gpu_params = {"tree_method": "hist", "device": "cuda"}
|
|
205
|
+
else:
|
|
206
|
+
# XGBoost < 3.1.0: Use legacy gpu_hist tree_method
|
|
207
|
+
gpu_params = {"tree_method": "gpu_hist", "predictor": "gpu_predictor"}
|
|
208
|
+
|
|
198
209
|
if isinstance(m, xgboost.Booster):
|
|
199
210
|
m.set_param(gpu_params)
|
|
200
211
|
elif isinstance(m, xgboost.XGBModel):
|
|
@@ -256,6 +267,20 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
|
256
267
|
@custom_model.inference_api
|
|
257
268
|
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
|
258
269
|
import shap
|
|
270
|
+
from packaging import version
|
|
271
|
+
|
|
272
|
+
xgb_version = version.parse(xgboost.__version__)
|
|
273
|
+
shap_version = version.parse(shap.__version__)
|
|
274
|
+
|
|
275
|
+
# SHAP < 0.50.0 is incompatible with XGBoost >= 3.1.0 due to base_score format change
|
|
276
|
+
# (base_score is now stored as a vector for multi-output models)
|
|
277
|
+
# See: https://xgboost.readthedocs.io/en/latest/changes/v3.1.0.html
|
|
278
|
+
if xgb_version >= version.parse("3.1.0") and shap_version < version.parse("0.50.0"):
|
|
279
|
+
raise RuntimeError(
|
|
280
|
+
f"SHAP version {shap.__version__} is incompatible with XGBoost version "
|
|
281
|
+
f"{xgboost.__version__}. XGBoost 3.1+ changed the model format which requires "
|
|
282
|
+
f"SHAP >= 0.50.0. Please upgrade SHAP or use XGBoost < 3.1."
|
|
283
|
+
)
|
|
259
284
|
|
|
260
285
|
explainer = shap.TreeExplainer(raw_model)
|
|
261
286
|
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer.shap_values(X))
|