awx-zipline-ai 0.2.1__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. agent/ttypes.py +6 -6
  2. ai/chronon/airflow_helpers.py +20 -23
  3. ai/chronon/cli/__init__.py +0 -0
  4. ai/chronon/cli/compile/__init__.py +0 -0
  5. ai/chronon/cli/compile/column_hashing.py +40 -17
  6. ai/chronon/cli/compile/compile_context.py +13 -17
  7. ai/chronon/cli/compile/compiler.py +59 -36
  8. ai/chronon/cli/compile/conf_validator.py +251 -99
  9. ai/chronon/cli/compile/display/__init__.py +0 -0
  10. ai/chronon/cli/compile/display/class_tracker.py +6 -16
  11. ai/chronon/cli/compile/display/compile_status.py +10 -10
  12. ai/chronon/cli/compile/display/diff_result.py +79 -14
  13. ai/chronon/cli/compile/fill_templates.py +3 -8
  14. ai/chronon/cli/compile/parse_configs.py +10 -17
  15. ai/chronon/cli/compile/parse_teams.py +38 -34
  16. ai/chronon/cli/compile/serializer.py +3 -9
  17. ai/chronon/cli/compile/version_utils.py +42 -0
  18. ai/chronon/cli/git_utils.py +2 -13
  19. ai/chronon/cli/logger.py +0 -2
  20. ai/chronon/constants.py +1 -1
  21. ai/chronon/group_by.py +47 -47
  22. ai/chronon/join.py +46 -32
  23. ai/chronon/logger.py +1 -2
  24. ai/chronon/model.py +9 -4
  25. ai/chronon/query.py +2 -2
  26. ai/chronon/repo/__init__.py +1 -2
  27. ai/chronon/repo/aws.py +17 -31
  28. ai/chronon/repo/cluster.py +121 -50
  29. ai/chronon/repo/compile.py +14 -8
  30. ai/chronon/repo/constants.py +1 -1
  31. ai/chronon/repo/default_runner.py +32 -54
  32. ai/chronon/repo/explore.py +70 -73
  33. ai/chronon/repo/extract_objects.py +6 -9
  34. ai/chronon/repo/gcp.py +89 -88
  35. ai/chronon/repo/gitpython_utils.py +3 -2
  36. ai/chronon/repo/hub_runner.py +145 -55
  37. ai/chronon/repo/hub_uploader.py +2 -1
  38. ai/chronon/repo/init.py +12 -5
  39. ai/chronon/repo/join_backfill.py +19 -5
  40. ai/chronon/repo/run.py +42 -39
  41. ai/chronon/repo/serializer.py +4 -12
  42. ai/chronon/repo/utils.py +72 -63
  43. ai/chronon/repo/zipline.py +3 -19
  44. ai/chronon/repo/zipline_hub.py +211 -39
  45. ai/chronon/resources/__init__.py +0 -0
  46. ai/chronon/resources/gcp/__init__.py +0 -0
  47. ai/chronon/resources/gcp/group_bys/__init__.py +0 -0
  48. ai/chronon/resources/gcp/group_bys/test/data.py +13 -17
  49. ai/chronon/resources/gcp/joins/__init__.py +0 -0
  50. ai/chronon/resources/gcp/joins/test/data.py +4 -8
  51. ai/chronon/resources/gcp/sources/__init__.py +0 -0
  52. ai/chronon/resources/gcp/sources/test/data.py +9 -6
  53. ai/chronon/resources/gcp/teams.py +9 -21
  54. ai/chronon/source.py +2 -4
  55. ai/chronon/staging_query.py +60 -19
  56. ai/chronon/types.py +3 -2
  57. ai/chronon/utils.py +21 -68
  58. ai/chronon/windows.py +2 -4
  59. {awx_zipline_ai-0.2.1.dist-info → awx_zipline_ai-0.3.1.dist-info}/METADATA +48 -24
  60. awx_zipline_ai-0.3.1.dist-info/RECORD +96 -0
  61. awx_zipline_ai-0.3.1.dist-info/top_level.txt +4 -0
  62. gen_thrift/__init__.py +0 -0
  63. {ai/chronon → gen_thrift}/api/ttypes.py +327 -197
  64. {ai/chronon/api → gen_thrift}/common/ttypes.py +9 -39
  65. gen_thrift/eval/ttypes.py +660 -0
  66. {ai/chronon → gen_thrift}/hub/ttypes.py +12 -131
  67. {ai/chronon → gen_thrift}/observability/ttypes.py +343 -180
  68. {ai/chronon → gen_thrift}/planner/ttypes.py +326 -45
  69. ai/chronon/eval/__init__.py +0 -122
  70. ai/chronon/eval/query_parsing.py +0 -19
  71. ai/chronon/eval/sample_tables.py +0 -100
  72. ai/chronon/eval/table_scan.py +0 -186
  73. ai/chronon/orchestration/ttypes.py +0 -4406
  74. ai/chronon/resources/gcp/README.md +0 -174
  75. ai/chronon/resources/gcp/zipline-cli-install.sh +0 -54
  76. awx_zipline_ai-0.2.1.dist-info/RECORD +0 -93
  77. awx_zipline_ai-0.2.1.dist-info/licenses/LICENSE +0 -202
  78. awx_zipline_ai-0.2.1.dist-info/top_level.txt +0 -3
  79. /jars/__init__.py → /__init__.py +0 -0
  80. {awx_zipline_ai-0.2.1.dist-info → awx_zipline_ai-0.3.1.dist-info}/WHEEL +0 -0
  81. {awx_zipline_ai-0.2.1.dist-info → awx_zipline_ai-0.3.1.dist-info}/entry_points.txt +0 -0
  82. {ai/chronon → gen_thrift}/api/__init__.py +0 -0
  83. {ai/chronon/api/common → gen_thrift/api}/constants.py +0 -0
  84. {ai/chronon/api → gen_thrift}/common/__init__.py +0 -0
  85. {ai/chronon/api → gen_thrift/common}/constants.py +0 -0
  86. {ai/chronon/fetcher → gen_thrift/eval}/__init__.py +0 -0
  87. {ai/chronon/fetcher → gen_thrift/eval}/constants.py +0 -0
  88. {ai/chronon/hub → gen_thrift/fetcher}/__init__.py +0 -0
  89. {ai/chronon/hub → gen_thrift/fetcher}/constants.py +0 -0
  90. {ai/chronon → gen_thrift}/fetcher/ttypes.py +0 -0
  91. {ai/chronon/observability → gen_thrift/hub}/__init__.py +0 -0
  92. {ai/chronon/observability → gen_thrift/hub}/constants.py +0 -0
  93. {ai/chronon/orchestration → gen_thrift/observability}/__init__.py +0 -0
  94. {ai/chronon/orchestration → gen_thrift/observability}/constants.py +0 -0
  95. {ai/chronon → gen_thrift}/planner/__init__.py +0 -0
  96. {ai/chronon → gen_thrift}/planner/constants.py +0 -0
@@ -1,122 +0,0 @@
1
- from typing import Any, List
2
-
3
- from pyspark.sql import DataFrame, SparkSession
4
-
5
- import ai.chronon.api.ttypes as chronon
6
- from ai.chronon.eval.query_parsing import get_tables_from_query
7
- from ai.chronon.eval.sample_tables import sample_tables, sample_with_query
8
- from ai.chronon.eval.table_scan import (
9
- TableScan,
10
- clean_table_name,
11
- table_scans_in_group_by,
12
- table_scans_in_join,
13
- table_scans_in_source,
14
- )
15
-
16
-
17
- def eval(obj: Any) -> List[DataFrame]:
18
-
19
- if isinstance(obj, chronon.Source):
20
- return _run_table_scans(table_scans_in_source(obj))
21
-
22
- elif isinstance(obj, chronon.GroupBy):
23
- return _run_table_scans(table_scans_in_group_by(obj))
24
-
25
- elif isinstance(obj, chronon.Join):
26
- return _run_table_scans(table_scans_in_join(obj))
27
-
28
- elif isinstance(obj, chronon.StagingQuery):
29
- return _sample_and_eval_query(_render_staging_query(obj))
30
-
31
- elif isinstance(obj, str):
32
- has_white_spaces = any(char.isspace() for char in obj)
33
- if has_white_spaces:
34
- return _sample_and_eval_query(obj)
35
- else:
36
- return _sample_and_eval_query(f"SELECT * FROM {obj} LIMIT 1000")
37
-
38
- elif isinstance(obj, chronon.Model):
39
- _run_table_scans(table_scans_in_source(obj.source))
40
-
41
- else:
42
- raise Exception(f"Unsupported object type for: {obj}")
43
-
44
-
45
- def _sample_and_eval_query(query: str) -> DataFrame:
46
-
47
- table_names = get_tables_from_query(query)
48
- sample_tables(table_names)
49
-
50
- clean_query = query
51
- for table_name in table_names:
52
- clean_name = clean_table_name(table_name)
53
- clean_query = clean_query.replace(table_name, clean_name)
54
-
55
- return _run_query(clean_query)
56
-
57
-
58
- def _run_query(query: str) -> DataFrame:
59
- spark = _get_spark()
60
- return spark.sql(query)
61
-
62
-
63
- def _sample_table_scan(table_scan: TableScan) -> str:
64
- table = table_scan.table
65
- output_path = table_scan.output_path()
66
- query = table_scan.raw_scan_query(local_table_view=False)
67
- return sample_with_query(table, query, output_path)
68
-
69
-
70
- def _run_table_scans(table_scans: List[TableScan]) -> List[DataFrame]:
71
- spark = _get_spark()
72
- df_list = []
73
-
74
- for table_scan in table_scans:
75
- output_path = table_scan.output_path()
76
-
77
- status = " (exists)" if output_path.exists() else ""
78
- print(
79
- f"table: {table_scan.table}\n"
80
- f"view: {table_scan.view_name()}\n"
81
- f"local_file: {output_path}{status}\n"
82
- )
83
-
84
- for table_scan in table_scans:
85
-
86
- view_name = table_scan.view_name()
87
- output_path = _sample_table_scan(table_scan)
88
-
89
- print(f"Creating view {view_name} from parquet file {output_path}")
90
- df = spark.read.parquet(str(output_path))
91
- df.createOrReplaceTempView(view_name)
92
-
93
- scan_query = table_scan.scan_query(local_table_view=True)
94
- print(f"Scanning {table_scan.table} with query: \n{scan_query}\n")
95
- df = spark.sql(scan_query)
96
- df.show(5)
97
- df_list.append(df)
98
-
99
- return df_list
100
-
101
-
102
- _spark: SparkSession = None
103
-
104
-
105
- def _get_spark() -> SparkSession:
106
- global _spark
107
- if not _spark:
108
- _spark = (
109
- SparkSession.builder.appName("Chronon Evaluator")
110
- .config("spark.driver.bindAddress", "127.0.0.1")
111
- .config("spark.driver.host", "127.0.0.1")
112
- .config("spark.sql.parquet.columnarReaderBatchSize", "16")
113
- .config("spark.executor.memory", "4g")
114
- .config("spark.driver.memory", "4g")
115
- .config("spark.driver.maxResultSize", "2g")
116
- .getOrCreate()
117
- )
118
- return _spark
119
-
120
-
121
- def _render_staging_query(staging_query: chronon.StagingQuery) -> str:
122
- raise NotImplementedError("Staging query evals are not yet implemented")
@@ -1,19 +0,0 @@
1
- from typing import List
2
-
3
-
4
- def get_tables_from_query(sql_query) -> List[str]:
5
- import sqlglot
6
-
7
- # Parse the query
8
- parsed = sqlglot.parse_one(sql_query, dialect="bigquery")
9
-
10
- # Extract all table references
11
- tables = parsed.find_all(sqlglot.exp.Table)
12
-
13
- table_names = []
14
- for table in tables:
15
- name_parts = [part for part in [table.catalog, table.db, table.name] if part]
16
- table_name = ".".join(name_parts)
17
- table_names.append(table_name)
18
-
19
- return table_names
@@ -1,100 +0,0 @@
1
- import os
2
- from typing import List
3
-
4
- from ai.chronon.eval.table_scan import local_warehouse
5
-
6
-
7
- def sample_with_query(table, query, output_path) -> str:
8
- # if file exists, skip
9
- if os.path.exists(output_path):
10
- print(f"File {output_path} already exists. Skipping sampling.")
11
- return output_path
12
-
13
- raw_scan_query = query
14
- print(f"Sampling {table} with query: {raw_scan_query}")
15
-
16
- _sample_internal(raw_scan_query, output_path)
17
- return output_path
18
-
19
-
20
- def sample_tables(table_names: List[str]) -> None:
21
-
22
- for table in table_names:
23
- query = f"SELECT * FROM {table} LIMIT 10000"
24
- sample_with_query(table, query, local_warehouse / f"{table}.parquet")
25
-
26
-
27
- _sampling_engine = os.getenv("CHRONON_SAMPLING_ENGINE", "bigquery")
28
-
29
-
30
- def _sample_internal(query, output_path) -> str:
31
- if _sampling_engine == "bigquery":
32
- _sample_bigquery(query, output_path)
33
- elif _sampling_engine == "trino":
34
- _sample_trino(query, output_path)
35
- else:
36
- raise ValueError("Invalid sampling engine")
37
-
38
-
39
- def _sample_trino(query, output_path):
40
- raise NotImplementedError("Trino sampling is not yet implemented")
41
-
42
-
43
- def _sample_bigquery(query, output_path):
44
-
45
- from google.cloud import bigquery
46
-
47
- project_id = os.getenv("GCP_PROJECT_ID")
48
- assert project_id, "Please set the GCP_PROJECT_ID environment variable"
49
-
50
- client = bigquery.Client(project=project_id)
51
-
52
- results = client.query_and_wait(query)
53
-
54
- df = results.to_dataframe()
55
- df.to_parquet(output_path)
56
-
57
-
58
- def _sample_bigquery_fast(query, destination_path):
59
- import os
60
-
61
- import pyarrow.parquet as pq
62
- from google.cloud import bigquery
63
- from google.cloud.bigquery_storage import BigQueryReadClient
64
- from google.cloud.bigquery_storage_v1.types import DataFormat, ReadSession
65
-
66
- project_id = os.getenv("GCP_PROJECT_ID")
67
- assert project_id, "Please set the GCP_PROJECT_ID environment variable"
68
-
69
- client = bigquery.Client(project=project_id)
70
- bqstorage_client = BigQueryReadClient()
71
-
72
- # Create query job
73
- query_job = client.query(query)
74
- table_ref = query_job.destination
75
-
76
- # Create read session
77
- read_session = ReadSession()
78
- read_session.table = table_ref.to_bqstorage()
79
- read_session.data_format = DataFormat.ARROW
80
-
81
- print("Fetching from BigQuery... (this might take a while)")
82
-
83
- session = bqstorage_client.create_read_session(
84
- parent=f"projects/{client.project}",
85
- read_session=read_session,
86
- max_stream_count=1,
87
- )
88
-
89
- print("Writing to local parquet file...")
90
-
91
- # Read using Arrow
92
- stream = bqstorage_client.read_rows(session.streams[0].name)
93
- table = stream.to_arrow(read_session=session)
94
-
95
- # Write to Parquet directly
96
- pq.write_table(table, destination_path)
97
-
98
- print(f"Wrote results to {destination_path}")
99
-
100
- return destination_path
@@ -1,186 +0,0 @@
1
- import hashlib
2
- import os
3
- import re
4
- from dataclasses import dataclass
5
- from datetime import datetime, timedelta
6
- from pathlib import Path
7
- from typing import List, Tuple
8
-
9
- import ai.chronon.api.ttypes as chronon
10
-
11
-
12
- def clean_table_name(name: str) -> str:
13
- return re.sub(r"[^a-zA-Z0-9_]", "_", name)
14
-
15
-
16
- local_warehouse = Path(os.getenv("CHRONON_ROOT", os.getcwd())) / "local_warehouse"
17
- limit = int(os.getenv("SAMPLE_LIMIT", "100"))
18
- # create local_warehouse if it doesn't exist
19
- local_warehouse.mkdir(parents=True, exist_ok=True)
20
-
21
-
22
- @dataclass
23
- class TableScan:
24
- table: str
25
- partition_col: str
26
- partition_date: str
27
- query: chronon.Query
28
- is_mutations: bool = False
29
-
30
- def output_path(self) -> str:
31
- return Path(local_warehouse) / f"{self.view_name()}.parquet"
32
-
33
- def view_name(self) -> str:
34
- return clean_table_name(self.table) + "_" + self.where_id()
35
-
36
- def table_name(self, local_table_view) -> str:
37
- return self.view_name() if local_table_view else self.table
38
-
39
- def where_id(self) -> str:
40
- return "_" + hashlib.md5(self.where_block().encode()).hexdigest()[:3]
41
-
42
- def where_block(self) -> str:
43
- wheres = []
44
- partition_scan = f"{self.partition_col} = '{self.partition_date}'"
45
- wheres.append(partition_scan)
46
-
47
- if self.query.wheres:
48
- wheres.extend(self.query.wheres)
49
-
50
- return " AND\n ".join([f"({where})" for where in wheres])
51
-
52
- def raw_scan_query(self, local_table_view: bool = True) -> str:
53
- return f"""
54
- SELECT * FROM {self.table_name(local_table_view)}
55
- WHERE
56
- {self.where_block()}
57
- LIMIT {limit}
58
- """
59
-
60
- def scan_query(self, local_table_view=True) -> str:
61
- selects = []
62
- base_selects = self.query.selects.copy()
63
-
64
- if self.is_mutations:
65
- base_selects["is_before"] = coalesce(self.query.reversalColumn, "is_before")
66
- base_selects["mutation_ts"] = coalesce(
67
- self.query.mutationTimeColumn, "mutation_ts"
68
- )
69
-
70
- if self.query.timeColumn:
71
- base_selects["ts"] = coalesce(self.query.timeColumn, "ts")
72
-
73
- for k, v in base_selects.items():
74
- selects.append(f"{v} as {k}")
75
- select_clauses = ",\n ".join(selects)
76
-
77
- return f"""
78
- SELECT
79
- {select_clauses}
80
- FROM
81
- {self.table_name(local_table_view)}
82
- WHERE
83
- {self.where_block()}
84
- LIMIT
85
- {limit}
86
- """
87
-
88
-
89
- # TODO: use teams.py to get the default date column
90
- DEFAULT_DATE_COLUMN = "_date"
91
- DEFAULT_DATE_FORMAT = "%Y-%m-%d"
92
-
93
- two_days_ago = (datetime.now() - timedelta(days=2)).strftime(DEFAULT_DATE_FORMAT)
94
-
95
- _sample_date = os.getenv("SAMPLE_DATE", two_days_ago)
96
-
97
-
98
- def get_date(query: chronon.Query) -> Tuple[str, str]:
99
- assert query and query.selects, "please specify source.query.selects"
100
-
101
- partition_col = query.selects.get("ds", DEFAULT_DATE_COLUMN)
102
- partition_date = coalesce(query.endPartition, _sample_date)
103
-
104
- return (partition_col, partition_date)
105
-
106
-
107
- def coalesce(*args):
108
- for arg in args:
109
- if arg:
110
- return arg
111
-
112
-
113
- def table_scans_in_source(source: chronon.Source) -> List[TableScan]:
114
- result = []
115
-
116
- if not source:
117
- return result
118
-
119
- if source.entities:
120
- query: chronon.Query = source.entities.query
121
- col, date = get_date(query)
122
-
123
- snapshot = TableScan(source.entities.snapshotTable, col, date, query)
124
- result.append(snapshot)
125
-
126
- if source.entities.mutationTable:
127
- mutations = TableScan(source.entities.mutationTable, col, date, query, True)
128
- result.append(mutations)
129
-
130
- if source.events:
131
- query = source.events.query
132
- col, date = get_date(query)
133
- table = TableScan(source.events.table, col, date, query)
134
- result.append(table)
135
-
136
- if source.joinSource:
137
- result.extend(table_scans_in_source(source.joinSource.join.left))
138
-
139
- return result
140
-
141
-
142
- def table_scans_in_sources(sources: List[chronon.Source]) -> List[TableScan]:
143
- result = []
144
-
145
- for source in sources:
146
- result.extend(table_scans_in_source(source))
147
-
148
- return result
149
-
150
-
151
- def table_scans_in_group_by(gb: chronon.GroupBy) -> List[TableScan]:
152
- if not gb:
153
- return []
154
-
155
- return table_scans_in_sources(gb.sources)
156
-
157
-
158
- def table_scans_in_join(join: chronon.Join) -> List[TableScan]:
159
-
160
- result = []
161
-
162
- if not join:
163
- return result
164
-
165
- result.extend(table_scans_in_source(join.left))
166
-
167
- parts: List[chronon.JoinPart] = join.joinParts
168
- if parts:
169
- for part in parts:
170
- result.extend(table_scans_in_group_by(part.groupBy))
171
-
172
- bootstraps: List[chronon.BootstrapPart] = join.bootstrapParts
173
- if bootstraps:
174
- for bootstrap in bootstraps:
175
- query = bootstrap.query
176
- col, date = get_date(query)
177
- bootstrap = TableScan(bootstrap.table, col, date, query)
178
-
179
- result.append(bootstrap)
180
-
181
- if join.labelParts:
182
- labelParts: List[chronon.JoinPart] = join.labelParts.labels
183
- for part in labelParts:
184
- result.extend(table_scans_in_sources(part.groupBy))
185
-
186
- return result