snowflake-ml-python 1.5.4__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +2 -0
- snowflake/cortex/_classify_text.py +36 -0
- snowflake/cortex/_complete.py +66 -35
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
- snowflake/ml/_internal/telemetry.py +26 -2
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/data/_internal/arrow_ingestor.py +284 -0
- snowflake/ml/data/data_connector.py +186 -0
- snowflake/ml/data/data_ingestor.py +45 -0
- snowflake/ml/data/data_source.py +23 -0
- snowflake/ml/data/ingestor_utils.py +62 -0
- snowflake/ml/data/torch_dataset.py +33 -0
- snowflake/ml/dataset/dataset.py +1 -13
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +23 -117
- snowflake/ml/feature_store/access_manager.py +7 -1
- snowflake/ml/feature_store/entity.py +19 -2
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +37 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +30 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/example_helper.py +278 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +44 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +9 -0
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
- snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
- snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +36 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +24 -0
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +8 -0
- snowflake/ml/feature_store/feature_store.py +637 -76
- snowflake/ml/feature_store/feature_view.py +316 -9
- snowflake/ml/fileset/stage_fs.py +18 -10
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +171 -20
- snowflake/ml/model/_client/ops/model_ops.py +105 -27
- snowflake/ml/model/_client/ops/service_ops.py +121 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +95 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +31 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +129 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
- snowflake/ml/model/_model_composer/model_composer.py +14 -14
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +33 -17
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -1
- snowflake/ml/model/_model_composer/model_method/function_generator.py +3 -3
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +3 -27
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/model_method.py +5 -2
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +30 -3
- snowflake/ml/model/_packager/model_handlers/_utils.py +58 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +52 -3
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +9 -5
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +80 -3
- snowflake/ml/model/_packager/model_handlers/llm.py +7 -3
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +87 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +8 -3
- snowflake/ml/model/_packager/model_handlers/xgboost.py +71 -3
- snowflake/ml/model/_packager/model_meta/model_meta.py +32 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +19 -0
- snowflake/ml/model/_packager/model_packager.py +2 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +7 -7
- snowflake/ml/model/model_signature.py +4 -4
- snowflake/ml/model/type_hints.py +2 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +1 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/impute/simple_imputer.py +26 -0
- snowflake/ml/modeling/pipeline/pipeline.py +7 -4
- snowflake/ml/registry/_manager/model_manager.py +16 -2
- snowflake/ml/registry/registry.py +100 -13
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/METADATA +81 -2
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/RECORD +99 -66
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/lineage/data_source.py +0 -10
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ import enum
|
|
2
2
|
import pathlib
|
3
3
|
import tempfile
|
4
4
|
import warnings
|
5
|
-
from typing import Any, Callable, Dict, List, Optional, Union
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, Union, overload
|
6
6
|
|
7
7
|
import pandas as pd
|
8
8
|
|
@@ -10,7 +10,7 @@ from snowflake.ml._internal import telemetry
|
|
10
10
|
from snowflake.ml._internal.utils import sql_identifier
|
11
11
|
from snowflake.ml.lineage import lineage_node
|
12
12
|
from snowflake.ml.model import type_hints as model_types
|
13
|
-
from snowflake.ml.model._client.ops import metadata_ops, model_ops
|
13
|
+
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
14
14
|
from snowflake.ml.model._model_composer import model_composer
|
15
15
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
16
16
|
from snowflake.ml.model._packager.model_handlers import snowmlmodel
|
@@ -29,6 +29,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
29
29
|
"""Model Version Object representing a specific version of the model that could be run."""
|
30
30
|
|
31
31
|
_model_ops: model_ops.ModelOperator
|
32
|
+
_service_ops: service_ops.ServiceOperator
|
32
33
|
_model_name: sql_identifier.SqlIdentifier
|
33
34
|
_version_name: sql_identifier.SqlIdentifier
|
34
35
|
_functions: List[model_manifest_schema.ModelFunctionInfo]
|
@@ -41,11 +42,13 @@ class ModelVersion(lineage_node.LineageNode):
|
|
41
42
|
cls,
|
42
43
|
model_ops: model_ops.ModelOperator,
|
43
44
|
*,
|
45
|
+
service_ops: service_ops.ServiceOperator,
|
44
46
|
model_name: sql_identifier.SqlIdentifier,
|
45
47
|
version_name: sql_identifier.SqlIdentifier,
|
46
48
|
) -> "ModelVersion":
|
47
49
|
self: "ModelVersion" = object.__new__(cls)
|
48
50
|
self._model_ops = model_ops
|
51
|
+
self._service_ops = service_ops
|
49
52
|
self._model_name = model_name
|
50
53
|
self._version_name = version_name
|
51
54
|
self._functions = self._get_functions()
|
@@ -65,6 +68,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
65
68
|
return False
|
66
69
|
return (
|
67
70
|
self._model_ops == __value._model_ops
|
71
|
+
and self._service_ops == __value._service_ops
|
68
72
|
and self._model_name == __value._model_name
|
69
73
|
and self._version_name == __value._version_name
|
70
74
|
)
|
@@ -318,10 +322,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
318
322
|
"""
|
319
323
|
return self._functions
|
320
324
|
|
321
|
-
@
|
322
|
-
project=_TELEMETRY_PROJECT,
|
323
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
324
|
-
)
|
325
|
+
@overload
|
325
326
|
def run(
|
326
327
|
self,
|
327
328
|
X: Union[pd.DataFrame, dataframe.DataFrame],
|
@@ -339,6 +340,53 @@ class ModelVersion(lineage_node.LineageNode):
|
|
339
340
|
partition_column: The partition column name to partition by.
|
340
341
|
strict_input_validation: Enable stricter validation for the input data. This will result value range based
|
341
342
|
type validation to make sure your input data won't overflow when providing to the model.
|
343
|
+
"""
|
344
|
+
...
|
345
|
+
|
346
|
+
@overload
|
347
|
+
def run(
|
348
|
+
self,
|
349
|
+
X: Union[pd.DataFrame, dataframe.DataFrame],
|
350
|
+
*,
|
351
|
+
service_name: str,
|
352
|
+
function_name: Optional[str] = None,
|
353
|
+
strict_input_validation: bool = False,
|
354
|
+
) -> Union[pd.DataFrame, dataframe.DataFrame]:
|
355
|
+
"""Invoke a method in a model version object via a service.
|
356
|
+
|
357
|
+
Args:
|
358
|
+
X: The input data, which could be a pandas DataFrame or Snowpark DataFrame.
|
359
|
+
service_name: The service name.
|
360
|
+
function_name: The function name to run. It is the name used to call a function in SQL.
|
361
|
+
strict_input_validation: Enable stricter validation for the input data. This will result value range based
|
362
|
+
type validation to make sure your input data won't overflow when providing to the model.
|
363
|
+
"""
|
364
|
+
...
|
365
|
+
|
366
|
+
@telemetry.send_api_usage_telemetry(
|
367
|
+
project=_TELEMETRY_PROJECT,
|
368
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
369
|
+
func_params_to_log=["function_name", "service_name"],
|
370
|
+
)
|
371
|
+
def run(
|
372
|
+
self,
|
373
|
+
X: Union[pd.DataFrame, "dataframe.DataFrame"],
|
374
|
+
*,
|
375
|
+
service_name: Optional[str] = None,
|
376
|
+
function_name: Optional[str] = None,
|
377
|
+
partition_column: Optional[str] = None,
|
378
|
+
strict_input_validation: bool = False,
|
379
|
+
) -> Union[pd.DataFrame, "dataframe.DataFrame"]:
|
380
|
+
"""Invoke a method in a model version object via the warehouse or a service.
|
381
|
+
|
382
|
+
Args:
|
383
|
+
X: The input data, which could be a pandas DataFrame or Snowpark DataFrame.
|
384
|
+
service_name: The service name. If None, the function is invoked via the warehouse. Otherwise, the function
|
385
|
+
is invoked via the given service.
|
386
|
+
function_name: The function name to run. It is the name used to call a function in SQL.
|
387
|
+
partition_column: The partition column name to partition by.
|
388
|
+
strict_input_validation: Enable stricter validation for the input data. This will result value range based
|
389
|
+
type validation to make sure your input data won't overflow when providing to the model.
|
342
390
|
|
343
391
|
Raises:
|
344
392
|
ValueError: When no method with the corresponding name is available.
|
@@ -375,23 +423,37 @@ class ModelVersion(lineage_node.LineageNode):
|
|
375
423
|
elif len(functions) != 1:
|
376
424
|
raise ValueError(
|
377
425
|
f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}"
|
378
|
-
f" version {self.version_name}. Please specify a `
|
426
|
+
f" version {self.version_name}. Please specify a `function_name` when calling the `run` method."
|
379
427
|
)
|
380
428
|
else:
|
381
429
|
target_function_info = functions[0]
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
430
|
+
|
431
|
+
if service_name:
|
432
|
+
return self._model_ops.invoke_method(
|
433
|
+
method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
|
434
|
+
signature=target_function_info["signature"],
|
435
|
+
X=X,
|
436
|
+
database_name=None,
|
437
|
+
schema_name=None,
|
438
|
+
service_name=sql_identifier.SqlIdentifier(service_name),
|
439
|
+
strict_input_validation=strict_input_validation,
|
440
|
+
statement_params=statement_params,
|
441
|
+
)
|
442
|
+
else:
|
443
|
+
return self._model_ops.invoke_method(
|
444
|
+
method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
|
445
|
+
method_function_type=target_function_info["target_method_function_type"],
|
446
|
+
signature=target_function_info["signature"],
|
447
|
+
X=X,
|
448
|
+
database_name=None,
|
449
|
+
schema_name=None,
|
450
|
+
model_name=self._model_name,
|
451
|
+
version_name=self._version_name,
|
452
|
+
strict_input_validation=strict_input_validation,
|
453
|
+
partition_column=partition_column,
|
454
|
+
statement_params=statement_params,
|
455
|
+
is_partitioned=target_function_info["is_partitioned"],
|
456
|
+
)
|
395
457
|
|
396
458
|
@telemetry.send_api_usage_telemetry(
|
397
459
|
project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, func_params_to_log=["export_mode"]
|
@@ -525,9 +587,98 @@ class ModelVersion(lineage_node.LineageNode):
|
|
525
587
|
database_name=database_name_id,
|
526
588
|
schema_name=schema_name_id,
|
527
589
|
),
|
590
|
+
service_ops=service_ops.ServiceOperator(
|
591
|
+
session,
|
592
|
+
database_name=database_name_id,
|
593
|
+
schema_name=schema_name_id,
|
594
|
+
),
|
528
595
|
model_name=model_name_id,
|
529
596
|
version_name=sql_identifier.SqlIdentifier(version),
|
530
597
|
)
|
531
598
|
|
599
|
+
@telemetry.send_api_usage_telemetry(
|
600
|
+
project=_TELEMETRY_PROJECT,
|
601
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
602
|
+
func_params_to_log=[
|
603
|
+
"service_name",
|
604
|
+
"image_build_compute_pool",
|
605
|
+
"service_compute_pool",
|
606
|
+
"image_repo_database",
|
607
|
+
"image_repo_schema",
|
608
|
+
"image_repo",
|
609
|
+
"image_name",
|
610
|
+
"gpu_requests",
|
611
|
+
],
|
612
|
+
)
|
613
|
+
def create_service(
|
614
|
+
self,
|
615
|
+
*,
|
616
|
+
service_name: str,
|
617
|
+
image_build_compute_pool: Optional[str] = None,
|
618
|
+
service_compute_pool: str,
|
619
|
+
image_repo: str,
|
620
|
+
image_name: Optional[str] = None,
|
621
|
+
ingress_enabled: bool = False,
|
622
|
+
min_instances: int = 1,
|
623
|
+
max_instances: int = 1,
|
624
|
+
gpu_requests: Optional[str] = None,
|
625
|
+
force_rebuild: bool = False,
|
626
|
+
build_external_access_integration: str,
|
627
|
+
) -> str:
|
628
|
+
"""Create an inference service with the given spec.
|
629
|
+
|
630
|
+
Args:
|
631
|
+
service_name: The name of the service, can be fully qualified. If not fully qualified, the database or
|
632
|
+
schema of the model will be used.
|
633
|
+
image_build_compute_pool: The name of the compute pool used to build the model inference image. Use
|
634
|
+
the service compute pool if None.
|
635
|
+
service_compute_pool: The name of the compute pool used to run the inference service.
|
636
|
+
image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
|
637
|
+
or schema of the model will be used.
|
638
|
+
image_name: The name of the model inference image. Use a generated name if None.
|
639
|
+
ingress_enabled: Whether to enable ingress.
|
640
|
+
min_instances: The minimum number of inference service instances to run.
|
641
|
+
max_instances: The maximum number of inference service instances to run.
|
642
|
+
gpu_requests: The gpu limit for GPU based inference. Can be integer, fractional or string values. Use CPU
|
643
|
+
if None.
|
644
|
+
force_rebuild: Whether to force a model inference image rebuild.
|
645
|
+
build_external_access_integration: The external access integration for image build.
|
646
|
+
|
647
|
+
Returns:
|
648
|
+
The service name.
|
649
|
+
"""
|
650
|
+
statement_params = telemetry.get_statement_params(
|
651
|
+
project=_TELEMETRY_PROJECT,
|
652
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
653
|
+
)
|
654
|
+
service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
|
655
|
+
image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
|
656
|
+
return self._service_ops.create_service(
|
657
|
+
database_name=None,
|
658
|
+
schema_name=None,
|
659
|
+
model_name=self._model_name,
|
660
|
+
version_name=self._version_name,
|
661
|
+
service_database_name=service_db_id,
|
662
|
+
service_schema_name=service_schema_id,
|
663
|
+
service_name=service_id,
|
664
|
+
image_build_compute_pool_name=(
|
665
|
+
sql_identifier.SqlIdentifier(image_build_compute_pool)
|
666
|
+
if image_build_compute_pool
|
667
|
+
else sql_identifier.SqlIdentifier(service_compute_pool)
|
668
|
+
),
|
669
|
+
service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
|
670
|
+
image_repo_database_name=image_repo_db_id,
|
671
|
+
image_repo_schema_name=image_repo_schema_id,
|
672
|
+
image_repo_name=image_repo_id,
|
673
|
+
image_name=sql_identifier.SqlIdentifier(image_name) if image_name else None,
|
674
|
+
ingress_enabled=ingress_enabled,
|
675
|
+
min_instances=min_instances,
|
676
|
+
max_instances=max_instances,
|
677
|
+
gpu_requests=gpu_requests,
|
678
|
+
force_rebuild=force_rebuild,
|
679
|
+
build_external_access_integration=sql_identifier.SqlIdentifier(build_external_access_integration),
|
680
|
+
statement_params=statement_params,
|
681
|
+
)
|
682
|
+
|
532
683
|
|
533
684
|
lineage_node.DOMAIN_LINEAGE_REGISTRY["model"] = ModelVersion
|
@@ -2,7 +2,7 @@ import os
|
|
2
2
|
import pathlib
|
3
3
|
import tempfile
|
4
4
|
import warnings
|
5
|
-
from typing import Any, Dict, List, Literal, Optional, Union, cast
|
5
|
+
from typing import Any, Dict, List, Literal, Optional, Union, cast, overload
|
6
6
|
|
7
7
|
import yaml
|
8
8
|
|
@@ -12,6 +12,7 @@ from snowflake.ml.model._client.ops import metadata_ops
|
|
12
12
|
from snowflake.ml.model._client.sql import (
|
13
13
|
model as model_sql,
|
14
14
|
model_version as model_version_sql,
|
15
|
+
service as service_sql,
|
15
16
|
stage as stage_sql,
|
16
17
|
tag as tag_sql,
|
17
18
|
)
|
@@ -21,7 +22,7 @@ from snowflake.ml.model._model_composer.model_manifest import (
|
|
21
22
|
model_manifest_schema,
|
22
23
|
)
|
23
24
|
from snowflake.ml.model._packager.model_env import model_env
|
24
|
-
from snowflake.ml.model._packager.model_meta import model_meta
|
25
|
+
from snowflake.ml.model._packager.model_meta import model_meta, model_meta_schema
|
25
26
|
from snowflake.ml.model._packager.model_runtime import model_runtime
|
26
27
|
from snowflake.ml.model._signatures import snowpark_handler
|
27
28
|
from snowflake.snowpark import dataframe, row, session
|
@@ -60,6 +61,11 @@ class ModelOperator:
|
|
60
61
|
database_name=database_name,
|
61
62
|
schema_name=schema_name,
|
62
63
|
)
|
64
|
+
self._service_client = service_sql.ServiceSQLClient(
|
65
|
+
session,
|
66
|
+
database_name=database_name,
|
67
|
+
schema_name=schema_name,
|
68
|
+
)
|
63
69
|
self._metadata_ops = metadata_ops.MetadataOperator(
|
64
70
|
session,
|
65
71
|
database_name=database_name,
|
@@ -597,16 +603,38 @@ class ModelOperator:
|
|
597
603
|
function_names, list(signatures.keys())
|
598
604
|
)
|
599
605
|
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
+
model_func_info = []
|
607
|
+
|
608
|
+
for function_name, function_type in function_names_and_types:
|
609
|
+
|
610
|
+
target_method = function_name_mapping[function_name]
|
611
|
+
|
612
|
+
is_partitioned = False
|
613
|
+
if function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
|
614
|
+
# better to set default True here because worse case it will be slow but not error out
|
615
|
+
is_partitioned = (
|
616
|
+
(
|
617
|
+
model_spec["function_properties"]
|
618
|
+
.get(target_method, {})
|
619
|
+
.get(model_meta_schema.FunctionProperties.PARTITIONED.value, True)
|
620
|
+
)
|
621
|
+
if "function_properties" in model_spec
|
622
|
+
else True
|
623
|
+
)
|
624
|
+
|
625
|
+
model_func_info.append(
|
626
|
+
model_manifest_schema.ModelFunctionInfo(
|
627
|
+
name=function_name.identifier(),
|
628
|
+
target_method=target_method,
|
629
|
+
target_method_function_type=function_type,
|
630
|
+
signature=model_signature.ModelSignature.from_dict(signatures[target_method]),
|
631
|
+
is_partitioned=is_partitioned,
|
632
|
+
)
|
606
633
|
)
|
607
|
-
for function_name, function_type in function_names_and_types
|
608
|
-
]
|
609
634
|
|
635
|
+
return model_func_info
|
636
|
+
|
637
|
+
@overload
|
610
638
|
def invoke_method(
|
611
639
|
self,
|
612
640
|
*,
|
@@ -621,6 +649,41 @@ class ModelOperator:
|
|
621
649
|
strict_input_validation: bool = False,
|
622
650
|
partition_column: Optional[sql_identifier.SqlIdentifier] = None,
|
623
651
|
statement_params: Optional[Dict[str, str]] = None,
|
652
|
+
is_partitioned: Optional[bool] = None,
|
653
|
+
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
654
|
+
...
|
655
|
+
|
656
|
+
@overload
|
657
|
+
def invoke_method(
|
658
|
+
self,
|
659
|
+
*,
|
660
|
+
method_name: sql_identifier.SqlIdentifier,
|
661
|
+
signature: model_signature.ModelSignature,
|
662
|
+
X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
|
663
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
664
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
665
|
+
service_name: sql_identifier.SqlIdentifier,
|
666
|
+
strict_input_validation: bool = False,
|
667
|
+
statement_params: Optional[Dict[str, str]] = None,
|
668
|
+
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
669
|
+
...
|
670
|
+
|
671
|
+
def invoke_method(
|
672
|
+
self,
|
673
|
+
*,
|
674
|
+
method_name: sql_identifier.SqlIdentifier,
|
675
|
+
method_function_type: Optional[str] = None,
|
676
|
+
signature: model_signature.ModelSignature,
|
677
|
+
X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
|
678
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
679
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
680
|
+
model_name: Optional[sql_identifier.SqlIdentifier] = None,
|
681
|
+
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
682
|
+
service_name: Optional[sql_identifier.SqlIdentifier] = None,
|
683
|
+
strict_input_validation: bool = False,
|
684
|
+
partition_column: Optional[sql_identifier.SqlIdentifier] = None,
|
685
|
+
statement_params: Optional[Dict[str, str]] = None,
|
686
|
+
is_partitioned: Optional[bool] = None,
|
624
687
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
625
688
|
identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
|
626
689
|
|
@@ -657,31 +720,46 @@ class ModelOperator:
|
|
657
720
|
if output_name in original_cols:
|
658
721
|
original_cols.remove(output_name)
|
659
722
|
|
660
|
-
if
|
661
|
-
df_res = self.
|
723
|
+
if service_name:
|
724
|
+
df_res = self._service_client.invoke_function_method(
|
662
725
|
method_name=method_name,
|
663
726
|
input_df=s_df,
|
664
727
|
input_args=input_args,
|
665
728
|
returns=returns,
|
666
729
|
database_name=database_name,
|
667
730
|
schema_name=schema_name,
|
668
|
-
|
669
|
-
version_name=version_name,
|
670
|
-
statement_params=statement_params,
|
671
|
-
)
|
672
|
-
elif method_function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
|
673
|
-
df_res = self._model_version_client.invoke_table_function_method(
|
674
|
-
method_name=method_name,
|
675
|
-
input_df=s_df,
|
676
|
-
input_args=input_args,
|
677
|
-
partition_column=partition_column,
|
678
|
-
returns=returns,
|
679
|
-
database_name=database_name,
|
680
|
-
schema_name=schema_name,
|
681
|
-
model_name=model_name,
|
682
|
-
version_name=version_name,
|
731
|
+
service_name=service_name,
|
683
732
|
statement_params=statement_params,
|
684
733
|
)
|
734
|
+
else:
|
735
|
+
assert model_name is not None
|
736
|
+
assert version_name is not None
|
737
|
+
if method_function_type == model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value:
|
738
|
+
df_res = self._model_version_client.invoke_function_method(
|
739
|
+
method_name=method_name,
|
740
|
+
input_df=s_df,
|
741
|
+
input_args=input_args,
|
742
|
+
returns=returns,
|
743
|
+
database_name=database_name,
|
744
|
+
schema_name=schema_name,
|
745
|
+
model_name=model_name,
|
746
|
+
version_name=version_name,
|
747
|
+
statement_params=statement_params,
|
748
|
+
)
|
749
|
+
elif method_function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
|
750
|
+
df_res = self._model_version_client.invoke_table_function_method(
|
751
|
+
method_name=method_name,
|
752
|
+
input_df=s_df,
|
753
|
+
input_args=input_args,
|
754
|
+
partition_column=partition_column,
|
755
|
+
returns=returns,
|
756
|
+
database_name=database_name,
|
757
|
+
schema_name=schema_name,
|
758
|
+
model_name=model_name,
|
759
|
+
version_name=version_name,
|
760
|
+
statement_params=statement_params,
|
761
|
+
is_partitioned=is_partitioned or False,
|
762
|
+
)
|
685
763
|
|
686
764
|
if keep_order:
|
687
765
|
# if it's a partitioned table function, _ID will be null and we won't be able to sort.
|
@@ -0,0 +1,121 @@
|
|
1
|
+
import pathlib
|
2
|
+
import tempfile
|
3
|
+
from typing import Any, Dict, Optional
|
4
|
+
|
5
|
+
from snowflake.ml._internal import file_utils
|
6
|
+
from snowflake.ml._internal.utils import sql_identifier
|
7
|
+
from snowflake.ml.model._client.service import model_deployment_spec
|
8
|
+
from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
|
9
|
+
from snowflake.snowpark import session
|
10
|
+
from snowflake.snowpark._internal import utils as snowpark_utils
|
11
|
+
|
12
|
+
|
13
|
+
class ServiceOperator:
|
14
|
+
"""Service operator for container services logic."""
|
15
|
+
|
16
|
+
def __init__(
|
17
|
+
self,
|
18
|
+
session: session.Session,
|
19
|
+
*,
|
20
|
+
database_name: sql_identifier.SqlIdentifier,
|
21
|
+
schema_name: sql_identifier.SqlIdentifier,
|
22
|
+
) -> None:
|
23
|
+
self._session = session
|
24
|
+
self._database_name = database_name
|
25
|
+
self._schema_name = schema_name
|
26
|
+
self._workspace = tempfile.TemporaryDirectory()
|
27
|
+
self._service_client = service_sql.ServiceSQLClient(
|
28
|
+
session,
|
29
|
+
database_name=database_name,
|
30
|
+
schema_name=schema_name,
|
31
|
+
)
|
32
|
+
self._stage_client = stage_sql.StageSQLClient(
|
33
|
+
session,
|
34
|
+
database_name=database_name,
|
35
|
+
schema_name=schema_name,
|
36
|
+
)
|
37
|
+
self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec(
|
38
|
+
workspace_path=pathlib.Path(self._workspace.name)
|
39
|
+
)
|
40
|
+
|
41
|
+
def __eq__(self, __value: object) -> bool:
|
42
|
+
if not isinstance(__value, ServiceOperator):
|
43
|
+
return False
|
44
|
+
return self._service_client == __value._service_client
|
45
|
+
|
46
|
+
@property
|
47
|
+
def workspace_path(self) -> pathlib.Path:
|
48
|
+
return pathlib.Path(self._workspace.name)
|
49
|
+
|
50
|
+
def create_service(
|
51
|
+
self,
|
52
|
+
*,
|
53
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
54
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
55
|
+
model_name: sql_identifier.SqlIdentifier,
|
56
|
+
version_name: sql_identifier.SqlIdentifier,
|
57
|
+
service_database_name: Optional[sql_identifier.SqlIdentifier],
|
58
|
+
service_schema_name: Optional[sql_identifier.SqlIdentifier],
|
59
|
+
service_name: sql_identifier.SqlIdentifier,
|
60
|
+
image_build_compute_pool_name: sql_identifier.SqlIdentifier,
|
61
|
+
service_compute_pool_name: sql_identifier.SqlIdentifier,
|
62
|
+
image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
|
63
|
+
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
64
|
+
image_repo_name: sql_identifier.SqlIdentifier,
|
65
|
+
image_name: Optional[sql_identifier.SqlIdentifier],
|
66
|
+
ingress_enabled: bool,
|
67
|
+
min_instances: int,
|
68
|
+
max_instances: int,
|
69
|
+
gpu_requests: Optional[str],
|
70
|
+
force_rebuild: bool,
|
71
|
+
build_external_access_integration: sql_identifier.SqlIdentifier,
|
72
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
73
|
+
) -> str:
|
74
|
+
# create a temp stage
|
75
|
+
stage_name = sql_identifier.SqlIdentifier(
|
76
|
+
snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
|
77
|
+
)
|
78
|
+
self._stage_client.create_tmp_stage(
|
79
|
+
database_name=database_name,
|
80
|
+
schema_name=schema_name,
|
81
|
+
stage_name=stage_name,
|
82
|
+
statement_params=statement_params,
|
83
|
+
)
|
84
|
+
stage_path = self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)
|
85
|
+
|
86
|
+
self._model_deployment_spec.save(
|
87
|
+
database_name=database_name or self._database_name,
|
88
|
+
schema_name=schema_name or self._schema_name,
|
89
|
+
model_name=model_name,
|
90
|
+
version_name=version_name,
|
91
|
+
service_database_name=service_database_name,
|
92
|
+
service_schema_name=service_schema_name,
|
93
|
+
service_name=service_name,
|
94
|
+
image_build_compute_pool_name=image_build_compute_pool_name,
|
95
|
+
service_compute_pool_name=service_compute_pool_name,
|
96
|
+
image_repo_database_name=image_repo_database_name,
|
97
|
+
image_repo_schema_name=image_repo_schema_name,
|
98
|
+
image_repo_name=image_repo_name,
|
99
|
+
image_name=image_name,
|
100
|
+
ingress_enabled=ingress_enabled,
|
101
|
+
min_instances=min_instances,
|
102
|
+
max_instances=max_instances,
|
103
|
+
gpu=gpu_requests,
|
104
|
+
force_rebuild=force_rebuild,
|
105
|
+
external_access_integration=build_external_access_integration,
|
106
|
+
)
|
107
|
+
file_utils.upload_directory_to_stage(
|
108
|
+
self._session,
|
109
|
+
local_path=self.workspace_path,
|
110
|
+
stage_path=pathlib.PurePosixPath(stage_path),
|
111
|
+
statement_params=statement_params,
|
112
|
+
)
|
113
|
+
|
114
|
+
# deploy the model service
|
115
|
+
self._service_client.deploy_model(
|
116
|
+
stage_path=stage_path,
|
117
|
+
model_deployment_spec_file_rel_path=model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH,
|
118
|
+
statement_params=statement_params,
|
119
|
+
)
|
120
|
+
|
121
|
+
return service_name
|
@@ -0,0 +1,95 @@
|
|
1
|
+
import pathlib
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import yaml
|
5
|
+
|
6
|
+
from snowflake.ml._internal.utils import identifier, sql_identifier
|
7
|
+
from snowflake.ml.model._client.service import model_deployment_spec_schema
|
8
|
+
|
9
|
+
|
10
|
+
class ModelDeploymentSpec:
|
11
|
+
"""Class to construct deploy.yml file for Model container services deployment.
|
12
|
+
|
13
|
+
Attributes:
|
14
|
+
workspace_path: A local path where model related files should be dumped to.
|
15
|
+
"""
|
16
|
+
|
17
|
+
DEPLOY_SPEC_FILE_REL_PATH = "deploy.yml"
|
18
|
+
|
19
|
+
def __init__(self, workspace_path: pathlib.Path) -> None:
|
20
|
+
self.workspace_path = workspace_path
|
21
|
+
|
22
|
+
def save(
|
23
|
+
self,
|
24
|
+
*,
|
25
|
+
database_name: sql_identifier.SqlIdentifier,
|
26
|
+
schema_name: sql_identifier.SqlIdentifier,
|
27
|
+
model_name: sql_identifier.SqlIdentifier,
|
28
|
+
version_name: sql_identifier.SqlIdentifier,
|
29
|
+
service_database_name: Optional[sql_identifier.SqlIdentifier],
|
30
|
+
service_schema_name: Optional[sql_identifier.SqlIdentifier],
|
31
|
+
service_name: sql_identifier.SqlIdentifier,
|
32
|
+
image_build_compute_pool_name: sql_identifier.SqlIdentifier,
|
33
|
+
service_compute_pool_name: sql_identifier.SqlIdentifier,
|
34
|
+
image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
|
35
|
+
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
36
|
+
image_repo_name: sql_identifier.SqlIdentifier,
|
37
|
+
image_name: Optional[sql_identifier.SqlIdentifier],
|
38
|
+
ingress_enabled: bool,
|
39
|
+
min_instances: int,
|
40
|
+
max_instances: int,
|
41
|
+
gpu: Optional[str],
|
42
|
+
force_rebuild: bool,
|
43
|
+
external_access_integration: sql_identifier.SqlIdentifier,
|
44
|
+
) -> None:
|
45
|
+
# create the deployment spec
|
46
|
+
# models spec
|
47
|
+
fq_model_name = identifier.get_schema_level_object_identifier(
|
48
|
+
database_name.identifier(), schema_name.identifier(), model_name.identifier()
|
49
|
+
)
|
50
|
+
model_dict = model_deployment_spec_schema.ModelDict(name=fq_model_name, version=version_name.identifier())
|
51
|
+
|
52
|
+
# image_build spec
|
53
|
+
saved_image_repo_database = image_repo_database_name or database_name
|
54
|
+
saved_image_repo_schema = image_repo_schema_name or schema_name
|
55
|
+
fq_image_repo_name = identifier.get_schema_level_object_identifier(
|
56
|
+
saved_image_repo_database.identifier(), saved_image_repo_schema.identifier(), image_repo_name.identifier()
|
57
|
+
)
|
58
|
+
image_build_dict = model_deployment_spec_schema.ImageBuildDict(
|
59
|
+
compute_pool=image_build_compute_pool_name.identifier(),
|
60
|
+
image_repo=fq_image_repo_name,
|
61
|
+
force_rebuild=force_rebuild,
|
62
|
+
external_access_integrations=[external_access_integration.identifier()],
|
63
|
+
)
|
64
|
+
if image_name:
|
65
|
+
image_build_dict["image_name"] = image_name.identifier()
|
66
|
+
|
67
|
+
# service spec
|
68
|
+
saved_service_database = service_database_name or database_name
|
69
|
+
saved_service_schema = service_schema_name or schema_name
|
70
|
+
fq_service_name = identifier.get_schema_level_object_identifier(
|
71
|
+
saved_service_database.identifier(), saved_service_schema.identifier(), service_name.identifier()
|
72
|
+
)
|
73
|
+
service_dict = model_deployment_spec_schema.ServiceDict(
|
74
|
+
name=fq_service_name,
|
75
|
+
compute_pool=service_compute_pool_name.identifier(),
|
76
|
+
ingress_enabled=ingress_enabled,
|
77
|
+
min_instances=min_instances,
|
78
|
+
max_instances=max_instances,
|
79
|
+
)
|
80
|
+
if gpu:
|
81
|
+
service_dict["gpu"] = gpu
|
82
|
+
|
83
|
+
# model deployment spec
|
84
|
+
model_deployment_spec_dict = model_deployment_spec_schema.ModelDeploymentSpecDict(
|
85
|
+
models=[model_dict],
|
86
|
+
image_build=image_build_dict,
|
87
|
+
service=service_dict,
|
88
|
+
)
|
89
|
+
|
90
|
+
# save the yaml
|
91
|
+
file_path = self.workspace_path / self.DEPLOY_SPEC_FILE_REL_PATH
|
92
|
+
with file_path.open("w", encoding="utf-8") as f:
|
93
|
+
# Anchors are not supported in the server, avoid that.
|
94
|
+
yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
|
95
|
+
yaml.safe_dump(model_deployment_spec_dict, f)
|