snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +64 -31
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +41 -5
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +8 -8
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +40 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +12 -8
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +2 -4
- snowflake/ml/jobs/_utils/interop_utils.py +442 -0
- snowflake/ml/jobs/_utils/payload_utils.py +86 -62
- snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
- snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
- snowflake/ml/jobs/_utils/spec_utils.py +22 -36
- snowflake/ml/jobs/_utils/types.py +8 -2
- snowflake/ml/jobs/decorators.py +7 -8
- snowflake/ml/jobs/job.py +158 -26
- snowflake/ml/jobs/manager.py +78 -30
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +230 -50
- snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +22 -18
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +46 -25
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +35 -26
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +12 -8
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +8 -8
- snowflake/ml/model/model_signature.py +9 -20
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +5 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/registry/_manager/model_manager.py +50 -29
- snowflake/ml/registry/registry.py +34 -23
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -2,10 +2,11 @@ import enum
|
|
2
2
|
import pathlib
|
3
3
|
import tempfile
|
4
4
|
import warnings
|
5
|
-
from typing import Any, Callable,
|
5
|
+
from typing import Any, Callable, Optional, Union, overload
|
6
6
|
|
7
7
|
import pandas as pd
|
8
8
|
|
9
|
+
from snowflake import snowpark
|
9
10
|
from snowflake.ml._internal import telemetry
|
10
11
|
from snowflake.ml._internal.utils import sql_identifier
|
11
12
|
from snowflake.ml.lineage import lineage_node
|
@@ -32,7 +33,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
32
33
|
_service_ops: service_ops.ServiceOperator
|
33
34
|
_model_name: sql_identifier.SqlIdentifier
|
34
35
|
_version_name: sql_identifier.SqlIdentifier
|
35
|
-
_functions:
|
36
|
+
_functions: list[model_manifest_schema.ModelFunctionInfo]
|
36
37
|
|
37
38
|
def __init__(self) -> None:
|
38
39
|
raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
|
@@ -152,7 +153,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
152
153
|
project=_TELEMETRY_PROJECT,
|
153
154
|
subproject=_TELEMETRY_SUBPROJECT,
|
154
155
|
)
|
155
|
-
def show_metrics(self) ->
|
156
|
+
def show_metrics(self) -> dict[str, Any]:
|
156
157
|
"""Show all metrics logged with the model version.
|
157
158
|
|
158
159
|
Returns:
|
@@ -293,7 +294,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
293
294
|
statement_params=statement_params,
|
294
295
|
)
|
295
296
|
|
296
|
-
def _get_functions(self) ->
|
297
|
+
def _get_functions(self) -> list[model_manifest_schema.ModelFunctionInfo]:
|
297
298
|
statement_params = telemetry.get_statement_params(
|
298
299
|
project=_TELEMETRY_PROJECT,
|
299
300
|
subproject=_TELEMETRY_SUBPROJECT,
|
@@ -327,7 +328,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
327
328
|
project=_TELEMETRY_PROJECT,
|
328
329
|
subproject=_TELEMETRY_SUBPROJECT,
|
329
330
|
)
|
330
|
-
def show_functions(self) ->
|
331
|
+
def show_functions(self) -> list[model_manifest_schema.ModelFunctionInfo]:
|
331
332
|
"""Show all functions information in a model version that is callable.
|
332
333
|
|
333
334
|
Returns:
|
@@ -405,11 +406,6 @@ class ModelVersion(lineage_node.LineageNode):
|
|
405
406
|
strict_input_validation: Enable stricter validation for the input data. This will result value range based
|
406
407
|
type validation to make sure your input data won't overflow when providing to the model.
|
407
408
|
|
408
|
-
Raises:
|
409
|
-
ValueError: When no method with the corresponding name is available.
|
410
|
-
ValueError: When there are more than 1 target methods available in the model but no function name specified.
|
411
|
-
ValueError: When the partition column is not a valid Snowflake identifier.
|
412
|
-
|
413
409
|
Returns:
|
414
410
|
The prediction data. It would be the same type dataframe as your input.
|
415
411
|
"""
|
@@ -422,29 +418,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
422
418
|
# Partition column must be a valid identifier
|
423
419
|
partition_column = sql_identifier.SqlIdentifier(partition_column)
|
424
420
|
|
425
|
-
|
426
|
-
|
427
|
-
if function_name:
|
428
|
-
req_method_name = sql_identifier.SqlIdentifier(function_name).identifier()
|
429
|
-
find_method: Callable[[model_manifest_schema.ModelFunctionInfo], bool] = (
|
430
|
-
lambda method: method["name"] == req_method_name
|
431
|
-
)
|
432
|
-
target_function_info = next(
|
433
|
-
filter(find_method, functions),
|
434
|
-
None,
|
435
|
-
)
|
436
|
-
if target_function_info is None:
|
437
|
-
raise ValueError(
|
438
|
-
f"There is no method with name {function_name} available in the model"
|
439
|
-
f" {self.fully_qualified_model_name} version {self.version_name}"
|
440
|
-
)
|
441
|
-
elif len(functions) != 1:
|
442
|
-
raise ValueError(
|
443
|
-
f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}"
|
444
|
-
f" version {self.version_name}. Please specify a `function_name` when calling the `run` method."
|
445
|
-
)
|
446
|
-
else:
|
447
|
-
target_function_info = functions[0]
|
421
|
+
target_function_info = self._get_function_info(function_name=function_name)
|
448
422
|
|
449
423
|
if service_name:
|
450
424
|
database_name_id, schema_name_id, service_name_id = sql_identifier.parse_fully_qualified_name(service_name)
|
@@ -475,6 +449,33 @@ class ModelVersion(lineage_node.LineageNode):
|
|
475
449
|
is_partitioned=target_function_info["is_partitioned"],
|
476
450
|
)
|
477
451
|
|
452
|
+
def _get_function_info(self, function_name: Optional[str]) -> model_manifest_schema.ModelFunctionInfo:
|
453
|
+
functions: list[model_manifest_schema.ModelFunctionInfo] = self._functions
|
454
|
+
|
455
|
+
if function_name:
|
456
|
+
req_method_name = sql_identifier.SqlIdentifier(function_name).identifier()
|
457
|
+
find_method: Callable[[model_manifest_schema.ModelFunctionInfo], bool] = (
|
458
|
+
lambda method: method["name"] == req_method_name
|
459
|
+
)
|
460
|
+
target_function_info = next(
|
461
|
+
filter(find_method, functions),
|
462
|
+
None,
|
463
|
+
)
|
464
|
+
if target_function_info is None:
|
465
|
+
raise ValueError(
|
466
|
+
f"There is no method with name {function_name} available in the model"
|
467
|
+
f" {self.fully_qualified_model_name} version {self.version_name}"
|
468
|
+
)
|
469
|
+
elif len(functions) != 1:
|
470
|
+
raise ValueError(
|
471
|
+
f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}"
|
472
|
+
f" version {self.version_name}. Please specify a `function_name` when calling the `run` method."
|
473
|
+
)
|
474
|
+
else:
|
475
|
+
target_function_info = functions[0]
|
476
|
+
|
477
|
+
return target_function_info
|
478
|
+
|
478
479
|
@telemetry.send_api_usage_telemetry(
|
479
480
|
project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, func_params_to_log=["export_mode"]
|
480
481
|
)
|
@@ -684,7 +685,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
684
685
|
num_workers: Optional[int] = None,
|
685
686
|
max_batch_rows: Optional[int] = None,
|
686
687
|
force_rebuild: bool = False,
|
687
|
-
build_external_access_integrations: Optional[
|
688
|
+
build_external_access_integrations: Optional[list[str]] = None,
|
688
689
|
block: bool = True,
|
689
690
|
) -> Union[str, async_job.AsyncJob]:
|
690
691
|
"""Create an inference service with the given spec.
|
@@ -751,7 +752,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
751
752
|
max_batch_rows: Optional[int] = None,
|
752
753
|
force_rebuild: bool = False,
|
753
754
|
build_external_access_integration: Optional[str] = None,
|
754
|
-
build_external_access_integrations: Optional[
|
755
|
+
build_external_access_integrations: Optional[list[str]] = None,
|
755
756
|
block: bool = True,
|
756
757
|
) -> Union[str, async_job.AsyncJob]:
|
757
758
|
"""Create an inference service with the given spec.
|
@@ -914,5 +915,72 @@ class ModelVersion(lineage_node.LineageNode):
|
|
914
915
|
statement_params=statement_params,
|
915
916
|
)
|
916
917
|
|
918
|
+
@snowpark._internal.utils.private_preview(version="1.8.3")
|
919
|
+
@telemetry.send_api_usage_telemetry(
|
920
|
+
project=_TELEMETRY_PROJECT,
|
921
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
922
|
+
)
|
923
|
+
def run_job(
|
924
|
+
self,
|
925
|
+
X: Union[pd.DataFrame, "dataframe.DataFrame"],
|
926
|
+
*,
|
927
|
+
job_name: str,
|
928
|
+
compute_pool: str,
|
929
|
+
image_repo: str,
|
930
|
+
output_table_name: str,
|
931
|
+
function_name: Optional[str] = None,
|
932
|
+
cpu_requests: Optional[str] = None,
|
933
|
+
memory_requests: Optional[str] = None,
|
934
|
+
gpu_requests: Optional[Union[str, int]] = None,
|
935
|
+
num_workers: Optional[int] = None,
|
936
|
+
max_batch_rows: Optional[int] = None,
|
937
|
+
force_rebuild: bool = False,
|
938
|
+
build_external_access_integrations: Optional[list[str]] = None,
|
939
|
+
) -> Union[pd.DataFrame, dataframe.DataFrame]:
|
940
|
+
statement_params = telemetry.get_statement_params(
|
941
|
+
project=_TELEMETRY_PROJECT,
|
942
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
943
|
+
)
|
944
|
+
target_function_info = self._get_function_info(function_name=function_name)
|
945
|
+
job_db_id, job_schema_id, job_id = sql_identifier.parse_fully_qualified_name(job_name)
|
946
|
+
image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
|
947
|
+
output_table_db_id, output_table_schema_id, output_table_id = sql_identifier.parse_fully_qualified_name(
|
948
|
+
output_table_name
|
949
|
+
)
|
950
|
+
warehouse = self._service_ops._session.get_current_warehouse()
|
951
|
+
assert warehouse, "No active warehouse selected in the current session."
|
952
|
+
return self._service_ops.invoke_job_method(
|
953
|
+
target_method=target_function_info["target_method"],
|
954
|
+
signature=target_function_info["signature"],
|
955
|
+
X=X,
|
956
|
+
database_name=None,
|
957
|
+
schema_name=None,
|
958
|
+
model_name=self._model_name,
|
959
|
+
version_name=self._version_name,
|
960
|
+
job_database_name=job_db_id,
|
961
|
+
job_schema_name=job_schema_id,
|
962
|
+
job_name=job_id,
|
963
|
+
compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
|
964
|
+
warehouse_name=sql_identifier.SqlIdentifier(warehouse),
|
965
|
+
image_repo_database_name=image_repo_db_id,
|
966
|
+
image_repo_schema_name=image_repo_schema_id,
|
967
|
+
image_repo_name=image_repo_id,
|
968
|
+
output_table_database_name=output_table_db_id,
|
969
|
+
output_table_schema_name=output_table_schema_id,
|
970
|
+
output_table_name=output_table_id,
|
971
|
+
cpu_requests=cpu_requests,
|
972
|
+
memory_requests=memory_requests,
|
973
|
+
gpu_requests=gpu_requests,
|
974
|
+
num_workers=num_workers,
|
975
|
+
max_batch_rows=max_batch_rows,
|
976
|
+
force_rebuild=force_rebuild,
|
977
|
+
build_external_access_integrations=(
|
978
|
+
None
|
979
|
+
if build_external_access_integrations is None
|
980
|
+
else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
|
981
|
+
),
|
982
|
+
statement_params=statement_params,
|
983
|
+
)
|
984
|
+
|
917
985
|
|
918
986
|
lineage_node.DOMAIN_LINEAGE_REGISTRY["model"] = ModelVersion
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import json
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Optional, TypedDict
|
3
3
|
|
4
4
|
from typing_extensions import NotRequired
|
5
5
|
|
@@ -14,7 +14,7 @@ MODEL_VERSION_METADATA_SCHEMA_VERSION = "2024-01-01"
|
|
14
14
|
|
15
15
|
|
16
16
|
class ModelVersionMetadataSchema(TypedDict):
|
17
|
-
metrics: NotRequired[
|
17
|
+
metrics: NotRequired[dict[str, Any]]
|
18
18
|
|
19
19
|
|
20
20
|
class MetadataOperator:
|
@@ -44,7 +44,7 @@ class MetadataOperator:
|
|
44
44
|
)
|
45
45
|
|
46
46
|
@staticmethod
|
47
|
-
def _parse(metadata_dict:
|
47
|
+
def _parse(metadata_dict: dict[str, Any]) -> ModelVersionMetadataSchema:
|
48
48
|
loaded_metadata_schema_version = metadata_dict.get("snowpark_ml_schema_version", None)
|
49
49
|
if loaded_metadata_schema_version is None:
|
50
50
|
return ModelVersionMetadataSchema(metrics={})
|
@@ -65,8 +65,8 @@ class MetadataOperator:
|
|
65
65
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
66
66
|
model_name: sql_identifier.SqlIdentifier,
|
67
67
|
version_name: sql_identifier.SqlIdentifier,
|
68
|
-
statement_params: Optional[
|
69
|
-
) ->
|
68
|
+
statement_params: Optional[dict[str, Any]] = None,
|
69
|
+
) -> dict[str, Any]:
|
70
70
|
version_info_list = self._model_client.show_versions(
|
71
71
|
database_name=database_name,
|
72
72
|
schema_name=schema_name,
|
@@ -89,7 +89,7 @@ class MetadataOperator:
|
|
89
89
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
90
90
|
model_name: sql_identifier.SqlIdentifier,
|
91
91
|
version_name: sql_identifier.SqlIdentifier,
|
92
|
-
statement_params: Optional[
|
92
|
+
statement_params: Optional[dict[str, Any]] = None,
|
93
93
|
) -> ModelVersionMetadataSchema:
|
94
94
|
metadata_dict = self._get_current_metadata_dict(
|
95
95
|
database_name=database_name,
|
@@ -108,7 +108,7 @@ class MetadataOperator:
|
|
108
108
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
109
109
|
model_name: sql_identifier.SqlIdentifier,
|
110
110
|
version_name: sql_identifier.SqlIdentifier,
|
111
|
-
statement_params: Optional[
|
111
|
+
statement_params: Optional[dict[str, Any]] = None,
|
112
112
|
) -> None:
|
113
113
|
metadata_dict = self._get_current_metadata_dict(
|
114
114
|
database_name=database_name,
|
@@ -4,7 +4,7 @@ import os
|
|
4
4
|
import pathlib
|
5
5
|
import tempfile
|
6
6
|
import warnings
|
7
|
-
from typing import Any,
|
7
|
+
from typing import Any, Literal, Optional, TypedDict, Union, cast, overload
|
8
8
|
|
9
9
|
import yaml
|
10
10
|
|
@@ -104,7 +104,7 @@ class ModelOperator:
|
|
104
104
|
*,
|
105
105
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
106
106
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
107
|
-
statement_params: Optional[
|
107
|
+
statement_params: Optional[dict[str, Any]] = None,
|
108
108
|
) -> str:
|
109
109
|
stage_name = sql_identifier.SqlIdentifier(
|
110
110
|
snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
|
@@ -137,7 +137,7 @@ class ModelOperator:
|
|
137
137
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
138
138
|
model_name: sql_identifier.SqlIdentifier,
|
139
139
|
version_name: sql_identifier.SqlIdentifier,
|
140
|
-
statement_params: Optional[
|
140
|
+
statement_params: Optional[dict[str, Any]] = None,
|
141
141
|
) -> ModelAction:
|
142
142
|
if self.validate_existence(
|
143
143
|
database_name=database_name,
|
@@ -169,7 +169,7 @@ class ModelOperator:
|
|
169
169
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
170
170
|
model_name: sql_identifier.SqlIdentifier,
|
171
171
|
version_name: sql_identifier.SqlIdentifier,
|
172
|
-
statement_params: Optional[
|
172
|
+
statement_params: Optional[dict[str, Any]] = None,
|
173
173
|
) -> None:
|
174
174
|
model_action = self.get_model_action_from_model_name_and_version(
|
175
175
|
database_name=database_name,
|
@@ -205,7 +205,7 @@ class ModelOperator:
|
|
205
205
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
206
206
|
model_name: sql_identifier.SqlIdentifier,
|
207
207
|
version_name: sql_identifier.SqlIdentifier,
|
208
|
-
statement_params: Optional[
|
208
|
+
statement_params: Optional[dict[str, Any]] = None,
|
209
209
|
use_live_commit: Optional[bool] = False,
|
210
210
|
) -> None:
|
211
211
|
|
@@ -263,7 +263,7 @@ class ModelOperator:
|
|
263
263
|
model_name: sql_identifier.SqlIdentifier,
|
264
264
|
version_name: sql_identifier.SqlIdentifier,
|
265
265
|
model_exists: bool,
|
266
|
-
statement_params: Optional[
|
266
|
+
statement_params: Optional[dict[str, Any]] = None,
|
267
267
|
) -> None:
|
268
268
|
if model_exists:
|
269
269
|
return self._model_version_client.add_version_from_model_version(
|
@@ -296,8 +296,8 @@ class ModelOperator:
|
|
296
296
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
297
297
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
298
298
|
model_name: Optional[sql_identifier.SqlIdentifier] = None,
|
299
|
-
statement_params: Optional[
|
300
|
-
) ->
|
299
|
+
statement_params: Optional[dict[str, Any]] = None,
|
300
|
+
) -> list[row.Row]:
|
301
301
|
if model_name:
|
302
302
|
return self._model_client.show_versions(
|
303
303
|
database_name=database_name,
|
@@ -320,8 +320,8 @@ class ModelOperator:
|
|
320
320
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
321
321
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
322
322
|
model_name: Optional[sql_identifier.SqlIdentifier] = None,
|
323
|
-
statement_params: Optional[
|
324
|
-
) ->
|
323
|
+
statement_params: Optional[dict[str, Any]] = None,
|
324
|
+
) -> list[sql_identifier.SqlIdentifier]:
|
325
325
|
res = self.show_models_or_versions(
|
326
326
|
database_name=database_name,
|
327
327
|
schema_name=schema_name,
|
@@ -341,7 +341,7 @@ class ModelOperator:
|
|
341
341
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
342
342
|
model_name: sql_identifier.SqlIdentifier,
|
343
343
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
344
|
-
statement_params: Optional[
|
344
|
+
statement_params: Optional[dict[str, Any]] = None,
|
345
345
|
) -> bool:
|
346
346
|
if version_name:
|
347
347
|
res = self._model_client.show_versions(
|
@@ -369,7 +369,7 @@ class ModelOperator:
|
|
369
369
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
370
370
|
model_name: sql_identifier.SqlIdentifier,
|
371
371
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
372
|
-
statement_params: Optional[
|
372
|
+
statement_params: Optional[dict[str, Any]] = None,
|
373
373
|
) -> str:
|
374
374
|
if version_name:
|
375
375
|
res = self._model_client.show_versions(
|
@@ -398,7 +398,7 @@ class ModelOperator:
|
|
398
398
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
399
399
|
model_name: sql_identifier.SqlIdentifier,
|
400
400
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
401
|
-
statement_params: Optional[
|
401
|
+
statement_params: Optional[dict[str, Any]] = None,
|
402
402
|
) -> None:
|
403
403
|
if version_name:
|
404
404
|
self._model_version_client.set_comment(
|
@@ -426,7 +426,7 @@ class ModelOperator:
|
|
426
426
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
427
427
|
model_name: sql_identifier.SqlIdentifier,
|
428
428
|
version_name: sql_identifier.SqlIdentifier,
|
429
|
-
statement_params: Optional[
|
429
|
+
statement_params: Optional[dict[str, Any]] = None,
|
430
430
|
) -> None:
|
431
431
|
self._model_version_client.set_alias(
|
432
432
|
alias_name=alias_name,
|
@@ -444,7 +444,7 @@ class ModelOperator:
|
|
444
444
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
445
445
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
446
446
|
model_name: sql_identifier.SqlIdentifier,
|
447
|
-
statement_params: Optional[
|
447
|
+
statement_params: Optional[dict[str, Any]] = None,
|
448
448
|
) -> None:
|
449
449
|
self._model_version_client.unset_alias(
|
450
450
|
database_name=database_name,
|
@@ -461,7 +461,7 @@ class ModelOperator:
|
|
461
461
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
462
462
|
model_name: sql_identifier.SqlIdentifier,
|
463
463
|
version_name: sql_identifier.SqlIdentifier,
|
464
|
-
statement_params: Optional[
|
464
|
+
statement_params: Optional[dict[str, Any]] = None,
|
465
465
|
) -> None:
|
466
466
|
if not self.validate_existence(
|
467
467
|
database_name=database_name,
|
@@ -485,7 +485,7 @@ class ModelOperator:
|
|
485
485
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
486
486
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
487
487
|
model_name: sql_identifier.SqlIdentifier,
|
488
|
-
statement_params: Optional[
|
488
|
+
statement_params: Optional[dict[str, Any]] = None,
|
489
489
|
) -> sql_identifier.SqlIdentifier:
|
490
490
|
res = self._model_client.show_models(
|
491
491
|
database_name=database_name,
|
@@ -504,7 +504,7 @@ class ModelOperator:
|
|
504
504
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
505
505
|
model_name: sql_identifier.SqlIdentifier,
|
506
506
|
alias_name: sql_identifier.SqlIdentifier,
|
507
|
-
statement_params: Optional[
|
507
|
+
statement_params: Optional[dict[str, Any]] = None,
|
508
508
|
) -> Optional[sql_identifier.SqlIdentifier]:
|
509
509
|
res = self._model_client.show_versions(
|
510
510
|
database_name=database_name,
|
@@ -528,7 +528,7 @@ class ModelOperator:
|
|
528
528
|
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
529
529
|
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
530
530
|
tag_name: sql_identifier.SqlIdentifier,
|
531
|
-
statement_params: Optional[
|
531
|
+
statement_params: Optional[dict[str, Any]] = None,
|
532
532
|
) -> Optional[str]:
|
533
533
|
r = self._tag_client.get_tag_value(
|
534
534
|
database_name=database_name,
|
@@ -550,15 +550,15 @@ class ModelOperator:
|
|
550
550
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
551
551
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
552
552
|
model_name: sql_identifier.SqlIdentifier,
|
553
|
-
statement_params: Optional[
|
554
|
-
) ->
|
553
|
+
statement_params: Optional[dict[str, Any]] = None,
|
554
|
+
) -> dict[str, str]:
|
555
555
|
tags_info = self._tag_client.get_tag_list(
|
556
556
|
database_name=database_name,
|
557
557
|
schema_name=schema_name,
|
558
558
|
model_name=model_name,
|
559
559
|
statement_params=statement_params,
|
560
560
|
)
|
561
|
-
res:
|
561
|
+
res: dict[str, str] = {
|
562
562
|
identifier.get_schema_level_object_identifier(
|
563
563
|
sql_identifier.SqlIdentifier(r.TAG_DATABASE, case_sensitive=True),
|
564
564
|
sql_identifier.SqlIdentifier(r.TAG_SCHEMA, case_sensitive=True),
|
@@ -578,7 +578,7 @@ class ModelOperator:
|
|
578
578
|
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
579
579
|
tag_name: sql_identifier.SqlIdentifier,
|
580
580
|
tag_value: str,
|
581
|
-
statement_params: Optional[
|
581
|
+
statement_params: Optional[dict[str, Any]] = None,
|
582
582
|
) -> None:
|
583
583
|
self._tag_client.set_tag_on_model(
|
584
584
|
database_name=database_name,
|
@@ -600,7 +600,7 @@ class ModelOperator:
|
|
600
600
|
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
601
601
|
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
602
602
|
tag_name: sql_identifier.SqlIdentifier,
|
603
|
-
statement_params: Optional[
|
603
|
+
statement_params: Optional[dict[str, Any]] = None,
|
604
604
|
) -> None:
|
605
605
|
self._tag_client.unset_tag_on_model(
|
606
606
|
database_name=database_name,
|
@@ -619,8 +619,8 @@ class ModelOperator:
|
|
619
619
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
620
620
|
model_name: sql_identifier.SqlIdentifier,
|
621
621
|
version_name: sql_identifier.SqlIdentifier,
|
622
|
-
statement_params: Optional[
|
623
|
-
) ->
|
622
|
+
statement_params: Optional[dict[str, Any]] = None,
|
623
|
+
) -> list[ServiceInfo]:
|
624
624
|
res = self._model_client.show_versions(
|
625
625
|
database_name=database_name,
|
626
626
|
schema_name=schema_name,
|
@@ -682,7 +682,7 @@ class ModelOperator:
|
|
682
682
|
service_database_name: Optional[sql_identifier.SqlIdentifier],
|
683
683
|
service_schema_name: Optional[sql_identifier.SqlIdentifier],
|
684
684
|
service_name: sql_identifier.SqlIdentifier,
|
685
|
-
statement_params: Optional[
|
685
|
+
statement_params: Optional[dict[str, Any]] = None,
|
686
686
|
) -> None:
|
687
687
|
services = self.show_services(
|
688
688
|
database_name=database_name,
|
@@ -724,7 +724,7 @@ class ModelOperator:
|
|
724
724
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
725
725
|
model_name: sql_identifier.SqlIdentifier,
|
726
726
|
version_name: sql_identifier.SqlIdentifier,
|
727
|
-
statement_params: Optional[
|
727
|
+
statement_params: Optional[dict[str, Any]] = None,
|
728
728
|
) -> model_manifest_schema.ModelManifestDict:
|
729
729
|
with tempfile.TemporaryDirectory() as tmpdir:
|
730
730
|
self._model_version_client.get_file(
|
@@ -741,9 +741,9 @@ class ModelOperator:
|
|
741
741
|
|
742
742
|
@staticmethod
|
743
743
|
def _match_model_spec_with_sql_functions(
|
744
|
-
sql_functions_names:
|
745
|
-
) ->
|
746
|
-
res:
|
744
|
+
sql_functions_names: list[sql_identifier.SqlIdentifier], target_methods: list[str]
|
745
|
+
) -> dict[sql_identifier.SqlIdentifier, str]:
|
746
|
+
res: dict[sql_identifier.SqlIdentifier, str] = {}
|
747
747
|
|
748
748
|
for target_method in target_methods:
|
749
749
|
# Here we need to find the SQL function corresponding to the Python function.
|
@@ -766,7 +766,7 @@ class ModelOperator:
|
|
766
766
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
767
767
|
model_name: sql_identifier.SqlIdentifier,
|
768
768
|
version_name: sql_identifier.SqlIdentifier,
|
769
|
-
statement_params: Optional[
|
769
|
+
statement_params: Optional[dict[str, Any]] = None,
|
770
770
|
) -> model_meta_schema.ModelMetadataDict:
|
771
771
|
raw_model_spec_res = self._model_client.show_versions(
|
772
772
|
database_name=database_name,
|
@@ -787,7 +787,7 @@ class ModelOperator:
|
|
787
787
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
788
788
|
model_name: sql_identifier.SqlIdentifier,
|
789
789
|
version_name: sql_identifier.SqlIdentifier,
|
790
|
-
statement_params: Optional[
|
790
|
+
statement_params: Optional[dict[str, Any]] = None,
|
791
791
|
) -> type_hints.Task:
|
792
792
|
model_version = self._model_client.show_versions(
|
793
793
|
database_name=database_name,
|
@@ -809,8 +809,8 @@ class ModelOperator:
|
|
809
809
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
810
810
|
model_name: sql_identifier.SqlIdentifier,
|
811
811
|
version_name: sql_identifier.SqlIdentifier,
|
812
|
-
statement_params: Optional[
|
813
|
-
) ->
|
812
|
+
statement_params: Optional[dict[str, Any]] = None,
|
813
|
+
) -> list[model_manifest_schema.ModelFunctionInfo]:
|
814
814
|
model_spec = self._fetch_model_spec(
|
815
815
|
database_name=database_name,
|
816
816
|
schema_name=schema_name,
|
@@ -907,7 +907,7 @@ class ModelOperator:
|
|
907
907
|
version_name: sql_identifier.SqlIdentifier,
|
908
908
|
strict_input_validation: bool = False,
|
909
909
|
partition_column: Optional[sql_identifier.SqlIdentifier] = None,
|
910
|
-
statement_params: Optional[
|
910
|
+
statement_params: Optional[dict[str, str]] = None,
|
911
911
|
is_partitioned: Optional[bool] = None,
|
912
912
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
913
913
|
...
|
@@ -923,7 +923,7 @@ class ModelOperator:
|
|
923
923
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
924
924
|
service_name: sql_identifier.SqlIdentifier,
|
925
925
|
strict_input_validation: bool = False,
|
926
|
-
statement_params: Optional[
|
926
|
+
statement_params: Optional[dict[str, str]] = None,
|
927
927
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
928
928
|
...
|
929
929
|
|
@@ -941,7 +941,7 @@ class ModelOperator:
|
|
941
941
|
service_name: Optional[sql_identifier.SqlIdentifier] = None,
|
942
942
|
strict_input_validation: bool = False,
|
943
943
|
partition_column: Optional[sql_identifier.SqlIdentifier] = None,
|
944
|
-
statement_params: Optional[
|
944
|
+
statement_params: Optional[dict[str, str]] = None,
|
945
945
|
is_partitioned: Optional[bool] = None,
|
946
946
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
947
947
|
identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
|
@@ -1059,7 +1059,7 @@ class ModelOperator:
|
|
1059
1059
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
1060
1060
|
model_name: sql_identifier.SqlIdentifier,
|
1061
1061
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
1062
|
-
statement_params: Optional[
|
1062
|
+
statement_params: Optional[dict[str, Any]] = None,
|
1063
1063
|
) -> None:
|
1064
1064
|
if version_name:
|
1065
1065
|
self._model_version_client.drop_version(
|
@@ -1086,7 +1086,7 @@ class ModelOperator:
|
|
1086
1086
|
new_model_db: Optional[sql_identifier.SqlIdentifier],
|
1087
1087
|
new_model_schema: Optional[sql_identifier.SqlIdentifier],
|
1088
1088
|
new_model_name: sql_identifier.SqlIdentifier,
|
1089
|
-
statement_params: Optional[
|
1089
|
+
statement_params: Optional[dict[str, Any]] = None,
|
1090
1090
|
) -> None:
|
1091
1091
|
self._model_client.rename(
|
1092
1092
|
database_name=database_name,
|
@@ -1121,7 +1121,7 @@ class ModelOperator:
|
|
1121
1121
|
version_name: sql_identifier.SqlIdentifier,
|
1122
1122
|
target_path: pathlib.Path,
|
1123
1123
|
mode: Literal["full", "model", "minimal"] = "model",
|
1124
|
-
statement_params: Optional[
|
1124
|
+
statement_params: Optional[dict[str, Any]] = None,
|
1125
1125
|
) -> None:
|
1126
1126
|
for remote_rel_path, is_dir in self.MODEL_FILE_DOWNLOAD_PATTERN[mode].items():
|
1127
1127
|
list_file_res = self._model_version_client.list_file(
|