awx-zipline-ai 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of awx-zipline-ai might be problematic. Click here for more details.

Files changed (96) hide show
  1. agent/ttypes.py +6 -6
  2. ai/chronon/airflow_helpers.py +20 -23
  3. ai/chronon/cli/__init__.py +0 -0
  4. ai/chronon/cli/compile/__init__.py +0 -0
  5. ai/chronon/cli/compile/column_hashing.py +40 -17
  6. ai/chronon/cli/compile/compile_context.py +13 -17
  7. ai/chronon/cli/compile/compiler.py +59 -36
  8. ai/chronon/cli/compile/conf_validator.py +251 -99
  9. ai/chronon/cli/compile/display/__init__.py +0 -0
  10. ai/chronon/cli/compile/display/class_tracker.py +6 -16
  11. ai/chronon/cli/compile/display/compile_status.py +10 -10
  12. ai/chronon/cli/compile/display/diff_result.py +79 -14
  13. ai/chronon/cli/compile/fill_templates.py +3 -8
  14. ai/chronon/cli/compile/parse_configs.py +10 -17
  15. ai/chronon/cli/compile/parse_teams.py +38 -34
  16. ai/chronon/cli/compile/serializer.py +3 -9
  17. ai/chronon/cli/compile/version_utils.py +42 -0
  18. ai/chronon/cli/git_utils.py +2 -13
  19. ai/chronon/cli/logger.py +0 -2
  20. ai/chronon/constants.py +1 -1
  21. ai/chronon/group_by.py +47 -47
  22. ai/chronon/join.py +46 -32
  23. ai/chronon/logger.py +1 -2
  24. ai/chronon/model.py +9 -4
  25. ai/chronon/query.py +2 -2
  26. ai/chronon/repo/__init__.py +1 -2
  27. ai/chronon/repo/aws.py +17 -31
  28. ai/chronon/repo/cluster.py +121 -50
  29. ai/chronon/repo/compile.py +14 -8
  30. ai/chronon/repo/constants.py +1 -1
  31. ai/chronon/repo/default_runner.py +32 -54
  32. ai/chronon/repo/explore.py +70 -73
  33. ai/chronon/repo/extract_objects.py +6 -9
  34. ai/chronon/repo/gcp.py +89 -88
  35. ai/chronon/repo/gitpython_utils.py +3 -2
  36. ai/chronon/repo/hub_runner.py +145 -55
  37. ai/chronon/repo/hub_uploader.py +2 -1
  38. ai/chronon/repo/init.py +12 -5
  39. ai/chronon/repo/join_backfill.py +19 -5
  40. ai/chronon/repo/run.py +42 -39
  41. ai/chronon/repo/serializer.py +4 -12
  42. ai/chronon/repo/utils.py +72 -63
  43. ai/chronon/repo/zipline.py +3 -19
  44. ai/chronon/repo/zipline_hub.py +211 -39
  45. ai/chronon/resources/__init__.py +0 -0
  46. ai/chronon/resources/gcp/__init__.py +0 -0
  47. ai/chronon/resources/gcp/group_bys/__init__.py +0 -0
  48. ai/chronon/resources/gcp/group_bys/test/data.py +13 -17
  49. ai/chronon/resources/gcp/joins/__init__.py +0 -0
  50. ai/chronon/resources/gcp/joins/test/data.py +4 -8
  51. ai/chronon/resources/gcp/sources/__init__.py +0 -0
  52. ai/chronon/resources/gcp/sources/test/data.py +9 -6
  53. ai/chronon/resources/gcp/teams.py +9 -21
  54. ai/chronon/source.py +2 -4
  55. ai/chronon/staging_query.py +60 -19
  56. ai/chronon/types.py +3 -2
  57. ai/chronon/utils.py +21 -68
  58. ai/chronon/windows.py +2 -4
  59. {awx_zipline_ai-0.2.1.dist-info → awx_zipline_ai-0.3.0.dist-info}/METADATA +47 -24
  60. awx_zipline_ai-0.3.0.dist-info/RECORD +96 -0
  61. awx_zipline_ai-0.3.0.dist-info/top_level.txt +4 -0
  62. gen_thrift/__init__.py +0 -0
  63. {ai/chronon → gen_thrift}/api/ttypes.py +327 -197
  64. {ai/chronon/api → gen_thrift}/common/ttypes.py +9 -39
  65. gen_thrift/eval/ttypes.py +660 -0
  66. {ai/chronon → gen_thrift}/hub/ttypes.py +12 -131
  67. {ai/chronon → gen_thrift}/observability/ttypes.py +343 -180
  68. {ai/chronon → gen_thrift}/planner/ttypes.py +326 -45
  69. ai/chronon/eval/__init__.py +0 -122
  70. ai/chronon/eval/query_parsing.py +0 -19
  71. ai/chronon/eval/sample_tables.py +0 -100
  72. ai/chronon/eval/table_scan.py +0 -186
  73. ai/chronon/orchestration/ttypes.py +0 -4406
  74. ai/chronon/resources/gcp/README.md +0 -174
  75. ai/chronon/resources/gcp/zipline-cli-install.sh +0 -54
  76. awx_zipline_ai-0.2.1.dist-info/RECORD +0 -93
  77. awx_zipline_ai-0.2.1.dist-info/licenses/LICENSE +0 -202
  78. awx_zipline_ai-0.2.1.dist-info/top_level.txt +0 -3
  79. /jars/__init__.py → /__init__.py +0 -0
  80. {awx_zipline_ai-0.2.1.dist-info → awx_zipline_ai-0.3.0.dist-info}/WHEEL +0 -0
  81. {awx_zipline_ai-0.2.1.dist-info → awx_zipline_ai-0.3.0.dist-info}/entry_points.txt +0 -0
  82. {ai/chronon → gen_thrift}/api/__init__.py +0 -0
  83. {ai/chronon/api/common → gen_thrift/api}/constants.py +0 -0
  84. {ai/chronon/api → gen_thrift}/common/__init__.py +0 -0
  85. {ai/chronon/api → gen_thrift/common}/constants.py +0 -0
  86. {ai/chronon/fetcher → gen_thrift/eval}/__init__.py +0 -0
  87. {ai/chronon/fetcher → gen_thrift/eval}/constants.py +0 -0
  88. {ai/chronon/hub → gen_thrift/fetcher}/__init__.py +0 -0
  89. {ai/chronon/hub → gen_thrift/fetcher}/constants.py +0 -0
  90. {ai/chronon → gen_thrift}/fetcher/ttypes.py +0 -0
  91. {ai/chronon/observability → gen_thrift/hub}/__init__.py +0 -0
  92. {ai/chronon/observability → gen_thrift/hub}/constants.py +0 -0
  93. {ai/chronon/orchestration → gen_thrift/observability}/__init__.py +0 -0
  94. {ai/chronon/orchestration → gen_thrift/observability}/constants.py +0 -0
  95. {ai/chronon → gen_thrift}/planner/__init__.py +0 -0
  96. {ai/chronon → gen_thrift}/planner/constants.py +0 -0
ai/chronon/group_by.py CHANGED
@@ -18,13 +18,25 @@ import logging
18
18
  from copy import deepcopy
19
19
  from typing import Callable, Dict, List, Optional, Tuple, Union
20
20
 
21
- import ai.chronon.api.common.ttypes as common
22
- import ai.chronon.api.ttypes as ttypes
21
+ import gen_thrift.api.ttypes as ttypes
22
+ import gen_thrift.common.ttypes as common
23
+
23
24
  import ai.chronon.utils as utils
24
25
  import ai.chronon.windows as window_utils
25
26
 
26
27
  OperationType = int # type(zthrift.Operation.FIRST)
27
28
 
29
+
30
+ def _get_output_table_name(obj, full_name: bool = False):
31
+ """
32
+ Group by backfill output table name
33
+ To be synced with api.Extensions.scala
34
+ """
35
+ if not obj.metaData.name:
36
+ utils.__set_name(obj, ttypes.GroupBy, "group_bys")
37
+ return utils.output_table_name(obj, full_name)
38
+
39
+
28
40
  # The GroupBy's default online/production status is None and it will inherit
29
41
  # online/production status from the Joins it is included.
30
42
  # If it is included in multiple joins, it is considered online/production
@@ -58,7 +70,6 @@ class Accuracy(ttypes.Accuracy):
58
70
 
59
71
 
60
72
  class Operation:
61
-
62
73
  MIN = ttypes.Operation.MIN
63
74
  """Minimum value in the column"""
64
75
 
@@ -143,9 +154,7 @@ class Operation:
143
154
  UNIQUE_TOP_K = collector(ttypes.Operation.UNIQUE_TOP_K)
144
155
  """Returns top k unique elements ranked by their values. Automatically deduplicates inputs. For structs, requires sort_key (String) and unique_id (Long) fields."""
145
156
 
146
- APPROX_PERCENTILE = generic_collector(
147
- ttypes.Operation.APPROX_PERCENTILE, ["percentiles"], k=20
148
- )
157
+ APPROX_PERCENTILE = generic_collector(ttypes.Operation.APPROX_PERCENTILE, ["percentiles"], k=20)
149
158
  """Approximate percentile calculation with configurable accuracy parameter k=20"""
150
159
 
151
160
 
@@ -169,9 +178,7 @@ def DefaultAggregation(keys, sources, operation=Operation.LAST, tags=None):
169
178
  "ds",
170
179
  query.timeColumn,
171
180
  ]
172
- aggregate_columns += [
173
- column for column in columns if column not in non_aggregate_columns
174
- ]
181
+ aggregate_columns += [column for column in columns if column not in non_aggregate_columns]
175
182
  return [
176
183
  Aggregation(operation=operation, input_column=column, tags=tags)
177
184
  for column in aggregate_columns
@@ -232,9 +239,7 @@ def Aggregation(
232
239
  elif isinstance(w, common.Window):
233
240
  return w
234
241
  else:
235
- raise Exception(
236
- "window should be either a string like '7d', '24h', or a Window type"
237
- )
242
+ raise Exception("window should be either a string like '7d', '24h', or a Window type")
238
243
 
239
244
  norm_windows = [normalize(w) for w in windows] if windows else None
240
245
 
@@ -279,8 +284,7 @@ def validate_group_by(group_by: ttypes.GroupBy):
279
284
  first_source_columns = set(utils.get_columns(sources[0]))
280
285
  # TODO undo this check after ml_models CI passes
281
286
  assert "ts" not in first_source_columns, (
282
- "'ts' is a reserved key word for Chronon,"
283
- " please specify the expression in timeColumn"
287
+ "'ts' is a reserved key word for Chronon, please specify the expression in timeColumn"
284
288
  )
285
289
  for src in sources:
286
290
  query = utils.get_query(src)
@@ -290,8 +294,7 @@ def validate_group_by(group_by: ttypes.GroupBy):
290
294
  "event source as it should be the same with timeColumn"
291
295
  )
292
296
  assert query.reversalColumn is None, (
293
- "reversalColumn should not be specified for event source "
294
- "as it won't have mutations"
297
+ "reversalColumn should not be specified for event source as it won't have mutations"
295
298
  )
296
299
  if group_by.accuracy != Accuracy.SNAPSHOT:
297
300
  assert query.timeColumn is not None, (
@@ -300,9 +303,9 @@ def validate_group_by(group_by: ttypes.GroupBy):
300
303
  )
301
304
  else:
302
305
  if contains_windowed_aggregation(aggregations):
303
- assert (
304
- query.timeColumn
305
- ), "Please specify timeColumn for entity source with windowed aggregations"
306
+ assert query.timeColumn, (
307
+ "Please specify timeColumn for entity source with windowed aggregations"
308
+ )
306
309
 
307
310
  column_set = None
308
311
  # all sources should select the same columns
@@ -310,7 +313,7 @@ def validate_group_by(group_by: ttypes.GroupBy):
310
313
  column_set = set(utils.get_columns(source))
311
314
  column_diff = column_set ^ first_source_columns
312
315
  assert not column_diff, f"""
313
- Mismatched columns among sources [1, {i+2}], Difference: {column_diff}
316
+ Mismatched columns among sources [1, {i + 2}], Difference: {column_diff}
314
317
  """
315
318
 
316
319
  # all keys should be present in the selected columns
@@ -325,10 +328,7 @@ Keys {unselected_keys}, are unselected in source
325
328
  has_mutations = (
326
329
  any(
327
330
  [
328
- (
329
- s.entities.mutationTable is not None
330
- or s.entities.mutationTopic is not None
331
- )
331
+ (s.entities.mutationTable is not None or s.entities.mutationTopic is not None)
332
332
  for s in sources
333
333
  if s.entities is not None
334
334
  ]
@@ -336,9 +336,9 @@ Keys {unselected_keys}, are unselected in source
336
336
  if not is_events
337
337
  else False
338
338
  )
339
- assert not (
340
- is_events or has_mutations
341
- ), "You can only set aggregations=None in an EntitySource without mutations"
339
+ assert not (is_events or has_mutations), (
340
+ "You can only set aggregations=None in an EntitySource without mutations"
341
+ )
342
342
  else:
343
343
  columns = set([c for src in sources for c in utils.get_columns(src)])
344
344
  for agg in aggregations:
@@ -355,9 +355,7 @@ Keys {unselected_keys}, are unselected in source
355
355
  try:
356
356
  percentile_array = json.loads(agg.argMap["percentiles"])
357
357
  assert isinstance(percentile_array, list)
358
- assert all(
359
- [float(p) >= 0 and float(p) <= 1 for p in percentile_array]
360
- )
358
+ assert all([float(p) >= 0 and float(p) <= 1 for p in percentile_array])
361
359
  except Exception as e:
362
360
  LOGGER.exception(e)
363
361
  raise ValueError(
@@ -388,9 +386,7 @@ Keys {unselected_keys}, are unselected in source
388
386
  )
389
387
 
390
388
 
391
- _ANY_SOURCE_TYPE = Union[
392
- ttypes.Source, ttypes.EventSource, ttypes.EntitySource, ttypes.JoinSource
393
- ]
389
+ _ANY_SOURCE_TYPE = Union[ttypes.Source, ttypes.EventSource, ttypes.EntitySource, ttypes.JoinSource]
394
390
 
395
391
 
396
392
  def _get_op_suffix(operation, argmap):
@@ -409,7 +405,9 @@ def _get_op_suffix(operation, argmap):
409
405
 
410
406
 
411
407
  def get_output_col_names(aggregation):
412
- base_name = f"{aggregation.inputColumn}_{_get_op_suffix(aggregation.operation, aggregation.argMap)}"
408
+ base_name = (
409
+ f"{aggregation.inputColumn}_{_get_op_suffix(aggregation.operation, aggregation.argMap)}"
410
+ )
413
411
  windowed_names = []
414
412
  if aggregation.windows:
415
413
  for window in aggregation.windows:
@@ -456,7 +454,7 @@ def GroupBy(
456
454
  :param sources:
457
455
  can be constructed as entities or events or joinSource::
458
456
 
459
- import ai.chronon.api.ttypes as chronon
457
+ import gen_thrift.api.ttypes as chronon
460
458
  events = chronon.Source(events=chronon.Events(
461
459
  table=YOUR_TABLE,
462
460
  topic=YOUR_TOPIC # <- OPTIONAL for serving
@@ -478,7 +476,7 @@ def GroupBy(
478
476
 
479
477
  Multiple sources can be supplied to backfill the historical values with their respective start and end
480
478
  partitions. However, only one source is allowed to be a streaming one.
481
- :type sources: List[ai.chronon.api.ttypes.Events|ai.chronon.api.ttypes.Entities]
479
+ :type sources: List[gen_thrift.api.ttypes.Events|gen_thrift.api.ttypes.Entities]
482
480
  :param keys:
483
481
  List of primary keys that defines the data that needs to be collected in the result table. Similar to the
484
482
  GroupBy in the SQL context.
@@ -486,12 +484,12 @@ def GroupBy(
486
484
  :param aggregations:
487
485
  List of aggregations that needs to be computed for the data following the grouping defined by the keys::
488
486
 
489
- import ai.chronon.api.ttypes as chronon
487
+ import gen_thrift.api.ttypes as chronon
490
488
  aggregations = [
491
489
  chronon.Aggregation(input_column="entity", operation=Operation.LAST),
492
490
  chronon.Aggregation(input_column="entity", operation=Operation.LAST, windows=['7d'])
493
491
  ],
494
- :type aggregations: List[ai.chronon.api.ttypes.Aggregation]
492
+ :type aggregations: List[gen_thrift.api.ttypes.Aggregation]
495
493
  :param online:
496
494
  Should we upload the result data of this conf into the KV store so that we can fetch/serve this GroupBy online.
497
495
  Once Online is set to True, you ideally should not change the conf.
@@ -533,7 +531,7 @@ def GroupBy(
533
531
  Defines the computing accuracy of the GroupBy.
534
532
  If "Snapshot" is selected, the aggregations are computed based on the partition identifier - "ds" time column.
535
533
  If "Temporal" is selected, the aggregations are computed based on the event time - "ts" time column.
536
- :type accuracy: ai.chronon.api.ttypes.SNAPSHOT or ai.chronon.api.ttypes.TEMPORAL
534
+ :type accuracy: gen_thrift.api.ttypes.SNAPSHOT or gen_thrift.api.ttypes.TEMPORAL
537
535
  :param lag:
538
536
  Param that goes into customJson. You can pull this out of the json at path "metaData.customJson.lag"
539
537
  This is used by airflow integration to pick an older hive partition to wait on.
@@ -555,7 +553,7 @@ def GroupBy(
555
553
  :param derivations:
556
554
  Derivation allows arbitrary SQL select clauses to be computed using columns from the output of group by backfill
557
555
  output schema. It is supported for offline computations for now.
558
- :type derivations: List[ai.chronon.api.ttypes.Drivation]
556
+ :type derivations: List[gen_thrift.api.ttypes.Drivation]
559
557
  :param kwargs:
560
558
  Additional properties that would be passed to run.py if specified under additional_args property.
561
559
  And provides an option to pass custom values to the processing logic.
@@ -585,6 +583,10 @@ def GroupBy(
585
583
  """
586
584
  assert sources, "Sources are not specified"
587
585
 
586
+ assert isinstance(version, int), (
587
+ f"Version must be an integer, but found {type(version).__name__}"
588
+ )
589
+
588
590
  agg_inputs = []
589
591
  if aggregations is not None:
590
592
  agg_inputs = [agg.inputColumn for agg in aggregations]
@@ -596,11 +598,7 @@ def GroupBy(
596
598
  query = (
597
599
  source.entities.query
598
600
  if source.entities is not None
599
- else (
600
- source.events.query
601
- if source.events is not None
602
- else source.joinSource.query
603
- )
601
+ else (source.events.query if source.events is not None else source.joinSource.query)
604
602
  )
605
603
 
606
604
  if query.selects is None:
@@ -665,13 +663,12 @@ def GroupBy(
665
663
  for output_col in get_output_col_names(agg):
666
664
  column_tags[output_col] = agg.tags
667
665
 
668
-
669
666
  metadata = ttypes.MetaData(
670
667
  online=online,
671
668
  production=production,
672
669
  outputNamespace=output_namespace,
673
670
  tableProperties=table_properties,
674
- team=team,
671
+ team=team,
675
672
  executionInfo=exec_info,
676
673
  tags=tags if tags else None,
677
674
  columnTags=column_tags if column_tags else None,
@@ -689,4 +686,7 @@ def GroupBy(
689
686
  )
690
687
  validate_group_by(group_by)
691
688
 
689
+ # Add the table property that calls the private function
690
+ group_by.__class__.table = property(lambda self: _get_output_table_name(self, full_name=True))
691
+
692
692
  return group_by
ai/chronon/join.py CHANGED
@@ -19,14 +19,34 @@ import logging
19
19
  from collections import Counter
20
20
  from typing import Dict, List, Tuple, Union
21
21
 
22
- import ai.chronon.api.common.ttypes as common
23
- import ai.chronon.api.ttypes as api
22
+ import gen_thrift.api.ttypes as api
23
+ import gen_thrift.common.ttypes as common
24
+
24
25
  import ai.chronon.repo.extract_objects as eo
25
26
  import ai.chronon.utils as utils
27
+ from ai.chronon.cli.compile import parse_teams
26
28
 
27
29
  logging.basicConfig(level=logging.INFO)
28
30
 
29
31
 
32
+ def _get_output_table_name(join: api.Join, full_name: bool = False):
33
+ """generate output table name for join backfill job"""
34
+ # join sources could also be created inline alongside groupBy file
35
+ # so we specify fallback module as group_bys
36
+ if isinstance(join, api.Join):
37
+ utils.__set_name(join, api.Join, "joins")
38
+ # set output namespace
39
+ if not join.metaData.outputNamespace:
40
+ team_name = join.metaData.name.split(".")[0]
41
+ namespace = (
42
+ parse_teams.load_teams(utils.chronon_root_path, print=False)
43
+ .get(team_name)
44
+ .outputNamespace
45
+ )
46
+ join.metaData.outputNamespace = namespace
47
+ return utils.output_table_name(join, full_name=full_name)
48
+
49
+
30
50
  def JoinPart(
31
51
  group_by: api.GroupBy,
32
52
  key_mapping: Dict[str, str] = None,
@@ -57,9 +77,9 @@ def JoinPart(
57
77
  components like GroupBys.
58
78
  """
59
79
 
60
- assert isinstance(
61
- group_by, api.GroupBy
62
- ), f"Expecting GroupBy. But found {type(group_by).__name__}"
80
+ assert isinstance(group_by, api.GroupBy), (
81
+ f"Expecting GroupBy. But found {type(group_by).__name__}"
82
+ )
63
83
 
64
84
  # used for reset for next run
65
85
  import_copy = __builtins__["__import__"]
@@ -80,14 +100,10 @@ def JoinPart(
80
100
 
81
101
  if group_by_module_name:
82
102
  logging.debug(
83
- "group_by's module info from garbage collector {}".format(
84
- group_by_module_name
85
- )
103
+ "group_by's module info from garbage collector {}".format(group_by_module_name)
86
104
  )
87
105
  group_by_module = importlib.import_module(group_by_module_name)
88
- __builtins__["__import__"] = eo.import_module_set_name(
89
- group_by_module, api.GroupBy
90
- )
106
+ __builtins__["__import__"] = eo.import_module_set_name(group_by_module, api.GroupBy)
91
107
  else:
92
108
  if not group_by.metaData.name:
93
109
  logging.error("No group_by file or custom group_by name found")
@@ -133,9 +149,9 @@ class DataType:
133
149
  # TIMESTAMP = api.TDataType(api.DataKind.TIMESTAMP)
134
150
 
135
151
  def MAP(key_type: api.TDataType, value_type: api.TDataType) -> api.TDataType:
136
- assert key_type == api.TDataType(
137
- api.DataKind.STRING
138
- ), "key_type has to STRING for MAP types"
152
+ assert key_type == api.TDataType(api.DataKind.STRING), (
153
+ "key_type has to STRING for MAP types"
154
+ )
139
155
 
140
156
  return api.TDataType(
141
157
  api.DataKind.MAP,
@@ -143,9 +159,7 @@ class DataType:
143
159
  )
144
160
 
145
161
  def LIST(elem_type: api.TDataType) -> api.TDataType:
146
- return api.TDataType(
147
- api.DataKind.LIST, params=[api.DataField("elem", elem_type)]
148
- )
162
+ return api.TDataType(api.DataKind.LIST, params=[api.DataField("elem", elem_type)])
149
163
 
150
164
  def STRUCT(name: str, *fields: FieldsType) -> api.TDataType:
151
165
  return api.TDataType(
@@ -475,18 +489,19 @@ def Join(
475
489
  if isinstance(row_ids, str):
476
490
  row_ids = [row_ids]
477
491
 
492
+ assert isinstance(version, int), (
493
+ f"Version must be an integer, but found {type(version).__name__}"
494
+ )
495
+
478
496
  # create a deep copy for case: multiple LeftOuterJoin use the same left,
479
497
  # validation will fail after the first iteration
480
498
  updated_left = copy.deepcopy(left)
481
499
  if left.events and left.events.query.selects:
482
500
  assert "ts" not in left.events.query.selects.keys(), (
483
- "'ts' is a reserved key word for Chronon,"
484
- " please specify the expression in timeColumn"
501
+ "'ts' is a reserved key word for Chronon, please specify the expression in timeColumn"
485
502
  )
486
503
  # mapping ts to query.timeColumn to events only
487
- updated_left.events.query.selects.update(
488
- {"ts": updated_left.events.query.timeColumn}
489
- )
504
+ updated_left.events.query.selects.update({"ts": updated_left.events.query.timeColumn})
490
505
 
491
506
  if label_part:
492
507
  label_metadata = api.MetaData(
@@ -499,9 +514,7 @@ def Join(
499
514
  metaData=label_metadata,
500
515
  )
501
516
 
502
- consistency_sample_percent = (
503
- consistency_sample_percent if check_consistency else None
504
- )
517
+ consistency_sample_percent = consistency_sample_percent if check_consistency else None
505
518
 
506
519
  # external parts need to be unique on (prefix, part.source.metaData.name)
507
520
  if online_external_parts:
@@ -513,15 +526,13 @@ def Join(
513
526
  if count > 1:
514
527
  has_duplicates = True
515
528
  print(f"Found {count - 1} duplicate(s) for external part {key}")
516
- assert (
517
- has_duplicates is False
518
- ), "Please address all the above mentioned duplicates."
529
+ assert has_duplicates is False, "Please address all the above mentioned duplicates."
519
530
 
520
531
  if bootstrap_from_log:
521
532
  has_logging = sample_percent > 0 and online
522
- assert (
523
- has_logging
524
- ), "Join must be online with sample_percent set in order to use bootstrap_from_log option"
533
+ assert has_logging, (
534
+ "Join must be online with sample_percent set in order to use bootstrap_from_log option"
535
+ )
525
536
  bootstrap_parts = (bootstrap_parts or []) + [
526
537
  api.BootstrapPart(
527
538
  # templated values will be replaced when metaData.name is set at the end
@@ -535,7 +546,7 @@ def Join(
535
546
  env=env_vars,
536
547
  stepDays=step_days,
537
548
  historicalBackfill=historical_backfill,
538
- clusterConf=cluster_conf
549
+ clusterConf=cluster_conf,
539
550
  )
540
551
 
541
552
  metadata = api.MetaData(
@@ -563,4 +574,7 @@ def Join(
563
574
  useLongNames=use_long_names,
564
575
  )
565
576
 
577
+ # Add the table property that calls the private function
578
+ join.__class__.table = property(lambda self: _get_output_table_name(self, full_name=True))
579
+
566
580
  return join
ai/chronon/logger.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
1
  # Copyright (C) 2023 The Chronon Authors.
3
2
  #
4
3
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,7 +14,7 @@
15
14
 
16
15
  import logging
17
16
 
18
- LOG_FORMAT = '[%(asctime)-11s] %(levelname)s [%(filename)s:%(lineno)d] %(message)s'
17
+ LOG_FORMAT = "[%(asctime)-11s] %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
19
18
 
20
19
 
21
20
  def get_logger(log_level=logging.INFO):
ai/chronon/model.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from typing import Optional
2
2
 
3
- import ai.chronon.api.ttypes as ttypes
3
+ import gen_thrift.api.ttypes as ttypes
4
4
 
5
5
 
6
6
  class ModelType:
@@ -14,7 +14,7 @@ def Model(
14
14
  outputSchema: ttypes.TDataType,
15
15
  modelType: ModelType,
16
16
  name: str = None,
17
- modelParams: Optional[dict[str, str]] = None
17
+ modelParams: Optional[dict[str, str]] = None,
18
18
  ) -> ttypes.Model:
19
19
  if not isinstance(source, ttypes.Source):
20
20
  raise ValueError("Invalid source type")
@@ -31,5 +31,10 @@ def Model(
31
31
  name=name,
32
32
  )
33
33
 
34
- return ttypes.Model(modelType=modelType, outputSchema=outputSchema, source=source,
35
- modelParams=modelParams, metaData=metaData)
34
+ return ttypes.Model(
35
+ modelType=modelType,
36
+ outputSchema=outputSchema,
37
+ source=source,
38
+ modelParams=modelParams,
39
+ metaData=metaData,
40
+ )
ai/chronon/query.py CHANGED
@@ -15,7 +15,7 @@
15
15
  from collections import OrderedDict
16
16
  from typing import Dict, List
17
17
 
18
- import ai.chronon.api.ttypes as api
18
+ import gen_thrift.api.ttypes as api
19
19
 
20
20
 
21
21
  def Query(
@@ -96,7 +96,7 @@ def Query(
96
96
  reversalColumn=reversal_column,
97
97
  partitionColumn=partition_column,
98
98
  subPartitionsToWaitFor=sub_partitions_to_wait_for,
99
- partitionFormat=partition_format
99
+ partitionFormat=partition_format,
100
100
  )
101
101
 
102
102
 
@@ -12,8 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ai.chronon.api.ttypes import GroupBy, Join, Model, StagingQuery
16
- from ai.chronon.orchestration.ttypes import ConfType
15
+ from gen_thrift.api.ttypes import ConfType, GroupBy, Join, Model, StagingQuery
17
16
 
18
17
  JOIN_FOLDER_NAME = "joins"
19
18
  GROUP_BY_FOLDER_NAME = "group_bys"
ai/chronon/repo/aws.py CHANGED
@@ -38,17 +38,13 @@ class AwsRunner(Runner):
38
38
  service_jar_path = AwsRunner.download_zipline_aws_jar(
39
39
  ZIPLINE_DIRECTORY, get_customer_id(), args["version"], ZIPLINE_AWS_SERVICE_JAR
40
40
  )
41
- jar_path = (
42
- f"{service_jar_path}:{aws_jar_path}" if args['mode'] == "fetch" else aws_jar_path
43
- )
41
+ jar_path = f"{service_jar_path}:{aws_jar_path}" if args["mode"] == "fetch" else aws_jar_path
44
42
  self.version = args.get("version", "latest")
45
43
 
46
44
  super().__init__(args, os.path.expanduser(jar_path))
47
45
 
48
46
  @staticmethod
49
- def upload_s3_file(
50
- bucket_name: str, source_file_name: str, destination_blob_name: str
51
- ):
47
+ def upload_s3_file(bucket_name: str, source_file_name: str, destination_blob_name: str):
52
48
  """Uploads a file to the bucket."""
53
49
  obj = boto3.client("s3")
54
50
  try:
@@ -61,7 +57,9 @@ class AwsRunner(Runner):
61
57
  raise RuntimeError(f"Failed to upload {source_file_name}: {str(e)}") from e
62
58
 
63
59
  @staticmethod
64
- def download_zipline_aws_jar(destination_dir: str, customer_id: str, version: str, jar_name: str):
60
+ def download_zipline_aws_jar(
61
+ destination_dir: str, customer_id: str, version: str, jar_name: str
62
+ ):
65
63
  s3_client = boto3.client("s3")
66
64
  destination_path = f"{destination_dir}/{jar_name}"
67
65
  source_key_name = f"release/{version}/jars/{jar_name}"
@@ -78,9 +76,7 @@ class AwsRunner(Runner):
78
76
  if are_identical:
79
77
  print(f"{destination_path} matches S3 {bucket_name}/{source_key_name}")
80
78
  else:
81
- print(
82
- f"{destination_path} does NOT match S3 {bucket_name}/{source_key_name}"
83
- )
79
+ print(f"{destination_path} does NOT match S3 {bucket_name}/{source_key_name}")
84
80
  print(f"Downloading {jar_name} from S3...")
85
81
 
86
82
  s3_client.download_file(
@@ -122,9 +118,7 @@ class AwsRunner(Runner):
122
118
  return None
123
119
 
124
120
  @staticmethod
125
- def compare_s3_and_local_file_hashes(
126
- bucket_name: str, s3_file_path: str, local_file_path: str
127
- ):
121
+ def compare_s3_and_local_file_hashes(bucket_name: str, s3_file_path: str, local_file_path: str):
128
122
  try:
129
123
  s3_hash = AwsRunner.get_s3_file_hash(bucket_name, s3_file_path)
130
124
  local_hash = AwsRunner.get_local_file_hash(local_file_path)
@@ -144,9 +138,7 @@ class AwsRunner(Runner):
144
138
  s3_files = []
145
139
  for source_file in local_files_to_upload:
146
140
  # upload to `metadata` folder
147
- destination_file_path = (
148
- f"metadata/{extract_filename_from_path(source_file)}"
149
- )
141
+ destination_file_path = f"metadata/{extract_filename_from_path(source_file)}"
150
142
  s3_files.append(
151
143
  AwsRunner.upload_s3_file(
152
144
  customer_warehouse_bucket_name, source_file, destination_file_path
@@ -169,7 +161,9 @@ class AwsRunner(Runner):
169
161
  + f"/release/{self.version}/jars/{ZIPLINE_AWS_JAR_DEFAULT}"
170
162
  )
171
163
 
172
- final_args = "{user_args} --jar-uri={jar_uri} --job-type={job_type} --main-class={main_class}"
164
+ final_args = (
165
+ "{user_args} --jar-uri={jar_uri} --job-type={job_type} --main-class={main_class}"
166
+ )
173
167
 
174
168
  if job_type == JobType.FLINK:
175
169
  main_class = "ai.chronon.flink.FlinkJob"
@@ -197,7 +191,7 @@ class AwsRunner(Runner):
197
191
  main_class=main_class,
198
192
  )
199
193
  + f" --additional-conf-path={EMR_MOUNT_FILE_PREFIX}additional-confs.yaml"
200
- f" --files={s3_file_args}"
194
+ f" --files={s3_file_args}"
201
195
  )
202
196
  else:
203
197
  raise ValueError(f"Invalid job type: {job_type}")
@@ -240,15 +234,12 @@ class AwsRunner(Runner):
240
234
  end_ds=end_ds,
241
235
  # when we download files from s3 to emr, they'll be mounted at /mnt/zipline
242
236
  override_conf_path=(
243
- EMR_MOUNT_FILE_PREFIX
244
- + extract_filename_from_path(self.conf)
237
+ EMR_MOUNT_FILE_PREFIX + extract_filename_from_path(self.conf)
245
238
  if self.conf
246
239
  else None
247
240
  ),
248
241
  ),
249
- additional_args=os.environ.get(
250
- "CHRONON_CONFIG_ADDITIONAL_ARGS", ""
251
- ),
242
+ additional_args=os.environ.get("CHRONON_CONFIG_ADDITIONAL_ARGS", ""),
252
243
  )
253
244
 
254
245
  emr_args = self.generate_emr_submitter_args(
@@ -265,15 +256,12 @@ class AwsRunner(Runner):
265
256
  start_ds=self.start_ds,
266
257
  # when we download files from s3 to emr, they'll be mounted at /mnt/zipline
267
258
  override_conf_path=(
268
- EMR_MOUNT_FILE_PREFIX
269
- + extract_filename_from_path(self.conf)
259
+ EMR_MOUNT_FILE_PREFIX + extract_filename_from_path(self.conf)
270
260
  if self.conf
271
261
  else None
272
262
  ),
273
263
  ),
274
- additional_args=os.environ.get(
275
- "CHRONON_CONFIG_ADDITIONAL_ARGS", ""
276
- ),
264
+ additional_args=os.environ.get("CHRONON_CONFIG_ADDITIONAL_ARGS", ""),
277
265
  )
278
266
 
279
267
  emr_args = self.generate_emr_submitter_args(
@@ -288,9 +276,7 @@ class AwsRunner(Runner):
288
276
  # parallel backfill mode
289
277
  with multiprocessing.Pool(processes=int(self.parallelism)) as pool:
290
278
  LOG.info(
291
- "Running args list {} with pool size {}".format(
292
- command_list, self.parallelism
293
- )
279
+ "Running args list {} with pool size {}".format(command_list, self.parallelism)
294
280
  )
295
281
  pool.map(check_call, command_list)
296
282
  elif len(command_list) == 1: