snowflake-ml-python 1.9.0__py3-none-any.whl → 1.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. snowflake/ml/_internal/env_utils.py +44 -3
  2. snowflake/ml/_internal/platform_capabilities.py +52 -2
  3. snowflake/ml/_internal/type_utils.py +1 -1
  4. snowflake/ml/_internal/utils/mixins.py +54 -42
  5. snowflake/ml/_internal/utils/service_logger.py +105 -3
  6. snowflake/ml/data/_internal/arrow_ingestor.py +15 -2
  7. snowflake/ml/data/data_connector.py +13 -2
  8. snowflake/ml/data/data_ingestor.py +8 -0
  9. snowflake/ml/data/torch_utils.py +1 -1
  10. snowflake/ml/dataset/dataset.py +2 -1
  11. snowflake/ml/dataset/dataset_reader.py +14 -4
  12. snowflake/ml/experiment/__init__.py +3 -0
  13. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
  14. snowflake/ml/experiment/_entities/__init__.py +4 -0
  15. snowflake/ml/experiment/_entities/experiment.py +10 -0
  16. snowflake/ml/experiment/_entities/run.py +62 -0
  17. snowflake/ml/experiment/_entities/run_metadata.py +68 -0
  18. snowflake/ml/experiment/_experiment_info.py +63 -0
  19. snowflake/ml/experiment/callback.py +121 -0
  20. snowflake/ml/experiment/experiment_tracking.py +319 -0
  21. snowflake/ml/jobs/_utils/constants.py +15 -4
  22. snowflake/ml/jobs/_utils/payload_utils.py +156 -54
  23. snowflake/ml/jobs/_utils/query_helper.py +16 -5
  24. snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
  25. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +130 -23
  26. snowflake/ml/jobs/_utils/spec_utils.py +23 -8
  27. snowflake/ml/jobs/_utils/stage_utils.py +30 -14
  28. snowflake/ml/jobs/_utils/types.py +64 -4
  29. snowflake/ml/jobs/job.py +70 -75
  30. snowflake/ml/jobs/manager.py +59 -31
  31. snowflake/ml/lineage/lineage_node.py +2 -2
  32. snowflake/ml/model/_client/model/model_version_impl.py +16 -4
  33. snowflake/ml/model/_client/ops/service_ops.py +336 -137
  34. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  35. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -1
  36. snowflake/ml/model/_client/sql/service.py +1 -38
  37. snowflake/ml/model/_model_composer/model_composer.py +6 -1
  38. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +17 -3
  39. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  40. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
  41. snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
  42. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -1
  43. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
  44. snowflake/ml/model/_signatures/pandas_handler.py +3 -0
  45. snowflake/ml/model/_signatures/utils.py +4 -0
  46. snowflake/ml/model/event_handler.py +117 -0
  47. snowflake/ml/model/model_signature.py +11 -9
  48. snowflake/ml/model/models/huggingface_pipeline.py +170 -1
  49. snowflake/ml/modeling/framework/base.py +1 -1
  50. snowflake/ml/modeling/metrics/classification.py +14 -14
  51. snowflake/ml/modeling/metrics/correlation.py +19 -8
  52. snowflake/ml/modeling/metrics/ranking.py +6 -6
  53. snowflake/ml/modeling/metrics/regression.py +9 -9
  54. snowflake/ml/monitoring/explain_visualize.py +12 -5
  55. snowflake/ml/registry/_manager/model_manager.py +32 -15
  56. snowflake/ml/registry/registry.py +48 -80
  57. snowflake/ml/version.py +1 -1
  58. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/METADATA +107 -5
  59. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/RECORD +62 -52
  60. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/WHEEL +0 -0
  61. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/licenses/LICENSE.txt +0 -0
  62. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ import dataclasses
2
2
  import enum
3
3
  import logging
4
4
  import textwrap
5
- from typing import Any, Optional, Union
5
+ from typing import Any, Optional
6
6
 
7
7
  from snowflake import snowpark
8
8
  from snowflake.ml._internal.utils import (
@@ -69,43 +69,6 @@ class ServiceSQLClient(_base._BaseSQLClient):
69
69
  CONTAINER_STATUS = "status"
70
70
  MESSAGE = "message"
71
71
 
72
- def build_model_container(
73
- self,
74
- *,
75
- database_name: Optional[sql_identifier.SqlIdentifier],
76
- schema_name: Optional[sql_identifier.SqlIdentifier],
77
- model_name: sql_identifier.SqlIdentifier,
78
- version_name: sql_identifier.SqlIdentifier,
79
- compute_pool_name: sql_identifier.SqlIdentifier,
80
- image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
81
- image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
82
- image_repo_name: sql_identifier.SqlIdentifier,
83
- gpu: Optional[Union[str, int]],
84
- force_rebuild: bool,
85
- external_access_integration: sql_identifier.SqlIdentifier,
86
- statement_params: Optional[dict[str, Any]] = None,
87
- ) -> None:
88
- actual_image_repo_database = image_repo_database_name or self._database_name
89
- actual_image_repo_schema = image_repo_schema_name or self._schema_name
90
- actual_model_database = database_name or self._database_name
91
- actual_model_schema = schema_name or self._schema_name
92
- fq_model_name = self.fully_qualified_object_name(actual_model_database, actual_model_schema, model_name)
93
- fq_image_repo_name = identifier.get_schema_level_object_identifier(
94
- actual_image_repo_database.identifier(),
95
- actual_image_repo_schema.identifier(),
96
- image_repo_name.identifier(),
97
- )
98
- is_gpu_str = "TRUE" if gpu else "FALSE"
99
- force_rebuild_str = "TRUE" if force_rebuild else "FALSE"
100
- query_result_checker.SqlResultValidator(
101
- self._session,
102
- (
103
- f"CALL SYSTEM$BUILD_MODEL_CONTAINER('{fq_model_name}', '{version_name}', '{compute_pool_name}',"
104
- f" '{fq_image_repo_name}', '{is_gpu_str}', '{force_rebuild_str}', '', '{external_access_integration}')"
105
- ),
106
- statement_params=statement_params,
107
- ).has_dimensions(expected_rows=1, expected_cols=1).validate()
108
-
109
72
  def deploy_model(
110
73
  self,
111
74
  *,
@@ -3,7 +3,7 @@ import tempfile
3
3
  import uuid
4
4
  import warnings
5
5
  from types import ModuleType
6
- from typing import Any, Optional, Union
6
+ from typing import TYPE_CHECKING, Any, Optional, Union
7
7
  from urllib import parse
8
8
 
9
9
  from absl import logging
@@ -21,6 +21,9 @@ from snowflake.ml.model._packager.model_meta import model_meta
21
21
  from snowflake.snowpark import Session
22
22
  from snowflake.snowpark._internal import utils as snowpark_utils
23
23
 
24
+ if TYPE_CHECKING:
25
+ from snowflake.ml.experiment._experiment_info import ExperimentInfo
26
+
24
27
 
25
28
  class ModelComposer:
26
29
  """Top-level class to construct contents in a MODEL object in SQL.
@@ -136,6 +139,7 @@ class ModelComposer:
136
139
  ext_modules: Optional[list[ModuleType]] = None,
137
140
  code_paths: Optional[list[str]] = None,
138
141
  task: model_types.Task = model_types.Task.UNKNOWN,
142
+ experiment_info: Optional["ExperimentInfo"] = None,
139
143
  options: Optional[model_types.ModelSaveOption] = None,
140
144
  ) -> model_meta.ModelMetadata:
141
145
  # set enable_explainability=False if the model is not runnable in WH or the target platforms include SPCS
@@ -230,6 +234,7 @@ class ModelComposer:
230
234
  options=options,
231
235
  user_files=user_files,
232
236
  data_sources=self._get_data_sources(model, sample_input_data),
237
+ experiment_info=experiment_info,
233
238
  target_platforms=target_platforms,
234
239
  )
235
240
 
@@ -2,7 +2,7 @@ import collections
2
2
  import logging
3
3
  import pathlib
4
4
  import warnings
5
- from typing import Optional, cast
5
+ from typing import TYPE_CHECKING, Optional, cast
6
6
 
7
7
  import yaml
8
8
 
@@ -23,6 +23,9 @@ from snowflake.ml.model._packager.model_meta import (
23
23
  )
24
24
  from snowflake.ml.model._packager.model_runtime import model_runtime
25
25
 
26
+ if TYPE_CHECKING:
27
+ from snowflake.ml.experiment._experiment_info import ExperimentInfo
28
+
26
29
  logger = logging.getLogger(__name__)
27
30
 
28
31
 
@@ -49,6 +52,7 @@ class ModelManifest:
49
52
  user_files: Optional[dict[str, list[str]]] = None,
50
53
  options: Optional[type_hints.ModelSaveOption] = None,
51
54
  data_sources: Optional[list[data_source.DataSource]] = None,
55
+ experiment_info: Optional["ExperimentInfo"] = None,
52
56
  target_platforms: Optional[list[type_hints.TargetPlatform]] = None,
53
57
  ) -> None:
54
58
  if options is None:
@@ -183,7 +187,7 @@ class ModelManifest:
183
187
  if self.user_files:
184
188
  manifest_dict["user_files"] = [user_file.save(self.workspace_path) for user_file in self.user_files]
185
189
 
186
- lineage_sources = self._extract_lineage_info(data_sources)
190
+ lineage_sources = self._extract_lineage_info(data_sources, experiment_info)
187
191
  if lineage_sources:
188
192
  manifest_dict["lineage_sources"] = lineage_sources
189
193
 
@@ -210,7 +214,9 @@ class ModelManifest:
210
214
  return res
211
215
 
212
216
  def _extract_lineage_info(
213
- self, data_sources: Optional[list[data_source.DataSource]]
217
+ self,
218
+ data_sources: Optional[list[data_source.DataSource]],
219
+ experiment_info: Optional["ExperimentInfo"],
214
220
  ) -> list[model_manifest_schema.LineageSourceDict]:
215
221
  result = []
216
222
  if data_sources:
@@ -229,4 +235,12 @@ class ModelManifest:
229
235
  type=model_manifest_schema.LineageSourceTypes.QUERY.value, entity=source.sql
230
236
  )
231
237
  )
238
+ if experiment_info:
239
+ result.append(
240
+ model_manifest_schema.LineageSourceDict(
241
+ type=model_manifest_schema.LineageSourceTypes.EXPERIMENT.value,
242
+ entity=experiment_info.fully_qualified_name,
243
+ version=experiment_info.run_name,
244
+ )
245
+ )
232
246
  return result
@@ -83,6 +83,7 @@ class SnowparkMLDataDict(TypedDict):
83
83
  class LineageSourceTypes(enum.Enum):
84
84
  DATASET = "DATASET"
85
85
  QUERY = "QUERY"
86
+ EXPERIMENT = "EXPERIMENT"
86
87
 
87
88
 
88
89
  class LineageSourceDict(TypedDict):
@@ -1,4 +1,5 @@
1
1
  import json
2
+ import logging
2
3
  import os
3
4
  import warnings
4
5
  from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
@@ -23,9 +24,13 @@ from snowflake.ml.model._signatures import utils as model_signature_utils
23
24
  from snowflake.ml.model.models import huggingface_pipeline
24
25
  from snowflake.snowpark._internal import utils as snowpark_utils
25
26
 
27
+ logger = logging.getLogger(__name__)
28
+
26
29
  if TYPE_CHECKING:
27
30
  import transformers
28
31
 
32
+ DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501
33
+
29
34
 
30
35
  def get_requirements_from_task(task: str, spcs_only: bool = False) -> list[model_env.ModelDependency]:
31
36
  # Text
@@ -326,6 +331,23 @@ class HuggingFacePipelineHandler(
326
331
  **device_config,
327
332
  )
328
333
 
334
+ # If the task is text-generation, and the tokenizer does not have a chat_template,
335
+ # set the default chat template.
336
+ if (
337
+ hasattr(m, "task")
338
+ and m.task == "text-generation"
339
+ and hasattr(m.tokenizer, "chat_template")
340
+ and not m.tokenizer.chat_template
341
+ ):
342
+ warnings.warn(
343
+ "The tokenizer does not have default chat_template. "
344
+ "Setting the chat_template to default ChatML template.",
345
+ UserWarning,
346
+ stacklevel=1,
347
+ )
348
+ logger.info(DEFAULT_CHAT_TEMPLATE)
349
+ m.tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
350
+
329
351
  m.__dict__.update(pipeline_params)
330
352
 
331
353
  else:
@@ -481,8 +503,25 @@ class HuggingFacePipelineHandler(
481
503
 
482
504
  # To enable batch_size > 1 for LLM
483
505
  # Pipe might not have tokenizer, but should always have a model, and model should always have a config.
484
- if getattr(pipe, "tokenizer", None) is not None and pipe.tokenizer.pad_token_id is None:
485
- pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id
506
+ if (
507
+ getattr(pipe, "tokenizer", None) is not None
508
+ and pipe.tokenizer.pad_token_id is None
509
+ and hasattr(pipe.model.config, "eos_token_id")
510
+ ):
511
+ if isinstance(pipe.model.config.eos_token_id, int):
512
+ pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id
513
+ elif (
514
+ isinstance(pipe.model.config.eos_token_id, list)
515
+ and len(pipe.model.config.eos_token_id) > 0
516
+ and isinstance(pipe.model.config.eos_token_id[0], int)
517
+ ):
518
+ pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id[0]
519
+ else:
520
+ warnings.warn(
521
+ f"Unexpected type of eos_token_id: {type(pipe.model.config.eos_token_id)}. "
522
+ "Not setting pad_token_id to eos_token_id.",
523
+ stacklevel=2,
524
+ )
486
525
 
487
526
  _HFPipelineModel = _create_custom_model(pipe, model_meta)
488
527
  hg_pipe_model = _HFPipelineModel(custom_model.ModelContext())
@@ -1,3 +1,4 @@
1
+ import logging
1
2
  import os
2
3
  import warnings
3
4
  from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union, cast, final
@@ -24,6 +25,8 @@ if TYPE_CHECKING:
24
25
  import sklearn.base
25
26
  import sklearn.pipeline
26
27
 
28
+ logger = logging.getLogger(__name__)
29
+
27
30
 
28
31
  def _unpack_container_runtime_pipeline(model: "sklearn.pipeline.Pipeline") -> "sklearn.pipeline.Pipeline":
29
32
  new_steps = []
@@ -201,13 +204,13 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
201
204
  explain_target_method = str(explain_target_method) # mypy complains if we don't cast to str here
202
205
 
203
206
  input_signature = handlers_utils.get_input_signature(model_meta, explain_target_method)
204
- transformed_background_data = _apply_transforms_up_to_last_step(
205
- model=model,
206
- data=background_data,
207
- input_feature_names=[spec.name for spec in input_signature],
208
- )
209
207
 
210
208
  try:
209
+ transformed_background_data = _apply_transforms_up_to_last_step(
210
+ model=model,
211
+ data=background_data,
212
+ input_feature_names=[spec.name for spec in input_signature],
213
+ )
211
214
  model_meta = handlers_utils.add_inferred_explain_method_signature(
212
215
  model_meta=model_meta,
213
216
  explain_method="explain",
@@ -217,6 +220,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
217
220
  output_feature_names=transformed_background_data.columns,
218
221
  )
219
222
  except Exception:
223
+ logger.debug("Explainability is disabled due to an exception.", exc_info=True)
220
224
  if kwargs.get("enable_explainability", None):
221
225
  # user explicitly enabled explainability, so we should raise the error
222
226
  raise ValueError(
@@ -10,9 +10,10 @@ REQUIREMENTS = [
10
10
  "cryptography",
11
11
  "fsspec>=2024.6.1,<2026",
12
12
  "importlib_resources>=6.1.1, <7",
13
- "numpy>=1.23,<2",
13
+ "numpy>=1.23,<3",
14
14
  "packaging>=20.9,<25",
15
15
  "pandas>=2.1.4,<3",
16
+ "platformdirs<5",
16
17
  "pyarrow",
17
18
  "pydantic>=2.8.2, <3",
18
19
  "pyjwt>=2.0.0, <3",
@@ -28,6 +29,7 @@ REQUIREMENTS = [
28
29
  "snowflake-snowpark-python>=1.17.0,<2,!=1.26.0",
29
30
  "snowflake.core>=1.0.2,<2",
30
31
  "sqlparse>=0.4,<1",
32
+ "tqdm<5",
31
33
  "typing-extensions>=4.1.0,<5",
32
34
  "xgboost>=1.7.3,<3",
33
35
  ]
@@ -98,9 +98,9 @@ class ModelRuntime:
98
98
  dependencies=model_meta_schema.ModelRuntimeDependenciesDict(
99
99
  conda=env_dict["conda"],
100
100
  pip=env_dict["pip"],
101
- artifact_repository_map=env_dict["artifact_repository_map"]
102
- if env_dict.get("artifact_repository_map") is not None
103
- else {},
101
+ artifact_repository_map=(
102
+ env_dict["artifact_repository_map"] if env_dict.get("artifact_repository_map") is not None else {}
103
+ ),
104
104
  ),
105
105
  resource_constraint=env_dict["resource_constraint"],
106
106
  )
@@ -86,6 +86,9 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
86
86
  df_col_data = utils.series_dropna(df_col_data)
87
87
  df_col_dtype = df_col_data.dtype
88
88
 
89
+ if utils.check_if_series_is_empty(df_col_data):
90
+ continue
91
+
89
92
  if df_col_dtype == np.dtype("O"):
90
93
  # Check if all objects have the same type
91
94
  if not all(isinstance(data_row, type(df_col_data.iloc[0])) for data_row in df_col_data):
@@ -412,3 +412,7 @@ def infer_dict(name: str, data: dict[str, Any]) -> core.FeatureGroupSpec:
412
412
  specs.append(core.FeatureSpec(name=key, dtype=core.DataType.from_numpy_type(np.array(value).dtype)))
413
413
 
414
414
  return core.FeatureGroupSpec(name=name, specs=specs)
415
+
416
+
417
+ def check_if_series_is_empty(series: Optional[pd.Series]) -> bool:
418
+ return series is None or series.empty
@@ -0,0 +1,117 @@
1
+ import os
2
+ import sys
3
+ from typing import Any, Optional
4
+
5
+
6
+ class _TqdmStatusContext:
7
+ """A tqdm-based context manager for status updates."""
8
+
9
+ def __init__(self, label: str, tqdm_module: Any, total: Optional[int] = None) -> None:
10
+ self._label = label
11
+ self._tqdm = tqdm_module
12
+ self._total = total or 1
13
+
14
+ def __enter__(self) -> "_TqdmStatusContext":
15
+ self._progress_bar = self._tqdm.tqdm(desc=self._label, file=sys.stdout, total=self._total, leave=True)
16
+ return self
17
+
18
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
19
+ self._progress_bar.close()
20
+
21
+ def update(self, label: str, *, state: str = "running", expanded: bool = True) -> None:
22
+ """Update the status by updating the tqdm description."""
23
+ if state == "complete":
24
+ self._progress_bar.update(self._progress_bar.total - self._progress_bar.n)
25
+ self._progress_bar.set_description(label)
26
+ else:
27
+ self._progress_bar.set_description(f"{self._label}: {label}")
28
+
29
+ def increment(self, n: int = 1) -> None:
30
+ """Increment the progress bar."""
31
+ self._progress_bar.update(n)
32
+
33
+
34
+ class _StreamlitStatusContext:
35
+ """A streamlit-based context manager for status updates with progress bar support."""
36
+
37
+ def __init__(self, label: str, streamlit_module: Any, total: Optional[int] = None) -> None:
38
+ self._label = label
39
+ self._streamlit = streamlit_module
40
+ self._total = total
41
+ self._current = 0
42
+ self._progress_bar = None
43
+
44
+ def __enter__(self) -> "_StreamlitStatusContext":
45
+ self._status_container = self._streamlit.status(self._label, state="running", expanded=True)
46
+ if self._total is not None:
47
+ with self._status_container:
48
+ self._progress_bar = self._streamlit.progress(0, text=f"0/{self._total}")
49
+ return self
50
+
51
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
52
+ self._status_container.update(state="complete")
53
+
54
+ def update(self, label: str, *, state: str = "running", expanded: bool = True) -> None:
55
+ """Update the status label."""
56
+ if state != "complete":
57
+ label = f"{self._label}: {label}"
58
+ self._status_container.update(label=label, state=state, expanded=expanded)
59
+ if self._progress_bar is not None:
60
+ self._progress_bar.progress(
61
+ self._current / self._total if self._total > 0 else 0,
62
+ text=f"{label} - {self._current}/{self._total}",
63
+ )
64
+
65
+ def increment(self, n: int = 1) -> None:
66
+ """Increment the progress."""
67
+ if self._total is not None:
68
+ self._current = min(self._current + n, self._total)
69
+ if self._progress_bar is not None:
70
+ progress_value = self._current / self._total if self._total > 0 else 0
71
+ self._progress_bar.progress(progress_value, text=f"{self._current}/{self._total}")
72
+
73
+
74
+ class ModelEventHandler:
75
+ """Event handler for model operations with streamlit-aware status updates."""
76
+
77
+ def __init__(self) -> None:
78
+ self._streamlit = None
79
+
80
+ # Try streamlit first
81
+ try:
82
+ import streamlit as st
83
+
84
+ if st.runtime.exists():
85
+ USE_STREAMLIT_WIDGETS = os.getenv("USE_STREAMLIT_WIDGETS", "1") == "1"
86
+ if USE_STREAMLIT_WIDGETS:
87
+ self._streamlit = st
88
+ except ImportError:
89
+ pass
90
+
91
+ import tqdm
92
+
93
+ self._tqdm = tqdm
94
+
95
+ def update(self, message: str) -> None:
96
+ """Write a message using streamlit if available, otherwise use tqdm."""
97
+ if self._streamlit is not None:
98
+ self._streamlit.write(message)
99
+ else:
100
+ self._tqdm.tqdm.write(message)
101
+
102
+ def status(self, label: str, *, state: str = "running", expanded: bool = True, total: Optional[int] = None) -> Any:
103
+ """Context manager that provides status updates with optional enhanced display capabilities.
104
+
105
+ Args:
106
+ label: The status label
107
+ state: The initial state ("running", "complete", "error")
108
+ expanded: Whether to show expanded view (streamlit only)
109
+ total: Total number of steps for progress tracking (optional)
110
+
111
+ Returns:
112
+ Status context (Streamlit or Tqdm)
113
+ """
114
+ if self._streamlit is not None:
115
+ return _StreamlitStatusContext(label, self._streamlit, total)
116
+ else:
117
+ return _TqdmStatusContext(label, self._tqdm, total)
@@ -16,7 +16,7 @@ from snowflake.ml._internal.exceptions import (
16
16
  exceptions as snowml_exceptions,
17
17
  )
18
18
  from snowflake.ml._internal.utils import formatting, identifier, sql_identifier
19
- from snowflake.ml.model import type_hints as model_types
19
+ from snowflake.ml.model import type_hints
20
20
  from snowflake.ml.model._signatures import (
21
21
  base_handler,
22
22
  builtins_handler,
@@ -55,9 +55,9 @@ _MODEL_TELEMETRY_SUBPROJECT = "ModelSignature"
55
55
 
56
56
 
57
57
  def _truncate_data(
58
- data: model_types.SupportedDataType,
58
+ data: type_hints.SupportedDataType,
59
59
  length: Optional[int] = 100,
60
- ) -> model_types.SupportedDataType:
60
+ ) -> type_hints.SupportedDataType:
61
61
  for handler in _ALL_DATA_HANDLERS:
62
62
  if handler.can_handle(data):
63
63
  # If length is None, return the original data
@@ -89,7 +89,7 @@ def _truncate_data(
89
89
 
90
90
 
91
91
  def _infer_signature(
92
- data: model_types.SupportedLocalDataType, role: Literal["input", "output"], use_snowflake_identifiers: bool = False
92
+ data: type_hints.SupportedLocalDataType, role: Literal["input", "output"], use_snowflake_identifiers: bool = False
93
93
  ) -> Sequence[core.BaseFeatureSpec]:
94
94
  """Infer the inputs/outputs signature given a data that could be dataframe, numpy array or list.
95
95
  Dispatching is used to separate logic for different types.
@@ -142,7 +142,7 @@ def _rename_signature_with_snowflake_identifiers(
142
142
 
143
143
 
144
144
  def _validate_array_or_series_type(
145
- arr: Union[model_types._SupportedNumpyArray, pd.Series], feature_type: core.DataType, strict: bool = False
145
+ arr: Union[type_hints._SupportedNumpyArray, pd.Series], feature_type: core.DataType, strict: bool = False
146
146
  ) -> bool:
147
147
  original_dtype = arr.dtype
148
148
  dtype = arr.dtype
@@ -272,6 +272,8 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
272
272
  ),
273
273
  )
274
274
  else:
275
+ if utils.check_if_series_is_empty(data_col):
276
+ continue
275
277
  if isinstance(data_col.iloc[0], list):
276
278
  if not ft_shape:
277
279
  raise snowml_exceptions.SnowflakeMLException(
@@ -649,7 +651,7 @@ def _validate_snowpark_type_feature(
649
651
 
650
652
 
651
653
  def _convert_local_data_to_df(
652
- data: model_types.SupportedLocalDataType, ensure_serializable: bool = False
654
+ data: type_hints.SupportedLocalDataType, ensure_serializable: bool = False
653
655
  ) -> pd.DataFrame:
654
656
  """Convert local data to pandas DataFrame or Snowpark DataFrame
655
657
 
@@ -679,7 +681,7 @@ def _convert_local_data_to_df(
679
681
 
680
682
 
681
683
  def _convert_and_validate_local_data(
682
- data: model_types.SupportedLocalDataType, features: Sequence[core.BaseFeatureSpec], strict: bool = False
684
+ data: type_hints.SupportedLocalDataType, features: Sequence[core.BaseFeatureSpec], strict: bool = False
683
685
  ) -> pd.DataFrame:
684
686
  """Validate the data with features in model signature and convert to DataFrame
685
687
 
@@ -703,8 +705,8 @@ def _convert_and_validate_local_data(
703
705
  subproject=_MODEL_TELEMETRY_SUBPROJECT,
704
706
  )
705
707
  def infer_signature(
706
- input_data: model_types.SupportedLocalDataType,
707
- output_data: model_types.SupportedLocalDataType,
708
+ input_data: type_hints.SupportedLocalDataType,
709
+ output_data: type_hints.SupportedLocalDataType,
708
710
  input_feature_names: Optional[list[str]] = None,
709
711
  output_feature_names: Optional[list[str]] = None,
710
712
  input_data_limit: Optional[int] = 100,