snowflake-ml-python 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/platform_capabilities.py +36 -0
- snowflake/ml/_internal/utils/url.py +42 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
- snowflake/ml/data/data_connector.py +103 -1
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
- snowflake/ml/experiment/callback/__init__.py +0 -0
- snowflake/ml/experiment/callback/keras.py +25 -2
- snowflake/ml/experiment/callback/lightgbm.py +27 -2
- snowflake/ml/experiment/callback/xgboost.py +25 -2
- snowflake/ml/experiment/experiment_tracking.py +93 -3
- snowflake/ml/experiment/utils.py +6 -0
- snowflake/ml/feature_store/feature_view.py +34 -24
- snowflake/ml/jobs/_interop/protocols.py +3 -0
- snowflake/ml/jobs/_utils/constants.py +1 -0
- snowflake/ml/jobs/_utils/payload_utils.py +354 -356
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
- snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
- snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
- snowflake/ml/jobs/_utils/spec_utils.py +1 -445
- snowflake/ml/jobs/_utils/stage_utils.py +22 -1
- snowflake/ml/jobs/_utils/types.py +14 -7
- snowflake/ml/jobs/job.py +2 -8
- snowflake/ml/jobs/manager.py +57 -135
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
- snowflake/ml/model/_client/model/model_version_impl.py +130 -14
- snowflake/ml/model/_client/ops/deployment_step.py +36 -0
- snowflake/ml/model/_client/ops/model_ops.py +93 -8
- snowflake/ml/model/_client/ops/service_ops.py +32 -52
- snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +12 -4
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -0
- snowflake/ml/model/_client/sql/model_version.py +30 -6
- snowflake/ml/model/_client/sql/service.py +94 -5
- snowflake/ml/model/_model_composer/model_composer.py +1 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
- snowflake/ml/model/_model_composer/model_method/model_method.py +61 -2
- snowflake/ml/model/_packager/model_handler.py +8 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
- snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
- snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
- snowflake/ml/model/_packager/model_packager.py +1 -1
- snowflake/ml/model/_signatures/core.py +390 -8
- snowflake/ml/model/_signatures/utils.py +13 -4
- snowflake/ml/model/code_path.py +104 -0
- snowflake/ml/model/compute_pool.py +2 -0
- snowflake/ml/model/custom_model.py +55 -13
- snowflake/ml/model/model_signature.py +13 -1
- snowflake/ml/model/models/huggingface.py +285 -0
- snowflake/ml/model/models/huggingface_pipeline.py +19 -208
- snowflake/ml/model/type_hints.py +7 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
- snowflake/ml/registry/_manager/model_manager.py +230 -15
- snowflake/ml/registry/registry.py +4 -4
- snowflake/ml/utils/html_utils.py +67 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/METADATA +81 -7
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/RECORD +67 -59
- snowflake/ml/jobs/_utils/runtime_env_utils.py +0 -63
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,4 @@
|
|
|
1
1
|
import dataclasses
|
|
2
|
-
import enum
|
|
3
|
-
import hashlib
|
|
4
2
|
import logging
|
|
5
3
|
import pathlib
|
|
6
4
|
import re
|
|
@@ -11,11 +9,12 @@ import warnings
|
|
|
11
9
|
from typing import Any, Optional, Union, cast
|
|
12
10
|
|
|
13
11
|
from snowflake import snowpark
|
|
14
|
-
from snowflake.ml import jobs
|
|
15
12
|
from snowflake.ml._internal import file_utils, platform_capabilities as pc
|
|
16
13
|
from snowflake.ml._internal.utils import identifier, service_logger, sql_identifier
|
|
14
|
+
from snowflake.ml.jobs import job
|
|
17
15
|
from snowflake.ml.model import inference_engine as inference_engine_module, type_hints
|
|
18
16
|
from snowflake.ml.model._client.model import batch_inference_specs
|
|
17
|
+
from snowflake.ml.model._client.ops import deployment_step
|
|
19
18
|
from snowflake.ml.model._client.service import model_deployment_spec
|
|
20
19
|
from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
|
|
21
20
|
from snowflake.snowpark import async_job, exceptions, row, session
|
|
@@ -25,32 +24,12 @@ module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY
|
|
|
25
24
|
module_logger.propagate = False
|
|
26
25
|
|
|
27
26
|
|
|
28
|
-
class DeploymentStep(enum.Enum):
|
|
29
|
-
MODEL_BUILD = ("model-build", "model_build_")
|
|
30
|
-
MODEL_INFERENCE = ("model-inference", None)
|
|
31
|
-
MODEL_LOGGING = ("model-logging", "model_logging_")
|
|
32
|
-
|
|
33
|
-
def __init__(self, container_name: str, service_name_prefix: Optional[str]) -> None:
|
|
34
|
-
self._container_name = container_name
|
|
35
|
-
self._service_name_prefix = service_name_prefix
|
|
36
|
-
|
|
37
|
-
@property
|
|
38
|
-
def container_name(self) -> str:
|
|
39
|
-
"""Get the container name for the deployment step."""
|
|
40
|
-
return self._container_name
|
|
41
|
-
|
|
42
|
-
@property
|
|
43
|
-
def service_name_prefix(self) -> Optional[str]:
|
|
44
|
-
"""Get the service name prefix for the deployment step."""
|
|
45
|
-
return self._service_name_prefix
|
|
46
|
-
|
|
47
|
-
|
|
48
27
|
@dataclasses.dataclass
|
|
49
28
|
class ServiceLogInfo:
|
|
50
29
|
database_name: Optional[sql_identifier.SqlIdentifier]
|
|
51
30
|
schema_name: Optional[sql_identifier.SqlIdentifier]
|
|
52
31
|
service_name: sql_identifier.SqlIdentifier
|
|
53
|
-
deployment_step: DeploymentStep
|
|
32
|
+
deployment_step: deployment_step.DeploymentStep
|
|
54
33
|
instance_id: str = "0"
|
|
55
34
|
log_color: service_logger.LogColor = service_logger.LogColor.GREY
|
|
56
35
|
|
|
@@ -171,6 +150,7 @@ class ServiceOperator:
|
|
|
171
150
|
self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec(
|
|
172
151
|
workspace_path=pathlib.Path(self._workspace.name)
|
|
173
152
|
)
|
|
153
|
+
self._inference_autocapture_enabled = pc.PlatformCapabilities.get_instance().is_inference_autocapture_enabled()
|
|
174
154
|
|
|
175
155
|
def __eq__(self, __value: object) -> bool:
|
|
176
156
|
if not isinstance(__value, ServiceOperator):
|
|
@@ -207,7 +187,7 @@ class ServiceOperator:
|
|
|
207
187
|
# inference engine model
|
|
208
188
|
inference_engine_args: Optional[InferenceEngineArgs] = None,
|
|
209
189
|
# inference table
|
|
210
|
-
autocapture:
|
|
190
|
+
autocapture: bool = False,
|
|
211
191
|
) -> Union[str, async_job.AsyncJob]:
|
|
212
192
|
|
|
213
193
|
# Generate operation ID for this deployment
|
|
@@ -231,6 +211,10 @@ class ServiceOperator:
|
|
|
231
211
|
progress_status.update("preparing deployment artifacts...")
|
|
232
212
|
progress_status.increment()
|
|
233
213
|
|
|
214
|
+
# If autocapture param is disabled, don't allow create service with autocapture
|
|
215
|
+
if not self._inference_autocapture_enabled and autocapture:
|
|
216
|
+
raise ValueError("Invalid Argument: Autocapture feature is not supported.")
|
|
217
|
+
|
|
234
218
|
if self._workspace:
|
|
235
219
|
stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
|
|
236
220
|
else:
|
|
@@ -348,13 +332,16 @@ class ServiceOperator:
|
|
|
348
332
|
if is_enable_image_build:
|
|
349
333
|
# stream service logs in a thread
|
|
350
334
|
model_build_service_name = sql_identifier.SqlIdentifier(
|
|
351
|
-
|
|
335
|
+
deployment_step.get_service_id_from_deployment_step(
|
|
336
|
+
query_id,
|
|
337
|
+
deployment_step.DeploymentStep.MODEL_BUILD,
|
|
338
|
+
)
|
|
352
339
|
)
|
|
353
340
|
model_build_service = ServiceLogInfo(
|
|
354
341
|
database_name=service_database_name,
|
|
355
342
|
schema_name=service_schema_name,
|
|
356
343
|
service_name=model_build_service_name,
|
|
357
|
-
deployment_step=DeploymentStep.MODEL_BUILD,
|
|
344
|
+
deployment_step=deployment_step.DeploymentStep.MODEL_BUILD,
|
|
358
345
|
log_color=service_logger.LogColor.GREEN,
|
|
359
346
|
)
|
|
360
347
|
|
|
@@ -362,21 +349,23 @@ class ServiceOperator:
|
|
|
362
349
|
database_name=service_database_name,
|
|
363
350
|
schema_name=service_schema_name,
|
|
364
351
|
service_name=service_name,
|
|
365
|
-
deployment_step=DeploymentStep.MODEL_INFERENCE,
|
|
352
|
+
deployment_step=deployment_step.DeploymentStep.MODEL_INFERENCE,
|
|
366
353
|
log_color=service_logger.LogColor.BLUE,
|
|
367
354
|
)
|
|
368
355
|
|
|
369
356
|
model_logger_service: Optional[ServiceLogInfo] = None
|
|
370
357
|
if hf_model_args:
|
|
371
358
|
model_logger_service_name = sql_identifier.SqlIdentifier(
|
|
372
|
-
|
|
359
|
+
deployment_step.get_service_id_from_deployment_step(
|
|
360
|
+
query_id, deployment_step.DeploymentStep.MODEL_LOGGING
|
|
361
|
+
)
|
|
373
362
|
)
|
|
374
363
|
|
|
375
364
|
model_logger_service = ServiceLogInfo(
|
|
376
365
|
database_name=service_database_name,
|
|
377
366
|
schema_name=service_schema_name,
|
|
378
367
|
service_name=model_logger_service_name,
|
|
379
|
-
deployment_step=DeploymentStep.MODEL_LOGGING,
|
|
368
|
+
deployment_step=deployment_step.DeploymentStep.MODEL_LOGGING,
|
|
380
369
|
log_color=service_logger.LogColor.ORANGE,
|
|
381
370
|
)
|
|
382
371
|
|
|
@@ -531,7 +520,7 @@ class ServiceOperator:
|
|
|
531
520
|
service = service_log_meta.service
|
|
532
521
|
# check if using an existing model build image
|
|
533
522
|
if (
|
|
534
|
-
service.deployment_step == DeploymentStep.MODEL_BUILD
|
|
523
|
+
service.deployment_step == deployment_step.DeploymentStep.MODEL_BUILD
|
|
535
524
|
and not force_rebuild
|
|
536
525
|
and service_log_meta.is_model_logger_service_done
|
|
537
526
|
and not service_log_meta.is_model_build_service_done
|
|
@@ -577,16 +566,16 @@ class ServiceOperator:
|
|
|
577
566
|
if (service_status != service_sql.ServiceStatus.RUNNING) or (service_status != service_log_meta.service_status):
|
|
578
567
|
service_log_meta.service_status = service_status
|
|
579
568
|
|
|
580
|
-
if service.deployment_step == DeploymentStep.MODEL_BUILD:
|
|
569
|
+
if service.deployment_step == deployment_step.DeploymentStep.MODEL_BUILD:
|
|
581
570
|
module_logger.info(
|
|
582
571
|
f"Image build service {service.display_service_name} is "
|
|
583
572
|
f"{service_log_meta.service_status.value}."
|
|
584
573
|
)
|
|
585
|
-
elif service.deployment_step == DeploymentStep.MODEL_INFERENCE:
|
|
574
|
+
elif service.deployment_step == deployment_step.DeploymentStep.MODEL_INFERENCE:
|
|
586
575
|
module_logger.info(
|
|
587
576
|
f"Inference service {service.display_service_name} is {service_log_meta.service_status.value}."
|
|
588
577
|
)
|
|
589
|
-
elif service.deployment_step == DeploymentStep.MODEL_LOGGING:
|
|
578
|
+
elif service.deployment_step == deployment_step.DeploymentStep.MODEL_LOGGING:
|
|
590
579
|
module_logger.info(
|
|
591
580
|
f"Model logger service {service.display_service_name} is "
|
|
592
581
|
f"{service_log_meta.service_status.value}."
|
|
@@ -622,7 +611,7 @@ class ServiceOperator:
|
|
|
622
611
|
if service_status == service_sql.ServiceStatus.DONE:
|
|
623
612
|
# check if model logger service is done
|
|
624
613
|
# and transition the service log metadata to the model image build service
|
|
625
|
-
if service.deployment_step == DeploymentStep.MODEL_LOGGING:
|
|
614
|
+
if service.deployment_step == deployment_step.DeploymentStep.MODEL_LOGGING:
|
|
626
615
|
if model_build_service:
|
|
627
616
|
# building the inference image, transition to the model build service
|
|
628
617
|
service_log_meta.transition_service_log_metadata(
|
|
@@ -643,7 +632,7 @@ class ServiceOperator:
|
|
|
643
632
|
)
|
|
644
633
|
# check if model build service is done
|
|
645
634
|
# and transition the service log metadata to the model inference service
|
|
646
|
-
elif service.deployment_step == DeploymentStep.MODEL_BUILD:
|
|
635
|
+
elif service.deployment_step == deployment_step.DeploymentStep.MODEL_BUILD:
|
|
647
636
|
service_log_meta.transition_service_log_metadata(
|
|
648
637
|
model_inference_service,
|
|
649
638
|
f"Image build service {service.display_service_name} complete.",
|
|
@@ -651,7 +640,7 @@ class ServiceOperator:
|
|
|
651
640
|
is_model_logger_service_done=service_log_meta.is_model_logger_service_done,
|
|
652
641
|
operation_id=operation_id,
|
|
653
642
|
)
|
|
654
|
-
elif service.deployment_step == DeploymentStep.MODEL_INFERENCE:
|
|
643
|
+
elif service.deployment_step == deployment_step.DeploymentStep.MODEL_INFERENCE:
|
|
655
644
|
module_logger.info(f"Inference service {service.display_service_name} is deployed.")
|
|
656
645
|
else:
|
|
657
646
|
module_logger.warning(f"Service {service.display_service_name} is done, but not transitioning.")
|
|
@@ -911,19 +900,6 @@ class ServiceOperator:
|
|
|
911
900
|
|
|
912
901
|
time.sleep(2) # Poll every 2 seconds
|
|
913
902
|
|
|
914
|
-
@staticmethod
|
|
915
|
-
def _get_service_id_from_deployment_step(query_id: str, deployment_step: DeploymentStep) -> str:
|
|
916
|
-
"""Get the service ID through the server-side logic."""
|
|
917
|
-
uuid = query_id.replace("-", "")
|
|
918
|
-
big_int = int(uuid, 16)
|
|
919
|
-
md5_hash = hashlib.md5(str(big_int).encode()).hexdigest()
|
|
920
|
-
identifier = md5_hash[:8]
|
|
921
|
-
service_name_prefix = deployment_step.service_name_prefix
|
|
922
|
-
if service_name_prefix is None:
|
|
923
|
-
# raise an exception if the service name prefix is None
|
|
924
|
-
raise ValueError(f"Service name prefix is {service_name_prefix} for deployment step {deployment_step}.")
|
|
925
|
-
return (service_name_prefix + identifier).upper()
|
|
926
|
-
|
|
927
903
|
def _check_if_service_exists(
|
|
928
904
|
self,
|
|
929
905
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
|
@@ -966,6 +942,8 @@ class ServiceOperator:
|
|
|
966
942
|
image_repo_name: Optional[str],
|
|
967
943
|
input_stage_location: str,
|
|
968
944
|
input_file_pattern: str,
|
|
945
|
+
column_handling: Optional[str],
|
|
946
|
+
params: Optional[str],
|
|
969
947
|
output_stage_location: str,
|
|
970
948
|
completion_filename: str,
|
|
971
949
|
force_rebuild: bool,
|
|
@@ -976,7 +954,7 @@ class ServiceOperator:
|
|
|
976
954
|
gpu_requests: Optional[str],
|
|
977
955
|
replicas: Optional[int],
|
|
978
956
|
statement_params: Optional[dict[str, Any]] = None,
|
|
979
|
-
) ->
|
|
957
|
+
) -> job.MLJob[Any]:
|
|
980
958
|
database_name = self._database_name
|
|
981
959
|
schema_name = self._schema_name
|
|
982
960
|
|
|
@@ -1002,6 +980,8 @@ class ServiceOperator:
|
|
|
1002
980
|
max_batch_rows=max_batch_rows,
|
|
1003
981
|
input_stage_location=input_stage_location,
|
|
1004
982
|
input_file_pattern=input_file_pattern,
|
|
983
|
+
column_handling=column_handling,
|
|
984
|
+
params=params,
|
|
1005
985
|
output_stage_location=output_stage_location,
|
|
1006
986
|
completion_filename=completion_filename,
|
|
1007
987
|
function_name=function_name,
|
|
@@ -1045,7 +1025,7 @@ class ServiceOperator:
|
|
|
1045
1025
|
# Block until the async job is done
|
|
1046
1026
|
async_job.result()
|
|
1047
1027
|
|
|
1048
|
-
return
|
|
1028
|
+
return job.MLJob(
|
|
1049
1029
|
id=sql_identifier.get_fully_qualified_name(job_database_name, job_schema_name, job_name),
|
|
1050
1030
|
session=self._session,
|
|
1051
1031
|
)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
from snowflake.ml.model._client.service import model_deployment_spec_schema
|
|
6
|
+
|
|
7
|
+
BaseModel.model_config["protected_namespaces"] = ()
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ModelName(BaseModel):
|
|
11
|
+
model_name: str
|
|
12
|
+
version_name: str
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ModelSpec(BaseModel):
|
|
16
|
+
name: ModelName
|
|
17
|
+
hf_model: Optional[model_deployment_spec_schema.HuggingFaceModel] = None
|
|
18
|
+
log_model_args: Optional[model_deployment_spec_schema.LogModelArgs] = None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ImportModelSpec(BaseModel):
|
|
22
|
+
compute_pool: str
|
|
23
|
+
models: list[ModelSpec]
|
|
@@ -195,6 +195,7 @@ class ModelDeploymentSpec:
|
|
|
195
195
|
|
|
196
196
|
def add_job_spec(
|
|
197
197
|
self,
|
|
198
|
+
*,
|
|
198
199
|
job_name: sql_identifier.SqlIdentifier,
|
|
199
200
|
inference_compute_pool_name: sql_identifier.SqlIdentifier,
|
|
200
201
|
function_name: str,
|
|
@@ -202,6 +203,8 @@ class ModelDeploymentSpec:
|
|
|
202
203
|
output_stage_location: str,
|
|
203
204
|
completion_filename: str,
|
|
204
205
|
input_file_pattern: str,
|
|
206
|
+
column_handling: Optional[str] = None,
|
|
207
|
+
params: Optional[str] = None,
|
|
205
208
|
warehouse: sql_identifier.SqlIdentifier,
|
|
206
209
|
job_database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
207
210
|
job_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
@@ -217,14 +220,16 @@ class ModelDeploymentSpec:
|
|
|
217
220
|
Args:
|
|
218
221
|
job_name: Name of the job.
|
|
219
222
|
inference_compute_pool_name: Compute pool for inference.
|
|
220
|
-
warehouse: Warehouse for the job.
|
|
221
223
|
function_name: Function name.
|
|
222
224
|
input_stage_location: Stage location for input data.
|
|
223
225
|
output_stage_location: Stage location for output data.
|
|
226
|
+
completion_filename: Name of completion file (default: "completion.txt").
|
|
227
|
+
input_file_pattern: Pattern for input files (optional).
|
|
228
|
+
column_handling: Column handling mode for input data.
|
|
229
|
+
params: Additional parameters for the job.
|
|
230
|
+
warehouse: Warehouse for the job.
|
|
224
231
|
job_database_name: Database name for the job.
|
|
225
232
|
job_schema_name: Schema name for the job.
|
|
226
|
-
input_file_pattern: Pattern for input files (optional).
|
|
227
|
-
completion_filename: Name of completion file (default: "completion.txt").
|
|
228
233
|
cpu: CPU requirement.
|
|
229
234
|
memory: Memory requirement.
|
|
230
235
|
gpu: GPU requirement.
|
|
@@ -259,7 +264,10 @@ class ModelDeploymentSpec:
|
|
|
259
264
|
warehouse=warehouse.identifier() if warehouse else None,
|
|
260
265
|
function_name=function_name,
|
|
261
266
|
input=model_deployment_spec_schema.Input(
|
|
262
|
-
input_stage_location=input_stage_location,
|
|
267
|
+
input_stage_location=input_stage_location,
|
|
268
|
+
input_file_pattern=input_file_pattern,
|
|
269
|
+
column_handling=column_handling,
|
|
270
|
+
params=params,
|
|
263
271
|
),
|
|
264
272
|
output=model_deployment_spec_schema.Output(
|
|
265
273
|
output_stage_location=output_stage_location,
|
|
@@ -39,6 +39,8 @@ class Service(BaseModel):
|
|
|
39
39
|
class Input(BaseModel):
|
|
40
40
|
input_stage_location: str
|
|
41
41
|
input_file_pattern: str
|
|
42
|
+
column_handling: Optional[str] = None
|
|
43
|
+
params: Optional[str] = None
|
|
42
44
|
|
|
43
45
|
|
|
44
46
|
class Output(BaseModel):
|
|
@@ -74,6 +76,7 @@ class HuggingFaceModel(BaseModel):
|
|
|
74
76
|
task: Optional[str] = None
|
|
75
77
|
tokenizer: Optional[str] = None
|
|
76
78
|
token: Optional[str] = None
|
|
79
|
+
token_secret_object: Optional[str] = None
|
|
77
80
|
trust_remote_code: Optional[bool] = False
|
|
78
81
|
revision: Optional[str] = None
|
|
79
82
|
hf_model_kwargs: Optional[str] = "{}"
|
|
@@ -22,6 +22,14 @@ def _normalize_url_for_sql(url: str) -> str:
|
|
|
22
22
|
return f"'{url}'"
|
|
23
23
|
|
|
24
24
|
|
|
25
|
+
def _format_param_value(value: Any) -> str:
|
|
26
|
+
if isinstance(value, str):
|
|
27
|
+
return f"'{snowpark_utils.escape_single_quotes(value)}'" # type: ignore[no-untyped-call]
|
|
28
|
+
elif value is None:
|
|
29
|
+
return "NULL"
|
|
30
|
+
return str(value)
|
|
31
|
+
|
|
32
|
+
|
|
25
33
|
class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
26
34
|
FUNCTION_NAME_COL_NAME = "name"
|
|
27
35
|
FUNCTION_RETURN_TYPE_COL_NAME = "return_type"
|
|
@@ -354,6 +362,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
|
354
362
|
input_args: list[sql_identifier.SqlIdentifier],
|
|
355
363
|
returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
|
356
364
|
statement_params: Optional[dict[str, Any]] = None,
|
|
365
|
+
params: Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]] = None,
|
|
357
366
|
) -> dataframe.DataFrame:
|
|
358
367
|
with_statements = []
|
|
359
368
|
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
|
@@ -392,10 +401,17 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
|
392
401
|
|
|
393
402
|
args_sql = ", ".join(args_sql_list)
|
|
394
403
|
|
|
395
|
-
|
|
404
|
+
if params:
|
|
405
|
+
param_sql = ", ".join(_format_param_value(val) for _, val in params)
|
|
406
|
+
args_sql = f"{args_sql}, {param_sql}" if args_sql else param_sql
|
|
407
|
+
|
|
408
|
+
total_args = len(input_args) + (len(params) if params else 0)
|
|
409
|
+
wide_input = total_args > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
|
|
396
410
|
if wide_input:
|
|
397
|
-
|
|
398
|
-
|
|
411
|
+
parts = [f"'{arg}', {arg.identifier()}" for arg in input_args]
|
|
412
|
+
if params:
|
|
413
|
+
parts.extend(f"'{name}', {_format_param_value(val)}" for name, val in params)
|
|
414
|
+
args_sql = f"object_construct_keep_null({', '.join(parts)})"
|
|
399
415
|
|
|
400
416
|
sql = textwrap.dedent(
|
|
401
417
|
f"""WITH {','.join(with_statements)}
|
|
@@ -439,6 +455,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
|
439
455
|
statement_params: Optional[dict[str, Any]] = None,
|
|
440
456
|
is_partitioned: bool = True,
|
|
441
457
|
explain_case_sensitive: bool = False,
|
|
458
|
+
params: Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]] = None,
|
|
442
459
|
) -> dataframe.DataFrame:
|
|
443
460
|
with_statements = []
|
|
444
461
|
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
|
@@ -477,10 +494,17 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
|
477
494
|
|
|
478
495
|
args_sql = ", ".join(args_sql_list)
|
|
479
496
|
|
|
480
|
-
|
|
497
|
+
if params:
|
|
498
|
+
param_sql = ", ".join(_format_param_value(val) for _, val in params)
|
|
499
|
+
args_sql = f"{args_sql}, {param_sql}" if args_sql else param_sql
|
|
500
|
+
|
|
501
|
+
total_args = len(input_args) + (len(params) if params else 0)
|
|
502
|
+
wide_input = total_args > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
|
|
481
503
|
if wide_input:
|
|
482
|
-
|
|
483
|
-
|
|
504
|
+
parts = [f"'{arg}', {arg.identifier()}" for arg in input_args]
|
|
505
|
+
if params:
|
|
506
|
+
parts.extend(f"'{name}', {_format_param_value(val)}" for name, val in params)
|
|
507
|
+
args_sql = f"object_construct_keep_null({', '.join(parts)})"
|
|
484
508
|
|
|
485
509
|
sql = textwrap.dedent(
|
|
486
510
|
f"""WITH {','.join(with_statements)}
|
|
@@ -3,7 +3,9 @@ import dataclasses
|
|
|
3
3
|
import enum
|
|
4
4
|
import logging
|
|
5
5
|
import textwrap
|
|
6
|
-
from typing import Any, Generator, Optional
|
|
6
|
+
from typing import Any, Generator, Optional, cast
|
|
7
|
+
|
|
8
|
+
import yaml
|
|
7
9
|
|
|
8
10
|
from snowflake import snowpark
|
|
9
11
|
from snowflake.ml._internal.utils import (
|
|
@@ -18,6 +20,15 @@ from snowflake.snowpark._internal import utils as snowpark_utils
|
|
|
18
20
|
|
|
19
21
|
logger = logging.getLogger(__name__)
|
|
20
22
|
|
|
23
|
+
|
|
24
|
+
def _format_param_value(value: Any) -> str:
|
|
25
|
+
if isinstance(value, str):
|
|
26
|
+
return f"'{snowpark_utils.escape_single_quotes(value)}'" # type: ignore[no-untyped-call]
|
|
27
|
+
elif value is None:
|
|
28
|
+
return "NULL"
|
|
29
|
+
return str(value)
|
|
30
|
+
|
|
31
|
+
|
|
21
32
|
# Using this token instead of '?' to avoid escaping issues
|
|
22
33
|
# After quotes are escaped, we replace this token with '|| ? ||'
|
|
23
34
|
QMARK_RESERVED_TOKEN = "<QMARK_RESERVED_TOKEN>"
|
|
@@ -68,6 +79,7 @@ class ServiceStatusInfo:
|
|
|
68
79
|
|
|
69
80
|
class ServiceSQLClient(_base._BaseSQLClient):
|
|
70
81
|
MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME = "name"
|
|
82
|
+
MODEL_INFERENCE_SERVICE_ENDPOINT_PORT_COL_NAME = "port"
|
|
71
83
|
MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME = "ingress_url"
|
|
72
84
|
MODEL_INFERENCE_SERVICE_ENDPOINT_PRIVATELINK_INGRESS_URL_COL_NAME = "privatelink_ingress_url"
|
|
73
85
|
SERVICE_STATUS = "service_status"
|
|
@@ -75,6 +87,14 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
75
87
|
INSTANCE_STATUS = "instance_status"
|
|
76
88
|
CONTAINER_STATUS = "status"
|
|
77
89
|
MESSAGE = "message"
|
|
90
|
+
DESC_SERVICE_INTERNAL_DNS_COL_NAME = "dns_name"
|
|
91
|
+
DESC_SERVICE_SPEC_COL_NAME = "spec"
|
|
92
|
+
DESC_SERVICE_CONTAINERS_SPEC_NAME = "containers"
|
|
93
|
+
DESC_SERVICE_NAME_SPEC_NAME = "name"
|
|
94
|
+
DESC_SERVICE_PROXY_SPEC_ENV_NAME = "env"
|
|
95
|
+
PROXY_CONTAINER_NAME = "proxy"
|
|
96
|
+
MODEL_INFERENCE_AUTOCAPTURE_ENV_NAME = "SPCS_MODEL_INFERENCE_SERVER__AUTOCAPTURE_ENABLED"
|
|
97
|
+
FEATURE_MODEL_INFERENCE_AUTOCAPTURE = "FEATURE_MODEL_INFERENCE_AUTOCAPTURE"
|
|
78
98
|
|
|
79
99
|
@contextlib.contextmanager
|
|
80
100
|
def _qmark_paramstyle(self) -> Generator[None, None, None]:
|
|
@@ -129,6 +149,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
129
149
|
input_args: list[sql_identifier.SqlIdentifier],
|
|
130
150
|
returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
|
131
151
|
statement_params: Optional[dict[str, Any]] = None,
|
|
152
|
+
params: Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]] = None,
|
|
132
153
|
) -> dataframe.DataFrame:
|
|
133
154
|
with_statements = []
|
|
134
155
|
actual_database_name = database_name or self._database_name
|
|
@@ -159,10 +180,17 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
159
180
|
args_sql_list.append(input_arg_value)
|
|
160
181
|
args_sql = ", ".join(args_sql_list)
|
|
161
182
|
|
|
162
|
-
|
|
183
|
+
if params:
|
|
184
|
+
param_sql = ", ".join(_format_param_value(val) for _, val in params)
|
|
185
|
+
args_sql = f"{args_sql}, {param_sql}" if args_sql else param_sql
|
|
186
|
+
|
|
187
|
+
total_args = len(input_args) + (len(params) if params else 0)
|
|
188
|
+
wide_input = total_args > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
|
|
163
189
|
if wide_input:
|
|
164
|
-
|
|
165
|
-
|
|
190
|
+
parts = [f"'{arg}', {arg.identifier()}" for arg in input_args]
|
|
191
|
+
if params:
|
|
192
|
+
parts.extend(f"'{name}', {_format_param_value(val)}" for name, val in params)
|
|
193
|
+
args_sql = f"object_construct_keep_null({', '.join(parts)})"
|
|
166
194
|
|
|
167
195
|
fully_qualified_service_name = self.fully_qualified_object_name(
|
|
168
196
|
actual_database_name, actual_schema_name, service_name
|
|
@@ -233,7 +261,15 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
233
261
|
) -> list[ServiceStatusInfo]:
|
|
234
262
|
fully_qualified_object_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
|
|
235
263
|
query = f"SHOW SERVICE CONTAINERS IN SERVICE {fully_qualified_object_name}"
|
|
236
|
-
rows =
|
|
264
|
+
rows = (
|
|
265
|
+
query_result_checker.SqlResultValidator(self._session, query, statement_params=statement_params)
|
|
266
|
+
.has_column(ServiceSQLClient.INSTANCE_STATUS)
|
|
267
|
+
.has_column(ServiceSQLClient.CONTAINER_STATUS)
|
|
268
|
+
.has_column(ServiceSQLClient.SERVICE_STATUS)
|
|
269
|
+
.has_column(ServiceSQLClient.INSTANCE_ID)
|
|
270
|
+
.has_column(ServiceSQLClient.MESSAGE)
|
|
271
|
+
.validate()
|
|
272
|
+
)
|
|
237
273
|
statuses = []
|
|
238
274
|
for r in rows:
|
|
239
275
|
instance_status, container_status = None, None
|
|
@@ -252,6 +288,58 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
252
288
|
)
|
|
253
289
|
return statuses
|
|
254
290
|
|
|
291
|
+
def describe_service(
|
|
292
|
+
self,
|
|
293
|
+
*,
|
|
294
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
|
295
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
296
|
+
service_name: sql_identifier.SqlIdentifier,
|
|
297
|
+
statement_params: Optional[dict[str, Any]] = None,
|
|
298
|
+
) -> row.Row:
|
|
299
|
+
fully_qualified_object_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
|
|
300
|
+
query = f"DESCRIBE SERVICE {fully_qualified_object_name}"
|
|
301
|
+
rows = (
|
|
302
|
+
query_result_checker.SqlResultValidator(self._session, query, statement_params=statement_params)
|
|
303
|
+
.has_dimensions(expected_rows=1)
|
|
304
|
+
.has_column(ServiceSQLClient.DESC_SERVICE_INTERNAL_DNS_COL_NAME)
|
|
305
|
+
.validate()
|
|
306
|
+
)
|
|
307
|
+
return rows[0]
|
|
308
|
+
|
|
309
|
+
def get_proxy_container_autocapture(self, row: row.Row) -> bool:
|
|
310
|
+
"""Extract whether service has autocapture enabled from proxy container spec.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
row: A row.Row object from DESCRIBE SERVICE containing the service YAML spec.
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
True if autocapture is enabled in proxy spec
|
|
317
|
+
False if disabled or not set in proxy spec
|
|
318
|
+
False if service doesn't have proxy container
|
|
319
|
+
"""
|
|
320
|
+
try:
|
|
321
|
+
spec_yaml = row[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME]
|
|
322
|
+
if spec_yaml is None:
|
|
323
|
+
return False
|
|
324
|
+
spec_raw = yaml.safe_load(spec_yaml)
|
|
325
|
+
if spec_raw is None:
|
|
326
|
+
return False
|
|
327
|
+
spec = cast(dict[str, Any], spec_raw)
|
|
328
|
+
|
|
329
|
+
proxy_container_spec = next(
|
|
330
|
+
container
|
|
331
|
+
for container in spec[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME][
|
|
332
|
+
ServiceSQLClient.DESC_SERVICE_CONTAINERS_SPEC_NAME
|
|
333
|
+
]
|
|
334
|
+
if container[ServiceSQLClient.DESC_SERVICE_NAME_SPEC_NAME] == ServiceSQLClient.PROXY_CONTAINER_NAME
|
|
335
|
+
)
|
|
336
|
+
env = proxy_container_spec.get(ServiceSQLClient.DESC_SERVICE_PROXY_SPEC_ENV_NAME, {})
|
|
337
|
+
autocapture_enabled = env.get(ServiceSQLClient.MODEL_INFERENCE_AUTOCAPTURE_ENV_NAME, "false")
|
|
338
|
+
return str(autocapture_enabled).lower() == "true"
|
|
339
|
+
|
|
340
|
+
except StopIteration:
|
|
341
|
+
return False
|
|
342
|
+
|
|
255
343
|
def drop_service(
|
|
256
344
|
self,
|
|
257
345
|
*,
|
|
@@ -282,6 +370,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
282
370
|
statement_params=statement_params,
|
|
283
371
|
)
|
|
284
372
|
.has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME, allow_empty=True)
|
|
373
|
+
.has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_PORT_COL_NAME, allow_empty=True)
|
|
285
374
|
.has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME, allow_empty=True)
|
|
286
375
|
)
|
|
287
376
|
|
|
@@ -131,7 +131,7 @@ class ModelComposer:
|
|
|
131
131
|
python_version: Optional[str] = None,
|
|
132
132
|
user_files: Optional[dict[str, list[str]]] = None,
|
|
133
133
|
ext_modules: Optional[list[ModuleType]] = None,
|
|
134
|
-
code_paths: Optional[list[
|
|
134
|
+
code_paths: Optional[list[model_types.CodePathLike]] = None,
|
|
135
135
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
|
136
136
|
experiment_info: Optional["ExperimentInfo"] = None,
|
|
137
137
|
options: Optional[model_types.ModelSaveOption] = None,
|
|
@@ -39,6 +39,10 @@ class ModelMethodSignatureFieldWithName(ModelMethodSignatureField):
|
|
|
39
39
|
name: Required[str]
|
|
40
40
|
|
|
41
41
|
|
|
42
|
+
class ModelMethodSignatureFieldWithNameAndDefault(ModelMethodSignatureFieldWithName):
|
|
43
|
+
default: Required[Any]
|
|
44
|
+
|
|
45
|
+
|
|
42
46
|
class ModelFunctionMethodDict(TypedDict):
|
|
43
47
|
name: Required[str]
|
|
44
48
|
runtime: Required[str]
|
|
@@ -46,6 +50,7 @@ class ModelFunctionMethodDict(TypedDict):
|
|
|
46
50
|
handler: Required[str]
|
|
47
51
|
inputs: Required[list[ModelMethodSignatureFieldWithName]]
|
|
48
52
|
outputs: Required[Union[list[ModelMethodSignatureField], list[ModelMethodSignatureFieldWithName]]]
|
|
53
|
+
params: NotRequired[list[ModelMethodSignatureFieldWithNameAndDefault]]
|
|
49
54
|
volatility: NotRequired[str]
|
|
50
55
|
|
|
51
56
|
|