snowflake-ml-python 1.9.0__py3-none-any.whl → 1.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. snowflake/ml/_internal/env_utils.py +44 -3
  2. snowflake/ml/_internal/platform_capabilities.py +52 -2
  3. snowflake/ml/_internal/type_utils.py +1 -1
  4. snowflake/ml/_internal/utils/mixins.py +54 -42
  5. snowflake/ml/_internal/utils/service_logger.py +105 -3
  6. snowflake/ml/data/_internal/arrow_ingestor.py +15 -2
  7. snowflake/ml/data/data_connector.py +13 -2
  8. snowflake/ml/data/data_ingestor.py +8 -0
  9. snowflake/ml/data/torch_utils.py +1 -1
  10. snowflake/ml/dataset/dataset.py +2 -1
  11. snowflake/ml/dataset/dataset_reader.py +14 -4
  12. snowflake/ml/experiment/__init__.py +3 -0
  13. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
  14. snowflake/ml/experiment/_entities/__init__.py +4 -0
  15. snowflake/ml/experiment/_entities/experiment.py +10 -0
  16. snowflake/ml/experiment/_entities/run.py +62 -0
  17. snowflake/ml/experiment/_entities/run_metadata.py +68 -0
  18. snowflake/ml/experiment/_experiment_info.py +63 -0
  19. snowflake/ml/experiment/callback.py +121 -0
  20. snowflake/ml/experiment/experiment_tracking.py +319 -0
  21. snowflake/ml/jobs/_utils/constants.py +15 -4
  22. snowflake/ml/jobs/_utils/payload_utils.py +156 -54
  23. snowflake/ml/jobs/_utils/query_helper.py +16 -5
  24. snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
  25. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +130 -23
  26. snowflake/ml/jobs/_utils/spec_utils.py +23 -8
  27. snowflake/ml/jobs/_utils/stage_utils.py +30 -14
  28. snowflake/ml/jobs/_utils/types.py +64 -4
  29. snowflake/ml/jobs/job.py +70 -75
  30. snowflake/ml/jobs/manager.py +59 -31
  31. snowflake/ml/lineage/lineage_node.py +2 -2
  32. snowflake/ml/model/_client/model/model_version_impl.py +16 -4
  33. snowflake/ml/model/_client/ops/service_ops.py +336 -137
  34. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  35. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -1
  36. snowflake/ml/model/_client/sql/service.py +1 -38
  37. snowflake/ml/model/_model_composer/model_composer.py +6 -1
  38. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +17 -3
  39. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  40. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
  41. snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
  42. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -1
  43. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
  44. snowflake/ml/model/_signatures/pandas_handler.py +3 -0
  45. snowflake/ml/model/_signatures/utils.py +4 -0
  46. snowflake/ml/model/event_handler.py +117 -0
  47. snowflake/ml/model/model_signature.py +11 -9
  48. snowflake/ml/model/models/huggingface_pipeline.py +170 -1
  49. snowflake/ml/modeling/framework/base.py +1 -1
  50. snowflake/ml/modeling/metrics/classification.py +14 -14
  51. snowflake/ml/modeling/metrics/correlation.py +19 -8
  52. snowflake/ml/modeling/metrics/ranking.py +6 -6
  53. snowflake/ml/modeling/metrics/regression.py +9 -9
  54. snowflake/ml/monitoring/explain_visualize.py +12 -5
  55. snowflake/ml/registry/_manager/model_manager.py +32 -15
  56. snowflake/ml/registry/registry.py +48 -80
  57. snowflake/ml/version.py +1 -1
  58. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/METADATA +107 -5
  59. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/RECORD +62 -52
  60. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/WHEEL +0 -0
  61. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/licenses/LICENSE.txt +0 -0
  62. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,22 @@
1
+ import logging
1
2
  import warnings
2
- from typing import Any, Optional
3
+ from typing import Any, Optional, Union
3
4
 
4
5
  from packaging import version
5
6
 
7
+ from snowflake import snowpark
8
+ from snowflake.ml._internal import telemetry
9
+ from snowflake.ml._internal.human_readable_id import hrid_generator
10
+ from snowflake.ml._internal.utils import sql_identifier
11
+ from snowflake.ml.model._client.ops import service_ops
12
+ from snowflake.snowpark import async_job, session
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ _TELEMETRY_PROJECT = "MLOps"
18
+ _TELEMETRY_SUBPROJECT = "ModelManagement"
19
+
6
20
 
7
21
  class HuggingFacePipelineModel:
8
22
  def __init__(
@@ -214,4 +228,159 @@ class HuggingFacePipelineModel:
214
228
  self.token = token
215
229
  self.trust_remote_code = trust_remote_code
216
230
  self.model_kwargs = model_kwargs
231
+ self.tokenizer = tokenizer
217
232
  self.__dict__.update(kwargs)
233
+
234
+ @telemetry.send_api_usage_telemetry(
235
+ project=_TELEMETRY_PROJECT,
236
+ subproject=_TELEMETRY_SUBPROJECT,
237
+ func_params_to_log=[
238
+ "service_name",
239
+ "image_build_compute_pool",
240
+ "service_compute_pool",
241
+ "image_repo",
242
+ "gpu_requests",
243
+ "num_workers",
244
+ "max_batch_rows",
245
+ ],
246
+ )
247
+ @snowpark._internal.utils.private_preview(version="1.9.1")
248
+ def create_service(
249
+ self,
250
+ *,
251
+ session: session.Session,
252
+ # registry.log_model parameters
253
+ model_name: str,
254
+ version_name: Optional[str] = None,
255
+ pip_requirements: Optional[list[str]] = None,
256
+ conda_dependencies: Optional[list[str]] = None,
257
+ comment: Optional[str] = None,
258
+ # model_version_impl.create_service parameters
259
+ service_name: str,
260
+ service_compute_pool: str,
261
+ image_repo: str,
262
+ image_build_compute_pool: Optional[str] = None,
263
+ ingress_enabled: bool = False,
264
+ max_instances: int = 1,
265
+ cpu_requests: Optional[str] = None,
266
+ memory_requests: Optional[str] = None,
267
+ gpu_requests: Optional[Union[str, int]] = None,
268
+ num_workers: Optional[int] = None,
269
+ max_batch_rows: Optional[int] = None,
270
+ force_rebuild: bool = False,
271
+ build_external_access_integrations: Optional[list[str]] = None,
272
+ block: bool = True,
273
+ ) -> Union[str, async_job.AsyncJob]:
274
+ """Logs a Hugging Face model and creates a service in Snowflake.
275
+
276
+ Args:
277
+ session: The Snowflake session object.
278
+ model_name: The name of the model in Snowflake.
279
+ version_name: The version name of the model. Defaults to None.
280
+ pip_requirements: Pip requirements for the model. Defaults to None.
281
+ conda_dependencies: Conda dependencies for the model. Defaults to None.
282
+ comment: Comment for the model. Defaults to None.
283
+ service_name: The name of the service to create.
284
+ service_compute_pool: The compute pool for the service.
285
+ image_repo: The name of the image repository.
286
+ image_build_compute_pool: The name of the compute pool used to build the model inference image. It uses
287
+ the service compute pool if None.
288
+ ingress_enabled: Whether ingress is enabled. Defaults to False.
289
+ max_instances: Maximum number of instances. Defaults to 1.
290
+ cpu_requests: CPU requests configuration. Defaults to None.
291
+ memory_requests: Memory requests configuration. Defaults to None.
292
+ gpu_requests: GPU requests configuration. Defaults to None.
293
+ num_workers: Number of workers. Defaults to None.
294
+ max_batch_rows: Maximum batch rows. Defaults to None.
295
+ force_rebuild: Whether to force rebuild the image. Defaults to False.
296
+ build_external_access_integrations: External access integrations for building the image. Defaults to None.
297
+ block: Whether to block the operation. Defaults to True.
298
+
299
+ Raises:
300
+ ValueError: if database and schema name is not provided and session doesn't have a
301
+ database and schema name.
302
+
303
+ Returns:
304
+ The service ID or an async job object.
305
+
306
+ .. # noqa: DAR003
307
+ """
308
+ statement_params = telemetry.get_statement_params(
309
+ project=_TELEMETRY_PROJECT,
310
+ subproject=_TELEMETRY_SUBPROJECT,
311
+ )
312
+
313
+ database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
314
+ session_database_name = session.get_current_database()
315
+ session_schema_name = session.get_current_schema()
316
+ if database_name_id is None:
317
+ if session_database_name is None:
318
+ raise ValueError("Either database needs to be provided or needs to be available in session.")
319
+ database_name_id = sql_identifier.SqlIdentifier(session_database_name)
320
+ if schema_name_id is None:
321
+ if session_schema_name is None:
322
+ raise ValueError("Either schema needs to be provided or needs to be available in session.")
323
+ schema_name_id = sql_identifier.SqlIdentifier(session_schema_name)
324
+
325
+ if version_name is None:
326
+ name_generator = hrid_generator.HRID16()
327
+ version_name = name_generator.generate()[1]
328
+
329
+ service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
330
+ image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
331
+
332
+ service_operator = service_ops.ServiceOperator(
333
+ session=session,
334
+ database_name=database_name_id,
335
+ schema_name=schema_name_id,
336
+ )
337
+ logger.info(f"A service job is going to register the hf model as: {model_name}.{version_name}")
338
+
339
+ return service_operator.create_service(
340
+ database_name=database_name_id,
341
+ schema_name=schema_name_id,
342
+ model_name=model_name_id,
343
+ version_name=sql_identifier.SqlIdentifier(version_name),
344
+ service_database_name=service_db_id,
345
+ service_schema_name=service_schema_id,
346
+ service_name=service_id,
347
+ image_build_compute_pool_name=(
348
+ sql_identifier.SqlIdentifier(image_build_compute_pool)
349
+ if image_build_compute_pool
350
+ else sql_identifier.SqlIdentifier(service_compute_pool)
351
+ ),
352
+ service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
353
+ image_repo_database_name=image_repo_db_id,
354
+ image_repo_schema_name=image_repo_schema_id,
355
+ image_repo_name=image_repo_id,
356
+ ingress_enabled=ingress_enabled,
357
+ max_instances=max_instances,
358
+ cpu_requests=cpu_requests,
359
+ memory_requests=memory_requests,
360
+ gpu_requests=gpu_requests,
361
+ num_workers=num_workers,
362
+ max_batch_rows=max_batch_rows,
363
+ force_rebuild=force_rebuild,
364
+ build_external_access_integrations=(
365
+ None
366
+ if build_external_access_integrations is None
367
+ else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
368
+ ),
369
+ block=block,
370
+ statement_params=statement_params,
371
+ # hf model
372
+ hf_model_args=service_ops.HFModelArgs(
373
+ hf_model_name=self.model,
374
+ hf_task=self.task,
375
+ hf_tokenizer=self.tokenizer,
376
+ hf_revision=self.revision,
377
+ hf_token=self.token,
378
+ hf_trust_remote_code=bool(self.trust_remote_code),
379
+ hf_model_kwargs=self.model_kwargs,
380
+ pip_requirements=pip_requirements,
381
+ conda_dependencies=conda_dependencies,
382
+ comment=comment,
383
+ # TODO: remove warehouse in the next release
384
+ warehouse=session.get_current_warehouse(),
385
+ ),
386
+ )
@@ -698,7 +698,7 @@ class BaseTransformer(BaseEstimator):
698
698
  self,
699
699
  attribute: Optional[Mapping[str, Union[int, float, str, Iterable[Union[int, float, str]]]]],
700
700
  dtype: Optional[type] = None,
701
- ) -> Optional[npt.NDArray[Union[np.int_, np.float_, np.str_]]]:
701
+ ) -> Optional[npt.NDArray[Union[np.int_, np.float64, np.str_]]]:
702
702
  """
703
703
  Convert the attribute from dict to ndarray based on the order of `self.input_cols`.
704
704
 
@@ -96,7 +96,7 @@ def confusion_matrix(
96
96
  labels: Optional[npt.ArrayLike] = None,
97
97
  sample_weight_col_name: Optional[str] = None,
98
98
  normalize: Optional[str] = None,
99
- ) -> Union[npt.NDArray[np.int_], npt.NDArray[np.float_]]:
99
+ ) -> Union[npt.NDArray[np.int_], npt.NDArray[np.float64]]:
100
100
  """
101
101
  Compute confusion matrix to evaluate the accuracy of a classification.
102
102
 
@@ -320,7 +320,7 @@ def f1_score(
320
320
  average: Optional[str] = "binary",
321
321
  sample_weight_col_name: Optional[str] = None,
322
322
  zero_division: Union[str, int] = "warn",
323
- ) -> Union[float, npt.NDArray[np.float_]]:
323
+ ) -> Union[float, npt.NDArray[np.float64]]:
324
324
  """
325
325
  Compute the F1 score, also known as balanced F-score or F-measure.
326
326
 
@@ -414,7 +414,7 @@ def fbeta_score(
414
414
  average: Optional[str] = "binary",
415
415
  sample_weight_col_name: Optional[str] = None,
416
416
  zero_division: Union[str, int] = "warn",
417
- ) -> Union[float, npt.NDArray[np.float_]]:
417
+ ) -> Union[float, npt.NDArray[np.float64]]:
418
418
  """
419
419
  Compute the F-beta score.
420
420
 
@@ -696,7 +696,7 @@ def precision_recall_fscore_support(
696
696
  zero_division: Union[str, int] = "warn",
697
697
  ) -> Union[
698
698
  tuple[float, float, float, None],
699
- tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]],
699
+ tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]],
700
700
  ]:
701
701
  """
702
702
  Compute precision, recall, F-measure and support for each class.
@@ -855,7 +855,7 @@ def precision_recall_fscore_support(
855
855
 
856
856
  res: Union[
857
857
  tuple[float, float, float, None],
858
- tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]],
858
+ tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]],
859
859
  ] = result_object[:4]
860
860
  warning = result_object[-1]
861
861
  if warning:
@@ -1050,7 +1050,7 @@ def _register_multilabel_confusion_matrix_computer(
1050
1050
 
1051
1051
  def end_partition(
1052
1052
  self,
1053
- ) -> Iterable[tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]]:
1053
+ ) -> Iterable[tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]]]:
1054
1054
  MCM = metrics.multilabel_confusion_matrix(
1055
1055
  self._y_true,
1056
1056
  self._y_pred,
@@ -1098,7 +1098,7 @@ def _binary_precision_score(
1098
1098
  pos_label: Union[str, int] = 1,
1099
1099
  sample_weight_col_name: Optional[str] = None,
1100
1100
  zero_division: Union[str, int] = "warn",
1101
- ) -> Union[float, npt.NDArray[np.float_]]:
1101
+ ) -> Union[float, npt.NDArray[np.float64]]:
1102
1102
 
1103
1103
  statement_params = telemetry.get_statement_params(_PROJECT, _SUBPROJECT)
1104
1104
 
@@ -1173,7 +1173,7 @@ def precision_score(
1173
1173
  average: Optional[str] = "binary",
1174
1174
  sample_weight_col_name: Optional[str] = None,
1175
1175
  zero_division: Union[str, int] = "warn",
1176
- ) -> Union[float, npt.NDArray[np.float_]]:
1176
+ ) -> Union[float, npt.NDArray[np.float64]]:
1177
1177
  """
1178
1178
  Compute the precision.
1179
1179
 
@@ -1271,7 +1271,7 @@ def recall_score(
1271
1271
  average: Optional[str] = "binary",
1272
1272
  sample_weight_col_name: Optional[str] = None,
1273
1273
  zero_division: Union[str, int] = "warn",
1274
- ) -> Union[float, npt.NDArray[np.float_]]:
1274
+ ) -> Union[float, npt.NDArray[np.float64]]:
1275
1275
  """
1276
1276
  Compute the recall.
1277
1277
 
@@ -1406,14 +1406,14 @@ def _check_binary_labels(
1406
1406
 
1407
1407
 
1408
1408
  def _prf_divide(
1409
- numerator: npt.NDArray[np.float_],
1410
- denominator: npt.NDArray[np.float_],
1409
+ numerator: npt.NDArray[np.float64],
1410
+ denominator: npt.NDArray[np.float64],
1411
1411
  metric: str,
1412
1412
  modifier: str,
1413
1413
  average: Optional[str] = None,
1414
1414
  warn_for: Union[tuple[str, ...], set[str]] = ("precision", "recall", "f-score"),
1415
1415
  zero_division: Union[str, int] = "warn",
1416
- ) -> npt.NDArray[np.float_]:
1416
+ ) -> npt.NDArray[np.float64]:
1417
1417
  """Performs division and handles divide-by-zero.
1418
1418
 
1419
1419
  On zero-division, sets the corresponding result elements equal to
@@ -1436,7 +1436,7 @@ def _prf_divide(
1436
1436
  "warn", this acts as 0, but warnings are also raised.
1437
1437
 
1438
1438
  Returns:
1439
- npt.NDArray[np.float_]: Result of the division, an array of floats.
1439
+ npt.NDArray[np.float64]: Result of the division, an array of floats.
1440
1440
  """
1441
1441
  mask = denominator == 0.0
1442
1442
  denominator = denominator.copy()
@@ -1522,7 +1522,7 @@ def _check_zero_division(zero_division: Union[int, float, str]) -> float:
1522
1522
  return np.nan
1523
1523
 
1524
1524
 
1525
- def _nanaverage(a: npt.NDArray[np.float_], weights: Optional[npt.ArrayLike] = None) -> Any:
1525
+ def _nanaverage(a: npt.NDArray[np.float64], weights: Optional[npt.ArrayLike] = None) -> Any:
1526
1526
  """Compute the weighted average, ignoring NaNs.
1527
1527
 
1528
1528
  Args:
@@ -26,7 +26,7 @@ def correlation(*, df: snowpark.DataFrame, columns: Optional[Collection[str]] =
26
26
  The below steps explain how correlation matrix is computed in a distributed way:
27
27
  Let n = # of rows in the dataframe; sqrt_n = sqrt(n); X, Y are 2 columns in the dataframe
28
28
  Correlation(X, Y) = numerator/denominator where
29
- numerator = dot(X/sqrt_n, Y/sqrt_n) - sum(X/n)*sum(X/n)
29
+ numerator = dot(X/sqrt_n, Y/sqrt_n) - sum(X/n)*sum(Y/n)
30
30
  denominator = std_dev(X)*std_dev(Y)
31
31
  std_dev(X) = sqrt(dot(X/sqrt_n, X/sqrt_n) - sum(X/n)*sum(X/n))
32
32
 
@@ -74,27 +74,38 @@ def correlation(*, df: snowpark.DataFrame, columns: Optional[Collection[str]] =
74
74
  # Pushing this to a udtf requires creating a temp udtf which takes about 20 secs, so it doesn't make sense
75
75
  # to have this in a udtf.
76
76
  n_cols = len(columns)
77
- sum_arr = np.zeros(n_cols)
78
- squared_sum_arr = np.zeros(n_cols)
77
+ column_means = np.zeros(n_cols)
78
+ mean_of_squares = np.zeros(n_cols)
79
79
  dot_prod = np.zeros((n_cols, n_cols))
80
80
  # Get sum, dot_prod and squared sum array from the results.
81
81
  for i in range(len(results)):
82
82
  x = results[i]
83
83
  if x[1] == "sum_by_count":
84
- sum_arr = cloudpickle.loads(x[0])
84
+ column_means = cloudpickle.loads(x[0])
85
85
  else:
86
86
  row = int(x[1].strip("row_"))
87
87
  dot_prod[row, :] = cloudpickle.loads(x[0])
88
- squared_sum_arr[row] = dot_prod[row, row]
88
+ mean_of_squares[row] = dot_prod[row, row]
89
89
 
90
90
  # sum(X/n)*sum(Y/n) is computed for all combinations of X,Y (columns in the dataframe)
91
- exey_arr = np.einsum("t,m->tm", sum_arr, sum_arr, optimize="optimal")
91
+ exey_arr = np.einsum("t,m->tm", column_means, column_means, optimize="optimal")
92
92
  numerator_matrix = dot_prod - exey_arr
93
93
 
94
94
  # standard deviation for all columns in the dataframe
95
- stddev_arr = np.sqrt(squared_sum_arr - np.einsum("i, i -> i", sum_arr, sum_arr, optimize="optimal"))
95
+ variance_arr = mean_of_squares - np.einsum("i, i -> i", column_means, column_means, optimize="optimal")
96
+ # ensure non-negative values from potential precision issues where variance might be slightly negative
97
+ variance_arr = np.maximum(variance_arr, 0)
98
+ stddev_arr = np.sqrt(variance_arr)
96
99
  # std_dev(X)*std_dev(Y) is computed for all combinations of X,Y (columns in the dataframe)
97
100
  denominator_matrix = np.einsum("t,m->tm", stddev_arr, stddev_arr, optimize="optimal")
98
- corr_res = numerator_matrix / denominator_matrix
101
+
102
+ # Use np.divide to handle NaN cases
103
+ corr_res = np.divide(
104
+ numerator_matrix,
105
+ denominator_matrix,
106
+ out=np.full_like(numerator_matrix, np.nan),
107
+ where=(denominator_matrix != 0),
108
+ )
109
+
99
110
  correlation_matrix = pd.DataFrame(corr_res, columns=columns, index=columns)
100
111
  return correlation_matrix
@@ -26,7 +26,7 @@ def precision_recall_curve(
26
26
  probas_pred_col_name: str,
27
27
  pos_label: Optional[Union[str, int]] = None,
28
28
  sample_weight_col_name: Optional[str] = None,
29
- ) -> tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]:
29
+ ) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]]:
30
30
  """
31
31
  Compute precision-recall pairs for different probability thresholds.
32
32
 
@@ -125,7 +125,7 @@ def precision_recall_curve(
125
125
 
126
126
  kwargs = telemetry.get_sproc_statement_params_kwargs(precision_recall_curve_anon_sproc, statement_params)
127
127
  result_object = result.deserialize(session, precision_recall_curve_anon_sproc(session, **kwargs))
128
- res: tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]] = result_object
128
+ res: tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]] = result_object
129
129
  return res
130
130
 
131
131
 
@@ -140,7 +140,7 @@ def roc_auc_score(
140
140
  max_fpr: Optional[float] = None,
141
141
  multi_class: str = "raise",
142
142
  labels: Optional[npt.ArrayLike] = None,
143
- ) -> Union[float, npt.NDArray[np.float_]]:
143
+ ) -> Union[float, npt.NDArray[np.float64]]:
144
144
  """
145
145
  Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC)
146
146
  from prediction scores.
@@ -276,7 +276,7 @@ def roc_auc_score(
276
276
 
277
277
  kwargs = telemetry.get_sproc_statement_params_kwargs(roc_auc_score_anon_sproc, statement_params)
278
278
  result_object = result.deserialize(session, roc_auc_score_anon_sproc(session, **kwargs))
279
- auc: Union[float, npt.NDArray[np.float_]] = result_object
279
+ auc: Union[float, npt.NDArray[np.float64]] = result_object
280
280
  return auc
281
281
 
282
282
 
@@ -289,7 +289,7 @@ def roc_curve(
289
289
  pos_label: Optional[Union[str, int]] = None,
290
290
  sample_weight_col_name: Optional[str] = None,
291
291
  drop_intermediate: bool = True,
292
- ) -> tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]:
292
+ ) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]]:
293
293
  """
294
294
  Compute Receiver operating characteristic (ROC).
295
295
 
@@ -380,6 +380,6 @@ def roc_curve(
380
380
  kwargs = telemetry.get_sproc_statement_params_kwargs(roc_curve_anon_sproc, statement_params)
381
381
  result_object = result.deserialize(session, roc_curve_anon_sproc(session, **kwargs))
382
382
 
383
- res: tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]] = result_object
383
+ res: tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]] = result_object
384
384
 
385
385
  return res
@@ -29,7 +29,7 @@ def d2_absolute_error_score(
29
29
  y_pred_col_names: Union[str, list[str]],
30
30
  sample_weight_col_name: Optional[str] = None,
31
31
  multioutput: Union[str, npt.ArrayLike] = "uniform_average",
32
- ) -> Union[float, npt.NDArray[np.float_]]:
32
+ ) -> Union[float, npt.NDArray[np.float64]]:
33
33
  """
34
34
  :math:`D^2` regression score function, \
35
35
  fraction of absolute error explained.
@@ -111,7 +111,7 @@ def d2_absolute_error_score(
111
111
 
112
112
  kwargs = telemetry.get_sproc_statement_params_kwargs(d2_absolute_error_score_anon_sproc, statement_params)
113
113
  result_object = result.deserialize(session, d2_absolute_error_score_anon_sproc(session, **kwargs))
114
- score: Union[float, npt.NDArray[np.float_]] = result_object
114
+ score: Union[float, npt.NDArray[np.float64]] = result_object
115
115
  return score
116
116
 
117
117
 
@@ -124,7 +124,7 @@ def d2_pinball_score(
124
124
  sample_weight_col_name: Optional[str] = None,
125
125
  alpha: float = 0.5,
126
126
  multioutput: Union[str, npt.ArrayLike] = "uniform_average",
127
- ) -> Union[float, npt.NDArray[np.float_]]:
127
+ ) -> Union[float, npt.NDArray[np.float64]]:
128
128
  """
129
129
  :math:`D^2` regression score function, fraction of pinball loss explained.
130
130
 
@@ -211,7 +211,7 @@ def d2_pinball_score(
211
211
  kwargs = telemetry.get_sproc_statement_params_kwargs(d2_pinball_score_anon_sproc, statement_params)
212
212
  result_object = result.deserialize(session, d2_pinball_score_anon_sproc(session, **kwargs))
213
213
 
214
- score: Union[float, npt.NDArray[np.float_]] = result_object
214
+ score: Union[float, npt.NDArray[np.float64]] = result_object
215
215
  return score
216
216
 
217
217
 
@@ -224,7 +224,7 @@ def explained_variance_score(
224
224
  sample_weight_col_name: Optional[str] = None,
225
225
  multioutput: Union[str, npt.ArrayLike] = "uniform_average",
226
226
  force_finite: bool = True,
227
- ) -> Union[float, npt.NDArray[np.float_]]:
227
+ ) -> Union[float, npt.NDArray[np.float64]]:
228
228
  """
229
229
  Explained variance regression score function.
230
230
 
@@ -326,7 +326,7 @@ def explained_variance_score(
326
326
 
327
327
  kwargs = telemetry.get_sproc_statement_params_kwargs(explained_variance_score_anon_sproc, statement_params)
328
328
  result_object = result.deserialize(session, explained_variance_score_anon_sproc(session, **kwargs))
329
- score: Union[float, npt.NDArray[np.float_]] = result_object
329
+ score: Union[float, npt.NDArray[np.float64]] = result_object
330
330
  return score
331
331
 
332
332
 
@@ -338,7 +338,7 @@ def mean_absolute_error(
338
338
  y_pred_col_names: Union[str, list[str]],
339
339
  sample_weight_col_name: Optional[str] = None,
340
340
  multioutput: Union[str, npt.ArrayLike] = "uniform_average",
341
- ) -> Union[float, npt.NDArray[np.float_]]:
341
+ ) -> Union[float, npt.NDArray[np.float64]]:
342
342
  """
343
343
  Mean absolute error regression loss.
344
344
 
@@ -411,7 +411,7 @@ def mean_absolute_percentage_error(
411
411
  y_pred_col_names: Union[str, list[str]],
412
412
  sample_weight_col_name: Optional[str] = None,
413
413
  multioutput: Union[str, npt.ArrayLike] = "uniform_average",
414
- ) -> Union[float, npt.NDArray[np.float_]]:
414
+ ) -> Union[float, npt.NDArray[np.float64]]:
415
415
  """
416
416
  Mean absolute percentage error (MAPE) regression loss.
417
417
 
@@ -495,7 +495,7 @@ def mean_squared_error(
495
495
  sample_weight_col_name: Optional[str] = None,
496
496
  multioutput: Union[str, npt.ArrayLike] = "uniform_average",
497
497
  squared: bool = True,
498
- ) -> Union[float, npt.NDArray[np.float_]]:
498
+ ) -> Union[float, npt.NDArray[np.float64]]:
499
499
  """
500
500
  Mean squared error regression loss.
501
501
 
@@ -264,6 +264,7 @@ def plot_force(
264
264
  def plot_influence_sensitivity(
265
265
  shap_values: type_hints.SupportedDataType,
266
266
  feature_values: type_hints.SupportedDataType,
267
+ infer_is_categorical: bool = True,
267
268
  figsize: tuple[float, float] = DEFAULT_FIGSIZE,
268
269
  ) -> Any:
269
270
  """
@@ -274,6 +275,8 @@ def plot_influence_sensitivity(
274
275
  Args:
275
276
  shap_values: pandas Series or 2D array containing the SHAP values for a specific feature
276
277
  feature_values: pandas Series or 2D array containing the feature values for the same feature
278
+ infer_is_categorical: If True, the function will infer if the feature is categorical
279
+ based on the number of unique values.
277
280
  figsize: tuple of (width, height) for the plot
278
281
 
279
282
  Returns:
@@ -294,7 +297,7 @@ def plot_influence_sensitivity(
294
297
  elif feature_values_df.shape[0] != shap_values_df.shape[0]:
295
298
  raise ValueError("Feature values and SHAP values must have the same number of rows.")
296
299
 
297
- scatter = _create_scatter_plot(feature_values, shap_values, figsize)
300
+ scatter = _create_scatter_plot(feature_values, shap_values, infer_is_categorical, figsize)
298
301
  return st.altair_chart(scatter) if use_streamlit else scatter
299
302
 
300
303
 
@@ -322,11 +325,13 @@ def _prepare_feature_values_for_streamlit(
322
325
  return feature_values, shap_values, st
323
326
 
324
327
 
325
- def _create_scatter_plot(feature_values: pd.Series, shap_values: pd.Series, figsize: tuple[float, float]) -> alt.Chart:
328
+ def _create_scatter_plot(
329
+ feature_values: pd.Series, shap_values: pd.Series, infer_is_categorical: bool, figsize: tuple[float, float]
330
+ ) -> alt.Chart:
326
331
  unique_vals = np.sort(np.unique(feature_values.values))
327
332
  max_points_per_unique_value = float(np.max(np.bincount(np.searchsorted(unique_vals, feature_values.values))))
328
333
  points_per_value = len(feature_values.values) / len(unique_vals)
329
- is_categorical = float(max(max_points_per_unique_value, points_per_value)) > 10
334
+ is_categorical = float(max(max_points_per_unique_value, points_per_value)) > 10 if infer_is_categorical else False
330
335
 
331
336
  kwargs = (
332
337
  {
@@ -403,9 +408,11 @@ def plot_violin(
403
408
  .transform_density(density="shap_value", groupby=["feature_name"], as_=["shap_value", "density"])
404
409
  .mark_area(orient="vertical")
405
410
  .encode(
406
- y=alt.Y("density:Q", title=None).stack("center").impute(None).axis(labels=False, grid=False, ticks=True),
411
+ y=alt.Y("density:Q", title=None).stack("center").impute(None).axis(labels=False, grid=False, ticks=False),
407
412
  x=alt.X("shap_value:Q", title="SHAP Value"),
408
- row=alt.Row("feature_name:N", sort=column_sort_order).spacing(0),
413
+ row=alt.Row(
414
+ "feature_name:N", sort=column_sort_order, header=alt.Header(labelAngle=0, labelAlign="left")
415
+ ).spacing(0),
409
416
  color=alt.Color("feature_name:N", legend=None),
410
417
  tooltip=["feature_name", "shap_value"],
411
418
  )
@@ -1,5 +1,5 @@
1
1
  from types import ModuleType
2
- from typing import Any, Optional, Protocol, Union
2
+ from typing import TYPE_CHECKING, Any, Optional, Union
3
3
 
4
4
  import pandas as pd
5
5
  from absl.logging import logging
@@ -17,15 +17,10 @@ from snowflake.ml.model._packager.model_meta import model_meta
17
17
  from snowflake.snowpark import exceptions as snowpark_exceptions, session
18
18
  from snowflake.snowpark._internal import utils as snowpark_utils
19
19
 
20
- logger = logging.getLogger(__name__)
21
-
20
+ if TYPE_CHECKING:
21
+ from snowflake.ml.experiment._experiment_info import ExperimentInfo
22
22
 
23
- class EventHandler(Protocol):
24
- """Protocol defining the interface for event handlers used during model operations."""
25
-
26
- def update(self, message: str) -> None:
27
- """Update with a progress message."""
28
- ...
23
+ logger = logging.getLogger(__name__)
29
24
 
30
25
 
31
26
  class ModelManager:
@@ -66,9 +61,10 @@ class ModelManager:
66
61
  code_paths: Optional[list[str]] = None,
67
62
  ext_modules: Optional[list[ModuleType]] = None,
68
63
  task: type_hints.Task = task.Task.UNKNOWN,
64
+ experiment_info: Optional["ExperimentInfo"] = None,
69
65
  options: Optional[type_hints.ModelSaveOption] = None,
70
66
  statement_params: Optional[dict[str, Any]] = None,
71
- event_handler: EventHandler,
67
+ progress_status: Optional[Any] = None,
72
68
  ) -> model_version_impl.ModelVersion:
73
69
 
74
70
  database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
@@ -150,9 +146,10 @@ class ModelManager:
150
146
  code_paths=code_paths,
151
147
  ext_modules=ext_modules,
152
148
  task=task,
149
+ experiment_info=experiment_info,
153
150
  options=options,
154
151
  statement_params=statement_params,
155
- event_handler=event_handler,
152
+ progress_status=progress_status,
156
153
  )
157
154
 
158
155
  def _log_model(
@@ -175,9 +172,10 @@ class ModelManager:
175
172
  code_paths: Optional[list[str]] = None,
176
173
  ext_modules: Optional[list[ModuleType]] = None,
177
174
  task: type_hints.Task = task.Task.UNKNOWN,
175
+ experiment_info: Optional["ExperimentInfo"] = None,
178
176
  options: Optional[type_hints.ModelSaveOption] = None,
179
177
  statement_params: Optional[dict[str, Any]] = None,
180
- event_handler: EventHandler,
178
+ progress_status: Optional[Any] = None,
181
179
  ) -> model_version_impl.ModelVersion:
182
180
  database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
183
181
  version_name_id = sql_identifier.SqlIdentifier(version_name)
@@ -265,7 +263,9 @@ class ModelManager:
265
263
  )
266
264
 
267
265
  logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
268
- event_handler.update("📦 Packaging model...")
266
+ if progress_status:
267
+ progress_status.update("packaging model...")
268
+ progress_status.increment()
269
269
 
270
270
  # Extract save_location from options if present
271
271
  save_location = None
@@ -279,6 +279,11 @@ class ModelManager:
279
279
  statement_params=statement_params,
280
280
  save_location=save_location,
281
281
  )
282
+
283
+ if progress_status:
284
+ progress_status.update("creating model manifest...")
285
+ progress_status.increment()
286
+
282
287
  model_metadata: model_meta.ModelMetadata = mc.save(
283
288
  name=model_name_id.resolved(),
284
289
  model=model,
@@ -295,7 +300,12 @@ class ModelManager:
295
300
  ext_modules=ext_modules,
296
301
  options=options,
297
302
  task=task,
303
+ experiment_info=experiment_info,
298
304
  )
305
+
306
+ if progress_status:
307
+ progress_status.update("uploading model files...")
308
+ progress_status.increment()
299
309
  statement_params = telemetry.add_statement_params_custom_tags(
300
310
  statement_params, model_metadata.telemetry_metadata()
301
311
  )
@@ -304,7 +314,9 @@ class ModelManager:
304
314
  )
305
315
 
306
316
  logger.info("Start creating MODEL object for you in the Snowflake.")
307
- event_handler.update("🏗️ Creating model object in Snowflake...")
317
+ if progress_status:
318
+ progress_status.update("creating model object in Snowflake...")
319
+ progress_status.increment()
308
320
 
309
321
  self._model_ops.create_from_stage(
310
322
  composed_model=mc,
@@ -331,6 +343,10 @@ class ModelManager:
331
343
  version_name=version_name_id,
332
344
  )
333
345
 
346
+ if progress_status:
347
+ progress_status.update("setting model metadata...")
348
+ progress_status.increment()
349
+
334
350
  if comment:
335
351
  mv.comment = comment
336
352
 
@@ -344,7 +360,8 @@ class ModelManager:
344
360
  statement_params=statement_params,
345
361
  )
346
362
 
347
- event_handler.update("✅ Model logged successfully!")
363
+ if progress_status:
364
+ progress_status.update("model logged successfully!")
348
365
 
349
366
  return mv
350
367