awx-zipline-ai 0.0.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. __init__.py +0 -0
  2. agent/__init__.py +1 -0
  3. agent/constants.py +15 -0
  4. agent/ttypes.py +1684 -0
  5. ai/__init__.py +0 -0
  6. ai/chronon/__init__.py +0 -0
  7. ai/chronon/airflow_helpers.py +248 -0
  8. ai/chronon/cli/__init__.py +0 -0
  9. ai/chronon/cli/compile/__init__.py +0 -0
  10. ai/chronon/cli/compile/column_hashing.py +336 -0
  11. ai/chronon/cli/compile/compile_context.py +173 -0
  12. ai/chronon/cli/compile/compiler.py +183 -0
  13. ai/chronon/cli/compile/conf_validator.py +742 -0
  14. ai/chronon/cli/compile/display/__init__.py +0 -0
  15. ai/chronon/cli/compile/display/class_tracker.py +102 -0
  16. ai/chronon/cli/compile/display/compile_status.py +95 -0
  17. ai/chronon/cli/compile/display/compiled_obj.py +12 -0
  18. ai/chronon/cli/compile/display/console.py +3 -0
  19. ai/chronon/cli/compile/display/diff_result.py +111 -0
  20. ai/chronon/cli/compile/fill_templates.py +35 -0
  21. ai/chronon/cli/compile/parse_configs.py +134 -0
  22. ai/chronon/cli/compile/parse_teams.py +242 -0
  23. ai/chronon/cli/compile/serializer.py +109 -0
  24. ai/chronon/cli/compile/version_utils.py +42 -0
  25. ai/chronon/cli/git_utils.py +145 -0
  26. ai/chronon/cli/logger.py +59 -0
  27. ai/chronon/constants.py +3 -0
  28. ai/chronon/group_by.py +692 -0
  29. ai/chronon/join.py +580 -0
  30. ai/chronon/logger.py +23 -0
  31. ai/chronon/model.py +40 -0
  32. ai/chronon/query.py +126 -0
  33. ai/chronon/repo/__init__.py +39 -0
  34. ai/chronon/repo/aws.py +284 -0
  35. ai/chronon/repo/cluster.py +136 -0
  36. ai/chronon/repo/compile.py +62 -0
  37. ai/chronon/repo/constants.py +164 -0
  38. ai/chronon/repo/default_runner.py +269 -0
  39. ai/chronon/repo/explore.py +418 -0
  40. ai/chronon/repo/extract_objects.py +134 -0
  41. ai/chronon/repo/gcp.py +586 -0
  42. ai/chronon/repo/gitpython_utils.py +15 -0
  43. ai/chronon/repo/hub_runner.py +261 -0
  44. ai/chronon/repo/hub_uploader.py +109 -0
  45. ai/chronon/repo/init.py +60 -0
  46. ai/chronon/repo/join_backfill.py +119 -0
  47. ai/chronon/repo/run.py +296 -0
  48. ai/chronon/repo/serializer.py +133 -0
  49. ai/chronon/repo/team_json_utils.py +46 -0
  50. ai/chronon/repo/utils.py +481 -0
  51. ai/chronon/repo/zipline.py +35 -0
  52. ai/chronon/repo/zipline_hub.py +277 -0
  53. ai/chronon/resources/__init__.py +0 -0
  54. ai/chronon/resources/gcp/__init__.py +0 -0
  55. ai/chronon/resources/gcp/group_bys/__init__.py +0 -0
  56. ai/chronon/resources/gcp/group_bys/test/__init__.py +0 -0
  57. ai/chronon/resources/gcp/group_bys/test/data.py +30 -0
  58. ai/chronon/resources/gcp/joins/__init__.py +0 -0
  59. ai/chronon/resources/gcp/joins/test/__init__.py +0 -0
  60. ai/chronon/resources/gcp/joins/test/data.py +26 -0
  61. ai/chronon/resources/gcp/sources/__init__.py +0 -0
  62. ai/chronon/resources/gcp/sources/test/__init__.py +0 -0
  63. ai/chronon/resources/gcp/sources/test/data.py +26 -0
  64. ai/chronon/resources/gcp/teams.py +58 -0
  65. ai/chronon/source.py +86 -0
  66. ai/chronon/staging_query.py +226 -0
  67. ai/chronon/types.py +58 -0
  68. ai/chronon/utils.py +510 -0
  69. ai/chronon/windows.py +48 -0
  70. awx_zipline_ai-0.0.32.dist-info/METADATA +197 -0
  71. awx_zipline_ai-0.0.32.dist-info/RECORD +96 -0
  72. awx_zipline_ai-0.0.32.dist-info/WHEEL +5 -0
  73. awx_zipline_ai-0.0.32.dist-info/entry_points.txt +2 -0
  74. awx_zipline_ai-0.0.32.dist-info/top_level.txt +4 -0
  75. gen_thrift/__init__.py +0 -0
  76. gen_thrift/api/__init__.py +1 -0
  77. gen_thrift/api/constants.py +15 -0
  78. gen_thrift/api/ttypes.py +3754 -0
  79. gen_thrift/common/__init__.py +1 -0
  80. gen_thrift/common/constants.py +15 -0
  81. gen_thrift/common/ttypes.py +1814 -0
  82. gen_thrift/eval/__init__.py +1 -0
  83. gen_thrift/eval/constants.py +15 -0
  84. gen_thrift/eval/ttypes.py +660 -0
  85. gen_thrift/fetcher/__init__.py +1 -0
  86. gen_thrift/fetcher/constants.py +15 -0
  87. gen_thrift/fetcher/ttypes.py +127 -0
  88. gen_thrift/hub/__init__.py +1 -0
  89. gen_thrift/hub/constants.py +15 -0
  90. gen_thrift/hub/ttypes.py +1109 -0
  91. gen_thrift/observability/__init__.py +1 -0
  92. gen_thrift/observability/constants.py +15 -0
  93. gen_thrift/observability/ttypes.py +2355 -0
  94. gen_thrift/planner/__init__.py +1 -0
  95. gen_thrift/planner/constants.py +15 -0
  96. gen_thrift/planner/ttypes.py +1967 -0
@@ -0,0 +1,742 @@
1
+ """Object for checking whether a Chronon API thrift object is consistent with other"""
2
+
3
+ # Copyright (C) 2023 The Chronon Authors.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import json
18
+ import logging
19
+ import re
20
+ import sys
21
+ import textwrap
22
+ from collections import defaultdict
23
+ from dataclasses import dataclass
24
+ from typing import Dict, List, Tuple
25
+
26
+ import gen_thrift.common.ttypes as common
27
+ from gen_thrift.api.ttypes import (
28
+ Accuracy,
29
+ Aggregation,
30
+ Derivation,
31
+ EventSource,
32
+ GroupBy,
33
+ Join,
34
+ JoinPart,
35
+ Source,
36
+ )
37
+
38
+ from ai.chronon.cli.compile.column_hashing import (
39
+ compute_group_by_columns_hashes,
40
+ get_pre_derived_group_by_columns,
41
+ get_pre_derived_group_by_features,
42
+ get_pre_derived_join_features,
43
+ get_pre_derived_source_keys,
44
+ )
45
+ from ai.chronon.cli.compile.version_utils import is_version_change
46
+ from ai.chronon.logger import get_logger
47
+ from ai.chronon.repo.serializer import thrift_simple_json
48
+ from ai.chronon.utils import get_query, get_root_source
49
+
50
+ # Fields that indicate status of the entities.
51
+ SKIPPED_FIELDS = frozenset(["metaData"])
52
+ EXTERNAL_KEY = "onlineExternalParts"
53
+
54
+
55
+ @dataclass
56
+ class ConfigChange:
57
+ """Represents a change to a compiled config object."""
58
+
59
+ name: str
60
+ obj_type: str
61
+ online: bool = False
62
+ production: bool = False
63
+ base_name: str = None
64
+ old_version: int = None
65
+ new_version: int = None
66
+ is_version_change: bool = False
67
+
68
+
69
+ def _filter_skipped_fields_from_join(json_obj: Dict, skipped_fields):
70
+ for join_part in json_obj["joinParts"]:
71
+ group_by = join_part["groupBy"]
72
+ for field in skipped_fields:
73
+ group_by.pop(field, None)
74
+ if EXTERNAL_KEY in json_obj:
75
+ json_obj.pop(EXTERNAL_KEY, None)
76
+
77
+
78
+ def _is_batch_upload_needed(group_by: GroupBy) -> bool:
79
+ if group_by.metaData.online or group_by.backfillStartDate:
80
+ return True
81
+ else:
82
+ return False
83
+
84
+
85
+ def is_identifier(s: str) -> bool:
86
+ identifier_regex = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*")
87
+ return re.fullmatch(identifier_regex, s) is not None
88
+
89
+
90
+ def _source_has_topic(source: Source) -> bool:
91
+ if source.events:
92
+ return source.events.topic is not None
93
+ elif source.entities:
94
+ return source.entities.mutationTopic is not None
95
+ elif source.joinSource:
96
+ return _source_has_topic(source.joinSource.join.left)
97
+ return False
98
+
99
+
100
+ def _group_by_has_topic(groupBy: GroupBy) -> bool:
101
+ return any(_source_has_topic(source) for source in groupBy.sources)
102
+
103
+
104
+ def _group_by_has_hourly_windows(groupBy: GroupBy) -> bool:
105
+ aggs: List[Aggregation] = groupBy.aggregations
106
+
107
+ if not aggs:
108
+ return False
109
+
110
+ for agg in aggs:
111
+ if not agg.windows:
112
+ return False
113
+
114
+ for window in agg.windows:
115
+ if window.timeUnit == common.TimeUnit.HOURS:
116
+ return True
117
+
118
+ return False
119
+
120
+
121
+ def detect_feature_name_collisions(
122
+ group_bys: List[Tuple[GroupBy, str]], entity_set_type: str, name: str
123
+ ) -> BaseException | None:
124
+ # Build a map of output_column -> set of group_by name
125
+ output_col_to_gbs = {}
126
+ for gb, prefix in group_bys:
127
+ jp_prefix_str = f"{prefix}_" if prefix else ""
128
+ key_str = "_".join(gb.keyColumns)
129
+ prefix_str = jp_prefix_str + key_str + "_"
130
+ cols = compute_group_by_columns_hashes(gb, exclude_keys=True)
131
+ if not cols:
132
+ print("HERE")
133
+ cols = {
134
+ f"{prefix_str}{base_col}"
135
+ for base_col in list(compute_group_by_columns_hashes(gb, exclude_keys=True).keys())
136
+ }
137
+ gb_name = gb.metaData.name
138
+ for col in cols:
139
+ if col not in output_col_to_gbs:
140
+ output_col_to_gbs[col] = set()
141
+ output_col_to_gbs[col].add(gb_name)
142
+
143
+ # Find output columns that are produced by more than one group_by
144
+ collisions = {col: gb_names for col, gb_names in output_col_to_gbs.items() if len(gb_names) > 1}
145
+
146
+ if not collisions:
147
+ return None # no collisions
148
+
149
+ # Assemble an error message listing the conflicting GroupBys by name
150
+ lines = [f"{entity_set_type} for Join: {name} has the following output name collisions:\n"]
151
+ for col, gb_names in collisions.items():
152
+ names_str = ", ".join(sorted(gb_names))
153
+ lines.append(f" - [{col}] has collisions from: [{names_str}]")
154
+
155
+ lines.append(
156
+ "\nConsider assigning distinct `prefix` values to the conflicting parts to avoid collisions."
157
+ )
158
+ return ValueError("\n".join(lines))
159
+
160
+
161
+ class ConfValidator(object):
162
+ """
163
+ applies repo wide validation rules
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ input_root,
169
+ output_root,
170
+ existing_gbs,
171
+ existing_joins,
172
+ existing_staging_queries,
173
+ log_level=logging.INFO,
174
+ ):
175
+ self.chronon_root_path = input_root
176
+ self.output_root = output_root
177
+
178
+ self.log_level = log_level
179
+ self.logger = get_logger(log_level)
180
+
181
+ # we keep the objs in the list not in a set since thrift does not
182
+ # implement __hash__ for ttypes object.
183
+
184
+ self.old_objs = defaultdict(dict)
185
+ self.old_group_bys = existing_gbs
186
+ self.old_joins = existing_joins
187
+ self.old_staging_queries = existing_staging_queries
188
+ self.old_objs["GroupBy"] = self.old_group_bys
189
+ self.old_objs["Join"] = self.old_joins
190
+ self.old_objs["StagingQuery"] = self.old_staging_queries
191
+
192
+ def _get_old_obj(self, obj_class: type, obj_name: str) -> object:
193
+ """
194
+ returns:
195
+ materialized version of the obj given the object's name.
196
+ """
197
+ class_name = obj_class.__name__
198
+
199
+ if class_name not in self.old_objs:
200
+ return None
201
+ obj_map = self.old_objs[class_name]
202
+
203
+ if obj_name not in obj_map:
204
+ return None
205
+ return obj_map[obj_name]
206
+
207
+ def _get_old_joins_with_group_by(self, group_by: GroupBy) -> List[Join]:
208
+ """
209
+ returns:
210
+ materialized joins including the group_by as dicts.
211
+ """
212
+ joins = []
213
+ for join in self.old_joins.values():
214
+ if join.joinParts is not None and group_by.metaData.name in [
215
+ rp.groupBy.metaData.name for rp in join.joinParts
216
+ ]:
217
+ joins.append(join)
218
+ return joins
219
+
220
+ def can_skip_materialize(self, obj: object) -> List[str]:
221
+ """
222
+ Check if the object can be skipped to be materialized and return reasons
223
+ if it can be.
224
+ """
225
+ reasons = []
226
+ if isinstance(obj, GroupBy):
227
+ if not _is_batch_upload_needed(obj):
228
+ reasons.append(
229
+ "GroupBys should not be materialized if batch upload job is not needed"
230
+ )
231
+ # Otherwise group_bys included in online join or are marked explicitly
232
+ # online itself are materialized.
233
+ elif not any(
234
+ join.metaData.online for join in self._get_old_joins_with_group_by(obj)
235
+ ) and not _is_batch_upload_needed(obj):
236
+ reasons.append("is not marked online/production nor is included in any online join")
237
+ return reasons
238
+
239
+ def validate_obj(self, obj: object) -> List[BaseException]:
240
+ """
241
+ Validate Chronon API obj against other entities in the repo.
242
+
243
+ returns:
244
+ list of errors.
245
+ """
246
+ if isinstance(obj, GroupBy):
247
+ return self._validate_group_by(obj)
248
+ elif isinstance(obj, Join):
249
+ return self._validate_join(obj)
250
+ return []
251
+
252
+ def _has_diff(self, obj: object, old_obj: object, skipped_fields=SKIPPED_FIELDS) -> bool:
253
+ new_json = {
254
+ k: v for k, v in json.loads(thrift_simple_json(obj)).items() if k not in skipped_fields
255
+ }
256
+ old_json = {
257
+ k: v
258
+ for k, v in json.loads(thrift_simple_json(old_obj)).items()
259
+ if k not in skipped_fields
260
+ }
261
+ if isinstance(obj, Join):
262
+ _filter_skipped_fields_from_join(new_json, skipped_fields)
263
+ _filter_skipped_fields_from_join(old_json, skipped_fields)
264
+
265
+ return new_json != old_json
266
+
267
+ def safe_to_overwrite(self, obj: object) -> bool:
268
+ """When an object is already materialized as online, it is no more safe
269
+ to materialize and overwrite the old conf.
270
+ """
271
+ old_obj = self._get_old_obj(type(obj), obj.metaData.name)
272
+ return not old_obj or not self._has_diff(obj, old_obj) or not old_obj.metaData.online
273
+
274
+ def _validate_derivations(
275
+ self, pre_derived_cols: List[str], derivations: List[Derivation]
276
+ ) -> List[BaseException]:
277
+ """
278
+ Validate join/groupBy's derivation is defined correctly.
279
+
280
+ Returns:
281
+ list of validation errors.
282
+ """
283
+ errors = []
284
+ derived_columns = set(pre_derived_cols)
285
+
286
+ wild_card_derivation_included = any(
287
+ derivation.expression == "*" for derivation in derivations
288
+ )
289
+ if not wild_card_derivation_included:
290
+ derived_columns.clear()
291
+ for derivation in derivations:
292
+ # if the derivation is a renaming derivation, check whether the expression is in pre-derived schema
293
+ if is_identifier(derivation.expression):
294
+ # for wildcard derivation we want to remove the original column if there is a renaming operation
295
+ # applied on it
296
+ if wild_card_derivation_included:
297
+ if derivation.expression in derived_columns:
298
+ derived_columns.remove(derivation.expression)
299
+ if derivation.expression not in pre_derived_cols and derivation.expression not in (
300
+ "ds",
301
+ "ts",
302
+ ):
303
+ errors.append(
304
+ ValueError(
305
+ "Incorrect derivation expression {}, expression not found in pre-derived columns {}".format(
306
+ derivation.expression, pre_derived_cols
307
+ )
308
+ )
309
+ )
310
+ if derivation.name != "*":
311
+ if derivation.name in derived_columns:
312
+ errors.append(
313
+ ValueError(
314
+ "Incorrect derivation name {} due to output column name conflict".format(
315
+ derivation.name
316
+ )
317
+ )
318
+ )
319
+ else:
320
+ derived_columns.add(derivation.name)
321
+ return errors
322
+
323
+ def _validate_join_part_keys(self, join_part: JoinPart, left_cols: List[str]) -> BaseException:
324
+ keys = []
325
+
326
+ key_mapping = join_part.keyMapping if join_part.keyMapping else {}
327
+ for key in join_part.groupBy.keyColumns:
328
+ keys.append(key_mapping.get(key, key))
329
+
330
+ missing = [k for k in keys if k not in left_cols]
331
+
332
+ err_string = ""
333
+ left_cols_as_str = ", ".join(left_cols)
334
+ group_by_name = join_part.groupBy.metaData.name
335
+ if missing:
336
+ key_mapping_str = f"Key Mapping: {key_mapping}" if key_mapping else ""
337
+ err_string += textwrap.dedent(f"""
338
+ - Join is missing keys {missing} on left side. Required for JoinPart: {group_by_name}.
339
+ Existing columns on left side: {left_cols_as_str}
340
+ All required Keys: {join_part.groupBy.keyColumns}
341
+ {key_mapping_str}
342
+ Consider renaming a column on the left, or including the key_mapping argument to your join_part.""")
343
+
344
+ if key_mapping:
345
+ # Left side of key mapping should include columns on the left
346
+ key_map_keys_missing_from_left = [k for k in key_mapping.keys() if k not in left_cols]
347
+ if key_map_keys_missing_from_left:
348
+ err_string += f"\n- The following keys in your key_mapping: {str(key_map_keys_missing_from_left)} for JoinPart {group_by_name} are not included in the left side of the join: {left_cols_as_str}"
349
+
350
+ # Right side of key mapping should only include keys in GroupBy
351
+ keys_missing_from_key_map_values = [
352
+ v for v in key_mapping.values() if v not in join_part.groupBy.keyColumns
353
+ ]
354
+ if keys_missing_from_key_map_values:
355
+ err_string += f"\n- The following values in your key_mapping: {str(keys_missing_from_key_map_values)} for JoinPart {group_by_name} do not cover any group by key columns: {join_part.groupBy.keyColumns}"
356
+
357
+ if key_map_keys_missing_from_left or keys_missing_from_key_map_values:
358
+ err_string += (
359
+ "\n(Key Mapping should be formatted as column_from_left -> group_by_key)"
360
+ )
361
+
362
+ if err_string:
363
+ return ValueError(err_string)
364
+
365
+ def _validate_keys(self, join: Join) -> List[BaseException]:
366
+ left = join.left
367
+
368
+ left_selects = None
369
+ if left.events:
370
+ left_selects = left.events.query.selects
371
+ elif left.entities:
372
+ left_selects = left.entities.query.selects
373
+ elif left.joinSource:
374
+ left_selects = left.joinSource.query.selects
375
+ # TODO -- if selects are not selected here, get output cols from join
376
+
377
+ left_cols = []
378
+
379
+ if left_selects:
380
+ left_cols = left_selects.keys()
381
+
382
+ errors = []
383
+
384
+ if left_cols:
385
+ join_parts = list(join.joinParts) # Create a copy to avoid modifying the original
386
+
387
+ # Add label_parts to join_parts to validate if set
388
+ label_parts = join.labelParts
389
+ if label_parts:
390
+ for label_jp in label_parts.labels:
391
+ join_parts.append(label_jp)
392
+
393
+ # Validate join_parts
394
+ for join_part in join_parts:
395
+ join_part_err = self._validate_join_part_keys(join_part, left_cols)
396
+ if join_part_err:
397
+ errors.append(join_part_err)
398
+
399
+ return errors
400
+
401
+ def _validate_join(self, join: Join) -> List[BaseException]:
402
+ """
403
+ Validate join's status with materialized versions of group_bys
404
+ included by the join.
405
+
406
+ Returns:
407
+ list of validation errors.
408
+ """
409
+ included_group_bys_and_prefixes = [(rp.groupBy, rp.prefix) for rp in join.joinParts]
410
+ # TODO: Remove label parts check in future PR that deprecates label_parts
411
+ included_label_parts_and_prefixes = (
412
+ [(lp.groupBy, lp.prefix) for lp in join.labelParts.labels] if join.labelParts else []
413
+ )
414
+ included_group_bys = [tup[0] for tup in included_group_bys_and_prefixes]
415
+
416
+ offline_included_group_bys = [
417
+ gb.metaData.name
418
+ for gb in included_group_bys
419
+ if not gb.metaData or gb.metaData.online is False
420
+ ]
421
+ errors = []
422
+ old_group_bys = [
423
+ group_by
424
+ for group_by in included_group_bys
425
+ if self._get_old_obj(GroupBy, group_by.metaData.name)
426
+ ]
427
+ non_prod_old_group_bys = [
428
+ group_by.metaData.name
429
+ for group_by in old_group_bys
430
+ if group_by.metaData.production is False
431
+ ]
432
+ # Check if the underlying groupBy is valid
433
+ group_by_errors = [self._validate_group_by(group_by) for group_by in included_group_bys]
434
+ errors += [
435
+ ValueError(f"join {join.metaData.name}'s underlying {error}")
436
+ for errors in group_by_errors
437
+ for error in errors
438
+ ]
439
+ # Check if the production join is using non production groupBy
440
+ if join.metaData.production and non_prod_old_group_bys:
441
+ errors.append(
442
+ ValueError(
443
+ "join {} is production but includes the following non production group_bys: {}".format(
444
+ join.metaData.name, ", ".join(non_prod_old_group_bys)
445
+ )
446
+ )
447
+ )
448
+ # Check if the online join is using the offline groupBy
449
+ if join.metaData.online:
450
+ if offline_included_group_bys:
451
+ errors.append(
452
+ ValueError(
453
+ "join {} is online but includes the following offline group_bys: {}".format(
454
+ join.metaData.name, ", ".join(offline_included_group_bys)
455
+ )
456
+ )
457
+ )
458
+ # Only validate the join derivation when the underlying groupBy is valid
459
+ group_by_correct = all(not errors for errors in group_by_errors)
460
+ if join.derivations and group_by_correct:
461
+ features = list(get_pre_derived_join_features(join).keys())
462
+ # For online joins keys are not included in output schema
463
+ if join.metaData.online:
464
+ columns = features
465
+ else:
466
+ keys = list(get_pre_derived_source_keys(join.left).keys())
467
+ columns = features + keys
468
+ errors.extend(self._validate_derivations(columns, join.derivations))
469
+
470
+ errors.extend(self._validate_keys(join))
471
+
472
+ # If the join is using "short" names, ensure that there are no collisions
473
+ if join.useLongNames is False:
474
+ right_part_collisions = detect_feature_name_collisions(
475
+ included_group_bys_and_prefixes, "right parts", join.metaData.name
476
+ )
477
+ if right_part_collisions:
478
+ errors.append(right_part_collisions)
479
+
480
+ label_part_collisions = detect_feature_name_collisions(
481
+ included_label_parts_and_prefixes, "label parts", join.metaData.name
482
+ )
483
+ if label_part_collisions:
484
+ errors.append(label_part_collisions)
485
+
486
+ return errors
487
+
488
+ def _validate_group_by(self, group_by: GroupBy) -> List[BaseException]:
489
+ """
490
+ Validate group_by's status with materialized versions of joins
491
+ including the group_by.
492
+
493
+ Return:
494
+ List of validation errors.
495
+ """
496
+ joins = self._get_old_joins_with_group_by(group_by)
497
+ online_joins = [join.metaData.name for join in joins if join.metaData.online is True]
498
+ prod_joins = [join.metaData.name for join in joins if join.metaData.production is True]
499
+ errors = []
500
+
501
+ non_temporal = group_by.accuracy is None or group_by.accuracy == Accuracy.SNAPSHOT
502
+
503
+ no_topic = not _group_by_has_topic(group_by)
504
+ has_hourly_windows = _group_by_has_hourly_windows(group_by)
505
+
506
+ # batch features cannot contain hourly windows
507
+ if (no_topic and non_temporal) and has_hourly_windows:
508
+ errors.append(
509
+ ValueError(
510
+ f"group_by {group_by.metaData.name} is defined to be daily refreshed but contains "
511
+ f"hourly windows. "
512
+ )
513
+ )
514
+
515
+ def _validate_bounded_event_source():
516
+ if group_by.aggregations is None:
517
+ return
518
+
519
+ unbounded_event_sources = [
520
+ str(src)
521
+ for src in group_by.sources
522
+ if isinstance(get_root_source(src), EventSource)
523
+ and get_query(src).startPartition is None
524
+ ]
525
+
526
+ if not unbounded_event_sources:
527
+ return
528
+
529
+ unwindowed_aggregations = [
530
+ str(agg) for agg in group_by.aggregations if agg.windows is None
531
+ ]
532
+
533
+ if not unwindowed_aggregations:
534
+ return
535
+
536
+ nln = "\n"
537
+
538
+ errors.append(
539
+ ValueError(
540
+ f"""group_by {group_by.metaData.name} uses unwindowed aggregations [{nln}{f",{nln}".join(unwindowed_aggregations)}{nln}]
541
+ on unbounded event sources: [{nln}{f",{nln}".join(unbounded_event_sources)}{nln}].
542
+ Please set a start_partition on the source, or a window on the aggregation."""
543
+ )
544
+ )
545
+
546
+ _validate_bounded_event_source()
547
+
548
+ # group by that are marked explicitly offline should not be present in
549
+ # materialized online joins.
550
+ if group_by.metaData.online is False and online_joins:
551
+ errors.append(
552
+ ValueError(
553
+ "group_by {} is explicitly marked offline but included in "
554
+ "the following online joins: {}".format(
555
+ group_by.metaData.name, ", ".join(online_joins)
556
+ )
557
+ )
558
+ )
559
+ # group by that are marked explicitly non-production should not be
560
+ # present in materialized production joins.
561
+ if prod_joins:
562
+ if group_by.metaData.production is False:
563
+ errors.append(
564
+ ValueError(
565
+ "group_by {} is explicitly marked as non-production but included in the following production "
566
+ "joins: {}".format(group_by.metaData.name, ", ".join(prod_joins))
567
+ )
568
+ )
569
+ # if the group by is included in any of materialized production join,
570
+ # set it to production in the materialized output.
571
+ else:
572
+ group_by.metaData.production = True
573
+
574
+ # validate the derivations are defined correctly
575
+ if group_by.derivations:
576
+ # For online group_by keys are not included in output schema
577
+ if group_by.metaData.online:
578
+ columns = list(get_pre_derived_group_by_features(group_by).keys())
579
+ else:
580
+ columns = list(get_pre_derived_group_by_columns(group_by).keys())
581
+ errors.extend(self._validate_derivations(columns, group_by.derivations))
582
+
583
+ for source in group_by.sources:
584
+ src: Source = source
585
+ if src.events and src.events.isCumulative and (src.events.query.timeColumn is None):
586
+ errors.append(
587
+ ValueError(
588
+ "Please set query.timeColumn for Cumulative Events Table: {}".format(
589
+ src.events.table
590
+ )
591
+ )
592
+ )
593
+ elif src.joinSource:
594
+ join_obj = src.joinSource.join
595
+ if join_obj.metaData.name is None or join_obj.metaData.team is None:
596
+ errors.append(
597
+ ValueError(f"Join must be defined with team and name: {join_obj}")
598
+ )
599
+ return errors
600
+
601
+ def validate_changes(self, results):
602
+ """
603
+ Validate changes in compiled objects against existing materialized objects.
604
+
605
+ Args:
606
+ results: List of CompiledObj from parse_configs
607
+
608
+ Returns:
609
+ None (exits on user cancellation)
610
+ """
611
+
612
+ # Filter out results with errors and only process GroupBy/Join
613
+ valid_results = []
614
+ for result in results:
615
+ if result.errors:
616
+ continue # Skip results with errors
617
+ if result.obj_type not in ["GroupBy", "Join", "StagingQuery"]:
618
+ continue # Skip non-GroupBy/Join objects
619
+ valid_results.append(result)
620
+
621
+ # Categorize changes
622
+ changed_objects = {"changed": [], "deleted": [], "added": []}
623
+
624
+ # Process each valid result
625
+ for result in valid_results:
626
+ obj_name = result.obj.metaData.name
627
+ obj_type = result.obj_type
628
+
629
+ # Check if object exists in old objects
630
+ old_obj = self._get_old_obj(type(result.obj), obj_name)
631
+
632
+ if old_obj is None:
633
+ # New object
634
+ change = self._create_config_change(result.obj, obj_type)
635
+ changed_objects["added"].append(change)
636
+ elif self._has_diff(result.obj, old_obj):
637
+ # Modified object
638
+ change = self._create_config_change(result.obj, obj_type)
639
+ changed_objects["changed"].append(change)
640
+
641
+ # Check for deleted objects
642
+ # Build set of current objects from valid_results for lookup
643
+ current_objects = {(result.obj_type, result.obj.metaData.name) for result in valid_results}
644
+
645
+ for obj_type in ["GroupBy", "Join", "StagingQuery"]:
646
+ old_objs = self.old_objs.get(obj_type, {})
647
+ for obj_name, old_obj in old_objs.items():
648
+ if (obj_type, obj_name) not in current_objects:
649
+ # Object was deleted
650
+ change = self._create_config_change(old_obj, obj_type)
651
+ changed_objects["deleted"].append(change)
652
+
653
+ # Store changes for later confirmation check
654
+ self._pending_changes = {
655
+ "changed": changed_objects["changed"],
656
+ "deleted": changed_objects["deleted"],
657
+ "added": changed_objects["added"],
658
+ }
659
+
660
+ def _filter_non_version_changes(self, existing_changes, added_changes):
661
+ """Filter out version changes from existing changes.
662
+
663
+ Returns list of changes that are NOT version bumps and require confirmation.
664
+ """
665
+ added_names = {change.name for change in added_changes}
666
+ non_version_changes = []
667
+
668
+ for change in existing_changes:
669
+ # Check if this deleted config has a corresponding added config that represents a version bump
670
+ is_version_bump = any(
671
+ is_version_change(change.name, added_name) for added_name in added_names
672
+ )
673
+
674
+ if not is_version_bump:
675
+ non_version_changes.append(change)
676
+
677
+ return non_version_changes
678
+
679
+ def check_pending_changes_confirmation(self, compile_status):
680
+ """Check if user confirmation is needed for pending changes after display."""
681
+ from ai.chronon.cli.compile.display.console import console
682
+
683
+ # Skip confirmation if there are compilation errors
684
+ if self._has_compilation_errors(compile_status):
685
+ return # Don't prompt when there are errors
686
+
687
+ if not hasattr(self, "_pending_changes"):
688
+ return # No pending changes
689
+
690
+ # Check if we need user confirmation (only for non-version-bump changes)
691
+ non_version_changes = self._filter_non_version_changes(
692
+ self._pending_changes["changed"] + self._pending_changes["deleted"],
693
+ self._pending_changes["added"],
694
+ )
695
+
696
+ if non_version_changes:
697
+ if not self._prompt_user_confirmation():
698
+ console.print("❌ Compilation cancelled by user.")
699
+ sys.exit(1)
700
+
701
+ def _has_compilation_errors(self, compile_status):
702
+ """Check if there are any compilation errors across all class trackers."""
703
+ for tracker in compile_status.cls_to_tracker.values():
704
+ if tracker.files_to_errors:
705
+ return True
706
+ return False
707
+
708
+ def _create_config_change(self, obj, obj_type):
709
+ """Create a ConfigChange object from a thrift object."""
710
+ return ConfigChange(
711
+ name=obj.metaData.name,
712
+ obj_type=obj_type,
713
+ online=obj.metaData.online if obj.metaData.online else False,
714
+ production=obj.metaData.production if obj.metaData.production else False,
715
+ )
716
+
717
+ def _prompt_user_confirmation(self) -> bool:
718
+ """
719
+ Prompt user for Y/N confirmation to proceed with overwriting existing configs.
720
+ Returns True if user confirms, False otherwise.
721
+ """
722
+ from ai.chronon.cli.compile.display.console import console
723
+
724
+ console.print(
725
+ "\n❗ [bold red3]Some configs are changing in-place (changing semantics without changing the version).[/bold red3]"
726
+ )
727
+ console.print(
728
+ "[dim]Note that changes can be caused by directly modifying a config, or by changing the version of an upstream object which changes the input source.[/dim]"
729
+ )
730
+ console.print("")
731
+ console.print(
732
+ "[dim]This can be a safe operation if you are sure that the configs that are unaffected by the change (i.e. upstream config is versioning up in a way that does not effect the downstream), or if the objects in question are not yet used by critical workloads.[/dim]"
733
+ )
734
+ console.print("")
735
+ console.print("❓ [bold]Do you want to proceed? (y/N):[/bold]", end=" ")
736
+
737
+ try:
738
+ response = input().strip().lower()
739
+ return response in ["y", "yes"]
740
+ except (EOFError, KeyboardInterrupt):
741
+ console.print("\n❌ Compilation cancelled.")
742
+ return False