squirrels 0.1.0__py3-none-any.whl → 0.6.0.post0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dateutils/__init__.py +6 -0
- dateutils/_enums.py +25 -0
- squirrels/dateutils.py → dateutils/_implementation.py +409 -380
- dateutils/types.py +6 -0
- squirrels/__init__.py +21 -18
- squirrels/_api_routes/__init__.py +5 -0
- squirrels/_api_routes/auth.py +337 -0
- squirrels/_api_routes/base.py +196 -0
- squirrels/_api_routes/dashboards.py +156 -0
- squirrels/_api_routes/data_management.py +148 -0
- squirrels/_api_routes/datasets.py +220 -0
- squirrels/_api_routes/project.py +289 -0
- squirrels/_api_server.py +552 -134
- squirrels/_arguments/__init__.py +0 -0
- squirrels/_arguments/init_time_args.py +83 -0
- squirrels/_arguments/run_time_args.py +111 -0
- squirrels/_auth.py +777 -0
- squirrels/_command_line.py +239 -107
- squirrels/_compile_prompts.py +147 -0
- squirrels/_connection_set.py +94 -0
- squirrels/_constants.py +141 -64
- squirrels/_dashboards.py +179 -0
- squirrels/_data_sources.py +570 -0
- squirrels/_dataset_types.py +91 -0
- squirrels/_env_vars.py +209 -0
- squirrels/_exceptions.py +29 -0
- squirrels/_http_error_responses.py +52 -0
- squirrels/_initializer.py +319 -110
- squirrels/_logging.py +121 -0
- squirrels/_manifest.py +357 -187
- squirrels/_mcp_server.py +578 -0
- squirrels/_model_builder.py +69 -0
- squirrels/_model_configs.py +74 -0
- squirrels/_model_queries.py +52 -0
- squirrels/_models.py +1201 -0
- squirrels/_package_data/base_project/.env +7 -0
- squirrels/_package_data/base_project/.env.example +44 -0
- squirrels/_package_data/base_project/connections.yml +16 -0
- squirrels/_package_data/base_project/dashboards/dashboard_example.py +40 -0
- squirrels/_package_data/base_project/dashboards/dashboard_example.yml +22 -0
- squirrels/_package_data/base_project/docker/.dockerignore +16 -0
- squirrels/_package_data/base_project/docker/Dockerfile +16 -0
- squirrels/_package_data/base_project/docker/compose.yml +7 -0
- squirrels/_package_data/base_project/duckdb_init.sql +10 -0
- squirrels/_package_data/base_project/gitignore +13 -0
- squirrels/_package_data/base_project/macros/macros_example.sql +17 -0
- squirrels/_package_data/base_project/models/builds/build_example.py +26 -0
- squirrels/_package_data/base_project/models/builds/build_example.sql +16 -0
- squirrels/_package_data/base_project/models/builds/build_example.yml +57 -0
- squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +17 -0
- squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +32 -0
- squirrels/_package_data/base_project/models/federates/federate_example.py +51 -0
- squirrels/_package_data/base_project/models/federates/federate_example.sql +21 -0
- squirrels/_package_data/base_project/models/federates/federate_example.yml +65 -0
- squirrels/_package_data/base_project/models/sources.yml +38 -0
- squirrels/_package_data/base_project/parameters.yml +142 -0
- squirrels/_package_data/base_project/pyconfigs/connections.py +19 -0
- squirrels/_package_data/base_project/pyconfigs/context.py +96 -0
- squirrels/_package_data/base_project/pyconfigs/parameters.py +141 -0
- squirrels/_package_data/base_project/pyconfigs/user.py +56 -0
- squirrels/_package_data/base_project/resources/expenses.db +0 -0
- squirrels/_package_data/base_project/resources/public/.gitkeep +0 -0
- squirrels/_package_data/base_project/resources/weather.db +0 -0
- squirrels/_package_data/base_project/seeds/seed_categories.csv +6 -0
- squirrels/_package_data/base_project/seeds/seed_categories.yml +15 -0
- squirrels/_package_data/base_project/seeds/seed_subcategories.csv +15 -0
- squirrels/_package_data/base_project/seeds/seed_subcategories.yml +21 -0
- squirrels/_package_data/base_project/squirrels.yml.j2 +61 -0
- squirrels/_package_data/base_project/tmp/.gitignore +2 -0
- squirrels/_package_data/templates/login_successful.html +53 -0
- squirrels/_package_data/templates/squirrels_studio.html +22 -0
- squirrels/_package_loader.py +29 -0
- squirrels/_parameter_configs.py +592 -0
- squirrels/_parameter_options.py +348 -0
- squirrels/_parameter_sets.py +207 -0
- squirrels/_parameters.py +1703 -0
- squirrels/_project.py +796 -0
- squirrels/_py_module.py +122 -0
- squirrels/_request_context.py +33 -0
- squirrels/_schemas/__init__.py +0 -0
- squirrels/_schemas/auth_models.py +83 -0
- squirrels/_schemas/query_param_models.py +70 -0
- squirrels/_schemas/request_models.py +26 -0
- squirrels/_schemas/response_models.py +286 -0
- squirrels/_seeds.py +97 -0
- squirrels/_sources.py +112 -0
- squirrels/_utils.py +540 -149
- squirrels/_version.py +1 -3
- squirrels/arguments.py +7 -0
- squirrels/auth.py +4 -0
- squirrels/connections.py +3 -0
- squirrels/dashboards.py +3 -0
- squirrels/data_sources.py +14 -282
- squirrels/parameter_options.py +13 -189
- squirrels/parameters.py +14 -801
- squirrels/types.py +18 -0
- squirrels-0.6.0.post0.dist-info/METADATA +148 -0
- squirrels-0.6.0.post0.dist-info/RECORD +101 -0
- {squirrels-0.1.0.dist-info → squirrels-0.6.0.post0.dist-info}/WHEEL +1 -2
- {squirrels-0.1.0.dist-info → squirrels-0.6.0.post0.dist-info}/entry_points.txt +1 -0
- squirrels-0.6.0.post0.dist-info/licenses/LICENSE +201 -0
- squirrels/_credentials_manager.py +0 -87
- squirrels/_module_loader.py +0 -37
- squirrels/_parameter_set.py +0 -151
- squirrels/_renderer.py +0 -286
- squirrels/_timed_imports.py +0 -37
- squirrels/connection_set.py +0 -126
- squirrels/package_data/base_project/.gitignore +0 -4
- squirrels/package_data/base_project/connections.py +0 -21
- squirrels/package_data/base_project/database/sample_database.db +0 -0
- squirrels/package_data/base_project/database/seattle_weather.db +0 -0
- squirrels/package_data/base_project/datasets/sample_dataset/context.py +0 -8
- squirrels/package_data/base_project/datasets/sample_dataset/database_view1.py +0 -23
- squirrels/package_data/base_project/datasets/sample_dataset/database_view1.sql.j2 +0 -7
- squirrels/package_data/base_project/datasets/sample_dataset/final_view.py +0 -10
- squirrels/package_data/base_project/datasets/sample_dataset/final_view.sql.j2 +0 -2
- squirrels/package_data/base_project/datasets/sample_dataset/parameters.py +0 -30
- squirrels/package_data/base_project/datasets/sample_dataset/selections.cfg +0 -6
- squirrels/package_data/base_project/squirrels.yaml +0 -26
- squirrels/package_data/static/favicon.ico +0 -0
- squirrels/package_data/static/script.js +0 -234
- squirrels/package_data/static/style.css +0 -110
- squirrels/package_data/templates/index.html +0 -32
- squirrels-0.1.0.dist-info/LICENSE +0 -22
- squirrels-0.1.0.dist-info/METADATA +0 -67
- squirrels-0.1.0.dist-info/RECORD +0 -40
- squirrels-0.1.0.dist-info/top_level.txt +0 -1
squirrels/_utils.py
CHANGED
|
@@ -1,149 +1,540 @@
|
|
|
1
|
-
from typing import
|
|
2
|
-
from
|
|
3
|
-
from pathlib import Path
|
|
4
|
-
|
|
5
|
-
import
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
"""
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
1
|
+
from typing import Sequence, Optional, Union, TypeVar, Callable, Iterable, Literal, Any
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
import os, time, logging, json, duckdb, polars as pl, yaml
|
|
5
|
+
import jinja2 as j2, jinja2.nodes as j2_nodes
|
|
6
|
+
import sqlglot, sqlglot.expressions, asyncio, hashlib, inspect, base64
|
|
7
|
+
|
|
8
|
+
from . import _constants as c
|
|
9
|
+
from ._exceptions import ConfigurationError
|
|
10
|
+
|
|
11
|
+
FilePath = Union[str, Path]
|
|
12
|
+
|
|
13
|
+
# Polars <-> Squirrels dtypes mappings (except Decimal)
|
|
14
|
+
polars_dtypes_to_sqrl_dtypes: dict[type[pl.DataType], list[str]] = {
|
|
15
|
+
pl.String: ["string", "varchar", "char", "text"],
|
|
16
|
+
pl.Int8: ["tinyint", "int1"],
|
|
17
|
+
pl.Int16: ["smallint", "short", "int2"],
|
|
18
|
+
pl.Int32: ["integer", "int", "int4"],
|
|
19
|
+
pl.Int64: ["bigint", "long", "int8"],
|
|
20
|
+
pl.Float32: ["float", "float4", "real"],
|
|
21
|
+
pl.Float64: ["double", "float8"],
|
|
22
|
+
pl.Boolean: ["boolean", "bool", "logical"],
|
|
23
|
+
pl.Date: ["date"],
|
|
24
|
+
pl.Time: ["time"],
|
|
25
|
+
pl.Datetime: ["timestamp", "datetime"],
|
|
26
|
+
pl.Duration: ["interval"],
|
|
27
|
+
pl.Binary: ["blob", "binary", "varbinary"]
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
sqrl_dtypes_to_polars_dtypes: dict[str, type[pl.DataType]] = {
|
|
31
|
+
sqrl_type: k for k, v in polars_dtypes_to_sqrl_dtypes.items() for sqrl_type in v
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
## Other utility classes
|
|
36
|
+
|
|
37
|
+
class Logger(logging.Logger):
|
|
38
|
+
def info(self, msg: str, *, data: dict[str, Any] = {}, **kwargs) -> None:
|
|
39
|
+
super().info(msg, extra={"data": data}, **kwargs)
|
|
40
|
+
|
|
41
|
+
def log_activity_time(self, activity: str, start_timestamp: float, *, additional_data: dict[str, Any] = {}) -> None:
|
|
42
|
+
end_timestamp = time.time()
|
|
43
|
+
time_taken = round((end_timestamp-start_timestamp) * 10**3, 3)
|
|
44
|
+
data = {
|
|
45
|
+
"activity": activity,
|
|
46
|
+
"start_timestamp": start_timestamp,
|
|
47
|
+
"end_timestamp": end_timestamp,
|
|
48
|
+
"time_taken_ms": time_taken,
|
|
49
|
+
**additional_data
|
|
50
|
+
}
|
|
51
|
+
self.info(f'Time taken for "{activity}": {time_taken}ms', data=data)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class EnvironmentWithMacros(j2.Environment):
|
|
55
|
+
def __init__(self, logger: logging.Logger, loader: j2.FileSystemLoader, *args, **kwargs):
|
|
56
|
+
super().__init__(*args, loader=loader, **kwargs)
|
|
57
|
+
self._logger = logger
|
|
58
|
+
self._macros = self._load_macro_templates(logger)
|
|
59
|
+
|
|
60
|
+
def _load_macro_templates(self, logger: logging.Logger) -> str:
|
|
61
|
+
macros_dirs = self._get_macro_folders_from_packages()
|
|
62
|
+
macro_templates = []
|
|
63
|
+
for macros_dir in macros_dirs:
|
|
64
|
+
for root, _, files in os.walk(macros_dir):
|
|
65
|
+
files: list[str]
|
|
66
|
+
for filename in files:
|
|
67
|
+
if any(filename.endswith(x) for x in [".sql", ".j2", ".jinja", ".jinja2"]):
|
|
68
|
+
filepath = Path(root, filename)
|
|
69
|
+
logger.info(f"Loaded macros from: {filepath}")
|
|
70
|
+
with open(filepath, 'r') as f:
|
|
71
|
+
content = f.read()
|
|
72
|
+
macro_templates.append(content)
|
|
73
|
+
return '\n'.join(macro_templates)
|
|
74
|
+
|
|
75
|
+
def _get_macro_folders_from_packages(self) -> list[Path]:
|
|
76
|
+
assert isinstance(self.loader, j2.FileSystemLoader)
|
|
77
|
+
packages_folder = Path(self.loader.searchpath[0], c.PACKAGES_FOLDER)
|
|
78
|
+
|
|
79
|
+
subdirectories = []
|
|
80
|
+
if os.path.exists(packages_folder):
|
|
81
|
+
for item in os.listdir(packages_folder):
|
|
82
|
+
item_path = Path(packages_folder, item)
|
|
83
|
+
if os.path.isdir(item_path):
|
|
84
|
+
subdirectories.append(Path(item_path, c.MACROS_FOLDER))
|
|
85
|
+
|
|
86
|
+
subdirectories.append(Path(self.loader.searchpath[0], c.MACROS_FOLDER))
|
|
87
|
+
return subdirectories
|
|
88
|
+
|
|
89
|
+
def _parse(self, source: str, name: str | None, filename: str | None) -> j2_nodes.Template:
|
|
90
|
+
source = self._macros + source
|
|
91
|
+
return super()._parse(source, name, filename)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
## Utility functions/variables
|
|
95
|
+
|
|
96
|
+
def render_string(raw_str: str, *, project_path: str = ".", **kwargs) -> str:
|
|
97
|
+
"""
|
|
98
|
+
Given a template string, render it with the given keyword arguments
|
|
99
|
+
|
|
100
|
+
Arguments:
|
|
101
|
+
raw_str: The template string
|
|
102
|
+
kwargs: The keyword arguments
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
The rendered string
|
|
106
|
+
"""
|
|
107
|
+
j2_env = j2.Environment(loader=j2.FileSystemLoader(project_path))
|
|
108
|
+
template = j2_env.from_string(raw_str)
|
|
109
|
+
return template.render(kwargs)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def read_file(filepath: FilePath) -> str:
|
|
113
|
+
"""
|
|
114
|
+
Reads a file and return its content if required
|
|
115
|
+
|
|
116
|
+
Arguments:
|
|
117
|
+
filepath (str | pathlib.Path): The path to the file to read
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
Content of the file, or None if doesn't exist and not required
|
|
121
|
+
"""
|
|
122
|
+
try:
|
|
123
|
+
with open(filepath, 'r') as f:
|
|
124
|
+
return f.read()
|
|
125
|
+
except FileNotFoundError as e:
|
|
126
|
+
raise ConfigurationError(f"Required file not found: '{str(filepath)}'") from e
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def normalize_name(name: str) -> str:
|
|
130
|
+
"""
|
|
131
|
+
Normalizes names to the convention of the squirrels manifest file (with underscores instead of dashes).
|
|
132
|
+
|
|
133
|
+
Arguments:
|
|
134
|
+
name: The name to normalize.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
The normalized name.
|
|
138
|
+
"""
|
|
139
|
+
return name.replace('-', '_')
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def normalize_name_for_api(name: str) -> str:
|
|
143
|
+
"""
|
|
144
|
+
Normalizes names to the REST API convention (with dashes instead of underscores).
|
|
145
|
+
|
|
146
|
+
Arguments:
|
|
147
|
+
name: The name to normalize.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
The normalized name.
|
|
151
|
+
"""
|
|
152
|
+
return name.replace('_', '-')
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def load_json_or_comma_delimited_str_as_list(input_str: Union[str, Sequence]) -> Sequence[str]:
|
|
156
|
+
"""
|
|
157
|
+
Given a string, load it as a list either by json string or comma delimited value
|
|
158
|
+
|
|
159
|
+
Arguments:
|
|
160
|
+
input_str: The input string
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
The list representation of the input string
|
|
164
|
+
"""
|
|
165
|
+
if not isinstance(input_str, str):
|
|
166
|
+
return (input_str)
|
|
167
|
+
|
|
168
|
+
output = None
|
|
169
|
+
try:
|
|
170
|
+
output = json.loads(input_str)
|
|
171
|
+
except json.decoder.JSONDecodeError:
|
|
172
|
+
pass
|
|
173
|
+
|
|
174
|
+
if isinstance(output, list):
|
|
175
|
+
return output
|
|
176
|
+
elif input_str == "":
|
|
177
|
+
return []
|
|
178
|
+
else:
|
|
179
|
+
return [x.strip() for x in input_str.split(",")]
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
X = TypeVar('X'); Y = TypeVar('Y')
|
|
183
|
+
def process_if_not_none(input_val: Optional[X], processor: Callable[[X], Y]) -> Optional[Y]:
|
|
184
|
+
"""
|
|
185
|
+
Given a input value and a function that processes the value, return the output of the function unless input is None
|
|
186
|
+
|
|
187
|
+
Arguments:
|
|
188
|
+
input_val: The input value
|
|
189
|
+
processor: The function that processes the input value
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
The output type of "processor" or None if input value if None
|
|
193
|
+
"""
|
|
194
|
+
if input_val is None:
|
|
195
|
+
return None
|
|
196
|
+
return processor(input_val)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _read_duckdb_init_sql(
|
|
200
|
+
*,
|
|
201
|
+
datalake_db_path: str | None = None,
|
|
202
|
+
) -> str:
|
|
203
|
+
"""
|
|
204
|
+
Reads and caches the duckdb init file content.
|
|
205
|
+
Returns None if file doesn't exist or is empty.
|
|
206
|
+
"""
|
|
207
|
+
try:
|
|
208
|
+
init_contents = []
|
|
209
|
+
global_init_path = Path(os.path.expanduser('~'), c.GLOBAL_ENV_FOLDER, c.DUCKDB_INIT_FILE)
|
|
210
|
+
if global_init_path.exists():
|
|
211
|
+
with open(global_init_path, 'r') as f:
|
|
212
|
+
init_contents.append(f.read())
|
|
213
|
+
|
|
214
|
+
if Path(c.DUCKDB_INIT_FILE).exists():
|
|
215
|
+
with open(c.DUCKDB_INIT_FILE, 'r') as f:
|
|
216
|
+
init_contents.append(f.read())
|
|
217
|
+
|
|
218
|
+
if datalake_db_path:
|
|
219
|
+
attach_stmt = f"ATTACH '{datalake_db_path}' AS vdl (READ_ONLY);"
|
|
220
|
+
init_contents.append(attach_stmt)
|
|
221
|
+
use_stmt = f"USE vdl;"
|
|
222
|
+
init_contents.append(use_stmt)
|
|
223
|
+
|
|
224
|
+
init_sql = "\n\n".join(init_contents).strip()
|
|
225
|
+
return init_sql
|
|
226
|
+
except Exception as e:
|
|
227
|
+
raise ConfigurationError(f"Failed to read {c.DUCKDB_INIT_FILE}: {str(e)}") from e
|
|
228
|
+
|
|
229
|
+
def create_duckdb_connection(
|
|
230
|
+
db_path: str | Path = ":memory:",
|
|
231
|
+
*,
|
|
232
|
+
datalake_db_path: str | None = None
|
|
233
|
+
) -> duckdb.DuckDBPyConnection:
|
|
234
|
+
"""
|
|
235
|
+
Creates a DuckDB connection and initializes it with statements from duckdb init file
|
|
236
|
+
|
|
237
|
+
Arguments:
|
|
238
|
+
filepath: Path to the DuckDB database file. Defaults to in-memory database.
|
|
239
|
+
datalake_db_path: The path to the VDL catalog database if applicable. If exists, this is attached as 'vdl' (READ_ONLY). Default is None.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
A DuckDB connection (which must be closed after use)
|
|
243
|
+
"""
|
|
244
|
+
conn = duckdb.connect(db_path)
|
|
245
|
+
|
|
246
|
+
try:
|
|
247
|
+
init_sql = _read_duckdb_init_sql(datalake_db_path=datalake_db_path)
|
|
248
|
+
conn.execute(init_sql)
|
|
249
|
+
except Exception as e:
|
|
250
|
+
conn.close()
|
|
251
|
+
raise ConfigurationError(f"Failed to execute {c.DUCKDB_INIT_FILE}: {str(e)}") from e
|
|
252
|
+
|
|
253
|
+
return conn
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def run_sql_on_dataframes(sql_query: str, dataframes: dict[str, pl.LazyFrame]) -> pl.DataFrame:
|
|
257
|
+
"""
|
|
258
|
+
Runs a SQL query against a collection of dataframes
|
|
259
|
+
|
|
260
|
+
Arguments:
|
|
261
|
+
sql_query: The SQL query to run
|
|
262
|
+
dataframes: A dictionary of table names to their polars LazyFrame
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
The result as a polars Dataframe from running the query
|
|
266
|
+
"""
|
|
267
|
+
duckdb_conn = create_duckdb_connection()
|
|
268
|
+
|
|
269
|
+
try:
|
|
270
|
+
for name, df in dataframes.items():
|
|
271
|
+
duckdb_conn.register(name, df)
|
|
272
|
+
|
|
273
|
+
result_df = duckdb_conn.sql(sql_query).pl()
|
|
274
|
+
finally:
|
|
275
|
+
duckdb_conn.close()
|
|
276
|
+
|
|
277
|
+
return result_df
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
async def run_polars_sql_on_dataframes(
|
|
281
|
+
sql_query: str, dataframes: dict[str, pl.LazyFrame], *, timeout_seconds: float = 2.0, max_rows: int | None = None
|
|
282
|
+
) -> pl.DataFrame:
|
|
283
|
+
"""
|
|
284
|
+
Runs a SQL query against a collection of dataframes using Polars SQL (more secure than DuckDB for user input).
|
|
285
|
+
|
|
286
|
+
Arguments:
|
|
287
|
+
sql_query: The SQL query to run (Polars SQL dialect)
|
|
288
|
+
dataframes: A dictionary of table names to their polars LazyFrame
|
|
289
|
+
timeout_seconds: Maximum execution time in seconds (default 2.0)
|
|
290
|
+
max_rows: Maximum number of rows to collect. Collects at most max_rows + 1 rows
|
|
291
|
+
to allow overflow detection without loading unbounded results into memory.
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
The result as a polars DataFrame from running the query (limited to max_rows + 1)
|
|
295
|
+
|
|
296
|
+
Raises:
|
|
297
|
+
ConfigurationError: If the query is invalid or insecure
|
|
298
|
+
"""
|
|
299
|
+
# Validate the SQL query
|
|
300
|
+
_validate_sql_query_security(sql_query, dataframes)
|
|
301
|
+
|
|
302
|
+
# Execute with timeout
|
|
303
|
+
try:
|
|
304
|
+
loop = asyncio.get_event_loop()
|
|
305
|
+
result = await asyncio.wait_for(
|
|
306
|
+
loop.run_in_executor(None, _run_polars_sql_sync, sql_query, dataframes, max_rows),
|
|
307
|
+
timeout=timeout_seconds
|
|
308
|
+
)
|
|
309
|
+
return result
|
|
310
|
+
except asyncio.TimeoutError as e:
|
|
311
|
+
raise ConfigurationError(f"SQL query execution exceeded timeout of {timeout_seconds} seconds") from e
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _run_polars_sql_sync(sql_query: str, dataframes: dict[str, pl.LazyFrame], max_rows: int | None) -> pl.DataFrame:
|
|
315
|
+
"""
|
|
316
|
+
Synchronous execution of Polars SQL.
|
|
317
|
+
|
|
318
|
+
Arguments:
|
|
319
|
+
sql_query: The SQL query to run
|
|
320
|
+
dataframes: A dictionary of table names to their polars LazyFrame
|
|
321
|
+
max_rows: Maximum number of rows to collect.
|
|
322
|
+
"""
|
|
323
|
+
ctx = pl.SQLContext(**dataframes)
|
|
324
|
+
result = ctx.execute(sql_query, eager=False)
|
|
325
|
+
if max_rows is not None:
|
|
326
|
+
result = result.limit(max_rows)
|
|
327
|
+
return result.collect()
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def _validate_sql_query_security(sql_query: str, dataframes: dict[str, pl.LazyFrame]) -> None:
|
|
331
|
+
"""
|
|
332
|
+
Validates that a SQL query is safe to execute.
|
|
333
|
+
|
|
334
|
+
Enforces:
|
|
335
|
+
- Single statement only
|
|
336
|
+
- Read-only operations (SELECT/WITH/UNION)
|
|
337
|
+
- Table references limited to registered frames (excluding CTE names)
|
|
338
|
+
|
|
339
|
+
Arguments:
|
|
340
|
+
sql_query: The SQL query to validate
|
|
341
|
+
dataframes: Dictionary of allowed table names
|
|
342
|
+
|
|
343
|
+
Raises:
|
|
344
|
+
ConfigurationError: If validation fails
|
|
345
|
+
"""
|
|
346
|
+
try:
|
|
347
|
+
parsed = sqlglot.parse(sql_query)
|
|
348
|
+
except Exception as e:
|
|
349
|
+
raise ConfigurationError(f"Failed to parse SQL query: {str(e)}") from e
|
|
350
|
+
|
|
351
|
+
# Enforce single statement
|
|
352
|
+
if len(parsed) != 1:
|
|
353
|
+
raise ConfigurationError(f"Only single SQL statements are allowed. Found {len(parsed)} statements.")
|
|
354
|
+
|
|
355
|
+
statement = parsed[0]
|
|
356
|
+
|
|
357
|
+
# Enforce read-only: allow SELECT, WITH (CTE), UNION, INTERSECT, EXCEPT
|
|
358
|
+
allowed_types = (
|
|
359
|
+
sqlglot.expressions.Select,
|
|
360
|
+
sqlglot.expressions.Union,
|
|
361
|
+
sqlglot.expressions.Intersect,
|
|
362
|
+
sqlglot.expressions.Except,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
if not isinstance(statement, allowed_types):
|
|
366
|
+
raise ConfigurationError(
|
|
367
|
+
f"Only read-only SQL statements (SELECT, WITH, UNION, INTERSECT, EXCEPT) are allowed. "
|
|
368
|
+
f"Found: {type(statement).__name__}"
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
# Collect CTE names (these are temporary tables created by WITH clauses)
|
|
372
|
+
cte_names: set[str] = set()
|
|
373
|
+
for cte in statement.find_all(sqlglot.expressions.CTE):
|
|
374
|
+
if cte.alias:
|
|
375
|
+
cte_names.add(cte.alias)
|
|
376
|
+
|
|
377
|
+
# Validate table references (excluding CTE names)
|
|
378
|
+
allowed_tables = set(dataframes.keys()) | cte_names
|
|
379
|
+
for table in statement.find_all(sqlglot.expressions.Table):
|
|
380
|
+
table_name = table.name
|
|
381
|
+
if table_name not in allowed_tables:
|
|
382
|
+
raise ConfigurationError(
|
|
383
|
+
f"Table reference '{table_name}' is not allowed. "
|
|
384
|
+
f"Only the following tables are available: {sorted(dataframes.keys())}"
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def load_yaml_config(filepath: FilePath) -> dict:
|
|
389
|
+
"""
|
|
390
|
+
Loads a YAML config file
|
|
391
|
+
|
|
392
|
+
Arguments:
|
|
393
|
+
filepath: The path to the YAML file
|
|
394
|
+
|
|
395
|
+
Returns:
|
|
396
|
+
A dictionary representation of the YAML file
|
|
397
|
+
"""
|
|
398
|
+
try:
|
|
399
|
+
with open(filepath, 'r') as f:
|
|
400
|
+
content = yaml.safe_load(f)
|
|
401
|
+
content = content if content else {}
|
|
402
|
+
|
|
403
|
+
if not isinstance(content, dict):
|
|
404
|
+
raise yaml.YAMLError(f"Parsed content from YAML file must be a dictionary. Got: {content}")
|
|
405
|
+
|
|
406
|
+
return content
|
|
407
|
+
except yaml.YAMLError as e:
|
|
408
|
+
raise ConfigurationError(f"Failed to parse yaml file: {filepath}") from e
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def run_duckdb_stmt(
|
|
412
|
+
logger: Logger, duckdb_conn: duckdb.DuckDBPyConnection, stmt: str, *, params: dict[str, Any] | None = None,
|
|
413
|
+
model_name: str | None = None, redacted_values: list[str] = []
|
|
414
|
+
) -> duckdb.DuckDBPyConnection:
|
|
415
|
+
"""
|
|
416
|
+
Runs a statement on a DuckDB connection
|
|
417
|
+
|
|
418
|
+
Arguments:
|
|
419
|
+
logger: The logger to use
|
|
420
|
+
duckdb_conn: The DuckDB connection
|
|
421
|
+
stmt: The statement to run
|
|
422
|
+
params: The parameters to use
|
|
423
|
+
redacted_values: The values to redact
|
|
424
|
+
"""
|
|
425
|
+
redacted_stmt = stmt
|
|
426
|
+
for value in redacted_values:
|
|
427
|
+
redacted_stmt = redacted_stmt.replace(value, "[REDACTED]")
|
|
428
|
+
|
|
429
|
+
for_model_name = f" for model '{model_name}'" if model_name is not None else ""
|
|
430
|
+
logger.debug(f"Running SQL statement{for_model_name}:\n{redacted_stmt}")
|
|
431
|
+
try:
|
|
432
|
+
return duckdb_conn.execute(stmt, params)
|
|
433
|
+
except duckdb.ParserException as e:
|
|
434
|
+
logger.error(f"Failed to run statement: {redacted_stmt}", exc_info=e)
|
|
435
|
+
raise e
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
def get_current_time() -> str:
|
|
439
|
+
"""
|
|
440
|
+
Returns the current time in the format HH:MM:SS.ms
|
|
441
|
+
"""
|
|
442
|
+
return datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def parse_dependent_tables(sql_query: str, all_table_names: Iterable[str]) -> tuple[set[str], sqlglot.Expression]:
|
|
446
|
+
"""
|
|
447
|
+
Parses the dependent tables from a SQL query
|
|
448
|
+
|
|
449
|
+
Arguments:
|
|
450
|
+
sql_query: The SQL query to parse
|
|
451
|
+
all_table_names: The list of all table names
|
|
452
|
+
|
|
453
|
+
Returns:
|
|
454
|
+
The set of dependent tables
|
|
455
|
+
"""
|
|
456
|
+
# Parse the SQL query and extract all table references
|
|
457
|
+
parsed = sqlglot.parse_one(sql_query)
|
|
458
|
+
dependencies = set()
|
|
459
|
+
|
|
460
|
+
# Collect all table references from the parsed SQL
|
|
461
|
+
for table in parsed.find_all(sqlglot.expressions.Table):
|
|
462
|
+
if table.name in set(all_table_names):
|
|
463
|
+
dependencies.add(table.name)
|
|
464
|
+
|
|
465
|
+
return dependencies, parsed
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
async def asyncio_gather(coroutines: list):
|
|
469
|
+
tasks = [asyncio.create_task(coro) for coro in coroutines]
|
|
470
|
+
|
|
471
|
+
try:
|
|
472
|
+
return await asyncio.gather(*tasks)
|
|
473
|
+
except BaseException:
|
|
474
|
+
# Cancel all tasks
|
|
475
|
+
for task in tasks:
|
|
476
|
+
if not task.done():
|
|
477
|
+
task.cancel()
|
|
478
|
+
# Wait for tasks to be cancelled
|
|
479
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
480
|
+
raise
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def hash_string(input_str: str, salt: str) -> str:
|
|
484
|
+
"""
|
|
485
|
+
Hashes a string using SHA-256
|
|
486
|
+
"""
|
|
487
|
+
return hashlib.sha256((input_str + salt).encode()).hexdigest()
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
T = TypeVar('T')
|
|
491
|
+
def call_func(func: Callable[..., T], **kwargs) -> T:
|
|
492
|
+
"""
|
|
493
|
+
Calls a function with the given arguments if func expects arguments, otherwise calls func without arguments
|
|
494
|
+
"""
|
|
495
|
+
sig = inspect.signature(func)
|
|
496
|
+
# Filter kwargs to only include parameters that the function accepts
|
|
497
|
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters}
|
|
498
|
+
return func(**filtered_kwargs)
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def generate_pkce_challenge(code_verifier: str) -> str:
|
|
502
|
+
"""Generate PKCE code challenge from code verifier"""
|
|
503
|
+
# Generate SHA256 hash of code_verifier
|
|
504
|
+
verifier_hash = hashlib.sha256(code_verifier.encode('utf-8')).digest()
|
|
505
|
+
# Base64 URL encode (without padding)
|
|
506
|
+
expected_challenge = base64.urlsafe_b64encode(verifier_hash).decode('utf-8').rstrip('=')
|
|
507
|
+
return expected_challenge
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def to_title_case(input_str: str) -> str:
|
|
511
|
+
"""Convert a string to title case"""
|
|
512
|
+
spaced_str = input_str.replace('_', ' ').replace('-', ' ')
|
|
513
|
+
return spaced_str.title()
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def to_bool(val: object) -> bool:
|
|
517
|
+
"""Convert common truthy/falsey representations to a boolean.
|
|
518
|
+
|
|
519
|
+
Accepted truthy values (case-insensitive): "1", "true", "t", "yes", "y", "on".
|
|
520
|
+
All other values are considered falsey. None is falsey.
|
|
521
|
+
"""
|
|
522
|
+
if isinstance(val, bool):
|
|
523
|
+
return val
|
|
524
|
+
if val is None:
|
|
525
|
+
return False
|
|
526
|
+
s = str(val).strip().lower()
|
|
527
|
+
return s in ("1", "true", "t", "yes", "y", "on")
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
ACCESS_LEVEL = Literal["admin", "member", "guest"]
|
|
531
|
+
|
|
532
|
+
def get_access_level_rank(access_level: ACCESS_LEVEL) -> int:
|
|
533
|
+
"""Get the rank of an access level. Lower ranks have more privileges."""
|
|
534
|
+
return { "admin": 1, "member": 2, "guest": 3 }.get(access_level.lower(), 1)
|
|
535
|
+
|
|
536
|
+
def user_has_elevated_privileges(user_access_level: ACCESS_LEVEL, required_access_level: ACCESS_LEVEL) -> bool:
|
|
537
|
+
"""Check if a user has privilege to access a resource"""
|
|
538
|
+
user_access_level_rank = get_access_level_rank(user_access_level)
|
|
539
|
+
required_access_level_rank = get_access_level_rank(required_access_level)
|
|
540
|
+
return user_access_level_rank <= required_access_level_rank
|