snowflake-ml-python 1.6.4__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_complete.py +107 -64
  3. snowflake/cortex/_finetune.py +273 -0
  4. snowflake/cortex/_sse_client.py +91 -28
  5. snowflake/cortex/_util.py +30 -1
  6. snowflake/ml/_internal/telemetry.py +4 -2
  7. snowflake/ml/_internal/type_utils.py +3 -3
  8. snowflake/ml/_internal/utils/import_utils.py +31 -0
  9. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +13 -0
  10. snowflake/ml/data/__init__.py +5 -0
  11. snowflake/ml/data/_internal/arrow_ingestor.py +8 -0
  12. snowflake/ml/data/data_connector.py +1 -1
  13. snowflake/ml/data/torch_utils.py +33 -14
  14. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +5 -3
  15. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +7 -5
  16. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +4 -2
  17. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +3 -1
  18. snowflake/ml/feature_store/examples/example_helper.py +6 -3
  19. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +4 -2
  20. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +4 -2
  21. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +3 -1
  22. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +3 -1
  23. snowflake/ml/feature_store/feature_store.py +1 -2
  24. snowflake/ml/feature_store/feature_view.py +5 -1
  25. snowflake/ml/model/_client/model/model_version_impl.py +145 -11
  26. snowflake/ml/model/_client/ops/model_ops.py +56 -16
  27. snowflake/ml/model/_client/ops/service_ops.py +46 -30
  28. snowflake/ml/model/_client/service/model_deployment_spec.py +19 -8
  29. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
  30. snowflake/ml/model/_client/sql/service.py +25 -1
  31. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  34. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
  35. snowflake/ml/model/_model_composer/model_method/model_method.py +1 -1
  36. snowflake/ml/model/_packager/model_env/model_env.py +12 -0
  37. snowflake/ml/model/_packager/model_handlers/_utils.py +6 -2
  38. snowflake/ml/model/_packager/model_handlers/catboost.py +4 -7
  39. snowflake/ml/model/_packager/model_handlers/custom.py +5 -1
  40. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +10 -1
  41. snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -7
  42. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -1
  43. snowflake/ml/model/_packager/model_handlers/sklearn.py +51 -7
  44. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +8 -66
  45. snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
  46. snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
  47. snowflake/ml/model/_packager/model_handlers/xgboost.py +10 -40
  48. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
  49. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
  50. snowflake/ml/model/_packager/model_packager.py +0 -11
  51. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
  52. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
  53. snowflake/ml/model/_packager/{model_handlers/model_objective_utils.py → model_task/model_task_utils.py} +14 -26
  54. snowflake/ml/model/_signatures/core.py +63 -16
  55. snowflake/ml/model/_signatures/pandas_handler.py +87 -27
  56. snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
  57. snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
  58. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
  59. snowflake/ml/model/_signatures/utils.py +4 -0
  60. snowflake/ml/model/custom_model.py +47 -7
  61. snowflake/ml/model/model_signature.py +40 -9
  62. snowflake/ml/model/type_hints.py +9 -1
  63. snowflake/ml/modeling/_internal/estimator_utils.py +13 -0
  64. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +7 -2
  65. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +16 -5
  66. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -2
  67. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -3
  68. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -8
  69. snowflake/ml/modeling/cluster/agglomerative_clustering.py +17 -19
  70. snowflake/ml/modeling/cluster/dbscan.py +5 -2
  71. snowflake/ml/modeling/cluster/feature_agglomeration.py +7 -19
  72. snowflake/ml/modeling/cluster/k_means.py +14 -19
  73. snowflake/ml/modeling/cluster/mini_batch_k_means.py +3 -3
  74. snowflake/ml/modeling/cluster/optics.py +6 -6
  75. snowflake/ml/modeling/cluster/spectral_clustering.py +4 -3
  76. snowflake/ml/modeling/compose/column_transformer.py +15 -5
  77. snowflake/ml/modeling/compose/transformed_target_regressor.py +7 -6
  78. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  79. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  80. snowflake/ml/modeling/covariance/min_cov_det.py +2 -2
  81. snowflake/ml/modeling/covariance/oas.py +1 -1
  82. snowflake/ml/modeling/decomposition/kernel_pca.py +2 -2
  83. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -12
  84. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -12
  85. snowflake/ml/modeling/decomposition/pca.py +28 -15
  86. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -0
  87. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -12
  88. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -11
  89. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -8
  90. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -8
  91. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +21 -2
  92. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +18 -2
  93. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +2 -0
  94. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +2 -0
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +21 -8
  96. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +21 -11
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +21 -2
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +18 -2
  99. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +2 -1
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
  101. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +2 -2
  102. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
  103. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
  104. snowflake/ml/modeling/linear_model/ard_regression.py +5 -10
  105. snowflake/ml/modeling/linear_model/bayesian_ridge.py +5 -11
  106. snowflake/ml/modeling/linear_model/elastic_net.py +3 -0
  107. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  108. snowflake/ml/modeling/linear_model/lars.py +0 -10
  109. snowflake/ml/modeling/linear_model/lars_cv.py +1 -11
  110. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  111. snowflake/ml/modeling/linear_model/lasso_lars.py +0 -10
  112. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -11
  113. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +0 -10
  114. snowflake/ml/modeling/linear_model/logistic_regression.py +28 -22
  115. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +30 -24
  116. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  117. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  118. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +4 -13
  119. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +4 -4
  120. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  121. snowflake/ml/modeling/linear_model/perceptron.py +3 -3
  122. snowflake/ml/modeling/linear_model/ransac_regressor.py +3 -2
  123. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +14 -6
  124. snowflake/ml/modeling/linear_model/ridge_cv.py +17 -11
  125. snowflake/ml/modeling/linear_model/sgd_classifier.py +2 -2
  126. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -1
  127. snowflake/ml/modeling/linear_model/sgd_regressor.py +12 -3
  128. snowflake/ml/modeling/manifold/isomap.py +1 -1
  129. snowflake/ml/modeling/manifold/mds.py +3 -3
  130. snowflake/ml/modeling/manifold/tsne.py +10 -4
  131. snowflake/ml/modeling/metrics/classification.py +12 -16
  132. snowflake/ml/modeling/metrics/ranking.py +3 -3
  133. snowflake/ml/modeling/metrics/regression.py +3 -3
  134. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +3 -3
  135. snowflake/ml/modeling/naive_bayes/categorical_nb.py +3 -3
  136. snowflake/ml/modeling/naive_bayes/complement_nb.py +3 -3
  137. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +3 -3
  138. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +10 -4
  139. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +5 -2
  140. snowflake/ml/modeling/neighbors/local_outlier_factor.py +2 -2
  141. snowflake/ml/modeling/neighbors/nearest_centroid.py +7 -14
  142. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  143. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -1
  144. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  145. snowflake/ml/modeling/neural_network/mlp_classifier.py +7 -1
  146. snowflake/ml/modeling/neural_network/mlp_regressor.py +3 -0
  147. snowflake/ml/modeling/pipeline/pipeline.py +16 -14
  148. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +8 -4
  149. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -7
  150. snowflake/ml/modeling/svm/linear_svc.py +25 -16
  151. snowflake/ml/modeling/svm/linear_svr.py +23 -17
  152. snowflake/ml/modeling/svm/nu_svc.py +5 -3
  153. snowflake/ml/modeling/svm/nu_svr.py +3 -1
  154. snowflake/ml/modeling/svm/svc.py +9 -5
  155. snowflake/ml/modeling/svm/svr.py +3 -1
  156. snowflake/ml/modeling/tree/decision_tree_classifier.py +21 -2
  157. snowflake/ml/modeling/tree/decision_tree_regressor.py +18 -2
  158. snowflake/ml/modeling/tree/extra_tree_classifier.py +28 -9
  159. snowflake/ml/modeling/tree/extra_tree_regressor.py +18 -2
  160. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +448 -0
  161. snowflake/ml/monitoring/_manager/model_monitor_manager.py +238 -0
  162. snowflake/ml/monitoring/entities/model_monitor_config.py +10 -10
  163. snowflake/ml/monitoring/model_monitor.py +37 -0
  164. snowflake/ml/registry/_manager/model_manager.py +15 -1
  165. snowflake/ml/registry/registry.py +32 -37
  166. snowflake/ml/version.py +1 -1
  167. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/METADATA +104 -12
  168. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/RECORD +172 -171
  169. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/WHEEL +1 -1
  170. snowflake/ml/monitoring/_client/model_monitor.py +0 -126
  171. snowflake/ml/monitoring/_client/model_monitor_manager.py +0 -361
  172. snowflake/ml/monitoring/_client/monitor_sql_client.py +0 -1335
  173. snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
  174. /snowflake/ml/monitoring/{_client/model_monitor_version.py → model_monitor_version.py} +0 -0
  175. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/LICENSE.txt +0 -0
  176. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/top_level.txt +0 -0
@@ -1,73 +1,125 @@
1
- from typing import Iterator, cast
1
+ import json
2
+ from typing import Any, Iterator, Optional
2
3
 
3
- import requests
4
+ _FIELD_SEPARATOR = ":"
4
5
 
5
6
 
6
7
  class Event:
7
- def __init__(self, event: str = "message", data: str = "") -> None:
8
+ """Representation of an event from the event stream."""
9
+
10
+ def __init__(
11
+ self,
12
+ id: Optional[str] = None,
13
+ event: str = "message",
14
+ data: str = "",
15
+ comment: Optional[str] = None,
16
+ retry: Optional[int] = None,
17
+ ) -> None:
18
+ self.id = id
8
19
  self.event = event
9
20
  self.data = data
21
+ self.comment = comment
22
+ self.retry = retry
10
23
 
11
24
  def __str__(self) -> str:
12
25
  s = f"{self.event} event"
26
+ if self.id:
27
+ s += f" #{self.id}"
13
28
  if self.data:
14
- s += f", {len(self.data)} bytes"
29
+ s += ", {} byte{}".format(len(self.data), "s" if len(self.data) else "")
15
30
  else:
16
31
  s += ", no data"
32
+ if self.comment:
33
+ s += f", comment: {self.comment}"
34
+ if self.retry:
35
+ s += f", retry in {self.retry}ms"
17
36
  return s
18
37
 
19
38
 
39
+ # This is copied from the snowpy library:
40
+ # https://github.com/snowflakedb/snowpy/blob/main/libs/snowflake.core/src/snowflake/core/rest.py#L39
41
+ # TODO(SNOW-1750723) - Current there’s code duplication across snowflake-ml-python
42
+ # and snowpy library for Cortex REST API which was done to meet our GA timelines
43
+ # Once snowpy has a release with https://github.com/snowflakedb/snowpy/pull/679, we should
44
+ # remove the class here and directly refer from the snowflake.core package directly
20
45
  class SSEClient:
21
- def __init__(self, response: requests.Response) -> None:
46
+ def __init__(self, event_source: Any, char_enc: str = "utf-8") -> None:
47
+ self._event_source = event_source
48
+ self._char_enc = char_enc
22
49
 
23
- self.response = response
24
-
25
- def _read(self) -> Iterator[str]:
26
-
27
- lines = b""
28
- for chunk in self.response:
50
+ def _read(self) -> Iterator[bytes]:
51
+ data = b""
52
+ for chunk in self._event_source:
29
53
  for line in chunk.splitlines(True):
30
- lines += line
31
- if lines.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
32
- yield cast(str, lines)
33
- lines = b""
34
- if lines:
35
- yield cast(str, lines)
54
+ data += line
55
+ if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
56
+ yield data
57
+ data = b""
58
+ if data:
59
+ yield data
36
60
 
37
61
  def events(self) -> Iterator[Event]:
38
- for raw_event in self._read():
62
+ content_type = self._event_source.headers.get("Content-Type")
63
+ # The check for empty content-type is present because it's being populated after
64
+ # the change in https://github.com/snowflakedb/snowflake/pull/217654.
65
+ # This can be removed once the above change makes it to prod or we move to snowpy
66
+ # for SSEClient implementation.
67
+ if content_type == "text/event-stream" or not content_type:
68
+ return self._handle_sse()
69
+ elif content_type == "application/json":
70
+ return self._handle_json()
71
+ else:
72
+ raise ValueError(f"Unknown Content-Type: {content_type}")
73
+
74
+ def _handle_sse(self) -> Iterator[Event]:
75
+ for chunk in self._read():
39
76
  event = Event()
40
- # splitlines() only uses \r and \n
41
- for line in raw_event.splitlines():
77
+ # Split before decoding so splitlines() only uses \r and \n
78
+ for line_bytes in chunk.splitlines():
79
+ # Decode the line.
80
+ line = line_bytes.decode(self._char_enc)
42
81
 
43
- line = cast(bytes, line).decode("utf-8")
82
+ # Lines starting with a separator are comments and are to be
83
+ # ignored.
84
+ if not line.strip() or line.startswith(_FIELD_SEPARATOR):
85
+ continue
44
86
 
45
- data = line.split(":", 1)
87
+ data = line.split(_FIELD_SEPARATOR, 1)
46
88
  field = data[0]
47
89
 
90
+ # Ignore unknown fields.
91
+ if not hasattr(event, field):
92
+ continue
93
+
48
94
  if len(data) > 1:
95
+ # From the spec:
49
96
  # "If value starts with a single U+0020 SPACE character,
50
- # remove it from value. .strip() would remove all white spaces"
97
+ # remove it from value."
51
98
  if data[1].startswith(" "):
52
99
  value = data[1][1:]
53
100
  else:
54
101
  value = data[1]
55
102
  else:
103
+ # If no value is present after the separator,
104
+ # assume an empty value.
56
105
  value = ""
57
106
 
58
107
  # The data field may come over multiple lines and their values
59
108
  # are concatenated with each other.
109
+ current_value = getattr(event, field, "")
60
110
  if field == "data":
61
- event.data += value + "\n"
62
- elif field == "event":
63
- event.event = value
111
+ new_value = current_value + value + "\n"
112
+ else:
113
+ new_value = value
114
+ setattr(event, field, new_value)
64
115
 
116
+ # Events with no data are not dispatched.
65
117
  if not event.data:
66
118
  continue
67
119
 
68
120
  # If the data field ends with a newline, remove it.
69
121
  if event.data.endswith("\n"):
70
- event.data = event.data[0:-1] # Replace trailing newline - rstrip would remove multiple.
122
+ event.data = event.data[0:-1]
71
123
 
72
124
  # Empty event names default to 'message'
73
125
  event.event = event.event or "message"
@@ -77,5 +129,16 @@ class SSEClient:
77
129
 
78
130
  yield event
79
131
 
132
+ def _handle_json(self) -> Iterator[Event]:
133
+ data_list = json.loads(self._event_source.data.decode(self._char_enc))
134
+ for data in data_list:
135
+ yield Event(
136
+ id=data.get("id"),
137
+ event=data.get("event"),
138
+ data=data.get("data"),
139
+ comment=data.get("comment"),
140
+ retry=data.get("retry"),
141
+ )
142
+
80
143
  def close(self) -> None:
81
- self.response.close()
144
+ self._event_source.close()
snowflake/cortex/_util.py CHANGED
@@ -1,6 +1,8 @@
1
- from typing import Dict, List, Optional, Union, cast
1
+ from typing import Any, Dict, List, Optional, Union, cast
2
2
 
3
3
  from snowflake import snowpark
4
+ from snowflake.ml._internal.exceptions import error_codes, exceptions
5
+ from snowflake.ml._internal.utils import formatting
4
6
  from snowflake.snowpark import context, functions
5
7
 
6
8
  CORTEX_FUNCTIONS_TELEMETRY_PROJECT = "CortexFunctions"
@@ -64,3 +66,30 @@ def _call_sql_function_immediate(
64
66
  empty_df = session.create_dataframe([snowpark.Row()])
65
67
  df = empty_df.select(functions.builtin(function)(*lit_args))
66
68
  return cast(str, df.collect()[0][0])
69
+
70
+
71
+ def call_sql_function_literals(function: str, session: Optional[snowpark.Session], *args: Any) -> str:
72
+ r"""Call a SQL function with only literal arguments.
73
+
74
+ This is useful for calling system functions.
75
+
76
+ Args:
77
+ function: The name of the function to be called.
78
+ session: The Snowpark session to use.
79
+ *args: The list of arguments
80
+
81
+ Returns:
82
+ String value that corresponds the the first cell in the dataframe.
83
+
84
+ Raises:
85
+ SnowflakeMLException: If no session is given and no active session exists.
86
+ """
87
+ if session is None:
88
+ session = context.get_active_session()
89
+ if session is None:
90
+ raise exceptions.SnowflakeMLException(
91
+ error_code=error_codes.INVALID_SNOWPARK_SESSION,
92
+ )
93
+
94
+ function_arguments = ",".join(["NULL" if arg is None else formatting.format_value_for_select(arg) for arg in args])
95
+ return cast(str, session.sql(f"SELECT {function}({function_arguments})").collect()[0][0])
@@ -544,7 +544,7 @@ def send_api_usage_telemetry(
544
544
  if not isinstance(e, snowml_exceptions.SnowflakeMLException):
545
545
  # already handled via a nested decorated function
546
546
  if getattr(e, "_snowflake_ml_handled", False):
547
- raise e
547
+ raise
548
548
  if isinstance(e, snowpark_exceptions.SnowparkClientException):
549
549
  me = snowml_exceptions.SnowflakeMLException(
550
550
  error_code=error_codes.INTERNAL_SNOWPARK_ERROR, original_exception=e
@@ -558,7 +558,9 @@ def send_api_usage_telemetry(
558
558
  telemetry_args["error"] = repr(me)
559
559
  telemetry_args["error_code"] = me.error_code
560
560
  me.original_exception._snowflake_ml_handled = True # type: ignore[attr-defined]
561
- if me.suppress_source_trace:
561
+ if e is not me:
562
+ raise # Directly raise non-wrapped exceptions to preserve original stacktrace
563
+ elif me.suppress_source_trace:
562
564
  raise me.original_exception from None
563
565
  else:
564
566
  raise me.original_exception from e
@@ -1,4 +1,4 @@
1
- import sys
1
+ import importlib
2
2
  from typing import Any, Generic, Type, TypeVar, Union, cast
3
3
 
4
4
  import numpy as np
@@ -51,8 +51,8 @@ class LazyType(Generic[T]):
51
51
  def get_class(self) -> Type[T]:
52
52
  if self._runtime_class is None:
53
53
  try:
54
- m = sys.modules[self.module]
55
- except KeyError:
54
+ m = importlib.import_module(self.module)
55
+ except ModuleNotFoundError:
56
56
  raise ValueError(f"Module {self.module} not imported.")
57
57
 
58
58
  self._runtime_class = cast("Type[T]", getattr(m, self.qualname))
@@ -19,6 +19,33 @@ class MissingOptionalDependency:
19
19
  raise ImportError(f"Unable to import {self._dep_name}.")
20
20
 
21
21
 
22
+ def import_with_fallbacks(*targets: str) -> Any:
23
+ """Import a module which may be located in different locations.
24
+
25
+ This method will iterate through the provided targets, returning the first available import target.
26
+ If none of the requested import targets are available, ImportError will be raised.
27
+
28
+ Args:
29
+ targets: Strings representing the target which needs to be imported. It should be a list of symbol name
30
+ joined by dot. Some valid examples:
31
+ - <some_package>
32
+ - <some_module>
33
+ - <some_package>.<some_module>
34
+ - <some_module>.<some_symbol>
35
+
36
+ Returns:
37
+ The imported target.
38
+
39
+ Raises:
40
+ ImportError: None of the requested targets are available
41
+ """
42
+ for target in targets:
43
+ result, success = import_or_get_dummy(target)
44
+ if success:
45
+ return result
46
+ raise ImportError(f"None of the requested targets could be imported. Requested: {', '.join(targets)}")
47
+
48
+
22
49
  def import_or_get_dummy(target: str) -> Tuple[Any, bool]:
23
50
  """Try to import the the given target or return a dummy object.
24
51
 
@@ -43,6 +70,10 @@ def import_or_get_dummy(target: str) -> Tuple[Any, bool]:
43
70
  except ImportError:
44
71
  pass
45
72
 
73
+ # Don't try symbol resolution if target doesn't contain '.'
74
+ if "." not in target:
75
+ return (MissingOptionalDependency(target), False)
76
+
46
77
  # Try to import the target as a symbol
47
78
  try:
48
79
  res = _try_import_symbol(target)
@@ -121,3 +121,16 @@ def cast_snowpark_dataframe_column_types(df: snowpark.DataFrame) -> snowpark.Dat
121
121
  selected_cols.append(functions.col(src))
122
122
  df = df.select(selected_cols)
123
123
  return df
124
+
125
+
126
+ def is_single_query_snowpark_dataframe(df: snowpark.DataFrame) -> bool:
127
+ """Check if dataframe only has a single query.
128
+
129
+ Args:
130
+ df: A snowpark dataframe.
131
+
132
+ Returns:
133
+ true if there is only on query in the dataframe and no post_actions,
134
+ false otherwise.
135
+ """
136
+ return len(df.queries["queries"]) == 1 and len(df.queries["post_actions"]) == 0
@@ -0,0 +1,5 @@
1
+ from .data_connector import DataConnector
2
+ from .data_ingestor import DataIngestor, DataIngestorType
3
+ from .data_source import DataFrameInfo, DatasetInfo, DataSource
4
+
5
+ __all__ = ["DataConnector", "DataSource", "DataFrameInfo", "DatasetInfo", "DataIngestor", "DataIngestorType"]
@@ -198,7 +198,15 @@ def _record_batch_to_arrays(rb: pa.RecordBatch) -> Dict[str, npt.NDArray[Any]]:
198
198
  for column, column_schema in zip(rb, rb.schema):
199
199
  # zero_copy_only=False because of nans. Ideally nans should have been imputed in feature engineering.
200
200
  array = column.to_numpy(zero_copy_only=False)
201
+ # If this column is a list, use the underlying type from the list values. Since this is just one column,
202
+ # there should only be one type within the list.
203
+ # TODO: Refactor to reduce data copies.
204
+ if isinstance(column_schema.type, pa.ListType):
205
+ # Update dtype of outer array:
206
+ array = np.array(array.tolist(), dtype=column_schema.type.value_type.to_pandas_dtype())
207
+
201
208
  batch_dict[column_schema.name] = array
209
+
202
210
  return batch_dict
203
211
 
204
212
 
@@ -159,7 +159,7 @@ class DataConnector:
159
159
  func_params_to_log=["batch_size", "shuffle", "drop_last_batch"],
160
160
  )
161
161
  def to_torch_dataset(
162
- self, *, batch_size: int = 1, shuffle: bool = False, drop_last_batch: bool = True
162
+ self, *, batch_size: Optional[int] = None, shuffle: bool = False, drop_last_batch: bool = True
163
163
  ) -> "torch_data.IterableDataset": # type: ignore[type-arg]
164
164
  """Transform the Snowflake data into a PyTorch Iterable Dataset to be used with a DataLoader.
165
165
 
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Iterator, List, Union
1
+ from typing import Any, Dict, Iterator, List, Optional, Union
2
2
 
3
3
  import numpy as np
4
4
  import numpy.typing as npt
@@ -14,17 +14,21 @@ class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
14
14
  self,
15
15
  ingestor: data_ingestor.DataIngestor,
16
16
  *,
17
- batch_size: int,
17
+ batch_size: Optional[int],
18
18
  shuffle: bool = False,
19
19
  drop_last: bool = False,
20
- squeeze_outputs: bool = True
21
20
  ) -> None:
22
21
  """Not intended for direct usage. Use DataConnector.to_torch_dataset() instead"""
22
+ squeeze = False
23
+ if batch_size is None:
24
+ batch_size = 1
25
+ squeeze = True
26
+
23
27
  self._ingestor = ingestor
24
28
  self._batch_size = batch_size
25
29
  self._shuffle = shuffle
26
30
  self._drop_last = drop_last
27
- self._squeeze_outputs = squeeze_outputs
31
+ self._squeeze_outputs = squeeze
28
32
 
29
33
  def __iter__(self) -> Iterator[Dict[str, Union[npt.NDArray[Any], List[Any]]]]:
30
34
  max_idx = 0
@@ -43,15 +47,7 @@ class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
43
47
  ):
44
48
  # Skip indices during multi-process data loading to prevent data duplication
45
49
  if counter == filter_idx:
46
- # Basic preprocessing on batch values: squeeze away extra dimensions
47
- # and convert object arrays (e.g. strings) to lists
48
- if self._squeeze_outputs:
49
- yield {
50
- k: (v.squeeze().tolist() if v.dtype == np.object_ else v.squeeze()) for k, v in batch.items()
51
- }
52
- else:
53
- yield batch # type: ignore[misc]
54
-
50
+ yield {k: _preprocess_array(v, squeeze=self._squeeze_outputs) for k, v in batch.items()}
55
51
  if counter < max_idx:
56
52
  counter += 1
57
53
  else:
@@ -65,4 +61,27 @@ class TorchDataPipeWrapper(TorchDatasetWrapper, torch.utils.data.IterDataPipe[Di
65
61
  self, ingestor: data_ingestor.DataIngestor, *, batch_size: int, shuffle: bool = False, drop_last: bool = False
66
62
  ) -> None:
67
63
  """Not intended for direct usage. Use DataConnector.to_torch_datapipe() instead"""
68
- super().__init__(ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, squeeze_outputs=False)
64
+ super().__init__(ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
65
+
66
+
67
+ def _preprocess_array(arr: npt.NDArray[Any], squeeze: bool = False) -> Union[npt.NDArray[Any], List[np.object_]]:
68
+ """Preprocesses batch column values."""
69
+ single_dimensional = arr.ndim < 2 and not arr.dtype == np.object_
70
+
71
+ # Squeeze away all extra dimensions. This is only used when batch_size = None.
72
+ if squeeze:
73
+ arr = arr.squeeze(axis=0)
74
+
75
+ # For single dimensional data,
76
+ if single_dimensional:
77
+ axis = 0 if arr.ndim == 0 else 1
78
+ arr = np.expand_dims(arr, axis=axis)
79
+
80
+ # Handle object arrays.
81
+ if arr.dtype == np.object_:
82
+ array_list = arr.tolist()
83
+ # If this is an array of arrays, convert the dtype to match the underlying array.
84
+ # Otherwise, if this is a numpy array of strings, convert the array to a list.
85
+ arr = np.array(array_list, dtype=arr.flat[0].dtype) if isinstance(arr.flat[0], np.ndarray) else array_list
86
+
87
+ return arr
@@ -6,15 +6,17 @@ from snowflake.snowpark import DataFrame, Session
6
6
 
7
7
 
8
8
  # This function will be invoked by example_helper.py. Do not change the name.
9
- def create_draft_feature_view(session: Session, source_dfs: List[DataFrame], source_tables: List[str]) -> FeatureView:
9
+ def create_draft_feature_view(
10
+ session: Session, source_dfs: List[DataFrame], source_tables: List[str], database: str, schema: str
11
+ ) -> FeatureView:
10
12
  """Create a feature view about airplane model."""
11
13
  query = session.sql(
12
- """
14
+ f"""
13
15
  select
14
16
  PLANE_MODEL,
15
17
  SEATING_CAPACITY
16
18
  from
17
- PLANE_MODEL_ATTRIBUTES
19
+ {database}.{schema}.PLANE_MODEL_ATTRIBUTES
18
20
  """
19
21
  )
20
22
 
@@ -6,10 +6,12 @@ from snowflake.snowpark import DataFrame, Session
6
6
 
7
7
 
8
8
  # This function will be invoked by example_helper.py. Do not change the name.
9
- def create_draft_feature_view(session: Session, source_dfs: List[DataFrame], source_tables: List[str]) -> FeatureView:
9
+ def create_draft_feature_view(
10
+ session: Session, source_dfs: List[DataFrame], source_tables: List[str], database: str, schema: str
11
+ ) -> FeatureView:
10
12
  """Create a feature view about airport weather."""
11
13
  query = session.sql(
12
- """
14
+ f"""
13
15
  select
14
16
  DATETIME_UTC AS TS,
15
17
  AIRPORT_ZIP_CODE,
@@ -21,9 +23,9 @@ def create_draft_feature_view(session: Session, source_dfs: List[DataFrame], sou
21
23
  sum(RAIN_MM_H) over (
22
24
  partition by AIRPORT_ZIP_CODE
23
25
  order by DATETIME_UTC
24
- range between interval '1 day' preceding and current row
26
+ range between interval '60 minutes' preceding and current row
25
27
  ) RAIN_SUM_60M
26
- from AIRPORT_WEATHER_STATION
28
+ from {database}.{schema}.AIRPORT_WEATHER_STATION
27
29
  """
28
30
  )
29
31
 
@@ -37,6 +39,6 @@ def create_draft_feature_view(session: Session, source_dfs: List[DataFrame], sou
37
39
  ).attach_feature_desc(
38
40
  {
39
41
  "RAIN_SUM_30M": "The sum of rain fall over past 30 minutes for one zipcode.",
40
- "RAIN_SUM_60M": "The sum of rain fall over past 1 day for one zipcode.",
42
+ "RAIN_SUM_60M": "The sum of rain fall over past 1 hour for one zipcode.",
41
43
  }
42
44
  )
@@ -8,7 +8,9 @@ from snowflake.snowpark import DataFrame, Session
8
8
 
9
9
 
10
10
  # This function will be invoked by example_helper.py. Do not change the name.
11
- def create_draft_feature_view(session: Session, source_dfs: List[DataFrame], source_tables: List[str]) -> FeatureView:
11
+ def create_draft_feature_view(
12
+ session: Session, source_dfs: List[DataFrame], source_tables: List[str], database: str, schema: str
13
+ ) -> FeatureView:
12
14
  """Create a feature view about trip station."""
13
15
  query = session.sql(
14
16
  f"""
@@ -17,7 +19,7 @@ def create_draft_feature_view(session: Session, source_dfs: List[DataFrame], sou
17
19
  count(end_station_id) as f_count,
18
20
  avg(end_station_latitude) as f_avg_latitude,
19
21
  avg(end_station_longitude) as f_avg_longtitude
20
- from {source_tables[0]}
22
+ from {database}.{schema}.{source_tables[0]}
21
23
  group by end_station_id
22
24
  """
23
25
  )
@@ -6,7 +6,9 @@ from snowflake.snowpark import DataFrame, Session, functions as F
6
6
 
7
7
 
8
8
  # This function will be invoked by example_helper.py. Do not change the name.
9
- def create_draft_feature_view(session: Session, source_dfs: List[DataFrame], source_tables: List[str]) -> FeatureView:
9
+ def create_draft_feature_view(
10
+ session: Session, source_dfs: List[DataFrame], source_tables: List[str], database: str, schema: str
11
+ ) -> FeatureView:
10
12
  """Create a feature view about trip."""
11
13
  feature_df = source_dfs[0].select(
12
14
  "trip_id",
@@ -66,7 +66,9 @@ class ExampleHelper:
66
66
  continue
67
67
  mod_path = f"{__package__}.{self._selected_example}.features.{f_name.rstrip('.py')}"
68
68
  mod = importlib.import_module(mod_path)
69
- fv = mod.create_draft_feature_view(self._session, self._source_dfs, self._source_tables)
69
+ fv = mod.create_draft_feature_view(
70
+ self._session, self._source_dfs, self._source_tables, self._database_name, self._dataset_schema
71
+ )
70
72
  fvs.append(fv)
71
73
 
72
74
  return fvs
@@ -140,7 +142,7 @@ class ExampleHelper:
140
142
  """
141
143
  ).collect()
142
144
 
143
- return [destination_table]
145
+ return [schema_dict["destination_table_name"]]
144
146
 
145
147
  def _load_parquet(self, schema_dict: Dict[str, str], temp_stage_name: str) -> List[str]:
146
148
  regex_pattern = schema_dict["load_files_pattern"]
@@ -173,13 +175,14 @@ class ExampleHelper:
173
175
  dest_table_name = (
174
176
  f"{self._database_name}.{self._dataset_schema}.{schema_dict['destination_table_name']}"
175
177
  )
178
+ result.append(schema_dict["destination_table_name"])
176
179
  else:
177
180
  regex_pattern = schema_dict["destination_table_name"]
178
181
  dest_table_name = re.match(regex_pattern, file_name).group("table_name") # type: ignore[union-attr]
182
+ result.append(dest_table_name)
179
183
  dest_table_name = f"{self._database_name}.{self._dataset_schema}.{dest_table_name}"
180
184
 
181
185
  df.write.mode("overwrite").save_as_table(dest_table_name)
182
- result.append(dest_table_name)
183
186
 
184
187
  return result
185
188
 
@@ -8,7 +8,9 @@ from snowflake.snowpark import DataFrame, Session
8
8
 
9
9
 
10
10
  # This function will be invoked by example_helper.py. Do not change the name.
11
- def create_draft_feature_view(session: Session, source_dfs: List[DataFrame], source_tables: List[str]) -> FeatureView:
11
+ def create_draft_feature_view(
12
+ session: Session, source_dfs: List[DataFrame], source_tables: List[str], database: str, schema: str
13
+ ) -> FeatureView:
12
14
  """Create a draft feature view."""
13
15
  feature_df = session.sql(
14
16
  f"""
@@ -25,7 +27,7 @@ def create_draft_feature_view(session: Session, source_dfs: List[DataFrame], sou
25
27
  order by TPEP_DROPOFF_DATETIME
26
28
  range between interval '10 hours' preceding and current row
27
29
  ) AVG_FARE_10h
28
- from {source_tables[0]}
30
+ from {database}.{schema}.{source_tables[0]}
29
31
  """
30
32
  )
31
33
 
@@ -6,7 +6,9 @@ from snowflake.snowpark import DataFrame, Session
6
6
 
7
7
 
8
8
  # This function will be invoked by example_helper.py. Do not change the name.
9
- def create_draft_feature_view(session: Session, source_dfs: List[DataFrame], source_tables: List[str]) -> FeatureView:
9
+ def create_draft_feature_view(
10
+ session: Session, source_dfs: List[DataFrame], source_tables: List[str], database: str, schema: str
11
+ ) -> FeatureView:
10
12
  """Create a draft feature view."""
11
13
  feature_df = session.sql(
12
14
  f"""
@@ -16,7 +18,7 @@ def create_draft_feature_view(session: Session, source_dfs: List[DataFrame], sou
16
18
  TRIP_DISTANCE,
17
19
  FARE_AMOUNT
18
20
  from
19
- {source_tables[0]}
21
+ {database}.{schema}.{source_tables[0]}
20
22
  """
21
23
  )
22
24
 
@@ -6,7 +6,9 @@ from snowflake.snowpark import DataFrame, Session, functions as F
6
6
 
7
7
 
8
8
  # This function will be invoked by example_helper.py. Do not change the name.
9
- def create_draft_feature_view(session: Session, source_dfs: List[DataFrame], source_tables: List[str]) -> FeatureView:
9
+ def create_draft_feature_view(
10
+ session: Session, source_dfs: List[DataFrame], source_tables: List[str], database: str, schema: str
11
+ ) -> FeatureView:
10
12
  """Create a feature view about trip station."""
11
13
  feature_df = source_dfs[0].select(
12
14
  "WINE_ID",
@@ -6,7 +6,9 @@ from snowflake.snowpark import DataFrame, Session
6
6
 
7
7
 
8
8
  # This function will be invoked by example_helper.py. Do not change the name.
9
- def create_draft_feature_view(session: Session, source_dfs: List[DataFrame], source_tables: List[str]) -> FeatureView:
9
+ def create_draft_feature_view(
10
+ session: Session, source_dfs: List[DataFrame], source_tables: List[str], database: str, schema: str
11
+ ) -> FeatureView:
10
12
  """Create a feature view about trip station."""
11
13
  feature_df = source_dfs[0].select("WINE_ID", "SULPHATES", "ALCOHOL")
12
14
 
@@ -1886,8 +1886,7 @@ class FeatureStore:
1886
1886
  if found_dts[0]["refresh_mode"] != "INCREMENTAL":
1887
1887
  warnings.warn(
1888
1888
  "Your pipeline won't be incrementally refreshed due to: "
1889
- + f"\"{found_dts[0]['refresh_mode_reason']}\". "
1890
- + "It will likely incurr higher cost.",
1889
+ + f"\"{found_dts[0]['refresh_mode_reason']}\".",
1891
1890
  stacklevel=2,
1892
1891
  category=UserWarning,
1893
1892
  )
@@ -169,6 +169,7 @@ class FeatureView(lineage_node.LineageNode):
169
169
  desc: str = "",
170
170
  warehouse: Optional[str] = None,
171
171
  initialize: str = "ON_CREATE",
172
+ refresh_mode: str = "AUTO",
172
173
  **_kwargs: Any,
173
174
  ) -> None:
174
175
  """
@@ -196,6 +197,9 @@ class FeatureView(lineage_node.LineageNode):
196
197
  after you register the feature view. It supports ON_CREATE (default) or ON_SCHEDULE. ON_CREATE refreshes
197
198
  the feature view synchronously at creation. ON_SCHEDULE refreshes the feature view at the next scheduled
198
199
  refresh. It is only effective when refresh_freq is not None.
200
+ refresh_mode: The refresh mode of managed feature view. The value can be 'AUTO', 'FULL' or 'INCREMENETAL'.
201
+ For managed feature view, the default value is 'AUTO'. For static feature view it has no effect.
202
+ Check https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table for for details.
199
203
  _kwargs: reserved kwargs for system generated args. NOTE: DO NOT USE.
200
204
 
201
205
  Example::
@@ -242,7 +246,7 @@ class FeatureView(lineage_node.LineageNode):
242
246
  self._schema: Optional[SqlIdentifier] = None
243
247
  self._initialize: str = initialize
244
248
  self._warehouse: Optional[SqlIdentifier] = SqlIdentifier(warehouse) if warehouse is not None else None
245
- self._refresh_mode: Optional[str] = _kwargs.get("refresh_mode", "AUTO")
249
+ self._refresh_mode: Optional[str] = refresh_mode
246
250
  self._refresh_mode_reason: Optional[str] = None
247
251
  self._owner: Optional[str] = None
248
252
  self._validate()