snowflake-ml-python 1.5.4__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. snowflake/cortex/__init__.py +2 -0
  2. snowflake/cortex/_classify_text.py +36 -0
  3. snowflake/cortex/_complete.py +67 -10
  4. snowflake/cortex/_util.py +4 -4
  5. snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
  6. snowflake/ml/_internal/telemetry.py +12 -2
  7. snowflake/ml/data/_internal/arrow_ingestor.py +228 -0
  8. snowflake/ml/data/_internal/ingestor_utils.py +58 -0
  9. snowflake/ml/data/data_connector.py +133 -0
  10. snowflake/ml/data/data_ingestor.py +28 -0
  11. snowflake/ml/data/data_source.py +23 -0
  12. snowflake/ml/dataset/dataset.py +1 -13
  13. snowflake/ml/dataset/dataset_reader.py +18 -118
  14. snowflake/ml/feature_store/access_manager.py +7 -1
  15. snowflake/ml/feature_store/entity.py +19 -2
  16. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
  17. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +31 -0
  18. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +24 -0
  19. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +4 -0
  20. snowflake/ml/feature_store/examples/example_helper.py +240 -0
  21. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
  22. snowflake/ml/feature_store/examples/new_york_taxi_features/features/dropoff_features.py +39 -0
  23. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +58 -0
  24. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -0
  25. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
  26. snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
  27. snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
  28. snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
  29. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
  30. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +29 -0
  31. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +21 -0
  32. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +5 -0
  33. snowflake/ml/feature_store/feature_store.py +579 -53
  34. snowflake/ml/feature_store/feature_view.py +168 -5
  35. snowflake/ml/fileset/stage_fs.py +18 -10
  36. snowflake/ml/lineage/lineage_node.py +1 -1
  37. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
  38. snowflake/ml/model/_model_composer/model_composer.py +11 -14
  39. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +24 -16
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  41. snowflake/ml/model/_model_composer/model_method/function_generator.py +3 -3
  42. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
  43. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +3 -27
  44. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -32
  45. snowflake/ml/model/_model_composer/model_method/model_method.py +5 -2
  46. snowflake/ml/model/_packager/model_handlers/_base.py +11 -1
  47. snowflake/ml/model/_packager/model_handlers/_utils.py +58 -1
  48. snowflake/ml/model/_packager/model_handlers/catboost.py +42 -0
  49. snowflake/ml/model/_packager/model_handlers/lightgbm.py +68 -0
  50. snowflake/ml/model/_packager/model_handlers/xgboost.py +59 -0
  51. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -5
  52. snowflake/ml/model/model_signature.py +4 -4
  53. snowflake/ml/model/type_hints.py +4 -0
  54. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +1 -1
  55. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
  56. snowflake/ml/modeling/impute/simple_imputer.py +26 -0
  57. snowflake/ml/modeling/pipeline/pipeline.py +4 -4
  58. snowflake/ml/registry/registry.py +100 -13
  59. snowflake/ml/version.py +1 -1
  60. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/METADATA +48 -2
  61. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/RECORD +64 -42
  62. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/WHEEL +1 -1
  63. snowflake/ml/_internal/lineage/data_source.py +0 -10
  64. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/LICENSE.txt +0 -0
  65. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,133 @@
1
+ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Type, TypeVar
2
+
3
+ import numpy.typing as npt
4
+
5
+ from snowflake import snowpark
6
+ from snowflake.ml._internal import telemetry
7
+ from snowflake.ml.data import data_ingestor, data_source
8
+ from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor as DefaultIngestor
9
+
10
+ if TYPE_CHECKING:
11
+ import pandas as pd
12
+ import tensorflow as tf
13
+ from torch.utils import data as torch_data
14
+
15
+ # This module can't actually depend on dataset to avoid a circular dependency
16
+ # Dataset -> DatasetReader -> DataConnector -!-> Dataset
17
+ from snowflake.ml import dataset
18
+
19
+ _PROJECT = "DataConnector"
20
+
21
+ DataConnectorType = TypeVar("DataConnectorType", bound="DataConnector")
22
+
23
+
24
+ class DataConnector:
25
+ """Snowflake data reader which provides application integration connectors"""
26
+
27
+ def __init__(
28
+ self,
29
+ ingestor: data_ingestor.DataIngestor,
30
+ ) -> None:
31
+ self._ingestor = ingestor
32
+
33
+ @classmethod
34
+ def from_dataframe(cls: Type[DataConnectorType], df: snowpark.DataFrame, **kwargs: Any) -> DataConnectorType:
35
+ if len(df.queries["queries"]) != 1 or len(df.queries["post_actions"]) != 0:
36
+ raise ValueError("DataFrames with multiple queries and/or post-actions not supported")
37
+ source = data_source.DataFrameInfo(df.queries["queries"][0])
38
+ assert df._session is not None
39
+ ingestor = DefaultIngestor(df._session, [source])
40
+ return cls(ingestor, **kwargs)
41
+
42
+ @classmethod
43
+ def from_dataset(cls: Type[DataConnectorType], ds: "dataset.Dataset", **kwargs: Any) -> DataConnectorType:
44
+ dsv = ds.selected_version
45
+ assert dsv is not None
46
+ source = data_source.DatasetInfo(
47
+ ds.fully_qualified_name, dsv.name, dsv.url(), exclude_cols=(dsv.label_cols + dsv.exclude_cols)
48
+ )
49
+ ingestor = DefaultIngestor(ds._session, [source])
50
+ return cls(ingestor, **kwargs)
51
+
52
+ @property
53
+ def data_sources(self) -> List[data_source.DataSource]:
54
+ return self._ingestor.data_sources
55
+
56
+ @telemetry.send_api_usage_telemetry(
57
+ project=_PROJECT,
58
+ subproject_extractor=lambda self: type(self).__name__,
59
+ func_params_to_log=["batch_size", "shuffle", "drop_last_batch"],
60
+ )
61
+ def to_tf_dataset(
62
+ self, *, batch_size: int, shuffle: bool = False, drop_last_batch: bool = True
63
+ ) -> "tf.data.Dataset":
64
+ """Transform the Snowflake data into a ready-to-use TensorFlow tf.data.Dataset.
65
+
66
+ Args:
67
+ batch_size: It specifies the size of each data batch which will be
68
+ yield in the result datapipe
69
+ shuffle: It specifies whether the data will be shuffled. If True, files will be shuffled, and
70
+ rows in each file will also be shuffled.
71
+ drop_last_batch: Whether the last batch of data should be dropped. If set to be true,
72
+ then the last batch will get dropped if its size is smaller than the given batch_size.
73
+
74
+ Returns:
75
+ A tf.data.Dataset that yields batched tf.Tensors.
76
+ """
77
+ import tensorflow as tf
78
+
79
+ def generator() -> Generator[Dict[str, npt.NDArray[Any]], None, None]:
80
+ yield from self._ingestor.to_batches(batch_size, shuffle, drop_last_batch)
81
+
82
+ # Derive TensorFlow signature
83
+ first_batch = next(self._ingestor.to_batches(1, shuffle=False, drop_last_batch=False))
84
+ tf_signature = {
85
+ k: tf.TensorSpec(shape=(None,), dtype=tf.dtypes.as_dtype(v.dtype), name=k) for k, v in first_batch.items()
86
+ }
87
+
88
+ return tf.data.Dataset.from_generator(generator, output_signature=tf_signature)
89
+
90
+ @telemetry.send_api_usage_telemetry(
91
+ project=_PROJECT,
92
+ subproject_extractor=lambda self: type(self).__name__,
93
+ func_params_to_log=["batch_size", "shuffle", "drop_last_batch"],
94
+ )
95
+ def to_torch_datapipe(
96
+ self, *, batch_size: int, shuffle: bool = False, drop_last_batch: bool = True
97
+ ) -> "torch_data.IterDataPipe": # type: ignore[type-arg]
98
+ """Transform the Snowflake data into a ready-to-use Pytorch datapipe.
99
+
100
+ Return a Pytorch datapipe which iterates on rows of data.
101
+
102
+ Args:
103
+ batch_size: It specifies the size of each data batch which will be
104
+ yield in the result datapipe
105
+ shuffle: It specifies whether the data will be shuffled. If True, files will be shuffled, and
106
+ rows in each file will also be shuffled.
107
+ drop_last_batch: Whether the last batch of data should be dropped. If set to be true,
108
+ then the last batch will get dropped if its size is smaller than the given batch_size.
109
+
110
+ Returns:
111
+ A Pytorch iterable datapipe that yield data.
112
+ """
113
+ from torch.utils.data.datapipes import iter as torch_iter
114
+
115
+ return torch_iter.IterableWrapper( # type: ignore[no-untyped-call]
116
+ self._ingestor.to_batches(batch_size, shuffle, drop_last_batch)
117
+ )
118
+
119
+ @telemetry.send_api_usage_telemetry(
120
+ project=_PROJECT,
121
+ subproject_extractor=lambda self: type(self).__name__,
122
+ func_params_to_log=["limit"],
123
+ )
124
+ def to_pandas(self, limit: Optional[int] = None) -> "pd.DataFrame":
125
+ """Retrieve the Snowflake data as a Pandas DataFrame.
126
+
127
+ Args:
128
+ limit: If specified, the maximum number of rows to load into the DataFrame.
129
+
130
+ Returns:
131
+ A Pandas DataFrame.
132
+ """
133
+ return self._ingestor.to_pandas(limit)
@@ -0,0 +1,28 @@
1
+ from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Protocol, TypeVar
2
+
3
+ from numpy import typing as npt
4
+
5
+ from snowflake.ml.data import data_source
6
+
7
+ if TYPE_CHECKING:
8
+ import pandas as pd
9
+
10
+
11
+ DataIngestorType = TypeVar("DataIngestorType", bound="DataIngestor")
12
+
13
+
14
+ class DataIngestor(Protocol):
15
+ @property
16
+ def data_sources(self) -> List[data_source.DataSource]:
17
+ raise NotImplementedError
18
+
19
+ def to_batches(
20
+ self,
21
+ batch_size: int,
22
+ shuffle: bool = True,
23
+ drop_last_batch: bool = True,
24
+ ) -> Iterator[Dict[str, npt.NDArray[Any]]]:
25
+ raise NotImplementedError
26
+
27
+ def to_pandas(self, limit: Optional[int] = None) -> "pd.DataFrame":
28
+ raise NotImplementedError
@@ -0,0 +1,23 @@
1
+ import dataclasses
2
+ from typing import List, Optional, Union
3
+
4
+
5
+ @dataclasses.dataclass(frozen=True)
6
+ class DataFrameInfo:
7
+ """Serializable information from Snowpark DataFrames"""
8
+
9
+ sql: str
10
+ query_id: Optional[str] = None
11
+
12
+
13
+ @dataclasses.dataclass(frozen=True)
14
+ class DatasetInfo:
15
+ """Serializable information from SnowML Datasets"""
16
+
17
+ fully_qualified_name: str
18
+ version: str
19
+ url: Optional[str] = None
20
+ exclude_cols: Optional[List[str]] = None
21
+
22
+
23
+ DataSource = Union[DataFrameInfo, DatasetInfo, str]
@@ -11,7 +11,6 @@ from snowflake.ml._internal.exceptions import (
11
11
  error_codes,
12
12
  exceptions as snowml_exceptions,
13
13
  )
14
- from snowflake.ml._internal.lineage import data_source
15
14
  from snowflake.ml._internal.utils import (
16
15
  formatting,
17
16
  identifier,
@@ -177,18 +176,7 @@ class Dataset(lineage_node.LineageNode):
177
176
  original_exception=RuntimeError("No Dataset version selected."),
178
177
  )
179
178
  if self._reader is None:
180
- v = self.selected_version
181
- self._reader = dataset_reader.DatasetReader(
182
- self._session,
183
- [
184
- data_source.DataSource(
185
- fully_qualified_name=self._lineage_node_name,
186
- version=v.name,
187
- url=v.url(),
188
- exclude_cols=(v.label_cols + v.exclude_cols),
189
- )
190
- ],
191
- )
179
+ self._reader = dataset_reader.DatasetReader.from_dataset(self, snowpark_session=self._session)
192
180
  return self._reader
193
181
 
194
182
  @staticmethod
@@ -1,48 +1,31 @@
1
- from typing import Any, List
2
-
3
- import pandas as pd
4
- from pyarrow import parquet as pq
1
+ from typing import List, Optional
5
2
 
6
3
  from snowflake import snowpark
7
4
  from snowflake.ml._internal import telemetry
8
- from snowflake.ml._internal.lineage import data_source, lineage_utils
9
- from snowflake.ml._internal.utils import import_utils
5
+ from snowflake.ml._internal.lineage import lineage_utils
6
+ from snowflake.ml.data import data_connector, data_ingestor, data_source
7
+ from snowflake.ml.data._internal import ingestor_utils
10
8
  from snowflake.ml.fileset import snowfs
11
9
 
12
10
  _PROJECT = "Dataset"
13
11
  _SUBPROJECT = "DatasetReader"
14
- TARGET_FILE_SIZE = 32 * 2**20 # The max file size for data loading.
15
12
 
16
13
 
17
- class DatasetReader:
14
+ class DatasetReader(data_connector.DataConnector):
18
15
  """Snowflake Dataset abstraction which provides application integration connectors"""
19
16
 
20
17
  @telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT)
21
18
  def __init__(
22
19
  self,
23
- session: snowpark.Session,
24
- sources: List[data_source.DataSource],
20
+ ingestor: data_ingestor.DataIngestor,
21
+ *,
22
+ snowpark_session: snowpark.Session,
25
23
  ) -> None:
26
- """Initialize a DatasetVersion object.
24
+ super().__init__(ingestor)
27
25
 
28
- Args:
29
- session: Snowpark Session to interact with Snowflake backend.
30
- sources: Data sources to read from.
31
-
32
- Raises:
33
- ValueError: `sources` arg was empty or null
34
- """
35
- if not sources:
36
- raise ValueError("Invalid input: empty `sources` list not allowed")
37
- self._session = session
38
- self._sources = sources
39
- self._fs: snowfs.SnowFileSystem = snowfs.SnowFileSystem(
40
- snowpark_session=self._session,
41
- cache_type="bytes",
42
- block_size=2 * TARGET_FILE_SIZE,
43
- )
44
-
45
- self._files: List[str] = []
26
+ self._session: snowpark.Session = snowpark_session
27
+ self._fs: snowfs.SnowFileSystem = ingestor_utils.get_dataset_filesystem(self._session)
28
+ self._files: Optional[List[str]] = None
46
29
 
47
30
  def _list_files(self) -> List[str]:
48
31
  """Private helper function that lists all files in this DatasetVersion and caches the results."""
@@ -50,18 +33,14 @@ class DatasetReader:
50
33
  return self._files
51
34
 
52
35
  files: List[str] = []
53
- for source in self._sources:
54
- # Sort within each source for consistent ordering
55
- files.extend(sorted(self._fs.ls(source.url))) # type: ignore[arg-type]
36
+ for source in self.data_sources:
37
+ assert isinstance(source, data_source.DatasetInfo)
38
+ files.extend(ingestor_utils.get_dataset_files(self._session, source, filesystem=self._fs))
56
39
  files.sort()
57
40
 
58
41
  self._files = files
59
42
  return self._files
60
43
 
61
- @property
62
- def data_sources(self) -> List[data_source.DataSource]:
63
- return self._sources
64
-
65
44
  @telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT)
66
45
  def files(self) -> List[str]:
67
46
  """Get the list of remote file paths for the current DatasetVersion.
@@ -85,76 +64,6 @@ class DatasetReader:
85
64
  """Return an fsspec FileSystem which can be used to load the DatasetVersion's `files()`"""
86
65
  return self._fs
87
66
 
88
- @telemetry.send_api_usage_telemetry(
89
- project=_PROJECT,
90
- subproject=_SUBPROJECT,
91
- func_params_to_log=["batch_size", "shuffle", "drop_last_batch"],
92
- )
93
- def to_torch_datapipe(self, *, batch_size: int, shuffle: bool = False, drop_last_batch: bool = True) -> Any:
94
- """Transform the Snowflake data into a ready-to-use Pytorch datapipe.
95
-
96
- Return a Pytorch datapipe which iterates on rows of data.
97
-
98
- Args:
99
- batch_size: It specifies the size of each data batch which will be
100
- yield in the result datapipe
101
- shuffle: It specifies whether the data will be shuffled. If True, files will be shuffled, and
102
- rows in each file will also be shuffled.
103
- drop_last_batch: Whether the last batch of data should be dropped. If set to be true,
104
- then the last batch will get dropped if its size is smaller than the given batch_size.
105
-
106
- Returns:
107
- A Pytorch iterable datapipe that yield data.
108
-
109
- Examples:
110
- >>> dp = dataset.to_torch_datapipe(batch_size=1)
111
- >>> for data in dp:
112
- >>> print(data)
113
- ----
114
- {'_COL_1':[10]}
115
- """
116
- IterableWrapper, _ = import_utils.import_or_get_dummy("torchdata.datapipes.iter.IterableWrapper")
117
- torch_datapipe_module, _ = import_utils.import_or_get_dummy("snowflake.ml.fileset.torch_datapipe")
118
-
119
- self._fs.optimize_read(self._list_files())
120
-
121
- input_dp = IterableWrapper(self._list_files())
122
- return torch_datapipe_module.ReadAndParseParquet(input_dp, self._fs, batch_size, shuffle, drop_last_batch)
123
-
124
- @telemetry.send_api_usage_telemetry(
125
- project=_PROJECT,
126
- subproject=_SUBPROJECT,
127
- func_params_to_log=["batch_size", "shuffle", "drop_last_batch"],
128
- )
129
- def to_tf_dataset(self, *, batch_size: int, shuffle: bool = False, drop_last_batch: bool = True) -> Any:
130
- """Transform the Snowflake data into a ready-to-use TensorFlow tf.data.Dataset.
131
-
132
- Args:
133
- batch_size: It specifies the size of each data batch which will be
134
- yield in the result datapipe
135
- shuffle: It specifies whether the data will be shuffled. If True, files will be shuffled, and
136
- rows in each file will also be shuffled.
137
- drop_last_batch: Whether the last batch of data should be dropped. If set to be true,
138
- then the last batch will get dropped if its size is smaller than the given batch_size.
139
-
140
- Returns:
141
- A tf.data.Dataset that yields batched tf.Tensors.
142
-
143
- Examples:
144
- >>> dp = dataset.to_tf_dataset(batch_size=1)
145
- >>> for data in dp:
146
- >>> print(data)
147
- ----
148
- {'_COL_1': <tf.Tensor: shape=(1,), dtype=int64, numpy=[10]>}
149
- """
150
- tf_dataset_module, _ = import_utils.import_or_get_dummy("snowflake.ml.fileset.tf_dataset")
151
-
152
- self._fs.optimize_read(self._list_files())
153
-
154
- return tf_dataset_module.read_and_parse_parquet(
155
- self._list_files(), self._fs, batch_size, shuffle, drop_last_batch
156
- )
157
-
158
67
  @telemetry.send_api_usage_telemetry(
159
68
  project=_PROJECT,
160
69
  subproject=_SUBPROJECT,
@@ -177,7 +86,8 @@ class DatasetReader:
177
86
  """
178
87
  file_path_pattern = ".*data_.*[.]parquet"
179
88
  dfs: List[snowpark.DataFrame] = []
180
- for source in self._sources:
89
+ for source in self.data_sources:
90
+ assert isinstance(source, data_source.DatasetInfo) and source.url is not None
181
91
  df = self._session.read.option("pattern", file_path_pattern).parquet(source.url)
182
92
  if only_feature_cols and source.exclude_cols:
183
93
  df = df.drop(source.exclude_cols)
@@ -186,14 +96,4 @@ class DatasetReader:
186
96
  combined_df = dfs[0]
187
97
  for df in dfs[1:]:
188
98
  combined_df = combined_df.union_all_by_name(df)
189
- return lineage_utils.patch_dataframe(combined_df, data_sources=self._sources, inplace=True)
190
-
191
- @telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT)
192
- def to_pandas(self) -> pd.DataFrame:
193
- """Retrieve the DatasetVersion contents as a Pandas Dataframe"""
194
- files = self._list_files()
195
- if not files:
196
- return pd.DataFrame() # Return empty DataFrame
197
- self._fs.optimize_read(files)
198
- pd_ds = pq.ParquetDataset(files, filesystem=self._fs)
199
- return pd_ds.read_pandas().to_pandas()
99
+ return lineage_utils.patch_dataframe(combined_df, data_sources=self.data_sources, inplace=True)
@@ -273,7 +273,13 @@ def setup_feature_store(
273
273
  assert current_role is not None # to make mypy happy
274
274
  try:
275
275
  session.use_role(producer_role)
276
- fs = FeatureStore(session, database, schema, warehouse, creation_mode=CreationMode.CREATE_IF_NOT_EXIST)
276
+ fs = FeatureStore(
277
+ session,
278
+ database,
279
+ schema,
280
+ default_warehouse=warehouse,
281
+ creation_mode=CreationMode.CREATE_IF_NOT_EXIST,
282
+ )
277
283
  finally:
278
284
  session.use_role(current_role)
279
285
 
@@ -22,7 +22,7 @@ class Entity:
22
22
  It can also be used for FeatureView search and lineage tracking.
23
23
  """
24
24
 
25
- def __init__(self, name: str, join_keys: List[str], desc: str = "") -> None:
25
+ def __init__(self, name: str, join_keys: List[str], *, desc: str = "") -> None:
26
26
  """
27
27
  Creates an Entity instance.
28
28
 
@@ -30,6 +30,23 @@ class Entity:
30
30
  name: name of the Entity.
31
31
  join_keys: join keys associated with a FeatureView, used for feature retrieval.
32
32
  desc: description of the Entity.
33
+
34
+ Example::
35
+
36
+ >>> fs = FeatureStore(...)
37
+ >>> e_1 = Entity(
38
+ ... name="my_entity",
39
+ ... join_keys=['col_1'],
40
+ ... desc='My first entity.'
41
+ ... )
42
+ >>> fs.register_entity(e_1)
43
+ >>> fs.list_entities().show()
44
+ -----------------------------------------------------------
45
+ |"NAME" |"JOIN_KEYS" |"DESC" |"OWNER" |
46
+ -----------------------------------------------------------
47
+ |MY_ENTITY |["COL_1"] |My first entity. |REGTEST_RL |
48
+ -----------------------------------------------------------
49
+
33
50
  """
34
51
  self._validate(name, join_keys)
35
52
 
@@ -65,7 +82,7 @@ class Entity:
65
82
 
66
83
  @staticmethod
67
84
  def _construct_entity(name: str, join_keys: List[str], desc: str, owner: str) -> "Entity":
68
- e = Entity(name, join_keys, desc)
85
+ e = Entity(name, join_keys, desc=desc)
69
86
  e.owner = owner
70
87
  return e
71
88
 
@@ -0,0 +1,20 @@
1
+ from typing import List
2
+
3
+ from snowflake.ml.feature_store import Entity
4
+
5
+ end_station_id = Entity(
6
+ name="end_station_id",
7
+ join_keys=["end_station_id"],
8
+ desc="The id of an end station.",
9
+ )
10
+
11
+ trip_id = Entity(
12
+ name="trip_id",
13
+ join_keys=["trip_id"],
14
+ desc="The id of a trip.",
15
+ )
16
+
17
+
18
+ # This will be invoked by example_helper.py. Do not change function name.
19
+ def get_all_entities() -> List[Entity]:
20
+ return [end_station_id, trip_id]
@@ -0,0 +1,31 @@
1
+ from typing import List
2
+
3
+ from snowflake.ml.feature_store import FeatureView
4
+ from snowflake.ml.feature_store.examples.citibike_trip_features.entities import (
5
+ end_station_id,
6
+ )
7
+ from snowflake.snowpark import DataFrame, Session
8
+
9
+
10
+ # This function will be invoked by example_helper.py. Do not change the name.
11
+ def create_draft_feature_view(session: Session, source_dfs: List[DataFrame], source_tables: List[str]) -> FeatureView:
12
+ """Create a feature view about trip station."""
13
+ query = session.sql(
14
+ f"""
15
+ select
16
+ end_station_id,
17
+ count(end_station_id) as f_count_1d,
18
+ avg(end_station_latitude) as f_avg_latitude_1d,
19
+ avg(end_station_longitude) as f_avg_longtitude_1d
20
+ from {source_tables[0]}
21
+ group by end_station_id
22
+ """
23
+ )
24
+
25
+ return FeatureView(
26
+ name="f_station_1d", # name of feature view
27
+ entities=[end_station_id], # entities
28
+ feature_df=query, # definition query
29
+ refresh_freq="1d", # refresh frequency. '1d' means it refreshes everyday
30
+ desc="Station features refreshed every day.",
31
+ )
@@ -0,0 +1,24 @@
1
+ from typing import List
2
+
3
+ from snowflake.ml.feature_store import FeatureView
4
+ from snowflake.ml.feature_store.examples.citibike_trip_features.entities import trip_id
5
+ from snowflake.snowpark import DataFrame, Session, functions as F
6
+
7
+
8
+ # This function will be invoked by example_helper.py. Do not change the name.
9
+ def create_draft_feature_view(session: Session, source_dfs: List[DataFrame], source_tables: List[str]) -> FeatureView:
10
+ """Create a feature view about trip."""
11
+ feature_df = source_dfs[0].select(
12
+ "trip_id",
13
+ F.col("birth_year").alias("f_birth_year"),
14
+ F.col("gender").alias("f_gender"),
15
+ F.col("bikeid").alias("f_bikeid"),
16
+ )
17
+
18
+ return FeatureView(
19
+ name="f_trip", # name of feature view
20
+ entities=[trip_id], # entities
21
+ feature_df=feature_df, # definition query
22
+ refresh_freq=None, # refresh frequency. None indicates it never refresh
23
+ desc="Static trip features",
24
+ )
@@ -0,0 +1,4 @@
1
+ ---
2
+ source_data: citibike_trips
3
+ label_columns: tripduration
4
+ add_id_column: trip_id