snowflake-ml-python 1.7.4__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. snowflake/cortex/_complete.py +58 -3
  2. snowflake/ml/_internal/env_utils.py +64 -21
  3. snowflake/ml/_internal/file_utils.py +18 -4
  4. snowflake/ml/_internal/platform_capabilities.py +3 -0
  5. snowflake/ml/_internal/relax_version_strategy.py +16 -0
  6. snowflake/ml/_internal/telemetry.py +25 -0
  7. snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
  8. snowflake/ml/feature_store/feature_store.py +18 -0
  9. snowflake/ml/feature_store/feature_view.py +46 -1
  10. snowflake/ml/fileset/fileset.py +0 -1
  11. snowflake/ml/jobs/_utils/constants.py +31 -1
  12. snowflake/ml/jobs/_utils/payload_utils.py +232 -72
  13. snowflake/ml/jobs/_utils/spec_utils.py +78 -38
  14. snowflake/ml/jobs/decorators.py +8 -25
  15. snowflake/ml/jobs/job.py +4 -4
  16. snowflake/ml/jobs/manager.py +5 -0
  17. snowflake/ml/model/_client/model/model_version_impl.py +1 -1
  18. snowflake/ml/model/_client/ops/model_ops.py +107 -14
  19. snowflake/ml/model/_client/ops/service_ops.py +1 -1
  20. snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
  21. snowflake/ml/model/_client/sql/model_version.py +58 -0
  22. snowflake/ml/model/_client/sql/service.py +8 -2
  23. snowflake/ml/model/_model_composer/model_composer.py +50 -3
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  26. snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
  27. snowflake/ml/model/_packager/model_env/model_env.py +49 -29
  28. snowflake/ml/model/_packager/model_handlers/_utils.py +8 -4
  29. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +44 -24
  30. snowflake/ml/model/_packager/model_handlers/keras.py +226 -0
  31. snowflake/ml/model/_packager/model_handlers/pytorch.py +51 -20
  32. snowflake/ml/model/_packager/model_handlers/sklearn.py +25 -3
  33. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +73 -21
  34. snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -72
  35. snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
  36. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  37. snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
  38. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
  39. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
  40. snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
  41. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  42. snowflake/ml/model/_packager/model_meta/model_meta.py +6 -2
  43. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +16 -0
  44. snowflake/ml/model/_packager/model_packager.py +3 -5
  45. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
  46. snowflake/ml/model/_packager/model_runtime/model_runtime.py +8 -1
  47. snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
  48. snowflake/ml/model/_signatures/builtins_handler.py +20 -9
  49. snowflake/ml/model/_signatures/core.py +54 -33
  50. snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
  51. snowflake/ml/model/_signatures/numpy_handler.py +12 -20
  52. snowflake/ml/model/_signatures/pandas_handler.py +28 -37
  53. snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
  54. snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
  55. snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
  56. snowflake/ml/model/_signatures/utils.py +120 -8
  57. snowflake/ml/model/custom_model.py +13 -4
  58. snowflake/ml/model/model_signature.py +39 -13
  59. snowflake/ml/model/type_hints.py +28 -2
  60. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
  61. snowflake/ml/modeling/metrics/ranking.py +3 -0
  62. snowflake/ml/modeling/metrics/regression.py +3 -0
  63. snowflake/ml/modeling/pipeline/pipeline.py +18 -1
  64. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
  65. snowflake/ml/modeling/preprocessing/polynomial_features.py +2 -2
  66. snowflake/ml/registry/_manager/model_manager.py +55 -7
  67. snowflake/ml/registry/registry.py +52 -4
  68. snowflake/ml/version.py +1 -1
  69. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/METADATA +336 -27
  70. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/RECORD +73 -66
  71. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/WHEEL +1 -1
  72. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info/licenses}/LICENSE.txt +0 -0
  73. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,15 @@ logger = logging.getLogger(__name__)
23
23
  _REST_COMPLETE_URL = "/api/v2/cortex/inference:complete"
24
24
 
25
25
 
26
+ class ResponseFormat(TypedDict):
27
+ """Represents an object describing response format config for structured-output mode"""
28
+
29
+ type: str
30
+ """The response format type (e.g. "json")"""
31
+ schema: Dict[str, Any]
32
+ """The schema defining the structure of the response. For json it should be a valid json schema object"""
33
+
34
+
26
35
  class ConversationMessage(TypedDict):
27
36
  """Represents an conversation interaction."""
28
37
 
@@ -53,6 +62,9 @@ class CompleteOptions(TypedDict):
53
62
  """ A boolean value that controls whether Cortex Guard filters unsafe or harmful responses
54
63
  from the language model. """
55
64
 
65
+ response_format: NotRequired[ResponseFormat]
66
+ """ An object describing response format config for structured-output mode """
67
+
56
68
 
57
69
  class ResponseParseException(Exception):
58
70
  """This exception is raised when the server response cannot be parsed."""
@@ -108,6 +120,32 @@ def _make_common_request_headers() -> Dict[str, str]:
108
120
  return headers
109
121
 
110
122
 
123
+ def _validate_response_format_object(options: CompleteOptions) -> None:
124
+ """Validate the response format object for structured-output mode.
125
+
126
+ More details can be found in:
127
+ docs.snowflake.com/en/user-guide/snowflake-cortex/complete-structured-outputs#using-complete-structured-outputs
128
+
129
+ Args:
130
+ options: The complete options object.
131
+
132
+ Raises:
133
+ ValueError: If the response format object is invalid or missing required fields.
134
+ """
135
+ if options is not None and options.get("response_format") is not None:
136
+ options_obj = options.get("response_format")
137
+ if not isinstance(options_obj, dict):
138
+ raise ValueError("'response_format' should be an object")
139
+ if options_obj.get("type") is None:
140
+ raise ValueError("'type' cannot be empty for 'response_format' object")
141
+ if not isinstance(options_obj.get("type"), str):
142
+ raise ValueError("'type' needs to be a str for 'response_format' object")
143
+ if options_obj.get("schema") is None:
144
+ raise ValueError("'schema' cannot be empty for 'response_format' object")
145
+ if not isinstance(options_obj.get("schema"), dict):
146
+ raise ValueError("'schema' needs to be a dict for 'response_format' object")
147
+
148
+
111
149
  def _make_request_body(
112
150
  model: str,
113
151
  prompt: Union[str, List[ConversationMessage]],
@@ -136,12 +174,16 @@ def _make_request_body(
136
174
  "response_when_unsafe": "Response filtered by Cortex Guard",
137
175
  }
138
176
  data["guardrails"] = guardrails_options
177
+ if "response_format" in options:
178
+ data["response_format"] = options["response_format"]
179
+
139
180
  return data
140
181
 
141
182
 
142
183
  # XP endpoint returns a dict response which needs to be converted to a format which can
143
184
  # be consumed by the SSEClient. This method does that.
144
185
  def _xp_dict_to_response(raw_resp: Dict[str, Any]) -> requests.Response:
186
+
145
187
  response = requests.Response()
146
188
  response.status_code = int(raw_resp["status"])
147
189
  response.headers = raw_resp["headers"]
@@ -159,7 +201,6 @@ def _xp_dict_to_response(raw_resp: Dict[str, Any]) -> requests.Response:
159
201
  data = json.loads(data)
160
202
  except json.JSONDecodeError:
161
203
  raise ValueError(f"Request failed (request id: {request_id})")
162
-
163
204
  if response.status_code < 200 or response.status_code >= 300:
164
205
  if "message" not in data:
165
206
  raise ValueError(f"Request failed (request id: {request_id})")
@@ -241,11 +282,21 @@ def _return_stream_response(response: requests.Response, deadline: Optional[floa
241
282
  if deadline is not None and time.time() > deadline:
242
283
  raise TimeoutError()
243
284
  try:
244
- yield json.loads(event.data)["choices"][0]["delta"]["content"]
285
+ parsed_resp = json.loads(event.data)
286
+ except json.JSONDecodeError:
287
+ raise ResponseParseException("Server response cannot be parsed")
288
+ try:
289
+ yield parsed_resp["choices"][0]["delta"]["content"]
245
290
  except (json.JSONDecodeError, KeyError, IndexError):
246
291
  # For the sake of evolution of the output format,
247
292
  # ignore stream messages that don't match the expected format.
248
- pass
293
+
294
+ # This is the case of midstream errors which were introduced specifically for structured output.
295
+ # TODO: discuss during code review
296
+ if parsed_resp.get("error"):
297
+ yield json.dumps(parsed_resp)
298
+ else:
299
+ pass
249
300
 
250
301
 
251
302
  def _complete_call_sql_function_snowpark(
@@ -291,6 +342,8 @@ def _complete_non_streaming_impl(
291
342
  raise ValueError("'model' cannot be a snowpark.Column when 'prompt' is a string.")
292
343
  if isinstance(options, snowpark.Column):
293
344
  raise ValueError("'options' cannot be a snowpark.Column when 'prompt' is a string.")
345
+ if options and not isinstance(options, snowpark.Column):
346
+ _validate_response_format_object(options)
294
347
  return _complete_non_streaming_immediate(
295
348
  snow_api_xp_request_handler=snow_api_xp_request_handler,
296
349
  model=model,
@@ -309,6 +362,8 @@ def _complete_rest(
309
362
  session: Optional[snowpark.Session] = None,
310
363
  deadline: Optional[float] = None,
311
364
  ) -> Iterator[str]:
365
+ if options:
366
+ _validate_response_format_object(options)
312
367
  if snow_api_xp_request_handler is not None:
313
368
  response = _call_complete_xp(
314
369
  snow_api_xp_request_handler=snow_api_xp_request_handler,
@@ -12,7 +12,7 @@ import yaml
12
12
  from packaging import requirements, specifiers, version
13
13
 
14
14
  import snowflake.connector
15
- from snowflake.ml._internal import env as snowml_env
15
+ from snowflake.ml._internal import env as snowml_env, relax_version_strategy
16
16
  from snowflake.ml._internal.utils import query_result_checker
17
17
  from snowflake.snowpark import context, exceptions, session
18
18
 
@@ -56,6 +56,8 @@ def _validate_pip_requirement_string(req_str: str) -> requirements.Requirement:
56
56
 
57
57
  if r.name == "python":
58
58
  raise ValueError("Don't specify python as a dependency, use python version argument instead.")
59
+ if r.name == "cuda":
60
+ raise ValueError("Don't specify cuda as a dependency, use cuda version argument instead.")
59
61
  except requirements.InvalidRequirement:
60
62
  raise ValueError(f"Invalid package requirement {req_str} found.")
61
63
 
@@ -313,19 +315,14 @@ def get_package_spec_with_supported_ops_only(req: requirements.Requirement) -> r
313
315
  return new_req
314
316
 
315
317
 
316
- def relax_requirement_version(req: requirements.Requirement) -> requirements.Requirement:
317
- """Relax version specifier from a requirement. It detects any ==x.y.z in specifiers and replaced with
318
- >=x.y, <(x+1)
319
-
320
- Args:
321
- req: The requirement that version specifier to be removed.
322
-
323
- Returns:
324
- A new requirement object after relaxations.
325
- """
326
- new_req = copy.deepcopy(req)
318
+ def _relax_specifier_set(
319
+ specifier_set: specifiers.SpecifierSet, strategy: relax_version_strategy.RelaxVersionStrategy
320
+ ) -> specifiers.SpecifierSet:
321
+ if strategy == relax_version_strategy.RelaxVersionStrategy.NO_RELAX:
322
+ return specifier_set
323
+ specifier_set = copy.deepcopy(specifier_set)
327
324
  relaxed_specifier_set = set()
328
- for spec in new_req.specifier._specs:
325
+ for spec in specifier_set._specs:
329
326
  if spec.operator != "==":
330
327
  relaxed_specifier_set.add(spec)
331
328
  continue
@@ -337,9 +334,40 @@ def relax_requirement_version(req: requirements.Requirement) -> requirements.Req
337
334
  relaxed_specifier_set.add(spec)
338
335
  continue
339
336
  assert pinned_version is not None
340
- relaxed_specifier_set.add(specifiers.Specifier(f">={pinned_version.major}.{pinned_version.minor}"))
341
- relaxed_specifier_set.add(specifiers.Specifier(f"<{pinned_version.major + 1}"))
342
- new_req.specifier._specs = frozenset(relaxed_specifier_set)
337
+ if strategy == relax_version_strategy.RelaxVersionStrategy.PATCH:
338
+ relaxed_specifier_set.add(specifiers.Specifier(f">={pinned_version.major}.{pinned_version.minor}"))
339
+ relaxed_specifier_set.add(specifiers.Specifier(f"<{pinned_version.major}.{pinned_version.minor+1}"))
340
+ elif strategy == relax_version_strategy.RelaxVersionStrategy.MINOR:
341
+ relaxed_specifier_set.add(specifiers.Specifier(f">={pinned_version.major}.{pinned_version.minor}"))
342
+ relaxed_specifier_set.add(specifiers.Specifier(f"<{pinned_version.major + 1}"))
343
+ elif strategy == relax_version_strategy.RelaxVersionStrategy.MAJOR:
344
+ relaxed_specifier_set.add(specifiers.Specifier(f">={pinned_version.major}"))
345
+ relaxed_specifier_set.add(specifiers.Specifier(f"<{pinned_version.major + 1}"))
346
+ specifier_set._specs = frozenset(relaxed_specifier_set)
347
+ return specifier_set
348
+
349
+
350
+ def relax_requirement_version(req: requirements.Requirement) -> requirements.Requirement:
351
+ """Relax version specifier from a requirement. It detects any ==x.y.z in specifiers and replaced with relaxed
352
+ version specifier based on the strategy defined in RELAX_VERSION_STRATEGY_MAP.
353
+
354
+ NO_RELAX: No relaxation.
355
+ PATCH: >=x.y, <x.(y+1)
356
+ MINOR (default): >=x.y, <(x+1)
357
+ MAJOR: >=x, <(x+1)
358
+
359
+
360
+ Args:
361
+ req: The requirement that version specifier to be removed.
362
+
363
+ Returns:
364
+ A new requirement object after relaxations.
365
+ """
366
+ new_req = copy.deepcopy(req)
367
+ strategy = relax_version_strategy.RELAX_VERSION_STRATEGY_MAP.get(
368
+ req.name, relax_version_strategy.RelaxVersionStrategy.MINOR
369
+ )
370
+ new_req.specifier = _relax_specifier_set(new_req.specifier, strategy)
343
371
  return new_req
344
372
 
345
373
 
@@ -431,10 +459,11 @@ def save_conda_env_file(
431
459
  path: pathlib.Path,
432
460
  conda_chan_deps: DefaultDict[str, List[requirements.Requirement]],
433
461
  python_version: str,
462
+ cuda_version: Optional[str] = None,
434
463
  default_channel_override: str = SNOWFLAKE_CONDA_CHANNEL_URL,
435
464
  ) -> None:
436
465
  """Generate conda.yml file given a dict of dependencies after validation.
437
- The channels part of conda.yml file will contains Snowflake Anaconda Channel, nodefaults and all channel names
466
+ The channels part of conda.yml file will contain Snowflake Anaconda Channel, nodefaults and all channel names
438
467
  in keys of the dict, ordered by the number of the packages which belongs to.
439
468
  The dependencies part of conda.yml file will contains requirements specifications. If the requirements is in the
440
469
  value list whose key is DEFAULT_CHANNEL_NAME, then the channel won't be specified explicitly. Otherwise, it will be
@@ -443,7 +472,8 @@ def save_conda_env_file(
443
472
  Args:
444
473
  path: Path to the conda.yml file.
445
474
  conda_chan_deps: Dict of conda dependencies after validated.
446
- python_version: A string 'major.minor' showing python version relate to model.
475
+ python_version: A string 'major.minor' for the model's python version.
476
+ cuda_version: A string 'major.minor' for the model's cuda version.
447
477
  default_channel_override: The default channel to be put in the first place of the channels section.
448
478
  """
449
479
  assert path.suffix in [".yml", ".yaml"], "Conda environment file should have extension of yml or yaml."
@@ -461,6 +491,10 @@ def save_conda_env_file(
461
491
 
462
492
  env["channels"] = [default_channel_override] + channels + [_NODEFAULTS]
463
493
  env["dependencies"] = [f"python=={python_version}.*"]
494
+
495
+ if cuda_version is not None:
496
+ env["dependencies"].extend([f"nvidia::cuda=={cuda_version}.*"])
497
+
464
498
  for chan, reqs in conda_chan_deps.items():
465
499
  env["dependencies"].extend(
466
500
  [f"{chan}::{str(req)}" if chan != DEFAULT_CHANNEL_NAME else str(req) for req in reqs]
@@ -487,7 +521,12 @@ def save_requirements_file(path: pathlib.Path, pip_deps: List[requirements.Requi
487
521
 
488
522
  def load_conda_env_file(
489
523
  path: pathlib.Path,
490
- ) -> Tuple[DefaultDict[str, List[requirements.Requirement]], Optional[List[requirements.Requirement]], Optional[str]]:
524
+ ) -> Tuple[
525
+ DefaultDict[str, List[requirements.Requirement]],
526
+ Optional[List[requirements.Requirement]],
527
+ Optional[str],
528
+ Optional[str],
529
+ ]:
491
530
  """Read conda.yml file to get a dict of dependencies after validation.
492
531
  The channels part of conda.yml file will be processed with following rules:
493
532
  1. If it is Snowflake Anaconda Channel, ignore as it is default.
@@ -515,7 +554,7 @@ def load_conda_env_file(
515
554
  and a string 'major.minor.patchlevel' of python version.
516
555
  """
517
556
  if not path.exists():
518
- return collections.defaultdict(list), None, None
557
+ return collections.defaultdict(list), None, None, None
519
558
 
520
559
  with open(path, encoding="utf-8") as f:
521
560
  env = yaml.safe_load(stream=f)
@@ -526,6 +565,7 @@ def load_conda_env_file(
526
565
  pip_deps = []
527
566
 
528
567
  python_version = None
568
+ cuda_version = None
529
569
 
530
570
  channels = env.get("channels", [])
531
571
  if len(channels) >= 1:
@@ -541,6 +581,9 @@ def load_conda_env_file(
541
581
  # ver is str: python w/ specifier
542
582
  if ver:
543
583
  python_version = ver
584
+ elif dep.startswith("nvidia::cuda"):
585
+ r = requirements.Requirement(dep.split("nvidia::")[1])
586
+ cuda_version = list(r.specifier)[0].version.strip(".*")
544
587
  elif ver is None:
545
588
  deps.append(dep)
546
589
  elif isinstance(dep, dict) and "pip" in dep:
@@ -555,7 +598,7 @@ def load_conda_env_file(
555
598
  if channel not in conda_dep_dict:
556
599
  conda_dep_dict[channel] = []
557
600
 
558
- return conda_dep_dict, pip_deps_list if pip_deps_list else None, python_version
601
+ return conda_dep_dict, pip_deps_list if pip_deps_list else None, python_version, cuda_version
559
602
 
560
603
 
561
604
  def load_requirements_file(path: pathlib.Path) -> List[requirements.Requirement]:
@@ -23,6 +23,7 @@ from typing import (
23
23
  Tuple,
24
24
  Union,
25
25
  )
26
+ from urllib import parse
26
27
 
27
28
  import cloudpickle
28
29
 
@@ -294,7 +295,7 @@ def _retry_on_sql_error(exception: Exception) -> bool:
294
295
  def upload_directory_to_stage(
295
296
  session: snowpark.Session,
296
297
  local_path: pathlib.Path,
297
- stage_path: pathlib.PurePosixPath,
298
+ stage_path: Union[pathlib.PurePosixPath, parse.ParseResult],
298
299
  *,
299
300
  statement_params: Optional[Dict[str, Any]] = None,
300
301
  ) -> None:
@@ -314,9 +315,22 @@ def upload_directory_to_stage(
314
315
  root_path = pathlib.Path(root)
315
316
  for filename in filenames:
316
317
  local_file_path = root_path / filename
317
- stage_dir_path = (
318
- stage_path / pathlib.PurePosixPath(local_file_path.relative_to(local_path).as_posix()).parent
319
- )
318
+ relative_path = pathlib.PurePosixPath(local_file_path.relative_to(local_path).as_posix())
319
+
320
+ if isinstance(stage_path, parse.ParseResult):
321
+ relative_stage_path = (pathlib.PosixPath(stage_path.path) / relative_path).parent
322
+ new_url = parse.ParseResult(
323
+ scheme=stage_path.scheme,
324
+ netloc=stage_path.netloc,
325
+ path=str(relative_stage_path),
326
+ params=stage_path.params,
327
+ query=stage_path.query,
328
+ fragment=stage_path.fragment,
329
+ )
330
+ stage_dir_path = parse.urlunparse(new_url)
331
+ else:
332
+ stage_dir_path = str((stage_path / relative_path).parent)
333
+
320
334
  retrying.retry(
321
335
  retry_on_exception=_retry_on_sql_error,
322
336
  stop_max_attempt_number=5,
@@ -37,6 +37,9 @@ class PlatformCapabilities:
37
37
  def is_nested_function_enabled(self) -> bool:
38
38
  return self._get_bool_feature("SPCS_MODEL_ENABLE_EMBEDDED_SERVICE_FUNCTIONS", False)
39
39
 
40
+ def is_live_commit_enabled(self) -> bool:
41
+ return self._get_bool_feature("ENABLE_BUNDLE_MODULE_CHECKOUT", False)
42
+
40
43
  @staticmethod
41
44
  def _get_features(session: snowpark_session.Session) -> Dict[str, Any]:
42
45
  try:
@@ -0,0 +1,16 @@
1
+ from enum import Enum
2
+
3
+
4
+ class RelaxVersionStrategy(Enum):
5
+ NO_RELAX = "no_relax"
6
+ PATCH = "patch"
7
+ MINOR = "minor"
8
+ MAJOR = "major"
9
+
10
+
11
+ RELAX_VERSION_STRATEGY_MAP = {
12
+ # The version of cloudpickle should not be relaxed as it is used for serialization.
13
+ "cloudpickle": RelaxVersionStrategy.NO_RELAX,
14
+ # The version of scikit-learn should be relaxed only in patch version as it has breaking changes in minor version.
15
+ "scikit-learn": RelaxVersionStrategy.PATCH,
16
+ }
@@ -4,6 +4,9 @@ import enum
4
4
  import functools
5
5
  import inspect
6
6
  import operator
7
+ import sys
8
+ import time
9
+ import traceback
7
10
  import types
8
11
  from typing import (
9
12
  Any,
@@ -75,6 +78,8 @@ class TelemetryField(enum.Enum):
75
78
  KEY_FUNC_PARAMS = "func_params"
76
79
  KEY_ERROR_INFO = "error_info"
77
80
  KEY_ERROR_CODE = "error_code"
81
+ KEY_STACK_TRACE = "stack_trace"
82
+ KEY_DURATION = "duration"
78
83
  KEY_VERSION = "version"
79
84
  KEY_PYTHON_VERSION = "python_version"
80
85
  KEY_OS = "operating_system"
@@ -348,6 +353,10 @@ def get_function_usage_statement_params(
348
353
  statement_params[TelemetryField.KEY_API_CALLS.value].append({TelemetryField.NAME.value: api_call})
349
354
  if custom_tags:
350
355
  statement_params[TelemetryField.KEY_CUSTOM_TAGS.value] = custom_tags
356
+ # Snowpark doesn't support None value in statement_params from version 1.29
357
+ for k in statement_params:
358
+ if statement_params[k] is None:
359
+ statement_params[k] = ""
351
360
  return statement_params
352
361
 
353
362
 
@@ -435,6 +444,7 @@ def send_api_usage_telemetry(
435
444
 
436
445
  # noqa: DAR402
437
446
  """
447
+ start_time = time.perf_counter()
438
448
 
439
449
  if subproject is not None and subproject_extractor is not None:
440
450
  raise ValueError("Specifying both subproject and subproject_extractor is not allowed")
@@ -555,8 +565,16 @@ def send_api_usage_telemetry(
555
565
  )
556
566
  else:
557
567
  me = e
568
+
558
569
  telemetry_args["error"] = repr(me)
559
570
  telemetry_args["error_code"] = me.error_code
571
+ # exclude telemetry frames
572
+ excluded_frames = 2
573
+ tb = traceback.extract_tb(sys.exc_info()[2])
574
+ formatted_tb = "".join(traceback.format_list(tb[excluded_frames:]))
575
+ formatted_exception = traceback.format_exception_only(*sys.exc_info()[:2])[0] # error type + message
576
+ telemetry_args["stack_trace"] = formatted_tb + formatted_exception
577
+
560
578
  me.original_exception._snowflake_ml_handled = True # type: ignore[attr-defined]
561
579
  if e is not me:
562
580
  raise # Directly raise non-wrapped exceptions to preserve original stacktrace
@@ -565,6 +583,7 @@ def send_api_usage_telemetry(
565
583
  else:
566
584
  raise me.original_exception from e
567
585
  finally:
586
+ telemetry_args["duration"] = time.perf_counter() - start_time # type: ignore[assignment]
568
587
  telemetry.send_function_usage_telemetry(**telemetry_args)
569
588
  global _log_counter
570
589
  _log_counter += 1
@@ -718,12 +737,14 @@ class _SourceTelemetryClient:
718
737
  self,
719
738
  func_name: str,
720
739
  function_category: str,
740
+ duration: float,
721
741
  func_params: Optional[Dict[str, Any]] = None,
722
742
  api_calls: Optional[List[Dict[str, Any]]] = None,
723
743
  sfqids: Optional[List[Any]] = None,
724
744
  custom_tags: Optional[Dict[str, Union[bool, int, str, float]]] = None,
725
745
  error: Optional[str] = None,
726
746
  error_code: Optional[str] = None,
747
+ stack_trace: Optional[str] = None,
727
748
  ) -> None:
728
749
  """
729
750
  Send function usage telemetry message.
@@ -731,12 +752,14 @@ class _SourceTelemetryClient:
731
752
  Args:
732
753
  func_name: Function name.
733
754
  function_category: Function category.
755
+ duration: Function duration.
734
756
  func_params: Function parameters.
735
757
  api_calls: API calls.
736
758
  sfqids: Snowflake query IDs.
737
759
  custom_tags: Custom tags.
738
760
  error: Error.
739
761
  error_code: Error code.
762
+ stack_trace: Error stack trace.
740
763
  """
741
764
  data: Dict[str, Any] = {
742
765
  TelemetryField.KEY_FUNC_NAME.value: func_name,
@@ -755,11 +778,13 @@ class _SourceTelemetryClient:
755
778
  message: Dict[str, Any] = {
756
779
  **self._create_basic_telemetry_data(telemetry_type),
757
780
  TelemetryField.KEY_DATA.value: data,
781
+ TelemetryField.KEY_DURATION.value: duration,
758
782
  }
759
783
 
760
784
  if error:
761
785
  message[TelemetryField.KEY_ERROR_INFO.value] = error
762
786
  message[TelemetryField.KEY_ERROR_CODE.value] = error_code
787
+ message[TelemetryField.KEY_STACK_TRACE.value] = stack_trace
763
788
 
764
789
  self._send(message)
765
790
 
@@ -116,7 +116,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
116
116
  def to_pandas(self, limit: Optional[int] = None) -> pd.DataFrame:
117
117
  ds = self._get_dataset(shuffle=False)
118
118
  table = ds.to_table() if limit is None else ds.head(num_rows=limit)
119
- return table.to_pandas()
119
+ return table.to_pandas(split_blocks=True, self_destruct=True)
120
120
 
121
121
  def _get_dataset(self, shuffle: bool) -> pds.Dataset:
122
122
  format = self._format
@@ -144,6 +144,7 @@ _LIST_FEATURE_VIEW_SCHEMA = StructType(
144
144
  StructField("refresh_mode", StringType()),
145
145
  StructField("scheduling_state", StringType()),
146
146
  StructField("warehouse", StringType()),
147
+ StructField("cluster_by", StringType()),
147
148
  ]
148
149
  )
149
150
 
@@ -1832,6 +1833,12 @@ class FeatureStore:
1832
1833
  WAREHOUSE = {warehouse}
1833
1834
  REFRESH_MODE = {feature_view.refresh_mode}
1834
1835
  INITIALIZE = {feature_view.initialize}
1836
+ """
1837
+ if feature_view.cluster_by:
1838
+ cluster_by_clause = f"CLUSTER BY ({', '.join(feature_view.cluster_by)})"
1839
+ query += f"{cluster_by_clause}"
1840
+
1841
+ query += f"""
1835
1842
  AS {feature_view.query}
1836
1843
  """
1837
1844
  self._session.sql(query).collect(block=block, statement_params=self._telemetry_stmp)
@@ -2249,6 +2256,7 @@ class FeatureStore:
2249
2256
  values.append(row["refresh_mode"] if "refresh_mode" in row else None)
2250
2257
  values.append(row["scheduling_state"] if "scheduling_state" in row else None)
2251
2258
  values.append(row["warehouse"] if "warehouse" in row else None)
2259
+ values.append(json.dumps(self._extract_cluster_by_columns(row["cluster_by"])) if "cluster_by" in row else None)
2252
2260
  output_values.append(values)
2253
2261
 
2254
2262
  def _lookup_feature_view_metadata(self, row: Row, fv_name: str) -> Tuple[_FeatureViewMetadata, str]:
@@ -2335,6 +2343,7 @@ class FeatureStore:
2335
2343
  owner=row["owner"],
2336
2344
  infer_schema_df=infer_schema_df,
2337
2345
  session=self._session,
2346
+ cluster_by=self._extract_cluster_by_columns(row["cluster_by"]),
2338
2347
  )
2339
2348
  return fv
2340
2349
  else:
@@ -2625,3 +2634,12 @@ class FeatureStore:
2625
2634
  )
2626
2635
 
2627
2636
  return feature_view
2637
+
2638
+ @staticmethod
2639
+ def _extract_cluster_by_columns(cluster_by_clause: str) -> List[str]:
2640
+ # Use regex to extract elements inside the parentheses.
2641
+ match = re.search(r"\((.*?)\)", cluster_by_clause)
2642
+ if match:
2643
+ # Handle both quoted and unquoted column names.
2644
+ return re.findall(identifier.SF_IDENTIFIER_RE, match.group(1))
2645
+ return []
@@ -170,6 +170,7 @@ class FeatureView(lineage_node.LineageNode):
170
170
  warehouse: Optional[str] = None,
171
171
  initialize: str = "ON_CREATE",
172
172
  refresh_mode: str = "AUTO",
173
+ cluster_by: Optional[List[str]] = None,
173
174
  **_kwargs: Any,
174
175
  ) -> None:
175
176
  """
@@ -200,6 +201,9 @@ class FeatureView(lineage_node.LineageNode):
200
201
  refresh_mode: The refresh mode of managed feature view. The value can be 'AUTO', 'FULL' or 'INCREMENETAL'.
201
202
  For managed feature view, the default value is 'AUTO'. For static feature view it has no effect.
202
203
  Check https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table for for details.
204
+ cluster_by: Columns to cluster the feature view by.
205
+ - Defaults to the join keys from entities.
206
+ - If `timestamp_col` is provided, it is added to the default clustering keys.
203
207
  _kwargs: reserved kwargs for system generated args. NOTE: DO NOT USE.
204
208
 
205
209
  Example::
@@ -224,6 +228,7 @@ class FeatureView(lineage_node.LineageNode):
224
228
  >>> print(registered_fv.status)
225
229
  FeatureViewStatus.ACTIVE
226
230
 
231
+ # noqa: DAR401
227
232
  """
228
233
 
229
234
  self._name: SqlIdentifier = SqlIdentifier(name)
@@ -233,7 +238,7 @@ class FeatureView(lineage_node.LineageNode):
233
238
  SqlIdentifier(timestamp_col) if timestamp_col is not None else None
234
239
  )
235
240
  self._desc: str = desc
236
- self._infer_schema_df: DataFrame = _kwargs.get("_infer_schema_df", self._feature_df)
241
+ self._infer_schema_df: DataFrame = _kwargs.pop("_infer_schema_df", self._feature_df)
237
242
  self._query: str = self._get_query()
238
243
  self._version: Optional[FeatureViewVersion] = None
239
244
  self._status: FeatureViewStatus = FeatureViewStatus.DRAFT
@@ -249,6 +254,14 @@ class FeatureView(lineage_node.LineageNode):
249
254
  self._refresh_mode: Optional[str] = refresh_mode
250
255
  self._refresh_mode_reason: Optional[str] = None
251
256
  self._owner: Optional[str] = None
257
+ self._cluster_by: List[SqlIdentifier] = (
258
+ [SqlIdentifier(col) for col in cluster_by] if cluster_by is not None else self._get_default_cluster_by()
259
+ )
260
+
261
+ # Validate kwargs
262
+ if _kwargs:
263
+ raise TypeError(f"FeatureView.__init__ got an unexpected keyword argument: '{next(iter(_kwargs.keys()))}'")
264
+
252
265
  self._validate()
253
266
 
254
267
  def slice(self, names: List[str]) -> FeatureViewSlice:
@@ -394,6 +407,10 @@ class FeatureView(lineage_node.LineageNode):
394
407
  def timestamp_col(self) -> Optional[SqlIdentifier]:
395
408
  return self._timestamp_col
396
409
 
410
+ @property
411
+ def cluster_by(self) -> Optional[List[SqlIdentifier]]:
412
+ return self._cluster_by
413
+
397
414
  @property
398
415
  def desc(self) -> str:
399
416
  return self._desc
@@ -656,6 +673,14 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
656
673
  if not isinstance(col_type, (DateType, TimeType, TimestampType, _NumericType)):
657
674
  raise ValueError(f"Invalid data type for timestamp_col {ts_col}: {col_type}.")
658
675
 
676
+ if self.cluster_by is not None:
677
+ for column in self.cluster_by:
678
+ if column not in df_cols:
679
+ raise ValueError(
680
+ f"Column '{column}' in `cluster_by` is not in the feature DataFrame schema. "
681
+ f"{df_cols}, {self.cluster_by}"
682
+ )
683
+
659
684
  if re.match(_RESULT_SCAN_QUERY_PATTERN, self._query) is not None:
660
685
  raise ValueError(f"feature_df should not be reading from RESULT_SCAN. Invalid query: {self._query}")
661
686
 
@@ -890,6 +915,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
890
915
  owner: Optional[str],
891
916
  infer_schema_df: Optional[DataFrame],
892
917
  session: Session,
918
+ cluster_by: Optional[List[str]] = None,
893
919
  ) -> FeatureView:
894
920
  fv = FeatureView(
895
921
  name=name,
@@ -898,6 +924,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
898
924
  timestamp_col=timestamp_col,
899
925
  desc=desc,
900
926
  _infer_schema_df=infer_schema_df,
927
+ cluster_by=cluster_by,
901
928
  )
902
929
  fv._version = FeatureViewVersion(version) if version is not None else None
903
930
  fv._status = status
@@ -916,5 +943,23 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
916
943
  )
917
944
  return fv
918
945
 
946
+ #
947
+ def _get_default_cluster_by(self) -> List[SqlIdentifier]:
948
+ """
949
+ Get default columns to cluster the feature view by.
950
+ Default cluster_by columns are join keys from entities and timestamp_col if it exists
951
+
952
+ Returns:
953
+ List of SqlIdentifiers representing the default columns to cluster the feature view by.
954
+ """
955
+ # We don't focus on the order of entities here, as users can define a custom 'cluster_by'
956
+ # if a specific order is required.
957
+ default_cluster_by_cols = [key for entity in self.entities if entity.join_keys for key in entity.join_keys]
958
+
959
+ if self.timestamp_col:
960
+ default_cluster_by_cols.append(self.timestamp_col)
961
+
962
+ return default_cluster_by_cols
963
+
919
964
 
920
965
  lineage_node.DOMAIN_LINEAGE_REGISTRY["feature_view"] = FeatureView
@@ -257,7 +257,6 @@ class FileSet:
257
257
  function_name=telemetry.get_statement_params_full_func_name(
258
258
  inspect.currentframe(), cls.__class__.__name__
259
259
  ),
260
- api_calls=[snowpark.DataFrameWriter.copy_into_location],
261
260
  ),
262
261
  )
263
262
  except snowpark_exceptions.SnowparkSQLException as e: