awx-zipline-ai 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent/__init__.py +1 -0
- agent/constants.py +15 -0
- agent/ttypes.py +1684 -0
- ai/__init__.py +0 -0
- ai/chronon/__init__.py +0 -0
- ai/chronon/airflow_helpers.py +251 -0
- ai/chronon/api/__init__.py +1 -0
- ai/chronon/api/common/__init__.py +1 -0
- ai/chronon/api/common/constants.py +15 -0
- ai/chronon/api/common/ttypes.py +1844 -0
- ai/chronon/api/constants.py +15 -0
- ai/chronon/api/ttypes.py +3624 -0
- ai/chronon/cli/compile/column_hashing.py +313 -0
- ai/chronon/cli/compile/compile_context.py +177 -0
- ai/chronon/cli/compile/compiler.py +160 -0
- ai/chronon/cli/compile/conf_validator.py +590 -0
- ai/chronon/cli/compile/display/class_tracker.py +112 -0
- ai/chronon/cli/compile/display/compile_status.py +95 -0
- ai/chronon/cli/compile/display/compiled_obj.py +12 -0
- ai/chronon/cli/compile/display/console.py +3 -0
- ai/chronon/cli/compile/display/diff_result.py +46 -0
- ai/chronon/cli/compile/fill_templates.py +40 -0
- ai/chronon/cli/compile/parse_configs.py +141 -0
- ai/chronon/cli/compile/parse_teams.py +238 -0
- ai/chronon/cli/compile/serializer.py +115 -0
- ai/chronon/cli/git_utils.py +156 -0
- ai/chronon/cli/logger.py +61 -0
- ai/chronon/constants.py +3 -0
- ai/chronon/eval/__init__.py +122 -0
- ai/chronon/eval/query_parsing.py +19 -0
- ai/chronon/eval/sample_tables.py +100 -0
- ai/chronon/eval/table_scan.py +186 -0
- ai/chronon/fetcher/__init__.py +1 -0
- ai/chronon/fetcher/constants.py +15 -0
- ai/chronon/fetcher/ttypes.py +127 -0
- ai/chronon/group_by.py +692 -0
- ai/chronon/hub/__init__.py +1 -0
- ai/chronon/hub/constants.py +15 -0
- ai/chronon/hub/ttypes.py +1228 -0
- ai/chronon/join.py +566 -0
- ai/chronon/logger.py +24 -0
- ai/chronon/model.py +35 -0
- ai/chronon/observability/__init__.py +1 -0
- ai/chronon/observability/constants.py +15 -0
- ai/chronon/observability/ttypes.py +2192 -0
- ai/chronon/orchestration/__init__.py +1 -0
- ai/chronon/orchestration/constants.py +15 -0
- ai/chronon/orchestration/ttypes.py +4406 -0
- ai/chronon/planner/__init__.py +1 -0
- ai/chronon/planner/constants.py +15 -0
- ai/chronon/planner/ttypes.py +1686 -0
- ai/chronon/query.py +126 -0
- ai/chronon/repo/__init__.py +40 -0
- ai/chronon/repo/aws.py +298 -0
- ai/chronon/repo/cluster.py +65 -0
- ai/chronon/repo/compile.py +56 -0
- ai/chronon/repo/constants.py +164 -0
- ai/chronon/repo/default_runner.py +291 -0
- ai/chronon/repo/explore.py +421 -0
- ai/chronon/repo/extract_objects.py +137 -0
- ai/chronon/repo/gcp.py +585 -0
- ai/chronon/repo/gitpython_utils.py +14 -0
- ai/chronon/repo/hub_runner.py +171 -0
- ai/chronon/repo/hub_uploader.py +108 -0
- ai/chronon/repo/init.py +53 -0
- ai/chronon/repo/join_backfill.py +105 -0
- ai/chronon/repo/run.py +293 -0
- ai/chronon/repo/serializer.py +141 -0
- ai/chronon/repo/team_json_utils.py +46 -0
- ai/chronon/repo/utils.py +472 -0
- ai/chronon/repo/zipline.py +51 -0
- ai/chronon/repo/zipline_hub.py +105 -0
- ai/chronon/resources/gcp/README.md +174 -0
- ai/chronon/resources/gcp/group_bys/test/__init__.py +0 -0
- ai/chronon/resources/gcp/group_bys/test/data.py +34 -0
- ai/chronon/resources/gcp/joins/test/__init__.py +0 -0
- ai/chronon/resources/gcp/joins/test/data.py +30 -0
- ai/chronon/resources/gcp/sources/test/__init__.py +0 -0
- ai/chronon/resources/gcp/sources/test/data.py +23 -0
- ai/chronon/resources/gcp/teams.py +70 -0
- ai/chronon/resources/gcp/zipline-cli-install.sh +54 -0
- ai/chronon/source.py +88 -0
- ai/chronon/staging_query.py +185 -0
- ai/chronon/types.py +57 -0
- ai/chronon/utils.py +557 -0
- ai/chronon/windows.py +50 -0
- awx_zipline_ai-0.2.0.dist-info/METADATA +173 -0
- awx_zipline_ai-0.2.0.dist-info/RECORD +93 -0
- awx_zipline_ai-0.2.0.dist-info/WHEEL +5 -0
- awx_zipline_ai-0.2.0.dist-info/entry_points.txt +2 -0
- awx_zipline_ai-0.2.0.dist-info/licenses/LICENSE +202 -0
- awx_zipline_ai-0.2.0.dist-info/top_level.txt +3 -0
- jars/__init__.py +0 -0
|
@@ -0,0 +1,590 @@
|
|
|
1
|
+
"""Object for checking whether a Chronon API thrift object is consistent with other
|
|
2
|
+
"""
|
|
3
|
+
|
|
4
|
+
# Copyright (C) 2023 The Chronon Authors.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
import json
|
|
19
|
+
import logging
|
|
20
|
+
import re
|
|
21
|
+
import textwrap
|
|
22
|
+
from collections import defaultdict
|
|
23
|
+
from typing import Dict, List, Tuple
|
|
24
|
+
|
|
25
|
+
import ai.chronon.api.common.ttypes as common
|
|
26
|
+
from ai.chronon.api.ttypes import (
|
|
27
|
+
Accuracy,
|
|
28
|
+
Aggregation,
|
|
29
|
+
Derivation,
|
|
30
|
+
EventSource,
|
|
31
|
+
GroupBy,
|
|
32
|
+
Join,
|
|
33
|
+
JoinPart,
|
|
34
|
+
Source,
|
|
35
|
+
)
|
|
36
|
+
from ai.chronon.cli.compile.column_hashing import (
|
|
37
|
+
compute_group_by_columns_hashes,
|
|
38
|
+
get_pre_derived_group_by_columns,
|
|
39
|
+
get_pre_derived_group_by_features,
|
|
40
|
+
get_pre_derived_join_features,
|
|
41
|
+
get_pre_derived_source_keys,
|
|
42
|
+
)
|
|
43
|
+
from ai.chronon.logger import get_logger
|
|
44
|
+
from ai.chronon.repo.serializer import thrift_simple_json
|
|
45
|
+
from ai.chronon.utils import get_query, get_root_source
|
|
46
|
+
|
|
47
|
+
# Fields that indicate status of the entities.
|
|
48
|
+
SKIPPED_FIELDS = frozenset(["metaData"])
|
|
49
|
+
EXTERNAL_KEY = "onlineExternalParts"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _filter_skipped_fields_from_join(json_obj: Dict, skipped_fields):
|
|
53
|
+
for join_part in json_obj["joinParts"]:
|
|
54
|
+
group_by = join_part["groupBy"]
|
|
55
|
+
for field in skipped_fields:
|
|
56
|
+
group_by.pop(field, None)
|
|
57
|
+
if EXTERNAL_KEY in json_obj:
|
|
58
|
+
json_obj.pop(EXTERNAL_KEY, None)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _is_batch_upload_needed(group_by: GroupBy) -> bool:
|
|
62
|
+
if group_by.metaData.online or group_by.backfillStartDate:
|
|
63
|
+
return True
|
|
64
|
+
else:
|
|
65
|
+
return False
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def is_identifier(s: str) -> bool:
|
|
69
|
+
identifier_regex = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*")
|
|
70
|
+
return re.fullmatch(identifier_regex, s) is not None
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _source_has_topic(source: Source) -> bool:
|
|
74
|
+
if source.events:
|
|
75
|
+
return source.events.topic is not None
|
|
76
|
+
elif source.entities:
|
|
77
|
+
return source.entities.mutationTopic is not None
|
|
78
|
+
elif source.joinSource:
|
|
79
|
+
return _source_has_topic(source.joinSource.join.left)
|
|
80
|
+
return False
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _group_by_has_topic(groupBy: GroupBy) -> bool:
|
|
84
|
+
return any(_source_has_topic(source) for source in groupBy.sources)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _group_by_has_hourly_windows(groupBy: GroupBy) -> bool:
|
|
88
|
+
aggs: List[Aggregation] = groupBy.aggregations
|
|
89
|
+
|
|
90
|
+
if not aggs:
|
|
91
|
+
return False
|
|
92
|
+
|
|
93
|
+
for agg in aggs:
|
|
94
|
+
|
|
95
|
+
if not agg.windows:
|
|
96
|
+
return False
|
|
97
|
+
|
|
98
|
+
for window in agg.windows:
|
|
99
|
+
if window.timeUnit == common.TimeUnit.HOURS:
|
|
100
|
+
return True
|
|
101
|
+
|
|
102
|
+
return False
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def detect_feature_name_collisions(
|
|
106
|
+
group_bys: List[Tuple[GroupBy, str]],
|
|
107
|
+
entity_set_type: str,
|
|
108
|
+
name: str
|
|
109
|
+
) -> BaseException | None:
|
|
110
|
+
# Build a map of output_column -> set of group_by name
|
|
111
|
+
output_col_to_gbs = {}
|
|
112
|
+
for gb, prefix in group_bys:
|
|
113
|
+
jp_prefix_str = f"{prefix}_" if prefix else ""
|
|
114
|
+
key_str = "_".join(gb.keyColumns)
|
|
115
|
+
prefix_str = jp_prefix_str + key_str + "_"
|
|
116
|
+
cols = compute_group_by_columns_hashes(gb, exclude_keys=True)
|
|
117
|
+
if not cols:
|
|
118
|
+
print('HERE')
|
|
119
|
+
cols = {
|
|
120
|
+
f"{prefix_str}{base_col}"
|
|
121
|
+
for base_col in list(compute_group_by_columns_hashes(gb, exclude_keys=True).keys())
|
|
122
|
+
}
|
|
123
|
+
gb_name = gb.metaData.name
|
|
124
|
+
for col in cols:
|
|
125
|
+
if col not in output_col_to_gbs:
|
|
126
|
+
output_col_to_gbs[col] = set()
|
|
127
|
+
output_col_to_gbs[col].add(gb_name)
|
|
128
|
+
|
|
129
|
+
# Find output columns that are produced by more than one group_by
|
|
130
|
+
collisions = {
|
|
131
|
+
col: gb_names
|
|
132
|
+
for col, gb_names in output_col_to_gbs.items()
|
|
133
|
+
if len(gb_names) > 1
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
if not collisions:
|
|
137
|
+
return None # no collisions
|
|
138
|
+
|
|
139
|
+
# Assemble an error message listing the conflicting GroupBys by name
|
|
140
|
+
lines = [
|
|
141
|
+
f"{entity_set_type} for Join: {name} has the following output name collisions:\n"
|
|
142
|
+
]
|
|
143
|
+
for col, gb_names in collisions.items():
|
|
144
|
+
names_str = ", ".join(sorted(gb_names))
|
|
145
|
+
lines.append(f" - [{col}] has collisions from: [{names_str}]")
|
|
146
|
+
|
|
147
|
+
lines.append(
|
|
148
|
+
"\nConsider assigning distinct `prefix` values to the conflicting parts to avoid collisions."
|
|
149
|
+
)
|
|
150
|
+
return ValueError("\n".join(lines))
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class ConfValidator(object):
|
|
154
|
+
"""
|
|
155
|
+
applies repo wide validation rules
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
def __init__(
|
|
159
|
+
self,
|
|
160
|
+
input_root,
|
|
161
|
+
output_root,
|
|
162
|
+
existing_gbs,
|
|
163
|
+
existing_joins,
|
|
164
|
+
log_level=logging.INFO,
|
|
165
|
+
):
|
|
166
|
+
|
|
167
|
+
self.chronon_root_path = input_root
|
|
168
|
+
self.output_root = output_root
|
|
169
|
+
|
|
170
|
+
self.log_level = log_level
|
|
171
|
+
self.logger = get_logger(log_level)
|
|
172
|
+
|
|
173
|
+
# we keep the objs in the list not in a set since thrift does not
|
|
174
|
+
# implement __hash__ for ttypes object.
|
|
175
|
+
|
|
176
|
+
self.old_objs = defaultdict(dict)
|
|
177
|
+
self.old_group_bys = existing_gbs
|
|
178
|
+
self.old_joins = existing_joins
|
|
179
|
+
self.old_objs["GroupBy"] = self.old_group_bys
|
|
180
|
+
self.old_objs["Join"] = self.old_joins
|
|
181
|
+
|
|
182
|
+
def _get_old_obj(self, obj_class: type, obj_name: str) -> object:
|
|
183
|
+
"""
|
|
184
|
+
returns:
|
|
185
|
+
materialized version of the obj given the object's name.
|
|
186
|
+
"""
|
|
187
|
+
class_name = obj_class.__name__
|
|
188
|
+
|
|
189
|
+
if class_name not in self.old_objs:
|
|
190
|
+
return None
|
|
191
|
+
obj_map = self.old_objs[class_name]
|
|
192
|
+
|
|
193
|
+
if obj_name not in obj_map:
|
|
194
|
+
return None
|
|
195
|
+
return obj_map[obj_name]
|
|
196
|
+
|
|
197
|
+
def _get_old_joins_with_group_by(self, group_by: GroupBy) -> List[Join]:
|
|
198
|
+
"""
|
|
199
|
+
returns:
|
|
200
|
+
materialized joins including the group_by as dicts.
|
|
201
|
+
"""
|
|
202
|
+
joins = []
|
|
203
|
+
for join in self.old_joins.values():
|
|
204
|
+
if join.joinParts is not None and group_by.metaData.name in [
|
|
205
|
+
rp.groupBy.metaData.name for rp in join.joinParts
|
|
206
|
+
]:
|
|
207
|
+
joins.append(join)
|
|
208
|
+
return joins
|
|
209
|
+
|
|
210
|
+
def can_skip_materialize(self, obj: object) -> List[str]:
|
|
211
|
+
"""
|
|
212
|
+
Check if the object can be skipped to be materialized and return reasons
|
|
213
|
+
if it can be.
|
|
214
|
+
"""
|
|
215
|
+
reasons = []
|
|
216
|
+
if isinstance(obj, GroupBy):
|
|
217
|
+
if not _is_batch_upload_needed(obj):
|
|
218
|
+
reasons.append(
|
|
219
|
+
"GroupBys should not be materialized if batch upload job is not needed"
|
|
220
|
+
)
|
|
221
|
+
# Otherwise group_bys included in online join or are marked explicitly
|
|
222
|
+
# online itself are materialized.
|
|
223
|
+
elif not any(
|
|
224
|
+
join.metaData.online for join in self._get_old_joins_with_group_by(obj)
|
|
225
|
+
) and not _is_batch_upload_needed(obj):
|
|
226
|
+
reasons.append(
|
|
227
|
+
"is not marked online/production nor is included in any online join"
|
|
228
|
+
)
|
|
229
|
+
return reasons
|
|
230
|
+
|
|
231
|
+
def validate_obj(self, obj: object) -> List[BaseException]:
|
|
232
|
+
"""
|
|
233
|
+
Validate Chronon API obj against other entities in the repo.
|
|
234
|
+
|
|
235
|
+
returns:
|
|
236
|
+
list of errors.
|
|
237
|
+
"""
|
|
238
|
+
if isinstance(obj, GroupBy):
|
|
239
|
+
return self._validate_group_by(obj)
|
|
240
|
+
elif isinstance(obj, Join):
|
|
241
|
+
return self._validate_join(obj)
|
|
242
|
+
return []
|
|
243
|
+
|
|
244
|
+
def _has_diff(
|
|
245
|
+
self, obj: object, old_obj: object, skipped_fields=SKIPPED_FIELDS
|
|
246
|
+
) -> bool:
|
|
247
|
+
new_json = {
|
|
248
|
+
k: v
|
|
249
|
+
for k, v in json.loads(thrift_simple_json(obj)).items()
|
|
250
|
+
if k not in skipped_fields
|
|
251
|
+
}
|
|
252
|
+
old_json = {
|
|
253
|
+
k: v
|
|
254
|
+
for k, v in json.loads(thrift_simple_json(old_obj)).items()
|
|
255
|
+
if k not in skipped_fields
|
|
256
|
+
}
|
|
257
|
+
if isinstance(obj, Join):
|
|
258
|
+
_filter_skipped_fields_from_join(new_json, skipped_fields)
|
|
259
|
+
_filter_skipped_fields_from_join(old_json, skipped_fields)
|
|
260
|
+
return new_json != old_json
|
|
261
|
+
|
|
262
|
+
def safe_to_overwrite(self, obj: object) -> bool:
|
|
263
|
+
"""When an object is already materialized as online, it is no more safe
|
|
264
|
+
to materialize and overwrite the old conf.
|
|
265
|
+
"""
|
|
266
|
+
old_obj = self._get_old_obj(type(obj), obj.metaData.name)
|
|
267
|
+
return (
|
|
268
|
+
not old_obj
|
|
269
|
+
or not self._has_diff(obj, old_obj)
|
|
270
|
+
or not old_obj.metaData.online
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
def _validate_derivations(
|
|
274
|
+
self, pre_derived_cols: List[str], derivations: List[Derivation]
|
|
275
|
+
) -> List[BaseException]:
|
|
276
|
+
"""
|
|
277
|
+
Validate join/groupBy's derivation is defined correctly.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
list of validation errors.
|
|
281
|
+
"""
|
|
282
|
+
errors = []
|
|
283
|
+
derived_columns = set(pre_derived_cols)
|
|
284
|
+
|
|
285
|
+
wild_card_derivation_included = any(
|
|
286
|
+
derivation.expression == "*" for derivation in derivations
|
|
287
|
+
)
|
|
288
|
+
if not wild_card_derivation_included:
|
|
289
|
+
derived_columns.clear()
|
|
290
|
+
for derivation in derivations:
|
|
291
|
+
# if the derivation is a renaming derivation, check whether the expression is in pre-derived schema
|
|
292
|
+
if is_identifier(derivation.expression):
|
|
293
|
+
# for wildcard derivation we want to remove the original column if there is a renaming operation
|
|
294
|
+
# applied on it
|
|
295
|
+
if wild_card_derivation_included:
|
|
296
|
+
if derivation.expression in derived_columns:
|
|
297
|
+
derived_columns.remove(derivation.expression)
|
|
298
|
+
if (
|
|
299
|
+
derivation.expression not in pre_derived_cols
|
|
300
|
+
and derivation.expression not in ("ds", "ts")
|
|
301
|
+
):
|
|
302
|
+
errors.append(
|
|
303
|
+
ValueError("Incorrect derivation expression {}, expression not found in pre-derived columns {}"
|
|
304
|
+
.format(
|
|
305
|
+
derivation.expression, pre_derived_cols
|
|
306
|
+
))
|
|
307
|
+
)
|
|
308
|
+
if derivation.name != "*":
|
|
309
|
+
if derivation.name in derived_columns:
|
|
310
|
+
errors.append(
|
|
311
|
+
ValueError("Incorrect derivation name {} due to output column name conflict".format(
|
|
312
|
+
derivation.name
|
|
313
|
+
)
|
|
314
|
+
))
|
|
315
|
+
else:
|
|
316
|
+
derived_columns.add(derivation.name)
|
|
317
|
+
return errors
|
|
318
|
+
|
|
319
|
+
def _validate_join_part_keys(self, join_part: JoinPart, left_cols: List[str]) -> BaseException:
|
|
320
|
+
keys = []
|
|
321
|
+
|
|
322
|
+
key_mapping = join_part.keyMapping if join_part.keyMapping else {}
|
|
323
|
+
for key in join_part.groupBy.keyColumns:
|
|
324
|
+
keys.append(key_mapping.get(key, key))
|
|
325
|
+
|
|
326
|
+
missing = [k for k in keys if k not in left_cols]
|
|
327
|
+
|
|
328
|
+
err_string = ""
|
|
329
|
+
left_cols_as_str = ", ".join(left_cols)
|
|
330
|
+
group_by_name = join_part.groupBy.metaData.name
|
|
331
|
+
if missing:
|
|
332
|
+
key_mapping_str = f"Key Mapping: {key_mapping}" if key_mapping else ""
|
|
333
|
+
err_string += textwrap.dedent(f"""
|
|
334
|
+
- Join is missing keys {missing} on left side. Required for JoinPart: {group_by_name}.
|
|
335
|
+
Existing columns on left side: {left_cols_as_str}
|
|
336
|
+
All required Keys: {join_part.groupBy.keyColumns}
|
|
337
|
+
{key_mapping_str}
|
|
338
|
+
Consider renaming a column on the left, or including the key_mapping argument to your join_part.""")
|
|
339
|
+
|
|
340
|
+
if key_mapping:
|
|
341
|
+
# Left side of key mapping should include columns on the left
|
|
342
|
+
key_map_keys_missing_from_left = [k for k in key_mapping.keys() if k not in left_cols]
|
|
343
|
+
if key_map_keys_missing_from_left:
|
|
344
|
+
err_string += f"\n- The following keys in your key_mapping: {str(key_map_keys_missing_from_left)} for JoinPart {group_by_name} are not included in the left side of the join: {left_cols_as_str}"
|
|
345
|
+
|
|
346
|
+
# Right side of key mapping should only include keys in GroupBy
|
|
347
|
+
keys_missing_from_key_map_values = [v for v in key_mapping.values() if v not in join_part.groupBy.keyColumns]
|
|
348
|
+
if keys_missing_from_key_map_values:
|
|
349
|
+
err_string += f"\n- The following values in your key_mapping: {str(keys_missing_from_key_map_values)} for JoinPart {group_by_name} do not cover any group by key columns: {join_part.groupBy.keyColumns}"
|
|
350
|
+
|
|
351
|
+
if key_map_keys_missing_from_left or keys_missing_from_key_map_values:
|
|
352
|
+
err_string += "\n(Key Mapping should be formatted as column_from_left -> group_by_key)"
|
|
353
|
+
|
|
354
|
+
if err_string:
|
|
355
|
+
return ValueError(err_string)
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def _validate_keys(self, join: Join) -> List[BaseException]:
|
|
359
|
+
left = join.left
|
|
360
|
+
|
|
361
|
+
left_selects = None
|
|
362
|
+
if left.events:
|
|
363
|
+
left_selects = left.events.query.selects
|
|
364
|
+
elif left.entities:
|
|
365
|
+
left_selects = left.entities.query.selects
|
|
366
|
+
elif left.joinSource:
|
|
367
|
+
left_selects = left.joinSource.query.selects
|
|
368
|
+
# TODO -- if selects are not selected here, get output cols from join
|
|
369
|
+
|
|
370
|
+
left_cols = []
|
|
371
|
+
|
|
372
|
+
if left_selects:
|
|
373
|
+
left_cols = left_selects.keys()
|
|
374
|
+
|
|
375
|
+
errors = []
|
|
376
|
+
|
|
377
|
+
if left_cols:
|
|
378
|
+
join_parts = join.joinParts
|
|
379
|
+
|
|
380
|
+
# Add label_parts to join_parts to validate if set
|
|
381
|
+
label_parts = join.labelParts
|
|
382
|
+
if label_parts:
|
|
383
|
+
for label_jp in label_parts.labels:
|
|
384
|
+
join_parts.append(label_jp)
|
|
385
|
+
|
|
386
|
+
# Validate join_parts
|
|
387
|
+
for join_part in join_parts:
|
|
388
|
+
join_part_err = self._validate_join_part_keys(join_part, left_cols)
|
|
389
|
+
if join_part_err:
|
|
390
|
+
errors.append(join_part_err)
|
|
391
|
+
|
|
392
|
+
return errors
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def _validate_join(self, join: Join) -> List[BaseException]:
|
|
396
|
+
"""
|
|
397
|
+
Validate join's status with materialized versions of group_bys
|
|
398
|
+
included by the join.
|
|
399
|
+
|
|
400
|
+
Returns:
|
|
401
|
+
list of validation errors.
|
|
402
|
+
"""
|
|
403
|
+
included_group_bys_and_prefixes = [(rp.groupBy, rp.prefix) for rp in join.joinParts]
|
|
404
|
+
# TODO: Remove label parts check in future PR that deprecates label_parts
|
|
405
|
+
included_label_parts_and_prefixes = [(lp.groupBy, lp.prefix) for lp in join.labelParts.labels] if join.labelParts else []
|
|
406
|
+
included_group_bys = [tup[0] for tup in included_group_bys_and_prefixes]
|
|
407
|
+
|
|
408
|
+
offline_included_group_bys = [
|
|
409
|
+
gb.metaData.name
|
|
410
|
+
for gb in included_group_bys
|
|
411
|
+
if not gb.metaData or gb.metaData.online is False
|
|
412
|
+
]
|
|
413
|
+
errors = []
|
|
414
|
+
old_group_bys = [
|
|
415
|
+
group_by
|
|
416
|
+
for group_by in included_group_bys
|
|
417
|
+
if self._get_old_obj(GroupBy, group_by.metaData.name)
|
|
418
|
+
]
|
|
419
|
+
non_prod_old_group_bys = [
|
|
420
|
+
group_by.metaData.name
|
|
421
|
+
for group_by in old_group_bys
|
|
422
|
+
if group_by.metaData.production is False
|
|
423
|
+
]
|
|
424
|
+
# Check if the underlying groupBy is valid
|
|
425
|
+
group_by_errors = [
|
|
426
|
+
self._validate_group_by(group_by) for group_by in included_group_bys
|
|
427
|
+
]
|
|
428
|
+
errors += [
|
|
429
|
+
ValueError(f"join {join.metaData.name}'s underlying {error}")
|
|
430
|
+
for errors in group_by_errors
|
|
431
|
+
for error in errors
|
|
432
|
+
]
|
|
433
|
+
# Check if the production join is using non production groupBy
|
|
434
|
+
if join.metaData.production and non_prod_old_group_bys:
|
|
435
|
+
errors.append(
|
|
436
|
+
ValueError("join {} is production but includes the following non production group_bys: {}".format(
|
|
437
|
+
join.metaData.name, ", ".join(non_prod_old_group_bys)
|
|
438
|
+
)
|
|
439
|
+
))
|
|
440
|
+
# Check if the online join is using the offline groupBy
|
|
441
|
+
if join.metaData.online:
|
|
442
|
+
if offline_included_group_bys:
|
|
443
|
+
errors.append(
|
|
444
|
+
ValueError("join {} is online but includes the following offline group_bys: {}".format(
|
|
445
|
+
join.metaData.name, ", ".join(offline_included_group_bys)
|
|
446
|
+
)
|
|
447
|
+
))
|
|
448
|
+
# Only validate the join derivation when the underlying groupBy is valid
|
|
449
|
+
group_by_correct = all(not errors for errors in group_by_errors)
|
|
450
|
+
if join.derivations and group_by_correct:
|
|
451
|
+
features = list(get_pre_derived_join_features(join).keys())
|
|
452
|
+
# For online joins keys are not included in output schema
|
|
453
|
+
if join.metaData.online:
|
|
454
|
+
columns = features
|
|
455
|
+
else:
|
|
456
|
+
keys = list(get_pre_derived_source_keys(join.left).keys())
|
|
457
|
+
columns = features + keys
|
|
458
|
+
errors.extend(self._validate_derivations(columns, join.derivations))
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
errors.extend(self._validate_keys(join))
|
|
462
|
+
|
|
463
|
+
# If the join is using "short" names, ensure that there are no collisions
|
|
464
|
+
if join.useLongNames is False:
|
|
465
|
+
right_part_collisions = detect_feature_name_collisions(included_group_bys_and_prefixes, "right parts", join.metaData.name)
|
|
466
|
+
if right_part_collisions:
|
|
467
|
+
errors.append(right_part_collisions)
|
|
468
|
+
|
|
469
|
+
label_part_collisions = detect_feature_name_collisions(included_label_parts_and_prefixes, "label parts", join.metaData.name)
|
|
470
|
+
if label_part_collisions:
|
|
471
|
+
errors.append(label_part_collisions)
|
|
472
|
+
|
|
473
|
+
return errors
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def _validate_group_by(self, group_by: GroupBy) -> List[BaseException]:
|
|
478
|
+
"""
|
|
479
|
+
Validate group_by's status with materialized versions of joins
|
|
480
|
+
including the group_by.
|
|
481
|
+
|
|
482
|
+
Return:
|
|
483
|
+
List of validation errors.
|
|
484
|
+
"""
|
|
485
|
+
joins = self._get_old_joins_with_group_by(group_by)
|
|
486
|
+
online_joins = [
|
|
487
|
+
join.metaData.name for join in joins if join.metaData.online is True
|
|
488
|
+
]
|
|
489
|
+
prod_joins = [
|
|
490
|
+
join.metaData.name for join in joins if join.metaData.production is True
|
|
491
|
+
]
|
|
492
|
+
errors = []
|
|
493
|
+
|
|
494
|
+
non_temporal = (
|
|
495
|
+
group_by.accuracy is None or group_by.accuracy == Accuracy.SNAPSHOT
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
no_topic = not _group_by_has_topic(group_by)
|
|
499
|
+
has_hourly_windows = _group_by_has_hourly_windows(group_by)
|
|
500
|
+
|
|
501
|
+
# batch features cannot contain hourly windows
|
|
502
|
+
if (no_topic and non_temporal) and has_hourly_windows:
|
|
503
|
+
errors.append(
|
|
504
|
+
ValueError(f"group_by {group_by.metaData.name} is defined to be daily refreshed but contains "
|
|
505
|
+
f"hourly windows. "
|
|
506
|
+
)
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
def _validate_bounded_event_source():
|
|
510
|
+
if group_by.aggregations is None:
|
|
511
|
+
return
|
|
512
|
+
|
|
513
|
+
unbounded_event_sources = [
|
|
514
|
+
str(src)
|
|
515
|
+
for src in group_by.sources
|
|
516
|
+
if isinstance(get_root_source(src), EventSource) and get_query(src).startPartition is None
|
|
517
|
+
]
|
|
518
|
+
|
|
519
|
+
if not unbounded_event_sources:
|
|
520
|
+
return
|
|
521
|
+
|
|
522
|
+
unwindowed_aggregations = [str(agg) for agg in group_by.aggregations if agg.windows is None]
|
|
523
|
+
|
|
524
|
+
if not unwindowed_aggregations:
|
|
525
|
+
return
|
|
526
|
+
|
|
527
|
+
nln = "\n"
|
|
528
|
+
|
|
529
|
+
errors.append(
|
|
530
|
+
ValueError(
|
|
531
|
+
f"""group_by {group_by.metaData.name} uses unwindowed aggregations [{nln}{f",{nln}".join(unwindowed_aggregations)}{nln}]
|
|
532
|
+
on unbounded event sources: [{nln}{f",{nln}".join(unbounded_event_sources)}{nln}].
|
|
533
|
+
Please set a start_partition on the source, or a window on the aggregation.""")
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
_validate_bounded_event_source()
|
|
537
|
+
|
|
538
|
+
# group by that are marked explicitly offline should not be present in
|
|
539
|
+
# materialized online joins.
|
|
540
|
+
if group_by.metaData.online is False and online_joins:
|
|
541
|
+
errors.append(
|
|
542
|
+
ValueError("group_by {} is explicitly marked offline but included in "
|
|
543
|
+
"the following online joins: {}".format(
|
|
544
|
+
group_by.metaData.name, ", ".join(online_joins)
|
|
545
|
+
))
|
|
546
|
+
)
|
|
547
|
+
# group by that are marked explicitly non-production should not be
|
|
548
|
+
# present in materialized production joins.
|
|
549
|
+
if prod_joins:
|
|
550
|
+
if group_by.metaData.production is False:
|
|
551
|
+
errors.append(
|
|
552
|
+
ValueError(
|
|
553
|
+
"group_by {} is explicitly marked as non-production but included in the following production "
|
|
554
|
+
"joins: {}".format(group_by.metaData.name, ", ".join(prod_joins))
|
|
555
|
+
))
|
|
556
|
+
# if the group by is included in any of materialized production join,
|
|
557
|
+
# set it to production in the materialized output.
|
|
558
|
+
else:
|
|
559
|
+
group_by.metaData.production = True
|
|
560
|
+
|
|
561
|
+
# validate the derivations are defined correctly
|
|
562
|
+
if group_by.derivations:
|
|
563
|
+
# For online group_by keys are not included in output schema
|
|
564
|
+
if group_by.metaData.online:
|
|
565
|
+
columns = list(get_pre_derived_group_by_features(group_by).keys())
|
|
566
|
+
else:
|
|
567
|
+
columns = list(get_pre_derived_group_by_columns(group_by).keys())
|
|
568
|
+
errors.extend(self._validate_derivations(columns, group_by.derivations))
|
|
569
|
+
|
|
570
|
+
for source in group_by.sources:
|
|
571
|
+
src: Source = source
|
|
572
|
+
if (
|
|
573
|
+
src.events
|
|
574
|
+
and src.events.isCumulative
|
|
575
|
+
and (src.events.query.timeColumn is None)
|
|
576
|
+
):
|
|
577
|
+
errors.append(
|
|
578
|
+
ValueError("Please set query.timeColumn for Cumulative Events Table: {}".format(
|
|
579
|
+
src.events.table
|
|
580
|
+
))
|
|
581
|
+
)
|
|
582
|
+
elif (
|
|
583
|
+
src.joinSource
|
|
584
|
+
):
|
|
585
|
+
join_obj = src.joinSource.join
|
|
586
|
+
if join_obj.metaData.name is None or join_obj.metaData.team is None:
|
|
587
|
+
errors.append(
|
|
588
|
+
ValueError(f"Join must be defined with team and name: {join_obj}")
|
|
589
|
+
)
|
|
590
|
+
return errors
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import difflib
|
|
2
|
+
from typing import Any, Dict, List
|
|
3
|
+
|
|
4
|
+
from rich.text import Text
|
|
5
|
+
|
|
6
|
+
from ai.chronon.cli.compile.display.compiled_obj import CompiledObj
|
|
7
|
+
from ai.chronon.cli.compile.display.diff_result import DiffResult
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ClassTracker:
|
|
11
|
+
"""
|
|
12
|
+
Tracker object per class - Join, StagingQuery, GroupBy etc
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(self):
|
|
16
|
+
self.existing_objs: Dict[str, CompiledObj] = {} # name to obj
|
|
17
|
+
self.files_to_obj: Dict[str, List[Any]] = {}
|
|
18
|
+
self.files_to_errors: Dict[str, List[Exception]] = {}
|
|
19
|
+
self.new_objs: Dict[str, CompiledObj] = {} # name to obj
|
|
20
|
+
self.diff_result = DiffResult()
|
|
21
|
+
self.deleted_names: List[str] = []
|
|
22
|
+
|
|
23
|
+
def add_existing(self, obj: CompiledObj) -> None:
|
|
24
|
+
self.existing_objs[obj.name] = obj
|
|
25
|
+
|
|
26
|
+
def add(self, compiled: CompiledObj) -> None:
|
|
27
|
+
|
|
28
|
+
if compiled.errors:
|
|
29
|
+
|
|
30
|
+
if compiled.file not in self.files_to_errors:
|
|
31
|
+
self.files_to_errors[compiled.file] = []
|
|
32
|
+
|
|
33
|
+
self.files_to_errors[compiled.file].extend(compiled.errors)
|
|
34
|
+
|
|
35
|
+
else:
|
|
36
|
+
if compiled.file not in self.files_to_obj:
|
|
37
|
+
self.files_to_obj[compiled.file] = []
|
|
38
|
+
|
|
39
|
+
self.files_to_obj[compiled.file].append(compiled.obj)
|
|
40
|
+
|
|
41
|
+
self.new_objs[compiled.name] = compiled
|
|
42
|
+
self._update_diff(compiled)
|
|
43
|
+
|
|
44
|
+
def _update_diff(self, compiled: CompiledObj) -> None:
|
|
45
|
+
if compiled.name in self.existing_objs:
|
|
46
|
+
|
|
47
|
+
existing_json = self.existing_objs[compiled.name].tjson
|
|
48
|
+
new_json = compiled.tjson
|
|
49
|
+
|
|
50
|
+
if existing_json != new_json:
|
|
51
|
+
|
|
52
|
+
diff = difflib.unified_diff(
|
|
53
|
+
existing_json.splitlines(keepends=True),
|
|
54
|
+
new_json.splitlines(keepends=True),
|
|
55
|
+
n=2,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
print(f"Updated object: {compiled.name} in file {compiled.file}")
|
|
59
|
+
print("".join(diff))
|
|
60
|
+
print("\n")
|
|
61
|
+
|
|
62
|
+
self.diff_result.updated.append(compiled.name)
|
|
63
|
+
|
|
64
|
+
else:
|
|
65
|
+
if not compiled.errors:
|
|
66
|
+
self.diff_result.added.append(compiled.name)
|
|
67
|
+
|
|
68
|
+
def close(self) -> None:
|
|
69
|
+
self.closed = True
|
|
70
|
+
self.recent_file = None
|
|
71
|
+
self.deleted_names = list(self.existing_objs.keys() - self.new_objs.keys())
|
|
72
|
+
|
|
73
|
+
def to_status(self) -> Text:
|
|
74
|
+
text = Text(overflow="fold", no_wrap=False)
|
|
75
|
+
|
|
76
|
+
if self.existing_objs:
|
|
77
|
+
text.append(
|
|
78
|
+
f" Parsed {len(self.existing_objs)} previously compiled objects.\n"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
if self.files_to_obj:
|
|
82
|
+
text.append(" Compiled ")
|
|
83
|
+
text.append(f"{len(self.new_objs)} ", style="bold green")
|
|
84
|
+
text.append("objects from ")
|
|
85
|
+
text.append(f"{len(self.files_to_obj)} ", style="bold green")
|
|
86
|
+
text.append("files.\n")
|
|
87
|
+
|
|
88
|
+
if self.files_to_errors:
|
|
89
|
+
text.append(" Failed to compile ")
|
|
90
|
+
text.append(f"{len(self.files_to_errors)} ", style="red")
|
|
91
|
+
text.append("files.\n")
|
|
92
|
+
|
|
93
|
+
return text
|
|
94
|
+
|
|
95
|
+
def to_errors(self) -> Text:
|
|
96
|
+
text = Text(overflow="fold", no_wrap=False)
|
|
97
|
+
|
|
98
|
+
if self.files_to_errors:
|
|
99
|
+
for file, errors in self.files_to_errors.items():
|
|
100
|
+
text.append(" ERROR ", style="bold red")
|
|
101
|
+
text.append(f"- {file}:\n")
|
|
102
|
+
|
|
103
|
+
for error in errors:
|
|
104
|
+
# Format each error properly, handling newlines
|
|
105
|
+
error_msg = str(error)
|
|
106
|
+
text.append(f" {error_msg}\n", style="red")
|
|
107
|
+
|
|
108
|
+
return text
|
|
109
|
+
|
|
110
|
+
# doesn't make sense to show deletes until the very end of compilation
|
|
111
|
+
def diff(self) -> Text:
|
|
112
|
+
return self.diff_result.render(deleted_names=self.deleted_names)
|