awx-zipline-ai 0.2.1__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent/ttypes.py +6 -6
- ai/chronon/airflow_helpers.py +20 -23
- ai/chronon/cli/__init__.py +0 -0
- ai/chronon/cli/compile/__init__.py +0 -0
- ai/chronon/cli/compile/column_hashing.py +40 -17
- ai/chronon/cli/compile/compile_context.py +13 -17
- ai/chronon/cli/compile/compiler.py +59 -36
- ai/chronon/cli/compile/conf_validator.py +251 -99
- ai/chronon/cli/compile/display/__init__.py +0 -0
- ai/chronon/cli/compile/display/class_tracker.py +6 -16
- ai/chronon/cli/compile/display/compile_status.py +10 -10
- ai/chronon/cli/compile/display/diff_result.py +79 -14
- ai/chronon/cli/compile/fill_templates.py +3 -8
- ai/chronon/cli/compile/parse_configs.py +10 -17
- ai/chronon/cli/compile/parse_teams.py +38 -34
- ai/chronon/cli/compile/serializer.py +3 -9
- ai/chronon/cli/compile/version_utils.py +42 -0
- ai/chronon/cli/git_utils.py +2 -13
- ai/chronon/cli/logger.py +0 -2
- ai/chronon/constants.py +1 -1
- ai/chronon/group_by.py +47 -47
- ai/chronon/join.py +46 -32
- ai/chronon/logger.py +1 -2
- ai/chronon/model.py +9 -4
- ai/chronon/query.py +2 -2
- ai/chronon/repo/__init__.py +1 -2
- ai/chronon/repo/aws.py +17 -31
- ai/chronon/repo/cluster.py +121 -50
- ai/chronon/repo/compile.py +14 -8
- ai/chronon/repo/constants.py +1 -1
- ai/chronon/repo/default_runner.py +32 -54
- ai/chronon/repo/explore.py +70 -73
- ai/chronon/repo/extract_objects.py +6 -9
- ai/chronon/repo/gcp.py +89 -88
- ai/chronon/repo/gitpython_utils.py +3 -2
- ai/chronon/repo/hub_runner.py +145 -55
- ai/chronon/repo/hub_uploader.py +2 -1
- ai/chronon/repo/init.py +12 -5
- ai/chronon/repo/join_backfill.py +19 -5
- ai/chronon/repo/run.py +42 -39
- ai/chronon/repo/serializer.py +4 -12
- ai/chronon/repo/utils.py +72 -63
- ai/chronon/repo/zipline.py +3 -19
- ai/chronon/repo/zipline_hub.py +211 -39
- ai/chronon/resources/__init__.py +0 -0
- ai/chronon/resources/gcp/__init__.py +0 -0
- ai/chronon/resources/gcp/group_bys/__init__.py +0 -0
- ai/chronon/resources/gcp/group_bys/test/data.py +13 -17
- ai/chronon/resources/gcp/joins/__init__.py +0 -0
- ai/chronon/resources/gcp/joins/test/data.py +4 -8
- ai/chronon/resources/gcp/sources/__init__.py +0 -0
- ai/chronon/resources/gcp/sources/test/data.py +9 -6
- ai/chronon/resources/gcp/teams.py +9 -21
- ai/chronon/source.py +2 -4
- ai/chronon/staging_query.py +60 -19
- ai/chronon/types.py +3 -2
- ai/chronon/utils.py +21 -68
- ai/chronon/windows.py +2 -4
- {awx_zipline_ai-0.2.1.dist-info → awx_zipline_ai-0.3.1.dist-info}/METADATA +48 -24
- awx_zipline_ai-0.3.1.dist-info/RECORD +96 -0
- awx_zipline_ai-0.3.1.dist-info/top_level.txt +4 -0
- gen_thrift/__init__.py +0 -0
- {ai/chronon → gen_thrift}/api/ttypes.py +327 -197
- {ai/chronon/api → gen_thrift}/common/ttypes.py +9 -39
- gen_thrift/eval/ttypes.py +660 -0
- {ai/chronon → gen_thrift}/hub/ttypes.py +12 -131
- {ai/chronon → gen_thrift}/observability/ttypes.py +343 -180
- {ai/chronon → gen_thrift}/planner/ttypes.py +326 -45
- ai/chronon/eval/__init__.py +0 -122
- ai/chronon/eval/query_parsing.py +0 -19
- ai/chronon/eval/sample_tables.py +0 -100
- ai/chronon/eval/table_scan.py +0 -186
- ai/chronon/orchestration/ttypes.py +0 -4406
- ai/chronon/resources/gcp/README.md +0 -174
- ai/chronon/resources/gcp/zipline-cli-install.sh +0 -54
- awx_zipline_ai-0.2.1.dist-info/RECORD +0 -93
- awx_zipline_ai-0.2.1.dist-info/licenses/LICENSE +0 -202
- awx_zipline_ai-0.2.1.dist-info/top_level.txt +0 -3
- /jars/__init__.py → /__init__.py +0 -0
- {awx_zipline_ai-0.2.1.dist-info → awx_zipline_ai-0.3.1.dist-info}/WHEEL +0 -0
- {awx_zipline_ai-0.2.1.dist-info → awx_zipline_ai-0.3.1.dist-info}/entry_points.txt +0 -0
- {ai/chronon → gen_thrift}/api/__init__.py +0 -0
- {ai/chronon/api/common → gen_thrift/api}/constants.py +0 -0
- {ai/chronon/api → gen_thrift}/common/__init__.py +0 -0
- {ai/chronon/api → gen_thrift/common}/constants.py +0 -0
- {ai/chronon/fetcher → gen_thrift/eval}/__init__.py +0 -0
- {ai/chronon/fetcher → gen_thrift/eval}/constants.py +0 -0
- {ai/chronon/hub → gen_thrift/fetcher}/__init__.py +0 -0
- {ai/chronon/hub → gen_thrift/fetcher}/constants.py +0 -0
- {ai/chronon → gen_thrift}/fetcher/ttypes.py +0 -0
- {ai/chronon/observability → gen_thrift/hub}/__init__.py +0 -0
- {ai/chronon/observability → gen_thrift/hub}/constants.py +0 -0
- {ai/chronon/orchestration → gen_thrift/observability}/__init__.py +0 -0
- {ai/chronon/orchestration → gen_thrift/observability}/constants.py +0 -0
- {ai/chronon → gen_thrift}/planner/__init__.py +0 -0
- {ai/chronon → gen_thrift}/planner/constants.py +0 -0
|
@@ -1,5 +1,4 @@
|
|
|
1
|
-
"""Object for checking whether a Chronon API thrift object is consistent with other
|
|
2
|
-
"""
|
|
1
|
+
"""Object for checking whether a Chronon API thrift object is consistent with other"""
|
|
3
2
|
|
|
4
3
|
# Copyright (C) 2023 The Chronon Authors.
|
|
5
4
|
#
|
|
@@ -18,12 +17,14 @@
|
|
|
18
17
|
import json
|
|
19
18
|
import logging
|
|
20
19
|
import re
|
|
20
|
+
import sys
|
|
21
21
|
import textwrap
|
|
22
22
|
from collections import defaultdict
|
|
23
|
+
from dataclasses import dataclass
|
|
23
24
|
from typing import Dict, List, Tuple
|
|
24
25
|
|
|
25
|
-
import
|
|
26
|
-
from
|
|
26
|
+
import gen_thrift.common.ttypes as common
|
|
27
|
+
from gen_thrift.api.ttypes import (
|
|
27
28
|
Accuracy,
|
|
28
29
|
Aggregation,
|
|
29
30
|
Derivation,
|
|
@@ -33,6 +34,7 @@ from ai.chronon.api.ttypes import (
|
|
|
33
34
|
JoinPart,
|
|
34
35
|
Source,
|
|
35
36
|
)
|
|
37
|
+
|
|
36
38
|
from ai.chronon.cli.compile.column_hashing import (
|
|
37
39
|
compute_group_by_columns_hashes,
|
|
38
40
|
get_pre_derived_group_by_columns,
|
|
@@ -40,6 +42,7 @@ from ai.chronon.cli.compile.column_hashing import (
|
|
|
40
42
|
get_pre_derived_join_features,
|
|
41
43
|
get_pre_derived_source_keys,
|
|
42
44
|
)
|
|
45
|
+
from ai.chronon.cli.compile.version_utils import is_version_change
|
|
43
46
|
from ai.chronon.logger import get_logger
|
|
44
47
|
from ai.chronon.repo.serializer import thrift_simple_json
|
|
45
48
|
from ai.chronon.utils import get_query, get_root_source
|
|
@@ -49,6 +52,20 @@ SKIPPED_FIELDS = frozenset(["metaData"])
|
|
|
49
52
|
EXTERNAL_KEY = "onlineExternalParts"
|
|
50
53
|
|
|
51
54
|
|
|
55
|
+
@dataclass
|
|
56
|
+
class ConfigChange:
|
|
57
|
+
"""Represents a change to a compiled config object."""
|
|
58
|
+
|
|
59
|
+
name: str
|
|
60
|
+
obj_type: str
|
|
61
|
+
online: bool = False
|
|
62
|
+
production: bool = False
|
|
63
|
+
base_name: str = None
|
|
64
|
+
old_version: int = None
|
|
65
|
+
new_version: int = None
|
|
66
|
+
is_version_change: bool = False
|
|
67
|
+
|
|
68
|
+
|
|
52
69
|
def _filter_skipped_fields_from_join(json_obj: Dict, skipped_fields):
|
|
53
70
|
for join_part in json_obj["joinParts"]:
|
|
54
71
|
group_by = join_part["groupBy"]
|
|
@@ -91,7 +108,6 @@ def _group_by_has_hourly_windows(groupBy: GroupBy) -> bool:
|
|
|
91
108
|
return False
|
|
92
109
|
|
|
93
110
|
for agg in aggs:
|
|
94
|
-
|
|
95
111
|
if not agg.windows:
|
|
96
112
|
return False
|
|
97
113
|
|
|
@@ -103,9 +119,7 @@ def _group_by_has_hourly_windows(groupBy: GroupBy) -> bool:
|
|
|
103
119
|
|
|
104
120
|
|
|
105
121
|
def detect_feature_name_collisions(
|
|
106
|
-
|
|
107
|
-
entity_set_type: str,
|
|
108
|
-
name: str
|
|
122
|
+
group_bys: List[Tuple[GroupBy, str]], entity_set_type: str, name: str
|
|
109
123
|
) -> BaseException | None:
|
|
110
124
|
# Build a map of output_column -> set of group_by name
|
|
111
125
|
output_col_to_gbs = {}
|
|
@@ -115,7 +129,7 @@ def detect_feature_name_collisions(
|
|
|
115
129
|
prefix_str = jp_prefix_str + key_str + "_"
|
|
116
130
|
cols = compute_group_by_columns_hashes(gb, exclude_keys=True)
|
|
117
131
|
if not cols:
|
|
118
|
-
print(
|
|
132
|
+
print("HERE")
|
|
119
133
|
cols = {
|
|
120
134
|
f"{prefix_str}{base_col}"
|
|
121
135
|
for base_col in list(compute_group_by_columns_hashes(gb, exclude_keys=True).keys())
|
|
@@ -127,19 +141,13 @@ def detect_feature_name_collisions(
|
|
|
127
141
|
output_col_to_gbs[col].add(gb_name)
|
|
128
142
|
|
|
129
143
|
# Find output columns that are produced by more than one group_by
|
|
130
|
-
collisions = {
|
|
131
|
-
col: gb_names
|
|
132
|
-
for col, gb_names in output_col_to_gbs.items()
|
|
133
|
-
if len(gb_names) > 1
|
|
134
|
-
}
|
|
144
|
+
collisions = {col: gb_names for col, gb_names in output_col_to_gbs.items() if len(gb_names) > 1}
|
|
135
145
|
|
|
136
146
|
if not collisions:
|
|
137
147
|
return None # no collisions
|
|
138
148
|
|
|
139
149
|
# Assemble an error message listing the conflicting GroupBys by name
|
|
140
|
-
lines = [
|
|
141
|
-
f"{entity_set_type} for Join: {name} has the following output name collisions:\n"
|
|
142
|
-
]
|
|
150
|
+
lines = [f"{entity_set_type} for Join: {name} has the following output name collisions:\n"]
|
|
143
151
|
for col, gb_names in collisions.items():
|
|
144
152
|
names_str = ", ".join(sorted(gb_names))
|
|
145
153
|
lines.append(f" - [{col}] has collisions from: [{names_str}]")
|
|
@@ -161,9 +169,9 @@ class ConfValidator(object):
|
|
|
161
169
|
output_root,
|
|
162
170
|
existing_gbs,
|
|
163
171
|
existing_joins,
|
|
172
|
+
existing_staging_queries,
|
|
164
173
|
log_level=logging.INFO,
|
|
165
174
|
):
|
|
166
|
-
|
|
167
175
|
self.chronon_root_path = input_root
|
|
168
176
|
self.output_root = output_root
|
|
169
177
|
|
|
@@ -176,8 +184,10 @@ class ConfValidator(object):
|
|
|
176
184
|
self.old_objs = defaultdict(dict)
|
|
177
185
|
self.old_group_bys = existing_gbs
|
|
178
186
|
self.old_joins = existing_joins
|
|
187
|
+
self.old_staging_queries = existing_staging_queries
|
|
179
188
|
self.old_objs["GroupBy"] = self.old_group_bys
|
|
180
189
|
self.old_objs["Join"] = self.old_joins
|
|
190
|
+
self.old_objs["StagingQuery"] = self.old_staging_queries
|
|
181
191
|
|
|
182
192
|
def _get_old_obj(self, obj_class: type, obj_name: str) -> object:
|
|
183
193
|
"""
|
|
@@ -223,9 +233,7 @@ class ConfValidator(object):
|
|
|
223
233
|
elif not any(
|
|
224
234
|
join.metaData.online for join in self._get_old_joins_with_group_by(obj)
|
|
225
235
|
) and not _is_batch_upload_needed(obj):
|
|
226
|
-
reasons.append(
|
|
227
|
-
"is not marked online/production nor is included in any online join"
|
|
228
|
-
)
|
|
236
|
+
reasons.append("is not marked online/production nor is included in any online join")
|
|
229
237
|
return reasons
|
|
230
238
|
|
|
231
239
|
def validate_obj(self, obj: object) -> List[BaseException]:
|
|
@@ -241,13 +249,9 @@ class ConfValidator(object):
|
|
|
241
249
|
return self._validate_join(obj)
|
|
242
250
|
return []
|
|
243
251
|
|
|
244
|
-
def _has_diff(
|
|
245
|
-
self, obj: object, old_obj: object, skipped_fields=SKIPPED_FIELDS
|
|
246
|
-
) -> bool:
|
|
252
|
+
def _has_diff(self, obj: object, old_obj: object, skipped_fields=SKIPPED_FIELDS) -> bool:
|
|
247
253
|
new_json = {
|
|
248
|
-
k: v
|
|
249
|
-
for k, v in json.loads(thrift_simple_json(obj)).items()
|
|
250
|
-
if k not in skipped_fields
|
|
254
|
+
k: v for k, v in json.loads(thrift_simple_json(obj)).items() if k not in skipped_fields
|
|
251
255
|
}
|
|
252
256
|
old_json = {
|
|
253
257
|
k: v
|
|
@@ -257,6 +261,7 @@ class ConfValidator(object):
|
|
|
257
261
|
if isinstance(obj, Join):
|
|
258
262
|
_filter_skipped_fields_from_join(new_json, skipped_fields)
|
|
259
263
|
_filter_skipped_fields_from_join(old_json, skipped_fields)
|
|
264
|
+
|
|
260
265
|
return new_json != old_json
|
|
261
266
|
|
|
262
267
|
def safe_to_overwrite(self, obj: object) -> bool:
|
|
@@ -264,11 +269,7 @@ class ConfValidator(object):
|
|
|
264
269
|
to materialize and overwrite the old conf.
|
|
265
270
|
"""
|
|
266
271
|
old_obj = self._get_old_obj(type(obj), obj.metaData.name)
|
|
267
|
-
return (
|
|
268
|
-
not old_obj
|
|
269
|
-
or not self._has_diff(obj, old_obj)
|
|
270
|
-
or not old_obj.metaData.online
|
|
271
|
-
)
|
|
272
|
+
return not old_obj or not self._has_diff(obj, old_obj) or not old_obj.metaData.online
|
|
272
273
|
|
|
273
274
|
def _validate_derivations(
|
|
274
275
|
self, pre_derived_cols: List[str], derivations: List[Derivation]
|
|
@@ -295,23 +296,26 @@ class ConfValidator(object):
|
|
|
295
296
|
if wild_card_derivation_included:
|
|
296
297
|
if derivation.expression in derived_columns:
|
|
297
298
|
derived_columns.remove(derivation.expression)
|
|
298
|
-
if (
|
|
299
|
-
|
|
300
|
-
|
|
299
|
+
if derivation.expression not in pre_derived_cols and derivation.expression not in (
|
|
300
|
+
"ds",
|
|
301
|
+
"ts",
|
|
301
302
|
):
|
|
302
303
|
errors.append(
|
|
303
|
-
ValueError(
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
304
|
+
ValueError(
|
|
305
|
+
"Incorrect derivation expression {}, expression not found in pre-derived columns {}".format(
|
|
306
|
+
derivation.expression, pre_derived_cols
|
|
307
|
+
)
|
|
308
|
+
)
|
|
307
309
|
)
|
|
308
310
|
if derivation.name != "*":
|
|
309
311
|
if derivation.name in derived_columns:
|
|
310
312
|
errors.append(
|
|
311
|
-
ValueError(
|
|
312
|
-
derivation.
|
|
313
|
+
ValueError(
|
|
314
|
+
"Incorrect derivation name {} due to output column name conflict".format(
|
|
315
|
+
derivation.name
|
|
316
|
+
)
|
|
313
317
|
)
|
|
314
|
-
)
|
|
318
|
+
)
|
|
315
319
|
else:
|
|
316
320
|
derived_columns.add(derivation.name)
|
|
317
321
|
return errors
|
|
@@ -344,17 +348,20 @@ class ConfValidator(object):
|
|
|
344
348
|
err_string += f"\n- The following keys in your key_mapping: {str(key_map_keys_missing_from_left)} for JoinPart {group_by_name} are not included in the left side of the join: {left_cols_as_str}"
|
|
345
349
|
|
|
346
350
|
# Right side of key mapping should only include keys in GroupBy
|
|
347
|
-
keys_missing_from_key_map_values = [
|
|
351
|
+
keys_missing_from_key_map_values = [
|
|
352
|
+
v for v in key_mapping.values() if v not in join_part.groupBy.keyColumns
|
|
353
|
+
]
|
|
348
354
|
if keys_missing_from_key_map_values:
|
|
349
355
|
err_string += f"\n- The following values in your key_mapping: {str(keys_missing_from_key_map_values)} for JoinPart {group_by_name} do not cover any group by key columns: {join_part.groupBy.keyColumns}"
|
|
350
356
|
|
|
351
357
|
if key_map_keys_missing_from_left or keys_missing_from_key_map_values:
|
|
352
|
-
err_string +=
|
|
358
|
+
err_string += (
|
|
359
|
+
"\n(Key Mapping should be formatted as column_from_left -> group_by_key)"
|
|
360
|
+
)
|
|
353
361
|
|
|
354
362
|
if err_string:
|
|
355
363
|
return ValueError(err_string)
|
|
356
364
|
|
|
357
|
-
|
|
358
365
|
def _validate_keys(self, join: Join) -> List[BaseException]:
|
|
359
366
|
left = join.left
|
|
360
367
|
|
|
@@ -375,7 +382,7 @@ class ConfValidator(object):
|
|
|
375
382
|
errors = []
|
|
376
383
|
|
|
377
384
|
if left_cols:
|
|
378
|
-
join_parts = join.joinParts
|
|
385
|
+
join_parts = list(join.joinParts) # Create a copy to avoid modifying the original
|
|
379
386
|
|
|
380
387
|
# Add label_parts to join_parts to validate if set
|
|
381
388
|
label_parts = join.labelParts
|
|
@@ -391,7 +398,6 @@ class ConfValidator(object):
|
|
|
391
398
|
|
|
392
399
|
return errors
|
|
393
400
|
|
|
394
|
-
|
|
395
401
|
def _validate_join(self, join: Join) -> List[BaseException]:
|
|
396
402
|
"""
|
|
397
403
|
Validate join's status with materialized versions of group_bys
|
|
@@ -402,7 +408,9 @@ class ConfValidator(object):
|
|
|
402
408
|
"""
|
|
403
409
|
included_group_bys_and_prefixes = [(rp.groupBy, rp.prefix) for rp in join.joinParts]
|
|
404
410
|
# TODO: Remove label parts check in future PR that deprecates label_parts
|
|
405
|
-
included_label_parts_and_prefixes =
|
|
411
|
+
included_label_parts_and_prefixes = (
|
|
412
|
+
[(lp.groupBy, lp.prefix) for lp in join.labelParts.labels] if join.labelParts else []
|
|
413
|
+
)
|
|
406
414
|
included_group_bys = [tup[0] for tup in included_group_bys_and_prefixes]
|
|
407
415
|
|
|
408
416
|
offline_included_group_bys = [
|
|
@@ -422,9 +430,7 @@ class ConfValidator(object):
|
|
|
422
430
|
if group_by.metaData.production is False
|
|
423
431
|
]
|
|
424
432
|
# Check if the underlying groupBy is valid
|
|
425
|
-
group_by_errors = [
|
|
426
|
-
self._validate_group_by(group_by) for group_by in included_group_bys
|
|
427
|
-
]
|
|
433
|
+
group_by_errors = [self._validate_group_by(group_by) for group_by in included_group_bys]
|
|
428
434
|
errors += [
|
|
429
435
|
ValueError(f"join {join.metaData.name}'s underlying {error}")
|
|
430
436
|
for errors in group_by_errors
|
|
@@ -433,18 +439,22 @@ class ConfValidator(object):
|
|
|
433
439
|
# Check if the production join is using non production groupBy
|
|
434
440
|
if join.metaData.production and non_prod_old_group_bys:
|
|
435
441
|
errors.append(
|
|
436
|
-
ValueError(
|
|
437
|
-
join
|
|
442
|
+
ValueError(
|
|
443
|
+
"join {} is production but includes the following non production group_bys: {}".format(
|
|
444
|
+
join.metaData.name, ", ".join(non_prod_old_group_bys)
|
|
445
|
+
)
|
|
438
446
|
)
|
|
439
|
-
)
|
|
447
|
+
)
|
|
440
448
|
# Check if the online join is using the offline groupBy
|
|
441
449
|
if join.metaData.online:
|
|
442
450
|
if offline_included_group_bys:
|
|
443
451
|
errors.append(
|
|
444
|
-
ValueError(
|
|
445
|
-
join
|
|
452
|
+
ValueError(
|
|
453
|
+
"join {} is online but includes the following offline group_bys: {}".format(
|
|
454
|
+
join.metaData.name, ", ".join(offline_included_group_bys)
|
|
455
|
+
)
|
|
446
456
|
)
|
|
447
|
-
)
|
|
457
|
+
)
|
|
448
458
|
# Only validate the join derivation when the underlying groupBy is valid
|
|
449
459
|
group_by_correct = all(not errors for errors in group_by_errors)
|
|
450
460
|
if join.derivations and group_by_correct:
|
|
@@ -457,22 +467,23 @@ class ConfValidator(object):
|
|
|
457
467
|
columns = features + keys
|
|
458
468
|
errors.extend(self._validate_derivations(columns, join.derivations))
|
|
459
469
|
|
|
460
|
-
|
|
461
470
|
errors.extend(self._validate_keys(join))
|
|
462
471
|
|
|
463
472
|
# If the join is using "short" names, ensure that there are no collisions
|
|
464
473
|
if join.useLongNames is False:
|
|
465
|
-
right_part_collisions = detect_feature_name_collisions(
|
|
474
|
+
right_part_collisions = detect_feature_name_collisions(
|
|
475
|
+
included_group_bys_and_prefixes, "right parts", join.metaData.name
|
|
476
|
+
)
|
|
466
477
|
if right_part_collisions:
|
|
467
478
|
errors.append(right_part_collisions)
|
|
468
479
|
|
|
469
|
-
label_part_collisions = detect_feature_name_collisions(
|
|
480
|
+
label_part_collisions = detect_feature_name_collisions(
|
|
481
|
+
included_label_parts_and_prefixes, "label parts", join.metaData.name
|
|
482
|
+
)
|
|
470
483
|
if label_part_collisions:
|
|
471
484
|
errors.append(label_part_collisions)
|
|
472
485
|
|
|
473
486
|
return errors
|
|
474
|
-
|
|
475
|
-
|
|
476
487
|
|
|
477
488
|
def _validate_group_by(self, group_by: GroupBy) -> List[BaseException]:
|
|
478
489
|
"""
|
|
@@ -483,17 +494,11 @@ class ConfValidator(object):
|
|
|
483
494
|
List of validation errors.
|
|
484
495
|
"""
|
|
485
496
|
joins = self._get_old_joins_with_group_by(group_by)
|
|
486
|
-
online_joins = [
|
|
487
|
-
|
|
488
|
-
]
|
|
489
|
-
prod_joins = [
|
|
490
|
-
join.metaData.name for join in joins if join.metaData.production is True
|
|
491
|
-
]
|
|
497
|
+
online_joins = [join.metaData.name for join in joins if join.metaData.online is True]
|
|
498
|
+
prod_joins = [join.metaData.name for join in joins if join.metaData.production is True]
|
|
492
499
|
errors = []
|
|
493
500
|
|
|
494
|
-
non_temporal =
|
|
495
|
-
group_by.accuracy is None or group_by.accuracy == Accuracy.SNAPSHOT
|
|
496
|
-
)
|
|
501
|
+
non_temporal = group_by.accuracy is None or group_by.accuracy == Accuracy.SNAPSHOT
|
|
497
502
|
|
|
498
503
|
no_topic = not _group_by_has_topic(group_by)
|
|
499
504
|
has_hourly_windows = _group_by_has_hourly_windows(group_by)
|
|
@@ -501,48 +506,55 @@ class ConfValidator(object):
|
|
|
501
506
|
# batch features cannot contain hourly windows
|
|
502
507
|
if (no_topic and non_temporal) and has_hourly_windows:
|
|
503
508
|
errors.append(
|
|
504
|
-
ValueError(
|
|
505
|
-
|
|
506
|
-
|
|
509
|
+
ValueError(
|
|
510
|
+
f"group_by {group_by.metaData.name} is defined to be daily refreshed but contains "
|
|
511
|
+
f"hourly windows. "
|
|
512
|
+
)
|
|
507
513
|
)
|
|
508
514
|
|
|
509
515
|
def _validate_bounded_event_source():
|
|
510
516
|
if group_by.aggregations is None:
|
|
511
517
|
return
|
|
512
|
-
|
|
518
|
+
|
|
513
519
|
unbounded_event_sources = [
|
|
514
520
|
str(src)
|
|
515
|
-
for src in group_by.sources
|
|
516
|
-
if isinstance(get_root_source(src), EventSource)
|
|
521
|
+
for src in group_by.sources
|
|
522
|
+
if isinstance(get_root_source(src), EventSource)
|
|
523
|
+
and get_query(src).startPartition is None
|
|
517
524
|
]
|
|
518
525
|
|
|
519
526
|
if not unbounded_event_sources:
|
|
520
527
|
return
|
|
521
|
-
|
|
522
|
-
unwindowed_aggregations = [
|
|
528
|
+
|
|
529
|
+
unwindowed_aggregations = [
|
|
530
|
+
str(agg) for agg in group_by.aggregations if agg.windows is None
|
|
531
|
+
]
|
|
523
532
|
|
|
524
533
|
if not unwindowed_aggregations:
|
|
525
534
|
return
|
|
526
|
-
|
|
535
|
+
|
|
527
536
|
nln = "\n"
|
|
528
537
|
|
|
529
538
|
errors.append(
|
|
530
539
|
ValueError(
|
|
531
540
|
f"""group_by {group_by.metaData.name} uses unwindowed aggregations [{nln}{f",{nln}".join(unwindowed_aggregations)}{nln}]
|
|
532
541
|
on unbounded event sources: [{nln}{f",{nln}".join(unbounded_event_sources)}{nln}].
|
|
533
|
-
Please set a start_partition on the source, or a window on the aggregation."""
|
|
542
|
+
Please set a start_partition on the source, or a window on the aggregation."""
|
|
534
543
|
)
|
|
535
|
-
|
|
544
|
+
)
|
|
545
|
+
|
|
536
546
|
_validate_bounded_event_source()
|
|
537
547
|
|
|
538
548
|
# group by that are marked explicitly offline should not be present in
|
|
539
549
|
# materialized online joins.
|
|
540
550
|
if group_by.metaData.online is False and online_joins:
|
|
541
551
|
errors.append(
|
|
542
|
-
ValueError(
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
552
|
+
ValueError(
|
|
553
|
+
"group_by {} is explicitly marked offline but included in "
|
|
554
|
+
"the following online joins: {}".format(
|
|
555
|
+
group_by.metaData.name, ", ".join(online_joins)
|
|
556
|
+
)
|
|
557
|
+
)
|
|
546
558
|
)
|
|
547
559
|
# group by that are marked explicitly non-production should not be
|
|
548
560
|
# present in materialized production joins.
|
|
@@ -551,8 +563,9 @@ class ConfValidator(object):
|
|
|
551
563
|
errors.append(
|
|
552
564
|
ValueError(
|
|
553
565
|
"group_by {} is explicitly marked as non-production but included in the following production "
|
|
554
|
-
|
|
555
|
-
|
|
566
|
+
"joins: {}".format(group_by.metaData.name, ", ".join(prod_joins))
|
|
567
|
+
)
|
|
568
|
+
)
|
|
556
569
|
# if the group by is included in any of materialized production join,
|
|
557
570
|
# set it to production in the materialized output.
|
|
558
571
|
else:
|
|
@@ -569,22 +582,161 @@ class ConfValidator(object):
|
|
|
569
582
|
|
|
570
583
|
for source in group_by.sources:
|
|
571
584
|
src: Source = source
|
|
572
|
-
if (
|
|
573
|
-
src.events
|
|
574
|
-
and src.events.isCumulative
|
|
575
|
-
and (src.events.query.timeColumn is None)
|
|
576
|
-
):
|
|
585
|
+
if src.events and src.events.isCumulative and (src.events.query.timeColumn is None):
|
|
577
586
|
errors.append(
|
|
578
|
-
ValueError(
|
|
579
|
-
|
|
580
|
-
|
|
587
|
+
ValueError(
|
|
588
|
+
"Please set query.timeColumn for Cumulative Events Table: {}".format(
|
|
589
|
+
src.events.table
|
|
590
|
+
)
|
|
591
|
+
)
|
|
581
592
|
)
|
|
582
|
-
elif
|
|
583
|
-
src.joinSource
|
|
584
|
-
):
|
|
593
|
+
elif src.joinSource:
|
|
585
594
|
join_obj = src.joinSource.join
|
|
586
595
|
if join_obj.metaData.name is None or join_obj.metaData.team is None:
|
|
587
596
|
errors.append(
|
|
588
597
|
ValueError(f"Join must be defined with team and name: {join_obj}")
|
|
589
598
|
)
|
|
590
599
|
return errors
|
|
600
|
+
|
|
601
|
+
def validate_changes(self, results):
|
|
602
|
+
"""
|
|
603
|
+
Validate changes in compiled objects against existing materialized objects.
|
|
604
|
+
|
|
605
|
+
Args:
|
|
606
|
+
results: List of CompiledObj from parse_configs
|
|
607
|
+
|
|
608
|
+
Returns:
|
|
609
|
+
None (exits on user cancellation)
|
|
610
|
+
"""
|
|
611
|
+
|
|
612
|
+
# Filter out results with errors and only process GroupBy/Join
|
|
613
|
+
valid_results = []
|
|
614
|
+
for result in results:
|
|
615
|
+
if result.errors:
|
|
616
|
+
continue # Skip results with errors
|
|
617
|
+
if result.obj_type not in ["GroupBy", "Join", "StagingQuery"]:
|
|
618
|
+
continue # Skip non-GroupBy/Join objects
|
|
619
|
+
valid_results.append(result)
|
|
620
|
+
|
|
621
|
+
# Categorize changes
|
|
622
|
+
changed_objects = {"changed": [], "deleted": [], "added": []}
|
|
623
|
+
|
|
624
|
+
# Process each valid result
|
|
625
|
+
for result in valid_results:
|
|
626
|
+
obj_name = result.obj.metaData.name
|
|
627
|
+
obj_type = result.obj_type
|
|
628
|
+
|
|
629
|
+
# Check if object exists in old objects
|
|
630
|
+
old_obj = self._get_old_obj(type(result.obj), obj_name)
|
|
631
|
+
|
|
632
|
+
if old_obj is None:
|
|
633
|
+
# New object
|
|
634
|
+
change = self._create_config_change(result.obj, obj_type)
|
|
635
|
+
changed_objects["added"].append(change)
|
|
636
|
+
elif self._has_diff(result.obj, old_obj):
|
|
637
|
+
# Modified object
|
|
638
|
+
change = self._create_config_change(result.obj, obj_type)
|
|
639
|
+
changed_objects["changed"].append(change)
|
|
640
|
+
|
|
641
|
+
# Check for deleted objects
|
|
642
|
+
# Build set of current objects from valid_results for lookup
|
|
643
|
+
current_objects = {(result.obj_type, result.obj.metaData.name) for result in valid_results}
|
|
644
|
+
|
|
645
|
+
for obj_type in ["GroupBy", "Join", "StagingQuery"]:
|
|
646
|
+
old_objs = self.old_objs.get(obj_type, {})
|
|
647
|
+
for obj_name, old_obj in old_objs.items():
|
|
648
|
+
if (obj_type, obj_name) not in current_objects:
|
|
649
|
+
# Object was deleted
|
|
650
|
+
change = self._create_config_change(old_obj, obj_type)
|
|
651
|
+
changed_objects["deleted"].append(change)
|
|
652
|
+
|
|
653
|
+
# Store changes for later confirmation check
|
|
654
|
+
self._pending_changes = {
|
|
655
|
+
"changed": changed_objects["changed"],
|
|
656
|
+
"deleted": changed_objects["deleted"],
|
|
657
|
+
"added": changed_objects["added"],
|
|
658
|
+
}
|
|
659
|
+
|
|
660
|
+
def _filter_non_version_changes(self, existing_changes, added_changes):
|
|
661
|
+
"""Filter out version changes from existing changes.
|
|
662
|
+
|
|
663
|
+
Returns list of changes that are NOT version bumps and require confirmation.
|
|
664
|
+
"""
|
|
665
|
+
added_names = {change.name for change in added_changes}
|
|
666
|
+
non_version_changes = []
|
|
667
|
+
|
|
668
|
+
for change in existing_changes:
|
|
669
|
+
# Check if this deleted config has a corresponding added config that represents a version bump
|
|
670
|
+
is_version_bump = any(
|
|
671
|
+
is_version_change(change.name, added_name) for added_name in added_names
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
if not is_version_bump:
|
|
675
|
+
non_version_changes.append(change)
|
|
676
|
+
|
|
677
|
+
return non_version_changes
|
|
678
|
+
|
|
679
|
+
def check_pending_changes_confirmation(self, compile_status):
|
|
680
|
+
"""Check if user confirmation is needed for pending changes after display."""
|
|
681
|
+
from ai.chronon.cli.compile.display.console import console
|
|
682
|
+
|
|
683
|
+
# Skip confirmation if there are compilation errors
|
|
684
|
+
if self._has_compilation_errors(compile_status):
|
|
685
|
+
return # Don't prompt when there are errors
|
|
686
|
+
|
|
687
|
+
if not hasattr(self, "_pending_changes"):
|
|
688
|
+
return # No pending changes
|
|
689
|
+
|
|
690
|
+
# Check if we need user confirmation (only for non-version-bump changes)
|
|
691
|
+
non_version_changes = self._filter_non_version_changes(
|
|
692
|
+
self._pending_changes["changed"] + self._pending_changes["deleted"],
|
|
693
|
+
self._pending_changes["added"],
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
if non_version_changes:
|
|
697
|
+
if not self._prompt_user_confirmation():
|
|
698
|
+
console.print("❌ Compilation cancelled by user.")
|
|
699
|
+
sys.exit(1)
|
|
700
|
+
|
|
701
|
+
def _has_compilation_errors(self, compile_status):
|
|
702
|
+
"""Check if there are any compilation errors across all class trackers."""
|
|
703
|
+
for tracker in compile_status.cls_to_tracker.values():
|
|
704
|
+
if tracker.files_to_errors:
|
|
705
|
+
return True
|
|
706
|
+
return False
|
|
707
|
+
|
|
708
|
+
def _create_config_change(self, obj, obj_type):
|
|
709
|
+
"""Create a ConfigChange object from a thrift object."""
|
|
710
|
+
return ConfigChange(
|
|
711
|
+
name=obj.metaData.name,
|
|
712
|
+
obj_type=obj_type,
|
|
713
|
+
online=obj.metaData.online if obj.metaData.online else False,
|
|
714
|
+
production=obj.metaData.production if obj.metaData.production else False,
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
def _prompt_user_confirmation(self) -> bool:
|
|
718
|
+
"""
|
|
719
|
+
Prompt user for Y/N confirmation to proceed with overwriting existing configs.
|
|
720
|
+
Returns True if user confirms, False otherwise.
|
|
721
|
+
"""
|
|
722
|
+
from ai.chronon.cli.compile.display.console import console
|
|
723
|
+
|
|
724
|
+
console.print(
|
|
725
|
+
"\n❗ [bold red3]Some configs are changing in-place (changing semantics without changing the version).[/bold red3]"
|
|
726
|
+
)
|
|
727
|
+
console.print(
|
|
728
|
+
"[dim]Note that changes can be caused by directly modifying a config, or by changing the version of an upstream object which changes the input source.[/dim]"
|
|
729
|
+
)
|
|
730
|
+
console.print("")
|
|
731
|
+
console.print(
|
|
732
|
+
"[dim]This can be a safe operation if you are sure that the configs that are unaffected by the change (i.e. upstream config is versioning up in a way that does not effect the downstream), or if the objects in question are not yet used by critical workloads.[/dim]"
|
|
733
|
+
)
|
|
734
|
+
console.print("")
|
|
735
|
+
console.print("❓ [bold]Do you want to proceed? (y/N):[/bold]", end=" ")
|
|
736
|
+
|
|
737
|
+
try:
|
|
738
|
+
response = input().strip().lower()
|
|
739
|
+
return response in ["y", "yes"]
|
|
740
|
+
except (EOFError, KeyboardInterrupt):
|
|
741
|
+
console.print("\n❌ Compilation cancelled.")
|
|
742
|
+
return False
|
|
File without changes
|
|
@@ -24,9 +24,7 @@ class ClassTracker:
|
|
|
24
24
|
self.existing_objs[obj.name] = obj
|
|
25
25
|
|
|
26
26
|
def add(self, compiled: CompiledObj) -> None:
|
|
27
|
-
|
|
28
27
|
if compiled.errors:
|
|
29
|
-
|
|
30
28
|
if compiled.file not in self.files_to_errors:
|
|
31
29
|
self.files_to_errors[compiled.file] = []
|
|
32
30
|
|
|
@@ -43,12 +41,10 @@ class ClassTracker:
|
|
|
43
41
|
|
|
44
42
|
def _update_diff(self, compiled: CompiledObj) -> None:
|
|
45
43
|
if compiled.name in self.existing_objs:
|
|
46
|
-
|
|
47
44
|
existing_json = self.existing_objs[compiled.name].tjson
|
|
48
45
|
new_json = compiled.tjson
|
|
49
46
|
|
|
50
47
|
if existing_json != new_json:
|
|
51
|
-
|
|
52
48
|
diff = difflib.unified_diff(
|
|
53
49
|
existing_json.splitlines(keepends=True),
|
|
54
50
|
new_json.splitlines(keepends=True),
|
|
@@ -74,16 +70,7 @@ class ClassTracker:
|
|
|
74
70
|
text = Text(overflow="fold", no_wrap=False)
|
|
75
71
|
|
|
76
72
|
if self.existing_objs:
|
|
77
|
-
text.append(
|
|
78
|
-
f" Parsed {len(self.existing_objs)} previously compiled objects.\n"
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
if self.files_to_obj:
|
|
82
|
-
text.append(" Compiled ")
|
|
83
|
-
text.append(f"{len(self.new_objs)} ", style="bold green")
|
|
84
|
-
text.append("objects from ")
|
|
85
|
-
text.append(f"{len(self.files_to_obj)} ", style="bold green")
|
|
86
|
-
text.append("files.\n")
|
|
73
|
+
text.append(f" Parsed {len(self.existing_objs)} previously compiled objects.\n")
|
|
87
74
|
|
|
88
75
|
if self.files_to_errors:
|
|
89
76
|
text.append(" Failed to compile ")
|
|
@@ -99,7 +86,7 @@ class ClassTracker:
|
|
|
99
86
|
for file, errors in self.files_to_errors.items():
|
|
100
87
|
text.append(" ERROR ", style="bold red")
|
|
101
88
|
text.append(f"- {file}:\n")
|
|
102
|
-
|
|
89
|
+
|
|
103
90
|
for error in errors:
|
|
104
91
|
# Format each error properly, handling newlines
|
|
105
92
|
error_msg = str(error)
|
|
@@ -108,5 +95,8 @@ class ClassTracker:
|
|
|
108
95
|
return text
|
|
109
96
|
|
|
110
97
|
# doesn't make sense to show deletes until the very end of compilation
|
|
111
|
-
def diff(self) -> Text:
|
|
98
|
+
def diff(self, ignore_python_errors: bool = False) -> Text:
|
|
99
|
+
# Don't show diff if there are compile errors - it's confusing
|
|
100
|
+
if self.files_to_errors and not ignore_python_errors:
|
|
101
|
+
return Text("\n❗Please fix python errors then retry compilation.\n", style="dim cyan")
|
|
112
102
|
return self.diff_result.render(deleted_names=self.deleted_names)
|