awx-zipline-ai 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. agent/__init__.py +1 -0
  2. agent/constants.py +15 -0
  3. agent/ttypes.py +1684 -0
  4. ai/__init__.py +0 -0
  5. ai/chronon/__init__.py +0 -0
  6. ai/chronon/airflow_helpers.py +251 -0
  7. ai/chronon/api/__init__.py +1 -0
  8. ai/chronon/api/common/__init__.py +1 -0
  9. ai/chronon/api/common/constants.py +15 -0
  10. ai/chronon/api/common/ttypes.py +1844 -0
  11. ai/chronon/api/constants.py +15 -0
  12. ai/chronon/api/ttypes.py +3624 -0
  13. ai/chronon/cli/compile/column_hashing.py +313 -0
  14. ai/chronon/cli/compile/compile_context.py +177 -0
  15. ai/chronon/cli/compile/compiler.py +160 -0
  16. ai/chronon/cli/compile/conf_validator.py +590 -0
  17. ai/chronon/cli/compile/display/class_tracker.py +112 -0
  18. ai/chronon/cli/compile/display/compile_status.py +95 -0
  19. ai/chronon/cli/compile/display/compiled_obj.py +12 -0
  20. ai/chronon/cli/compile/display/console.py +3 -0
  21. ai/chronon/cli/compile/display/diff_result.py +46 -0
  22. ai/chronon/cli/compile/fill_templates.py +40 -0
  23. ai/chronon/cli/compile/parse_configs.py +141 -0
  24. ai/chronon/cli/compile/parse_teams.py +238 -0
  25. ai/chronon/cli/compile/serializer.py +115 -0
  26. ai/chronon/cli/git_utils.py +156 -0
  27. ai/chronon/cli/logger.py +61 -0
  28. ai/chronon/constants.py +3 -0
  29. ai/chronon/eval/__init__.py +122 -0
  30. ai/chronon/eval/query_parsing.py +19 -0
  31. ai/chronon/eval/sample_tables.py +100 -0
  32. ai/chronon/eval/table_scan.py +186 -0
  33. ai/chronon/fetcher/__init__.py +1 -0
  34. ai/chronon/fetcher/constants.py +15 -0
  35. ai/chronon/fetcher/ttypes.py +127 -0
  36. ai/chronon/group_by.py +692 -0
  37. ai/chronon/hub/__init__.py +1 -0
  38. ai/chronon/hub/constants.py +15 -0
  39. ai/chronon/hub/ttypes.py +1228 -0
  40. ai/chronon/join.py +566 -0
  41. ai/chronon/logger.py +24 -0
  42. ai/chronon/model.py +35 -0
  43. ai/chronon/observability/__init__.py +1 -0
  44. ai/chronon/observability/constants.py +15 -0
  45. ai/chronon/observability/ttypes.py +2192 -0
  46. ai/chronon/orchestration/__init__.py +1 -0
  47. ai/chronon/orchestration/constants.py +15 -0
  48. ai/chronon/orchestration/ttypes.py +4406 -0
  49. ai/chronon/planner/__init__.py +1 -0
  50. ai/chronon/planner/constants.py +15 -0
  51. ai/chronon/planner/ttypes.py +1686 -0
  52. ai/chronon/query.py +126 -0
  53. ai/chronon/repo/__init__.py +40 -0
  54. ai/chronon/repo/aws.py +298 -0
  55. ai/chronon/repo/cluster.py +65 -0
  56. ai/chronon/repo/compile.py +56 -0
  57. ai/chronon/repo/constants.py +164 -0
  58. ai/chronon/repo/default_runner.py +291 -0
  59. ai/chronon/repo/explore.py +421 -0
  60. ai/chronon/repo/extract_objects.py +137 -0
  61. ai/chronon/repo/gcp.py +585 -0
  62. ai/chronon/repo/gitpython_utils.py +14 -0
  63. ai/chronon/repo/hub_runner.py +171 -0
  64. ai/chronon/repo/hub_uploader.py +108 -0
  65. ai/chronon/repo/init.py +53 -0
  66. ai/chronon/repo/join_backfill.py +105 -0
  67. ai/chronon/repo/run.py +293 -0
  68. ai/chronon/repo/serializer.py +141 -0
  69. ai/chronon/repo/team_json_utils.py +46 -0
  70. ai/chronon/repo/utils.py +472 -0
  71. ai/chronon/repo/zipline.py +51 -0
  72. ai/chronon/repo/zipline_hub.py +105 -0
  73. ai/chronon/resources/gcp/README.md +174 -0
  74. ai/chronon/resources/gcp/group_bys/test/__init__.py +0 -0
  75. ai/chronon/resources/gcp/group_bys/test/data.py +34 -0
  76. ai/chronon/resources/gcp/joins/test/__init__.py +0 -0
  77. ai/chronon/resources/gcp/joins/test/data.py +30 -0
  78. ai/chronon/resources/gcp/sources/test/__init__.py +0 -0
  79. ai/chronon/resources/gcp/sources/test/data.py +23 -0
  80. ai/chronon/resources/gcp/teams.py +70 -0
  81. ai/chronon/resources/gcp/zipline-cli-install.sh +54 -0
  82. ai/chronon/source.py +88 -0
  83. ai/chronon/staging_query.py +185 -0
  84. ai/chronon/types.py +57 -0
  85. ai/chronon/utils.py +557 -0
  86. ai/chronon/windows.py +50 -0
  87. awx_zipline_ai-0.2.0.dist-info/METADATA +173 -0
  88. awx_zipline_ai-0.2.0.dist-info/RECORD +93 -0
  89. awx_zipline_ai-0.2.0.dist-info/WHEEL +5 -0
  90. awx_zipline_ai-0.2.0.dist-info/entry_points.txt +2 -0
  91. awx_zipline_ai-0.2.0.dist-info/licenses/LICENSE +202 -0
  92. awx_zipline_ai-0.2.0.dist-info/top_level.txt +3 -0
  93. jars/__init__.py +0 -0
@@ -0,0 +1,590 @@
1
+ """Object for checking whether a Chronon API thrift object is consistent with other
2
+ """
3
+
4
+ # Copyright (C) 2023 The Chronon Authors.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import json
19
+ import logging
20
+ import re
21
+ import textwrap
22
+ from collections import defaultdict
23
+ from typing import Dict, List, Tuple
24
+
25
+ import ai.chronon.api.common.ttypes as common
26
+ from ai.chronon.api.ttypes import (
27
+ Accuracy,
28
+ Aggregation,
29
+ Derivation,
30
+ EventSource,
31
+ GroupBy,
32
+ Join,
33
+ JoinPart,
34
+ Source,
35
+ )
36
+ from ai.chronon.cli.compile.column_hashing import (
37
+ compute_group_by_columns_hashes,
38
+ get_pre_derived_group_by_columns,
39
+ get_pre_derived_group_by_features,
40
+ get_pre_derived_join_features,
41
+ get_pre_derived_source_keys,
42
+ )
43
+ from ai.chronon.logger import get_logger
44
+ from ai.chronon.repo.serializer import thrift_simple_json
45
+ from ai.chronon.utils import get_query, get_root_source
46
+
47
+ # Fields that indicate status of the entities.
48
+ SKIPPED_FIELDS = frozenset(["metaData"])
49
+ EXTERNAL_KEY = "onlineExternalParts"
50
+
51
+
52
+ def _filter_skipped_fields_from_join(json_obj: Dict, skipped_fields):
53
+ for join_part in json_obj["joinParts"]:
54
+ group_by = join_part["groupBy"]
55
+ for field in skipped_fields:
56
+ group_by.pop(field, None)
57
+ if EXTERNAL_KEY in json_obj:
58
+ json_obj.pop(EXTERNAL_KEY, None)
59
+
60
+
61
+ def _is_batch_upload_needed(group_by: GroupBy) -> bool:
62
+ if group_by.metaData.online or group_by.backfillStartDate:
63
+ return True
64
+ else:
65
+ return False
66
+
67
+
68
+ def is_identifier(s: str) -> bool:
69
+ identifier_regex = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*")
70
+ return re.fullmatch(identifier_regex, s) is not None
71
+
72
+
73
+ def _source_has_topic(source: Source) -> bool:
74
+ if source.events:
75
+ return source.events.topic is not None
76
+ elif source.entities:
77
+ return source.entities.mutationTopic is not None
78
+ elif source.joinSource:
79
+ return _source_has_topic(source.joinSource.join.left)
80
+ return False
81
+
82
+
83
+ def _group_by_has_topic(groupBy: GroupBy) -> bool:
84
+ return any(_source_has_topic(source) for source in groupBy.sources)
85
+
86
+
87
+ def _group_by_has_hourly_windows(groupBy: GroupBy) -> bool:
88
+ aggs: List[Aggregation] = groupBy.aggregations
89
+
90
+ if not aggs:
91
+ return False
92
+
93
+ for agg in aggs:
94
+
95
+ if not agg.windows:
96
+ return False
97
+
98
+ for window in agg.windows:
99
+ if window.timeUnit == common.TimeUnit.HOURS:
100
+ return True
101
+
102
+ return False
103
+
104
+
105
+ def detect_feature_name_collisions(
106
+ group_bys: List[Tuple[GroupBy, str]],
107
+ entity_set_type: str,
108
+ name: str
109
+ ) -> BaseException | None:
110
+ # Build a map of output_column -> set of group_by name
111
+ output_col_to_gbs = {}
112
+ for gb, prefix in group_bys:
113
+ jp_prefix_str = f"{prefix}_" if prefix else ""
114
+ key_str = "_".join(gb.keyColumns)
115
+ prefix_str = jp_prefix_str + key_str + "_"
116
+ cols = compute_group_by_columns_hashes(gb, exclude_keys=True)
117
+ if not cols:
118
+ print('HERE')
119
+ cols = {
120
+ f"{prefix_str}{base_col}"
121
+ for base_col in list(compute_group_by_columns_hashes(gb, exclude_keys=True).keys())
122
+ }
123
+ gb_name = gb.metaData.name
124
+ for col in cols:
125
+ if col not in output_col_to_gbs:
126
+ output_col_to_gbs[col] = set()
127
+ output_col_to_gbs[col].add(gb_name)
128
+
129
+ # Find output columns that are produced by more than one group_by
130
+ collisions = {
131
+ col: gb_names
132
+ for col, gb_names in output_col_to_gbs.items()
133
+ if len(gb_names) > 1
134
+ }
135
+
136
+ if not collisions:
137
+ return None # no collisions
138
+
139
+ # Assemble an error message listing the conflicting GroupBys by name
140
+ lines = [
141
+ f"{entity_set_type} for Join: {name} has the following output name collisions:\n"
142
+ ]
143
+ for col, gb_names in collisions.items():
144
+ names_str = ", ".join(sorted(gb_names))
145
+ lines.append(f" - [{col}] has collisions from: [{names_str}]")
146
+
147
+ lines.append(
148
+ "\nConsider assigning distinct `prefix` values to the conflicting parts to avoid collisions."
149
+ )
150
+ return ValueError("\n".join(lines))
151
+
152
+
153
+ class ConfValidator(object):
154
+ """
155
+ applies repo wide validation rules
156
+ """
157
+
158
+ def __init__(
159
+ self,
160
+ input_root,
161
+ output_root,
162
+ existing_gbs,
163
+ existing_joins,
164
+ log_level=logging.INFO,
165
+ ):
166
+
167
+ self.chronon_root_path = input_root
168
+ self.output_root = output_root
169
+
170
+ self.log_level = log_level
171
+ self.logger = get_logger(log_level)
172
+
173
+ # we keep the objs in the list not in a set since thrift does not
174
+ # implement __hash__ for ttypes object.
175
+
176
+ self.old_objs = defaultdict(dict)
177
+ self.old_group_bys = existing_gbs
178
+ self.old_joins = existing_joins
179
+ self.old_objs["GroupBy"] = self.old_group_bys
180
+ self.old_objs["Join"] = self.old_joins
181
+
182
+ def _get_old_obj(self, obj_class: type, obj_name: str) -> object:
183
+ """
184
+ returns:
185
+ materialized version of the obj given the object's name.
186
+ """
187
+ class_name = obj_class.__name__
188
+
189
+ if class_name not in self.old_objs:
190
+ return None
191
+ obj_map = self.old_objs[class_name]
192
+
193
+ if obj_name not in obj_map:
194
+ return None
195
+ return obj_map[obj_name]
196
+
197
+ def _get_old_joins_with_group_by(self, group_by: GroupBy) -> List[Join]:
198
+ """
199
+ returns:
200
+ materialized joins including the group_by as dicts.
201
+ """
202
+ joins = []
203
+ for join in self.old_joins.values():
204
+ if join.joinParts is not None and group_by.metaData.name in [
205
+ rp.groupBy.metaData.name for rp in join.joinParts
206
+ ]:
207
+ joins.append(join)
208
+ return joins
209
+
210
+ def can_skip_materialize(self, obj: object) -> List[str]:
211
+ """
212
+ Check if the object can be skipped to be materialized and return reasons
213
+ if it can be.
214
+ """
215
+ reasons = []
216
+ if isinstance(obj, GroupBy):
217
+ if not _is_batch_upload_needed(obj):
218
+ reasons.append(
219
+ "GroupBys should not be materialized if batch upload job is not needed"
220
+ )
221
+ # Otherwise group_bys included in online join or are marked explicitly
222
+ # online itself are materialized.
223
+ elif not any(
224
+ join.metaData.online for join in self._get_old_joins_with_group_by(obj)
225
+ ) and not _is_batch_upload_needed(obj):
226
+ reasons.append(
227
+ "is not marked online/production nor is included in any online join"
228
+ )
229
+ return reasons
230
+
231
+ def validate_obj(self, obj: object) -> List[BaseException]:
232
+ """
233
+ Validate Chronon API obj against other entities in the repo.
234
+
235
+ returns:
236
+ list of errors.
237
+ """
238
+ if isinstance(obj, GroupBy):
239
+ return self._validate_group_by(obj)
240
+ elif isinstance(obj, Join):
241
+ return self._validate_join(obj)
242
+ return []
243
+
244
+ def _has_diff(
245
+ self, obj: object, old_obj: object, skipped_fields=SKIPPED_FIELDS
246
+ ) -> bool:
247
+ new_json = {
248
+ k: v
249
+ for k, v in json.loads(thrift_simple_json(obj)).items()
250
+ if k not in skipped_fields
251
+ }
252
+ old_json = {
253
+ k: v
254
+ for k, v in json.loads(thrift_simple_json(old_obj)).items()
255
+ if k not in skipped_fields
256
+ }
257
+ if isinstance(obj, Join):
258
+ _filter_skipped_fields_from_join(new_json, skipped_fields)
259
+ _filter_skipped_fields_from_join(old_json, skipped_fields)
260
+ return new_json != old_json
261
+
262
+ def safe_to_overwrite(self, obj: object) -> bool:
263
+ """When an object is already materialized as online, it is no more safe
264
+ to materialize and overwrite the old conf.
265
+ """
266
+ old_obj = self._get_old_obj(type(obj), obj.metaData.name)
267
+ return (
268
+ not old_obj
269
+ or not self._has_diff(obj, old_obj)
270
+ or not old_obj.metaData.online
271
+ )
272
+
273
+ def _validate_derivations(
274
+ self, pre_derived_cols: List[str], derivations: List[Derivation]
275
+ ) -> List[BaseException]:
276
+ """
277
+ Validate join/groupBy's derivation is defined correctly.
278
+
279
+ Returns:
280
+ list of validation errors.
281
+ """
282
+ errors = []
283
+ derived_columns = set(pre_derived_cols)
284
+
285
+ wild_card_derivation_included = any(
286
+ derivation.expression == "*" for derivation in derivations
287
+ )
288
+ if not wild_card_derivation_included:
289
+ derived_columns.clear()
290
+ for derivation in derivations:
291
+ # if the derivation is a renaming derivation, check whether the expression is in pre-derived schema
292
+ if is_identifier(derivation.expression):
293
+ # for wildcard derivation we want to remove the original column if there is a renaming operation
294
+ # applied on it
295
+ if wild_card_derivation_included:
296
+ if derivation.expression in derived_columns:
297
+ derived_columns.remove(derivation.expression)
298
+ if (
299
+ derivation.expression not in pre_derived_cols
300
+ and derivation.expression not in ("ds", "ts")
301
+ ):
302
+ errors.append(
303
+ ValueError("Incorrect derivation expression {}, expression not found in pre-derived columns {}"
304
+ .format(
305
+ derivation.expression, pre_derived_cols
306
+ ))
307
+ )
308
+ if derivation.name != "*":
309
+ if derivation.name in derived_columns:
310
+ errors.append(
311
+ ValueError("Incorrect derivation name {} due to output column name conflict".format(
312
+ derivation.name
313
+ )
314
+ ))
315
+ else:
316
+ derived_columns.add(derivation.name)
317
+ return errors
318
+
319
+ def _validate_join_part_keys(self, join_part: JoinPart, left_cols: List[str]) -> BaseException:
320
+ keys = []
321
+
322
+ key_mapping = join_part.keyMapping if join_part.keyMapping else {}
323
+ for key in join_part.groupBy.keyColumns:
324
+ keys.append(key_mapping.get(key, key))
325
+
326
+ missing = [k for k in keys if k not in left_cols]
327
+
328
+ err_string = ""
329
+ left_cols_as_str = ", ".join(left_cols)
330
+ group_by_name = join_part.groupBy.metaData.name
331
+ if missing:
332
+ key_mapping_str = f"Key Mapping: {key_mapping}" if key_mapping else ""
333
+ err_string += textwrap.dedent(f"""
334
+ - Join is missing keys {missing} on left side. Required for JoinPart: {group_by_name}.
335
+ Existing columns on left side: {left_cols_as_str}
336
+ All required Keys: {join_part.groupBy.keyColumns}
337
+ {key_mapping_str}
338
+ Consider renaming a column on the left, or including the key_mapping argument to your join_part.""")
339
+
340
+ if key_mapping:
341
+ # Left side of key mapping should include columns on the left
342
+ key_map_keys_missing_from_left = [k for k in key_mapping.keys() if k not in left_cols]
343
+ if key_map_keys_missing_from_left:
344
+ err_string += f"\n- The following keys in your key_mapping: {str(key_map_keys_missing_from_left)} for JoinPart {group_by_name} are not included in the left side of the join: {left_cols_as_str}"
345
+
346
+ # Right side of key mapping should only include keys in GroupBy
347
+ keys_missing_from_key_map_values = [v for v in key_mapping.values() if v not in join_part.groupBy.keyColumns]
348
+ if keys_missing_from_key_map_values:
349
+ err_string += f"\n- The following values in your key_mapping: {str(keys_missing_from_key_map_values)} for JoinPart {group_by_name} do not cover any group by key columns: {join_part.groupBy.keyColumns}"
350
+
351
+ if key_map_keys_missing_from_left or keys_missing_from_key_map_values:
352
+ err_string += "\n(Key Mapping should be formatted as column_from_left -> group_by_key)"
353
+
354
+ if err_string:
355
+ return ValueError(err_string)
356
+
357
+
358
+ def _validate_keys(self, join: Join) -> List[BaseException]:
359
+ left = join.left
360
+
361
+ left_selects = None
362
+ if left.events:
363
+ left_selects = left.events.query.selects
364
+ elif left.entities:
365
+ left_selects = left.entities.query.selects
366
+ elif left.joinSource:
367
+ left_selects = left.joinSource.query.selects
368
+ # TODO -- if selects are not selected here, get output cols from join
369
+
370
+ left_cols = []
371
+
372
+ if left_selects:
373
+ left_cols = left_selects.keys()
374
+
375
+ errors = []
376
+
377
+ if left_cols:
378
+ join_parts = join.joinParts
379
+
380
+ # Add label_parts to join_parts to validate if set
381
+ label_parts = join.labelParts
382
+ if label_parts:
383
+ for label_jp in label_parts.labels:
384
+ join_parts.append(label_jp)
385
+
386
+ # Validate join_parts
387
+ for join_part in join_parts:
388
+ join_part_err = self._validate_join_part_keys(join_part, left_cols)
389
+ if join_part_err:
390
+ errors.append(join_part_err)
391
+
392
+ return errors
393
+
394
+
395
+ def _validate_join(self, join: Join) -> List[BaseException]:
396
+ """
397
+ Validate join's status with materialized versions of group_bys
398
+ included by the join.
399
+
400
+ Returns:
401
+ list of validation errors.
402
+ """
403
+ included_group_bys_and_prefixes = [(rp.groupBy, rp.prefix) for rp in join.joinParts]
404
+ # TODO: Remove label parts check in future PR that deprecates label_parts
405
+ included_label_parts_and_prefixes = [(lp.groupBy, lp.prefix) for lp in join.labelParts.labels] if join.labelParts else []
406
+ included_group_bys = [tup[0] for tup in included_group_bys_and_prefixes]
407
+
408
+ offline_included_group_bys = [
409
+ gb.metaData.name
410
+ for gb in included_group_bys
411
+ if not gb.metaData or gb.metaData.online is False
412
+ ]
413
+ errors = []
414
+ old_group_bys = [
415
+ group_by
416
+ for group_by in included_group_bys
417
+ if self._get_old_obj(GroupBy, group_by.metaData.name)
418
+ ]
419
+ non_prod_old_group_bys = [
420
+ group_by.metaData.name
421
+ for group_by in old_group_bys
422
+ if group_by.metaData.production is False
423
+ ]
424
+ # Check if the underlying groupBy is valid
425
+ group_by_errors = [
426
+ self._validate_group_by(group_by) for group_by in included_group_bys
427
+ ]
428
+ errors += [
429
+ ValueError(f"join {join.metaData.name}'s underlying {error}")
430
+ for errors in group_by_errors
431
+ for error in errors
432
+ ]
433
+ # Check if the production join is using non production groupBy
434
+ if join.metaData.production and non_prod_old_group_bys:
435
+ errors.append(
436
+ ValueError("join {} is production but includes the following non production group_bys: {}".format(
437
+ join.metaData.name, ", ".join(non_prod_old_group_bys)
438
+ )
439
+ ))
440
+ # Check if the online join is using the offline groupBy
441
+ if join.metaData.online:
442
+ if offline_included_group_bys:
443
+ errors.append(
444
+ ValueError("join {} is online but includes the following offline group_bys: {}".format(
445
+ join.metaData.name, ", ".join(offline_included_group_bys)
446
+ )
447
+ ))
448
+ # Only validate the join derivation when the underlying groupBy is valid
449
+ group_by_correct = all(not errors for errors in group_by_errors)
450
+ if join.derivations and group_by_correct:
451
+ features = list(get_pre_derived_join_features(join).keys())
452
+ # For online joins keys are not included in output schema
453
+ if join.metaData.online:
454
+ columns = features
455
+ else:
456
+ keys = list(get_pre_derived_source_keys(join.left).keys())
457
+ columns = features + keys
458
+ errors.extend(self._validate_derivations(columns, join.derivations))
459
+
460
+
461
+ errors.extend(self._validate_keys(join))
462
+
463
+ # If the join is using "short" names, ensure that there are no collisions
464
+ if join.useLongNames is False:
465
+ right_part_collisions = detect_feature_name_collisions(included_group_bys_and_prefixes, "right parts", join.metaData.name)
466
+ if right_part_collisions:
467
+ errors.append(right_part_collisions)
468
+
469
+ label_part_collisions = detect_feature_name_collisions(included_label_parts_and_prefixes, "label parts", join.metaData.name)
470
+ if label_part_collisions:
471
+ errors.append(label_part_collisions)
472
+
473
+ return errors
474
+
475
+
476
+
477
+ def _validate_group_by(self, group_by: GroupBy) -> List[BaseException]:
478
+ """
479
+ Validate group_by's status with materialized versions of joins
480
+ including the group_by.
481
+
482
+ Return:
483
+ List of validation errors.
484
+ """
485
+ joins = self._get_old_joins_with_group_by(group_by)
486
+ online_joins = [
487
+ join.metaData.name for join in joins if join.metaData.online is True
488
+ ]
489
+ prod_joins = [
490
+ join.metaData.name for join in joins if join.metaData.production is True
491
+ ]
492
+ errors = []
493
+
494
+ non_temporal = (
495
+ group_by.accuracy is None or group_by.accuracy == Accuracy.SNAPSHOT
496
+ )
497
+
498
+ no_topic = not _group_by_has_topic(group_by)
499
+ has_hourly_windows = _group_by_has_hourly_windows(group_by)
500
+
501
+ # batch features cannot contain hourly windows
502
+ if (no_topic and non_temporal) and has_hourly_windows:
503
+ errors.append(
504
+ ValueError(f"group_by {group_by.metaData.name} is defined to be daily refreshed but contains "
505
+ f"hourly windows. "
506
+ )
507
+ )
508
+
509
+ def _validate_bounded_event_source():
510
+ if group_by.aggregations is None:
511
+ return
512
+
513
+ unbounded_event_sources = [
514
+ str(src)
515
+ for src in group_by.sources
516
+ if isinstance(get_root_source(src), EventSource) and get_query(src).startPartition is None
517
+ ]
518
+
519
+ if not unbounded_event_sources:
520
+ return
521
+
522
+ unwindowed_aggregations = [str(agg) for agg in group_by.aggregations if agg.windows is None]
523
+
524
+ if not unwindowed_aggregations:
525
+ return
526
+
527
+ nln = "\n"
528
+
529
+ errors.append(
530
+ ValueError(
531
+ f"""group_by {group_by.metaData.name} uses unwindowed aggregations [{nln}{f",{nln}".join(unwindowed_aggregations)}{nln}]
532
+ on unbounded event sources: [{nln}{f",{nln}".join(unbounded_event_sources)}{nln}].
533
+ Please set a start_partition on the source, or a window on the aggregation.""")
534
+ )
535
+
536
+ _validate_bounded_event_source()
537
+
538
+ # group by that are marked explicitly offline should not be present in
539
+ # materialized online joins.
540
+ if group_by.metaData.online is False and online_joins:
541
+ errors.append(
542
+ ValueError("group_by {} is explicitly marked offline but included in "
543
+ "the following online joins: {}".format(
544
+ group_by.metaData.name, ", ".join(online_joins)
545
+ ))
546
+ )
547
+ # group by that are marked explicitly non-production should not be
548
+ # present in materialized production joins.
549
+ if prod_joins:
550
+ if group_by.metaData.production is False:
551
+ errors.append(
552
+ ValueError(
553
+ "group_by {} is explicitly marked as non-production but included in the following production "
554
+ "joins: {}".format(group_by.metaData.name, ", ".join(prod_joins))
555
+ ))
556
+ # if the group by is included in any of materialized production join,
557
+ # set it to production in the materialized output.
558
+ else:
559
+ group_by.metaData.production = True
560
+
561
+ # validate the derivations are defined correctly
562
+ if group_by.derivations:
563
+ # For online group_by keys are not included in output schema
564
+ if group_by.metaData.online:
565
+ columns = list(get_pre_derived_group_by_features(group_by).keys())
566
+ else:
567
+ columns = list(get_pre_derived_group_by_columns(group_by).keys())
568
+ errors.extend(self._validate_derivations(columns, group_by.derivations))
569
+
570
+ for source in group_by.sources:
571
+ src: Source = source
572
+ if (
573
+ src.events
574
+ and src.events.isCumulative
575
+ and (src.events.query.timeColumn is None)
576
+ ):
577
+ errors.append(
578
+ ValueError("Please set query.timeColumn for Cumulative Events Table: {}".format(
579
+ src.events.table
580
+ ))
581
+ )
582
+ elif (
583
+ src.joinSource
584
+ ):
585
+ join_obj = src.joinSource.join
586
+ if join_obj.metaData.name is None or join_obj.metaData.team is None:
587
+ errors.append(
588
+ ValueError(f"Join must be defined with team and name: {join_obj}")
589
+ )
590
+ return errors
@@ -0,0 +1,112 @@
1
+ import difflib
2
+ from typing import Any, Dict, List
3
+
4
+ from rich.text import Text
5
+
6
+ from ai.chronon.cli.compile.display.compiled_obj import CompiledObj
7
+ from ai.chronon.cli.compile.display.diff_result import DiffResult
8
+
9
+
10
+ class ClassTracker:
11
+ """
12
+ Tracker object per class - Join, StagingQuery, GroupBy etc
13
+ """
14
+
15
+ def __init__(self):
16
+ self.existing_objs: Dict[str, CompiledObj] = {} # name to obj
17
+ self.files_to_obj: Dict[str, List[Any]] = {}
18
+ self.files_to_errors: Dict[str, List[Exception]] = {}
19
+ self.new_objs: Dict[str, CompiledObj] = {} # name to obj
20
+ self.diff_result = DiffResult()
21
+ self.deleted_names: List[str] = []
22
+
23
+ def add_existing(self, obj: CompiledObj) -> None:
24
+ self.existing_objs[obj.name] = obj
25
+
26
+ def add(self, compiled: CompiledObj) -> None:
27
+
28
+ if compiled.errors:
29
+
30
+ if compiled.file not in self.files_to_errors:
31
+ self.files_to_errors[compiled.file] = []
32
+
33
+ self.files_to_errors[compiled.file].extend(compiled.errors)
34
+
35
+ else:
36
+ if compiled.file not in self.files_to_obj:
37
+ self.files_to_obj[compiled.file] = []
38
+
39
+ self.files_to_obj[compiled.file].append(compiled.obj)
40
+
41
+ self.new_objs[compiled.name] = compiled
42
+ self._update_diff(compiled)
43
+
44
+ def _update_diff(self, compiled: CompiledObj) -> None:
45
+ if compiled.name in self.existing_objs:
46
+
47
+ existing_json = self.existing_objs[compiled.name].tjson
48
+ new_json = compiled.tjson
49
+
50
+ if existing_json != new_json:
51
+
52
+ diff = difflib.unified_diff(
53
+ existing_json.splitlines(keepends=True),
54
+ new_json.splitlines(keepends=True),
55
+ n=2,
56
+ )
57
+
58
+ print(f"Updated object: {compiled.name} in file {compiled.file}")
59
+ print("".join(diff))
60
+ print("\n")
61
+
62
+ self.diff_result.updated.append(compiled.name)
63
+
64
+ else:
65
+ if not compiled.errors:
66
+ self.diff_result.added.append(compiled.name)
67
+
68
+ def close(self) -> None:
69
+ self.closed = True
70
+ self.recent_file = None
71
+ self.deleted_names = list(self.existing_objs.keys() - self.new_objs.keys())
72
+
73
+ def to_status(self) -> Text:
74
+ text = Text(overflow="fold", no_wrap=False)
75
+
76
+ if self.existing_objs:
77
+ text.append(
78
+ f" Parsed {len(self.existing_objs)} previously compiled objects.\n"
79
+ )
80
+
81
+ if self.files_to_obj:
82
+ text.append(" Compiled ")
83
+ text.append(f"{len(self.new_objs)} ", style="bold green")
84
+ text.append("objects from ")
85
+ text.append(f"{len(self.files_to_obj)} ", style="bold green")
86
+ text.append("files.\n")
87
+
88
+ if self.files_to_errors:
89
+ text.append(" Failed to compile ")
90
+ text.append(f"{len(self.files_to_errors)} ", style="red")
91
+ text.append("files.\n")
92
+
93
+ return text
94
+
95
+ def to_errors(self) -> Text:
96
+ text = Text(overflow="fold", no_wrap=False)
97
+
98
+ if self.files_to_errors:
99
+ for file, errors in self.files_to_errors.items():
100
+ text.append(" ERROR ", style="bold red")
101
+ text.append(f"- {file}:\n")
102
+
103
+ for error in errors:
104
+ # Format each error properly, handling newlines
105
+ error_msg = str(error)
106
+ text.append(f" {error_msg}\n", style="red")
107
+
108
+ return text
109
+
110
+ # doesn't make sense to show deletes until the very end of compilation
111
+ def diff(self) -> Text:
112
+ return self.diff_result.render(deleted_names=self.deleted_names)