wedata-feature-engineering 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. wedata/__init__.py +1 -1
  2. wedata/feature_store/client.py +113 -41
  3. wedata/feature_store/constants/constants.py +19 -0
  4. wedata/feature_store/entities/column_info.py +4 -4
  5. wedata/feature_store/entities/feature_lookup.py +5 -1
  6. wedata/feature_store/entities/feature_spec.py +46 -46
  7. wedata/feature_store/entities/feature_table.py +42 -99
  8. wedata/feature_store/entities/training_set.py +13 -12
  9. wedata/feature_store/feature_table_client/feature_table_client.py +85 -30
  10. wedata/feature_store/spark_client/spark_client.py +30 -56
  11. wedata/feature_store/training_set_client/training_set_client.py +209 -38
  12. wedata/feature_store/utils/common_utils.py +213 -3
  13. wedata/feature_store/utils/feature_lookup_utils.py +6 -6
  14. wedata/feature_store/utils/feature_spec_utils.py +6 -6
  15. wedata/feature_store/utils/feature_utils.py +5 -5
  16. wedata/feature_store/utils/on_demand_utils.py +107 -0
  17. wedata/feature_store/utils/schema_utils.py +1 -1
  18. wedata/feature_store/utils/signature_utils.py +205 -0
  19. wedata/feature_store/utils/training_set_utils.py +18 -19
  20. wedata/feature_store/utils/uc_utils.py +1 -1
  21. {wedata_feature_engineering-0.1.5.dist-info → wedata_feature_engineering-0.1.6.dist-info}/METADATA +1 -1
  22. wedata_feature_engineering-0.1.6.dist-info/RECORD +43 -0
  23. feature_store/__init__.py +0 -6
  24. feature_store/client.py +0 -169
  25. feature_store/constants/__init__.py +0 -0
  26. feature_store/constants/constants.py +0 -28
  27. feature_store/entities/__init__.py +0 -0
  28. feature_store/entities/column_info.py +0 -117
  29. feature_store/entities/data_type.py +0 -92
  30. feature_store/entities/environment_variables.py +0 -55
  31. feature_store/entities/feature.py +0 -53
  32. feature_store/entities/feature_column_info.py +0 -64
  33. feature_store/entities/feature_function.py +0 -55
  34. feature_store/entities/feature_lookup.py +0 -179
  35. feature_store/entities/feature_spec.py +0 -454
  36. feature_store/entities/feature_spec_constants.py +0 -25
  37. feature_store/entities/feature_table.py +0 -164
  38. feature_store/entities/feature_table_info.py +0 -40
  39. feature_store/entities/function_info.py +0 -184
  40. feature_store/entities/on_demand_column_info.py +0 -44
  41. feature_store/entities/source_data_column_info.py +0 -21
  42. feature_store/entities/training_set.py +0 -134
  43. feature_store/feature_table_client/__init__.py +0 -0
  44. feature_store/feature_table_client/feature_table_client.py +0 -313
  45. feature_store/spark_client/__init__.py +0 -0
  46. feature_store/spark_client/spark_client.py +0 -286
  47. feature_store/training_set_client/__init__.py +0 -0
  48. feature_store/training_set_client/training_set_client.py +0 -196
  49. feature_store/utils/__init__.py +0 -0
  50. feature_store/utils/common_utils.py +0 -96
  51. feature_store/utils/feature_lookup_utils.py +0 -570
  52. feature_store/utils/feature_spec_utils.py +0 -286
  53. feature_store/utils/feature_utils.py +0 -73
  54. feature_store/utils/schema_utils.py +0 -117
  55. feature_store/utils/topological_sort.py +0 -158
  56. feature_store/utils/training_set_utils.py +0 -580
  57. feature_store/utils/uc_utils.py +0 -281
  58. feature_store/utils/utils.py +0 -252
  59. feature_store/utils/validation_utils.py +0 -55
  60. wedata/feature_store/utils/utils.py +0 -252
  61. wedata_feature_engineering-0.1.5.dist-info/RECORD +0 -79
  62. {wedata_feature_engineering-0.1.5.dist-info → wedata_feature_engineering-0.1.6.dist-info}/WHEEL +0 -0
  63. {wedata_feature_engineering-0.1.5.dist-info → wedata_feature_engineering-0.1.6.dist-info}/top_level.txt +0 -0
@@ -1,286 +0,0 @@
1
- import logging
2
- from dataclasses import dataclass
3
- from functools import reduce
4
- from typing import Dict, List, Set, Tuple, Type, Union
5
-
6
- import yaml
7
- from mlflow.utils.file_utils import YamlSafeDumper
8
-
9
- from feature_store.entities.column_info import ColumnInfo
10
- from feature_store.entities.feature_column_info import FeatureColumnInfo
11
- from feature_store.entities.feature_spec import FeatureSpec
12
- from feature_store.entities.on_demand_column_info import OnDemandColumnInfo
13
- from feature_store.entities.source_data_column_info import SourceDataColumnInfo
14
- from feature_store.utils.topological_sort import topological_sort
15
-
16
- DEFAULT_GRAPH_DEPTH_LIMIT = 5
17
-
18
- COLUMN_INFO_TYPE_SOURCE = "SOURCE"
19
- COLUMN_INFO_TYPE_ON_DEMAND = "ON_DEMAND"
20
- COLUMN_INFO_TYPE_FEATURE = "FEATURE"
21
-
22
- _logger = logging.getLogger(__name__)
23
-
24
-
25
- @dataclass
26
- class FeatureExecutionGroup:
27
- type: str # could be FEATURE, ON_DEMAND, SOURCE
28
- features: Union[
29
- List[FeatureColumnInfo], List[OnDemandColumnInfo], List[SourceDataColumnInfo]
30
- ]
31
-
32
-
33
- # Small number has high priority. Besides SOURCE, preferring FEATURE over ON_DEMAND in topological
34
- # sorting to make sure ON_DEMAND columns after FEATURE in simple cases to align with previous
35
- # assumption before implementing TLT.
36
- # NOTE: changing this priority may cause performance regression, proceed with caution.
37
- COLUMN_TYPE_PRIORITY = {
38
- COLUMN_INFO_TYPE_SOURCE: 0,
39
- COLUMN_INFO_TYPE_ON_DEMAND: 1,
40
- COLUMN_INFO_TYPE_FEATURE: 2,
41
- }
42
-
43
-
44
- class _GraphNode:
45
- def __init__(self, column_info: ColumnInfo):
46
- info = column_info.info
47
- self.column_info = column_info
48
- self.output_name = info.output_name
49
-
50
- if isinstance(column_info.info, SourceDataColumnInfo):
51
- self.input_names = set()
52
- self.type = COLUMN_INFO_TYPE_SOURCE
53
- elif isinstance(column_info.info, FeatureColumnInfo):
54
- self.input_names = set(info.lookup_key)
55
- self.type = COLUMN_INFO_TYPE_FEATURE
56
- elif isinstance(column_info.info, OnDemandColumnInfo):
57
- self.input_names = set(info.input_bindings.values())
58
- self.type = COLUMN_INFO_TYPE_ON_DEMAND
59
- else:
60
- raise ValueError("unknown column info type")
61
-
62
- def __str__(self):
63
- return "node<" + self.output_name + ">"
64
-
65
- def __repr__(self):
66
- return str(self)
67
-
68
-
69
- def _column_info_sort_key(node: _GraphNode) -> Tuple[int, str]:
70
- """
71
- Returns a tuple of an int and a str as the sorting key for _GraphNode. Priority is determined by
72
- the first element and then use the second element to break ties.
73
- """
74
- return COLUMN_TYPE_PRIORITY[node.type], node.output_name
75
-
76
-
77
- def _should_be_grouped(node: _GraphNode) -> bool:
78
- """
79
- Returns True if the given node is of type that should be grouped together as much as possible.
80
- """
81
- return node.type == COLUMN_INFO_TYPE_FEATURE
82
-
83
-
84
- def _validate_graph_depth(nodes: List[_GraphNode], depth_limit: int):
85
- name_to_node = {node.output_name: node for node in nodes}
86
- visited_depth = {}
87
-
88
- def dfs(node: _GraphNode, depth: int):
89
- if depth > depth_limit:
90
- raise ValueError(
91
- f"The given graph contains a dependency path longer than the limit {depth_limit}"
92
- )
93
- if (
94
- node.output_name in visited_depth
95
- and depth <= visited_depth[node.output_name]
96
- ):
97
- return
98
- visited_depth[node.output_name] = depth
99
- for column_name in node.input_names:
100
- dependency = name_to_node[column_name]
101
- dfs(dependency, depth + 1)
102
-
103
- for node in nodes:
104
- dfs(node, 1)
105
-
106
-
107
- def get_encoded_graph_map(column_infos: List[ColumnInfo]) -> Dict[str, List[str]]:
108
- """
109
- Creates a dictionary of columns with their dependency columns for metric use. Columns are
110
- encoded with a string representing the type and index. For example:
111
- {
112
- "f3": ["s1", "s2"],
113
- "o4": ["f3"],
114
- "o5": []
115
- }
116
- "s1" and "s2" are SourceColumnInfos, "f3" is FeatureColumnInfo and "o4", "o5" are
117
- OnDemandColumnInfos. "f3" depends on "s1" and "s2", "o5" doesn't depend on any column, etc.
118
- :param column_infos: A list of ColumnInfos.
119
- """
120
- nodes = {info.output_name: _GraphNode(info) for info in column_infos}
121
- next_node_index = 0
122
- # A map from column info's output_name to its label.
123
- node_label = {}
124
-
125
- def get_node_label(node):
126
- nonlocal next_node_index
127
- output_name = node.output_name
128
- if output_name not in node_label:
129
- if node.type == COLUMN_INFO_TYPE_SOURCE:
130
- type_simple_str = "s"
131
- if node.type == COLUMN_INFO_TYPE_FEATURE:
132
- type_simple_str = "f"
133
- if node.type == COLUMN_INFO_TYPE_ON_DEMAND:
134
- type_simple_str = "o"
135
- new_label = type_simple_str + str(next_node_index)
136
- next_node_index += 1
137
- node_label[output_name] = new_label
138
- return node_label[output_name]
139
-
140
- graph_map = {}
141
- for node in nodes.values():
142
- label = get_node_label(node)
143
- dependencies = []
144
- for dep_name in sorted(node.input_names):
145
- if dep_name not in nodes:
146
- # skip the column if it's not in the feature spec.
147
- continue
148
- dep = get_node_label(nodes[dep_name])
149
- dependencies.append(dep)
150
- graph_map[label] = dependencies
151
- return graph_map
152
-
153
-
154
- def assign_topological_ordering(
155
- column_infos: List[ColumnInfo],
156
- allow_missing_source_columns=False,
157
- graph_depth_limit=DEFAULT_GRAPH_DEPTH_LIMIT,
158
- ) -> List[ColumnInfo]:
159
- """
160
- Assigns the topological ordering for each ColumnInfo of the input. Returns a list of new
161
- ColumnInfo objects with topological_ordering set to an integer.
162
-
163
- :param column_infos: a list of ColumnInfos.
164
- :param allow_missing_source_columns: ONLY USED BY FSE TEMPORARILY. Allow lookup key or
165
- function input be missing from source columns. If true, this method will assign
166
- topological_ordering to columns as if the missing sources are added in the column_infos.
167
- :param graph_depth_limit raises if the given graph exceed the limit.
168
- :raises ValueError if there is a cycle in the graph.
169
- """
170
- nodes = list(map(lambda c: _GraphNode(c), column_infos))
171
- # allow_missing_source_columns is used when feature_serving_endpoint_client creates training
172
- # sets. It doesn't include source columns in the dataframe.
173
- # TODO[ML-33809]: clean up allow_missing_source_columns.
174
- all_output_names = set([n.output_name for n in nodes])
175
- all_input_names = reduce(lambda a, b: a | b, [n.input_names for n in nodes])
176
- missing_inputs = all_input_names - all_output_names
177
- if allow_missing_source_columns:
178
- for input_name in missing_inputs:
179
- if input_name not in all_output_names:
180
- nodes.append(
181
- _GraphNode(ColumnInfo(SourceDataColumnInfo(input_name), False))
182
- )
183
- elif len(missing_inputs) > 0:
184
- missing_input_names_str = ", ".join(
185
- [f"'{name}'" for name in sorted(missing_inputs)]
186
- )
187
- raise ValueError(
188
- f"Input columns {missing_input_names_str} required by FeatureLookups or "
189
- "FeatureFunctions are not provided by input DataFrame or other FeatureFunctions and "
190
- "FeatureLookups"
191
- )
192
- output_name_to_node = {node.output_name: node for node in nodes}
193
- graph = {
194
- node: [output_name_to_node[input_name] for input_name in node.input_names]
195
- for node in nodes
196
- }
197
- sorted_nodes = topological_sort(graph, _column_info_sort_key, _should_be_grouped)
198
- # validate depth after sorting the graph because cycle is detected during sorting.
199
- _validate_graph_depth(nodes, graph_depth_limit)
200
- name_to_ordering = {node.output_name: i for i, node in enumerate(sorted_nodes)}
201
- return [
202
- column.with_topological_ordering(name_to_ordering[column.output_name])
203
- for column in column_infos
204
- ]
205
-
206
-
207
- def get_feature_execution_groups(
208
- feature_spec: FeatureSpec, df_columns: List[str] = []
209
- ) -> List[FeatureExecutionGroup]:
210
- """
211
- Splits the list of column_infos in feature_spec into groups based on the topological_ordering of
212
- the column_infos such that each group contains only one type of feature columns and columns
213
- don't depend on other columns in the same group. The type of feature column is equivalent to the
214
- class type of column_info.info field.
215
- Example:
216
- Given FeatureSpec with some columns, after sorting the columns by topological_ordering,
217
- assuming the sorted list:
218
- [source_1, feature_2, feature_3, on_demand_4, on_demand_5]
219
- where feature_2 depends on feature_3. The resulting groups will be:
220
- [
221
- group(SOURCE, [source_1]),
222
- group(FEATURE, [feature_2]),
223
- group(FEATURE, [feature_3]),
224
- group(ON_DEMAND, [on_demand_4, on_demand_5]),
225
- ]
226
-
227
- :param feature_spec: A FeatureSpec with topologically sorted column_infos.
228
- :param df_columns: the columns from the DF used to create_training_set or score_batch.
229
- """
230
- # convert column infos into _GraphNode
231
- nodes = list(map(lambda c: _GraphNode(c), feature_spec.column_infos))
232
- if any(info.topological_ordering is None for info in feature_spec.column_infos):
233
- # The old version of feature_spec may not have topological_ordering, we can safely assume
234
- # they are already sorted because of validations during the feature_spec creation.
235
- _logger.warning(
236
- "Processing a feature spec that at least one of the column_infos has no "
237
- "topological_ordering"
238
- )
239
- else:
240
- # sort nodes by topological_ordering
241
- nodes = sorted(nodes, key=lambda n: n.column_info.topological_ordering)
242
- # A buffer holding the columns in a group.
243
- buffer = []
244
- # output names of columns in the current buffer.
245
- buffered_output_names = set()
246
- # Used to validate the topological sorting.
247
- # df_columns is used to be backward compatible. In old FeatureSpecs, source columns might not
248
- # exist. So we need to consider the df as initial resolved columns.
249
- resolved_columns = set(df_columns)
250
- result_list = []
251
- last_type = None
252
- for node in nodes:
253
- if not node.input_names.issubset(resolved_columns):
254
- raise ValueError(
255
- "The column_infos in the FeatureSpec is not topologically sorted"
256
- )
257
- if node.type != last_type or buffered_output_names.intersection(
258
- node.input_names
259
- ):
260
- # split group if the current node has a different type from the previous node OR
261
- # any of the inputs are from the nodes in the current group.
262
- if buffer:
263
- result_list.append(FeatureExecutionGroup(last_type, buffer))
264
- buffer = []
265
- buffered_output_names.clear()
266
- last_type = node.type
267
- buffer.append(node.column_info.info)
268
- resolved_columns.add(node.output_name)
269
- buffered_output_names.add(node.output_name)
270
- if buffer:
271
- result_list.append(FeatureExecutionGroup(last_type, buffer))
272
- return result_list
273
-
274
-
275
- def convert_to_yaml_string(feature_spec: FeatureSpec) -> str:
276
- """
277
- Converts the given FeatureSpec to a YAML string.
278
- """
279
- feature_spec_dict = feature_spec._to_dict()
280
- return yaml.dump(
281
- feature_spec_dict,
282
- default_flow_style=False,
283
- allow_unicode=True,
284
- sort_keys=False,
285
- Dumper=YamlSafeDumper,
286
- )
@@ -1,73 +0,0 @@
1
- import copy
2
- from typing import List, Union
3
-
4
- from feature_store.entities.feature_function import FeatureFunction
5
- from feature_store.entities.feature_lookup import FeatureLookup
6
- from feature_store.spark_client.spark_client import SparkClient
7
- from feature_store.utils import uc_utils
8
- from feature_store.utils.feature_lookup_utils import get_feature_lookups_with_full_table_names
9
-
10
-
11
- def format_feature_lookups_and_functions(
12
- _spark_client: SparkClient, features: List[Union[FeatureLookup, FeatureFunction]]
13
- ):
14
- fl_idx = []
15
- ff_idx = []
16
- feature_lookups = []
17
- feature_functions = []
18
- for idx, feature in enumerate(features):
19
- if isinstance(feature, FeatureLookup):
20
- fl_idx.append(idx)
21
- feature_lookups.append(feature)
22
- elif isinstance(feature, FeatureFunction):
23
- ff_idx.append(idx)
24
- feature_functions.append(feature)
25
- else:
26
- raise ValueError(
27
- f"Expected a list of FeatureLookups for 'feature_lookups', but received type '{type(feature)}'."
28
- )
29
-
30
- # FeatureLookups and FeatureFunctions must have fully qualified table, UDF names
31
- feature_lookups = get_feature_lookups_with_full_table_names(
32
- feature_lookups,
33
- _spark_client.get_current_catalog(),
34
- _spark_client.get_current_database(),
35
- )
36
- feature_functions = get_feature_functions_with_full_udf_names(
37
- feature_functions,
38
- _spark_client.get_current_catalog(),
39
- _spark_client.get_current_database(),
40
- )
41
-
42
- # Restore original order of FeatureLookups, FeatureFunctions. Copy to avoid mutating original list.
43
- features = features.copy()
44
- for idx, feature in zip(fl_idx + ff_idx, feature_lookups + feature_functions):
45
- features[idx] = feature
46
-
47
- return features
48
-
49
-
50
- def get_feature_functions_with_full_udf_names(
51
- feature_functions: List[FeatureFunction], current_catalog: str, current_schema: str
52
- ):
53
- """
54
- Takes in a list of FeatureFunctions, and returns copies with:
55
- 1. Fully qualified UDF names.
56
- 2. If output_name is empty, fully qualified UDF names as output_name.
57
- """
58
- udf_names = {ff.udf_name for ff in feature_functions}
59
- uc_utils._check_qualified_udf_names(udf_names)
60
- uc_utils._verify_all_udfs_in_uc(udf_names, current_catalog, current_schema)
61
-
62
- standardized_feature_functions = []
63
- for ff in feature_functions:
64
- ff_copy = copy.deepcopy(ff)
65
- del ff
66
-
67
- ff_copy._udf_name = uc_utils.get_full_udf_name(
68
- ff_copy.udf_name, current_catalog, current_schema
69
- )
70
- if not ff_copy.output_name:
71
- ff_copy._output_name = ff_copy.udf_name
72
- standardized_feature_functions.append(ff_copy)
73
- return standardized_feature_functions
@@ -1,117 +0,0 @@
1
- import logging
2
-
3
- from feature_store.constants.constants import _ERROR, _WARN
4
-
5
- _logger = logging.getLogger(__name__)
6
-
7
-
8
- def catalog_matches_delta_schema(catalog_features, df_schema, column_filter=None):
9
- """
10
- Confirm that the column names and column types are the same.
11
-
12
- Returns True if identical, False if there is a mismatch.
13
-
14
- If column_filter is not None, only columns in column_filter must match.
15
- """
16
- if column_filter is not None:
17
- catalog_features = [c for c in catalog_features if c.name in column_filter]
18
- df_schema = [c for c in df_schema if c.name in column_filter]
19
-
20
- catalog_schema = {
21
- feature.name: feature.data_type
22
- for feature in catalog_features
23
- }
24
- delta_schema = {
25
- feature.name: feature.dataType
26
- for feature in df_schema
27
- }
28
-
29
- complex_catalog_schema = get_complex_catalog_schema(
30
- catalog_features, catalog_schema
31
- )
32
- complex_delta_schema = get_complex_delta_schema(df_schema, delta_schema)
33
-
34
- return (
35
- catalog_schema == delta_schema
36
- and complex_catalog_schema == complex_delta_schema
37
- )
38
-
39
-
40
- def get_complex_delta_schema(delta_features, delta_feature_names_to_fs_types):
41
- """
42
- 1. Filter delta features to features that have complex datatypes.
43
- 2. Take the existing Spark DataType stored on the Delta features. This is later used for
44
- comparison against the Catalog schema's complex Spark DataTypes.
45
- 3. Return a mapping of feature name to their respective complex Spark DataTypes.
46
-
47
- :param delta_features: List[Feature]. List of features stored in Delta.
48
- :param delta_feature_names_to_fs_types: Map[str, feature_store.DataType]. A mapping of feature
49
- names to their respective Feature Store DataTypes.
50
- :return: Map[str, spark.sql.types.DataType]. A mapping of feature names to their respective
51
- Spark DataTypes.
52
- """
53
- complex_delta_features = [
54
- feature
55
- for feature in delta_features
56
- if delta_feature_names_to_fs_types[feature.name] in DATA_TYPES_REQUIRES_DETAILS
57
- ]
58
- complex_delta_feature_names_to_spark_types = {
59
- feature.name: feature.dataType for feature in complex_delta_features
60
- }
61
- return complex_delta_feature_names_to_spark_types
62
-
63
-
64
- def get_complex_catalog_schema(catalog_features, catalog_feature_names_to_fs_types):
65
- """
66
- 1. Filter catalog features to features that have complex datatypes.
67
- 2. Convert the JSON string stored in each feature's data_type_details to the corresponding
68
- Spark DataType. This is later used for comparison against the Delta schema's complex Spark
69
- DataTypes.
70
- 3. Return a mapping of feature name to their respective complex Spark DataTypes.
71
-
72
- :param catalog_features: List[Feature]. List of features stored in the Catalog.
73
- :param catalog_feature_names_to_fs_types: Map[str, feature_store.DataType]. A mapping of feature
74
- names to their respective Feature Store DataTypes.
75
- :return: Map[str, spark.sql.types.DataType]. A mapping of feature names to their respective
76
- Spark DataTypes.
77
- """
78
- complex_catalog_features = [
79
- feature
80
- for feature in catalog_features
81
- if catalog_feature_names_to_fs_types[feature.name]
82
- in DATA_TYPES_REQUIRES_DETAILS
83
- ]
84
- complex_catalog_feature_names_to_spark_types = {
85
- feature.name: feature.data_type_details
86
- for feature in complex_catalog_features
87
- }
88
- return complex_catalog_feature_names_to_spark_types
89
-
90
-
91
- def log_catalog_schema_not_match_delta_schema(catalog_features, df_schema, level):
92
- """
93
- Log the catalog schema does not match the delta table schema.
94
-
95
- Example warning:
96
- Expected recorded schema from Feature Catalog to be identical with
97
- schema in delta table.Feature Catalog's schema is
98
- '{'id': 'INTEGER', 'feat1': 'INTEGER'}' while delta table's
99
- schema is '{'id': 'INTEGER', 'feat1': 'FLOAT'}'
100
- """
101
- catalog_schema = {feature.name: feature.data_type for feature in catalog_features}
102
- delta_schema = {
103
- feature.name: feature.dataType
104
- for feature in df_schema
105
- }
106
- msg = (
107
- f"Expected recorded schema from Feature Catalog to be identical with schema "
108
- f"in Delta table. "
109
- f"Feature Catalog's schema is '{catalog_schema}' while Delta table's schema "
110
- f"is '{delta_schema}'"
111
- )
112
- if level == _WARN:
113
- _logger.warning(msg)
114
- elif level == _ERROR:
115
- raise RuntimeError(msg)
116
- else:
117
- _logger.info(msg)
@@ -1,158 +0,0 @@
1
- from collections import defaultdict, deque
2
- from queue import PriorityQueue
3
- from typing import Callable, Dict, Hashable, List, Optional
4
-
5
- __all__ = ["find_cycle", "topological_sort"]
6
-
7
-
8
- class _NodeInfo:
9
- def __init__(self, node):
10
- self.node = node
11
- # number of non-processed predecessors.
12
- self.n_blockers = 0
13
- # list of nodes that depend on this node.
14
- self.successors = []
15
-
16
-
17
- def find_cycle(
18
- node_dependencies: Dict[Hashable, List[Hashable]]
19
- ) -> Optional[List[Hashable]]:
20
- """
21
- Finds a cycle in the node_dependencies graph. Returns a list of node(s) that forms a cycle or
22
- None if no cycle can be found.
23
- :param node_dependencies: A dict with hashable objects as keys and their list of dependency
24
- nodes as values.
25
- """
26
- # A stack used to perform DFS on the graph.
27
- stack = deque()
28
- # Another stack storing the path from root to current node in DFS. Used to detect cycle.
29
- backtrack_stack = []
30
- # A set of nodes that no cycle can be found starting from these nodes.
31
- resolved = set()
32
- # Create a copy of the dependency graph with defaultdict for convenience.
33
- default_dependency = defaultdict(list)
34
- default_dependency.update(node_dependencies)
35
-
36
- # Perform DFS on every node in the graph.
37
- for node in node_dependencies.keys():
38
- if node in resolved:
39
- # Skip a node if it's already resolved.
40
- continue
41
- # DFS from the node
42
- stack.append(node)
43
- while stack:
44
- top = stack[-1]
45
- if top not in backtrack_stack:
46
- # First time visiting this node. There will be a second visit after the dependencies
47
- # are resolved if it has dependencies.
48
- backtrack_stack.append(top)
49
- # If not expended after traversing the dependencies, meaning there is no dependency or
50
- # all dependencies are resolved.
51
- expanded = False
52
- for depend in default_dependency[top]:
53
- if depend in backtrack_stack:
54
- # found a cycle
55
- index = backtrack_stack.index(depend)
56
- return backtrack_stack[index:]
57
- if depend in resolved:
58
- continue
59
- # Only adding node to stack. backtrack_stack only contains nodes in the current DFS
60
- # path.
61
- stack.append(depend)
62
- expanded = True
63
- if not expanded:
64
- stack.pop()
65
- resolved.add(top)
66
- backtrack_stack.pop()
67
- return None
68
-
69
-
70
- def _all_items_in_queue_should_be_grouped(
71
- queue: PriorityQueue, should_be_grouped: Callable
72
- ) -> bool:
73
- temp = []
74
- should_group = True
75
- # note: avoid using queue.qsize() because it's not guaranteed to be accurate.
76
- while not queue.empty():
77
- k, node = queue.get()
78
- temp.append((k, node))
79
- if not should_be_grouped(node):
80
- should_group = False
81
- for item in temp:
82
- queue.put(item)
83
- return should_group
84
-
85
-
86
- def topological_sort(
87
- node_dependencies: Dict[Hashable, List[Hashable]],
88
- key: Callable = None,
89
- should_be_grouped: Callable = None,
90
- ) -> List[Hashable]:
91
- """
92
- Topological sort the given node_dependencies graph. Returns a sorted list of nodes.
93
- :param node_dependencies: A dict with hashable objects as keys and their list of dependency
94
- nodes as values.
95
- :param key: a Callable that returns a sort key when called with a hashable object. The key is
96
- used to break ties in topological sorting. An object with smaller key is added
97
- to the result list first.
98
- :raises ValueError if a cycle is found in the graph.
99
- """
100
- # Calling a dedicated find_cycle function to be able to give a detailed error message.
101
- cycle = find_cycle(node_dependencies)
102
- if cycle is not None:
103
- raise ValueError(
104
- "Following nodes form a cycle: ",
105
- cycle,
106
- ". Please resolve any circular dependencies before calling Feature Store.",
107
- )
108
-
109
- # A priority-queue storing the nodes whose dependency has been resolved.
110
- # priority is determined by the given key function.
111
- ready_queue = PriorityQueue()
112
- # Map from node to _NodeInfo.
113
- nodes = {}
114
- if key is None:
115
- key = hash # use the built-in hash function by default
116
- if should_be_grouped is None:
117
- should_be_grouped = lambda _: False
118
- # Perform Kahn's algorithm by traversing the graph starting from nodes without dependency.
119
- # Node is removed from its successors' dependency once resolved. And node whose dependency gets
120
- # all resolved is added to the priority queue.
121
- for node, dependencies in node_dependencies.items():
122
- # Initialize the graph to topologically sort based on the input node_dependencies.
123
- # All nodes, its successors and number of predecessors should be populated.
124
- if node not in nodes:
125
- nodes[node] = _NodeInfo(node)
126
- for dependency in dependencies:
127
- if dependency not in nodes:
128
- nodes[dependency] = _NodeInfo(dependency)
129
- nodes[dependency].successors.append(node)
130
- if len(dependencies):
131
- nodes[node].n_blockers = len(dependencies)
132
- # Initialize the ready_queue to start traversing the graph from nodes without any dependencies.
133
- for node, node_info in nodes.items():
134
- if node_info.n_blockers == 0:
135
- ready_queue.put((key(node), node))
136
- # At the end of the algorithm, result_list will have a topologically sorted listed of nodes.
137
- result_list = []
138
-
139
- def process_nodes(node_buffer, queue):
140
- for node in node_buffer:
141
- result_list.append(node)
142
- for successor in nodes[node].successors:
143
- s_info = nodes[successor]
144
- s_info.n_blockers -= 1
145
- if s_info.n_blockers == 0:
146
- queue.put((key(successor), successor))
147
-
148
- while not ready_queue.empty():
149
- if _all_items_in_queue_should_be_grouped(ready_queue, should_be_grouped):
150
- batch_buffer = []
151
- while not ready_queue.empty():
152
- _, node = ready_queue.get()
153
- batch_buffer.append(node)
154
- process_nodes(batch_buffer, ready_queue)
155
- else:
156
- _, node = ready_queue.get()
157
- process_nodes([node], ready_queue)
158
- return result_list