snowflake-ml-python 1.7.1__py3-none-any.whl → 1.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. snowflake/cortex/__init__.py +16 -8
  2. snowflake/cortex/_classify_text.py +12 -1
  3. snowflake/cortex/_complete.py +82 -13
  4. snowflake/cortex/_embed_text_1024.py +9 -2
  5. snowflake/cortex/_embed_text_768.py +9 -2
  6. snowflake/cortex/_extract_answer.py +9 -2
  7. snowflake/cortex/_sentiment.py +9 -2
  8. snowflake/cortex/_summarize.py +9 -2
  9. snowflake/cortex/_translate.py +9 -2
  10. snowflake/ml/_internal/env_utils.py +7 -52
  11. snowflake/ml/_internal/utils/identifier.py +4 -2
  12. snowflake/ml/_internal/utils/jwt_generator.py +141 -0
  13. snowflake/ml/data/__init__.py +3 -0
  14. snowflake/ml/data/_internal/arrow_ingestor.py +4 -4
  15. snowflake/ml/data/data_connector.py +53 -11
  16. snowflake/ml/data/data_ingestor.py +2 -1
  17. snowflake/ml/data/torch_utils.py +18 -5
  18. snowflake/ml/feature_store/examples/example_helper.py +2 -1
  19. snowflake/ml/fileset/fileset.py +18 -18
  20. snowflake/ml/model/_client/model/model_version_impl.py +24 -8
  21. snowflake/ml/model/_client/ops/model_ops.py +2 -6
  22. snowflake/ml/model/_client/ops/service_ops.py +12 -7
  23. snowflake/ml/model/_client/sql/model_version.py +11 -0
  24. snowflake/ml/model/_client/sql/stage.py +1 -1
  25. snowflake/ml/model/_model_composer/model_composer.py +8 -3
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -1
  27. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  28. snowflake/ml/model/_model_composer/model_method/constants.py +1 -0
  29. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -0
  30. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +1 -1
  31. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +1 -1
  32. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
  33. snowflake/ml/model/_model_composer/model_method/model_method.py +9 -1
  34. snowflake/ml/model/_model_composer/model_user_file/model_user_file.py +27 -0
  35. snowflake/ml/model/_packager/model_handlers/_utils.py +27 -2
  36. snowflake/ml/model/_packager/model_handlers/catboost.py +3 -3
  37. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +5 -1
  38. snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -3
  39. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +55 -20
  40. snowflake/ml/model/_packager/model_handlers/sklearn.py +10 -9
  41. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +66 -28
  42. snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -17
  43. snowflake/ml/model/_packager/model_handlers/xgboost.py +3 -3
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -0
  45. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  46. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  47. snowflake/ml/model/_packager/model_task/model_task_utils.py +3 -2
  48. snowflake/ml/model/_signatures/pandas_handler.py +1 -1
  49. snowflake/ml/model/_signatures/snowpark_handler.py +8 -2
  50. snowflake/ml/model/_signatures/utils.py +0 -1
  51. snowflake/ml/model/type_hints.py +1 -0
  52. snowflake/ml/modeling/_internal/model_trainer_builder.py +0 -8
  53. snowflake/ml/modeling/_internal/model_transformer_builder.py +0 -13
  54. snowflake/ml/modeling/pipeline/pipeline.py +6 -176
  55. snowflake/ml/modeling/xgboost/xgb_classifier.py +161 -88
  56. snowflake/ml/modeling/xgboost/xgb_regressor.py +160 -85
  57. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +160 -85
  58. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +160 -85
  59. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +5 -170
  60. snowflake/ml/monitoring/_manager/model_monitor_manager.py +9 -9
  61. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -2
  62. snowflake/ml/monitoring/model_monitor.py +26 -11
  63. snowflake/ml/registry/_manager/model_manager.py +70 -33
  64. snowflake/ml/registry/registry.py +53 -34
  65. snowflake/ml/utils/authentication.py +75 -0
  66. snowflake/ml/version.py +1 -1
  67. {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/METADATA +120 -53
  68. {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/RECORD +71 -74
  69. {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/WHEEL +1 -1
  70. snowflake/ml/_internal/utils/retryable_http.py +0 -39
  71. snowflake/ml/fileset/parquet_parser.py +0 -170
  72. snowflake/ml/fileset/tf_dataset.py +0 -88
  73. snowflake/ml/fileset/torch_datapipe.py +0 -57
  74. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +0 -151
  75. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py +0 -66
  76. snowflake/ml/monitoring/entities/output_score_type.py +0 -90
  77. {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/LICENSE.txt +0 -0
  78. {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,141 @@
1
+ import base64
2
+ import hashlib
3
+ import logging
4
+ from datetime import datetime, timedelta, timezone
5
+ from typing import Optional
6
+
7
+ import jwt
8
+ from cryptography.hazmat.primitives import serialization
9
+ from cryptography.hazmat.primitives.asymmetric import types
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ ISSUER = "iss"
14
+ EXPIRE_TIME = "exp"
15
+ ISSUE_TIME = "iat"
16
+ SUBJECT = "sub"
17
+
18
+
19
+ class JWTGenerator:
20
+ """
21
+ Creates and signs a JWT with the specified private key file, username, and account identifier. The JWTGenerator
22
+ keeps the generated token and only regenerates the token if a specified period of time has passed.
23
+ """
24
+
25
+ _DEFAULT_LIFETIME = timedelta(minutes=59) # The tokens will have a 59-minute lifetime
26
+ _DEFAULT_RENEWAL_DELTA = timedelta(minutes=54) # Tokens will be renewed after 54 minutes
27
+ ALGORITHM = "RS256" # Tokens will be generated using RSA with SHA256
28
+
29
+ def __init__(
30
+ self,
31
+ account: str,
32
+ user: str,
33
+ private_key: types.PRIVATE_KEY_TYPES,
34
+ lifetime: Optional[timedelta] = None,
35
+ renewal_delay: Optional[timedelta] = None,
36
+ ) -> None:
37
+ """
38
+ Create a new JWTGenerator object.
39
+
40
+ Args:
41
+ account: The account identifier.
42
+ user: The username.
43
+ private_key: The private key used to sign the JWT.
44
+ lifetime: The lifetime of the token.
45
+ renewal_delay: The time before the token expires to renew it.
46
+ """
47
+
48
+ # Construct the fully qualified name of the user in uppercase.
49
+ self.account = JWTGenerator._prepare_account_name_for_jwt(account)
50
+ self.user = user.upper()
51
+ self.qualified_username = self.account + "." + self.user
52
+ self.private_key = private_key
53
+ self.public_key_fp = JWTGenerator._calculate_public_key_fingerprint(self.private_key)
54
+
55
+ self.issuer = self.qualified_username + "." + self.public_key_fp
56
+ self.lifetime = lifetime or JWTGenerator._DEFAULT_LIFETIME
57
+ self.renewal_delay = renewal_delay or JWTGenerator._DEFAULT_RENEWAL_DELTA
58
+ self.renew_time = datetime.now(timezone.utc)
59
+ self.token: Optional[str] = None
60
+
61
+ logger.info(
62
+ """Creating JWTGenerator with arguments
63
+ account : %s, user : %s, lifetime : %s, renewal_delay : %s""",
64
+ self.account,
65
+ self.user,
66
+ self.lifetime,
67
+ self.renewal_delay,
68
+ )
69
+
70
+ @staticmethod
71
+ def _prepare_account_name_for_jwt(raw_account: str) -> str:
72
+ account = raw_account
73
+ if ".global" not in account:
74
+ # Handle the general case.
75
+ idx = account.find(".")
76
+ if idx > 0:
77
+ account = account[0:idx]
78
+ else:
79
+ # Handle the replication case.
80
+ idx = account.find("-")
81
+ if idx > 0:
82
+ account = account[0:idx]
83
+ # Use uppercase for the account identifier.
84
+ return account.upper()
85
+
86
+ def get_token(self) -> str:
87
+ now = datetime.now(timezone.utc) # Fetch the current time
88
+ if self.token is not None and self.renew_time > now:
89
+ return self.token
90
+
91
+ # If the token has expired or doesn't exist, regenerate the token.
92
+ logger.info(
93
+ "Generating a new token because the present time (%s) is later than the renewal time (%s)",
94
+ now,
95
+ self.renew_time,
96
+ )
97
+ # Calculate the next time we need to renew the token.
98
+ self.renew_time = now + self.renewal_delay
99
+
100
+ # Create our payload
101
+ payload = {
102
+ # Set the issuer to the fully qualified username concatenated with the public key fingerprint.
103
+ ISSUER: self.issuer,
104
+ # Set the subject to the fully qualified username.
105
+ SUBJECT: self.qualified_username,
106
+ # Set the issue time to now.
107
+ ISSUE_TIME: now,
108
+ # Set the expiration time, based on the lifetime specified for this object.
109
+ EXPIRE_TIME: now + self.lifetime,
110
+ }
111
+
112
+ # Regenerate the actual token
113
+ token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM)
114
+ # If you are using a version of PyJWT prior to 2.0, jwt.encode returns a byte string instead of a string.
115
+ # If the token is a byte string, convert it to a string.
116
+ if isinstance(token, bytes):
117
+ token = token.decode("utf-8")
118
+ self.token = token
119
+ logger.info(
120
+ "Generated a JWT with the following payload: %s",
121
+ jwt.decode(self.token, key=self.private_key.public_key(), algorithms=[JWTGenerator.ALGORITHM]),
122
+ )
123
+
124
+ return token
125
+
126
+ @staticmethod
127
+ def _calculate_public_key_fingerprint(private_key: types.PRIVATE_KEY_TYPES) -> str:
128
+ # Get the raw bytes of public key.
129
+ public_key_raw = private_key.public_key().public_bytes(
130
+ serialization.Encoding.DER, serialization.PublicFormat.SubjectPublicKeyInfo
131
+ )
132
+
133
+ # Get the sha256 hash of the raw bytes.
134
+ sha256hash = hashlib.sha256()
135
+ sha256hash.update(public_key_raw)
136
+
137
+ # Base64-encode the value and prepend the prefix 'SHA256:'.
138
+ public_key_fp = "SHA256:" + base64.b64encode(sha256hash.digest()).decode("utf-8")
139
+ logger.info("Public key fingerprint is %s", public_key_fp)
140
+
141
+ return public_key_fp
@@ -1,5 +1,8 @@
1
+ from pkgutil import extend_path
2
+
1
3
  from .data_connector import DataConnector
2
4
  from .data_ingestor import DataIngestor, DataIngestorType
3
5
  from .data_source import DataFrameInfo, DatasetInfo, DataSource
4
6
 
5
7
  __all__ = ["DataConnector", "DataSource", "DataFrameInfo", "DatasetInfo", "DataIngestor", "DataIngestorType"]
8
+ __path__ = extend_path(__path__, __name__)
@@ -2,7 +2,7 @@ import collections
2
2
  import logging
3
3
  import os
4
4
  import time
5
- from typing import Any, Deque, Dict, Iterator, List, Optional, Union
5
+ from typing import Any, Deque, Dict, Iterator, List, Optional, Sequence, Union
6
6
 
7
7
  import numpy as np
8
8
  import numpy.typing as npt
@@ -47,7 +47,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
47
47
  def __init__(
48
48
  self,
49
49
  session: snowpark.Session,
50
- data_sources: List[data_source.DataSource],
50
+ data_sources: Sequence[data_source.DataSource],
51
51
  format: Optional[str] = None,
52
52
  **kwargs: Any,
53
53
  ) -> None:
@@ -60,14 +60,14 @@ class ArrowIngestor(data_ingestor.DataIngestor):
60
60
  kwargs: Miscellaneous arguments passed to underlying PyArrow Dataset initializer.
61
61
  """
62
62
  self._session = session
63
- self._data_sources = data_sources
63
+ self._data_sources = list(data_sources)
64
64
  self._format = format
65
65
  self._kwargs = kwargs
66
66
 
67
67
  self._schema: Optional[pa.Schema] = None
68
68
 
69
69
  @classmethod
70
- def from_sources(cls, session: snowpark.Session, sources: List[data_source.DataSource]) -> "ArrowIngestor":
70
+ def from_sources(cls, session: snowpark.Session, sources: Sequence[data_source.DataSource]) -> "ArrowIngestor":
71
71
  return cls(session, sources)
72
72
 
73
73
  @property
@@ -1,5 +1,16 @@
1
1
  import os
2
- from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Type, TypeVar
2
+ from typing import (
3
+ TYPE_CHECKING,
4
+ Any,
5
+ Dict,
6
+ Generator,
7
+ List,
8
+ Optional,
9
+ Sequence,
10
+ Type,
11
+ TypeVar,
12
+ cast,
13
+ )
3
14
 
4
15
  import numpy.typing as npt
5
16
  from typing_extensions import deprecated
@@ -12,6 +23,7 @@ from snowflake.ml.modeling._internal.constants import (
12
23
  IN_ML_RUNTIME_ENV_VAR,
13
24
  USE_OPTIMIZED_DATA_INGESTOR,
14
25
  )
26
+ from snowflake.snowpark import context as sf_context
15
27
 
16
28
  if TYPE_CHECKING:
17
29
  import pandas as pd
@@ -35,8 +47,10 @@ class DataConnector:
35
47
  def __init__(
36
48
  self,
37
49
  ingestor: data_ingestor.DataIngestor,
50
+ **kwargs: Any,
38
51
  ) -> None:
39
52
  self._ingestor = ingestor
53
+ self._kwargs = kwargs
40
54
 
41
55
  @classmethod
42
56
  @snowpark._internal.utils.private_preview(version="1.6.0")
@@ -44,20 +58,34 @@ class DataConnector:
44
58
  cls: Type[DataConnectorType],
45
59
  df: snowpark.DataFrame,
46
60
  ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
47
- **kwargs: Any
61
+ **kwargs: Any,
48
62
  ) -> DataConnectorType:
49
63
  if len(df.queries["queries"]) != 1 or len(df.queries["post_actions"]) != 0:
50
64
  raise ValueError("DataFrames with multiple queries and/or post-actions not supported")
51
- source = data_source.DataFrameInfo(df.queries["queries"][0])
52
- assert df._session is not None
53
- return cls.from_sources(df._session, [source], ingestor_class=ingestor_class, **kwargs)
65
+ return cast(
66
+ DataConnectorType,
67
+ cls.from_sql(df.queries["queries"][0], session=df._session, ingestor_class=ingestor_class, **kwargs),
68
+ )
69
+
70
+ @classmethod
71
+ @snowpark._internal.utils.private_preview(version="1.7.3")
72
+ def from_sql(
73
+ cls: Type[DataConnectorType],
74
+ query: str,
75
+ session: Optional[snowpark.Session] = None,
76
+ ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
77
+ **kwargs: Any,
78
+ ) -> DataConnectorType:
79
+ session = session or sf_context.get_active_session()
80
+ source = data_source.DataFrameInfo(query)
81
+ return cls.from_sources(session, [source], ingestor_class=ingestor_class, **kwargs)
54
82
 
55
83
  @classmethod
56
84
  def from_dataset(
57
85
  cls: Type[DataConnectorType],
58
86
  ds: "dataset.Dataset",
59
87
  ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
60
- **kwargs: Any
88
+ **kwargs: Any,
61
89
  ) -> DataConnectorType:
62
90
  dsv = ds.selected_version
63
91
  assert dsv is not None
@@ -75,9 +103,9 @@ class DataConnector:
75
103
  def from_sources(
76
104
  cls: Type[DataConnectorType],
77
105
  session: snowpark.Session,
78
- sources: List[data_source.DataSource],
106
+ sources: Sequence[data_source.DataSource],
79
107
  ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
80
- **kwargs: Any
108
+ **kwargs: Any,
81
109
  ) -> DataConnectorType:
82
110
  ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
83
111
  ingestor = ingestor_class.from_sources(session, sources)
@@ -130,7 +158,11 @@ class DataConnector:
130
158
  func_params_to_log=["batch_size", "shuffle", "drop_last_batch"],
131
159
  )
132
160
  def to_torch_datapipe(
133
- self, *, batch_size: int, shuffle: bool = False, drop_last_batch: bool = True
161
+ self,
162
+ *,
163
+ batch_size: int,
164
+ shuffle: bool = False,
165
+ drop_last_batch: bool = True,
134
166
  ) -> "torch_data.IterDataPipe": # type: ignore[type-arg]
135
167
  """Transform the Snowflake data into a ready-to-use Pytorch datapipe.
136
168
 
@@ -149,8 +181,13 @@ class DataConnector:
149
181
  """
150
182
  from snowflake.ml.data import torch_utils
151
183
 
184
+ expand_dims = self._kwargs.get("expand_dims", True)
152
185
  return torch_utils.TorchDataPipeWrapper(
153
- self._ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last_batch
186
+ self._ingestor,
187
+ batch_size=batch_size,
188
+ shuffle=shuffle,
189
+ drop_last=drop_last_batch,
190
+ expand_dims=expand_dims,
154
191
  )
155
192
 
156
193
  @telemetry.send_api_usage_telemetry(
@@ -179,8 +216,13 @@ class DataConnector:
179
216
  """
180
217
  from snowflake.ml.data import torch_utils
181
218
 
219
+ expand_dims = self._kwargs.get("expand_dims", True)
182
220
  return torch_utils.TorchDatasetWrapper(
183
- self._ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last_batch
221
+ self._ingestor,
222
+ batch_size=batch_size,
223
+ shuffle=shuffle,
224
+ drop_last=drop_last_batch,
225
+ expand_dims=expand_dims,
184
226
  )
185
227
 
186
228
  @telemetry.send_api_usage_telemetry(
@@ -6,6 +6,7 @@ from typing import (
6
6
  List,
7
7
  Optional,
8
8
  Protocol,
9
+ Sequence,
9
10
  Type,
10
11
  TypeVar,
11
12
  )
@@ -25,7 +26,7 @@ DataIngestorType = TypeVar("DataIngestorType", bound="DataIngestor")
25
26
  class DataIngestor(Protocol):
26
27
  @classmethod
27
28
  def from_sources(
28
- cls: Type[DataIngestorType], session: snowpark.Session, sources: List[data_source.DataSource]
29
+ cls: Type[DataIngestorType], session: snowpark.Session, sources: Sequence[data_source.DataSource]
29
30
  ) -> DataIngestorType:
30
31
  raise NotImplementedError
31
32
 
@@ -17,6 +17,7 @@ class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
17
17
  batch_size: Optional[int],
18
18
  shuffle: bool = False,
19
19
  drop_last: bool = False,
20
+ expand_dims: bool = True,
20
21
  ) -> None:
21
22
  """Not intended for direct usage. Use DataConnector.to_torch_dataset() instead"""
22
23
  squeeze = False
@@ -29,6 +30,7 @@ class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
29
30
  self._shuffle = shuffle
30
31
  self._drop_last = drop_last
31
32
  self._squeeze_outputs = squeeze
33
+ self._expand_dims = expand_dims
32
34
 
33
35
  def __iter__(self) -> Iterator[Dict[str, Union[npt.NDArray[Any], List[Any]]]]:
34
36
  max_idx = 0
@@ -47,7 +49,10 @@ class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
47
49
  ):
48
50
  # Skip indices during multi-process data loading to prevent data duplication
49
51
  if counter == filter_idx:
50
- yield {k: _preprocess_array(v, squeeze=self._squeeze_outputs) for k, v in batch.items()}
52
+ yield {
53
+ k: _preprocess_array(v, squeeze=self._squeeze_outputs, expand_dims=self._expand_dims)
54
+ for k, v in batch.items()
55
+ }
51
56
  if counter < max_idx:
52
57
  counter += 1
53
58
  else:
@@ -58,13 +63,21 @@ class TorchDataPipeWrapper(TorchDatasetWrapper, torch.utils.data.IterDataPipe[Di
58
63
  """Wrap a DataIngestor into a PyTorch IterDataPipe"""
59
64
 
60
65
  def __init__(
61
- self, ingestor: data_ingestor.DataIngestor, *, batch_size: int, shuffle: bool = False, drop_last: bool = False
66
+ self,
67
+ ingestor: data_ingestor.DataIngestor,
68
+ *,
69
+ batch_size: int,
70
+ shuffle: bool = False,
71
+ drop_last: bool = False,
72
+ expand_dims: bool = True,
62
73
  ) -> None:
63
74
  """Not intended for direct usage. Use DataConnector.to_torch_datapipe() instead"""
64
- super().__init__(ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
75
+ super().__init__(ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, expand_dims=expand_dims)
65
76
 
66
77
 
67
- def _preprocess_array(arr: npt.NDArray[Any], squeeze: bool = False) -> Union[npt.NDArray[Any], List[np.object_]]:
78
+ def _preprocess_array(
79
+ arr: npt.NDArray[Any], squeeze: bool = False, expand_dims: bool = True
80
+ ) -> Union[npt.NDArray[Any], List[np.object_]]:
68
81
  """Preprocesses batch column values."""
69
82
  single_dimensional = arr.ndim < 2 and not arr.dtype == np.object_
70
83
 
@@ -73,7 +86,7 @@ def _preprocess_array(arr: npt.NDArray[Any], squeeze: bool = False) -> Union[npt
73
86
  arr = arr.squeeze(axis=0)
74
87
 
75
88
  # For single dimensional data,
76
- if single_dimensional:
89
+ if single_dimensional and expand_dims:
77
90
  axis = 0 if arr.ndim == 0 else 1
78
91
  arr = np.expand_dims(arr, axis=axis)
79
92
 
@@ -45,8 +45,9 @@ class ExampleHelper:
45
45
  """Return a dataframe object about descriptions of all examples."""
46
46
  root_dir = Path(__file__).parent
47
47
  rows = []
48
+ hide_folders = ["citibike_trip_features", "source_data"]
48
49
  for f_name in os.listdir(root_dir):
49
- if os.path.isdir(os.path.join(root_dir, f_name)) and f_name[0].isalpha() and f_name != "source_data":
50
+ if os.path.isdir(os.path.join(root_dir, f_name)) and f_name[0].isalpha() and f_name not in hide_folders:
50
51
  source_file_path = root_dir.joinpath(f"{f_name}/source.yaml")
51
52
  source_dict = self._read_yaml(str(source_file_path))
52
53
  rows.append((f_name, source_dict["model_category"], source_dict["desc"], source_dict["label_columns"]))
@@ -11,11 +11,9 @@ from snowflake.ml._internal.exceptions import (
11
11
  fileset_error_messages,
12
12
  fileset_errors,
13
13
  )
14
- from snowflake.ml._internal.utils import (
15
- identifier,
16
- import_utils,
17
- snowpark_dataframe_utils,
18
- )
14
+ from snowflake.ml._internal.utils import identifier, snowpark_dataframe_utils
15
+ from snowflake.ml.data import data_connector
16
+ from snowflake.ml.data._internal import arrow_ingestor
19
17
  from snowflake.ml.fileset import sfcfs
20
18
  from snowflake.snowpark import exceptions as snowpark_exceptions, functions
21
19
 
@@ -285,6 +283,16 @@ class FileSet:
285
283
  """Get the Snowflake absolute path to this FileSet directory."""
286
284
  return _fileset_absolute_path(self._target_stage_loc, self.name)
287
285
 
286
+ def _to_data_connector(self) -> data_connector.DataConnector:
287
+ self._fs.optimize_read(self._list_files())
288
+ ingester = arrow_ingestor.ArrowIngestor(
289
+ self._snowpark_session,
290
+ self._list_files(),
291
+ format="parquet",
292
+ filesystem=self._fs,
293
+ )
294
+ return data_connector.DataConnector(ingester, expand_dims=False)
295
+
288
296
  @telemetry.send_api_usage_telemetry(
289
297
  project=_PROJECT,
290
298
  )
@@ -362,13 +370,9 @@ class FileSet:
362
370
  ----
363
371
  {'_COL_1':[10]}
364
372
  """
365
- IterableWrapper, _ = import_utils.import_or_get_dummy("torchdata.datapipes.iter.IterableWrapper")
366
- torch_datapipe_module, _ = import_utils.import_or_get_dummy("snowflake.ml.fileset.torch_datapipe")
367
-
368
- self._fs.optimize_read(self._list_files())
369
-
370
- input_dp = IterableWrapper(self._list_files())
371
- return torch_datapipe_module.ReadAndParseParquet(input_dp, self._fs, batch_size, shuffle, drop_last_batch)
373
+ return self._to_data_connector().to_torch_datapipe(
374
+ batch_size=batch_size, shuffle=shuffle, drop_last_batch=drop_last_batch
375
+ )
372
376
 
373
377
  @telemetry.send_api_usage_telemetry(
374
378
  project=_PROJECT,
@@ -402,12 +406,8 @@ class FileSet:
402
406
  ----
403
407
  {'_COL_1': <tf.Tensor: shape=(1,), dtype=int64, numpy=[10]>}
404
408
  """
405
- tf_dataset_module, _ = import_utils.import_or_get_dummy("snowflake.ml.fileset.tf_dataset")
406
-
407
- self._fs.optimize_read(self._list_files())
408
-
409
- return tf_dataset_module.read_and_parse_parquet(
410
- self._list_files(), self._fs, batch_size, shuffle, drop_last_batch
409
+ return self._to_data_connector().to_tf_dataset(
410
+ batch_size=batch_size, shuffle=shuffle, drop_last_batch=drop_last_batch
411
411
  )
412
412
 
413
413
  @telemetry.send_api_usage_telemetry(
@@ -14,7 +14,7 @@ from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
14
14
  from snowflake.ml.model._model_composer import model_composer
15
15
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
16
16
  from snowflake.ml.model._packager.model_handlers import snowmlmodel
17
- from snowflake.snowpark import Session, dataframe
17
+ from snowflake.snowpark import Session, async_job, dataframe
18
18
 
19
19
  _TELEMETRY_PROJECT = "MLOps"
20
20
  _TELEMETRY_SUBPROJECT = "ModelManagement"
@@ -447,13 +447,15 @@ class ModelVersion(lineage_node.LineageNode):
447
447
  target_function_info = functions[0]
448
448
 
449
449
  if service_name:
450
+ database_name_id, schema_name_id, service_name_id = sql_identifier.parse_fully_qualified_name(service_name)
451
+
450
452
  return self._model_ops.invoke_method(
451
453
  method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
452
454
  signature=target_function_info["signature"],
453
455
  X=X,
454
- database_name=None,
455
- schema_name=None,
456
- service_name=sql_identifier.SqlIdentifier(service_name),
456
+ database_name=database_name_id,
457
+ schema_name=schema_name_id,
458
+ service_name=service_name_id,
457
459
  strict_input_validation=strict_input_validation,
458
460
  statement_params=statement_params,
459
461
  )
@@ -631,7 +633,8 @@ class ModelVersion(lineage_node.LineageNode):
631
633
  max_batch_rows: Optional[int] = None,
632
634
  force_rebuild: bool = False,
633
635
  build_external_access_integration: Optional[str] = None,
634
- ) -> str:
636
+ block: bool = True,
637
+ ) -> Union[str, async_job.AsyncJob]:
635
638
  """Create an inference service with the given spec.
636
639
 
637
640
  Args:
@@ -659,6 +662,9 @@ class ModelVersion(lineage_node.LineageNode):
659
662
  force_rebuild: Whether to force a model inference image rebuild.
660
663
  build_external_access_integration: (Deprecated) The external access integration for image build. This is
661
664
  usually permitting access to conda & PyPI repositories.
665
+ block: A bool value indicating whether this function will wait until the service is available.
666
+ When it is ``False``, this function executes the underlying service creation asynchronously
667
+ and returns an :class:`AsyncJob`.
662
668
  """
663
669
  ...
664
670
 
@@ -679,7 +685,8 @@ class ModelVersion(lineage_node.LineageNode):
679
685
  max_batch_rows: Optional[int] = None,
680
686
  force_rebuild: bool = False,
681
687
  build_external_access_integrations: Optional[List[str]] = None,
682
- ) -> str:
688
+ block: bool = True,
689
+ ) -> Union[str, async_job.AsyncJob]:
683
690
  """Create an inference service with the given spec.
684
691
 
685
692
  Args:
@@ -707,6 +714,9 @@ class ModelVersion(lineage_node.LineageNode):
707
714
  force_rebuild: Whether to force a model inference image rebuild.
708
715
  build_external_access_integrations: The external access integrations for image build. This is usually
709
716
  permitting access to conda & PyPI repositories.
717
+ block: A bool value indicating whether this function will wait until the service is available.
718
+ When it is ``False``, this function executes the underlying service creation asynchronously
719
+ and returns an :class:`AsyncJob`.
710
720
  """
711
721
  ...
712
722
 
@@ -742,7 +752,8 @@ class ModelVersion(lineage_node.LineageNode):
742
752
  force_rebuild: bool = False,
743
753
  build_external_access_integration: Optional[str] = None,
744
754
  build_external_access_integrations: Optional[List[str]] = None,
745
- ) -> str:
755
+ block: bool = True,
756
+ ) -> Union[str, async_job.AsyncJob]:
746
757
  """Create an inference service with the given spec.
747
758
 
748
759
  Args:
@@ -772,12 +783,16 @@ class ModelVersion(lineage_node.LineageNode):
772
783
  usually permitting access to conda & PyPI repositories.
773
784
  build_external_access_integrations: The external access integrations for image build. This is usually
774
785
  permitting access to conda & PyPI repositories.
786
+ block: A bool value indicating whether this function will wait until the service is available.
787
+ When it is False, this function executes the underlying service creation asynchronously
788
+ and returns an AsyncJob.
775
789
 
776
790
  Raises:
777
791
  ValueError: Illegal external access integration arguments.
778
792
 
779
793
  Returns:
780
- Result information about service creation from server.
794
+ If `block=True`, return result information about service creation from server.
795
+ Otherwise, return the service creation AsyncJob.
781
796
  """
782
797
  statement_params = telemetry.get_statement_params(
783
798
  project=_TELEMETRY_PROJECT,
@@ -829,6 +844,7 @@ class ModelVersion(lineage_node.LineageNode):
829
844
  if build_external_access_integrations is None
830
845
  else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
831
846
  ),
847
+ block=block,
832
848
  statement_params=statement_params,
833
849
  )
834
850
 
@@ -168,14 +168,10 @@ class ModelOperator:
168
168
  schema_name: Optional[sql_identifier.SqlIdentifier],
169
169
  model_name: sql_identifier.SqlIdentifier,
170
170
  version_name: sql_identifier.SqlIdentifier,
171
+ model_exists: bool,
171
172
  statement_params: Optional[Dict[str, Any]] = None,
172
173
  ) -> None:
173
- if self.validate_existence(
174
- database_name=database_name,
175
- schema_name=schema_name,
176
- model_name=model_name,
177
- statement_params=statement_params,
178
- ):
174
+ if model_exists:
179
175
  return self._model_version_client.add_version_from_model_version(
180
176
  source_database_name=source_database_name,
181
177
  source_schema_name=source_schema_name,
@@ -6,7 +6,7 @@ import re
6
6
  import tempfile
7
7
  import threading
8
8
  import time
9
- from typing import Any, Dict, List, Optional, Tuple, cast
9
+ from typing import Any, Dict, List, Optional, Tuple, Union, cast
10
10
 
11
11
  from packaging import version
12
12
 
@@ -15,7 +15,7 @@ from snowflake.ml._internal import file_utils
15
15
  from snowflake.ml._internal.utils import service_logger, snowflake_env, sql_identifier
16
16
  from snowflake.ml.model._client.service import model_deployment_spec
17
17
  from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
18
- from snowflake.snowpark import exceptions, row, session
18
+ from snowflake.snowpark import async_job, exceptions, row, session
19
19
  from snowflake.snowpark._internal import utils as snowpark_utils
20
20
 
21
21
  module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY)
@@ -107,8 +107,9 @@ class ServiceOperator:
107
107
  max_batch_rows: Optional[int],
108
108
  force_rebuild: bool,
109
109
  build_external_access_integrations: Optional[List[sql_identifier.SqlIdentifier]],
110
+ block: bool,
110
111
  statement_params: Optional[Dict[str, Any]] = None,
111
- ) -> str:
112
+ ) -> Union[str, async_job.AsyncJob]:
112
113
 
113
114
  # Fall back to the registry's database and schema if not provided
114
115
  database_name = database_name or self._database_name
@@ -204,11 +205,15 @@ class ServiceOperator:
204
205
  log_thread = self._start_service_log_streaming(
205
206
  async_job, services, model_inference_service_exists, force_rebuild, statement_params
206
207
  )
207
- log_thread.join()
208
208
 
209
- res = cast(str, cast(List[row.Row], async_job.result())[0][0])
210
- module_logger.info(f"Inference service {service_name} deployment complete: {res}")
211
- return res
209
+ if block:
210
+ log_thread.join()
211
+
212
+ res = cast(str, cast(List[row.Row], async_job.result())[0][0])
213
+ module_logger.info(f"Inference service {service_name} deployment complete: {res}")
214
+ return res
215
+ else:
216
+ return async_job
212
217
 
213
218
  def _start_service_log_streaming(
214
219
  self,
@@ -10,6 +10,7 @@ from snowflake.ml._internal.utils import (
10
10
  sql_identifier,
11
11
  )
12
12
  from snowflake.ml.model._client.sql import _base
13
+ from snowflake.ml.model._model_composer.model_method import constants
13
14
  from snowflake.snowpark import dataframe, functions as F, row, types as spt
14
15
  from snowflake.snowpark._internal import utils as snowpark_utils
15
16
 
@@ -333,6 +334,11 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
333
334
 
334
335
  args_sql = ", ".join(args_sql_list)
335
336
 
337
+ wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
338
+ if wide_input:
339
+ input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
340
+ args_sql = f"object_construct_keep_null({input_args_sql})"
341
+
336
342
  sql = textwrap.dedent(
337
343
  f"""WITH {','.join(with_statements)}
338
344
  SELECT *,
@@ -412,6 +418,11 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
412
418
 
413
419
  args_sql = ", ".join(args_sql_list)
414
420
 
421
+ wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
422
+ if wide_input:
423
+ input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
424
+ args_sql = f"object_construct_keep_null({input_args_sql})"
425
+
415
426
  sql = textwrap.dedent(
416
427
  f"""WITH {','.join(with_statements)}
417
428
  SELECT *,