tencent-wedata-feature-engineering-dev 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tencent-wedata-feature-engineering-dev might be problematic. Click here for more details.
- tencent_wedata_feature_engineering_dev-0.1.0.dist-info/METADATA +19 -0
- tencent_wedata_feature_engineering_dev-0.1.0.dist-info/RECORD +64 -0
- tencent_wedata_feature_engineering_dev-0.1.0.dist-info/WHEEL +5 -0
- tencent_wedata_feature_engineering_dev-0.1.0.dist-info/top_level.txt +1 -0
- wedata/__init__.py +9 -0
- wedata/feature_store/__init__.py +0 -0
- wedata/feature_store/client.py +462 -0
- wedata/feature_store/cloud_sdk_client/__init__.py +0 -0
- wedata/feature_store/cloud_sdk_client/client.py +86 -0
- wedata/feature_store/cloud_sdk_client/models.py +686 -0
- wedata/feature_store/cloud_sdk_client/utils.py +32 -0
- wedata/feature_store/common/__init__.py +0 -0
- wedata/feature_store/common/protos/__init__.py +0 -0
- wedata/feature_store/common/protos/feature_store_pb2.py +49 -0
- wedata/feature_store/common/store_config/__init__.py +0 -0
- wedata/feature_store/common/store_config/redis.py +48 -0
- wedata/feature_store/constants/__init__.py +0 -0
- wedata/feature_store/constants/constants.py +59 -0
- wedata/feature_store/constants/engine_types.py +34 -0
- wedata/feature_store/entities/__init__.py +0 -0
- wedata/feature_store/entities/column_info.py +138 -0
- wedata/feature_store/entities/environment_variables.py +55 -0
- wedata/feature_store/entities/feature.py +53 -0
- wedata/feature_store/entities/feature_column_info.py +72 -0
- wedata/feature_store/entities/feature_function.py +55 -0
- wedata/feature_store/entities/feature_lookup.py +200 -0
- wedata/feature_store/entities/feature_spec.py +489 -0
- wedata/feature_store/entities/feature_spec_constants.py +25 -0
- wedata/feature_store/entities/feature_table.py +111 -0
- wedata/feature_store/entities/feature_table_info.py +49 -0
- wedata/feature_store/entities/function_info.py +90 -0
- wedata/feature_store/entities/on_demand_column_info.py +57 -0
- wedata/feature_store/entities/source_data_column_info.py +24 -0
- wedata/feature_store/entities/training_set.py +135 -0
- wedata/feature_store/feast_client/__init__.py +0 -0
- wedata/feature_store/feast_client/feast_client.py +482 -0
- wedata/feature_store/feature_table_client/__init__.py +0 -0
- wedata/feature_store/feature_table_client/feature_table_client.py +969 -0
- wedata/feature_store/mlflow_model.py +17 -0
- wedata/feature_store/spark_client/__init__.py +0 -0
- wedata/feature_store/spark_client/spark_client.py +289 -0
- wedata/feature_store/training_set_client/__init__.py +0 -0
- wedata/feature_store/training_set_client/training_set_client.py +572 -0
- wedata/feature_store/utils/__init__.py +0 -0
- wedata/feature_store/utils/common_utils.py +352 -0
- wedata/feature_store/utils/env_utils.py +86 -0
- wedata/feature_store/utils/feature_lookup_utils.py +564 -0
- wedata/feature_store/utils/feature_spec_utils.py +286 -0
- wedata/feature_store/utils/feature_utils.py +73 -0
- wedata/feature_store/utils/on_demand_utils.py +107 -0
- wedata/feature_store/utils/schema_utils.py +117 -0
- wedata/feature_store/utils/signature_utils.py +202 -0
- wedata/feature_store/utils/topological_sort.py +158 -0
- wedata/feature_store/utils/training_set_utils.py +579 -0
- wedata/feature_store/utils/uc_utils.py +296 -0
- wedata/feature_store/utils/validation_utils.py +79 -0
- wedata/tempo/__init__.py +0 -0
- wedata/tempo/interpol.py +448 -0
- wedata/tempo/intervals.py +1331 -0
- wedata/tempo/io.py +61 -0
- wedata/tempo/ml.py +129 -0
- wedata/tempo/resample.py +318 -0
- wedata/tempo/tsdf.py +1720 -0
- wedata/tempo/utils.py +254 -0
|
@@ -0,0 +1,296 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import re
|
|
3
|
+
from typing import Optional, Set, Any, List
|
|
4
|
+
from datetime import datetime, timezone
|
|
5
|
+
|
|
6
|
+
from wedata.feature_store.entities.feature_spec import FeatureSpec
|
|
7
|
+
|
|
8
|
+
SINGLE_LEVEL_NAMESPACE_REGEX = r"^[^\. \/\x00-\x1F\x7F]+$"
|
|
9
|
+
TWO_LEVEL_NAMESPACE_REGEX = r"^[^\. \/\x00-\x1F\x7F]+(\.[^\. \/\x00-\x1F\x7F]+)$"
|
|
10
|
+
THREE_LEVEL_NAMESPACE_REGEX = (
|
|
11
|
+
r"^[^\. \/\x00-\x1F\x7F]+(\.[^\. \/\x00-\x1F\x7F]+)(\.[^\. \/\x00-\x1F\x7F]+)$"
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
HIVE_METASTORE_NAME = "hive_metastore"
|
|
15
|
+
# these two catalog names both points to the workspace local default HMS (hive metastore).
|
|
16
|
+
LOCAL_METASTORE_NAMES = [HIVE_METASTORE_NAME,"spark_catalog"]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# Get full table name in the form of <catalog_name>.<schema_name>.<table_name>
|
|
20
|
+
# given user specified table name, current catalog and schema.
|
|
21
|
+
def get_full_table_name(
|
|
22
|
+
table_name: str,
|
|
23
|
+
current_catalog: str,
|
|
24
|
+
current_schema: str,
|
|
25
|
+
) -> str:
|
|
26
|
+
_check_qualified_table_names({table_name})
|
|
27
|
+
return _get_full_name_for_entity(
|
|
28
|
+
name=table_name,
|
|
29
|
+
current_catalog=current_catalog,
|
|
30
|
+
current_schema=current_schema,
|
|
31
|
+
entity_type="table",
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# Get full UDF name in the form of <catalog_name>.<schema_name>.<udf_name>
|
|
36
|
+
# given user specified UDF name, current catalog and schema.
|
|
37
|
+
def get_full_udf_name(
|
|
38
|
+
udf_name: str,
|
|
39
|
+
current_catalog: str,
|
|
40
|
+
current_schema: str,
|
|
41
|
+
) -> str:
|
|
42
|
+
_check_qualified_udf_names({udf_name})
|
|
43
|
+
return _get_full_name_for_entity(
|
|
44
|
+
name=udf_name,
|
|
45
|
+
current_catalog=current_catalog,
|
|
46
|
+
current_schema=current_schema,
|
|
47
|
+
entity_type="UDF",
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _get_full_name_for_entity(
|
|
52
|
+
name: str,
|
|
53
|
+
current_catalog: str,
|
|
54
|
+
current_schema: str,
|
|
55
|
+
entity_type: str,
|
|
56
|
+
) -> str:
|
|
57
|
+
if not _is_single_level_name(current_catalog) or not _is_single_level_name(
|
|
58
|
+
current_schema
|
|
59
|
+
):
|
|
60
|
+
raise ValueError(
|
|
61
|
+
f"Invalid catalog '{current_catalog}' or "
|
|
62
|
+
f"schema '{current_schema}' name for {entity_type} '{name}'."
|
|
63
|
+
)
|
|
64
|
+
if _is_single_level_name(name):
|
|
65
|
+
full_name = f"{current_catalog}.{current_schema}.{name}"
|
|
66
|
+
elif _is_two_level_name(name):
|
|
67
|
+
full_name = f"{current_catalog}.{name}"
|
|
68
|
+
elif _is_three_level_name(name):
|
|
69
|
+
full_name = name
|
|
70
|
+
else:
|
|
71
|
+
raise _invalid_names_error({name}, entity_type)
|
|
72
|
+
|
|
73
|
+
catalog, schema, name = full_name.split(".")
|
|
74
|
+
if catalog in LOCAL_METASTORE_NAMES:
|
|
75
|
+
return f"{HIVE_METASTORE_NAME}.{schema}.{name}"
|
|
76
|
+
return full_name
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _replace_catalog_name(full_name: str, catalog: Optional[str]) -> str:
|
|
80
|
+
if catalog is None:
|
|
81
|
+
return full_name
|
|
82
|
+
name_sec = full_name.split(".")
|
|
83
|
+
name_sec[0] = catalog
|
|
84
|
+
return ".".join(name_sec)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
# Local metastore tables in feature_spec.yaml are all stored in 2L.
|
|
88
|
+
# Standardize table names to be all in 3L to avoid erroneously reading data from UC tables.
|
|
89
|
+
def get_feature_spec_with_full_table_names(
|
|
90
|
+
feature_spec: FeatureSpec, catalog_name_override: Optional[str] = None
|
|
91
|
+
) -> FeatureSpec:
|
|
92
|
+
column_info_table_names = [
|
|
93
|
+
column_info.table_name for column_info in feature_spec.feature_column_infos
|
|
94
|
+
]
|
|
95
|
+
table_info_table_names = [
|
|
96
|
+
table_info.table_name for table_info in feature_spec.table_infos
|
|
97
|
+
]
|
|
98
|
+
_check_qualified_table_names(set(column_info_table_names))
|
|
99
|
+
_check_qualified_table_names(set(table_info_table_names))
|
|
100
|
+
invalid_table_names = list(
|
|
101
|
+
filter(_is_single_level_name, column_info_table_names)
|
|
102
|
+
) + list(filter(_is_single_level_name, table_info_table_names))
|
|
103
|
+
if len(invalid_table_names) > 0:
|
|
104
|
+
raise _invalid_names_error(set(invalid_table_names), "table")
|
|
105
|
+
standardized_feature_spec = copy.deepcopy(feature_spec)
|
|
106
|
+
for column_info in standardized_feature_spec.feature_column_infos:
|
|
107
|
+
if _is_two_level_name(column_info.table_name):
|
|
108
|
+
column_info._table_name = f"{HIVE_METASTORE_NAME}.{column_info.table_name}"
|
|
109
|
+
column_info._table_name = _replace_catalog_name(
|
|
110
|
+
column_info.table_name, catalog_name_override
|
|
111
|
+
)
|
|
112
|
+
for column_info in standardized_feature_spec.on_demand_column_infos:
|
|
113
|
+
if _is_two_level_name(column_info.udf_name):
|
|
114
|
+
column_info._udf_name = f"{HIVE_METASTORE_NAME}.{column_info.udf_name}"
|
|
115
|
+
column_info._udf_name = _replace_catalog_name(
|
|
116
|
+
column_info.udf_name, catalog_name_override
|
|
117
|
+
)
|
|
118
|
+
for table_info in standardized_feature_spec.table_infos:
|
|
119
|
+
if _is_two_level_name(table_info.table_name):
|
|
120
|
+
table_info._table_name = f"{HIVE_METASTORE_NAME}.{table_info.table_name}"
|
|
121
|
+
table_info._table_name = _replace_catalog_name(
|
|
122
|
+
table_info.table_name, catalog_name_override
|
|
123
|
+
)
|
|
124
|
+
for udf_info in standardized_feature_spec.function_infos:
|
|
125
|
+
udf_info._udf_name = _replace_catalog_name(
|
|
126
|
+
udf_info.udf_name, catalog_name_override
|
|
127
|
+
)
|
|
128
|
+
return standardized_feature_spec
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
# Reformat 3L table name for tables in local metastore to 2L. This is used when interacting with catalog client
|
|
132
|
+
# and serializing workspace local feature spec for scoring.
|
|
133
|
+
def reformat_full_table_name(full_table_name: str) -> str:
|
|
134
|
+
if not _is_three_level_name(full_table_name):
|
|
135
|
+
raise _invalid_names_error({full_table_name}, "table")
|
|
136
|
+
catalog, schema, table = full_table_name.split(".")
|
|
137
|
+
if catalog in LOCAL_METASTORE_NAMES:
|
|
138
|
+
return f"{schema}.{table}"
|
|
139
|
+
return full_table_name
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
# Reformat table names in feature_spec with reformat_full_table_name
|
|
143
|
+
def get_feature_spec_with_reformat_full_table_names(
|
|
144
|
+
feature_spec: FeatureSpec,
|
|
145
|
+
) -> FeatureSpec:
|
|
146
|
+
column_info_table_names = [
|
|
147
|
+
column_info.table_name for column_info in feature_spec.feature_column_infos
|
|
148
|
+
]
|
|
149
|
+
table_info_table_names = [
|
|
150
|
+
table_info.table_name for table_info in feature_spec.table_infos
|
|
151
|
+
]
|
|
152
|
+
_check_qualified_table_names(set(column_info_table_names))
|
|
153
|
+
_check_qualified_table_names(set(table_info_table_names))
|
|
154
|
+
invalid_table_names = list(
|
|
155
|
+
filter(lambda name: not _is_three_level_name(name), column_info_table_names)
|
|
156
|
+
) + list(
|
|
157
|
+
filter(lambda name: not _is_three_level_name(name), table_info_table_names)
|
|
158
|
+
)
|
|
159
|
+
if len(invalid_table_names) > 0:
|
|
160
|
+
raise _invalid_names_error(set(invalid_table_names), "table")
|
|
161
|
+
standardized_feature_spec = copy.deepcopy(feature_spec)
|
|
162
|
+
for column_info in standardized_feature_spec.feature_column_infos:
|
|
163
|
+
column_info._table_name = reformat_full_table_name(column_info.table_name)
|
|
164
|
+
for table_info in standardized_feature_spec.table_infos:
|
|
165
|
+
table_info._table_name = reformat_full_table_name(table_info.table_name)
|
|
166
|
+
return standardized_feature_spec
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _invalid_names_error(invalid_names: Set[str], entity_type: str) -> ValueError:
|
|
170
|
+
return ValueError(
|
|
171
|
+
f"Invalid {entity_type} name{'s' if len(invalid_names) > 1 else ''} '{', '.join(invalid_names)}'."
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _is_qualified_entity_name(name) -> bool:
|
|
176
|
+
return isinstance(name, str) and (
|
|
177
|
+
_is_single_level_name(name)
|
|
178
|
+
or _is_two_level_name(name)
|
|
179
|
+
or _is_three_level_name(name)
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _is_single_level_name(name) -> bool:
|
|
184
|
+
return (
|
|
185
|
+
isinstance(name, str)
|
|
186
|
+
and re.match(SINGLE_LEVEL_NAMESPACE_REGEX, name) is not None
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _is_two_level_name(name) -> bool:
|
|
191
|
+
return (
|
|
192
|
+
isinstance(name, str) and re.match(TWO_LEVEL_NAMESPACE_REGEX, name) is not None
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def _is_three_level_name(name) -> bool:
|
|
197
|
+
return (
|
|
198
|
+
isinstance(name, str)
|
|
199
|
+
and re.match(THREE_LEVEL_NAMESPACE_REGEX, name) is not None
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def unsupported_api_error_uc(api_name):
|
|
204
|
+
return ValueError(f"{api_name} is not supported for Unity Catalog tables.")
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
# check if entity is in UC
|
|
208
|
+
def is_uc_entity(full_entity_name) -> bool:
|
|
209
|
+
catalog_name, schema_name, table_name = full_entity_name.split(".")
|
|
210
|
+
return not is_default_hms_table(full_entity_name)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def is_default_hms_table(full_table_name) -> bool:
|
|
214
|
+
catalog_name, schema_name, table_name = full_table_name.split(".")
|
|
215
|
+
return catalog_name in LOCAL_METASTORE_NAMES
|
|
216
|
+
# return True
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
# check if UDF names are in the correct format - 1L, 2L or 3L
|
|
220
|
+
def _check_qualified_udf_names(udf_names: Set[str]):
|
|
221
|
+
unqualified_udf_names = [
|
|
222
|
+
udf_name for udf_name in udf_names if not _is_qualified_entity_name(udf_name)
|
|
223
|
+
]
|
|
224
|
+
if len(unqualified_udf_names) > 0:
|
|
225
|
+
raise ValueError(
|
|
226
|
+
f"UDF name{'s' if len(unqualified_udf_names) > 1 else ''} "
|
|
227
|
+
f"'{', '.join(map(str, unqualified_udf_names))}' must have the form "
|
|
228
|
+
f"<catalog_name>.<schema_name>.<udf_name>, <schema_name>.<udf_name>, "
|
|
229
|
+
f"or <udf_name> and cannot include space or forward-slash."
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
# check if table names are in the correct format - 1L, 2L or 3L
|
|
234
|
+
def _check_qualified_table_names(feature_table_names: Set[str]):
|
|
235
|
+
unqualified_table_names = list(
|
|
236
|
+
filter(
|
|
237
|
+
lambda table_name: not _is_qualified_entity_name(table_name),
|
|
238
|
+
feature_table_names,
|
|
239
|
+
)
|
|
240
|
+
)
|
|
241
|
+
if len(unqualified_table_names) > 0:
|
|
242
|
+
raise ValueError(
|
|
243
|
+
f"Feature table name{'s' if len(unqualified_table_names) > 1 else ''} "
|
|
244
|
+
f"'{', '.join(map(str, unqualified_table_names))}' must have the form "
|
|
245
|
+
f"<catalog_name>.<schema_name>.<table_name>, <database_name>.<table_name>, "
|
|
246
|
+
f"or <table_name> and cannot include space or forward-slash."
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
# For APIs like create_training_set and score_batch, all tables must all be in
|
|
251
|
+
# UC catalog (shareable cross-workspaces) or default HMS (intended to only be used in the current workspace)
|
|
252
|
+
# check if all tables are either in UC or default HMS.
|
|
253
|
+
def _verify_all_tables_are_either_in_uc_or_in_hms(
|
|
254
|
+
table_names: Set[str], current_catalog: str, current_schema: str
|
|
255
|
+
):
|
|
256
|
+
full_table_names = [
|
|
257
|
+
get_full_table_name(table_name, current_catalog, current_schema)
|
|
258
|
+
for table_name in table_names
|
|
259
|
+
]
|
|
260
|
+
is_valid = all(
|
|
261
|
+
[is_uc_entity(full_table_name) for full_table_name in full_table_names]
|
|
262
|
+
) or all(
|
|
263
|
+
[is_default_hms_table(full_table_name) for full_table_name in full_table_names]
|
|
264
|
+
)
|
|
265
|
+
if not is_valid:
|
|
266
|
+
raise ValueError(
|
|
267
|
+
f"Feature table names '{', '.join(table_names)}' "
|
|
268
|
+
f"must all be in UC or the local default hive metastore. "
|
|
269
|
+
f"Mixing feature tables from two different storage locations is not allowed."
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
# For APIs like create_training_set with FeatureFunctions, only UC UDFs are supported.
|
|
274
|
+
def _verify_all_udfs_in_uc(
|
|
275
|
+
udf_names: Set[str], current_catalog: str, current_schema: str
|
|
276
|
+
):
|
|
277
|
+
full_udf_names = [
|
|
278
|
+
get_full_udf_name(udf_name, current_catalog, current_schema)
|
|
279
|
+
for udf_name in udf_names
|
|
280
|
+
]
|
|
281
|
+
is_valid = all([is_uc_entity(full_udf_name) for full_udf_name in full_udf_names])
|
|
282
|
+
if not is_valid:
|
|
283
|
+
raise ValueError(f"UDFs must all be in Unity Catalog.")
|
|
284
|
+
|
|
285
|
+
def utc_timestamp_ms_from_iso_datetime_string(date_string: str) -> int:
|
|
286
|
+
# Python uses seconds for its time granularity, so we multiply by 1000 to convert to milliseconds.
|
|
287
|
+
# The Feature Store backend returns timestamps in milliseconds, so this allows for direct comparisons.
|
|
288
|
+
dt = datetime.fromisoformat(date_string)
|
|
289
|
+
utc_dt = dt.replace(tzinfo=timezone.utc)
|
|
290
|
+
return int(1000 * utc_dt.timestamp())
|
|
291
|
+
|
|
292
|
+
def get_unique_list_order(elements: List[Any]) -> List[Any]:
|
|
293
|
+
"""
|
|
294
|
+
Returns unique elements in the order they first appear.
|
|
295
|
+
"""
|
|
296
|
+
return list(dict.fromkeys(elements))
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import List, Union, Any
|
|
3
|
+
from collections import Counter
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
from pyspark.sql import DataFrame
|
|
7
|
+
|
|
8
|
+
_logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def standardize_checkpoint_location(checkpoint_location):
|
|
12
|
+
if checkpoint_location is None:
|
|
13
|
+
return checkpoint_location
|
|
14
|
+
checkpoint_location = checkpoint_location.strip()
|
|
15
|
+
if checkpoint_location == "":
|
|
16
|
+
checkpoint_location = None
|
|
17
|
+
return checkpoint_location
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _is_spark_connect_data_frame(df):
|
|
21
|
+
# We cannot directly pyspark.sql.connect.dataframe.DataFrame as it requires Spark 3.4, which
|
|
22
|
+
# is not installed on DBR 12.2 and earlier. Instead, we string match on the type.
|
|
23
|
+
return (
|
|
24
|
+
type(df).__name__ == "DataFrame"
|
|
25
|
+
and type(df).__module__ == "pyspark.sql.connect.dataframe"
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def check_dataframe_type(df):
|
|
30
|
+
"""
|
|
31
|
+
Check if df is a PySpark DataFrame, otherwise raise an error.
|
|
32
|
+
"""
|
|
33
|
+
if not (isinstance(df, DataFrame) or _is_spark_connect_data_frame(df)):
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"Unsupported DataFrame type: {type(df)}. DataFrame must be a PySpark DataFrame."
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def check_kwargs_empty(the_kwargs, method_name):
|
|
40
|
+
if len(the_kwargs) != 0:
|
|
41
|
+
raise TypeError(
|
|
42
|
+
f"{method_name}() got unexpected keyword argument(s): {list(the_kwargs.keys())}"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def check_duplicate_keys(keys: Union[str, List[str]], key_name: str) -> None:
|
|
47
|
+
"""
|
|
48
|
+
Check if there are duplicate keys. Raise an error if there is duplicates.
|
|
49
|
+
"""
|
|
50
|
+
if keys and isinstance(keys, list):
|
|
51
|
+
seen = set()
|
|
52
|
+
for k in keys:
|
|
53
|
+
if k in seen:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
f"Found duplicated key '{k}' in {key_name}. {key_name} must be unique."
|
|
56
|
+
)
|
|
57
|
+
seen.add(k)
|
|
58
|
+
|
|
59
|
+
def get_duplicates(elements: List[Any]) -> List[Any]:
|
|
60
|
+
"""
|
|
61
|
+
Returns duplicate elements in the order they first appear.
|
|
62
|
+
"""
|
|
63
|
+
element_counts = Counter(elements)
|
|
64
|
+
duplicates = []
|
|
65
|
+
for e in element_counts.keys():
|
|
66
|
+
if element_counts[e] > 1:
|
|
67
|
+
duplicates.append(e)
|
|
68
|
+
return duplicates
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def validate_strings_unique(strings: List[str], error_template: str):
|
|
72
|
+
"""
|
|
73
|
+
Validates all strings are unique, otherwise raise ValueError with the error template and duplicates.
|
|
74
|
+
Passes single-quoted, comma delimited duplicates to the error template.
|
|
75
|
+
"""
|
|
76
|
+
duplicate_strings = get_duplicates(strings)
|
|
77
|
+
if duplicate_strings:
|
|
78
|
+
duplicates_formatted = ", ".join([f"'{s}'" for s in duplicate_strings])
|
|
79
|
+
raise ValueError(error_template.format(duplicates_formatted))
|
wedata/tempo/__init__.py
ADDED
|
File without changes
|