snowflake-ml-python 1.9.0__py3-none-any.whl → 1.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +44 -3
- snowflake/ml/_internal/platform_capabilities.py +52 -2
- snowflake/ml/_internal/type_utils.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +54 -42
- snowflake/ml/_internal/utils/service_logger.py +105 -3
- snowflake/ml/data/_internal/arrow_ingestor.py +15 -2
- snowflake/ml/data/data_connector.py +13 -2
- snowflake/ml/data/data_ingestor.py +8 -0
- snowflake/ml/data/torch_utils.py +1 -1
- snowflake/ml/dataset/dataset.py +2 -1
- snowflake/ml/dataset/dataset_reader.py +14 -4
- snowflake/ml/experiment/__init__.py +3 -0
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
- snowflake/ml/experiment/_entities/__init__.py +4 -0
- snowflake/ml/experiment/_entities/experiment.py +10 -0
- snowflake/ml/experiment/_entities/run.py +62 -0
- snowflake/ml/experiment/_entities/run_metadata.py +68 -0
- snowflake/ml/experiment/_experiment_info.py +63 -0
- snowflake/ml/experiment/callback.py +121 -0
- snowflake/ml/experiment/experiment_tracking.py +319 -0
- snowflake/ml/jobs/_utils/constants.py +15 -4
- snowflake/ml/jobs/_utils/payload_utils.py +156 -54
- snowflake/ml/jobs/_utils/query_helper.py +16 -5
- snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +130 -23
- snowflake/ml/jobs/_utils/spec_utils.py +23 -8
- snowflake/ml/jobs/_utils/stage_utils.py +30 -14
- snowflake/ml/jobs/_utils/types.py +64 -4
- snowflake/ml/jobs/job.py +70 -75
- snowflake/ml/jobs/manager.py +59 -31
- snowflake/ml/lineage/lineage_node.py +2 -2
- snowflake/ml/model/_client/model/model_version_impl.py +16 -4
- snowflake/ml/model/_client/ops/service_ops.py +336 -137
- snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -1
- snowflake/ml/model/_client/sql/service.py +1 -38
- snowflake/ml/model/_model_composer/model_composer.py +6 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +17 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
- snowflake/ml/model/_signatures/pandas_handler.py +3 -0
- snowflake/ml/model/_signatures/utils.py +4 -0
- snowflake/ml/model/event_handler.py +117 -0
- snowflake/ml/model/model_signature.py +11 -9
- snowflake/ml/model/models/huggingface_pipeline.py +170 -1
- snowflake/ml/modeling/framework/base.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +14 -14
- snowflake/ml/modeling/metrics/correlation.py +19 -8
- snowflake/ml/modeling/metrics/ranking.py +6 -6
- snowflake/ml/modeling/metrics/regression.py +9 -9
- snowflake/ml/monitoring/explain_visualize.py +12 -5
- snowflake/ml/registry/_manager/model_manager.py +32 -15
- snowflake/ml/registry/registry.py +48 -80
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/METADATA +107 -5
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/RECORD +62 -52
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,7 @@ import dataclasses
|
|
|
2
2
|
import enum
|
|
3
3
|
import logging
|
|
4
4
|
import textwrap
|
|
5
|
-
from typing import Any, Optional
|
|
5
|
+
from typing import Any, Optional
|
|
6
6
|
|
|
7
7
|
from snowflake import snowpark
|
|
8
8
|
from snowflake.ml._internal.utils import (
|
|
@@ -69,43 +69,6 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
69
69
|
CONTAINER_STATUS = "status"
|
|
70
70
|
MESSAGE = "message"
|
|
71
71
|
|
|
72
|
-
def build_model_container(
|
|
73
|
-
self,
|
|
74
|
-
*,
|
|
75
|
-
database_name: Optional[sql_identifier.SqlIdentifier],
|
|
76
|
-
schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
77
|
-
model_name: sql_identifier.SqlIdentifier,
|
|
78
|
-
version_name: sql_identifier.SqlIdentifier,
|
|
79
|
-
compute_pool_name: sql_identifier.SqlIdentifier,
|
|
80
|
-
image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
|
|
81
|
-
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
82
|
-
image_repo_name: sql_identifier.SqlIdentifier,
|
|
83
|
-
gpu: Optional[Union[str, int]],
|
|
84
|
-
force_rebuild: bool,
|
|
85
|
-
external_access_integration: sql_identifier.SqlIdentifier,
|
|
86
|
-
statement_params: Optional[dict[str, Any]] = None,
|
|
87
|
-
) -> None:
|
|
88
|
-
actual_image_repo_database = image_repo_database_name or self._database_name
|
|
89
|
-
actual_image_repo_schema = image_repo_schema_name or self._schema_name
|
|
90
|
-
actual_model_database = database_name or self._database_name
|
|
91
|
-
actual_model_schema = schema_name or self._schema_name
|
|
92
|
-
fq_model_name = self.fully_qualified_object_name(actual_model_database, actual_model_schema, model_name)
|
|
93
|
-
fq_image_repo_name = identifier.get_schema_level_object_identifier(
|
|
94
|
-
actual_image_repo_database.identifier(),
|
|
95
|
-
actual_image_repo_schema.identifier(),
|
|
96
|
-
image_repo_name.identifier(),
|
|
97
|
-
)
|
|
98
|
-
is_gpu_str = "TRUE" if gpu else "FALSE"
|
|
99
|
-
force_rebuild_str = "TRUE" if force_rebuild else "FALSE"
|
|
100
|
-
query_result_checker.SqlResultValidator(
|
|
101
|
-
self._session,
|
|
102
|
-
(
|
|
103
|
-
f"CALL SYSTEM$BUILD_MODEL_CONTAINER('{fq_model_name}', '{version_name}', '{compute_pool_name}',"
|
|
104
|
-
f" '{fq_image_repo_name}', '{is_gpu_str}', '{force_rebuild_str}', '', '{external_access_integration}')"
|
|
105
|
-
),
|
|
106
|
-
statement_params=statement_params,
|
|
107
|
-
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
108
|
-
|
|
109
72
|
def deploy_model(
|
|
110
73
|
self,
|
|
111
74
|
*,
|
|
@@ -3,7 +3,7 @@ import tempfile
|
|
|
3
3
|
import uuid
|
|
4
4
|
import warnings
|
|
5
5
|
from types import ModuleType
|
|
6
|
-
from typing import Any, Optional, Union
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
7
7
|
from urllib import parse
|
|
8
8
|
|
|
9
9
|
from absl import logging
|
|
@@ -21,6 +21,9 @@ from snowflake.ml.model._packager.model_meta import model_meta
|
|
|
21
21
|
from snowflake.snowpark import Session
|
|
22
22
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
|
23
23
|
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from snowflake.ml.experiment._experiment_info import ExperimentInfo
|
|
26
|
+
|
|
24
27
|
|
|
25
28
|
class ModelComposer:
|
|
26
29
|
"""Top-level class to construct contents in a MODEL object in SQL.
|
|
@@ -136,6 +139,7 @@ class ModelComposer:
|
|
|
136
139
|
ext_modules: Optional[list[ModuleType]] = None,
|
|
137
140
|
code_paths: Optional[list[str]] = None,
|
|
138
141
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
|
142
|
+
experiment_info: Optional["ExperimentInfo"] = None,
|
|
139
143
|
options: Optional[model_types.ModelSaveOption] = None,
|
|
140
144
|
) -> model_meta.ModelMetadata:
|
|
141
145
|
# set enable_explainability=False if the model is not runnable in WH or the target platforms include SPCS
|
|
@@ -230,6 +234,7 @@ class ModelComposer:
|
|
|
230
234
|
options=options,
|
|
231
235
|
user_files=user_files,
|
|
232
236
|
data_sources=self._get_data_sources(model, sample_input_data),
|
|
237
|
+
experiment_info=experiment_info,
|
|
233
238
|
target_platforms=target_platforms,
|
|
234
239
|
)
|
|
235
240
|
|
|
@@ -2,7 +2,7 @@ import collections
|
|
|
2
2
|
import logging
|
|
3
3
|
import pathlib
|
|
4
4
|
import warnings
|
|
5
|
-
from typing import Optional, cast
|
|
5
|
+
from typing import TYPE_CHECKING, Optional, cast
|
|
6
6
|
|
|
7
7
|
import yaml
|
|
8
8
|
|
|
@@ -23,6 +23,9 @@ from snowflake.ml.model._packager.model_meta import (
|
|
|
23
23
|
)
|
|
24
24
|
from snowflake.ml.model._packager.model_runtime import model_runtime
|
|
25
25
|
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from snowflake.ml.experiment._experiment_info import ExperimentInfo
|
|
28
|
+
|
|
26
29
|
logger = logging.getLogger(__name__)
|
|
27
30
|
|
|
28
31
|
|
|
@@ -49,6 +52,7 @@ class ModelManifest:
|
|
|
49
52
|
user_files: Optional[dict[str, list[str]]] = None,
|
|
50
53
|
options: Optional[type_hints.ModelSaveOption] = None,
|
|
51
54
|
data_sources: Optional[list[data_source.DataSource]] = None,
|
|
55
|
+
experiment_info: Optional["ExperimentInfo"] = None,
|
|
52
56
|
target_platforms: Optional[list[type_hints.TargetPlatform]] = None,
|
|
53
57
|
) -> None:
|
|
54
58
|
if options is None:
|
|
@@ -183,7 +187,7 @@ class ModelManifest:
|
|
|
183
187
|
if self.user_files:
|
|
184
188
|
manifest_dict["user_files"] = [user_file.save(self.workspace_path) for user_file in self.user_files]
|
|
185
189
|
|
|
186
|
-
lineage_sources = self._extract_lineage_info(data_sources)
|
|
190
|
+
lineage_sources = self._extract_lineage_info(data_sources, experiment_info)
|
|
187
191
|
if lineage_sources:
|
|
188
192
|
manifest_dict["lineage_sources"] = lineage_sources
|
|
189
193
|
|
|
@@ -210,7 +214,9 @@ class ModelManifest:
|
|
|
210
214
|
return res
|
|
211
215
|
|
|
212
216
|
def _extract_lineage_info(
|
|
213
|
-
self,
|
|
217
|
+
self,
|
|
218
|
+
data_sources: Optional[list[data_source.DataSource]],
|
|
219
|
+
experiment_info: Optional["ExperimentInfo"],
|
|
214
220
|
) -> list[model_manifest_schema.LineageSourceDict]:
|
|
215
221
|
result = []
|
|
216
222
|
if data_sources:
|
|
@@ -229,4 +235,12 @@ class ModelManifest:
|
|
|
229
235
|
type=model_manifest_schema.LineageSourceTypes.QUERY.value, entity=source.sql
|
|
230
236
|
)
|
|
231
237
|
)
|
|
238
|
+
if experiment_info:
|
|
239
|
+
result.append(
|
|
240
|
+
model_manifest_schema.LineageSourceDict(
|
|
241
|
+
type=model_manifest_schema.LineageSourceTypes.EXPERIMENT.value,
|
|
242
|
+
entity=experiment_info.fully_qualified_name,
|
|
243
|
+
version=experiment_info.run_name,
|
|
244
|
+
)
|
|
245
|
+
)
|
|
232
246
|
return result
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import logging
|
|
2
3
|
import os
|
|
3
4
|
import warnings
|
|
4
5
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
|
|
@@ -23,9 +24,13 @@ from snowflake.ml.model._signatures import utils as model_signature_utils
|
|
|
23
24
|
from snowflake.ml.model.models import huggingface_pipeline
|
|
24
25
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
|
25
26
|
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
26
29
|
if TYPE_CHECKING:
|
|
27
30
|
import transformers
|
|
28
31
|
|
|
32
|
+
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501
|
|
33
|
+
|
|
29
34
|
|
|
30
35
|
def get_requirements_from_task(task: str, spcs_only: bool = False) -> list[model_env.ModelDependency]:
|
|
31
36
|
# Text
|
|
@@ -326,6 +331,23 @@ class HuggingFacePipelineHandler(
|
|
|
326
331
|
**device_config,
|
|
327
332
|
)
|
|
328
333
|
|
|
334
|
+
# If the task is text-generation, and the tokenizer does not have a chat_template,
|
|
335
|
+
# set the default chat template.
|
|
336
|
+
if (
|
|
337
|
+
hasattr(m, "task")
|
|
338
|
+
and m.task == "text-generation"
|
|
339
|
+
and hasattr(m.tokenizer, "chat_template")
|
|
340
|
+
and not m.tokenizer.chat_template
|
|
341
|
+
):
|
|
342
|
+
warnings.warn(
|
|
343
|
+
"The tokenizer does not have default chat_template. "
|
|
344
|
+
"Setting the chat_template to default ChatML template.",
|
|
345
|
+
UserWarning,
|
|
346
|
+
stacklevel=1,
|
|
347
|
+
)
|
|
348
|
+
logger.info(DEFAULT_CHAT_TEMPLATE)
|
|
349
|
+
m.tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
|
|
350
|
+
|
|
329
351
|
m.__dict__.update(pipeline_params)
|
|
330
352
|
|
|
331
353
|
else:
|
|
@@ -481,8 +503,25 @@ class HuggingFacePipelineHandler(
|
|
|
481
503
|
|
|
482
504
|
# To enable batch_size > 1 for LLM
|
|
483
505
|
# Pipe might not have tokenizer, but should always have a model, and model should always have a config.
|
|
484
|
-
if
|
|
485
|
-
pipe
|
|
506
|
+
if (
|
|
507
|
+
getattr(pipe, "tokenizer", None) is not None
|
|
508
|
+
and pipe.tokenizer.pad_token_id is None
|
|
509
|
+
and hasattr(pipe.model.config, "eos_token_id")
|
|
510
|
+
):
|
|
511
|
+
if isinstance(pipe.model.config.eos_token_id, int):
|
|
512
|
+
pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id
|
|
513
|
+
elif (
|
|
514
|
+
isinstance(pipe.model.config.eos_token_id, list)
|
|
515
|
+
and len(pipe.model.config.eos_token_id) > 0
|
|
516
|
+
and isinstance(pipe.model.config.eos_token_id[0], int)
|
|
517
|
+
):
|
|
518
|
+
pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id[0]
|
|
519
|
+
else:
|
|
520
|
+
warnings.warn(
|
|
521
|
+
f"Unexpected type of eos_token_id: {type(pipe.model.config.eos_token_id)}. "
|
|
522
|
+
"Not setting pad_token_id to eos_token_id.",
|
|
523
|
+
stacklevel=2,
|
|
524
|
+
)
|
|
486
525
|
|
|
487
526
|
_HFPipelineModel = _create_custom_model(pipe, model_meta)
|
|
488
527
|
hg_pipe_model = _HFPipelineModel(custom_model.ModelContext())
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import os
|
|
2
3
|
import warnings
|
|
3
4
|
from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union, cast, final
|
|
@@ -24,6 +25,8 @@ if TYPE_CHECKING:
|
|
|
24
25
|
import sklearn.base
|
|
25
26
|
import sklearn.pipeline
|
|
26
27
|
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
27
30
|
|
|
28
31
|
def _unpack_container_runtime_pipeline(model: "sklearn.pipeline.Pipeline") -> "sklearn.pipeline.Pipeline":
|
|
29
32
|
new_steps = []
|
|
@@ -201,13 +204,13 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
|
201
204
|
explain_target_method = str(explain_target_method) # mypy complains if we don't cast to str here
|
|
202
205
|
|
|
203
206
|
input_signature = handlers_utils.get_input_signature(model_meta, explain_target_method)
|
|
204
|
-
transformed_background_data = _apply_transforms_up_to_last_step(
|
|
205
|
-
model=model,
|
|
206
|
-
data=background_data,
|
|
207
|
-
input_feature_names=[spec.name for spec in input_signature],
|
|
208
|
-
)
|
|
209
207
|
|
|
210
208
|
try:
|
|
209
|
+
transformed_background_data = _apply_transforms_up_to_last_step(
|
|
210
|
+
model=model,
|
|
211
|
+
data=background_data,
|
|
212
|
+
input_feature_names=[spec.name for spec in input_signature],
|
|
213
|
+
)
|
|
211
214
|
model_meta = handlers_utils.add_inferred_explain_method_signature(
|
|
212
215
|
model_meta=model_meta,
|
|
213
216
|
explain_method="explain",
|
|
@@ -217,6 +220,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
|
217
220
|
output_feature_names=transformed_background_data.columns,
|
|
218
221
|
)
|
|
219
222
|
except Exception:
|
|
223
|
+
logger.debug("Explainability is disabled due to an exception.", exc_info=True)
|
|
220
224
|
if kwargs.get("enable_explainability", None):
|
|
221
225
|
# user explicitly enabled explainability, so we should raise the error
|
|
222
226
|
raise ValueError(
|
|
@@ -10,9 +10,10 @@ REQUIREMENTS = [
|
|
|
10
10
|
"cryptography",
|
|
11
11
|
"fsspec>=2024.6.1,<2026",
|
|
12
12
|
"importlib_resources>=6.1.1, <7",
|
|
13
|
-
"numpy>=1.23,<
|
|
13
|
+
"numpy>=1.23,<3",
|
|
14
14
|
"packaging>=20.9,<25",
|
|
15
15
|
"pandas>=2.1.4,<3",
|
|
16
|
+
"platformdirs<5",
|
|
16
17
|
"pyarrow",
|
|
17
18
|
"pydantic>=2.8.2, <3",
|
|
18
19
|
"pyjwt>=2.0.0, <3",
|
|
@@ -28,6 +29,7 @@ REQUIREMENTS = [
|
|
|
28
29
|
"snowflake-snowpark-python>=1.17.0,<2,!=1.26.0",
|
|
29
30
|
"snowflake.core>=1.0.2,<2",
|
|
30
31
|
"sqlparse>=0.4,<1",
|
|
32
|
+
"tqdm<5",
|
|
31
33
|
"typing-extensions>=4.1.0,<5",
|
|
32
34
|
"xgboost>=1.7.3,<3",
|
|
33
35
|
]
|
|
@@ -98,9 +98,9 @@ class ModelRuntime:
|
|
|
98
98
|
dependencies=model_meta_schema.ModelRuntimeDependenciesDict(
|
|
99
99
|
conda=env_dict["conda"],
|
|
100
100
|
pip=env_dict["pip"],
|
|
101
|
-
artifact_repository_map=
|
|
102
|
-
|
|
103
|
-
|
|
101
|
+
artifact_repository_map=(
|
|
102
|
+
env_dict["artifact_repository_map"] if env_dict.get("artifact_repository_map") is not None else {}
|
|
103
|
+
),
|
|
104
104
|
),
|
|
105
105
|
resource_constraint=env_dict["resource_constraint"],
|
|
106
106
|
)
|
|
@@ -86,6 +86,9 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
|
|
|
86
86
|
df_col_data = utils.series_dropna(df_col_data)
|
|
87
87
|
df_col_dtype = df_col_data.dtype
|
|
88
88
|
|
|
89
|
+
if utils.check_if_series_is_empty(df_col_data):
|
|
90
|
+
continue
|
|
91
|
+
|
|
89
92
|
if df_col_dtype == np.dtype("O"):
|
|
90
93
|
# Check if all objects have the same type
|
|
91
94
|
if not all(isinstance(data_row, type(df_col_data.iloc[0])) for data_row in df_col_data):
|
|
@@ -412,3 +412,7 @@ def infer_dict(name: str, data: dict[str, Any]) -> core.FeatureGroupSpec:
|
|
|
412
412
|
specs.append(core.FeatureSpec(name=key, dtype=core.DataType.from_numpy_type(np.array(value).dtype)))
|
|
413
413
|
|
|
414
414
|
return core.FeatureGroupSpec(name=name, specs=specs)
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def check_if_series_is_empty(series: Optional[pd.Series]) -> bool:
|
|
418
|
+
return series is None or series.empty
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class _TqdmStatusContext:
|
|
7
|
+
"""A tqdm-based context manager for status updates."""
|
|
8
|
+
|
|
9
|
+
def __init__(self, label: str, tqdm_module: Any, total: Optional[int] = None) -> None:
|
|
10
|
+
self._label = label
|
|
11
|
+
self._tqdm = tqdm_module
|
|
12
|
+
self._total = total or 1
|
|
13
|
+
|
|
14
|
+
def __enter__(self) -> "_TqdmStatusContext":
|
|
15
|
+
self._progress_bar = self._tqdm.tqdm(desc=self._label, file=sys.stdout, total=self._total, leave=True)
|
|
16
|
+
return self
|
|
17
|
+
|
|
18
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
19
|
+
self._progress_bar.close()
|
|
20
|
+
|
|
21
|
+
def update(self, label: str, *, state: str = "running", expanded: bool = True) -> None:
|
|
22
|
+
"""Update the status by updating the tqdm description."""
|
|
23
|
+
if state == "complete":
|
|
24
|
+
self._progress_bar.update(self._progress_bar.total - self._progress_bar.n)
|
|
25
|
+
self._progress_bar.set_description(label)
|
|
26
|
+
else:
|
|
27
|
+
self._progress_bar.set_description(f"{self._label}: {label}")
|
|
28
|
+
|
|
29
|
+
def increment(self, n: int = 1) -> None:
|
|
30
|
+
"""Increment the progress bar."""
|
|
31
|
+
self._progress_bar.update(n)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class _StreamlitStatusContext:
|
|
35
|
+
"""A streamlit-based context manager for status updates with progress bar support."""
|
|
36
|
+
|
|
37
|
+
def __init__(self, label: str, streamlit_module: Any, total: Optional[int] = None) -> None:
|
|
38
|
+
self._label = label
|
|
39
|
+
self._streamlit = streamlit_module
|
|
40
|
+
self._total = total
|
|
41
|
+
self._current = 0
|
|
42
|
+
self._progress_bar = None
|
|
43
|
+
|
|
44
|
+
def __enter__(self) -> "_StreamlitStatusContext":
|
|
45
|
+
self._status_container = self._streamlit.status(self._label, state="running", expanded=True)
|
|
46
|
+
if self._total is not None:
|
|
47
|
+
with self._status_container:
|
|
48
|
+
self._progress_bar = self._streamlit.progress(0, text=f"0/{self._total}")
|
|
49
|
+
return self
|
|
50
|
+
|
|
51
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
52
|
+
self._status_container.update(state="complete")
|
|
53
|
+
|
|
54
|
+
def update(self, label: str, *, state: str = "running", expanded: bool = True) -> None:
|
|
55
|
+
"""Update the status label."""
|
|
56
|
+
if state != "complete":
|
|
57
|
+
label = f"{self._label}: {label}"
|
|
58
|
+
self._status_container.update(label=label, state=state, expanded=expanded)
|
|
59
|
+
if self._progress_bar is not None:
|
|
60
|
+
self._progress_bar.progress(
|
|
61
|
+
self._current / self._total if self._total > 0 else 0,
|
|
62
|
+
text=f"{label} - {self._current}/{self._total}",
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def increment(self, n: int = 1) -> None:
|
|
66
|
+
"""Increment the progress."""
|
|
67
|
+
if self._total is not None:
|
|
68
|
+
self._current = min(self._current + n, self._total)
|
|
69
|
+
if self._progress_bar is not None:
|
|
70
|
+
progress_value = self._current / self._total if self._total > 0 else 0
|
|
71
|
+
self._progress_bar.progress(progress_value, text=f"{self._current}/{self._total}")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class ModelEventHandler:
|
|
75
|
+
"""Event handler for model operations with streamlit-aware status updates."""
|
|
76
|
+
|
|
77
|
+
def __init__(self) -> None:
|
|
78
|
+
self._streamlit = None
|
|
79
|
+
|
|
80
|
+
# Try streamlit first
|
|
81
|
+
try:
|
|
82
|
+
import streamlit as st
|
|
83
|
+
|
|
84
|
+
if st.runtime.exists():
|
|
85
|
+
USE_STREAMLIT_WIDGETS = os.getenv("USE_STREAMLIT_WIDGETS", "1") == "1"
|
|
86
|
+
if USE_STREAMLIT_WIDGETS:
|
|
87
|
+
self._streamlit = st
|
|
88
|
+
except ImportError:
|
|
89
|
+
pass
|
|
90
|
+
|
|
91
|
+
import tqdm
|
|
92
|
+
|
|
93
|
+
self._tqdm = tqdm
|
|
94
|
+
|
|
95
|
+
def update(self, message: str) -> None:
|
|
96
|
+
"""Write a message using streamlit if available, otherwise use tqdm."""
|
|
97
|
+
if self._streamlit is not None:
|
|
98
|
+
self._streamlit.write(message)
|
|
99
|
+
else:
|
|
100
|
+
self._tqdm.tqdm.write(message)
|
|
101
|
+
|
|
102
|
+
def status(self, label: str, *, state: str = "running", expanded: bool = True, total: Optional[int] = None) -> Any:
|
|
103
|
+
"""Context manager that provides status updates with optional enhanced display capabilities.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
label: The status label
|
|
107
|
+
state: The initial state ("running", "complete", "error")
|
|
108
|
+
expanded: Whether to show expanded view (streamlit only)
|
|
109
|
+
total: Total number of steps for progress tracking (optional)
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
Status context (Streamlit or Tqdm)
|
|
113
|
+
"""
|
|
114
|
+
if self._streamlit is not None:
|
|
115
|
+
return _StreamlitStatusContext(label, self._streamlit, total)
|
|
116
|
+
else:
|
|
117
|
+
return _TqdmStatusContext(label, self._tqdm, total)
|
|
@@ -16,7 +16,7 @@ from snowflake.ml._internal.exceptions import (
|
|
|
16
16
|
exceptions as snowml_exceptions,
|
|
17
17
|
)
|
|
18
18
|
from snowflake.ml._internal.utils import formatting, identifier, sql_identifier
|
|
19
|
-
from snowflake.ml.model import type_hints
|
|
19
|
+
from snowflake.ml.model import type_hints
|
|
20
20
|
from snowflake.ml.model._signatures import (
|
|
21
21
|
base_handler,
|
|
22
22
|
builtins_handler,
|
|
@@ -55,9 +55,9 @@ _MODEL_TELEMETRY_SUBPROJECT = "ModelSignature"
|
|
|
55
55
|
|
|
56
56
|
|
|
57
57
|
def _truncate_data(
|
|
58
|
-
data:
|
|
58
|
+
data: type_hints.SupportedDataType,
|
|
59
59
|
length: Optional[int] = 100,
|
|
60
|
-
) ->
|
|
60
|
+
) -> type_hints.SupportedDataType:
|
|
61
61
|
for handler in _ALL_DATA_HANDLERS:
|
|
62
62
|
if handler.can_handle(data):
|
|
63
63
|
# If length is None, return the original data
|
|
@@ -89,7 +89,7 @@ def _truncate_data(
|
|
|
89
89
|
|
|
90
90
|
|
|
91
91
|
def _infer_signature(
|
|
92
|
-
data:
|
|
92
|
+
data: type_hints.SupportedLocalDataType, role: Literal["input", "output"], use_snowflake_identifiers: bool = False
|
|
93
93
|
) -> Sequence[core.BaseFeatureSpec]:
|
|
94
94
|
"""Infer the inputs/outputs signature given a data that could be dataframe, numpy array or list.
|
|
95
95
|
Dispatching is used to separate logic for different types.
|
|
@@ -142,7 +142,7 @@ def _rename_signature_with_snowflake_identifiers(
|
|
|
142
142
|
|
|
143
143
|
|
|
144
144
|
def _validate_array_or_series_type(
|
|
145
|
-
arr: Union[
|
|
145
|
+
arr: Union[type_hints._SupportedNumpyArray, pd.Series], feature_type: core.DataType, strict: bool = False
|
|
146
146
|
) -> bool:
|
|
147
147
|
original_dtype = arr.dtype
|
|
148
148
|
dtype = arr.dtype
|
|
@@ -272,6 +272,8 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
|
|
|
272
272
|
),
|
|
273
273
|
)
|
|
274
274
|
else:
|
|
275
|
+
if utils.check_if_series_is_empty(data_col):
|
|
276
|
+
continue
|
|
275
277
|
if isinstance(data_col.iloc[0], list):
|
|
276
278
|
if not ft_shape:
|
|
277
279
|
raise snowml_exceptions.SnowflakeMLException(
|
|
@@ -649,7 +651,7 @@ def _validate_snowpark_type_feature(
|
|
|
649
651
|
|
|
650
652
|
|
|
651
653
|
def _convert_local_data_to_df(
|
|
652
|
-
data:
|
|
654
|
+
data: type_hints.SupportedLocalDataType, ensure_serializable: bool = False
|
|
653
655
|
) -> pd.DataFrame:
|
|
654
656
|
"""Convert local data to pandas DataFrame or Snowpark DataFrame
|
|
655
657
|
|
|
@@ -679,7 +681,7 @@ def _convert_local_data_to_df(
|
|
|
679
681
|
|
|
680
682
|
|
|
681
683
|
def _convert_and_validate_local_data(
|
|
682
|
-
data:
|
|
684
|
+
data: type_hints.SupportedLocalDataType, features: Sequence[core.BaseFeatureSpec], strict: bool = False
|
|
683
685
|
) -> pd.DataFrame:
|
|
684
686
|
"""Validate the data with features in model signature and convert to DataFrame
|
|
685
687
|
|
|
@@ -703,8 +705,8 @@ def _convert_and_validate_local_data(
|
|
|
703
705
|
subproject=_MODEL_TELEMETRY_SUBPROJECT,
|
|
704
706
|
)
|
|
705
707
|
def infer_signature(
|
|
706
|
-
input_data:
|
|
707
|
-
output_data:
|
|
708
|
+
input_data: type_hints.SupportedLocalDataType,
|
|
709
|
+
output_data: type_hints.SupportedLocalDataType,
|
|
708
710
|
input_feature_names: Optional[list[str]] = None,
|
|
709
711
|
output_feature_names: Optional[list[str]] = None,
|
|
710
712
|
input_data_limit: Optional[int] = 100,
|