snowflake-ml-python 1.5.4__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. snowflake/cortex/__init__.py +2 -0
  2. snowflake/cortex/_classify_text.py +36 -0
  3. snowflake/cortex/_complete.py +66 -35
  4. snowflake/cortex/_util.py +4 -4
  5. snowflake/ml/_internal/env_utils.py +11 -5
  6. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  7. snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
  8. snowflake/ml/_internal/telemetry.py +26 -2
  9. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  10. snowflake/ml/data/_internal/arrow_ingestor.py +284 -0
  11. snowflake/ml/data/data_connector.py +186 -0
  12. snowflake/ml/data/data_ingestor.py +45 -0
  13. snowflake/ml/data/data_source.py +23 -0
  14. snowflake/ml/data/ingestor_utils.py +62 -0
  15. snowflake/ml/data/torch_dataset.py +33 -0
  16. snowflake/ml/dataset/dataset.py +1 -13
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +23 -117
  19. snowflake/ml/feature_store/access_manager.py +7 -1
  20. snowflake/ml/feature_store/entity.py +19 -2
  21. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  22. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  23. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  24. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
  26. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +37 -0
  27. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +30 -0
  28. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +7 -0
  29. snowflake/ml/feature_store/examples/example_helper.py +278 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
  31. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +44 -0
  32. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  33. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +9 -0
  34. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  35. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
  36. snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
  37. snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
  38. snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
  39. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
  40. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +36 -0
  41. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +24 -0
  42. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +8 -0
  43. snowflake/ml/feature_store/feature_store.py +637 -76
  44. snowflake/ml/feature_store/feature_view.py +316 -9
  45. snowflake/ml/fileset/stage_fs.py +18 -10
  46. snowflake/ml/lineage/lineage_node.py +1 -1
  47. snowflake/ml/model/_client/model/model_impl.py +11 -2
  48. snowflake/ml/model/_client/model/model_version_impl.py +171 -20
  49. snowflake/ml/model/_client/ops/model_ops.py +105 -27
  50. snowflake/ml/model/_client/ops/service_ops.py +121 -0
  51. snowflake/ml/model/_client/service/model_deployment_spec.py +95 -0
  52. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +31 -0
  53. snowflake/ml/model/_client/sql/model_version.py +13 -4
  54. snowflake/ml/model/_client/sql/service.py +129 -0
  55. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
  56. snowflake/ml/model/_model_composer/model_composer.py +14 -14
  57. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +33 -17
  58. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -1
  59. snowflake/ml/model/_model_composer/model_method/function_generator.py +3 -3
  60. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
  61. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +3 -27
  62. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -32
  63. snowflake/ml/model/_model_composer/model_method/model_method.py +5 -2
  64. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  65. snowflake/ml/model/_packager/model_handlers/_base.py +30 -3
  66. snowflake/ml/model/_packager/model_handlers/_utils.py +58 -1
  67. snowflake/ml/model/_packager/model_handlers/catboost.py +52 -3
  68. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  69. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +9 -5
  70. snowflake/ml/model/_packager/model_handlers/lightgbm.py +80 -3
  71. snowflake/ml/model/_packager/model_handlers/llm.py +7 -3
  72. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  73. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  74. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  75. snowflake/ml/model/_packager/model_handlers/sklearn.py +87 -4
  76. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -2
  77. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  78. snowflake/ml/model/_packager/model_handlers/torchscript.py +8 -3
  79. snowflake/ml/model/_packager/model_handlers/xgboost.py +71 -3
  80. snowflake/ml/model/_packager/model_meta/model_meta.py +32 -2
  81. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +19 -0
  82. snowflake/ml/model/_packager/model_packager.py +2 -1
  83. snowflake/ml/model/_packager/model_runtime/model_runtime.py +7 -7
  84. snowflake/ml/model/model_signature.py +4 -4
  85. snowflake/ml/model/type_hints.py +2 -0
  86. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +1 -1
  87. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
  88. snowflake/ml/modeling/framework/base.py +28 -19
  89. snowflake/ml/modeling/impute/simple_imputer.py +26 -0
  90. snowflake/ml/modeling/pipeline/pipeline.py +7 -4
  91. snowflake/ml/registry/_manager/model_manager.py +16 -2
  92. snowflake/ml/registry/registry.py +100 -13
  93. snowflake/ml/utils/sql_client.py +22 -0
  94. snowflake/ml/version.py +1 -1
  95. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/METADATA +81 -2
  96. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/RECORD +99 -66
  97. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/WHEEL +1 -1
  98. snowflake/ml/_internal/lineage/data_source.py +0 -10
  99. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/LICENSE.txt +0 -0
  100. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ from snowflake.cortex._classify_text import ClassifyText
1
2
  from snowflake.cortex._complete import Complete, CompleteOptions
2
3
  from snowflake.cortex._extract_answer import ExtractAnswer
3
4
  from snowflake.cortex._sentiment import Sentiment
@@ -5,6 +6,7 @@ from snowflake.cortex._summarize import Summarize
5
6
  from snowflake.cortex._translate import Translate
6
7
 
7
8
  __all__ = [
9
+ "ClassifyText",
8
10
  "Complete",
9
11
  "CompleteOptions",
10
12
  "ExtractAnswer",
@@ -0,0 +1,36 @@
1
+ from typing import List, Optional, Union
2
+
3
+ from snowflake import snowpark
4
+ from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function
5
+ from snowflake.ml._internal import telemetry
6
+
7
+
8
+ @telemetry.send_api_usage_telemetry(
9
+ project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
10
+ )
11
+ def ClassifyText(
12
+ str_input: Union[str, snowpark.Column],
13
+ categories: Union[List[str], snowpark.Column],
14
+ session: Optional[snowpark.Session] = None,
15
+ ) -> Union[str, snowpark.Column]:
16
+ """Use the LLM inference service to classify the INPUT text into one of the target CATEGORIES.
17
+
18
+ Args:
19
+ str_input: A Column of strings to classify.
20
+ categories: A list of candidate categories to classify the INPUT text into.
21
+ session: The snowpark session to use. Will be inferred by context if not specified.
22
+
23
+ Returns:
24
+ A column of classification responses.
25
+ """
26
+
27
+ return _classify_text_impl("snowflake.cortex.classify_text", str_input, categories, session=session)
28
+
29
+
30
+ def _classify_text_impl(
31
+ function: str,
32
+ str_input: Union[str, snowpark.Column],
33
+ categories: Union[List[str], snowpark.Column],
34
+ session: Optional[snowpark.Session] = None,
35
+ ) -> Union[str, snowpark.Column]:
36
+ return call_sql_function(function, session, str_input, categories)
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  import logging
3
- from typing import Iterator, List, Optional, TypedDict, Union, cast
3
+ import time
4
+ from typing import Any, Callable, Iterator, List, Optional, TypedDict, Union, cast
4
5
  from urllib.parse import urlunparse
5
6
 
6
7
  import requests
@@ -52,12 +53,43 @@ class ResponseParseException(Exception):
52
53
  pass
53
54
 
54
55
 
56
+ _MAX_RETRY_SECONDS = 30
57
+
58
+
59
+ def retry(func: Callable[..., requests.Response]) -> Callable[..., requests.Response]:
60
+ def inner(*args: Any, **kwargs: Any) -> requests.Response:
61
+ deadline = cast(Optional[float], kwargs["deadline"])
62
+ kwargs = {key: value for key, value in kwargs.items() if key != "deadline"}
63
+ expRetrySeconds = 0.5
64
+ while True:
65
+ if deadline is not None and time.time() > deadline:
66
+ raise TimeoutError()
67
+ response = func(*args, **kwargs)
68
+ if response.status_code >= 200 and response.status_code < 300:
69
+ return response
70
+ retry_status_codes = [429, 503, 504]
71
+ if response.status_code not in retry_status_codes:
72
+ response.raise_for_status()
73
+ logger.debug(f"request failed with status code {response.status_code}, retrying")
74
+
75
+ # Formula: delay(i) = max(RetryAfterHeader, min(2^i, _MAX_RETRY_SECONDS)).
76
+ expRetrySeconds = min(2 * expRetrySeconds, _MAX_RETRY_SECONDS)
77
+ retrySeconds = expRetrySeconds
78
+ retryAfterHeader = response.headers.get("retry-after")
79
+ if retryAfterHeader is not None:
80
+ retrySeconds = max(retrySeconds, int(retryAfterHeader))
81
+ logger.debug(f"sleeping for {retrySeconds}s before retrying")
82
+ time.sleep(retrySeconds)
83
+
84
+ return inner
85
+
86
+
87
+ @retry
55
88
  def _call_complete_rest(
56
89
  model: str,
57
90
  prompt: Union[str, List[ConversationMessage]],
58
91
  options: Optional[CompleteOptions] = None,
59
92
  session: Optional[snowpark.Session] = None,
60
- stream: bool = False,
61
93
  ) -> requests.Response:
62
94
  session = session or context.get_active_session()
63
95
  if session is None:
@@ -78,7 +110,7 @@ def _call_complete_rest(
78
110
  scheme = "https"
79
111
  if hasattr(session.connection, "scheme"):
80
112
  scheme = session.connection.scheme
81
- url = urlunparse((scheme, session.connection.host, "api/v2/cortex/inference/complete", "", "", ""))
113
+ url = urlunparse((scheme, session.connection.host, "api/v2/cortex/inference:complete", "", "", ""))
82
114
 
83
115
  headers = {
84
116
  "Content-Type": "application/json",
@@ -88,7 +120,7 @@ def _call_complete_rest(
88
120
 
89
121
  data = {
90
122
  "model": model,
91
- "stream": stream,
123
+ "stream": True,
92
124
  }
93
125
  if isinstance(prompt, List):
94
126
  data["messages"] = prompt
@@ -104,33 +136,20 @@ def _call_complete_rest(
104
136
  if "top_p" in options:
105
137
  data["top_p"] = options["top_p"]
106
138
 
107
- logger.debug(f"making POST request to {url} (model={model}, stream={stream})")
108
- response = requests.post(
139
+ logger.debug(f"making POST request to {url} (model={model})")
140
+ return requests.post(
109
141
  url,
110
142
  json=data,
111
143
  headers=headers,
112
- stream=stream,
144
+ stream=True,
113
145
  )
114
- response.raise_for_status()
115
- return response
116
146
 
117
147
 
118
- def _process_rest_response(response: requests.Response, stream: bool = False) -> Union[str, Iterator[str]]:
119
- if stream:
120
- return _return_stream_response(response)
121
-
122
- try:
123
- content = response.json()["choices"][0]["message"]["content"]
124
- assert isinstance(content, str)
125
- return content
126
- except (KeyError, IndexError, AssertionError) as e:
127
- # Unlike the streaming case, errors are not ignored because a message must be returned.
128
- raise ResponseParseException("Failed to parse message from response.") from e
129
-
130
-
131
- def _return_stream_response(response: requests.Response) -> Iterator[str]:
148
+ def _return_stream_response(response: requests.Response, deadline: Optional[float]) -> Iterator[str]:
132
149
  client = SSEClient(response)
133
150
  for event in client.events():
151
+ if deadline is not None and time.time() > deadline:
152
+ raise TimeoutError()
134
153
  try:
135
154
  yield json.loads(event.data)["choices"][0]["delta"]["content"]
136
155
  except (json.JSONDecodeError, KeyError, IndexError):
@@ -206,19 +225,23 @@ def _complete_impl(
206
225
  prompt: Union[str, List[ConversationMessage], snowpark.Column],
207
226
  options: Optional[CompleteOptions] = None,
208
227
  session: Optional[snowpark.Session] = None,
209
- use_rest_api_experimental: bool = False,
210
228
  stream: bool = False,
211
229
  function: str = "snowflake.cortex.complete",
230
+ timeout: Optional[float] = None,
231
+ deadline: Optional[float] = None,
212
232
  ) -> Union[str, Iterator[str], snowpark.Column]:
213
- if use_rest_api_experimental:
233
+ if timeout is not None and deadline is not None:
234
+ raise ValueError('only one of "timeout" and "deadline" must be set')
235
+ if timeout is not None:
236
+ deadline = time.time() + timeout
237
+ if stream:
214
238
  if not isinstance(model, str):
215
239
  raise ValueError("in REST mode, 'model' must be a string")
216
240
  if not isinstance(prompt, str) and not isinstance(prompt, List):
217
241
  raise ValueError("in REST mode, 'prompt' must be a string or a list of ConversationMessage")
218
- response = _call_complete_rest(model, prompt, options, session=session, stream=stream)
219
- return _process_rest_response(response, stream=stream)
220
- if stream is True:
221
- raise ValueError("streaming can only be enabled in REST mode, set use_rest_api_experimental=True")
242
+ response = _call_complete_rest(model, prompt, options, session=session, deadline=deadline)
243
+ assert response.status_code >= 200 and response.status_code < 300
244
+ return _return_stream_response(response, deadline)
222
245
  return _complete_sql_impl(function, model, prompt, options, session)
223
246
 
224
247
 
@@ -231,8 +254,9 @@ def Complete(
231
254
  *,
232
255
  options: Optional[CompleteOptions] = None,
233
256
  session: Optional[snowpark.Session] = None,
234
- use_rest_api_experimental: bool = False,
235
257
  stream: bool = False,
258
+ timeout: Optional[float] = None,
259
+ deadline: Optional[float] = None,
236
260
  ) -> Union[str, Iterator[str], snowpark.Column]:
237
261
  """Complete calls into the LLM inference service to perform completion.
238
262
 
@@ -241,19 +265,26 @@ def Complete(
241
265
  prompt: A Column of prompts to send to the LLM.
242
266
  options: A instance of snowflake.cortex.CompleteOptions
243
267
  session: The snowpark session to use. Will be inferred by context if not specified.
244
- use_rest_api_experimental (bool): Toggles between the use of SQL and REST implementation. This feature is
245
- experimental and can be removed at any time.
246
268
  stream (bool): Enables streaming. When enabled, a generator function is returned that provides the streaming
247
269
  output as it is received. Each update is a string containing the new text content since the previous update.
248
- The use of streaming requires the experimental use_rest_api_experimental flag to be enabled.
270
+ timeout (float): Timeout in seconds to retry failed REST requests.
271
+ deadline (float): Time in seconds since the epoch (as returned by time.time()) to retry failed REST requests.
249
272
 
250
273
  Raises:
251
- ValueError: If `stream` is set to True and `use_rest_api_experimental` is set to False.
274
+ ValueError: incorrect argument.
252
275
 
253
276
  Returns:
254
277
  A column of string responses.
255
278
  """
256
279
  try:
257
- return _complete_impl(model, prompt, options, session, use_rest_api_experimental, stream)
280
+ return _complete_impl(
281
+ model,
282
+ prompt,
283
+ options=options,
284
+ session=session,
285
+ stream=stream,
286
+ timeout=timeout,
287
+ deadline=deadline,
288
+ )
258
289
  except ValueError as err:
259
290
  raise err
snowflake/cortex/_util.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Dict, Optional, Union, cast
1
+ from typing import Dict, List, Optional, Union, cast
2
2
 
3
3
  from snowflake import snowpark
4
4
  from snowflake.snowpark import context, functions
@@ -23,7 +23,7 @@ class SnowflakeConfigurationException(Exception):
23
23
  def call_sql_function(
24
24
  function: str,
25
25
  session: Optional[snowpark.Session],
26
- *args: Union[str, snowpark.Column, Dict[str, Union[int, float]]],
26
+ *args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]],
27
27
  ) -> Union[str, snowpark.Column]:
28
28
  handle_as_column = False
29
29
 
@@ -40,7 +40,7 @@ def call_sql_function(
40
40
 
41
41
 
42
42
  def _call_sql_function_column(
43
- function: str, *args: Union[str, snowpark.Column, Dict[str, Union[int, float]]]
43
+ function: str, *args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]]
44
44
  ) -> snowpark.Column:
45
45
  return cast(snowpark.Column, functions.builtin(function)(*args))
46
46
 
@@ -48,7 +48,7 @@ def _call_sql_function_column(
48
48
  def _call_sql_function_immediate(
49
49
  function: str,
50
50
  session: Optional[snowpark.Session],
51
- *args: Union[str, snowpark.Column, Dict[str, Union[int, float]]],
51
+ *args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]],
52
52
  ) -> str:
53
53
  session = session or context.get_active_session()
54
54
  if session is None:
@@ -27,7 +27,6 @@ class CONDA_OS(Enum):
27
27
  NO_ARCH = "noarch"
28
28
 
29
29
 
30
- _SNOWFLAKE_CONDA_CHANNEL_URL = "https://repo.anaconda.com/pkgs/snowflake"
31
30
  _NODEFAULTS = "nodefaults"
32
31
  _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {}
33
32
  _SNOWFLAKE_CONDA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {}
@@ -36,6 +35,7 @@ _SUPPORTED_PACKAGE_SPEC_OPS = ["==", ">=", "<=", ">", "<"]
36
35
  DEFAULT_CHANNEL_NAME = ""
37
36
  SNOWML_SPROC_ENV = "IN_SNOWML_SPROC"
38
37
  SNOWPARK_ML_PKG_NAME = "snowflake-ml-python"
38
+ SNOWFLAKE_CONDA_CHANNEL_URL = "https://repo.anaconda.com/pkgs/snowflake"
39
39
 
40
40
 
41
41
  def _validate_pip_requirement_string(req_str: str) -> requirements.Requirement:
@@ -370,7 +370,7 @@ def get_matched_package_versions_in_snowflake_conda_channel(
370
370
 
371
371
  assert not snowpark_utils.is_in_stored_procedure() # type: ignore[no-untyped-call]
372
372
 
373
- url = f"{_SNOWFLAKE_CONDA_CHANNEL_URL}/{conda_os.value}/repodata.json"
373
+ url = f"{SNOWFLAKE_CONDA_CHANNEL_URL}/{conda_os.value}/repodata.json"
374
374
 
375
375
  if req.name not in _SNOWFLAKE_CONDA_PACKAGE_CACHE:
376
376
  try:
@@ -477,6 +477,7 @@ def save_conda_env_file(
477
477
  path: pathlib.Path,
478
478
  conda_chan_deps: DefaultDict[str, List[requirements.Requirement]],
479
479
  python_version: str,
480
+ default_channel_override: str = SNOWFLAKE_CONDA_CHANNEL_URL,
480
481
  ) -> None:
481
482
  """Generate conda.yml file given a dict of dependencies after validation.
482
483
  The channels part of conda.yml file will contains Snowflake Anaconda Channel, nodefaults and all channel names
@@ -489,6 +490,7 @@ def save_conda_env_file(
489
490
  path: Path to the conda.yml file.
490
491
  conda_chan_deps: Dict of conda dependencies after validated.
491
492
  python_version: A string 'major.minor' showing python version relate to model.
493
+ default_channel_override: The default channel to be put in the first place of the channels section.
492
494
  """
493
495
  assert path.suffix in [".yml", ".yaml"], "Conda environment file should have extension of yml or yaml."
494
496
  path.parent.mkdir(parents=True, exist_ok=True)
@@ -499,7 +501,11 @@ def save_conda_env_file(
499
501
  channels = list(dict(sorted(conda_chan_deps.items(), key=lambda item: len(item[1]), reverse=True)).keys())
500
502
  if DEFAULT_CHANNEL_NAME in channels:
501
503
  channels.remove(DEFAULT_CHANNEL_NAME)
502
- env["channels"] = [_SNOWFLAKE_CONDA_CHANNEL_URL] + channels + [_NODEFAULTS]
504
+
505
+ if default_channel_override in channels:
506
+ channels.remove(default_channel_override)
507
+
508
+ env["channels"] = [default_channel_override] + channels + [_NODEFAULTS]
503
509
  env["dependencies"] = [f"python=={python_version}.*"]
504
510
  for chan, reqs in conda_chan_deps.items():
505
511
  env["dependencies"].extend(
@@ -567,8 +573,8 @@ def load_conda_env_file(
567
573
  python_version = None
568
574
 
569
575
  channels = env.get("channels", [])
570
- if _SNOWFLAKE_CONDA_CHANNEL_URL in channels:
571
- channels.remove(_SNOWFLAKE_CONDA_CHANNEL_URL)
576
+ if len(channels) >= 1:
577
+ channels = channels[1:] # Skip the first channel which is the default channel
572
578
  if _NODEFAULTS in channels:
573
579
  channels.remove(_NODEFAULTS)
574
580
 
@@ -4,7 +4,10 @@ ATTRIBUTE_NOT_SET = (
4
4
  "-differences."
5
5
  )
6
6
  SIZE_MISMATCH = "Size mismatch: {}={}, {}={}."
7
- INVALID_MODEL_PARAM = "Invalid parameter {} for model {}. Valid parameters: {}."
7
+ INVALID_MODEL_PARAM = (
8
+ "Invalid parameter {} for model {}. Valid parameters: {}."
9
+ "Note: Scikit learn params cannot be set until the model has been fit."
10
+ )
8
11
  UNSUPPORTED_MODEL_CONVERSION = "Object doesn't support {}. Please use {}."
9
12
  INCOMPATIBLE_NEW_SKLEARN_PARAM = "Incompatible scikit-learn version: {} requires scikit-learn>={}. Installed: {}."
10
13
  REMOVED_SKLEARN_PARAM = "Incompatible scikit-learn version: {} is removed in scikit-learn>={}. Installed: {}."
@@ -1,9 +1,9 @@
1
1
  import copy
2
2
  import functools
3
- from typing import Any, Callable, List, Optional
3
+ from typing import Any, Callable, List, Optional, get_args
4
4
 
5
5
  from snowflake import snowpark
6
- from snowflake.ml._internal.lineage import data_source
6
+ from snowflake.ml.data import data_source
7
7
 
8
8
  _DATA_SOURCES_ATTR = "_data_sources"
9
9
 
@@ -39,7 +39,7 @@ def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
39
39
  result: Optional[List[data_source.DataSource]] = None
40
40
  for arg in args:
41
41
  srcs = getattr(arg, _DATA_SOURCES_ATTR, None)
42
- if isinstance(srcs, list) and all(isinstance(s, data_source.DataSource) for s in srcs):
42
+ if isinstance(srcs, list) and all(isinstance(s, get_args(data_source.DataSource)) for s in srcs):
43
43
  if result is None:
44
44
  result = []
45
45
  result += srcs
@@ -49,7 +49,7 @@ def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
49
49
  def set_data_sources(obj: Any, data_sources: Optional[List[data_source.DataSource]]) -> None:
50
50
  """Helper method for attaching data sources to an object"""
51
51
  if data_sources:
52
- assert all(isinstance(ds, data_source.DataSource) for ds in data_sources)
52
+ assert all(isinstance(ds, get_args(data_source.DataSource)) for ds in data_sources)
53
53
  setattr(obj, _DATA_SOURCES_ATTR, data_sources)
54
54
 
55
55
 
@@ -44,6 +44,20 @@ _Args = ParamSpec("_Args")
44
44
  _ReturnValue = TypeVar("_ReturnValue")
45
45
 
46
46
 
47
+ @enum.unique
48
+ class TelemetryProject(enum.Enum):
49
+ MLOPS = "MLOps"
50
+ MODELING = "ModelDevelopment"
51
+ # TODO: Update with remaining projects.
52
+
53
+
54
+ @enum.unique
55
+ class TelemetrySubProject(enum.Enum):
56
+ MONITORING = "Monitoring"
57
+ REGISTRY = "ModelManagement"
58
+ # TODO: Update with remaining subprojects.
59
+
60
+
47
61
  @enum.unique
48
62
  class TelemetryField(enum.Enum):
49
63
  # constants
@@ -277,6 +291,7 @@ def send_api_usage_telemetry(
277
291
  ]
278
292
  ] = None,
279
293
  sfqids_extractor: Optional[Callable[..., List[str]]] = None,
294
+ subproject_extractor: Optional[Callable[[Any], str]] = None,
280
295
  custom_tags: Optional[Dict[str, Union[bool, int, str, float]]] = None,
281
296
  ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, _ReturnValue]]:
282
297
  """
@@ -290,6 +305,7 @@ def send_api_usage_telemetry(
290
305
  conn_attr_name: Name of the SnowflakeConnection attribute in `self`.
291
306
  api_calls_extractor: Extract API calls from `self`.
292
307
  sfqids_extractor: Extract sfqids from `self`.
308
+ subproject_extractor: Extract subproject at runtime from `self`.
293
309
  custom_tags: Custom tags.
294
310
 
295
311
  Returns:
@@ -297,10 +313,14 @@ def send_api_usage_telemetry(
297
313
 
298
314
  Raises:
299
315
  TypeError: If `conn_attr_name` is provided but the conn attribute is not of type SnowflakeConnection.
316
+ ValueError: If both `subproject` and `subproject_extractor` are provided
300
317
 
301
318
  # noqa: DAR402
302
319
  """
303
320
 
321
+ if subproject is not None and subproject_extractor is not None:
322
+ raise ValueError("Specifying both subproject and subproject_extractor is not allowed")
323
+
304
324
  def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, _ReturnValue]:
305
325
  @functools.wraps(func)
306
326
  def wrap(*args: Any, **kwargs: Any) -> _ReturnValue:
@@ -322,9 +342,13 @@ def send_api_usage_telemetry(
322
342
  if sfqids_extractor:
323
343
  sfqids = sfqids_extractor(args[0])
324
344
 
345
+ subproject_name = subproject
346
+ if subproject_extractor is not None:
347
+ subproject_name = subproject_extractor(args[0])
348
+
325
349
  statement_params = get_function_usage_statement_params(
326
350
  project=project,
327
- subproject=subproject,
351
+ subproject=subproject_name,
328
352
  function_category=TelemetryField.FUNC_CAT_USAGE.value,
329
353
  function_name=_get_full_func_name(func),
330
354
  function_parameters=params,
@@ -381,7 +405,7 @@ def send_api_usage_telemetry(
381
405
  raise e.original_exception from e
382
406
 
383
407
  # TODO(hayu): [SNOW-750287] Optimize telemetry client to a singleton.
384
- telemetry = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject)
408
+ telemetry = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject_name)
385
409
  telemetry_args = dict(
386
410
  func_name=_get_full_func_name(func),
387
411
  function_category=TelemetryField.FUNC_CAT_USAGE.value,
@@ -26,30 +26,11 @@ def get_valid_pkg_versions_supported_in_snowflake_conda_channel(
26
26
  pkg_versions: List[str], session: Session, subproject: Optional[str] = None
27
27
  ) -> List[str]:
28
28
  if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
29
- return _get_valid_pkg_versions_supported_in_snowflake_conda_channel_sync(pkg_versions, session, subproject)
29
+ return pkg_versions
30
30
  else:
31
31
  return _get_valid_pkg_versions_supported_in_snowflake_conda_channel_async(pkg_versions, session, subproject)
32
32
 
33
33
 
34
- def _get_valid_pkg_versions_supported_in_snowflake_conda_channel_sync(
35
- pkg_versions: List[str], session: Session, subproject: Optional[str] = None
36
- ) -> List[str]:
37
- for pkg_version in pkg_versions:
38
- if pkg_version not in cache:
39
- pkg_version_list = _query_pkg_version_supported_in_snowflake_conda_channel(
40
- pkg_version=pkg_version, session=session, block=True, subproject=subproject
41
- )
42
- assert isinstance(pkg_version_list, list) # keep mypy happy
43
- try:
44
- cache[pkg_version] = pkg_version_list[0]["VERSION"]
45
- except IndexError:
46
- cache[pkg_version] = None
47
-
48
- pkg_version_conda_list = _get_conda_packages_and_emit_warnings(pkg_versions)
49
-
50
- return pkg_version_conda_list
51
-
52
-
53
34
  def _get_valid_pkg_versions_supported_in_snowflake_conda_channel_async(
54
35
  pkg_versions: List[str], session: Session, subproject: Optional[str] = None
55
36
  ) -> List[str]:
@@ -60,7 +41,11 @@ def _get_valid_pkg_versions_supported_in_snowflake_conda_channel_async(
60
41
  async_job = _query_pkg_version_supported_in_snowflake_conda_channel(
61
42
  pkg_version=pkg_version, session=session, block=False, subproject=subproject
62
43
  )
63
- assert isinstance(async_job, AsyncJob)
44
+ if isinstance(async_job, list):
45
+ raise RuntimeError(
46
+ "Async job was expected, executed query was returned. Please contact Snowflake support."
47
+ )
48
+
64
49
  pkg_version_async_job_list.append((pkg_version, async_job))
65
50
 
66
51
  # Populate the cache.
@@ -143,7 +128,8 @@ def _get_conda_packages_and_emit_warnings(pkg_versions: List[str]) -> List[str]:
143
128
  warnings.warn(
144
129
  f"Package {', '.join([pkg[0] for pkg in pkg_version_warning_list])} is not supported "
145
130
  f"in snowflake conda channel for python runtime "
146
- f"{', '.join([pkg[1] for pkg in pkg_version_warning_list])}."
131
+ f"{', '.join([pkg[1] for pkg in pkg_version_warning_list])}.",
132
+ stacklevel=1,
147
133
  )
148
134
 
149
135
  return pkg_version_conda_list