squirrels 0.1.0__py3-none-any.whl → 0.6.0.post0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. dateutils/__init__.py +6 -0
  2. dateutils/_enums.py +25 -0
  3. squirrels/dateutils.py → dateutils/_implementation.py +409 -380
  4. dateutils/types.py +6 -0
  5. squirrels/__init__.py +21 -18
  6. squirrels/_api_routes/__init__.py +5 -0
  7. squirrels/_api_routes/auth.py +337 -0
  8. squirrels/_api_routes/base.py +196 -0
  9. squirrels/_api_routes/dashboards.py +156 -0
  10. squirrels/_api_routes/data_management.py +148 -0
  11. squirrels/_api_routes/datasets.py +220 -0
  12. squirrels/_api_routes/project.py +289 -0
  13. squirrels/_api_server.py +552 -134
  14. squirrels/_arguments/__init__.py +0 -0
  15. squirrels/_arguments/init_time_args.py +83 -0
  16. squirrels/_arguments/run_time_args.py +111 -0
  17. squirrels/_auth.py +777 -0
  18. squirrels/_command_line.py +239 -107
  19. squirrels/_compile_prompts.py +147 -0
  20. squirrels/_connection_set.py +94 -0
  21. squirrels/_constants.py +141 -64
  22. squirrels/_dashboards.py +179 -0
  23. squirrels/_data_sources.py +570 -0
  24. squirrels/_dataset_types.py +91 -0
  25. squirrels/_env_vars.py +209 -0
  26. squirrels/_exceptions.py +29 -0
  27. squirrels/_http_error_responses.py +52 -0
  28. squirrels/_initializer.py +319 -110
  29. squirrels/_logging.py +121 -0
  30. squirrels/_manifest.py +357 -187
  31. squirrels/_mcp_server.py +578 -0
  32. squirrels/_model_builder.py +69 -0
  33. squirrels/_model_configs.py +74 -0
  34. squirrels/_model_queries.py +52 -0
  35. squirrels/_models.py +1201 -0
  36. squirrels/_package_data/base_project/.env +7 -0
  37. squirrels/_package_data/base_project/.env.example +44 -0
  38. squirrels/_package_data/base_project/connections.yml +16 -0
  39. squirrels/_package_data/base_project/dashboards/dashboard_example.py +40 -0
  40. squirrels/_package_data/base_project/dashboards/dashboard_example.yml +22 -0
  41. squirrels/_package_data/base_project/docker/.dockerignore +16 -0
  42. squirrels/_package_data/base_project/docker/Dockerfile +16 -0
  43. squirrels/_package_data/base_project/docker/compose.yml +7 -0
  44. squirrels/_package_data/base_project/duckdb_init.sql +10 -0
  45. squirrels/_package_data/base_project/gitignore +13 -0
  46. squirrels/_package_data/base_project/macros/macros_example.sql +17 -0
  47. squirrels/_package_data/base_project/models/builds/build_example.py +26 -0
  48. squirrels/_package_data/base_project/models/builds/build_example.sql +16 -0
  49. squirrels/_package_data/base_project/models/builds/build_example.yml +57 -0
  50. squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +17 -0
  51. squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +32 -0
  52. squirrels/_package_data/base_project/models/federates/federate_example.py +51 -0
  53. squirrels/_package_data/base_project/models/federates/federate_example.sql +21 -0
  54. squirrels/_package_data/base_project/models/federates/federate_example.yml +65 -0
  55. squirrels/_package_data/base_project/models/sources.yml +38 -0
  56. squirrels/_package_data/base_project/parameters.yml +142 -0
  57. squirrels/_package_data/base_project/pyconfigs/connections.py +19 -0
  58. squirrels/_package_data/base_project/pyconfigs/context.py +96 -0
  59. squirrels/_package_data/base_project/pyconfigs/parameters.py +141 -0
  60. squirrels/_package_data/base_project/pyconfigs/user.py +56 -0
  61. squirrels/_package_data/base_project/resources/expenses.db +0 -0
  62. squirrels/_package_data/base_project/resources/public/.gitkeep +0 -0
  63. squirrels/_package_data/base_project/resources/weather.db +0 -0
  64. squirrels/_package_data/base_project/seeds/seed_categories.csv +6 -0
  65. squirrels/_package_data/base_project/seeds/seed_categories.yml +15 -0
  66. squirrels/_package_data/base_project/seeds/seed_subcategories.csv +15 -0
  67. squirrels/_package_data/base_project/seeds/seed_subcategories.yml +21 -0
  68. squirrels/_package_data/base_project/squirrels.yml.j2 +61 -0
  69. squirrels/_package_data/base_project/tmp/.gitignore +2 -0
  70. squirrels/_package_data/templates/login_successful.html +53 -0
  71. squirrels/_package_data/templates/squirrels_studio.html +22 -0
  72. squirrels/_package_loader.py +29 -0
  73. squirrels/_parameter_configs.py +592 -0
  74. squirrels/_parameter_options.py +348 -0
  75. squirrels/_parameter_sets.py +207 -0
  76. squirrels/_parameters.py +1703 -0
  77. squirrels/_project.py +796 -0
  78. squirrels/_py_module.py +122 -0
  79. squirrels/_request_context.py +33 -0
  80. squirrels/_schemas/__init__.py +0 -0
  81. squirrels/_schemas/auth_models.py +83 -0
  82. squirrels/_schemas/query_param_models.py +70 -0
  83. squirrels/_schemas/request_models.py +26 -0
  84. squirrels/_schemas/response_models.py +286 -0
  85. squirrels/_seeds.py +97 -0
  86. squirrels/_sources.py +112 -0
  87. squirrels/_utils.py +540 -149
  88. squirrels/_version.py +1 -3
  89. squirrels/arguments.py +7 -0
  90. squirrels/auth.py +4 -0
  91. squirrels/connections.py +3 -0
  92. squirrels/dashboards.py +3 -0
  93. squirrels/data_sources.py +14 -282
  94. squirrels/parameter_options.py +13 -189
  95. squirrels/parameters.py +14 -801
  96. squirrels/types.py +18 -0
  97. squirrels-0.6.0.post0.dist-info/METADATA +148 -0
  98. squirrels-0.6.0.post0.dist-info/RECORD +101 -0
  99. {squirrels-0.1.0.dist-info → squirrels-0.6.0.post0.dist-info}/WHEEL +1 -2
  100. {squirrels-0.1.0.dist-info → squirrels-0.6.0.post0.dist-info}/entry_points.txt +1 -0
  101. squirrels-0.6.0.post0.dist-info/licenses/LICENSE +201 -0
  102. squirrels/_credentials_manager.py +0 -87
  103. squirrels/_module_loader.py +0 -37
  104. squirrels/_parameter_set.py +0 -151
  105. squirrels/_renderer.py +0 -286
  106. squirrels/_timed_imports.py +0 -37
  107. squirrels/connection_set.py +0 -126
  108. squirrels/package_data/base_project/.gitignore +0 -4
  109. squirrels/package_data/base_project/connections.py +0 -21
  110. squirrels/package_data/base_project/database/sample_database.db +0 -0
  111. squirrels/package_data/base_project/database/seattle_weather.db +0 -0
  112. squirrels/package_data/base_project/datasets/sample_dataset/context.py +0 -8
  113. squirrels/package_data/base_project/datasets/sample_dataset/database_view1.py +0 -23
  114. squirrels/package_data/base_project/datasets/sample_dataset/database_view1.sql.j2 +0 -7
  115. squirrels/package_data/base_project/datasets/sample_dataset/final_view.py +0 -10
  116. squirrels/package_data/base_project/datasets/sample_dataset/final_view.sql.j2 +0 -2
  117. squirrels/package_data/base_project/datasets/sample_dataset/parameters.py +0 -30
  118. squirrels/package_data/base_project/datasets/sample_dataset/selections.cfg +0 -6
  119. squirrels/package_data/base_project/squirrels.yaml +0 -26
  120. squirrels/package_data/static/favicon.ico +0 -0
  121. squirrels/package_data/static/script.js +0 -234
  122. squirrels/package_data/static/style.css +0 -110
  123. squirrels/package_data/templates/index.html +0 -32
  124. squirrels-0.1.0.dist-info/LICENSE +0 -22
  125. squirrels-0.1.0.dist-info/METADATA +0 -67
  126. squirrels-0.1.0.dist-info/RECORD +0 -40
  127. squirrels-0.1.0.dist-info/top_level.txt +0 -1
squirrels/_utils.py CHANGED
@@ -1,149 +1,540 @@
1
- from typing import List, Dict, Optional, Union, Any
2
- from types import ModuleType
3
- from pathlib import Path
4
- from importlib.machinery import SourceFileLoader
5
- import json
6
-
7
- from squirrels._timed_imports import jinja2 as j2, pandas as pd, pd_types
8
-
9
- FilePath = Union[str, Path]
10
-
11
-
12
- # Custom Exceptions
13
- class InvalidInputError(Exception):
14
- pass
15
-
16
- class ConfigurationError(Exception):
17
- pass
18
-
19
- class AbstractMethodCallError(NotImplementedError):
20
- def __init__(self, cls, method, more_message = ""):
21
- message = f"Abstract method {method}() not implemented in {cls.__name__}."
22
- super().__init__(message + more_message)
23
-
24
-
25
- # Utility functions/variables
26
- j2_env = j2.Environment(loader=j2.FileSystemLoader('.'))
27
-
28
-
29
- def import_file_as_module(filepath: Optional[FilePath]) -> ModuleType:
30
- """
31
- Imports a python file as a module.
32
-
33
- Parameters:
34
- filepath: The path to the file to import.
35
-
36
- Returns:
37
- The imported module.
38
- """
39
- filepath = str(filepath) if filepath is not None else None
40
- return SourceFileLoader(filepath, filepath).load_module() if filepath is not None else None
41
-
42
-
43
- def join_paths(*paths: FilePath) -> Path:
44
- """
45
- Joins paths together.
46
-
47
- Parameters:
48
- paths: The paths to join.
49
-
50
- Returns:
51
- The joined path.
52
- """
53
- return Path(*paths)
54
-
55
-
56
- def normalize_name(name: str) -> str:
57
- """
58
- Normalizes names to the convention of the squirrels manifest file.
59
-
60
- Parameters:
61
- name: The name to normalize.
62
-
63
- Returns:
64
- The normalized name.
65
- """
66
- return name.replace('-', '_')
67
-
68
-
69
- def normalize_name_for_api(name: str) -> str:
70
- """
71
- Normalizes names to the REST API convention.
72
-
73
- Parameters:
74
- name: The name to normalize.
75
-
76
- Returns:
77
- The normalized name.
78
- """
79
- return name.replace('_', '-')
80
-
81
-
82
- def get_row_value(row: pd.Series, value: str) -> Any:
83
- """
84
- Gets the value of a row from a pandas Series.
85
-
86
- Parameters:
87
- row: The row to get the value from.
88
- value: The name of the column to get the value from.
89
-
90
- Returns:
91
- The value of the column.
92
-
93
- Raises:
94
- ConfigurationError: If the column does not exist.
95
- """
96
- try:
97
- result = row[value]
98
- except KeyError as e:
99
- raise ConfigurationError(f'Column name "{value}" does not exist') from e
100
- return result
101
-
102
-
103
- def df_to_json(df: pd.DataFrame, dimensions: List[str] = None) -> Dict[str, Any]:
104
- """
105
- Convert a pandas DataFrame to the same JSON format that the dataset result API of Squirrels outputs.
106
-
107
- Parameters:
108
- df: The dataframe to convert into JSON
109
- dimensions: The list of declared dimensions. If None, all non-numeric columns are assumed as dimensions
110
-
111
- Returns:
112
- The JSON response of a Squirrels dataset result API
113
- """
114
- in_df_json = json.loads(df.to_json(orient='table', index=False))
115
- out_fields = []
116
- non_numeric_fields = []
117
- for in_column in in_df_json["schema"]["fields"]:
118
- col_name: str = in_column["name"]
119
- out_column = {"name": col_name, "type": in_column["type"]}
120
- out_fields.append(out_column)
121
-
122
- if not pd_types.is_numeric_dtype(df[col_name].dtype):
123
- non_numeric_fields.append(col_name)
124
-
125
- out_dimensions = non_numeric_fields if dimensions is None else dimensions
126
- out_schema = {"fields": out_fields, "dimensions": out_dimensions}
127
- return {"response_version": 0, "schema": out_schema, "data": in_df_json["data"]}
128
-
129
-
130
- def load_json_or_comma_delimited_str_as_list(input_str: str) -> List[str]:
131
- """
132
- Given a string, load it as a list either by json string or comma delimited value
133
-
134
- Parameters:
135
- input_str: The input string
136
-
137
- Returns:
138
- The list representation of the input string
139
- """
140
- output = None
141
- try:
142
- output = json.loads(input_str)
143
- except json.decoder.JSONDecodeError:
144
- pass
145
-
146
- if isinstance(output, list):
147
- return output
148
- else:
149
- return [] if input_str == "" else input_str.split(",")
1
+ from typing import Sequence, Optional, Union, TypeVar, Callable, Iterable, Literal, Any
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+ import os, time, logging, json, duckdb, polars as pl, yaml
5
+ import jinja2 as j2, jinja2.nodes as j2_nodes
6
+ import sqlglot, sqlglot.expressions, asyncio, hashlib, inspect, base64
7
+
8
+ from . import _constants as c
9
+ from ._exceptions import ConfigurationError
10
+
11
+ FilePath = Union[str, Path]
12
+
13
+ # Polars <-> Squirrels dtypes mappings (except Decimal)
14
+ polars_dtypes_to_sqrl_dtypes: dict[type[pl.DataType], list[str]] = {
15
+ pl.String: ["string", "varchar", "char", "text"],
16
+ pl.Int8: ["tinyint", "int1"],
17
+ pl.Int16: ["smallint", "short", "int2"],
18
+ pl.Int32: ["integer", "int", "int4"],
19
+ pl.Int64: ["bigint", "long", "int8"],
20
+ pl.Float32: ["float", "float4", "real"],
21
+ pl.Float64: ["double", "float8"],
22
+ pl.Boolean: ["boolean", "bool", "logical"],
23
+ pl.Date: ["date"],
24
+ pl.Time: ["time"],
25
+ pl.Datetime: ["timestamp", "datetime"],
26
+ pl.Duration: ["interval"],
27
+ pl.Binary: ["blob", "binary", "varbinary"]
28
+ }
29
+
30
+ sqrl_dtypes_to_polars_dtypes: dict[str, type[pl.DataType]] = {
31
+ sqrl_type: k for k, v in polars_dtypes_to_sqrl_dtypes.items() for sqrl_type in v
32
+ }
33
+
34
+
35
+ ## Other utility classes
36
+
37
+ class Logger(logging.Logger):
38
+ def info(self, msg: str, *, data: dict[str, Any] = {}, **kwargs) -> None:
39
+ super().info(msg, extra={"data": data}, **kwargs)
40
+
41
+ def log_activity_time(self, activity: str, start_timestamp: float, *, additional_data: dict[str, Any] = {}) -> None:
42
+ end_timestamp = time.time()
43
+ time_taken = round((end_timestamp-start_timestamp) * 10**3, 3)
44
+ data = {
45
+ "activity": activity,
46
+ "start_timestamp": start_timestamp,
47
+ "end_timestamp": end_timestamp,
48
+ "time_taken_ms": time_taken,
49
+ **additional_data
50
+ }
51
+ self.info(f'Time taken for "{activity}": {time_taken}ms', data=data)
52
+
53
+
54
+ class EnvironmentWithMacros(j2.Environment):
55
+ def __init__(self, logger: logging.Logger, loader: j2.FileSystemLoader, *args, **kwargs):
56
+ super().__init__(*args, loader=loader, **kwargs)
57
+ self._logger = logger
58
+ self._macros = self._load_macro_templates(logger)
59
+
60
+ def _load_macro_templates(self, logger: logging.Logger) -> str:
61
+ macros_dirs = self._get_macro_folders_from_packages()
62
+ macro_templates = []
63
+ for macros_dir in macros_dirs:
64
+ for root, _, files in os.walk(macros_dir):
65
+ files: list[str]
66
+ for filename in files:
67
+ if any(filename.endswith(x) for x in [".sql", ".j2", ".jinja", ".jinja2"]):
68
+ filepath = Path(root, filename)
69
+ logger.info(f"Loaded macros from: {filepath}")
70
+ with open(filepath, 'r') as f:
71
+ content = f.read()
72
+ macro_templates.append(content)
73
+ return '\n'.join(macro_templates)
74
+
75
+ def _get_macro_folders_from_packages(self) -> list[Path]:
76
+ assert isinstance(self.loader, j2.FileSystemLoader)
77
+ packages_folder = Path(self.loader.searchpath[0], c.PACKAGES_FOLDER)
78
+
79
+ subdirectories = []
80
+ if os.path.exists(packages_folder):
81
+ for item in os.listdir(packages_folder):
82
+ item_path = Path(packages_folder, item)
83
+ if os.path.isdir(item_path):
84
+ subdirectories.append(Path(item_path, c.MACROS_FOLDER))
85
+
86
+ subdirectories.append(Path(self.loader.searchpath[0], c.MACROS_FOLDER))
87
+ return subdirectories
88
+
89
+ def _parse(self, source: str, name: str | None, filename: str | None) -> j2_nodes.Template:
90
+ source = self._macros + source
91
+ return super()._parse(source, name, filename)
92
+
93
+
94
+ ## Utility functions/variables
95
+
96
+ def render_string(raw_str: str, *, project_path: str = ".", **kwargs) -> str:
97
+ """
98
+ Given a template string, render it with the given keyword arguments
99
+
100
+ Arguments:
101
+ raw_str: The template string
102
+ kwargs: The keyword arguments
103
+
104
+ Returns:
105
+ The rendered string
106
+ """
107
+ j2_env = j2.Environment(loader=j2.FileSystemLoader(project_path))
108
+ template = j2_env.from_string(raw_str)
109
+ return template.render(kwargs)
110
+
111
+
112
+ def read_file(filepath: FilePath) -> str:
113
+ """
114
+ Reads a file and return its content if required
115
+
116
+ Arguments:
117
+ filepath (str | pathlib.Path): The path to the file to read
118
+
119
+ Returns:
120
+ Content of the file, or None if doesn't exist and not required
121
+ """
122
+ try:
123
+ with open(filepath, 'r') as f:
124
+ return f.read()
125
+ except FileNotFoundError as e:
126
+ raise ConfigurationError(f"Required file not found: '{str(filepath)}'") from e
127
+
128
+
129
+ def normalize_name(name: str) -> str:
130
+ """
131
+ Normalizes names to the convention of the squirrels manifest file (with underscores instead of dashes).
132
+
133
+ Arguments:
134
+ name: The name to normalize.
135
+
136
+ Returns:
137
+ The normalized name.
138
+ """
139
+ return name.replace('-', '_')
140
+
141
+
142
+ def normalize_name_for_api(name: str) -> str:
143
+ """
144
+ Normalizes names to the REST API convention (with dashes instead of underscores).
145
+
146
+ Arguments:
147
+ name: The name to normalize.
148
+
149
+ Returns:
150
+ The normalized name.
151
+ """
152
+ return name.replace('_', '-')
153
+
154
+
155
+ def load_json_or_comma_delimited_str_as_list(input_str: Union[str, Sequence]) -> Sequence[str]:
156
+ """
157
+ Given a string, load it as a list either by json string or comma delimited value
158
+
159
+ Arguments:
160
+ input_str: The input string
161
+
162
+ Returns:
163
+ The list representation of the input string
164
+ """
165
+ if not isinstance(input_str, str):
166
+ return (input_str)
167
+
168
+ output = None
169
+ try:
170
+ output = json.loads(input_str)
171
+ except json.decoder.JSONDecodeError:
172
+ pass
173
+
174
+ if isinstance(output, list):
175
+ return output
176
+ elif input_str == "":
177
+ return []
178
+ else:
179
+ return [x.strip() for x in input_str.split(",")]
180
+
181
+
182
+ X = TypeVar('X'); Y = TypeVar('Y')
183
+ def process_if_not_none(input_val: Optional[X], processor: Callable[[X], Y]) -> Optional[Y]:
184
+ """
185
+ Given a input value and a function that processes the value, return the output of the function unless input is None
186
+
187
+ Arguments:
188
+ input_val: The input value
189
+ processor: The function that processes the input value
190
+
191
+ Returns:
192
+ The output type of "processor" or None if input value if None
193
+ """
194
+ if input_val is None:
195
+ return None
196
+ return processor(input_val)
197
+
198
+
199
+ def _read_duckdb_init_sql(
200
+ *,
201
+ datalake_db_path: str | None = None,
202
+ ) -> str:
203
+ """
204
+ Reads and caches the duckdb init file content.
205
+ Returns None if file doesn't exist or is empty.
206
+ """
207
+ try:
208
+ init_contents = []
209
+ global_init_path = Path(os.path.expanduser('~'), c.GLOBAL_ENV_FOLDER, c.DUCKDB_INIT_FILE)
210
+ if global_init_path.exists():
211
+ with open(global_init_path, 'r') as f:
212
+ init_contents.append(f.read())
213
+
214
+ if Path(c.DUCKDB_INIT_FILE).exists():
215
+ with open(c.DUCKDB_INIT_FILE, 'r') as f:
216
+ init_contents.append(f.read())
217
+
218
+ if datalake_db_path:
219
+ attach_stmt = f"ATTACH '{datalake_db_path}' AS vdl (READ_ONLY);"
220
+ init_contents.append(attach_stmt)
221
+ use_stmt = f"USE vdl;"
222
+ init_contents.append(use_stmt)
223
+
224
+ init_sql = "\n\n".join(init_contents).strip()
225
+ return init_sql
226
+ except Exception as e:
227
+ raise ConfigurationError(f"Failed to read {c.DUCKDB_INIT_FILE}: {str(e)}") from e
228
+
229
+ def create_duckdb_connection(
230
+ db_path: str | Path = ":memory:",
231
+ *,
232
+ datalake_db_path: str | None = None
233
+ ) -> duckdb.DuckDBPyConnection:
234
+ """
235
+ Creates a DuckDB connection and initializes it with statements from duckdb init file
236
+
237
+ Arguments:
238
+ filepath: Path to the DuckDB database file. Defaults to in-memory database.
239
+ datalake_db_path: The path to the VDL catalog database if applicable. If exists, this is attached as 'vdl' (READ_ONLY). Default is None.
240
+
241
+ Returns:
242
+ A DuckDB connection (which must be closed after use)
243
+ """
244
+ conn = duckdb.connect(db_path)
245
+
246
+ try:
247
+ init_sql = _read_duckdb_init_sql(datalake_db_path=datalake_db_path)
248
+ conn.execute(init_sql)
249
+ except Exception as e:
250
+ conn.close()
251
+ raise ConfigurationError(f"Failed to execute {c.DUCKDB_INIT_FILE}: {str(e)}") from e
252
+
253
+ return conn
254
+
255
+
256
+ def run_sql_on_dataframes(sql_query: str, dataframes: dict[str, pl.LazyFrame]) -> pl.DataFrame:
257
+ """
258
+ Runs a SQL query against a collection of dataframes
259
+
260
+ Arguments:
261
+ sql_query: The SQL query to run
262
+ dataframes: A dictionary of table names to their polars LazyFrame
263
+
264
+ Returns:
265
+ The result as a polars Dataframe from running the query
266
+ """
267
+ duckdb_conn = create_duckdb_connection()
268
+
269
+ try:
270
+ for name, df in dataframes.items():
271
+ duckdb_conn.register(name, df)
272
+
273
+ result_df = duckdb_conn.sql(sql_query).pl()
274
+ finally:
275
+ duckdb_conn.close()
276
+
277
+ return result_df
278
+
279
+
280
+ async def run_polars_sql_on_dataframes(
281
+ sql_query: str, dataframes: dict[str, pl.LazyFrame], *, timeout_seconds: float = 2.0, max_rows: int | None = None
282
+ ) -> pl.DataFrame:
283
+ """
284
+ Runs a SQL query against a collection of dataframes using Polars SQL (more secure than DuckDB for user input).
285
+
286
+ Arguments:
287
+ sql_query: The SQL query to run (Polars SQL dialect)
288
+ dataframes: A dictionary of table names to their polars LazyFrame
289
+ timeout_seconds: Maximum execution time in seconds (default 2.0)
290
+ max_rows: Maximum number of rows to collect. Collects at most max_rows + 1 rows
291
+ to allow overflow detection without loading unbounded results into memory.
292
+
293
+ Returns:
294
+ The result as a polars DataFrame from running the query (limited to max_rows + 1)
295
+
296
+ Raises:
297
+ ConfigurationError: If the query is invalid or insecure
298
+ """
299
+ # Validate the SQL query
300
+ _validate_sql_query_security(sql_query, dataframes)
301
+
302
+ # Execute with timeout
303
+ try:
304
+ loop = asyncio.get_event_loop()
305
+ result = await asyncio.wait_for(
306
+ loop.run_in_executor(None, _run_polars_sql_sync, sql_query, dataframes, max_rows),
307
+ timeout=timeout_seconds
308
+ )
309
+ return result
310
+ except asyncio.TimeoutError as e:
311
+ raise ConfigurationError(f"SQL query execution exceeded timeout of {timeout_seconds} seconds") from e
312
+
313
+
314
+ def _run_polars_sql_sync(sql_query: str, dataframes: dict[str, pl.LazyFrame], max_rows: int | None) -> pl.DataFrame:
315
+ """
316
+ Synchronous execution of Polars SQL.
317
+
318
+ Arguments:
319
+ sql_query: The SQL query to run
320
+ dataframes: A dictionary of table names to their polars LazyFrame
321
+ max_rows: Maximum number of rows to collect.
322
+ """
323
+ ctx = pl.SQLContext(**dataframes)
324
+ result = ctx.execute(sql_query, eager=False)
325
+ if max_rows is not None:
326
+ result = result.limit(max_rows)
327
+ return result.collect()
328
+
329
+
330
+ def _validate_sql_query_security(sql_query: str, dataframes: dict[str, pl.LazyFrame]) -> None:
331
+ """
332
+ Validates that a SQL query is safe to execute.
333
+
334
+ Enforces:
335
+ - Single statement only
336
+ - Read-only operations (SELECT/WITH/UNION)
337
+ - Table references limited to registered frames (excluding CTE names)
338
+
339
+ Arguments:
340
+ sql_query: The SQL query to validate
341
+ dataframes: Dictionary of allowed table names
342
+
343
+ Raises:
344
+ ConfigurationError: If validation fails
345
+ """
346
+ try:
347
+ parsed = sqlglot.parse(sql_query)
348
+ except Exception as e:
349
+ raise ConfigurationError(f"Failed to parse SQL query: {str(e)}") from e
350
+
351
+ # Enforce single statement
352
+ if len(parsed) != 1:
353
+ raise ConfigurationError(f"Only single SQL statements are allowed. Found {len(parsed)} statements.")
354
+
355
+ statement = parsed[0]
356
+
357
+ # Enforce read-only: allow SELECT, WITH (CTE), UNION, INTERSECT, EXCEPT
358
+ allowed_types = (
359
+ sqlglot.expressions.Select,
360
+ sqlglot.expressions.Union,
361
+ sqlglot.expressions.Intersect,
362
+ sqlglot.expressions.Except,
363
+ )
364
+
365
+ if not isinstance(statement, allowed_types):
366
+ raise ConfigurationError(
367
+ f"Only read-only SQL statements (SELECT, WITH, UNION, INTERSECT, EXCEPT) are allowed. "
368
+ f"Found: {type(statement).__name__}"
369
+ )
370
+
371
+ # Collect CTE names (these are temporary tables created by WITH clauses)
372
+ cte_names: set[str] = set()
373
+ for cte in statement.find_all(sqlglot.expressions.CTE):
374
+ if cte.alias:
375
+ cte_names.add(cte.alias)
376
+
377
+ # Validate table references (excluding CTE names)
378
+ allowed_tables = set(dataframes.keys()) | cte_names
379
+ for table in statement.find_all(sqlglot.expressions.Table):
380
+ table_name = table.name
381
+ if table_name not in allowed_tables:
382
+ raise ConfigurationError(
383
+ f"Table reference '{table_name}' is not allowed. "
384
+ f"Only the following tables are available: {sorted(dataframes.keys())}"
385
+ )
386
+
387
+
388
+ def load_yaml_config(filepath: FilePath) -> dict:
389
+ """
390
+ Loads a YAML config file
391
+
392
+ Arguments:
393
+ filepath: The path to the YAML file
394
+
395
+ Returns:
396
+ A dictionary representation of the YAML file
397
+ """
398
+ try:
399
+ with open(filepath, 'r') as f:
400
+ content = yaml.safe_load(f)
401
+ content = content if content else {}
402
+
403
+ if not isinstance(content, dict):
404
+ raise yaml.YAMLError(f"Parsed content from YAML file must be a dictionary. Got: {content}")
405
+
406
+ return content
407
+ except yaml.YAMLError as e:
408
+ raise ConfigurationError(f"Failed to parse yaml file: {filepath}") from e
409
+
410
+
411
+ def run_duckdb_stmt(
412
+ logger: Logger, duckdb_conn: duckdb.DuckDBPyConnection, stmt: str, *, params: dict[str, Any] | None = None,
413
+ model_name: str | None = None, redacted_values: list[str] = []
414
+ ) -> duckdb.DuckDBPyConnection:
415
+ """
416
+ Runs a statement on a DuckDB connection
417
+
418
+ Arguments:
419
+ logger: The logger to use
420
+ duckdb_conn: The DuckDB connection
421
+ stmt: The statement to run
422
+ params: The parameters to use
423
+ redacted_values: The values to redact
424
+ """
425
+ redacted_stmt = stmt
426
+ for value in redacted_values:
427
+ redacted_stmt = redacted_stmt.replace(value, "[REDACTED]")
428
+
429
+ for_model_name = f" for model '{model_name}'" if model_name is not None else ""
430
+ logger.debug(f"Running SQL statement{for_model_name}:\n{redacted_stmt}")
431
+ try:
432
+ return duckdb_conn.execute(stmt, params)
433
+ except duckdb.ParserException as e:
434
+ logger.error(f"Failed to run statement: {redacted_stmt}", exc_info=e)
435
+ raise e
436
+
437
+
438
+ def get_current_time() -> str:
439
+ """
440
+ Returns the current time in the format HH:MM:SS.ms
441
+ """
442
+ return datetime.now().strftime('%H:%M:%S.%f')[:-3]
443
+
444
+
445
+ def parse_dependent_tables(sql_query: str, all_table_names: Iterable[str]) -> tuple[set[str], sqlglot.Expression]:
446
+ """
447
+ Parses the dependent tables from a SQL query
448
+
449
+ Arguments:
450
+ sql_query: The SQL query to parse
451
+ all_table_names: The list of all table names
452
+
453
+ Returns:
454
+ The set of dependent tables
455
+ """
456
+ # Parse the SQL query and extract all table references
457
+ parsed = sqlglot.parse_one(sql_query)
458
+ dependencies = set()
459
+
460
+ # Collect all table references from the parsed SQL
461
+ for table in parsed.find_all(sqlglot.expressions.Table):
462
+ if table.name in set(all_table_names):
463
+ dependencies.add(table.name)
464
+
465
+ return dependencies, parsed
466
+
467
+
468
+ async def asyncio_gather(coroutines: list):
469
+ tasks = [asyncio.create_task(coro) for coro in coroutines]
470
+
471
+ try:
472
+ return await asyncio.gather(*tasks)
473
+ except BaseException:
474
+ # Cancel all tasks
475
+ for task in tasks:
476
+ if not task.done():
477
+ task.cancel()
478
+ # Wait for tasks to be cancelled
479
+ await asyncio.gather(*tasks, return_exceptions=True)
480
+ raise
481
+
482
+
483
+ def hash_string(input_str: str, salt: str) -> str:
484
+ """
485
+ Hashes a string using SHA-256
486
+ """
487
+ return hashlib.sha256((input_str + salt).encode()).hexdigest()
488
+
489
+
490
+ T = TypeVar('T')
491
+ def call_func(func: Callable[..., T], **kwargs) -> T:
492
+ """
493
+ Calls a function with the given arguments if func expects arguments, otherwise calls func without arguments
494
+ """
495
+ sig = inspect.signature(func)
496
+ # Filter kwargs to only include parameters that the function accepts
497
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters}
498
+ return func(**filtered_kwargs)
499
+
500
+
501
+ def generate_pkce_challenge(code_verifier: str) -> str:
502
+ """Generate PKCE code challenge from code verifier"""
503
+ # Generate SHA256 hash of code_verifier
504
+ verifier_hash = hashlib.sha256(code_verifier.encode('utf-8')).digest()
505
+ # Base64 URL encode (without padding)
506
+ expected_challenge = base64.urlsafe_b64encode(verifier_hash).decode('utf-8').rstrip('=')
507
+ return expected_challenge
508
+
509
+
510
+ def to_title_case(input_str: str) -> str:
511
+ """Convert a string to title case"""
512
+ spaced_str = input_str.replace('_', ' ').replace('-', ' ')
513
+ return spaced_str.title()
514
+
515
+
516
+ def to_bool(val: object) -> bool:
517
+ """Convert common truthy/falsey representations to a boolean.
518
+
519
+ Accepted truthy values (case-insensitive): "1", "true", "t", "yes", "y", "on".
520
+ All other values are considered falsey. None is falsey.
521
+ """
522
+ if isinstance(val, bool):
523
+ return val
524
+ if val is None:
525
+ return False
526
+ s = str(val).strip().lower()
527
+ return s in ("1", "true", "t", "yes", "y", "on")
528
+
529
+
530
+ ACCESS_LEVEL = Literal["admin", "member", "guest"]
531
+
532
+ def get_access_level_rank(access_level: ACCESS_LEVEL) -> int:
533
+ """Get the rank of an access level. Lower ranks have more privileges."""
534
+ return { "admin": 1, "member": 2, "guest": 3 }.get(access_level.lower(), 1)
535
+
536
+ def user_has_elevated_privileges(user_access_level: ACCESS_LEVEL, required_access_level: ACCESS_LEVEL) -> bool:
537
+ """Check if a user has privilege to access a resource"""
538
+ user_access_level_rank = get_access_level_rank(user_access_level)
539
+ required_access_level_rank = get_access_level_rank(required_access_level)
540
+ return user_access_level_rank <= required_access_level_rank