wedata-feature-engineering 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wedata/__init__.py +1 -1
- wedata/feature_store/client.py +113 -41
- wedata/feature_store/constants/constants.py +19 -0
- wedata/feature_store/entities/column_info.py +4 -4
- wedata/feature_store/entities/feature_lookup.py +5 -1
- wedata/feature_store/entities/feature_spec.py +46 -46
- wedata/feature_store/entities/feature_table.py +42 -99
- wedata/feature_store/entities/training_set.py +13 -12
- wedata/feature_store/feature_table_client/feature_table_client.py +85 -30
- wedata/feature_store/spark_client/spark_client.py +30 -56
- wedata/feature_store/training_set_client/training_set_client.py +209 -38
- wedata/feature_store/utils/common_utils.py +213 -3
- wedata/feature_store/utils/feature_lookup_utils.py +6 -6
- wedata/feature_store/utils/feature_spec_utils.py +6 -6
- wedata/feature_store/utils/feature_utils.py +5 -5
- wedata/feature_store/utils/on_demand_utils.py +107 -0
- wedata/feature_store/utils/schema_utils.py +1 -1
- wedata/feature_store/utils/signature_utils.py +205 -0
- wedata/feature_store/utils/training_set_utils.py +18 -19
- wedata/feature_store/utils/uc_utils.py +1 -1
- {wedata_feature_engineering-0.1.5.dist-info → wedata_feature_engineering-0.1.6.dist-info}/METADATA +1 -1
- wedata_feature_engineering-0.1.6.dist-info/RECORD +43 -0
- feature_store/__init__.py +0 -6
- feature_store/client.py +0 -169
- feature_store/constants/__init__.py +0 -0
- feature_store/constants/constants.py +0 -28
- feature_store/entities/__init__.py +0 -0
- feature_store/entities/column_info.py +0 -117
- feature_store/entities/data_type.py +0 -92
- feature_store/entities/environment_variables.py +0 -55
- feature_store/entities/feature.py +0 -53
- feature_store/entities/feature_column_info.py +0 -64
- feature_store/entities/feature_function.py +0 -55
- feature_store/entities/feature_lookup.py +0 -179
- feature_store/entities/feature_spec.py +0 -454
- feature_store/entities/feature_spec_constants.py +0 -25
- feature_store/entities/feature_table.py +0 -164
- feature_store/entities/feature_table_info.py +0 -40
- feature_store/entities/function_info.py +0 -184
- feature_store/entities/on_demand_column_info.py +0 -44
- feature_store/entities/source_data_column_info.py +0 -21
- feature_store/entities/training_set.py +0 -134
- feature_store/feature_table_client/__init__.py +0 -0
- feature_store/feature_table_client/feature_table_client.py +0 -313
- feature_store/spark_client/__init__.py +0 -0
- feature_store/spark_client/spark_client.py +0 -286
- feature_store/training_set_client/__init__.py +0 -0
- feature_store/training_set_client/training_set_client.py +0 -196
- feature_store/utils/__init__.py +0 -0
- feature_store/utils/common_utils.py +0 -96
- feature_store/utils/feature_lookup_utils.py +0 -570
- feature_store/utils/feature_spec_utils.py +0 -286
- feature_store/utils/feature_utils.py +0 -73
- feature_store/utils/schema_utils.py +0 -117
- feature_store/utils/topological_sort.py +0 -158
- feature_store/utils/training_set_utils.py +0 -580
- feature_store/utils/uc_utils.py +0 -281
- feature_store/utils/utils.py +0 -252
- feature_store/utils/validation_utils.py +0 -55
- wedata/feature_store/utils/utils.py +0 -252
- wedata_feature_engineering-0.1.5.dist-info/RECORD +0 -79
- {wedata_feature_engineering-0.1.5.dist-info → wedata_feature_engineering-0.1.6.dist-info}/WHEEL +0 -0
- {wedata_feature_engineering-0.1.5.dist-info → wedata_feature_engineering-0.1.6.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,217 @@
|
|
1
1
|
"""
|
2
2
|
通用工具函数
|
3
3
|
"""
|
4
|
-
|
4
|
+
import os
|
5
5
|
from collections import Counter
|
6
|
-
from
|
6
|
+
from datetime import datetime, timezone
|
7
|
+
from functools import wraps
|
8
|
+
from typing import Any, Dict, List, Optional
|
9
|
+
from urllib.parse import urlparse
|
7
10
|
|
11
|
+
import mlflow
|
12
|
+
from mlflow.exceptions import RestException
|
13
|
+
from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
|
8
14
|
from mlflow.store.artifact.models_artifact_repo import ModelsArtifactRepository
|
9
15
|
from mlflow.store.artifact.runs_artifact_repo import RunsArtifactRepository
|
16
|
+
from mlflow.utils import databricks_utils
|
17
|
+
|
18
|
+
from wedata.feature_store.constants.constants import MODEL_DATA_PATH_ROOT
|
19
|
+
|
20
|
+
|
21
|
+
def validate_table_name(name: str):
|
22
|
+
"""
|
23
|
+
验证特征表名规范,仅支持单表名,不能包含点(如<catalog>.<schema>.<table>)
|
24
|
+
|
25
|
+
参数:
|
26
|
+
name: 要验证的表名
|
27
|
+
|
28
|
+
异常:
|
29
|
+
ValueError: 如果表名包含点或不符合规范
|
30
|
+
"""
|
31
|
+
if not name or not isinstance(name, str):
|
32
|
+
raise ValueError("Table name must be a non-empty string")
|
33
|
+
if name.count('.') > 0:
|
34
|
+
raise ValueError("Feature table name only supports single table name, cannot contain dots (e.g. <catalog>.<schema>.<table>)")
|
35
|
+
if not name[0].isalpha():
|
36
|
+
raise ValueError("Table name must start with a letter")
|
37
|
+
if not all(c.isalnum() or c == '_' for c in name):
|
38
|
+
raise ValueError("Table name can only contain letters, numbers and underscores")
|
39
|
+
|
40
|
+
|
41
|
+
def build_full_table_name(table_name: str) -> str:
|
42
|
+
"""
|
43
|
+
构建完整的表名,格式化为`<database>.<table>`形式。
|
44
|
+
|
45
|
+
Args:
|
46
|
+
table_name: 输入的表名(可以是简化的表名或完整表名)。
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
完整表名(`<database>.<table>`)。
|
50
|
+
"""
|
51
|
+
|
52
|
+
# 从环境变量中获取当前主账户名称
|
53
|
+
owner_uin = os.environ.get("WEDATA_OWNER_UIN", "default")
|
54
|
+
|
55
|
+
# 如果owner_uin为空,则报错
|
56
|
+
if not owner_uin:
|
57
|
+
raise ValueError("WEDATA_OWNER_UIN environment variable is not set")
|
58
|
+
|
59
|
+
feature_store_database = f"{owner_uin}.{table_name}"
|
60
|
+
|
61
|
+
return feature_store_database
|
62
|
+
|
63
|
+
|
64
|
+
def enable_if(condition):
|
65
|
+
"""
|
66
|
+
A decorator that conditionally enables a function based on a condition.
|
67
|
+
If the condition is not truthy, calling the function raises a NotImplementedError.
|
68
|
+
|
69
|
+
:param condition: A callable that returns a truthy or falsy value.
|
70
|
+
"""
|
71
|
+
|
72
|
+
def decorator(func):
|
73
|
+
@wraps(func)
|
74
|
+
def wrapper(*args, **kwargs):
|
75
|
+
if not condition():
|
76
|
+
raise NotImplementedError
|
77
|
+
return func(*args, **kwargs)
|
78
|
+
|
79
|
+
return wrapper
|
80
|
+
|
81
|
+
return decorator
|
82
|
+
|
83
|
+
|
84
|
+
def is_empty(target: str):
|
85
|
+
return target is None or len(target.strip()) == 0
|
86
|
+
|
87
|
+
|
88
|
+
class _NoDbutilsError(Exception):
|
89
|
+
pass
|
90
|
+
|
91
|
+
|
92
|
+
def _get_dbutils():
|
93
|
+
try:
|
94
|
+
import IPython
|
95
|
+
|
96
|
+
ip_shell = IPython.get_ipython()
|
97
|
+
if ip_shell is None:
|
98
|
+
raise _NoDbutilsError
|
99
|
+
return ip_shell.ns_table["user_global"]["dbutils"]
|
100
|
+
except ImportError:
|
101
|
+
raise _NoDbutilsError
|
102
|
+
except KeyError:
|
103
|
+
raise _NoDbutilsError
|
104
|
+
|
105
|
+
|
106
|
+
def utc_timestamp_ms_from_iso_datetime_string(date_string: str) -> int:
|
107
|
+
dt = datetime.fromisoformat(date_string)
|
108
|
+
utc_dt = dt.replace(tzinfo=timezone.utc)
|
109
|
+
return 1000 * utc_dt.timestamp()
|
110
|
+
|
111
|
+
|
112
|
+
def pip_depependency_pinned_major_version(pip_package_name, major_version):
|
113
|
+
"""
|
114
|
+
Generate a pip dependency string that is pinned to a major version, for example: "databricks-feature-lookup==0.*"
|
115
|
+
"""
|
116
|
+
return f"{pip_package_name}=={major_version}.*"
|
117
|
+
|
118
|
+
|
119
|
+
def add_mlflow_pip_depependency(conda_env, pip_package_name):
|
120
|
+
"""
|
121
|
+
Add a new pip dependency to the conda environment taken from the raw MLflow model.
|
122
|
+
"""
|
123
|
+
if pip_package_name is None or len(pip_package_name) == 0:
|
124
|
+
raise ValueError(
|
125
|
+
"Unexpected input: missing or empty pip_package_name parameter"
|
126
|
+
)
|
127
|
+
|
128
|
+
found_pip_dependency = False
|
129
|
+
if conda_env is not None:
|
130
|
+
for dep in conda_env["dependencies"]:
|
131
|
+
if isinstance(dep, dict) and "pip" in dep:
|
132
|
+
found_pip_dependency = True
|
133
|
+
pip_deps = dep["pip"]
|
134
|
+
if pip_package_name not in pip_deps:
|
135
|
+
pip_deps.append(pip_package_name)
|
136
|
+
if "dependencies" in conda_env and not found_pip_dependency:
|
137
|
+
raise ValueError(
|
138
|
+
"Unexpected input: mlflow conda_env did not contain pip as a dependency"
|
139
|
+
)
|
140
|
+
|
141
|
+
|
142
|
+
def download_model_artifacts(model_uri, dir):
|
143
|
+
"""
|
144
|
+
Downloads model artifacts from model_uri to dir.
|
145
|
+
"""
|
146
|
+
if not is_artifact_uri(model_uri):
|
147
|
+
raise ValueError(
|
148
|
+
f"Invalid model URI '{model_uri}'."
|
149
|
+
f"Use ``models:/model_name>/<version_number>`` or "
|
150
|
+
f"``runs:/<mlflow_run_id>/run-relative/path/to/model``."
|
151
|
+
)
|
152
|
+
|
153
|
+
try:
|
154
|
+
repo = get_artifact_repository(model_uri)
|
155
|
+
except RestException as e:
|
156
|
+
raise ValueError(f"The model at '{model_uri}' does not exist.", e)
|
157
|
+
|
158
|
+
artifact_path = os.path.join(mlflow.pyfunc.DATA, MODEL_DATA_PATH_ROOT)
|
159
|
+
if len(repo.list_artifacts(artifact_path)) == 0:
|
160
|
+
raise ValueError(
|
161
|
+
f"No suitable model found at '{model_uri}'. Either no model exists in this "
|
162
|
+
f"artifact location or an existing model was not packaged with Feature Store metadata. "
|
163
|
+
f"Only models logged by FeatureStoreClient.log_model can be used in inference."
|
164
|
+
)
|
165
|
+
|
166
|
+
return repo.download_artifacts(artifact_path="", dst_path=dir)
|
167
|
+
|
168
|
+
|
169
|
+
def validate_params_non_empty(params: Dict[str, Any], expected_params: List[str]):
|
170
|
+
"""
|
171
|
+
Validate that none of the expected parameters are empty.
|
172
|
+
"""
|
173
|
+
for expected_param in expected_params:
|
174
|
+
if expected_param not in params:
|
175
|
+
raise ValueError(
|
176
|
+
f'Internal error: expected parameter "{expected_param}" not found in params dictionary'
|
177
|
+
)
|
178
|
+
param_value = params[expected_param]
|
179
|
+
if not param_value:
|
180
|
+
raise ValueError(f'Parameter "{expected_param}" cannot be empty')
|
181
|
+
|
182
|
+
|
183
|
+
def is_in_databricks_job():
|
184
|
+
"""
|
185
|
+
Overrides the behavior of the mlflow databricks_utils.is_in_databricks_job().
|
186
|
+
"""
|
187
|
+
try:
|
188
|
+
return databricks_utils.get_job_id() is not None
|
189
|
+
except Exception:
|
190
|
+
return False
|
191
|
+
|
192
|
+
|
193
|
+
def get_workspace_url() -> Optional[str]:
|
194
|
+
"""
|
195
|
+
Overrides the behavior of the mlflow.utils.databricks_utils.get_workspace_url().
|
196
|
+
"""
|
197
|
+
workspace_url = databricks_utils.get_workspace_url()
|
198
|
+
if workspace_url and not urlparse(workspace_url).scheme:
|
199
|
+
workspace_url = "https://" + workspace_url
|
200
|
+
return workspace_url
|
201
|
+
|
202
|
+
|
203
|
+
def is_in_databricks_env():
|
204
|
+
"""
|
205
|
+
Determine if we are running in a Databricks environment.
|
206
|
+
"""
|
207
|
+
try:
|
208
|
+
return (
|
209
|
+
is_in_databricks_job()
|
210
|
+
or databricks_utils.is_in_databricks_notebook()
|
211
|
+
or databricks_utils.is_in_databricks_runtime()
|
212
|
+
)
|
213
|
+
except Exception:
|
214
|
+
return False
|
10
215
|
|
11
216
|
|
12
217
|
def is_artifact_uri(uri):
|
@@ -18,6 +223,7 @@ def is_artifact_uri(uri):
|
|
18
223
|
uri
|
19
224
|
) or RunsArtifactRepository.is_runs_uri(uri)
|
20
225
|
|
226
|
+
|
21
227
|
def as_list(obj, default=None):
|
22
228
|
if not obj:
|
23
229
|
return default
|
@@ -26,6 +232,7 @@ def as_list(obj, default=None):
|
|
26
232
|
else:
|
27
233
|
return [obj]
|
28
234
|
|
235
|
+
|
29
236
|
def get_duplicates(elements: List[Any]) -> List[Any]:
|
30
237
|
"""
|
31
238
|
Returns duplicate elements in the order they first appear.
|
@@ -37,6 +244,7 @@ def get_duplicates(elements: List[Any]) -> List[Any]:
|
|
37
244
|
duplicates.append(e)
|
38
245
|
return duplicates
|
39
246
|
|
247
|
+
|
40
248
|
def validate_strings_unique(strings: List[str], error_template: str):
|
41
249
|
"""
|
42
250
|
Validates all strings are unique, otherwise raise ValueError with the error template and duplicates.
|
@@ -47,6 +255,7 @@ def validate_strings_unique(strings: List[str], error_template: str):
|
|
47
255
|
duplicates_formatted = ", ".join([f"'{s}'" for s in duplicate_strings])
|
48
256
|
raise ValueError(error_template.format(duplicates_formatted))
|
49
257
|
|
258
|
+
|
50
259
|
def sanitize_identifier(identifier: str):
|
51
260
|
"""
|
52
261
|
Sanitize and wrap an identifier with backquotes. For example, "a`b" becomes "`a``b`".
|
@@ -89,8 +298,9 @@ def unsanitize_identifier(identifier: str):
|
|
89
298
|
def escape_sql_string(input_str: str) -> str:
|
90
299
|
return input_str.replace("\\", "\\\\").replace("'", "\\'")
|
91
300
|
|
301
|
+
|
92
302
|
def get_unique_list_order(elements: List[Any]) -> List[Any]:
|
93
303
|
"""
|
94
304
|
Returns unique elements in the order they first appear.
|
95
305
|
"""
|
96
|
-
return list(dict.fromkeys(elements))
|
306
|
+
return list(dict.fromkeys(elements))
|
@@ -10,13 +10,13 @@ from pyspark.sql import DataFrame, Window
|
|
10
10
|
from pyspark.sql import functions as F
|
11
11
|
from pyspark.sql.functions import sum, unix_timestamp
|
12
12
|
|
13
|
-
from feature_store.entities.environment_variables import BROADCAST_JOIN_THRESHOLD
|
14
|
-
from feature_store.entities.feature_column_info import FeatureColumnInfo
|
15
|
-
from feature_store.entities.feature_lookup import FeatureLookup
|
16
|
-
from feature_store.entities.feature_spec import FeatureSpec
|
17
|
-
from feature_store.entities.feature_table import FeatureTable
|
13
|
+
from wedata.feature_store.entities.environment_variables import BROADCAST_JOIN_THRESHOLD
|
14
|
+
from wedata.feature_store.entities.feature_column_info import FeatureColumnInfo
|
15
|
+
from wedata.feature_store.entities.feature_lookup import FeatureLookup
|
16
|
+
from wedata.feature_store.entities.feature_spec import FeatureSpec
|
17
|
+
from wedata.feature_store.entities.feature_table import FeatureTable
|
18
18
|
|
19
|
-
from feature_store.utils import common_utils, validation_utils, uc_utils
|
19
|
+
from wedata.feature_store.utils import common_utils, validation_utils, uc_utils
|
20
20
|
|
21
21
|
_logger = logging.getLogger(__name__)
|
22
22
|
|
@@ -6,12 +6,12 @@ from typing import Dict, List, Set, Tuple, Type, Union
|
|
6
6
|
import yaml
|
7
7
|
from mlflow.utils.file_utils import YamlSafeDumper
|
8
8
|
|
9
|
-
from feature_store.entities.column_info import ColumnInfo
|
10
|
-
from feature_store.entities.feature_column_info import FeatureColumnInfo
|
11
|
-
from feature_store.entities.feature_spec import FeatureSpec
|
12
|
-
from feature_store.entities.on_demand_column_info import OnDemandColumnInfo
|
13
|
-
from feature_store.entities.source_data_column_info import SourceDataColumnInfo
|
14
|
-
from feature_store.utils.topological_sort import topological_sort
|
9
|
+
from wedata.feature_store.entities.column_info import ColumnInfo
|
10
|
+
from wedata.feature_store.entities.feature_column_info import FeatureColumnInfo
|
11
|
+
from wedata.feature_store.entities.feature_spec import FeatureSpec
|
12
|
+
from wedata.feature_store.entities.on_demand_column_info import OnDemandColumnInfo
|
13
|
+
from wedata.feature_store.entities.source_data_column_info import SourceDataColumnInfo
|
14
|
+
from wedata.feature_store.utils.topological_sort import topological_sort
|
15
15
|
|
16
16
|
DEFAULT_GRAPH_DEPTH_LIMIT = 5
|
17
17
|
|
@@ -1,11 +1,11 @@
|
|
1
1
|
import copy
|
2
2
|
from typing import List, Union
|
3
3
|
|
4
|
-
from feature_store.entities.feature_function import FeatureFunction
|
5
|
-
from feature_store.entities.feature_lookup import FeatureLookup
|
6
|
-
from feature_store.spark_client.spark_client import SparkClient
|
7
|
-
from feature_store.utils import uc_utils
|
8
|
-
from feature_store.utils.feature_lookup_utils import get_feature_lookups_with_full_table_names
|
4
|
+
from wedata.feature_store.entities.feature_function import FeatureFunction
|
5
|
+
from wedata.feature_store.entities.feature_lookup import FeatureLookup
|
6
|
+
from wedata.feature_store.spark_client.spark_client import SparkClient
|
7
|
+
from wedata.feature_store.utils import uc_utils
|
8
|
+
from wedata.feature_store.utils.feature_lookup_utils import get_feature_lookups_with_full_table_names
|
9
9
|
|
10
10
|
|
11
11
|
def format_feature_lookups_and_functions(
|
@@ -0,0 +1,107 @@
|
|
1
|
+
import copy
|
2
|
+
from typing import Dict, List
|
3
|
+
|
4
|
+
from pyspark.sql import DataFrame
|
5
|
+
from pyspark.sql.functions import expr
|
6
|
+
|
7
|
+
from wedata.feature_store.entities.feature_function import FeatureFunction
|
8
|
+
from wedata.feature_store.entities.function_info import FunctionInfo
|
9
|
+
from wedata.feature_store.entities.on_demand_column_info import OnDemandColumnInfo
|
10
|
+
from wedata.feature_store.utils import common_utils, uc_utils
|
11
|
+
|
12
|
+
|
13
|
+
def _udf_expr(udf_name: str, arguments: List[str]) -> expr:
|
14
|
+
"""
|
15
|
+
Generate a Spark SQL expression, e.g. expr("udf_name(col1, col2)")
|
16
|
+
"""
|
17
|
+
arguments_str = ", ".join(common_utils.sanitize_identifiers(arguments))
|
18
|
+
return expr(f"{udf_name}({arguments_str})")
|
19
|
+
|
20
|
+
|
21
|
+
def _validate_apply_functions_df(
|
22
|
+
df: DataFrame,
|
23
|
+
functions_to_apply: List[OnDemandColumnInfo],
|
24
|
+
uc_function_infos: Dict[str, FunctionInfo],
|
25
|
+
):
|
26
|
+
"""
|
27
|
+
Validate the following:
|
28
|
+
1. On-demand input columns specified by functions_to_apply exist in the DataFrame.
|
29
|
+
2. On-demand input columns have data types that match those of UDF parameters.
|
30
|
+
"""
|
31
|
+
for odci in functions_to_apply:
|
32
|
+
function_info = uc_function_infos[odci.udf_name]
|
33
|
+
types_dict = dict(df.dtypes)
|
34
|
+
|
35
|
+
for p in function_info.input_params:
|
36
|
+
arg_column = odci.input_bindings[p.name]
|
37
|
+
if arg_column not in df.columns:
|
38
|
+
raise ValueError(
|
39
|
+
f"FeatureFunction argument column '{arg_column}' for UDF '{odci.udf_name}' parameter '{p.name}' "
|
40
|
+
f"does not exist in provided DataFrame with schema '{df.schema}'."
|
41
|
+
)
|
42
|
+
if types_dict[arg_column] != p.type_text:
|
43
|
+
raise ValueError(
|
44
|
+
f"FeatureFunction argument column '{arg_column}' for UDF '{odci.udf_name}' parameter '{p.name}' "
|
45
|
+
f"does not have the expected type. Argument column '{arg_column}' has type "
|
46
|
+
f"'{types_dict[arg_column]}' and parameter '{p.name}' has type '{p.type_text}'."
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
def apply_functions_if_not_overridden(
|
51
|
+
df: DataFrame,
|
52
|
+
functions_to_apply: List[OnDemandColumnInfo],
|
53
|
+
uc_function_infos: Dict[str, FunctionInfo],
|
54
|
+
) -> DataFrame:
|
55
|
+
"""
|
56
|
+
For all on-demand features, in the order defined by the FeatureSpec:
|
57
|
+
If the feature does not already exist, append the evaluated UDF expression.
|
58
|
+
Existing column values or column positions are not modified.
|
59
|
+
|
60
|
+
`_validate_apply_functions_df` validates UDFs can be applied on `df` schema.
|
61
|
+
|
62
|
+
The caller should validate:
|
63
|
+
1. FeatureFunction bound argument columns for UDF parameters exist in FeatureSpec defined features.
|
64
|
+
2. FeatureFunction output feature names are unique.
|
65
|
+
"""
|
66
|
+
_validate_apply_functions_df(
|
67
|
+
df=df,
|
68
|
+
functions_to_apply=functions_to_apply,
|
69
|
+
uc_function_infos=uc_function_infos,
|
70
|
+
)
|
71
|
+
|
72
|
+
columns = {}
|
73
|
+
for odci in functions_to_apply:
|
74
|
+
if odci.output_name not in df.columns:
|
75
|
+
function_info = uc_function_infos[odci.udf_name]
|
76
|
+
# Resolve the bound arguments in the UDF parameter order
|
77
|
+
udf_arguments = [
|
78
|
+
odci.input_bindings[p.name] for p in function_info.input_params
|
79
|
+
]
|
80
|
+
columns[odci.output_name] = _udf_expr(odci.udf_name, udf_arguments)
|
81
|
+
return df.withColumns(columns)
|
82
|
+
|
83
|
+
|
84
|
+
def get_feature_functions_with_full_udf_names(
|
85
|
+
feature_functions: List[FeatureFunction], current_catalog: str, current_schema: str
|
86
|
+
):
|
87
|
+
"""
|
88
|
+
Takes in a list of FeatureFunctions, and returns copies with:
|
89
|
+
1. Fully qualified UDF names.
|
90
|
+
2. If output_name is empty, fully qualified UDF names as output_name.
|
91
|
+
"""
|
92
|
+
udf_names = {ff.udf_name for ff in feature_functions}
|
93
|
+
uc_utils._check_qualified_udf_names(udf_names)
|
94
|
+
uc_utils._verify_all_udfs_in_uc(udf_names, current_catalog, current_schema)
|
95
|
+
|
96
|
+
standardized_feature_functions = []
|
97
|
+
for ff in feature_functions:
|
98
|
+
ff_copy = copy.deepcopy(ff)
|
99
|
+
del ff
|
100
|
+
|
101
|
+
ff_copy._udf_name = uc_utils.get_full_udf_name(
|
102
|
+
ff_copy.udf_name, current_catalog, current_schema
|
103
|
+
)
|
104
|
+
if not ff_copy.output_name:
|
105
|
+
ff_copy._output_name = ff_copy.udf_name
|
106
|
+
standardized_feature_functions.append(ff_copy)
|
107
|
+
return standardized_feature_functions
|
@@ -0,0 +1,205 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Any, Dict, Optional
|
3
|
+
|
4
|
+
import mlflow
|
5
|
+
from mlflow.models import ModelSignature
|
6
|
+
from mlflow.types import ColSpec
|
7
|
+
from mlflow.types import DataType as MlflowDataType
|
8
|
+
from mlflow.types import ParamSchema, Schema
|
9
|
+
|
10
|
+
from wedata.feature_store.entities.feature_column_info import FeatureColumnInfo
|
11
|
+
from wedata.feature_store.entities.feature_spec import FeatureSpec
|
12
|
+
from wedata.feature_store.entities.on_demand_column_info import OnDemandColumnInfo
|
13
|
+
from wedata.feature_store.entities.source_data_column_info import SourceDataColumnInfo
|
14
|
+
|
15
|
+
_logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
# Some types (array, map, decimal, timestamp_ntz) are unsupported due to MLflow signatures
|
18
|
+
# lacking any equivalent types. We thus cannot construct a ColSpec for any column
|
19
|
+
# that uses these types.
|
20
|
+
SUPPORTED_TYPE_MAP = {
|
21
|
+
"smallint": MlflowDataType.integer, # Upcast to integer
|
22
|
+
"int": MlflowDataType.integer,
|
23
|
+
"bigint": MlflowDataType.long,
|
24
|
+
"float": MlflowDataType.float,
|
25
|
+
"double": MlflowDataType.double,
|
26
|
+
"boolean": MlflowDataType.boolean,
|
27
|
+
"date": MlflowDataType.datetime,
|
28
|
+
"timestamp": MlflowDataType.datetime,
|
29
|
+
"string": MlflowDataType.string,
|
30
|
+
"binary": MlflowDataType.binary,
|
31
|
+
}
|
32
|
+
|
33
|
+
|
34
|
+
def is_unsupported_type(type_str: str):
|
35
|
+
return type_str not in SUPPORTED_TYPE_MAP
|
36
|
+
|
37
|
+
|
38
|
+
def convert_spark_data_type_to_mlflow_signature_type(spark_type):
|
39
|
+
"""
|
40
|
+
Maps Databricks SQL types to MLflow signature types.
|
41
|
+
docs.databricks.com/sql/language-manual/sql-ref-datatypes.html#language-mappings
|
42
|
+
"""
|
43
|
+
return SUPPORTED_TYPE_MAP.get(spark_type)
|
44
|
+
|
45
|
+
|
46
|
+
def get_input_schema_from_feature_spec(feature_spec: FeatureSpec) -> Schema:
|
47
|
+
"""
|
48
|
+
Produces an MLflow signature schema from a feature spec.
|
49
|
+
Source data columns are marked as required inputs and feature columns
|
50
|
+
(both lookups and on-demand features) are marked as optional inputs.
|
51
|
+
|
52
|
+
:param feature_spec: FeatureSpec object with datatypes for each column.
|
53
|
+
"""
|
54
|
+
# If we're missing any data types for any column, we are likely dealing with a
|
55
|
+
# malformed feature spec and should halt signature construction.
|
56
|
+
if any([ci.data_type is None for ci in feature_spec.column_infos]):
|
57
|
+
raise Exception("Training set does not contain column data types.")
|
58
|
+
|
59
|
+
source_data_cols = [
|
60
|
+
ci
|
61
|
+
for ci in feature_spec.column_infos
|
62
|
+
if isinstance(ci.info, SourceDataColumnInfo)
|
63
|
+
]
|
64
|
+
# Don't create signature if any source data columns (required) are of complex types.
|
65
|
+
if any(
|
66
|
+
[
|
67
|
+
ci.data_type is None or is_unsupported_type(ci.data_type)
|
68
|
+
for ci in source_data_cols
|
69
|
+
]
|
70
|
+
):
|
71
|
+
raise Exception(
|
72
|
+
"Input DataFrame contains column data types not supported by "
|
73
|
+
"MLflow model signatures."
|
74
|
+
)
|
75
|
+
required_input_colspecs = [
|
76
|
+
ColSpec(
|
77
|
+
convert_spark_data_type_to_mlflow_signature_type(ci.data_type),
|
78
|
+
ci.info.output_name,
|
79
|
+
required=True,
|
80
|
+
)
|
81
|
+
for ci in source_data_cols
|
82
|
+
]
|
83
|
+
feature_cols = [
|
84
|
+
ci
|
85
|
+
for ci in feature_spec.column_infos
|
86
|
+
if isinstance(ci.info, (FeatureColumnInfo, OnDemandColumnInfo))
|
87
|
+
]
|
88
|
+
unsupported_feature_cols = [
|
89
|
+
ci for ci in feature_cols if is_unsupported_type(ci.data_type)
|
90
|
+
]
|
91
|
+
optional_input_colspecs = [
|
92
|
+
ColSpec(
|
93
|
+
convert_spark_data_type_to_mlflow_signature_type(ci.data_type),
|
94
|
+
ci.output_name,
|
95
|
+
required=False,
|
96
|
+
)
|
97
|
+
for ci in feature_cols
|
98
|
+
if not is_unsupported_type(ci.data_type)
|
99
|
+
]
|
100
|
+
if unsupported_feature_cols:
|
101
|
+
feat_string = ", ".join(
|
102
|
+
[f"{ci.output_name} ({ci.data_type})" for ci in unsupported_feature_cols]
|
103
|
+
)
|
104
|
+
_logger.warning(
|
105
|
+
f"The following features will not be included in the input schema because their"
|
106
|
+
f" data types are not supported by MLflow model signatures: {feat_string}. "
|
107
|
+
f"These features cannot be overridden during model serving."
|
108
|
+
)
|
109
|
+
return Schema(required_input_colspecs + optional_input_colspecs)
|
110
|
+
|
111
|
+
|
112
|
+
def get_output_schema_from_labels(label_type_map: Optional[Dict[str, str]]) -> Schema:
|
113
|
+
"""
|
114
|
+
Produces an MLflow signature schema from the provided label type map.
|
115
|
+
:param label_type_map: Map label column name -> data type
|
116
|
+
"""
|
117
|
+
if not label_type_map:
|
118
|
+
raise Exception("Training set does not contain a label.")
|
119
|
+
if any([is_unsupported_type(dtype) for dtype in label_type_map.values()]):
|
120
|
+
raise Exception(
|
121
|
+
"Labels are of data types not supported by MLflow model signatures."
|
122
|
+
)
|
123
|
+
else:
|
124
|
+
output_colspecs = [
|
125
|
+
ColSpec(
|
126
|
+
convert_spark_data_type_to_mlflow_signature_type(spark_type),
|
127
|
+
col_name,
|
128
|
+
required=True,
|
129
|
+
)
|
130
|
+
for col_name, spark_type in label_type_map.items()
|
131
|
+
]
|
132
|
+
return Schema(output_colspecs)
|
133
|
+
|
134
|
+
|
135
|
+
def get_mlflow_signature_from_feature_spec(
|
136
|
+
feature_spec: FeatureSpec,
|
137
|
+
label_type_map: Optional[Dict[str, str]],
|
138
|
+
override_output_schema: Optional[Schema],
|
139
|
+
params: Optional[Dict[str, Any]] = None,
|
140
|
+
) -> Optional[ModelSignature]:
|
141
|
+
"""
|
142
|
+
Produce an MLflow signature from a feature spec and label type map.
|
143
|
+
Source data columns are marked as required inputs and feature columns
|
144
|
+
(both lookups and on-demand features) are marked as optional inputs.
|
145
|
+
|
146
|
+
Reads output types from the cached label -> datatype map in the training set.
|
147
|
+
If override_output_schema is provided, it will always be used as the output schema.
|
148
|
+
|
149
|
+
:param feature_spec: FeatureSpec object with datatypes for each column.
|
150
|
+
:param label_type_map: Map of label column name -> datatype
|
151
|
+
:param override_output_schema: User-provided output schema to use if provided.
|
152
|
+
"""
|
153
|
+
kwargs = {}
|
154
|
+
kwargs["inputs"] = get_input_schema_from_feature_spec(feature_spec)
|
155
|
+
try:
|
156
|
+
output_schema = override_output_schema or get_output_schema_from_labels(
|
157
|
+
label_type_map
|
158
|
+
)
|
159
|
+
kwargs["outputs"] = output_schema
|
160
|
+
except Exception as e:
|
161
|
+
_logger.warning(f"Could not infer an output schema: {e}")
|
162
|
+
|
163
|
+
if params:
|
164
|
+
try:
|
165
|
+
from mlflow.types.utils import _infer_param_schema
|
166
|
+
|
167
|
+
kwargs["params"] = _infer_param_schema(params)
|
168
|
+
except Exception as e:
|
169
|
+
_logger.warning(f"Could not infer params schema: {e}")
|
170
|
+
|
171
|
+
return mlflow.models.ModelSignature(**kwargs)
|
172
|
+
|
173
|
+
|
174
|
+
def drop_signature_inputs_and_invalid_params(signature):
|
175
|
+
"""
|
176
|
+
Drop ModelSignature inputs field and invalid params from params field.
|
177
|
+
This is useful for feature store model's raw_model.
|
178
|
+
Feature store model's input schema does not apply to raw_model's input,
|
179
|
+
so we drop the inputs field of raw_model's signature.
|
180
|
+
Feature store model's result_type param enables setting and overriding
|
181
|
+
a default result_type for predictions, but this interferes with params
|
182
|
+
passed to MLflow's predict function, so we drop result_type from
|
183
|
+
the params field of raw_model's signature.
|
184
|
+
|
185
|
+
:param signature: ModelSignature object.
|
186
|
+
"""
|
187
|
+
if signature:
|
188
|
+
outputs_schema = signature.outputs
|
189
|
+
params_schema = signature.params if hasattr(signature, "params") else None
|
190
|
+
try:
|
191
|
+
# Only for mlflow>=2.6.0 ModelSignature contains params attribute
|
192
|
+
if params_schema:
|
193
|
+
updated_params_schema = ParamSchema(
|
194
|
+
[param for param in params_schema if param.name != "result_type"]
|
195
|
+
)
|
196
|
+
return ModelSignature(
|
197
|
+
outputs=outputs_schema, params=updated_params_schema
|
198
|
+
)
|
199
|
+
if outputs_schema:
|
200
|
+
return ModelSignature(outputs=outputs_schema)
|
201
|
+
except TypeError:
|
202
|
+
_logger.warning(
|
203
|
+
"ModelSignature without inputs is not supported, please upgrade "
|
204
|
+
"mlflow >= 2.7.0 to use the feature."
|
205
|
+
)
|