squirrels 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of squirrels might be problematic. Click here for more details.
- dateutils/__init__.py +6 -0
- dateutils/_enums.py +25 -0
- squirrels/dateutils.py → dateutils/_implementation.py +58 -111
- dateutils/types.py +6 -0
- squirrels/__init__.py +13 -11
- squirrels/_api_routes/__init__.py +5 -0
- squirrels/_api_routes/auth.py +271 -0
- squirrels/_api_routes/base.py +165 -0
- squirrels/_api_routes/dashboards.py +150 -0
- squirrels/_api_routes/data_management.py +145 -0
- squirrels/_api_routes/datasets.py +257 -0
- squirrels/_api_routes/oauth2.py +298 -0
- squirrels/_api_routes/project.py +252 -0
- squirrels/_api_server.py +256 -450
- squirrels/_arguments/__init__.py +0 -0
- squirrels/_arguments/init_time_args.py +108 -0
- squirrels/_arguments/run_time_args.py +147 -0
- squirrels/_auth.py +960 -0
- squirrels/_command_line.py +126 -45
- squirrels/_compile_prompts.py +147 -0
- squirrels/_connection_set.py +48 -26
- squirrels/_constants.py +68 -38
- squirrels/_dashboards.py +160 -0
- squirrels/_data_sources.py +570 -0
- squirrels/_dataset_types.py +84 -0
- squirrels/_exceptions.py +29 -0
- squirrels/_initializer.py +177 -80
- squirrels/_logging.py +115 -0
- squirrels/_manifest.py +208 -79
- squirrels/_model_builder.py +69 -0
- squirrels/_model_configs.py +74 -0
- squirrels/_model_queries.py +52 -0
- squirrels/_models.py +926 -367
- squirrels/_package_data/base_project/.env +42 -0
- squirrels/_package_data/base_project/.env.example +42 -0
- squirrels/_package_data/base_project/assets/expenses.db +0 -0
- squirrels/_package_data/base_project/connections.yml +16 -0
- squirrels/_package_data/base_project/dashboards/dashboard_example.py +34 -0
- squirrels/_package_data/base_project/dashboards/dashboard_example.yml +22 -0
- squirrels/{package_data → _package_data}/base_project/docker/.dockerignore +5 -2
- squirrels/{package_data → _package_data}/base_project/docker/Dockerfile +3 -3
- squirrels/{package_data → _package_data}/base_project/docker/compose.yml +1 -1
- squirrels/_package_data/base_project/duckdb_init.sql +10 -0
- squirrels/{package_data/base_project/.gitignore → _package_data/base_project/gitignore} +3 -2
- squirrels/_package_data/base_project/macros/macros_example.sql +17 -0
- squirrels/_package_data/base_project/models/builds/build_example.py +26 -0
- squirrels/_package_data/base_project/models/builds/build_example.sql +16 -0
- squirrels/_package_data/base_project/models/builds/build_example.yml +57 -0
- squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +12 -0
- squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +26 -0
- squirrels/_package_data/base_project/models/federates/federate_example.py +37 -0
- squirrels/_package_data/base_project/models/federates/federate_example.sql +19 -0
- squirrels/_package_data/base_project/models/federates/federate_example.yml +65 -0
- squirrels/_package_data/base_project/models/sources.yml +38 -0
- squirrels/{package_data → _package_data}/base_project/parameters.yml +56 -40
- squirrels/_package_data/base_project/pyconfigs/connections.py +14 -0
- squirrels/{package_data → _package_data}/base_project/pyconfigs/context.py +21 -40
- squirrels/_package_data/base_project/pyconfigs/parameters.py +141 -0
- squirrels/_package_data/base_project/pyconfigs/user.py +44 -0
- squirrels/_package_data/base_project/seeds/seed_categories.yml +15 -0
- squirrels/_package_data/base_project/seeds/seed_subcategories.csv +15 -0
- squirrels/_package_data/base_project/seeds/seed_subcategories.yml +21 -0
- squirrels/_package_data/base_project/squirrels.yml.j2 +61 -0
- squirrels/_package_data/templates/dataset_results.html +112 -0
- squirrels/_package_data/templates/oauth_login.html +271 -0
- squirrels/_package_data/templates/squirrels_studio.html +20 -0
- squirrels/_package_loader.py +8 -4
- squirrels/_parameter_configs.py +104 -103
- squirrels/_parameter_options.py +348 -0
- squirrels/_parameter_sets.py +57 -47
- squirrels/_parameters.py +1664 -0
- squirrels/_project.py +721 -0
- squirrels/_py_module.py +7 -5
- squirrels/_schemas/__init__.py +0 -0
- squirrels/_schemas/auth_models.py +167 -0
- squirrels/_schemas/query_param_models.py +75 -0
- squirrels/{_api_response_models.py → _schemas/response_models.py} +126 -47
- squirrels/_seeds.py +35 -16
- squirrels/_sources.py +110 -0
- squirrels/_utils.py +248 -73
- squirrels/_version.py +1 -1
- squirrels/arguments.py +7 -0
- squirrels/auth.py +4 -0
- squirrels/connections.py +3 -0
- squirrels/dashboards.py +2 -81
- squirrels/data_sources.py +14 -631
- squirrels/parameter_options.py +13 -348
- squirrels/parameters.py +14 -1266
- squirrels/types.py +16 -0
- squirrels-0.5.0.dist-info/METADATA +113 -0
- squirrels-0.5.0.dist-info/RECORD +97 -0
- {squirrels-0.4.0.dist-info → squirrels-0.5.0.dist-info}/WHEEL +1 -1
- squirrels-0.5.0.dist-info/entry_points.txt +3 -0
- {squirrels-0.4.0.dist-info → squirrels-0.5.0.dist-info/licenses}/LICENSE +1 -1
- squirrels/_authenticator.py +0 -85
- squirrels/_dashboards_io.py +0 -61
- squirrels/_environcfg.py +0 -84
- squirrels/arguments/init_time_args.py +0 -40
- squirrels/arguments/run_time_args.py +0 -208
- squirrels/package_data/assets/favicon.ico +0 -0
- squirrels/package_data/assets/index.css +0 -1
- squirrels/package_data/assets/index.js +0 -58
- squirrels/package_data/base_project/assets/expenses.db +0 -0
- squirrels/package_data/base_project/connections.yml +0 -7
- squirrels/package_data/base_project/dashboards/dashboard_example.py +0 -32
- squirrels/package_data/base_project/dashboards.yml +0 -10
- squirrels/package_data/base_project/env.yml +0 -29
- squirrels/package_data/base_project/models/dbviews/dbview_example.py +0 -47
- squirrels/package_data/base_project/models/dbviews/dbview_example.sql +0 -22
- squirrels/package_data/base_project/models/federates/federate_example.py +0 -21
- squirrels/package_data/base_project/models/federates/federate_example.sql +0 -3
- squirrels/package_data/base_project/pyconfigs/auth.py +0 -45
- squirrels/package_data/base_project/pyconfigs/connections.py +0 -19
- squirrels/package_data/base_project/pyconfigs/parameters.py +0 -95
- squirrels/package_data/base_project/seeds/seed_subcategories.csv +0 -15
- squirrels/package_data/base_project/squirrels.yml.j2 +0 -94
- squirrels/package_data/templates/index.html +0 -18
- squirrels/project.py +0 -378
- squirrels/user_base.py +0 -55
- squirrels-0.4.0.dist-info/METADATA +0 -117
- squirrels-0.4.0.dist-info/RECORD +0 -60
- squirrels-0.4.0.dist-info/entry_points.txt +0 -4
- /squirrels/{package_data → _package_data}/base_project/assets/weather.db +0 -0
- /squirrels/{package_data → _package_data}/base_project/seeds/seed_categories.csv +0 -0
- /squirrels/{package_data → _package_data}/base_project/tmp/.gitignore +0 -0
squirrels/_sources.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
from pydantic import BaseModel, Field, model_validator
|
|
3
|
+
import time, sqlglot, yaml
|
|
4
|
+
|
|
5
|
+
from . import _utils as u, _constants as c, _model_configs as mc
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class UpdateHints(BaseModel):
|
|
9
|
+
increasing_column: str | None = Field(default=None)
|
|
10
|
+
strictly_increasing: bool = Field(default=True, description="Delete the max value of the increasing column, ignored if selective_overwrite_value is set")
|
|
11
|
+
selective_overwrite_value: Any = Field(default=None, description="Delete all values of the increasing column greater than or equal to this value")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Source(mc.ConnectionInterface, mc.ModelConfig):
|
|
15
|
+
table: str | None = Field(default=None)
|
|
16
|
+
load_to_vdl: bool = Field(default=False, description="Whether to load the data to the 'virtual data lake' (VDL)")
|
|
17
|
+
primary_key: list[str] = Field(default_factory=list)
|
|
18
|
+
update_hints: UpdateHints = Field(default_factory=UpdateHints)
|
|
19
|
+
|
|
20
|
+
def finalize_table(self, source_name: str):
|
|
21
|
+
if self.table is None:
|
|
22
|
+
self.table = source_name
|
|
23
|
+
return self
|
|
24
|
+
|
|
25
|
+
def get_table(self) -> str:
|
|
26
|
+
assert self.table is not None, "Table must be set"
|
|
27
|
+
return self.table
|
|
28
|
+
|
|
29
|
+
def get_cols_for_create_table_stmt(self) -> str:
|
|
30
|
+
cols_clause = ", ".join([f"{col.name} {col.type}" for col in self.columns])
|
|
31
|
+
return cols_clause
|
|
32
|
+
|
|
33
|
+
def get_max_incr_col_query(self, source_name: str) -> str:
|
|
34
|
+
return f"SELECT max({self.update_hints.increasing_column}) FROM {source_name}"
|
|
35
|
+
|
|
36
|
+
def get_query_for_upsert(self, dialect: str, conn_name: str, table_name: str, max_value_of_increasing_col: Any | None, *, full_refresh: bool = True) -> str:
|
|
37
|
+
select_cols = ", ".join([col.name for col in self.columns])
|
|
38
|
+
if full_refresh or max_value_of_increasing_col is None:
|
|
39
|
+
return f"SELECT {select_cols} FROM db_{conn_name}.{table_name}"
|
|
40
|
+
|
|
41
|
+
increasing_col = self.update_hints.increasing_column
|
|
42
|
+
increasing_col_type = next(col.type for col in self.columns if col.name == increasing_col)
|
|
43
|
+
where_cond = f"{increasing_col}::{increasing_col_type} > '{max_value_of_increasing_col}'::{increasing_col_type}"
|
|
44
|
+
|
|
45
|
+
# TODO: figure out if using pushdown query is worth it
|
|
46
|
+
# if dialect in ['postgres', 'mysql']:
|
|
47
|
+
# pushdown_query = f"SELECT {select_cols} FROM {table_name} WHERE {where_cond}"
|
|
48
|
+
# transpiled_query = sqlglot.transpile(pushdown_query, read='duckdb', write=dialect)[0].replace("'", "''")
|
|
49
|
+
# return f"FROM {dialect}_query('db_{conn_name}', '{transpiled_query}')"
|
|
50
|
+
|
|
51
|
+
return f"SELECT {select_cols} FROM db_{conn_name}.{table_name} WHERE {where_cond}"
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class Sources(BaseModel):
|
|
55
|
+
sources: dict[str, Source] = Field(default_factory=dict)
|
|
56
|
+
|
|
57
|
+
@model_validator(mode="before")
|
|
58
|
+
@classmethod
|
|
59
|
+
def convert_sources_list_to_dict(cls, data: dict[str, Any]) -> dict[str, Any]:
|
|
60
|
+
if "sources" in data and isinstance(data["sources"], list):
|
|
61
|
+
# Convert list of sources to dictionary
|
|
62
|
+
sources_dict = {}
|
|
63
|
+
for source in data["sources"]:
|
|
64
|
+
if isinstance(source, dict) and "name" in source:
|
|
65
|
+
name = source.pop("name") # Remove name from source config
|
|
66
|
+
if name in sources_dict:
|
|
67
|
+
raise u.ConfigurationError(f"Duplicate source name found: {name}")
|
|
68
|
+
sources_dict[name] = source
|
|
69
|
+
else:
|
|
70
|
+
raise u.ConfigurationError(f"All sources must have a name field in sources file")
|
|
71
|
+
data["sources"] = sources_dict
|
|
72
|
+
return data
|
|
73
|
+
|
|
74
|
+
@model_validator(mode="after")
|
|
75
|
+
def validate_column_types(self):
|
|
76
|
+
for source_name, source in self.sources.items():
|
|
77
|
+
for col in source.columns:
|
|
78
|
+
if not col.type:
|
|
79
|
+
raise u.ConfigurationError(f"Column '{col.name}' in source '{source_name}' must have a type specified")
|
|
80
|
+
return self
|
|
81
|
+
|
|
82
|
+
def finalize_null_fields(self, env_vars: dict[str, str]):
|
|
83
|
+
for source_name, source in self.sources.items():
|
|
84
|
+
source.finalize_connection(env_vars)
|
|
85
|
+
source.finalize_table(source_name)
|
|
86
|
+
return self
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class SourcesIO:
|
|
90
|
+
@classmethod
|
|
91
|
+
def load_file(cls, logger: u.Logger, base_path: str, env_vars: dict[str, str]) -> Sources:
|
|
92
|
+
start = time.time()
|
|
93
|
+
|
|
94
|
+
sources_path = u.Path(base_path, c.MODELS_FOLDER, c.SOURCES_FILE)
|
|
95
|
+
if sources_path.exists():
|
|
96
|
+
raw_content = u.read_file(sources_path)
|
|
97
|
+
rendered = u.render_string(raw_content, base_path=base_path, env_vars=env_vars)
|
|
98
|
+
sources_data = yaml.safe_load(rendered) or {}
|
|
99
|
+
else:
|
|
100
|
+
sources_data = {}
|
|
101
|
+
|
|
102
|
+
if not isinstance(sources_data, dict):
|
|
103
|
+
raise u.ConfigurationError(
|
|
104
|
+
f"Parsed content from YAML file must be a dictionary. Got: {sources_data}"
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
sources = Sources(**sources_data).finalize_null_fields(env_vars)
|
|
108
|
+
|
|
109
|
+
logger.log_activity_time("loading sources", start)
|
|
110
|
+
return sources
|
squirrels/_utils.py
CHANGED
|
@@ -1,35 +1,33 @@
|
|
|
1
|
-
from typing import Sequence, Optional, Union, TypeVar, Callable
|
|
2
|
-
from pathlib import Path
|
|
3
|
-
from pandas.api import types as pd_types
|
|
1
|
+
from typing import Sequence, Optional, Union, TypeVar, Callable, Iterable, Literal, Any
|
|
4
2
|
from datetime import datetime
|
|
5
|
-
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
import os, time, logging, json, duckdb, polars as pl, yaml
|
|
6
5
|
import jinja2 as j2, jinja2.nodes as j2_nodes
|
|
6
|
+
import sqlglot, sqlglot.expressions, asyncio, hashlib, inspect, base64
|
|
7
7
|
|
|
8
8
|
from . import _constants as c
|
|
9
|
+
from ._exceptions import ConfigurationError
|
|
9
10
|
|
|
10
11
|
FilePath = Union[str, Path]
|
|
11
12
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
"""
|
|
17
|
-
|
|
18
|
-
"""
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
""
|
|
23
|
-
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
new_message = f"\n" + message + f"\n{t}Produced error message:\n{t}{t}{error} (see above for more details on handled exception)"
|
|
31
|
-
super().__init__(new_message, *args)
|
|
32
|
-
self.error = error
|
|
13
|
+
# Polars
|
|
14
|
+
polars_dtypes_to_sqrl_dtypes: dict[type[pl.DataType], list[str]] = {
|
|
15
|
+
pl.String: ["string", "varchar", "char", "text"],
|
|
16
|
+
pl.Int8: ["tinyint", "int1"],
|
|
17
|
+
pl.Int16: ["smallint", "short", "int2"],
|
|
18
|
+
pl.Int32: ["integer", "int", "int4"],
|
|
19
|
+
pl.Int64: ["bigint", "long", "int8"],
|
|
20
|
+
pl.Float32: ["float", "float4", "real"],
|
|
21
|
+
pl.Float64: ["double", "float8", "decimal"], # Note: Polars Decimal type is considered unstable, so we use Float64 for "decimal"
|
|
22
|
+
pl.Boolean: ["boolean", "bool", "logical"],
|
|
23
|
+
pl.Date: ["date"],
|
|
24
|
+
pl.Time: ["time"],
|
|
25
|
+
pl.Datetime: ["timestamp", "datetime"],
|
|
26
|
+
pl.Duration: ["interval"],
|
|
27
|
+
pl.Binary: ["blob", "binary", "varbinary"]
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
sqrl_dtypes_to_polars_dtypes: dict[str, type[pl.DataType]] = {sqrl_type: k for k, v in polars_dtypes_to_sqrl_dtypes.items() for sqrl_type in v}
|
|
33
31
|
|
|
34
32
|
|
|
35
33
|
## Other utility classes
|
|
@@ -40,7 +38,7 @@ class Logger(logging.Logger):
|
|
|
40
38
|
time_taken = round((end_timestamp-start_timestamp) * 10**3, 3)
|
|
41
39
|
data = { "activity": activity, "start_timestamp": start_timestamp, "end_timestamp": end_timestamp, "time_taken_ms": time_taken }
|
|
42
40
|
info = { "request_id": request_id } if request_id else {}
|
|
43
|
-
self.
|
|
41
|
+
self.info(f'Time taken for "{activity}": {time_taken}ms', extra={"data": data, "info": info})
|
|
44
42
|
|
|
45
43
|
|
|
46
44
|
class EnvironmentWithMacros(j2.Environment):
|
|
@@ -85,14 +83,6 @@ class EnvironmentWithMacros(j2.Environment):
|
|
|
85
83
|
|
|
86
84
|
## Utility functions/variables
|
|
87
85
|
|
|
88
|
-
def log_activity_time(logger: logging.Logger, activity: str, start_timestamp: float, *, request_id: str | None = None) -> None:
|
|
89
|
-
end_timestamp = time.time()
|
|
90
|
-
time_taken = round((end_timestamp-start_timestamp) * 10**3, 3)
|
|
91
|
-
data = { "activity": activity, "start_timestamp": start_timestamp, "end_timestamp": end_timestamp, "time_taken_ms": time_taken }
|
|
92
|
-
info = { "request_id": request_id } if request_id else {}
|
|
93
|
-
logger.debug(f'Time taken for "{activity}": {time_taken}ms', extra={"data": data, "info": info})
|
|
94
|
-
|
|
95
|
-
|
|
96
86
|
def render_string(raw_str: str, *, base_path: str = ".", **kwargs) -> str:
|
|
97
87
|
"""
|
|
98
88
|
Given a template string, render it with the given keyword arguments
|
|
@@ -115,7 +105,6 @@ def read_file(filepath: FilePath) -> str:
|
|
|
115
105
|
|
|
116
106
|
Arguments:
|
|
117
107
|
filepath (str | pathlib.Path): The path to the file to read
|
|
118
|
-
is_required: If true, throw error if file doesn't exist
|
|
119
108
|
|
|
120
109
|
Returns:
|
|
121
110
|
Content of the file, or None if doesn't exist and not required
|
|
@@ -129,7 +118,7 @@ def read_file(filepath: FilePath) -> str:
|
|
|
129
118
|
|
|
130
119
|
def normalize_name(name: str) -> str:
|
|
131
120
|
"""
|
|
132
|
-
Normalizes names to the convention of the squirrels manifest file.
|
|
121
|
+
Normalizes names to the convention of the squirrels manifest file (with underscores instead of dashes).
|
|
133
122
|
|
|
134
123
|
Arguments:
|
|
135
124
|
name: The name to normalize.
|
|
@@ -142,7 +131,7 @@ def normalize_name(name: str) -> str:
|
|
|
142
131
|
|
|
143
132
|
def normalize_name_for_api(name: str) -> str:
|
|
144
133
|
"""
|
|
145
|
-
Normalizes names to the REST API convention.
|
|
134
|
+
Normalizes names to the REST API convention (with dashes instead of underscores).
|
|
146
135
|
|
|
147
136
|
Arguments:
|
|
148
137
|
name: The name to normalize.
|
|
@@ -180,7 +169,7 @@ def load_json_or_comma_delimited_str_as_list(input_str: Union[str, Sequence]) ->
|
|
|
180
169
|
return [x.strip() for x in input_str.split(",")]
|
|
181
170
|
|
|
182
171
|
|
|
183
|
-
X
|
|
172
|
+
X = TypeVar('X'); Y = TypeVar('Y')
|
|
184
173
|
def process_if_not_none(input_val: Optional[X], processor: Callable[[X], Y]) -> Optional[Y]:
|
|
185
174
|
"""
|
|
186
175
|
Given a input value and a function that processes the value, return the output of the function unless input is None
|
|
@@ -197,60 +186,246 @@ def process_if_not_none(input_val: Optional[X], processor: Callable[[X], Y]) ->
|
|
|
197
186
|
return processor(input_val)
|
|
198
187
|
|
|
199
188
|
|
|
200
|
-
def
|
|
189
|
+
def _read_duckdb_init_sql(
|
|
190
|
+
*,
|
|
191
|
+
datalake_db_path: str | None = None,
|
|
192
|
+
) -> str:
|
|
193
|
+
"""
|
|
194
|
+
Reads and caches the duckdb init file content.
|
|
195
|
+
Returns None if file doesn't exist or is empty.
|
|
196
|
+
"""
|
|
197
|
+
try:
|
|
198
|
+
init_contents = []
|
|
199
|
+
global_init_path = Path(os.path.expanduser('~'), c.GLOBAL_ENV_FOLDER, c.DUCKDB_INIT_FILE)
|
|
200
|
+
if global_init_path.exists():
|
|
201
|
+
with open(global_init_path, 'r') as f:
|
|
202
|
+
init_contents.append(f.read())
|
|
203
|
+
|
|
204
|
+
if Path(c.DUCKDB_INIT_FILE).exists():
|
|
205
|
+
with open(c.DUCKDB_INIT_FILE, 'r') as f:
|
|
206
|
+
init_contents.append(f.read())
|
|
207
|
+
|
|
208
|
+
if datalake_db_path:
|
|
209
|
+
attach_stmt = f"ATTACH '{datalake_db_path}' AS vdl (READ_ONLY);"
|
|
210
|
+
init_contents.append(attach_stmt)
|
|
211
|
+
|
|
212
|
+
init_sql = "\n\n".join(init_contents).strip()
|
|
213
|
+
return init_sql
|
|
214
|
+
except Exception as e:
|
|
215
|
+
raise ConfigurationError(f"Failed to read {c.DUCKDB_INIT_FILE}: {str(e)}") from e
|
|
216
|
+
|
|
217
|
+
def create_duckdb_connection(
|
|
218
|
+
db_path: str | Path = ":memory:",
|
|
219
|
+
*,
|
|
220
|
+
datalake_db_path: str | None = None
|
|
221
|
+
) -> duckdb.DuckDBPyConnection:
|
|
222
|
+
"""
|
|
223
|
+
Creates a DuckDB connection and initializes it with statements from duckdb init file
|
|
224
|
+
|
|
225
|
+
Arguments:
|
|
226
|
+
filepath: Path to the DuckDB database file. Defaults to in-memory database.
|
|
227
|
+
datalake_db_path: The path to the VDL catalog database if applicable. If exists, this is attached as 'vdl' (READ_ONLY). Default is None.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
A DuckDB connection (which must be closed after use)
|
|
231
|
+
"""
|
|
232
|
+
conn = duckdb.connect(db_path)
|
|
233
|
+
|
|
234
|
+
try:
|
|
235
|
+
init_sql = _read_duckdb_init_sql(datalake_db_path=datalake_db_path)
|
|
236
|
+
conn.execute(init_sql)
|
|
237
|
+
except Exception as e:
|
|
238
|
+
conn.close()
|
|
239
|
+
raise ConfigurationError(f"Failed to execute {c.DUCKDB_INIT_FILE}: {str(e)}") from e
|
|
240
|
+
|
|
241
|
+
return conn
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def run_sql_on_dataframes(sql_query: str, dataframes: dict[str, pl.LazyFrame]) -> pl.DataFrame:
|
|
201
245
|
"""
|
|
202
246
|
Runs a SQL query against a collection of dataframes
|
|
203
247
|
|
|
204
248
|
Arguments:
|
|
205
249
|
sql_query: The SQL query to run
|
|
206
|
-
dataframes: A dictionary of table names to their
|
|
250
|
+
dataframes: A dictionary of table names to their polars LazyFrame
|
|
207
251
|
|
|
208
252
|
Returns:
|
|
209
|
-
The result as a
|
|
253
|
+
The result as a polars Dataframe from running the query
|
|
210
254
|
"""
|
|
211
|
-
|
|
212
|
-
import duckdb
|
|
213
|
-
duckdb_conn = duckdb.connect()
|
|
214
|
-
else:
|
|
215
|
-
conn = sqlite3.connect(":memory:")
|
|
255
|
+
duckdb_conn = create_duckdb_connection()
|
|
216
256
|
|
|
217
257
|
try:
|
|
218
258
|
for name, df in dataframes.items():
|
|
219
|
-
|
|
220
|
-
duckdb_conn.execute(f"CREATE TABLE {name} AS FROM df")
|
|
221
|
-
else:
|
|
222
|
-
df.to_sql(name, conn, index=False)
|
|
259
|
+
duckdb_conn.register(name, df)
|
|
223
260
|
|
|
224
|
-
|
|
261
|
+
result_df = duckdb_conn.sql(sql_query).pl()
|
|
225
262
|
finally:
|
|
226
|
-
duckdb_conn.close()
|
|
263
|
+
duckdb_conn.close()
|
|
264
|
+
|
|
265
|
+
return result_df
|
|
227
266
|
|
|
228
267
|
|
|
229
|
-
def
|
|
268
|
+
def load_yaml_config(filepath: FilePath) -> dict:
|
|
230
269
|
"""
|
|
231
|
-
|
|
270
|
+
Loads a YAML config file
|
|
232
271
|
|
|
233
272
|
Arguments:
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
273
|
+
filepath: The path to the YAML file
|
|
274
|
+
|
|
237
275
|
Returns:
|
|
238
|
-
|
|
276
|
+
A dictionary representation of the YAML file
|
|
239
277
|
"""
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
col_name: str = in_column["name"]
|
|
245
|
-
out_column = { "name": col_name, "type": in_column["type"] }
|
|
246
|
-
out_fields.append(out_column)
|
|
278
|
+
try:
|
|
279
|
+
with open(filepath, 'r') as f:
|
|
280
|
+
content = yaml.safe_load(f)
|
|
281
|
+
content = content if content else {}
|
|
247
282
|
|
|
248
|
-
if not
|
|
249
|
-
|
|
283
|
+
if not isinstance(content, dict):
|
|
284
|
+
raise yaml.YAMLError(f"Parsed content from YAML file must be a dictionary. Got: {content}")
|
|
285
|
+
|
|
286
|
+
return content
|
|
287
|
+
except yaml.YAMLError as e:
|
|
288
|
+
raise ConfigurationError(f"Failed to parse yaml file: {filepath}") from e
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def run_duckdb_stmt(
|
|
292
|
+
logger: Logger, duckdb_conn: duckdb.DuckDBPyConnection, stmt: str, *, params: dict[str, Any] | None = None,
|
|
293
|
+
model_name: str | None = None, redacted_values: list[str] = []
|
|
294
|
+
) -> duckdb.DuckDBPyConnection:
|
|
295
|
+
"""
|
|
296
|
+
Runs a statement on a DuckDB connection
|
|
297
|
+
|
|
298
|
+
Arguments:
|
|
299
|
+
logger: The logger to use
|
|
300
|
+
duckdb_conn: The DuckDB connection
|
|
301
|
+
stmt: The statement to run
|
|
302
|
+
params: The parameters to use
|
|
303
|
+
redacted_values: The values to redact
|
|
304
|
+
"""
|
|
305
|
+
redacted_stmt = stmt
|
|
306
|
+
for value in redacted_values:
|
|
307
|
+
redacted_stmt = redacted_stmt.replace(value, "[REDACTED]")
|
|
250
308
|
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
309
|
+
for_model_name = f" for model '{model_name}'" if model_name is not None else ""
|
|
310
|
+
logger.debug(f"Running SQL statement{for_model_name}:\n{redacted_stmt}", extra={"data": {"params": params}})
|
|
311
|
+
try:
|
|
312
|
+
return duckdb_conn.execute(stmt, params)
|
|
313
|
+
except duckdb.ParserException as e:
|
|
314
|
+
logger.error(f"Failed to run statement: {redacted_stmt}", exc_info=e)
|
|
315
|
+
raise e
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def get_current_time() -> str:
|
|
319
|
+
"""
|
|
320
|
+
Returns the current time in the format HH:MM:SS.ms
|
|
321
|
+
"""
|
|
322
|
+
return datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def parse_dependent_tables(sql_query: str, all_table_names: Iterable[str]) -> tuple[set[str], sqlglot.Expression]:
|
|
326
|
+
"""
|
|
327
|
+
Parses the dependent tables from a SQL query
|
|
328
|
+
|
|
329
|
+
Arguments:
|
|
330
|
+
sql_query: The SQL query to parse
|
|
331
|
+
all_table_names: The list of all table names
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
The set of dependent tables
|
|
335
|
+
"""
|
|
336
|
+
# Parse the SQL query and extract all table references
|
|
337
|
+
parsed = sqlglot.parse_one(sql_query)
|
|
338
|
+
dependencies = set()
|
|
339
|
+
|
|
340
|
+
# Collect all table references from the parsed SQL
|
|
341
|
+
for table in parsed.find_all(sqlglot.expressions.Table):
|
|
342
|
+
if table.name in set(all_table_names):
|
|
343
|
+
dependencies.add(table.name)
|
|
344
|
+
|
|
345
|
+
return dependencies, parsed
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
async def asyncio_gather(coroutines: list):
|
|
349
|
+
tasks = [asyncio.create_task(coro) for coro in coroutines]
|
|
350
|
+
|
|
351
|
+
try:
|
|
352
|
+
return await asyncio.gather(*tasks)
|
|
353
|
+
except BaseException:
|
|
354
|
+
# Cancel all tasks
|
|
355
|
+
for task in tasks:
|
|
356
|
+
if not task.done():
|
|
357
|
+
task.cancel()
|
|
358
|
+
# Wait for tasks to be cancelled
|
|
359
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
360
|
+
raise
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def hash_string(input_str: str, salt: str) -> str:
|
|
364
|
+
"""
|
|
365
|
+
Hashes a string using SHA-256
|
|
366
|
+
"""
|
|
367
|
+
return hashlib.sha256((input_str + salt).encode()).hexdigest()
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
T = TypeVar('T')
|
|
371
|
+
def call_func(func: Callable[..., T], **kwargs) -> T:
|
|
372
|
+
"""
|
|
373
|
+
Calls a function with the given arguments if func expects arguments, otherwise calls func without arguments
|
|
374
|
+
"""
|
|
375
|
+
sig = inspect.signature(func)
|
|
376
|
+
# Filter kwargs to only include parameters that the function accepts
|
|
377
|
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters}
|
|
378
|
+
return func(**filtered_kwargs)
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def generate_pkce_challenge(code_verifier: str) -> str:
|
|
382
|
+
"""Generate PKCE code challenge from code verifier"""
|
|
383
|
+
# Generate SHA256 hash of code_verifier
|
|
384
|
+
verifier_hash = hashlib.sha256(code_verifier.encode('utf-8')).digest()
|
|
385
|
+
# Base64 URL encode (without padding)
|
|
386
|
+
expected_challenge = base64.urlsafe_b64encode(verifier_hash).decode('utf-8').rstrip('=')
|
|
387
|
+
return expected_challenge
|
|
388
|
+
|
|
389
|
+
def validate_pkce_challenge(code_verifier: str, code_challenge: str) -> bool:
|
|
390
|
+
"""Validate PKCE code verifier against code challenge"""
|
|
391
|
+
# Generate expected challenge
|
|
392
|
+
expected_challenge = generate_pkce_challenge(code_verifier)
|
|
393
|
+
return expected_challenge == code_challenge
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def get_scheme(hostname: str | None) -> str:
|
|
397
|
+
"""Get the scheme of the request"""
|
|
398
|
+
return "http" if hostname in ("localhost", "127.0.0.1") else "https"
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def to_title_case(input_str: str) -> str:
|
|
402
|
+
"""Convert a string to title case"""
|
|
403
|
+
spaced_str = input_str.replace('_', ' ').replace('-', ' ')
|
|
404
|
+
return spaced_str.title()
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def to_bool(val: object) -> bool:
|
|
408
|
+
"""Convert common truthy/falsey representations to a boolean.
|
|
409
|
+
|
|
410
|
+
Accepted truthy values (case-insensitive): "1", "true", "t", "yes", "y", "on".
|
|
411
|
+
All other values are considered falsey. None is falsey.
|
|
412
|
+
"""
|
|
413
|
+
if isinstance(val, bool):
|
|
414
|
+
return val
|
|
415
|
+
if val is None:
|
|
416
|
+
return False
|
|
417
|
+
s = str(val).strip().lower()
|
|
418
|
+
return s in ("1", "true", "t", "yes", "y", "on")
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
ACCESS_LEVEL = Literal["admin", "member", "guest"]
|
|
422
|
+
|
|
423
|
+
def get_access_level_rank(access_level: ACCESS_LEVEL) -> int:
|
|
424
|
+
"""Get the rank of an access level. Lower ranks have more privileges."""
|
|
425
|
+
return { "admin": 1, "member": 2, "guest": 3 }.get(access_level.lower(), 1)
|
|
426
|
+
|
|
427
|
+
def user_has_elevated_privileges(user_access_level: ACCESS_LEVEL, required_access_level: ACCESS_LEVEL) -> bool:
|
|
428
|
+
"""Check if a user has privilege to access a resource"""
|
|
429
|
+
user_access_level_rank = get_access_level_rank(user_access_level)
|
|
430
|
+
required_access_level_rank = get_access_level_rank(required_access_level)
|
|
431
|
+
return user_access_level_rank <= required_access_level_rank
|
squirrels/_version.py
CHANGED
squirrels/arguments.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
from ._arguments.init_time_args import ConnectionsArgs, AuthProviderArgs, ParametersArgs, BuildModelArgs
|
|
2
|
+
from ._arguments.run_time_args import ContextArgs, ModelArgs, DashboardArgs
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"ConnectionsArgs", "AuthProviderArgs", "ParametersArgs", "BuildModelArgs",
|
|
6
|
+
"ContextArgs", "ModelArgs", "DashboardArgs"
|
|
7
|
+
]
|
squirrels/auth.py
ADDED
squirrels/connections.py
ADDED
squirrels/dashboards.py
CHANGED
|
@@ -1,82 +1,3 @@
|
|
|
1
|
-
|
|
1
|
+
from ._dashboards import PngDashboard, HtmlDashboard
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
class Dashboard(metaclass=_abc.ABCMeta):
|
|
7
|
-
"""
|
|
8
|
-
Abstract parent class for all Dashboard classes.
|
|
9
|
-
"""
|
|
10
|
-
|
|
11
|
-
@property
|
|
12
|
-
@_abc.abstractmethod
|
|
13
|
-
def _content(self) -> bytes | str:
|
|
14
|
-
pass
|
|
15
|
-
|
|
16
|
-
@property
|
|
17
|
-
@_abc.abstractmethod
|
|
18
|
-
def _format(self) -> str:
|
|
19
|
-
pass
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class PngDashboard(Dashboard):
|
|
23
|
-
"""
|
|
24
|
-
Instantiate a Dashboard in PNG format from a matplotlib figure or bytes
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
def __init__(self, content: _figure.Figure | _io.BytesIO | bytes) -> None:
|
|
28
|
-
"""
|
|
29
|
-
Constructor for PngDashboard
|
|
30
|
-
|
|
31
|
-
Arguments:
|
|
32
|
-
content: The content of the dashboard as a matplotlib.figure.Figure or bytes
|
|
33
|
-
"""
|
|
34
|
-
if isinstance(content, _figure.Figure):
|
|
35
|
-
buffer = _io.BytesIO()
|
|
36
|
-
content.savefig(buffer, format=_c.PNG)
|
|
37
|
-
content = buffer.getvalue()
|
|
38
|
-
|
|
39
|
-
if isinstance(content, _io.BytesIO):
|
|
40
|
-
content = content.getvalue()
|
|
41
|
-
|
|
42
|
-
self.__content = content
|
|
43
|
-
|
|
44
|
-
@property
|
|
45
|
-
def _content(self) -> bytes:
|
|
46
|
-
return self.__content
|
|
47
|
-
|
|
48
|
-
@property
|
|
49
|
-
def _format(self) -> _t.Literal['png']:
|
|
50
|
-
return _c.PNG
|
|
51
|
-
|
|
52
|
-
def _repr_png_(self):
|
|
53
|
-
return self._content
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
class HtmlDashboard(Dashboard):
|
|
57
|
-
"""
|
|
58
|
-
Instantiate a Dashboard from an HTML string
|
|
59
|
-
"""
|
|
60
|
-
|
|
61
|
-
def __init__(self, content: _io.StringIO | str) -> None:
|
|
62
|
-
"""
|
|
63
|
-
Constructor for HtmlDashboard
|
|
64
|
-
|
|
65
|
-
Arguments:
|
|
66
|
-
content: The content of the dashboard as HTML string
|
|
67
|
-
"""
|
|
68
|
-
if isinstance(content, _io.StringIO):
|
|
69
|
-
content = content.getvalue()
|
|
70
|
-
|
|
71
|
-
self.__content = content
|
|
72
|
-
|
|
73
|
-
@property
|
|
74
|
-
def _content(self) -> str:
|
|
75
|
-
return self.__content
|
|
76
|
-
|
|
77
|
-
@property
|
|
78
|
-
def _format(self) -> _t.Literal['html']:
|
|
79
|
-
return _c.HTML
|
|
80
|
-
|
|
81
|
-
def _repr_html_(self):
|
|
82
|
-
return self._content
|
|
3
|
+
__all__ = ["PngDashboard", "HtmlDashboard"]
|