awx-zipline-ai 0.0.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- __init__.py +0 -0
- agent/__init__.py +1 -0
- agent/constants.py +15 -0
- agent/ttypes.py +1684 -0
- ai/__init__.py +0 -0
- ai/chronon/__init__.py +0 -0
- ai/chronon/airflow_helpers.py +248 -0
- ai/chronon/cli/__init__.py +0 -0
- ai/chronon/cli/compile/__init__.py +0 -0
- ai/chronon/cli/compile/column_hashing.py +336 -0
- ai/chronon/cli/compile/compile_context.py +173 -0
- ai/chronon/cli/compile/compiler.py +183 -0
- ai/chronon/cli/compile/conf_validator.py +742 -0
- ai/chronon/cli/compile/display/__init__.py +0 -0
- ai/chronon/cli/compile/display/class_tracker.py +102 -0
- ai/chronon/cli/compile/display/compile_status.py +95 -0
- ai/chronon/cli/compile/display/compiled_obj.py +12 -0
- ai/chronon/cli/compile/display/console.py +3 -0
- ai/chronon/cli/compile/display/diff_result.py +111 -0
- ai/chronon/cli/compile/fill_templates.py +35 -0
- ai/chronon/cli/compile/parse_configs.py +134 -0
- ai/chronon/cli/compile/parse_teams.py +242 -0
- ai/chronon/cli/compile/serializer.py +109 -0
- ai/chronon/cli/compile/version_utils.py +42 -0
- ai/chronon/cli/git_utils.py +145 -0
- ai/chronon/cli/logger.py +59 -0
- ai/chronon/constants.py +3 -0
- ai/chronon/group_by.py +692 -0
- ai/chronon/join.py +580 -0
- ai/chronon/logger.py +23 -0
- ai/chronon/model.py +40 -0
- ai/chronon/query.py +126 -0
- ai/chronon/repo/__init__.py +39 -0
- ai/chronon/repo/aws.py +284 -0
- ai/chronon/repo/cluster.py +136 -0
- ai/chronon/repo/compile.py +62 -0
- ai/chronon/repo/constants.py +164 -0
- ai/chronon/repo/default_runner.py +269 -0
- ai/chronon/repo/explore.py +418 -0
- ai/chronon/repo/extract_objects.py +134 -0
- ai/chronon/repo/gcp.py +586 -0
- ai/chronon/repo/gitpython_utils.py +15 -0
- ai/chronon/repo/hub_runner.py +261 -0
- ai/chronon/repo/hub_uploader.py +109 -0
- ai/chronon/repo/init.py +60 -0
- ai/chronon/repo/join_backfill.py +119 -0
- ai/chronon/repo/run.py +296 -0
- ai/chronon/repo/serializer.py +133 -0
- ai/chronon/repo/team_json_utils.py +46 -0
- ai/chronon/repo/utils.py +481 -0
- ai/chronon/repo/zipline.py +35 -0
- ai/chronon/repo/zipline_hub.py +277 -0
- ai/chronon/resources/__init__.py +0 -0
- ai/chronon/resources/gcp/__init__.py +0 -0
- ai/chronon/resources/gcp/group_bys/__init__.py +0 -0
- ai/chronon/resources/gcp/group_bys/test/__init__.py +0 -0
- ai/chronon/resources/gcp/group_bys/test/data.py +30 -0
- ai/chronon/resources/gcp/joins/__init__.py +0 -0
- ai/chronon/resources/gcp/joins/test/__init__.py +0 -0
- ai/chronon/resources/gcp/joins/test/data.py +26 -0
- ai/chronon/resources/gcp/sources/__init__.py +0 -0
- ai/chronon/resources/gcp/sources/test/__init__.py +0 -0
- ai/chronon/resources/gcp/sources/test/data.py +26 -0
- ai/chronon/resources/gcp/teams.py +58 -0
- ai/chronon/source.py +86 -0
- ai/chronon/staging_query.py +226 -0
- ai/chronon/types.py +58 -0
- ai/chronon/utils.py +510 -0
- ai/chronon/windows.py +48 -0
- awx_zipline_ai-0.0.32.dist-info/METADATA +197 -0
- awx_zipline_ai-0.0.32.dist-info/RECORD +96 -0
- awx_zipline_ai-0.0.32.dist-info/WHEEL +5 -0
- awx_zipline_ai-0.0.32.dist-info/entry_points.txt +2 -0
- awx_zipline_ai-0.0.32.dist-info/top_level.txt +4 -0
- gen_thrift/__init__.py +0 -0
- gen_thrift/api/__init__.py +1 -0
- gen_thrift/api/constants.py +15 -0
- gen_thrift/api/ttypes.py +3754 -0
- gen_thrift/common/__init__.py +1 -0
- gen_thrift/common/constants.py +15 -0
- gen_thrift/common/ttypes.py +1814 -0
- gen_thrift/eval/__init__.py +1 -0
- gen_thrift/eval/constants.py +15 -0
- gen_thrift/eval/ttypes.py +660 -0
- gen_thrift/fetcher/__init__.py +1 -0
- gen_thrift/fetcher/constants.py +15 -0
- gen_thrift/fetcher/ttypes.py +127 -0
- gen_thrift/hub/__init__.py +1 -0
- gen_thrift/hub/constants.py +15 -0
- gen_thrift/hub/ttypes.py +1109 -0
- gen_thrift/observability/__init__.py +1 -0
- gen_thrift/observability/constants.py +15 -0
- gen_thrift/observability/ttypes.py +2355 -0
- gen_thrift/planner/__init__.py +1 -0
- gen_thrift/planner/constants.py +15 -0
- gen_thrift/planner/ttypes.py +1967 -0
|
@@ -0,0 +1,742 @@
|
|
|
1
|
+
"""Object for checking whether a Chronon API thrift object is consistent with other"""
|
|
2
|
+
|
|
3
|
+
# Copyright (C) 2023 The Chronon Authors.
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
import re
|
|
20
|
+
import sys
|
|
21
|
+
import textwrap
|
|
22
|
+
from collections import defaultdict
|
|
23
|
+
from dataclasses import dataclass
|
|
24
|
+
from typing import Dict, List, Tuple
|
|
25
|
+
|
|
26
|
+
import gen_thrift.common.ttypes as common
|
|
27
|
+
from gen_thrift.api.ttypes import (
|
|
28
|
+
Accuracy,
|
|
29
|
+
Aggregation,
|
|
30
|
+
Derivation,
|
|
31
|
+
EventSource,
|
|
32
|
+
GroupBy,
|
|
33
|
+
Join,
|
|
34
|
+
JoinPart,
|
|
35
|
+
Source,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
from ai.chronon.cli.compile.column_hashing import (
|
|
39
|
+
compute_group_by_columns_hashes,
|
|
40
|
+
get_pre_derived_group_by_columns,
|
|
41
|
+
get_pre_derived_group_by_features,
|
|
42
|
+
get_pre_derived_join_features,
|
|
43
|
+
get_pre_derived_source_keys,
|
|
44
|
+
)
|
|
45
|
+
from ai.chronon.cli.compile.version_utils import is_version_change
|
|
46
|
+
from ai.chronon.logger import get_logger
|
|
47
|
+
from ai.chronon.repo.serializer import thrift_simple_json
|
|
48
|
+
from ai.chronon.utils import get_query, get_root_source
|
|
49
|
+
|
|
50
|
+
# Fields that indicate status of the entities.
|
|
51
|
+
SKIPPED_FIELDS = frozenset(["metaData"])
|
|
52
|
+
EXTERNAL_KEY = "onlineExternalParts"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclass
|
|
56
|
+
class ConfigChange:
|
|
57
|
+
"""Represents a change to a compiled config object."""
|
|
58
|
+
|
|
59
|
+
name: str
|
|
60
|
+
obj_type: str
|
|
61
|
+
online: bool = False
|
|
62
|
+
production: bool = False
|
|
63
|
+
base_name: str = None
|
|
64
|
+
old_version: int = None
|
|
65
|
+
new_version: int = None
|
|
66
|
+
is_version_change: bool = False
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _filter_skipped_fields_from_join(json_obj: Dict, skipped_fields):
|
|
70
|
+
for join_part in json_obj["joinParts"]:
|
|
71
|
+
group_by = join_part["groupBy"]
|
|
72
|
+
for field in skipped_fields:
|
|
73
|
+
group_by.pop(field, None)
|
|
74
|
+
if EXTERNAL_KEY in json_obj:
|
|
75
|
+
json_obj.pop(EXTERNAL_KEY, None)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _is_batch_upload_needed(group_by: GroupBy) -> bool:
|
|
79
|
+
if group_by.metaData.online or group_by.backfillStartDate:
|
|
80
|
+
return True
|
|
81
|
+
else:
|
|
82
|
+
return False
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def is_identifier(s: str) -> bool:
|
|
86
|
+
identifier_regex = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*")
|
|
87
|
+
return re.fullmatch(identifier_regex, s) is not None
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _source_has_topic(source: Source) -> bool:
|
|
91
|
+
if source.events:
|
|
92
|
+
return source.events.topic is not None
|
|
93
|
+
elif source.entities:
|
|
94
|
+
return source.entities.mutationTopic is not None
|
|
95
|
+
elif source.joinSource:
|
|
96
|
+
return _source_has_topic(source.joinSource.join.left)
|
|
97
|
+
return False
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _group_by_has_topic(groupBy: GroupBy) -> bool:
|
|
101
|
+
return any(_source_has_topic(source) for source in groupBy.sources)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _group_by_has_hourly_windows(groupBy: GroupBy) -> bool:
|
|
105
|
+
aggs: List[Aggregation] = groupBy.aggregations
|
|
106
|
+
|
|
107
|
+
if not aggs:
|
|
108
|
+
return False
|
|
109
|
+
|
|
110
|
+
for agg in aggs:
|
|
111
|
+
if not agg.windows:
|
|
112
|
+
return False
|
|
113
|
+
|
|
114
|
+
for window in agg.windows:
|
|
115
|
+
if window.timeUnit == common.TimeUnit.HOURS:
|
|
116
|
+
return True
|
|
117
|
+
|
|
118
|
+
return False
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def detect_feature_name_collisions(
|
|
122
|
+
group_bys: List[Tuple[GroupBy, str]], entity_set_type: str, name: str
|
|
123
|
+
) -> BaseException | None:
|
|
124
|
+
# Build a map of output_column -> set of group_by name
|
|
125
|
+
output_col_to_gbs = {}
|
|
126
|
+
for gb, prefix in group_bys:
|
|
127
|
+
jp_prefix_str = f"{prefix}_" if prefix else ""
|
|
128
|
+
key_str = "_".join(gb.keyColumns)
|
|
129
|
+
prefix_str = jp_prefix_str + key_str + "_"
|
|
130
|
+
cols = compute_group_by_columns_hashes(gb, exclude_keys=True)
|
|
131
|
+
if not cols:
|
|
132
|
+
print("HERE")
|
|
133
|
+
cols = {
|
|
134
|
+
f"{prefix_str}{base_col}"
|
|
135
|
+
for base_col in list(compute_group_by_columns_hashes(gb, exclude_keys=True).keys())
|
|
136
|
+
}
|
|
137
|
+
gb_name = gb.metaData.name
|
|
138
|
+
for col in cols:
|
|
139
|
+
if col not in output_col_to_gbs:
|
|
140
|
+
output_col_to_gbs[col] = set()
|
|
141
|
+
output_col_to_gbs[col].add(gb_name)
|
|
142
|
+
|
|
143
|
+
# Find output columns that are produced by more than one group_by
|
|
144
|
+
collisions = {col: gb_names for col, gb_names in output_col_to_gbs.items() if len(gb_names) > 1}
|
|
145
|
+
|
|
146
|
+
if not collisions:
|
|
147
|
+
return None # no collisions
|
|
148
|
+
|
|
149
|
+
# Assemble an error message listing the conflicting GroupBys by name
|
|
150
|
+
lines = [f"{entity_set_type} for Join: {name} has the following output name collisions:\n"]
|
|
151
|
+
for col, gb_names in collisions.items():
|
|
152
|
+
names_str = ", ".join(sorted(gb_names))
|
|
153
|
+
lines.append(f" - [{col}] has collisions from: [{names_str}]")
|
|
154
|
+
|
|
155
|
+
lines.append(
|
|
156
|
+
"\nConsider assigning distinct `prefix` values to the conflicting parts to avoid collisions."
|
|
157
|
+
)
|
|
158
|
+
return ValueError("\n".join(lines))
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class ConfValidator(object):
|
|
162
|
+
"""
|
|
163
|
+
applies repo wide validation rules
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
def __init__(
|
|
167
|
+
self,
|
|
168
|
+
input_root,
|
|
169
|
+
output_root,
|
|
170
|
+
existing_gbs,
|
|
171
|
+
existing_joins,
|
|
172
|
+
existing_staging_queries,
|
|
173
|
+
log_level=logging.INFO,
|
|
174
|
+
):
|
|
175
|
+
self.chronon_root_path = input_root
|
|
176
|
+
self.output_root = output_root
|
|
177
|
+
|
|
178
|
+
self.log_level = log_level
|
|
179
|
+
self.logger = get_logger(log_level)
|
|
180
|
+
|
|
181
|
+
# we keep the objs in the list not in a set since thrift does not
|
|
182
|
+
# implement __hash__ for ttypes object.
|
|
183
|
+
|
|
184
|
+
self.old_objs = defaultdict(dict)
|
|
185
|
+
self.old_group_bys = existing_gbs
|
|
186
|
+
self.old_joins = existing_joins
|
|
187
|
+
self.old_staging_queries = existing_staging_queries
|
|
188
|
+
self.old_objs["GroupBy"] = self.old_group_bys
|
|
189
|
+
self.old_objs["Join"] = self.old_joins
|
|
190
|
+
self.old_objs["StagingQuery"] = self.old_staging_queries
|
|
191
|
+
|
|
192
|
+
def _get_old_obj(self, obj_class: type, obj_name: str) -> object:
|
|
193
|
+
"""
|
|
194
|
+
returns:
|
|
195
|
+
materialized version of the obj given the object's name.
|
|
196
|
+
"""
|
|
197
|
+
class_name = obj_class.__name__
|
|
198
|
+
|
|
199
|
+
if class_name not in self.old_objs:
|
|
200
|
+
return None
|
|
201
|
+
obj_map = self.old_objs[class_name]
|
|
202
|
+
|
|
203
|
+
if obj_name not in obj_map:
|
|
204
|
+
return None
|
|
205
|
+
return obj_map[obj_name]
|
|
206
|
+
|
|
207
|
+
def _get_old_joins_with_group_by(self, group_by: GroupBy) -> List[Join]:
|
|
208
|
+
"""
|
|
209
|
+
returns:
|
|
210
|
+
materialized joins including the group_by as dicts.
|
|
211
|
+
"""
|
|
212
|
+
joins = []
|
|
213
|
+
for join in self.old_joins.values():
|
|
214
|
+
if join.joinParts is not None and group_by.metaData.name in [
|
|
215
|
+
rp.groupBy.metaData.name for rp in join.joinParts
|
|
216
|
+
]:
|
|
217
|
+
joins.append(join)
|
|
218
|
+
return joins
|
|
219
|
+
|
|
220
|
+
def can_skip_materialize(self, obj: object) -> List[str]:
|
|
221
|
+
"""
|
|
222
|
+
Check if the object can be skipped to be materialized and return reasons
|
|
223
|
+
if it can be.
|
|
224
|
+
"""
|
|
225
|
+
reasons = []
|
|
226
|
+
if isinstance(obj, GroupBy):
|
|
227
|
+
if not _is_batch_upload_needed(obj):
|
|
228
|
+
reasons.append(
|
|
229
|
+
"GroupBys should not be materialized if batch upload job is not needed"
|
|
230
|
+
)
|
|
231
|
+
# Otherwise group_bys included in online join or are marked explicitly
|
|
232
|
+
# online itself are materialized.
|
|
233
|
+
elif not any(
|
|
234
|
+
join.metaData.online for join in self._get_old_joins_with_group_by(obj)
|
|
235
|
+
) and not _is_batch_upload_needed(obj):
|
|
236
|
+
reasons.append("is not marked online/production nor is included in any online join")
|
|
237
|
+
return reasons
|
|
238
|
+
|
|
239
|
+
def validate_obj(self, obj: object) -> List[BaseException]:
|
|
240
|
+
"""
|
|
241
|
+
Validate Chronon API obj against other entities in the repo.
|
|
242
|
+
|
|
243
|
+
returns:
|
|
244
|
+
list of errors.
|
|
245
|
+
"""
|
|
246
|
+
if isinstance(obj, GroupBy):
|
|
247
|
+
return self._validate_group_by(obj)
|
|
248
|
+
elif isinstance(obj, Join):
|
|
249
|
+
return self._validate_join(obj)
|
|
250
|
+
return []
|
|
251
|
+
|
|
252
|
+
def _has_diff(self, obj: object, old_obj: object, skipped_fields=SKIPPED_FIELDS) -> bool:
|
|
253
|
+
new_json = {
|
|
254
|
+
k: v for k, v in json.loads(thrift_simple_json(obj)).items() if k not in skipped_fields
|
|
255
|
+
}
|
|
256
|
+
old_json = {
|
|
257
|
+
k: v
|
|
258
|
+
for k, v in json.loads(thrift_simple_json(old_obj)).items()
|
|
259
|
+
if k not in skipped_fields
|
|
260
|
+
}
|
|
261
|
+
if isinstance(obj, Join):
|
|
262
|
+
_filter_skipped_fields_from_join(new_json, skipped_fields)
|
|
263
|
+
_filter_skipped_fields_from_join(old_json, skipped_fields)
|
|
264
|
+
|
|
265
|
+
return new_json != old_json
|
|
266
|
+
|
|
267
|
+
def safe_to_overwrite(self, obj: object) -> bool:
|
|
268
|
+
"""When an object is already materialized as online, it is no more safe
|
|
269
|
+
to materialize and overwrite the old conf.
|
|
270
|
+
"""
|
|
271
|
+
old_obj = self._get_old_obj(type(obj), obj.metaData.name)
|
|
272
|
+
return not old_obj or not self._has_diff(obj, old_obj) or not old_obj.metaData.online
|
|
273
|
+
|
|
274
|
+
def _validate_derivations(
|
|
275
|
+
self, pre_derived_cols: List[str], derivations: List[Derivation]
|
|
276
|
+
) -> List[BaseException]:
|
|
277
|
+
"""
|
|
278
|
+
Validate join/groupBy's derivation is defined correctly.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
list of validation errors.
|
|
282
|
+
"""
|
|
283
|
+
errors = []
|
|
284
|
+
derived_columns = set(pre_derived_cols)
|
|
285
|
+
|
|
286
|
+
wild_card_derivation_included = any(
|
|
287
|
+
derivation.expression == "*" for derivation in derivations
|
|
288
|
+
)
|
|
289
|
+
if not wild_card_derivation_included:
|
|
290
|
+
derived_columns.clear()
|
|
291
|
+
for derivation in derivations:
|
|
292
|
+
# if the derivation is a renaming derivation, check whether the expression is in pre-derived schema
|
|
293
|
+
if is_identifier(derivation.expression):
|
|
294
|
+
# for wildcard derivation we want to remove the original column if there is a renaming operation
|
|
295
|
+
# applied on it
|
|
296
|
+
if wild_card_derivation_included:
|
|
297
|
+
if derivation.expression in derived_columns:
|
|
298
|
+
derived_columns.remove(derivation.expression)
|
|
299
|
+
if derivation.expression not in pre_derived_cols and derivation.expression not in (
|
|
300
|
+
"ds",
|
|
301
|
+
"ts",
|
|
302
|
+
):
|
|
303
|
+
errors.append(
|
|
304
|
+
ValueError(
|
|
305
|
+
"Incorrect derivation expression {}, expression not found in pre-derived columns {}".format(
|
|
306
|
+
derivation.expression, pre_derived_cols
|
|
307
|
+
)
|
|
308
|
+
)
|
|
309
|
+
)
|
|
310
|
+
if derivation.name != "*":
|
|
311
|
+
if derivation.name in derived_columns:
|
|
312
|
+
errors.append(
|
|
313
|
+
ValueError(
|
|
314
|
+
"Incorrect derivation name {} due to output column name conflict".format(
|
|
315
|
+
derivation.name
|
|
316
|
+
)
|
|
317
|
+
)
|
|
318
|
+
)
|
|
319
|
+
else:
|
|
320
|
+
derived_columns.add(derivation.name)
|
|
321
|
+
return errors
|
|
322
|
+
|
|
323
|
+
def _validate_join_part_keys(self, join_part: JoinPart, left_cols: List[str]) -> BaseException:
|
|
324
|
+
keys = []
|
|
325
|
+
|
|
326
|
+
key_mapping = join_part.keyMapping if join_part.keyMapping else {}
|
|
327
|
+
for key in join_part.groupBy.keyColumns:
|
|
328
|
+
keys.append(key_mapping.get(key, key))
|
|
329
|
+
|
|
330
|
+
missing = [k for k in keys if k not in left_cols]
|
|
331
|
+
|
|
332
|
+
err_string = ""
|
|
333
|
+
left_cols_as_str = ", ".join(left_cols)
|
|
334
|
+
group_by_name = join_part.groupBy.metaData.name
|
|
335
|
+
if missing:
|
|
336
|
+
key_mapping_str = f"Key Mapping: {key_mapping}" if key_mapping else ""
|
|
337
|
+
err_string += textwrap.dedent(f"""
|
|
338
|
+
- Join is missing keys {missing} on left side. Required for JoinPart: {group_by_name}.
|
|
339
|
+
Existing columns on left side: {left_cols_as_str}
|
|
340
|
+
All required Keys: {join_part.groupBy.keyColumns}
|
|
341
|
+
{key_mapping_str}
|
|
342
|
+
Consider renaming a column on the left, or including the key_mapping argument to your join_part.""")
|
|
343
|
+
|
|
344
|
+
if key_mapping:
|
|
345
|
+
# Left side of key mapping should include columns on the left
|
|
346
|
+
key_map_keys_missing_from_left = [k for k in key_mapping.keys() if k not in left_cols]
|
|
347
|
+
if key_map_keys_missing_from_left:
|
|
348
|
+
err_string += f"\n- The following keys in your key_mapping: {str(key_map_keys_missing_from_left)} for JoinPart {group_by_name} are not included in the left side of the join: {left_cols_as_str}"
|
|
349
|
+
|
|
350
|
+
# Right side of key mapping should only include keys in GroupBy
|
|
351
|
+
keys_missing_from_key_map_values = [
|
|
352
|
+
v for v in key_mapping.values() if v not in join_part.groupBy.keyColumns
|
|
353
|
+
]
|
|
354
|
+
if keys_missing_from_key_map_values:
|
|
355
|
+
err_string += f"\n- The following values in your key_mapping: {str(keys_missing_from_key_map_values)} for JoinPart {group_by_name} do not cover any group by key columns: {join_part.groupBy.keyColumns}"
|
|
356
|
+
|
|
357
|
+
if key_map_keys_missing_from_left or keys_missing_from_key_map_values:
|
|
358
|
+
err_string += (
|
|
359
|
+
"\n(Key Mapping should be formatted as column_from_left -> group_by_key)"
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
if err_string:
|
|
363
|
+
return ValueError(err_string)
|
|
364
|
+
|
|
365
|
+
def _validate_keys(self, join: Join) -> List[BaseException]:
|
|
366
|
+
left = join.left
|
|
367
|
+
|
|
368
|
+
left_selects = None
|
|
369
|
+
if left.events:
|
|
370
|
+
left_selects = left.events.query.selects
|
|
371
|
+
elif left.entities:
|
|
372
|
+
left_selects = left.entities.query.selects
|
|
373
|
+
elif left.joinSource:
|
|
374
|
+
left_selects = left.joinSource.query.selects
|
|
375
|
+
# TODO -- if selects are not selected here, get output cols from join
|
|
376
|
+
|
|
377
|
+
left_cols = []
|
|
378
|
+
|
|
379
|
+
if left_selects:
|
|
380
|
+
left_cols = left_selects.keys()
|
|
381
|
+
|
|
382
|
+
errors = []
|
|
383
|
+
|
|
384
|
+
if left_cols:
|
|
385
|
+
join_parts = list(join.joinParts) # Create a copy to avoid modifying the original
|
|
386
|
+
|
|
387
|
+
# Add label_parts to join_parts to validate if set
|
|
388
|
+
label_parts = join.labelParts
|
|
389
|
+
if label_parts:
|
|
390
|
+
for label_jp in label_parts.labels:
|
|
391
|
+
join_parts.append(label_jp)
|
|
392
|
+
|
|
393
|
+
# Validate join_parts
|
|
394
|
+
for join_part in join_parts:
|
|
395
|
+
join_part_err = self._validate_join_part_keys(join_part, left_cols)
|
|
396
|
+
if join_part_err:
|
|
397
|
+
errors.append(join_part_err)
|
|
398
|
+
|
|
399
|
+
return errors
|
|
400
|
+
|
|
401
|
+
def _validate_join(self, join: Join) -> List[BaseException]:
|
|
402
|
+
"""
|
|
403
|
+
Validate join's status with materialized versions of group_bys
|
|
404
|
+
included by the join.
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
list of validation errors.
|
|
408
|
+
"""
|
|
409
|
+
included_group_bys_and_prefixes = [(rp.groupBy, rp.prefix) for rp in join.joinParts]
|
|
410
|
+
# TODO: Remove label parts check in future PR that deprecates label_parts
|
|
411
|
+
included_label_parts_and_prefixes = (
|
|
412
|
+
[(lp.groupBy, lp.prefix) for lp in join.labelParts.labels] if join.labelParts else []
|
|
413
|
+
)
|
|
414
|
+
included_group_bys = [tup[0] for tup in included_group_bys_and_prefixes]
|
|
415
|
+
|
|
416
|
+
offline_included_group_bys = [
|
|
417
|
+
gb.metaData.name
|
|
418
|
+
for gb in included_group_bys
|
|
419
|
+
if not gb.metaData or gb.metaData.online is False
|
|
420
|
+
]
|
|
421
|
+
errors = []
|
|
422
|
+
old_group_bys = [
|
|
423
|
+
group_by
|
|
424
|
+
for group_by in included_group_bys
|
|
425
|
+
if self._get_old_obj(GroupBy, group_by.metaData.name)
|
|
426
|
+
]
|
|
427
|
+
non_prod_old_group_bys = [
|
|
428
|
+
group_by.metaData.name
|
|
429
|
+
for group_by in old_group_bys
|
|
430
|
+
if group_by.metaData.production is False
|
|
431
|
+
]
|
|
432
|
+
# Check if the underlying groupBy is valid
|
|
433
|
+
group_by_errors = [self._validate_group_by(group_by) for group_by in included_group_bys]
|
|
434
|
+
errors += [
|
|
435
|
+
ValueError(f"join {join.metaData.name}'s underlying {error}")
|
|
436
|
+
for errors in group_by_errors
|
|
437
|
+
for error in errors
|
|
438
|
+
]
|
|
439
|
+
# Check if the production join is using non production groupBy
|
|
440
|
+
if join.metaData.production and non_prod_old_group_bys:
|
|
441
|
+
errors.append(
|
|
442
|
+
ValueError(
|
|
443
|
+
"join {} is production but includes the following non production group_bys: {}".format(
|
|
444
|
+
join.metaData.name, ", ".join(non_prod_old_group_bys)
|
|
445
|
+
)
|
|
446
|
+
)
|
|
447
|
+
)
|
|
448
|
+
# Check if the online join is using the offline groupBy
|
|
449
|
+
if join.metaData.online:
|
|
450
|
+
if offline_included_group_bys:
|
|
451
|
+
errors.append(
|
|
452
|
+
ValueError(
|
|
453
|
+
"join {} is online but includes the following offline group_bys: {}".format(
|
|
454
|
+
join.metaData.name, ", ".join(offline_included_group_bys)
|
|
455
|
+
)
|
|
456
|
+
)
|
|
457
|
+
)
|
|
458
|
+
# Only validate the join derivation when the underlying groupBy is valid
|
|
459
|
+
group_by_correct = all(not errors for errors in group_by_errors)
|
|
460
|
+
if join.derivations and group_by_correct:
|
|
461
|
+
features = list(get_pre_derived_join_features(join).keys())
|
|
462
|
+
# For online joins keys are not included in output schema
|
|
463
|
+
if join.metaData.online:
|
|
464
|
+
columns = features
|
|
465
|
+
else:
|
|
466
|
+
keys = list(get_pre_derived_source_keys(join.left).keys())
|
|
467
|
+
columns = features + keys
|
|
468
|
+
errors.extend(self._validate_derivations(columns, join.derivations))
|
|
469
|
+
|
|
470
|
+
errors.extend(self._validate_keys(join))
|
|
471
|
+
|
|
472
|
+
# If the join is using "short" names, ensure that there are no collisions
|
|
473
|
+
if join.useLongNames is False:
|
|
474
|
+
right_part_collisions = detect_feature_name_collisions(
|
|
475
|
+
included_group_bys_and_prefixes, "right parts", join.metaData.name
|
|
476
|
+
)
|
|
477
|
+
if right_part_collisions:
|
|
478
|
+
errors.append(right_part_collisions)
|
|
479
|
+
|
|
480
|
+
label_part_collisions = detect_feature_name_collisions(
|
|
481
|
+
included_label_parts_and_prefixes, "label parts", join.metaData.name
|
|
482
|
+
)
|
|
483
|
+
if label_part_collisions:
|
|
484
|
+
errors.append(label_part_collisions)
|
|
485
|
+
|
|
486
|
+
return errors
|
|
487
|
+
|
|
488
|
+
def _validate_group_by(self, group_by: GroupBy) -> List[BaseException]:
|
|
489
|
+
"""
|
|
490
|
+
Validate group_by's status with materialized versions of joins
|
|
491
|
+
including the group_by.
|
|
492
|
+
|
|
493
|
+
Return:
|
|
494
|
+
List of validation errors.
|
|
495
|
+
"""
|
|
496
|
+
joins = self._get_old_joins_with_group_by(group_by)
|
|
497
|
+
online_joins = [join.metaData.name for join in joins if join.metaData.online is True]
|
|
498
|
+
prod_joins = [join.metaData.name for join in joins if join.metaData.production is True]
|
|
499
|
+
errors = []
|
|
500
|
+
|
|
501
|
+
non_temporal = group_by.accuracy is None or group_by.accuracy == Accuracy.SNAPSHOT
|
|
502
|
+
|
|
503
|
+
no_topic = not _group_by_has_topic(group_by)
|
|
504
|
+
has_hourly_windows = _group_by_has_hourly_windows(group_by)
|
|
505
|
+
|
|
506
|
+
# batch features cannot contain hourly windows
|
|
507
|
+
if (no_topic and non_temporal) and has_hourly_windows:
|
|
508
|
+
errors.append(
|
|
509
|
+
ValueError(
|
|
510
|
+
f"group_by {group_by.metaData.name} is defined to be daily refreshed but contains "
|
|
511
|
+
f"hourly windows. "
|
|
512
|
+
)
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
def _validate_bounded_event_source():
|
|
516
|
+
if group_by.aggregations is None:
|
|
517
|
+
return
|
|
518
|
+
|
|
519
|
+
unbounded_event_sources = [
|
|
520
|
+
str(src)
|
|
521
|
+
for src in group_by.sources
|
|
522
|
+
if isinstance(get_root_source(src), EventSource)
|
|
523
|
+
and get_query(src).startPartition is None
|
|
524
|
+
]
|
|
525
|
+
|
|
526
|
+
if not unbounded_event_sources:
|
|
527
|
+
return
|
|
528
|
+
|
|
529
|
+
unwindowed_aggregations = [
|
|
530
|
+
str(agg) for agg in group_by.aggregations if agg.windows is None
|
|
531
|
+
]
|
|
532
|
+
|
|
533
|
+
if not unwindowed_aggregations:
|
|
534
|
+
return
|
|
535
|
+
|
|
536
|
+
nln = "\n"
|
|
537
|
+
|
|
538
|
+
errors.append(
|
|
539
|
+
ValueError(
|
|
540
|
+
f"""group_by {group_by.metaData.name} uses unwindowed aggregations [{nln}{f",{nln}".join(unwindowed_aggregations)}{nln}]
|
|
541
|
+
on unbounded event sources: [{nln}{f",{nln}".join(unbounded_event_sources)}{nln}].
|
|
542
|
+
Please set a start_partition on the source, or a window on the aggregation."""
|
|
543
|
+
)
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
_validate_bounded_event_source()
|
|
547
|
+
|
|
548
|
+
# group by that are marked explicitly offline should not be present in
|
|
549
|
+
# materialized online joins.
|
|
550
|
+
if group_by.metaData.online is False and online_joins:
|
|
551
|
+
errors.append(
|
|
552
|
+
ValueError(
|
|
553
|
+
"group_by {} is explicitly marked offline but included in "
|
|
554
|
+
"the following online joins: {}".format(
|
|
555
|
+
group_by.metaData.name, ", ".join(online_joins)
|
|
556
|
+
)
|
|
557
|
+
)
|
|
558
|
+
)
|
|
559
|
+
# group by that are marked explicitly non-production should not be
|
|
560
|
+
# present in materialized production joins.
|
|
561
|
+
if prod_joins:
|
|
562
|
+
if group_by.metaData.production is False:
|
|
563
|
+
errors.append(
|
|
564
|
+
ValueError(
|
|
565
|
+
"group_by {} is explicitly marked as non-production but included in the following production "
|
|
566
|
+
"joins: {}".format(group_by.metaData.name, ", ".join(prod_joins))
|
|
567
|
+
)
|
|
568
|
+
)
|
|
569
|
+
# if the group by is included in any of materialized production join,
|
|
570
|
+
# set it to production in the materialized output.
|
|
571
|
+
else:
|
|
572
|
+
group_by.metaData.production = True
|
|
573
|
+
|
|
574
|
+
# validate the derivations are defined correctly
|
|
575
|
+
if group_by.derivations:
|
|
576
|
+
# For online group_by keys are not included in output schema
|
|
577
|
+
if group_by.metaData.online:
|
|
578
|
+
columns = list(get_pre_derived_group_by_features(group_by).keys())
|
|
579
|
+
else:
|
|
580
|
+
columns = list(get_pre_derived_group_by_columns(group_by).keys())
|
|
581
|
+
errors.extend(self._validate_derivations(columns, group_by.derivations))
|
|
582
|
+
|
|
583
|
+
for source in group_by.sources:
|
|
584
|
+
src: Source = source
|
|
585
|
+
if src.events and src.events.isCumulative and (src.events.query.timeColumn is None):
|
|
586
|
+
errors.append(
|
|
587
|
+
ValueError(
|
|
588
|
+
"Please set query.timeColumn for Cumulative Events Table: {}".format(
|
|
589
|
+
src.events.table
|
|
590
|
+
)
|
|
591
|
+
)
|
|
592
|
+
)
|
|
593
|
+
elif src.joinSource:
|
|
594
|
+
join_obj = src.joinSource.join
|
|
595
|
+
if join_obj.metaData.name is None or join_obj.metaData.team is None:
|
|
596
|
+
errors.append(
|
|
597
|
+
ValueError(f"Join must be defined with team and name: {join_obj}")
|
|
598
|
+
)
|
|
599
|
+
return errors
|
|
600
|
+
|
|
601
|
+
def validate_changes(self, results):
|
|
602
|
+
"""
|
|
603
|
+
Validate changes in compiled objects against existing materialized objects.
|
|
604
|
+
|
|
605
|
+
Args:
|
|
606
|
+
results: List of CompiledObj from parse_configs
|
|
607
|
+
|
|
608
|
+
Returns:
|
|
609
|
+
None (exits on user cancellation)
|
|
610
|
+
"""
|
|
611
|
+
|
|
612
|
+
# Filter out results with errors and only process GroupBy/Join
|
|
613
|
+
valid_results = []
|
|
614
|
+
for result in results:
|
|
615
|
+
if result.errors:
|
|
616
|
+
continue # Skip results with errors
|
|
617
|
+
if result.obj_type not in ["GroupBy", "Join", "StagingQuery"]:
|
|
618
|
+
continue # Skip non-GroupBy/Join objects
|
|
619
|
+
valid_results.append(result)
|
|
620
|
+
|
|
621
|
+
# Categorize changes
|
|
622
|
+
changed_objects = {"changed": [], "deleted": [], "added": []}
|
|
623
|
+
|
|
624
|
+
# Process each valid result
|
|
625
|
+
for result in valid_results:
|
|
626
|
+
obj_name = result.obj.metaData.name
|
|
627
|
+
obj_type = result.obj_type
|
|
628
|
+
|
|
629
|
+
# Check if object exists in old objects
|
|
630
|
+
old_obj = self._get_old_obj(type(result.obj), obj_name)
|
|
631
|
+
|
|
632
|
+
if old_obj is None:
|
|
633
|
+
# New object
|
|
634
|
+
change = self._create_config_change(result.obj, obj_type)
|
|
635
|
+
changed_objects["added"].append(change)
|
|
636
|
+
elif self._has_diff(result.obj, old_obj):
|
|
637
|
+
# Modified object
|
|
638
|
+
change = self._create_config_change(result.obj, obj_type)
|
|
639
|
+
changed_objects["changed"].append(change)
|
|
640
|
+
|
|
641
|
+
# Check for deleted objects
|
|
642
|
+
# Build set of current objects from valid_results for lookup
|
|
643
|
+
current_objects = {(result.obj_type, result.obj.metaData.name) for result in valid_results}
|
|
644
|
+
|
|
645
|
+
for obj_type in ["GroupBy", "Join", "StagingQuery"]:
|
|
646
|
+
old_objs = self.old_objs.get(obj_type, {})
|
|
647
|
+
for obj_name, old_obj in old_objs.items():
|
|
648
|
+
if (obj_type, obj_name) not in current_objects:
|
|
649
|
+
# Object was deleted
|
|
650
|
+
change = self._create_config_change(old_obj, obj_type)
|
|
651
|
+
changed_objects["deleted"].append(change)
|
|
652
|
+
|
|
653
|
+
# Store changes for later confirmation check
|
|
654
|
+
self._pending_changes = {
|
|
655
|
+
"changed": changed_objects["changed"],
|
|
656
|
+
"deleted": changed_objects["deleted"],
|
|
657
|
+
"added": changed_objects["added"],
|
|
658
|
+
}
|
|
659
|
+
|
|
660
|
+
def _filter_non_version_changes(self, existing_changes, added_changes):
|
|
661
|
+
"""Filter out version changes from existing changes.
|
|
662
|
+
|
|
663
|
+
Returns list of changes that are NOT version bumps and require confirmation.
|
|
664
|
+
"""
|
|
665
|
+
added_names = {change.name for change in added_changes}
|
|
666
|
+
non_version_changes = []
|
|
667
|
+
|
|
668
|
+
for change in existing_changes:
|
|
669
|
+
# Check if this deleted config has a corresponding added config that represents a version bump
|
|
670
|
+
is_version_bump = any(
|
|
671
|
+
is_version_change(change.name, added_name) for added_name in added_names
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
if not is_version_bump:
|
|
675
|
+
non_version_changes.append(change)
|
|
676
|
+
|
|
677
|
+
return non_version_changes
|
|
678
|
+
|
|
679
|
+
def check_pending_changes_confirmation(self, compile_status):
|
|
680
|
+
"""Check if user confirmation is needed for pending changes after display."""
|
|
681
|
+
from ai.chronon.cli.compile.display.console import console
|
|
682
|
+
|
|
683
|
+
# Skip confirmation if there are compilation errors
|
|
684
|
+
if self._has_compilation_errors(compile_status):
|
|
685
|
+
return # Don't prompt when there are errors
|
|
686
|
+
|
|
687
|
+
if not hasattr(self, "_pending_changes"):
|
|
688
|
+
return # No pending changes
|
|
689
|
+
|
|
690
|
+
# Check if we need user confirmation (only for non-version-bump changes)
|
|
691
|
+
non_version_changes = self._filter_non_version_changes(
|
|
692
|
+
self._pending_changes["changed"] + self._pending_changes["deleted"],
|
|
693
|
+
self._pending_changes["added"],
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
if non_version_changes:
|
|
697
|
+
if not self._prompt_user_confirmation():
|
|
698
|
+
console.print("❌ Compilation cancelled by user.")
|
|
699
|
+
sys.exit(1)
|
|
700
|
+
|
|
701
|
+
def _has_compilation_errors(self, compile_status):
|
|
702
|
+
"""Check if there are any compilation errors across all class trackers."""
|
|
703
|
+
for tracker in compile_status.cls_to_tracker.values():
|
|
704
|
+
if tracker.files_to_errors:
|
|
705
|
+
return True
|
|
706
|
+
return False
|
|
707
|
+
|
|
708
|
+
def _create_config_change(self, obj, obj_type):
|
|
709
|
+
"""Create a ConfigChange object from a thrift object."""
|
|
710
|
+
return ConfigChange(
|
|
711
|
+
name=obj.metaData.name,
|
|
712
|
+
obj_type=obj_type,
|
|
713
|
+
online=obj.metaData.online if obj.metaData.online else False,
|
|
714
|
+
production=obj.metaData.production if obj.metaData.production else False,
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
def _prompt_user_confirmation(self) -> bool:
|
|
718
|
+
"""
|
|
719
|
+
Prompt user for Y/N confirmation to proceed with overwriting existing configs.
|
|
720
|
+
Returns True if user confirms, False otherwise.
|
|
721
|
+
"""
|
|
722
|
+
from ai.chronon.cli.compile.display.console import console
|
|
723
|
+
|
|
724
|
+
console.print(
|
|
725
|
+
"\n❗ [bold red3]Some configs are changing in-place (changing semantics without changing the version).[/bold red3]"
|
|
726
|
+
)
|
|
727
|
+
console.print(
|
|
728
|
+
"[dim]Note that changes can be caused by directly modifying a config, or by changing the version of an upstream object which changes the input source.[/dim]"
|
|
729
|
+
)
|
|
730
|
+
console.print("")
|
|
731
|
+
console.print(
|
|
732
|
+
"[dim]This can be a safe operation if you are sure that the configs that are unaffected by the change (i.e. upstream config is versioning up in a way that does not effect the downstream), or if the objects in question are not yet used by critical workloads.[/dim]"
|
|
733
|
+
)
|
|
734
|
+
console.print("")
|
|
735
|
+
console.print("❓ [bold]Do you want to proceed? (y/N):[/bold]", end=" ")
|
|
736
|
+
|
|
737
|
+
try:
|
|
738
|
+
response = input().strip().lower()
|
|
739
|
+
return response in ["y", "yes"]
|
|
740
|
+
except (EOFError, KeyboardInterrupt):
|
|
741
|
+
console.print("\n❌ Compilation cancelled.")
|
|
742
|
+
return False
|