wedata-feature-engineering 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- feature_store/utils/__init__.py +0 -0
- feature_store/utils/common_utils.py +96 -0
- feature_store/utils/feature_lookup_utils.py +570 -0
- feature_store/utils/feature_spec_utils.py +286 -0
- feature_store/utils/feature_utils.py +73 -0
- feature_store/utils/schema_utils.py +117 -0
- feature_store/utils/topological_sort.py +158 -0
- feature_store/utils/training_set_utils.py +580 -0
- feature_store/utils/uc_utils.py +281 -0
- feature_store/utils/utils.py +252 -0
- feature_store/utils/validation_utils.py +55 -0
- wedata/__init__.py +6 -0
- wedata/feature_store/__init__.py +0 -0
- wedata/feature_store/client.py +169 -0
- wedata/feature_store/constants/__init__.py +0 -0
- wedata/feature_store/constants/constants.py +28 -0
- wedata/feature_store/entities/__init__.py +0 -0
- wedata/feature_store/entities/column_info.py +117 -0
- wedata/feature_store/entities/data_type.py +92 -0
- wedata/feature_store/entities/environment_variables.py +55 -0
- wedata/feature_store/entities/feature.py +53 -0
- wedata/feature_store/entities/feature_column_info.py +64 -0
- wedata/feature_store/entities/feature_function.py +55 -0
- wedata/feature_store/entities/feature_lookup.py +179 -0
- wedata/feature_store/entities/feature_spec.py +454 -0
- wedata/feature_store/entities/feature_spec_constants.py +25 -0
- wedata/feature_store/entities/feature_table.py +164 -0
- wedata/feature_store/entities/feature_table_info.py +40 -0
- wedata/feature_store/entities/function_info.py +184 -0
- wedata/feature_store/entities/on_demand_column_info.py +44 -0
- wedata/feature_store/entities/source_data_column_info.py +21 -0
- wedata/feature_store/entities/training_set.py +134 -0
- wedata/feature_store/feature_table_client/__init__.py +0 -0
- wedata/feature_store/feature_table_client/feature_table_client.py +313 -0
- wedata/feature_store/spark_client/__init__.py +0 -0
- wedata/feature_store/spark_client/spark_client.py +286 -0
- wedata/feature_store/training_set_client/__init__.py +0 -0
- wedata/feature_store/training_set_client/training_set_client.py +196 -0
- wedata/feature_store/utils/__init__.py +0 -0
- wedata/feature_store/utils/common_utils.py +96 -0
- wedata/feature_store/utils/feature_lookup_utils.py +570 -0
- wedata/feature_store/utils/feature_spec_utils.py +286 -0
- wedata/feature_store/utils/feature_utils.py +73 -0
- wedata/feature_store/utils/schema_utils.py +117 -0
- wedata/feature_store/utils/topological_sort.py +158 -0
- wedata/feature_store/utils/training_set_utils.py +580 -0
- wedata/feature_store/utils/uc_utils.py +281 -0
- wedata/feature_store/utils/utils.py +252 -0
- wedata/feature_store/utils/validation_utils.py +55 -0
- {wedata_feature_engineering-0.1.3.dist-info → wedata_feature_engineering-0.1.5.dist-info}/METADATA +1 -1
- wedata_feature_engineering-0.1.5.dist-info/RECORD +79 -0
- wedata_feature_engineering-0.1.5.dist-info/top_level.txt +1 -0
- wedata_feature_engineering-0.1.3.dist-info/RECORD +0 -30
- wedata_feature_engineering-0.1.3.dist-info/top_level.txt +0 -1
- {wedata_feature_engineering-0.1.3.dist-info → wedata_feature_engineering-0.1.5.dist-info}/WHEEL +0 -0
@@ -0,0 +1,286 @@
|
|
1
|
+
import logging
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from functools import reduce
|
4
|
+
from typing import Dict, List, Set, Tuple, Type, Union
|
5
|
+
|
6
|
+
import yaml
|
7
|
+
from mlflow.utils.file_utils import YamlSafeDumper
|
8
|
+
|
9
|
+
from feature_store.entities.column_info import ColumnInfo
|
10
|
+
from feature_store.entities.feature_column_info import FeatureColumnInfo
|
11
|
+
from feature_store.entities.feature_spec import FeatureSpec
|
12
|
+
from feature_store.entities.on_demand_column_info import OnDemandColumnInfo
|
13
|
+
from feature_store.entities.source_data_column_info import SourceDataColumnInfo
|
14
|
+
from feature_store.utils.topological_sort import topological_sort
|
15
|
+
|
16
|
+
DEFAULT_GRAPH_DEPTH_LIMIT = 5
|
17
|
+
|
18
|
+
COLUMN_INFO_TYPE_SOURCE = "SOURCE"
|
19
|
+
COLUMN_INFO_TYPE_ON_DEMAND = "ON_DEMAND"
|
20
|
+
COLUMN_INFO_TYPE_FEATURE = "FEATURE"
|
21
|
+
|
22
|
+
_logger = logging.getLogger(__name__)
|
23
|
+
|
24
|
+
|
25
|
+
@dataclass
|
26
|
+
class FeatureExecutionGroup:
|
27
|
+
type: str # could be FEATURE, ON_DEMAND, SOURCE
|
28
|
+
features: Union[
|
29
|
+
List[FeatureColumnInfo], List[OnDemandColumnInfo], List[SourceDataColumnInfo]
|
30
|
+
]
|
31
|
+
|
32
|
+
|
33
|
+
# Small number has high priority. Besides SOURCE, preferring FEATURE over ON_DEMAND in topological
|
34
|
+
# sorting to make sure ON_DEMAND columns after FEATURE in simple cases to align with previous
|
35
|
+
# assumption before implementing TLT.
|
36
|
+
# NOTE: changing this priority may cause performance regression, proceed with caution.
|
37
|
+
COLUMN_TYPE_PRIORITY = {
|
38
|
+
COLUMN_INFO_TYPE_SOURCE: 0,
|
39
|
+
COLUMN_INFO_TYPE_ON_DEMAND: 1,
|
40
|
+
COLUMN_INFO_TYPE_FEATURE: 2,
|
41
|
+
}
|
42
|
+
|
43
|
+
|
44
|
+
class _GraphNode:
|
45
|
+
def __init__(self, column_info: ColumnInfo):
|
46
|
+
info = column_info.info
|
47
|
+
self.column_info = column_info
|
48
|
+
self.output_name = info.output_name
|
49
|
+
|
50
|
+
if isinstance(column_info.info, SourceDataColumnInfo):
|
51
|
+
self.input_names = set()
|
52
|
+
self.type = COLUMN_INFO_TYPE_SOURCE
|
53
|
+
elif isinstance(column_info.info, FeatureColumnInfo):
|
54
|
+
self.input_names = set(info.lookup_key)
|
55
|
+
self.type = COLUMN_INFO_TYPE_FEATURE
|
56
|
+
elif isinstance(column_info.info, OnDemandColumnInfo):
|
57
|
+
self.input_names = set(info.input_bindings.values())
|
58
|
+
self.type = COLUMN_INFO_TYPE_ON_DEMAND
|
59
|
+
else:
|
60
|
+
raise ValueError("unknown column info type")
|
61
|
+
|
62
|
+
def __str__(self):
|
63
|
+
return "node<" + self.output_name + ">"
|
64
|
+
|
65
|
+
def __repr__(self):
|
66
|
+
return str(self)
|
67
|
+
|
68
|
+
|
69
|
+
def _column_info_sort_key(node: _GraphNode) -> Tuple[int, str]:
|
70
|
+
"""
|
71
|
+
Returns a tuple of an int and a str as the sorting key for _GraphNode. Priority is determined by
|
72
|
+
the first element and then use the second element to break ties.
|
73
|
+
"""
|
74
|
+
return COLUMN_TYPE_PRIORITY[node.type], node.output_name
|
75
|
+
|
76
|
+
|
77
|
+
def _should_be_grouped(node: _GraphNode) -> bool:
|
78
|
+
"""
|
79
|
+
Returns True if the given node is of type that should be grouped together as much as possible.
|
80
|
+
"""
|
81
|
+
return node.type == COLUMN_INFO_TYPE_FEATURE
|
82
|
+
|
83
|
+
|
84
|
+
def _validate_graph_depth(nodes: List[_GraphNode], depth_limit: int):
|
85
|
+
name_to_node = {node.output_name: node for node in nodes}
|
86
|
+
visited_depth = {}
|
87
|
+
|
88
|
+
def dfs(node: _GraphNode, depth: int):
|
89
|
+
if depth > depth_limit:
|
90
|
+
raise ValueError(
|
91
|
+
f"The given graph contains a dependency path longer than the limit {depth_limit}"
|
92
|
+
)
|
93
|
+
if (
|
94
|
+
node.output_name in visited_depth
|
95
|
+
and depth <= visited_depth[node.output_name]
|
96
|
+
):
|
97
|
+
return
|
98
|
+
visited_depth[node.output_name] = depth
|
99
|
+
for column_name in node.input_names:
|
100
|
+
dependency = name_to_node[column_name]
|
101
|
+
dfs(dependency, depth + 1)
|
102
|
+
|
103
|
+
for node in nodes:
|
104
|
+
dfs(node, 1)
|
105
|
+
|
106
|
+
|
107
|
+
def get_encoded_graph_map(column_infos: List[ColumnInfo]) -> Dict[str, List[str]]:
|
108
|
+
"""
|
109
|
+
Creates a dictionary of columns with their dependency columns for metric use. Columns are
|
110
|
+
encoded with a string representing the type and index. For example:
|
111
|
+
{
|
112
|
+
"f3": ["s1", "s2"],
|
113
|
+
"o4": ["f3"],
|
114
|
+
"o5": []
|
115
|
+
}
|
116
|
+
"s1" and "s2" are SourceColumnInfos, "f3" is FeatureColumnInfo and "o4", "o5" are
|
117
|
+
OnDemandColumnInfos. "f3" depends on "s1" and "s2", "o5" doesn't depend on any column, etc.
|
118
|
+
:param column_infos: A list of ColumnInfos.
|
119
|
+
"""
|
120
|
+
nodes = {info.output_name: _GraphNode(info) for info in column_infos}
|
121
|
+
next_node_index = 0
|
122
|
+
# A map from column info's output_name to its label.
|
123
|
+
node_label = {}
|
124
|
+
|
125
|
+
def get_node_label(node):
|
126
|
+
nonlocal next_node_index
|
127
|
+
output_name = node.output_name
|
128
|
+
if output_name not in node_label:
|
129
|
+
if node.type == COLUMN_INFO_TYPE_SOURCE:
|
130
|
+
type_simple_str = "s"
|
131
|
+
if node.type == COLUMN_INFO_TYPE_FEATURE:
|
132
|
+
type_simple_str = "f"
|
133
|
+
if node.type == COLUMN_INFO_TYPE_ON_DEMAND:
|
134
|
+
type_simple_str = "o"
|
135
|
+
new_label = type_simple_str + str(next_node_index)
|
136
|
+
next_node_index += 1
|
137
|
+
node_label[output_name] = new_label
|
138
|
+
return node_label[output_name]
|
139
|
+
|
140
|
+
graph_map = {}
|
141
|
+
for node in nodes.values():
|
142
|
+
label = get_node_label(node)
|
143
|
+
dependencies = []
|
144
|
+
for dep_name in sorted(node.input_names):
|
145
|
+
if dep_name not in nodes:
|
146
|
+
# skip the column if it's not in the feature spec.
|
147
|
+
continue
|
148
|
+
dep = get_node_label(nodes[dep_name])
|
149
|
+
dependencies.append(dep)
|
150
|
+
graph_map[label] = dependencies
|
151
|
+
return graph_map
|
152
|
+
|
153
|
+
|
154
|
+
def assign_topological_ordering(
|
155
|
+
column_infos: List[ColumnInfo],
|
156
|
+
allow_missing_source_columns=False,
|
157
|
+
graph_depth_limit=DEFAULT_GRAPH_DEPTH_LIMIT,
|
158
|
+
) -> List[ColumnInfo]:
|
159
|
+
"""
|
160
|
+
Assigns the topological ordering for each ColumnInfo of the input. Returns a list of new
|
161
|
+
ColumnInfo objects with topological_ordering set to an integer.
|
162
|
+
|
163
|
+
:param column_infos: a list of ColumnInfos.
|
164
|
+
:param allow_missing_source_columns: ONLY USED BY FSE TEMPORARILY. Allow lookup key or
|
165
|
+
function input be missing from source columns. If true, this method will assign
|
166
|
+
topological_ordering to columns as if the missing sources are added in the column_infos.
|
167
|
+
:param graph_depth_limit raises if the given graph exceed the limit.
|
168
|
+
:raises ValueError if there is a cycle in the graph.
|
169
|
+
"""
|
170
|
+
nodes = list(map(lambda c: _GraphNode(c), column_infos))
|
171
|
+
# allow_missing_source_columns is used when feature_serving_endpoint_client creates training
|
172
|
+
# sets. It doesn't include source columns in the dataframe.
|
173
|
+
# TODO[ML-33809]: clean up allow_missing_source_columns.
|
174
|
+
all_output_names = set([n.output_name for n in nodes])
|
175
|
+
all_input_names = reduce(lambda a, b: a | b, [n.input_names for n in nodes])
|
176
|
+
missing_inputs = all_input_names - all_output_names
|
177
|
+
if allow_missing_source_columns:
|
178
|
+
for input_name in missing_inputs:
|
179
|
+
if input_name not in all_output_names:
|
180
|
+
nodes.append(
|
181
|
+
_GraphNode(ColumnInfo(SourceDataColumnInfo(input_name), False))
|
182
|
+
)
|
183
|
+
elif len(missing_inputs) > 0:
|
184
|
+
missing_input_names_str = ", ".join(
|
185
|
+
[f"'{name}'" for name in sorted(missing_inputs)]
|
186
|
+
)
|
187
|
+
raise ValueError(
|
188
|
+
f"Input columns {missing_input_names_str} required by FeatureLookups or "
|
189
|
+
"FeatureFunctions are not provided by input DataFrame or other FeatureFunctions and "
|
190
|
+
"FeatureLookups"
|
191
|
+
)
|
192
|
+
output_name_to_node = {node.output_name: node for node in nodes}
|
193
|
+
graph = {
|
194
|
+
node: [output_name_to_node[input_name] for input_name in node.input_names]
|
195
|
+
for node in nodes
|
196
|
+
}
|
197
|
+
sorted_nodes = topological_sort(graph, _column_info_sort_key, _should_be_grouped)
|
198
|
+
# validate depth after sorting the graph because cycle is detected during sorting.
|
199
|
+
_validate_graph_depth(nodes, graph_depth_limit)
|
200
|
+
name_to_ordering = {node.output_name: i for i, node in enumerate(sorted_nodes)}
|
201
|
+
return [
|
202
|
+
column.with_topological_ordering(name_to_ordering[column.output_name])
|
203
|
+
for column in column_infos
|
204
|
+
]
|
205
|
+
|
206
|
+
|
207
|
+
def get_feature_execution_groups(
|
208
|
+
feature_spec: FeatureSpec, df_columns: List[str] = []
|
209
|
+
) -> List[FeatureExecutionGroup]:
|
210
|
+
"""
|
211
|
+
Splits the list of column_infos in feature_spec into groups based on the topological_ordering of
|
212
|
+
the column_infos such that each group contains only one type of feature columns and columns
|
213
|
+
don't depend on other columns in the same group. The type of feature column is equivalent to the
|
214
|
+
class type of column_info.info field.
|
215
|
+
Example:
|
216
|
+
Given FeatureSpec with some columns, after sorting the columns by topological_ordering,
|
217
|
+
assuming the sorted list:
|
218
|
+
[source_1, feature_2, feature_3, on_demand_4, on_demand_5]
|
219
|
+
where feature_2 depends on feature_3. The resulting groups will be:
|
220
|
+
[
|
221
|
+
group(SOURCE, [source_1]),
|
222
|
+
group(FEATURE, [feature_2]),
|
223
|
+
group(FEATURE, [feature_3]),
|
224
|
+
group(ON_DEMAND, [on_demand_4, on_demand_5]),
|
225
|
+
]
|
226
|
+
|
227
|
+
:param feature_spec: A FeatureSpec with topologically sorted column_infos.
|
228
|
+
:param df_columns: the columns from the DF used to create_training_set or score_batch.
|
229
|
+
"""
|
230
|
+
# convert column infos into _GraphNode
|
231
|
+
nodes = list(map(lambda c: _GraphNode(c), feature_spec.column_infos))
|
232
|
+
if any(info.topological_ordering is None for info in feature_spec.column_infos):
|
233
|
+
# The old version of feature_spec may not have topological_ordering, we can safely assume
|
234
|
+
# they are already sorted because of validations during the feature_spec creation.
|
235
|
+
_logger.warning(
|
236
|
+
"Processing a feature spec that at least one of the column_infos has no "
|
237
|
+
"topological_ordering"
|
238
|
+
)
|
239
|
+
else:
|
240
|
+
# sort nodes by topological_ordering
|
241
|
+
nodes = sorted(nodes, key=lambda n: n.column_info.topological_ordering)
|
242
|
+
# A buffer holding the columns in a group.
|
243
|
+
buffer = []
|
244
|
+
# output names of columns in the current buffer.
|
245
|
+
buffered_output_names = set()
|
246
|
+
# Used to validate the topological sorting.
|
247
|
+
# df_columns is used to be backward compatible. In old FeatureSpecs, source columns might not
|
248
|
+
# exist. So we need to consider the df as initial resolved columns.
|
249
|
+
resolved_columns = set(df_columns)
|
250
|
+
result_list = []
|
251
|
+
last_type = None
|
252
|
+
for node in nodes:
|
253
|
+
if not node.input_names.issubset(resolved_columns):
|
254
|
+
raise ValueError(
|
255
|
+
"The column_infos in the FeatureSpec is not topologically sorted"
|
256
|
+
)
|
257
|
+
if node.type != last_type or buffered_output_names.intersection(
|
258
|
+
node.input_names
|
259
|
+
):
|
260
|
+
# split group if the current node has a different type from the previous node OR
|
261
|
+
# any of the inputs are from the nodes in the current group.
|
262
|
+
if buffer:
|
263
|
+
result_list.append(FeatureExecutionGroup(last_type, buffer))
|
264
|
+
buffer = []
|
265
|
+
buffered_output_names.clear()
|
266
|
+
last_type = node.type
|
267
|
+
buffer.append(node.column_info.info)
|
268
|
+
resolved_columns.add(node.output_name)
|
269
|
+
buffered_output_names.add(node.output_name)
|
270
|
+
if buffer:
|
271
|
+
result_list.append(FeatureExecutionGroup(last_type, buffer))
|
272
|
+
return result_list
|
273
|
+
|
274
|
+
|
275
|
+
def convert_to_yaml_string(feature_spec: FeatureSpec) -> str:
|
276
|
+
"""
|
277
|
+
Converts the given FeatureSpec to a YAML string.
|
278
|
+
"""
|
279
|
+
feature_spec_dict = feature_spec._to_dict()
|
280
|
+
return yaml.dump(
|
281
|
+
feature_spec_dict,
|
282
|
+
default_flow_style=False,
|
283
|
+
allow_unicode=True,
|
284
|
+
sort_keys=False,
|
285
|
+
Dumper=YamlSafeDumper,
|
286
|
+
)
|
@@ -0,0 +1,73 @@
|
|
1
|
+
import copy
|
2
|
+
from typing import List, Union
|
3
|
+
|
4
|
+
from feature_store.entities.feature_function import FeatureFunction
|
5
|
+
from feature_store.entities.feature_lookup import FeatureLookup
|
6
|
+
from feature_store.spark_client.spark_client import SparkClient
|
7
|
+
from feature_store.utils import uc_utils
|
8
|
+
from feature_store.utils.feature_lookup_utils import get_feature_lookups_with_full_table_names
|
9
|
+
|
10
|
+
|
11
|
+
def format_feature_lookups_and_functions(
|
12
|
+
_spark_client: SparkClient, features: List[Union[FeatureLookup, FeatureFunction]]
|
13
|
+
):
|
14
|
+
fl_idx = []
|
15
|
+
ff_idx = []
|
16
|
+
feature_lookups = []
|
17
|
+
feature_functions = []
|
18
|
+
for idx, feature in enumerate(features):
|
19
|
+
if isinstance(feature, FeatureLookup):
|
20
|
+
fl_idx.append(idx)
|
21
|
+
feature_lookups.append(feature)
|
22
|
+
elif isinstance(feature, FeatureFunction):
|
23
|
+
ff_idx.append(idx)
|
24
|
+
feature_functions.append(feature)
|
25
|
+
else:
|
26
|
+
raise ValueError(
|
27
|
+
f"Expected a list of FeatureLookups for 'feature_lookups', but received type '{type(feature)}'."
|
28
|
+
)
|
29
|
+
|
30
|
+
# FeatureLookups and FeatureFunctions must have fully qualified table, UDF names
|
31
|
+
feature_lookups = get_feature_lookups_with_full_table_names(
|
32
|
+
feature_lookups,
|
33
|
+
_spark_client.get_current_catalog(),
|
34
|
+
_spark_client.get_current_database(),
|
35
|
+
)
|
36
|
+
feature_functions = get_feature_functions_with_full_udf_names(
|
37
|
+
feature_functions,
|
38
|
+
_spark_client.get_current_catalog(),
|
39
|
+
_spark_client.get_current_database(),
|
40
|
+
)
|
41
|
+
|
42
|
+
# Restore original order of FeatureLookups, FeatureFunctions. Copy to avoid mutating original list.
|
43
|
+
features = features.copy()
|
44
|
+
for idx, feature in zip(fl_idx + ff_idx, feature_lookups + feature_functions):
|
45
|
+
features[idx] = feature
|
46
|
+
|
47
|
+
return features
|
48
|
+
|
49
|
+
|
50
|
+
def get_feature_functions_with_full_udf_names(
|
51
|
+
feature_functions: List[FeatureFunction], current_catalog: str, current_schema: str
|
52
|
+
):
|
53
|
+
"""
|
54
|
+
Takes in a list of FeatureFunctions, and returns copies with:
|
55
|
+
1. Fully qualified UDF names.
|
56
|
+
2. If output_name is empty, fully qualified UDF names as output_name.
|
57
|
+
"""
|
58
|
+
udf_names = {ff.udf_name for ff in feature_functions}
|
59
|
+
uc_utils._check_qualified_udf_names(udf_names)
|
60
|
+
uc_utils._verify_all_udfs_in_uc(udf_names, current_catalog, current_schema)
|
61
|
+
|
62
|
+
standardized_feature_functions = []
|
63
|
+
for ff in feature_functions:
|
64
|
+
ff_copy = copy.deepcopy(ff)
|
65
|
+
del ff
|
66
|
+
|
67
|
+
ff_copy._udf_name = uc_utils.get_full_udf_name(
|
68
|
+
ff_copy.udf_name, current_catalog, current_schema
|
69
|
+
)
|
70
|
+
if not ff_copy.output_name:
|
71
|
+
ff_copy._output_name = ff_copy.udf_name
|
72
|
+
standardized_feature_functions.append(ff_copy)
|
73
|
+
return standardized_feature_functions
|
@@ -0,0 +1,117 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
from feature_store.constants.constants import _ERROR, _WARN
|
4
|
+
|
5
|
+
_logger = logging.getLogger(__name__)
|
6
|
+
|
7
|
+
|
8
|
+
def catalog_matches_delta_schema(catalog_features, df_schema, column_filter=None):
|
9
|
+
"""
|
10
|
+
Confirm that the column names and column types are the same.
|
11
|
+
|
12
|
+
Returns True if identical, False if there is a mismatch.
|
13
|
+
|
14
|
+
If column_filter is not None, only columns in column_filter must match.
|
15
|
+
"""
|
16
|
+
if column_filter is not None:
|
17
|
+
catalog_features = [c for c in catalog_features if c.name in column_filter]
|
18
|
+
df_schema = [c for c in df_schema if c.name in column_filter]
|
19
|
+
|
20
|
+
catalog_schema = {
|
21
|
+
feature.name: feature.data_type
|
22
|
+
for feature in catalog_features
|
23
|
+
}
|
24
|
+
delta_schema = {
|
25
|
+
feature.name: feature.dataType
|
26
|
+
for feature in df_schema
|
27
|
+
}
|
28
|
+
|
29
|
+
complex_catalog_schema = get_complex_catalog_schema(
|
30
|
+
catalog_features, catalog_schema
|
31
|
+
)
|
32
|
+
complex_delta_schema = get_complex_delta_schema(df_schema, delta_schema)
|
33
|
+
|
34
|
+
return (
|
35
|
+
catalog_schema == delta_schema
|
36
|
+
and complex_catalog_schema == complex_delta_schema
|
37
|
+
)
|
38
|
+
|
39
|
+
|
40
|
+
def get_complex_delta_schema(delta_features, delta_feature_names_to_fs_types):
|
41
|
+
"""
|
42
|
+
1. Filter delta features to features that have complex datatypes.
|
43
|
+
2. Take the existing Spark DataType stored on the Delta features. This is later used for
|
44
|
+
comparison against the Catalog schema's complex Spark DataTypes.
|
45
|
+
3. Return a mapping of feature name to their respective complex Spark DataTypes.
|
46
|
+
|
47
|
+
:param delta_features: List[Feature]. List of features stored in Delta.
|
48
|
+
:param delta_feature_names_to_fs_types: Map[str, feature_store.DataType]. A mapping of feature
|
49
|
+
names to their respective Feature Store DataTypes.
|
50
|
+
:return: Map[str, spark.sql.types.DataType]. A mapping of feature names to their respective
|
51
|
+
Spark DataTypes.
|
52
|
+
"""
|
53
|
+
complex_delta_features = [
|
54
|
+
feature
|
55
|
+
for feature in delta_features
|
56
|
+
if delta_feature_names_to_fs_types[feature.name] in DATA_TYPES_REQUIRES_DETAILS
|
57
|
+
]
|
58
|
+
complex_delta_feature_names_to_spark_types = {
|
59
|
+
feature.name: feature.dataType for feature in complex_delta_features
|
60
|
+
}
|
61
|
+
return complex_delta_feature_names_to_spark_types
|
62
|
+
|
63
|
+
|
64
|
+
def get_complex_catalog_schema(catalog_features, catalog_feature_names_to_fs_types):
|
65
|
+
"""
|
66
|
+
1. Filter catalog features to features that have complex datatypes.
|
67
|
+
2. Convert the JSON string stored in each feature's data_type_details to the corresponding
|
68
|
+
Spark DataType. This is later used for comparison against the Delta schema's complex Spark
|
69
|
+
DataTypes.
|
70
|
+
3. Return a mapping of feature name to their respective complex Spark DataTypes.
|
71
|
+
|
72
|
+
:param catalog_features: List[Feature]. List of features stored in the Catalog.
|
73
|
+
:param catalog_feature_names_to_fs_types: Map[str, feature_store.DataType]. A mapping of feature
|
74
|
+
names to their respective Feature Store DataTypes.
|
75
|
+
:return: Map[str, spark.sql.types.DataType]. A mapping of feature names to their respective
|
76
|
+
Spark DataTypes.
|
77
|
+
"""
|
78
|
+
complex_catalog_features = [
|
79
|
+
feature
|
80
|
+
for feature in catalog_features
|
81
|
+
if catalog_feature_names_to_fs_types[feature.name]
|
82
|
+
in DATA_TYPES_REQUIRES_DETAILS
|
83
|
+
]
|
84
|
+
complex_catalog_feature_names_to_spark_types = {
|
85
|
+
feature.name: feature.data_type_details
|
86
|
+
for feature in complex_catalog_features
|
87
|
+
}
|
88
|
+
return complex_catalog_feature_names_to_spark_types
|
89
|
+
|
90
|
+
|
91
|
+
def log_catalog_schema_not_match_delta_schema(catalog_features, df_schema, level):
|
92
|
+
"""
|
93
|
+
Log the catalog schema does not match the delta table schema.
|
94
|
+
|
95
|
+
Example warning:
|
96
|
+
Expected recorded schema from Feature Catalog to be identical with
|
97
|
+
schema in delta table.Feature Catalog's schema is
|
98
|
+
'{'id': 'INTEGER', 'feat1': 'INTEGER'}' while delta table's
|
99
|
+
schema is '{'id': 'INTEGER', 'feat1': 'FLOAT'}'
|
100
|
+
"""
|
101
|
+
catalog_schema = {feature.name: feature.data_type for feature in catalog_features}
|
102
|
+
delta_schema = {
|
103
|
+
feature.name: feature.dataType
|
104
|
+
for feature in df_schema
|
105
|
+
}
|
106
|
+
msg = (
|
107
|
+
f"Expected recorded schema from Feature Catalog to be identical with schema "
|
108
|
+
f"in Delta table. "
|
109
|
+
f"Feature Catalog's schema is '{catalog_schema}' while Delta table's schema "
|
110
|
+
f"is '{delta_schema}'"
|
111
|
+
)
|
112
|
+
if level == _WARN:
|
113
|
+
_logger.warning(msg)
|
114
|
+
elif level == _ERROR:
|
115
|
+
raise RuntimeError(msg)
|
116
|
+
else:
|
117
|
+
_logger.info(msg)
|
@@ -0,0 +1,158 @@
|
|
1
|
+
from collections import defaultdict, deque
|
2
|
+
from queue import PriorityQueue
|
3
|
+
from typing import Callable, Dict, Hashable, List, Optional
|
4
|
+
|
5
|
+
__all__ = ["find_cycle", "topological_sort"]
|
6
|
+
|
7
|
+
|
8
|
+
class _NodeInfo:
|
9
|
+
def __init__(self, node):
|
10
|
+
self.node = node
|
11
|
+
# number of non-processed predecessors.
|
12
|
+
self.n_blockers = 0
|
13
|
+
# list of nodes that depend on this node.
|
14
|
+
self.successors = []
|
15
|
+
|
16
|
+
|
17
|
+
def find_cycle(
|
18
|
+
node_dependencies: Dict[Hashable, List[Hashable]]
|
19
|
+
) -> Optional[List[Hashable]]:
|
20
|
+
"""
|
21
|
+
Finds a cycle in the node_dependencies graph. Returns a list of node(s) that forms a cycle or
|
22
|
+
None if no cycle can be found.
|
23
|
+
:param node_dependencies: A dict with hashable objects as keys and their list of dependency
|
24
|
+
nodes as values.
|
25
|
+
"""
|
26
|
+
# A stack used to perform DFS on the graph.
|
27
|
+
stack = deque()
|
28
|
+
# Another stack storing the path from root to current node in DFS. Used to detect cycle.
|
29
|
+
backtrack_stack = []
|
30
|
+
# A set of nodes that no cycle can be found starting from these nodes.
|
31
|
+
resolved = set()
|
32
|
+
# Create a copy of the dependency graph with defaultdict for convenience.
|
33
|
+
default_dependency = defaultdict(list)
|
34
|
+
default_dependency.update(node_dependencies)
|
35
|
+
|
36
|
+
# Perform DFS on every node in the graph.
|
37
|
+
for node in node_dependencies.keys():
|
38
|
+
if node in resolved:
|
39
|
+
# Skip a node if it's already resolved.
|
40
|
+
continue
|
41
|
+
# DFS from the node
|
42
|
+
stack.append(node)
|
43
|
+
while stack:
|
44
|
+
top = stack[-1]
|
45
|
+
if top not in backtrack_stack:
|
46
|
+
# First time visiting this node. There will be a second visit after the dependencies
|
47
|
+
# are resolved if it has dependencies.
|
48
|
+
backtrack_stack.append(top)
|
49
|
+
# If not expended after traversing the dependencies, meaning there is no dependency or
|
50
|
+
# all dependencies are resolved.
|
51
|
+
expanded = False
|
52
|
+
for depend in default_dependency[top]:
|
53
|
+
if depend in backtrack_stack:
|
54
|
+
# found a cycle
|
55
|
+
index = backtrack_stack.index(depend)
|
56
|
+
return backtrack_stack[index:]
|
57
|
+
if depend in resolved:
|
58
|
+
continue
|
59
|
+
# Only adding node to stack. backtrack_stack only contains nodes in the current DFS
|
60
|
+
# path.
|
61
|
+
stack.append(depend)
|
62
|
+
expanded = True
|
63
|
+
if not expanded:
|
64
|
+
stack.pop()
|
65
|
+
resolved.add(top)
|
66
|
+
backtrack_stack.pop()
|
67
|
+
return None
|
68
|
+
|
69
|
+
|
70
|
+
def _all_items_in_queue_should_be_grouped(
|
71
|
+
queue: PriorityQueue, should_be_grouped: Callable
|
72
|
+
) -> bool:
|
73
|
+
temp = []
|
74
|
+
should_group = True
|
75
|
+
# note: avoid using queue.qsize() because it's not guaranteed to be accurate.
|
76
|
+
while not queue.empty():
|
77
|
+
k, node = queue.get()
|
78
|
+
temp.append((k, node))
|
79
|
+
if not should_be_grouped(node):
|
80
|
+
should_group = False
|
81
|
+
for item in temp:
|
82
|
+
queue.put(item)
|
83
|
+
return should_group
|
84
|
+
|
85
|
+
|
86
|
+
def topological_sort(
|
87
|
+
node_dependencies: Dict[Hashable, List[Hashable]],
|
88
|
+
key: Callable = None,
|
89
|
+
should_be_grouped: Callable = None,
|
90
|
+
) -> List[Hashable]:
|
91
|
+
"""
|
92
|
+
Topological sort the given node_dependencies graph. Returns a sorted list of nodes.
|
93
|
+
:param node_dependencies: A dict with hashable objects as keys and their list of dependency
|
94
|
+
nodes as values.
|
95
|
+
:param key: a Callable that returns a sort key when called with a hashable object. The key is
|
96
|
+
used to break ties in topological sorting. An object with smaller key is added
|
97
|
+
to the result list first.
|
98
|
+
:raises ValueError if a cycle is found in the graph.
|
99
|
+
"""
|
100
|
+
# Calling a dedicated find_cycle function to be able to give a detailed error message.
|
101
|
+
cycle = find_cycle(node_dependencies)
|
102
|
+
if cycle is not None:
|
103
|
+
raise ValueError(
|
104
|
+
"Following nodes form a cycle: ",
|
105
|
+
cycle,
|
106
|
+
". Please resolve any circular dependencies before calling Feature Store.",
|
107
|
+
)
|
108
|
+
|
109
|
+
# A priority-queue storing the nodes whose dependency has been resolved.
|
110
|
+
# priority is determined by the given key function.
|
111
|
+
ready_queue = PriorityQueue()
|
112
|
+
# Map from node to _NodeInfo.
|
113
|
+
nodes = {}
|
114
|
+
if key is None:
|
115
|
+
key = hash # use the built-in hash function by default
|
116
|
+
if should_be_grouped is None:
|
117
|
+
should_be_grouped = lambda _: False
|
118
|
+
# Perform Kahn's algorithm by traversing the graph starting from nodes without dependency.
|
119
|
+
# Node is removed from its successors' dependency once resolved. And node whose dependency gets
|
120
|
+
# all resolved is added to the priority queue.
|
121
|
+
for node, dependencies in node_dependencies.items():
|
122
|
+
# Initialize the graph to topologically sort based on the input node_dependencies.
|
123
|
+
# All nodes, its successors and number of predecessors should be populated.
|
124
|
+
if node not in nodes:
|
125
|
+
nodes[node] = _NodeInfo(node)
|
126
|
+
for dependency in dependencies:
|
127
|
+
if dependency not in nodes:
|
128
|
+
nodes[dependency] = _NodeInfo(dependency)
|
129
|
+
nodes[dependency].successors.append(node)
|
130
|
+
if len(dependencies):
|
131
|
+
nodes[node].n_blockers = len(dependencies)
|
132
|
+
# Initialize the ready_queue to start traversing the graph from nodes without any dependencies.
|
133
|
+
for node, node_info in nodes.items():
|
134
|
+
if node_info.n_blockers == 0:
|
135
|
+
ready_queue.put((key(node), node))
|
136
|
+
# At the end of the algorithm, result_list will have a topologically sorted listed of nodes.
|
137
|
+
result_list = []
|
138
|
+
|
139
|
+
def process_nodes(node_buffer, queue):
|
140
|
+
for node in node_buffer:
|
141
|
+
result_list.append(node)
|
142
|
+
for successor in nodes[node].successors:
|
143
|
+
s_info = nodes[successor]
|
144
|
+
s_info.n_blockers -= 1
|
145
|
+
if s_info.n_blockers == 0:
|
146
|
+
queue.put((key(successor), successor))
|
147
|
+
|
148
|
+
while not ready_queue.empty():
|
149
|
+
if _all_items_in_queue_should_be_grouped(ready_queue, should_be_grouped):
|
150
|
+
batch_buffer = []
|
151
|
+
while not ready_queue.empty():
|
152
|
+
_, node = ready_queue.get()
|
153
|
+
batch_buffer.append(node)
|
154
|
+
process_nodes(batch_buffer, ready_queue)
|
155
|
+
else:
|
156
|
+
_, node = ready_queue.get()
|
157
|
+
process_nodes([node], ready_queue)
|
158
|
+
return result_list
|