snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +64 -31
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +41 -5
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +8 -8
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +40 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +12 -8
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +2 -4
- snowflake/ml/jobs/_utils/interop_utils.py +442 -0
- snowflake/ml/jobs/_utils/payload_utils.py +86 -62
- snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
- snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
- snowflake/ml/jobs/_utils/spec_utils.py +22 -36
- snowflake/ml/jobs/_utils/types.py +8 -2
- snowflake/ml/jobs/decorators.py +7 -8
- snowflake/ml/jobs/job.py +158 -26
- snowflake/ml/jobs/manager.py +78 -30
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +230 -50
- snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +22 -18
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +46 -25
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +35 -26
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +12 -8
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +8 -8
- snowflake/ml/model/model_signature.py +9 -20
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +5 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/registry/_manager/model_manager.py +50 -29
- snowflake/ml/registry/registry.py +34 -23
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -1,33 +1,53 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Optional
|
2
2
|
|
3
|
-
from
|
3
|
+
from pydantic import BaseModel
|
4
4
|
|
5
5
|
|
6
|
-
class
|
7
|
-
name:
|
8
|
-
version:
|
6
|
+
class Model(BaseModel):
|
7
|
+
name: str
|
8
|
+
version: str
|
9
9
|
|
10
10
|
|
11
|
-
class
|
12
|
-
compute_pool:
|
13
|
-
image_repo:
|
14
|
-
force_rebuild:
|
15
|
-
external_access_integrations:
|
11
|
+
class ImageBuild(BaseModel):
|
12
|
+
compute_pool: str
|
13
|
+
image_repo: str
|
14
|
+
force_rebuild: bool
|
15
|
+
external_access_integrations: Optional[list[str]] = None
|
16
16
|
|
17
17
|
|
18
|
-
class
|
19
|
-
name:
|
20
|
-
compute_pool:
|
21
|
-
ingress_enabled:
|
22
|
-
max_instances:
|
23
|
-
cpu:
|
24
|
-
memory:
|
25
|
-
gpu:
|
26
|
-
num_workers:
|
27
|
-
max_batch_rows:
|
18
|
+
class Service(BaseModel):
|
19
|
+
name: str
|
20
|
+
compute_pool: str
|
21
|
+
ingress_enabled: bool
|
22
|
+
max_instances: int
|
23
|
+
cpu: Optional[str] = None
|
24
|
+
memory: Optional[str] = None
|
25
|
+
gpu: Optional[str] = None
|
26
|
+
num_workers: Optional[int] = None
|
27
|
+
max_batch_rows: Optional[int] = None
|
28
28
|
|
29
29
|
|
30
|
-
class
|
31
|
-
|
32
|
-
|
33
|
-
|
30
|
+
class Job(BaseModel):
|
31
|
+
name: str
|
32
|
+
compute_pool: str
|
33
|
+
cpu: Optional[str] = None
|
34
|
+
memory: Optional[str] = None
|
35
|
+
gpu: Optional[str] = None
|
36
|
+
num_workers: Optional[int] = None
|
37
|
+
max_batch_rows: Optional[int] = None
|
38
|
+
warehouse: str
|
39
|
+
target_method: str
|
40
|
+
input_table_name: str
|
41
|
+
output_table_name: str
|
42
|
+
|
43
|
+
|
44
|
+
class ModelServiceDeploymentSpec(BaseModel):
|
45
|
+
models: list[Model]
|
46
|
+
image_build: ImageBuild
|
47
|
+
service: Service
|
48
|
+
|
49
|
+
|
50
|
+
class ModelJobDeploymentSpec(BaseModel):
|
51
|
+
models: list[Model]
|
52
|
+
image_build: ImageBuild
|
53
|
+
job: Job
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, Optional
|
2
2
|
|
3
3
|
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
4
4
|
from snowflake.ml.model._client.sql import _base
|
@@ -24,8 +24,8 @@ class ModelSQLClient(_base._BaseSQLClient):
|
|
24
24
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
25
25
|
model_name: Optional[sql_identifier.SqlIdentifier] = None,
|
26
26
|
validate_result: bool = True,
|
27
|
-
statement_params: Optional[
|
28
|
-
) ->
|
27
|
+
statement_params: Optional[dict[str, Any]] = None,
|
28
|
+
) -> list[row.Row]:
|
29
29
|
actual_database_name = database_name or self._database_name
|
30
30
|
actual_schema_name = schema_name or self._schema_name
|
31
31
|
fully_qualified_schema_name = ".".join([actual_database_name.identifier(), actual_schema_name.identifier()])
|
@@ -57,8 +57,8 @@ class ModelSQLClient(_base._BaseSQLClient):
|
|
57
57
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
58
58
|
validate_result: bool = True,
|
59
59
|
check_model_details: bool = False,
|
60
|
-
statement_params: Optional[
|
61
|
-
) ->
|
60
|
+
statement_params: Optional[dict[str, Any]] = None,
|
61
|
+
) -> list[row.Row]:
|
62
62
|
like_sql = ""
|
63
63
|
if version_name:
|
64
64
|
like_sql = f" LIKE '{version_name.resolved()}'"
|
@@ -90,7 +90,7 @@ class ModelSQLClient(_base._BaseSQLClient):
|
|
90
90
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
91
91
|
model_name: sql_identifier.SqlIdentifier,
|
92
92
|
comment: str,
|
93
|
-
statement_params: Optional[
|
93
|
+
statement_params: Optional[dict[str, Any]] = None,
|
94
94
|
) -> None:
|
95
95
|
query_result_checker.SqlResultValidator(
|
96
96
|
self._session,
|
@@ -107,7 +107,7 @@ class ModelSQLClient(_base._BaseSQLClient):
|
|
107
107
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
108
108
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
109
109
|
model_name: sql_identifier.SqlIdentifier,
|
110
|
-
statement_params: Optional[
|
110
|
+
statement_params: Optional[dict[str, Any]] = None,
|
111
111
|
) -> None:
|
112
112
|
query_result_checker.SqlResultValidator(
|
113
113
|
self._session,
|
@@ -124,7 +124,7 @@ class ModelSQLClient(_base._BaseSQLClient):
|
|
124
124
|
new_model_db: Optional[sql_identifier.SqlIdentifier],
|
125
125
|
new_model_schema: Optional[sql_identifier.SqlIdentifier],
|
126
126
|
new_model_name: sql_identifier.SqlIdentifier,
|
127
|
-
statement_params: Optional[
|
127
|
+
statement_params: Optional[dict[str, Any]] = None,
|
128
128
|
) -> None:
|
129
129
|
# Use registry's database and schema if a non fully qualified new model name is provided.
|
130
130
|
new_fully_qualified_name = self.fully_qualified_object_name(new_model_db, new_model_schema, new_model_name)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import json
|
2
2
|
import pathlib
|
3
3
|
import textwrap
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Optional
|
5
5
|
from urllib.parse import ParseResult
|
6
6
|
|
7
7
|
from snowflake.ml._internal.utils import (
|
@@ -34,7 +34,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
34
34
|
model_name: sql_identifier.SqlIdentifier,
|
35
35
|
version_name: sql_identifier.SqlIdentifier,
|
36
36
|
stage_path: str,
|
37
|
-
statement_params: Optional[
|
37
|
+
statement_params: Optional[dict[str, Any]] = None,
|
38
38
|
) -> None:
|
39
39
|
query_result_checker.SqlResultValidator(
|
40
40
|
self._session,
|
@@ -56,7 +56,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
56
56
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
57
57
|
model_name: sql_identifier.SqlIdentifier,
|
58
58
|
version_name: sql_identifier.SqlIdentifier,
|
59
|
-
statement_params: Optional[
|
59
|
+
statement_params: Optional[dict[str, Any]] = None,
|
60
60
|
) -> None:
|
61
61
|
fq_source_model_name = self.fully_qualified_object_name(
|
62
62
|
source_database_name, source_schema_name, source_model_name
|
@@ -78,7 +78,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
78
78
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
79
79
|
model_name: sql_identifier.SqlIdentifier,
|
80
80
|
version_name: sql_identifier.SqlIdentifier,
|
81
|
-
statement_params: Optional[
|
81
|
+
statement_params: Optional[dict[str, Any]] = None,
|
82
82
|
) -> None:
|
83
83
|
sql = (
|
84
84
|
f"CREATE MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
@@ -97,7 +97,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
97
97
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
98
98
|
model_name: sql_identifier.SqlIdentifier,
|
99
99
|
version_name: sql_identifier.SqlIdentifier,
|
100
|
-
statement_params: Optional[
|
100
|
+
statement_params: Optional[dict[str, Any]] = None,
|
101
101
|
) -> None:
|
102
102
|
sql = (
|
103
103
|
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
@@ -116,7 +116,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
116
116
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
117
117
|
model_name: sql_identifier.SqlIdentifier,
|
118
118
|
version_name: sql_identifier.SqlIdentifier,
|
119
|
-
statement_params: Optional[
|
119
|
+
statement_params: Optional[dict[str, Any]] = None,
|
120
120
|
) -> None:
|
121
121
|
sql = (
|
122
122
|
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
@@ -138,7 +138,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
138
138
|
model_name: sql_identifier.SqlIdentifier,
|
139
139
|
version_name: sql_identifier.SqlIdentifier,
|
140
140
|
stage_path: str,
|
141
|
-
statement_params: Optional[
|
141
|
+
statement_params: Optional[dict[str, Any]] = None,
|
142
142
|
) -> None:
|
143
143
|
query_result_checker.SqlResultValidator(
|
144
144
|
self._session,
|
@@ -160,7 +160,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
160
160
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
161
161
|
model_name: sql_identifier.SqlIdentifier,
|
162
162
|
version_name: sql_identifier.SqlIdentifier,
|
163
|
-
statement_params: Optional[
|
163
|
+
statement_params: Optional[dict[str, Any]] = None,
|
164
164
|
) -> None:
|
165
165
|
fq_source_model_name = self.fully_qualified_object_name(
|
166
166
|
source_database_name, source_schema_name, source_model_name
|
@@ -182,7 +182,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
182
182
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
183
183
|
model_name: sql_identifier.SqlIdentifier,
|
184
184
|
version_name: sql_identifier.SqlIdentifier,
|
185
|
-
statement_params: Optional[
|
185
|
+
statement_params: Optional[dict[str, Any]] = None,
|
186
186
|
) -> None:
|
187
187
|
query_result_checker.SqlResultValidator(
|
188
188
|
self._session,
|
@@ -201,7 +201,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
201
201
|
model_name: sql_identifier.SqlIdentifier,
|
202
202
|
version_name: sql_identifier.SqlIdentifier,
|
203
203
|
alias_name: sql_identifier.SqlIdentifier,
|
204
|
-
statement_params: Optional[
|
204
|
+
statement_params: Optional[dict[str, Any]] = None,
|
205
205
|
) -> None:
|
206
206
|
query_result_checker.SqlResultValidator(
|
207
207
|
self._session,
|
@@ -219,7 +219,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
219
219
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
220
220
|
model_name: sql_identifier.SqlIdentifier,
|
221
221
|
version_or_alias_name: sql_identifier.SqlIdentifier,
|
222
|
-
statement_params: Optional[
|
222
|
+
statement_params: Optional[dict[str, Any]] = None,
|
223
223
|
) -> None:
|
224
224
|
query_result_checker.SqlResultValidator(
|
225
225
|
self._session,
|
@@ -239,8 +239,8 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
239
239
|
version_name: sql_identifier.SqlIdentifier,
|
240
240
|
file_path: pathlib.PurePosixPath,
|
241
241
|
is_dir: bool = False,
|
242
|
-
statement_params: Optional[
|
243
|
-
) ->
|
242
|
+
statement_params: Optional[dict[str, Any]] = None,
|
243
|
+
) -> list[row.Row]:
|
244
244
|
# Workaround for snowURL bug.
|
245
245
|
trailing_slash = "/" if is_dir else ""
|
246
246
|
|
@@ -276,7 +276,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
276
276
|
version_name: sql_identifier.SqlIdentifier,
|
277
277
|
file_path: pathlib.PurePosixPath,
|
278
278
|
target_path: pathlib.Path,
|
279
|
-
statement_params: Optional[
|
279
|
+
statement_params: Optional[dict[str, Any]] = None,
|
280
280
|
) -> pathlib.Path:
|
281
281
|
stage_location = pathlib.PurePosixPath(
|
282
282
|
self.fully_qualified_object_name(database_name, schema_name, model_name),
|
@@ -310,8 +310,8 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
310
310
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
311
311
|
model_name: sql_identifier.SqlIdentifier,
|
312
312
|
version_name: sql_identifier.SqlIdentifier,
|
313
|
-
statement_params: Optional[
|
314
|
-
) ->
|
313
|
+
statement_params: Optional[dict[str, Any]] = None,
|
314
|
+
) -> list[row.Row]:
|
315
315
|
res = query_result_checker.SqlResultValidator(
|
316
316
|
self._session,
|
317
317
|
(
|
@@ -331,7 +331,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
331
331
|
model_name: sql_identifier.SqlIdentifier,
|
332
332
|
version_name: sql_identifier.SqlIdentifier,
|
333
333
|
comment: str,
|
334
|
-
statement_params: Optional[
|
334
|
+
statement_params: Optional[dict[str, Any]] = None,
|
335
335
|
) -> None:
|
336
336
|
query_result_checker.SqlResultValidator(
|
337
337
|
self._session,
|
@@ -351,9 +351,9 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
351
351
|
version_name: sql_identifier.SqlIdentifier,
|
352
352
|
method_name: sql_identifier.SqlIdentifier,
|
353
353
|
input_df: dataframe.DataFrame,
|
354
|
-
input_args:
|
355
|
-
returns:
|
356
|
-
statement_params: Optional[
|
354
|
+
input_args: list[sql_identifier.SqlIdentifier],
|
355
|
+
returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
356
|
+
statement_params: Optional[dict[str, Any]] = None,
|
357
357
|
) -> dataframe.DataFrame:
|
358
358
|
with_statements = []
|
359
359
|
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
@@ -433,10 +433,10 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
433
433
|
version_name: sql_identifier.SqlIdentifier,
|
434
434
|
method_name: sql_identifier.SqlIdentifier,
|
435
435
|
input_df: dataframe.DataFrame,
|
436
|
-
input_args:
|
437
|
-
returns:
|
436
|
+
input_args: list[sql_identifier.SqlIdentifier],
|
437
|
+
returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
438
438
|
partition_column: Optional[sql_identifier.SqlIdentifier],
|
439
|
-
statement_params: Optional[
|
439
|
+
statement_params: Optional[dict[str, Any]] = None,
|
440
440
|
is_partitioned: bool = True,
|
441
441
|
) -> dataframe.DataFrame:
|
442
442
|
with_statements = []
|
@@ -529,13 +529,13 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
529
529
|
|
530
530
|
def set_metadata(
|
531
531
|
self,
|
532
|
-
metadata_dict:
|
532
|
+
metadata_dict: dict[str, Any],
|
533
533
|
*,
|
534
534
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
535
535
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
536
536
|
model_name: sql_identifier.SqlIdentifier,
|
537
537
|
version_name: sql_identifier.SqlIdentifier,
|
538
|
-
statement_params: Optional[
|
538
|
+
statement_params: Optional[dict[str, Any]] = None,
|
539
539
|
) -> None:
|
540
540
|
json_metadata = json.dumps(metadata_dict)
|
541
541
|
query_result_checker.SqlResultValidator(
|
@@ -554,7 +554,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
554
554
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
555
555
|
model_name: sql_identifier.SqlIdentifier,
|
556
556
|
version_name: sql_identifier.SqlIdentifier,
|
557
|
-
statement_params: Optional[
|
557
|
+
statement_params: Optional[dict[str, Any]] = None,
|
558
558
|
) -> None:
|
559
559
|
query_result_checker.SqlResultValidator(
|
560
560
|
self._session,
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import enum
|
2
2
|
import json
|
3
3
|
import textwrap
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Optional, Union
|
5
5
|
|
6
6
|
from snowflake import snowpark
|
7
7
|
from snowflake.ml._internal import platform_capabilities
|
@@ -47,7 +47,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
47
47
|
gpu: Optional[Union[str, int]],
|
48
48
|
force_rebuild: bool,
|
49
49
|
external_access_integration: sql_identifier.SqlIdentifier,
|
50
|
-
statement_params: Optional[
|
50
|
+
statement_params: Optional[dict[str, Any]] = None,
|
51
51
|
) -> None:
|
52
52
|
actual_image_repo_database = image_repo_database_name or self._database_name
|
53
53
|
actual_image_repo_schema = image_repo_schema_name or self._schema_name
|
@@ -73,13 +73,17 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
73
73
|
def deploy_model(
|
74
74
|
self,
|
75
75
|
*,
|
76
|
-
stage_path: str,
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
76
|
+
stage_path: Optional[str] = None,
|
77
|
+
model_deployment_spec_yaml_str: Optional[str] = None,
|
78
|
+
model_deployment_spec_file_rel_path: Optional[str] = None,
|
79
|
+
statement_params: Optional[dict[str, Any]] = None,
|
80
|
+
) -> tuple[str, snowpark.AsyncJob]:
|
81
|
+
assert model_deployment_spec_yaml_str or model_deployment_spec_file_rel_path
|
82
|
+
if model_deployment_spec_yaml_str:
|
83
|
+
sql_str = f"CALL SYSTEM$DEPLOY_MODEL('{model_deployment_spec_yaml_str}')"
|
84
|
+
else:
|
85
|
+
sql_str = f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')"
|
86
|
+
async_job = self._session.sql(sql_str).collect(block=False, statement_params=statement_params)
|
83
87
|
assert isinstance(async_job, snowpark.AsyncJob)
|
84
88
|
return async_job.query_id, async_job
|
85
89
|
|
@@ -91,9 +95,9 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
91
95
|
service_name: sql_identifier.SqlIdentifier,
|
92
96
|
method_name: sql_identifier.SqlIdentifier,
|
93
97
|
input_df: dataframe.DataFrame,
|
94
|
-
input_args:
|
95
|
-
returns:
|
96
|
-
statement_params: Optional[
|
98
|
+
input_args: list[sql_identifier.SqlIdentifier],
|
99
|
+
returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
100
|
+
statement_params: Optional[dict[str, Any]] = None,
|
97
101
|
) -> dataframe.DataFrame:
|
98
102
|
with_statements = []
|
99
103
|
actual_database_name = database_name or self._database_name
|
@@ -177,7 +181,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
177
181
|
service_name: sql_identifier.SqlIdentifier,
|
178
182
|
instance_id: str = "0",
|
179
183
|
container_name: str,
|
180
|
-
statement_params: Optional[
|
184
|
+
statement_params: Optional[dict[str, Any]] = None,
|
181
185
|
) -> str:
|
182
186
|
system_func = "SYSTEM$GET_SERVICE_LOGS"
|
183
187
|
rows = (
|
@@ -202,8 +206,8 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
202
206
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
203
207
|
service_name: sql_identifier.SqlIdentifier,
|
204
208
|
include_message: bool = False,
|
205
|
-
statement_params: Optional[
|
206
|
-
) ->
|
209
|
+
statement_params: Optional[dict[str, Any]] = None,
|
210
|
+
) -> tuple[ServiceStatus, Optional[str]]:
|
207
211
|
system_func = "SYSTEM$GET_SERVICE_STATUS"
|
208
212
|
rows = (
|
209
213
|
query_result_checker.SqlResultValidator(
|
@@ -227,7 +231,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
227
231
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
228
232
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
229
233
|
service_name: sql_identifier.SqlIdentifier,
|
230
|
-
statement_params: Optional[
|
234
|
+
statement_params: Optional[dict[str, Any]] = None,
|
231
235
|
) -> None:
|
232
236
|
query_result_checker.SqlResultValidator(
|
233
237
|
self._session,
|
@@ -241,8 +245,8 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
241
245
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
242
246
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
243
247
|
service_name: sql_identifier.SqlIdentifier,
|
244
|
-
statement_params: Optional[
|
245
|
-
) ->
|
248
|
+
statement_params: Optional[dict[str, Any]] = None,
|
249
|
+
) -> list[row.Row]:
|
246
250
|
fully_qualified_service_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
|
247
251
|
res = (
|
248
252
|
query_result_checker.SqlResultValidator(
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, Optional
|
2
2
|
|
3
3
|
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
4
4
|
from snowflake.ml.model._client.sql import _base
|
@@ -11,7 +11,7 @@ class StageSQLClient(_base._BaseSQLClient):
|
|
11
11
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
12
12
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
13
13
|
stage_name: sql_identifier.SqlIdentifier,
|
14
|
-
statement_params: Optional[
|
14
|
+
statement_params: Optional[dict[str, Any]] = None,
|
15
15
|
) -> None:
|
16
16
|
query_result_checker.SqlResultValidator(
|
17
17
|
self._session,
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, Optional
|
2
2
|
|
3
3
|
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
4
4
|
from snowflake.ml.model._client.sql import _base
|
@@ -16,7 +16,7 @@ class ModuleTagSQLClient(_base._BaseSQLClient):
|
|
16
16
|
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
17
17
|
tag_name: sql_identifier.SqlIdentifier,
|
18
18
|
tag_value: str,
|
19
|
-
statement_params: Optional[
|
19
|
+
statement_params: Optional[dict[str, Any]] = None,
|
20
20
|
) -> None:
|
21
21
|
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
22
22
|
fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
|
@@ -35,7 +35,7 @@ class ModuleTagSQLClient(_base._BaseSQLClient):
|
|
35
35
|
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
36
36
|
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
37
37
|
tag_name: sql_identifier.SqlIdentifier,
|
38
|
-
statement_params: Optional[
|
38
|
+
statement_params: Optional[dict[str, Any]] = None,
|
39
39
|
) -> None:
|
40
40
|
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
41
41
|
fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
|
@@ -54,7 +54,7 @@ class ModuleTagSQLClient(_base._BaseSQLClient):
|
|
54
54
|
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
55
55
|
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
56
56
|
tag_name: sql_identifier.SqlIdentifier,
|
57
|
-
statement_params: Optional[
|
57
|
+
statement_params: Optional[dict[str, Any]] = None,
|
58
58
|
) -> row.Row:
|
59
59
|
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
60
60
|
fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
|
@@ -75,8 +75,8 @@ class ModuleTagSQLClient(_base._BaseSQLClient):
|
|
75
75
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
76
76
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
77
77
|
model_name: sql_identifier.SqlIdentifier,
|
78
|
-
statement_params: Optional[
|
79
|
-
) ->
|
78
|
+
statement_params: Optional[dict[str, Any]] = None,
|
79
|
+
) -> list[row.Row]:
|
80
80
|
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
81
81
|
actual_database_name = database_name or self._database_name
|
82
82
|
return (
|
@@ -3,13 +3,14 @@ import tempfile
|
|
3
3
|
import uuid
|
4
4
|
import warnings
|
5
5
|
from types import ModuleType
|
6
|
-
from typing import Any,
|
6
|
+
from typing import Any, Optional, Union
|
7
7
|
from urllib import parse
|
8
8
|
|
9
9
|
from absl import logging
|
10
10
|
from packaging import requirements
|
11
11
|
|
12
12
|
from snowflake import snowpark
|
13
|
+
from snowflake.ml import version as snowml_version
|
13
14
|
from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
|
14
15
|
from snowflake.ml._internal.lineage import lineage_utils
|
15
16
|
from snowflake.ml.data import data_source
|
@@ -43,7 +44,8 @@ class ModelComposer:
|
|
43
44
|
session: Session,
|
44
45
|
stage_path: str,
|
45
46
|
*,
|
46
|
-
statement_params: Optional[
|
47
|
+
statement_params: Optional[dict[str, Any]] = None,
|
48
|
+
save_location: Optional[str] = None,
|
47
49
|
) -> None:
|
48
50
|
self.session = session
|
49
51
|
self.stage_path: Union[pathlib.PurePosixPath, parse.ParseResult] = None # type: ignore[assignment]
|
@@ -54,10 +56,29 @@ class ModelComposer:
|
|
54
56
|
# The stage path is a user stage path
|
55
57
|
self.stage_path = pathlib.PurePosixPath(stage_path)
|
56
58
|
|
57
|
-
|
58
|
-
self.
|
59
|
+
# Set up workspace based on save_location if provided, otherwise use temporary directory
|
60
|
+
self.save_location = save_location
|
61
|
+
if save_location:
|
62
|
+
# Use the save_location directory directly
|
63
|
+
self._workspace_path = pathlib.Path(save_location)
|
64
|
+
self._workspace_path.mkdir(exist_ok=True)
|
65
|
+
# ensure that the directory is empty
|
66
|
+
if any(self._workspace_path.iterdir()):
|
67
|
+
raise ValueError(f"The directory {self._workspace_path} is not empty.")
|
68
|
+
self._workspace = None
|
69
|
+
|
70
|
+
self._packager_workspace_path = self._workspace_path / ModelComposer.MODEL_DIR_REL_PATH
|
71
|
+
self._packager_workspace_path.mkdir(exist_ok=True)
|
72
|
+
self._packager_workspace = None
|
73
|
+
else:
|
74
|
+
# Use a temporary directory
|
75
|
+
self._workspace = tempfile.TemporaryDirectory()
|
76
|
+
self._workspace_path = pathlib.Path(self._workspace.name)
|
77
|
+
|
78
|
+
self._packager_workspace_path = self._workspace_path / ModelComposer.MODEL_DIR_REL_PATH
|
79
|
+
self._packager_workspace_path.mkdir(exist_ok=True)
|
59
80
|
|
60
|
-
self.packager = model_packager.ModelPackager(local_dir_path=str(self.
|
81
|
+
self.packager = model_packager.ModelPackager(local_dir_path=str(self.packager_workspace_path))
|
61
82
|
self.manifest = model_manifest.ModelManifest(workspace_path=self.workspace_path)
|
62
83
|
|
63
84
|
self.model_file_rel_path = f"model-{uuid.uuid4().hex}.zip"
|
@@ -65,16 +86,16 @@ class ModelComposer:
|
|
65
86
|
self._statement_params = statement_params
|
66
87
|
|
67
88
|
def __del__(self) -> None:
|
68
|
-
self._workspace
|
69
|
-
|
89
|
+
if self._workspace:
|
90
|
+
self._workspace.cleanup()
|
70
91
|
|
71
92
|
@property
|
72
93
|
def workspace_path(self) -> pathlib.Path:
|
73
|
-
return
|
94
|
+
return self._workspace_path
|
74
95
|
|
75
96
|
@property
|
76
|
-
def
|
77
|
-
return
|
97
|
+
def packager_workspace_path(self) -> pathlib.Path:
|
98
|
+
return self._packager_workspace_path
|
78
99
|
|
79
100
|
@property
|
80
101
|
def model_stage_path(self) -> str:
|
@@ -102,17 +123,18 @@ class ModelComposer:
|
|
102
123
|
*,
|
103
124
|
name: str,
|
104
125
|
model: model_types.SupportedModelType,
|
105
|
-
signatures: Optional[
|
126
|
+
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
106
127
|
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
107
|
-
metadata: Optional[
|
108
|
-
conda_dependencies: Optional[
|
109
|
-
pip_requirements: Optional[
|
110
|
-
artifact_repository_map: Optional[
|
111
|
-
|
128
|
+
metadata: Optional[dict[str, str]] = None,
|
129
|
+
conda_dependencies: Optional[list[str]] = None,
|
130
|
+
pip_requirements: Optional[list[str]] = None,
|
131
|
+
artifact_repository_map: Optional[dict[str, str]] = None,
|
132
|
+
resource_constraint: Optional[dict[str, str]] = None,
|
133
|
+
target_platforms: Optional[list[model_types.TargetPlatform]] = None,
|
112
134
|
python_version: Optional[str] = None,
|
113
|
-
user_files: Optional[
|
114
|
-
ext_modules: Optional[
|
115
|
-
code_paths: Optional[
|
135
|
+
user_files: Optional[dict[str, list[str]]] = None,
|
136
|
+
ext_modules: Optional[list[ModuleType]] = None,
|
137
|
+
code_paths: Optional[list[str]] = None,
|
116
138
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
117
139
|
options: Optional[model_types.ModelSaveOption] = None,
|
118
140
|
) -> model_meta.ModelMetadata:
|
@@ -146,14 +168,14 @@ class ModelComposer:
|
|
146
168
|
if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
147
169
|
snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
|
148
170
|
self.session,
|
149
|
-
reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={
|
171
|
+
reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")],
|
150
172
|
python_version=python_version or snowml_env.PYTHON_VERSION,
|
151
173
|
statement_params=self._statement_params,
|
152
174
|
).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
|
153
175
|
|
154
176
|
if len(snowml_matched_versions) < 1 and options.get("embed_local_ml_library", False) is False:
|
155
177
|
logging.info(
|
156
|
-
f"Local snowflake-ml-python library has version {
|
178
|
+
f"Local snowflake-ml-python library has version {snowml_version.VERSION},"
|
157
179
|
" which is not available in the Snowflake server, embedding local ML library automatically."
|
158
180
|
)
|
159
181
|
options["embed_local_ml_library"] = True
|
@@ -167,6 +189,8 @@ class ModelComposer:
|
|
167
189
|
conda_dependencies=conda_dependencies,
|
168
190
|
pip_requirements=pip_requirements,
|
169
191
|
artifact_repository_map=artifact_repository_map,
|
192
|
+
resource_constraint=resource_constraint,
|
193
|
+
target_platforms=target_platforms,
|
170
194
|
python_version=python_version,
|
171
195
|
ext_modules=ext_modules,
|
172
196
|
code_paths=code_paths,
|
@@ -175,9 +199,6 @@ class ModelComposer:
|
|
175
199
|
)
|
176
200
|
assert self.packager.meta is not None
|
177
201
|
|
178
|
-
file_utils.copytree(
|
179
|
-
str(self._packager_workspace_path), str(self.workspace_path / ModelComposer.MODEL_DIR_REL_PATH)
|
180
|
-
)
|
181
202
|
self.manifest.save(
|
182
203
|
model_meta=self.packager.meta,
|
183
204
|
model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
|
@@ -208,7 +229,7 @@ class ModelComposer:
|
|
208
229
|
|
209
230
|
def _get_data_sources(
|
210
231
|
self, model: model_types.SupportedModelType, sample_input_data: Optional[model_types.SupportedDataType] = None
|
211
|
-
) -> Optional[
|
232
|
+
) -> Optional[list[data_source.DataSource]]:
|
212
233
|
data_sources = lineage_utils.get_data_sources(model)
|
213
234
|
if not data_sources and sample_input_data is not None:
|
214
235
|
data_sources = lineage_utils.get_data_sources(sample_input_data)
|