squirrels 0.5.0b2__py3-none-any.whl → 0.5.0b4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of squirrels might be problematic. Click here for more details.
- dateutils/__init__.py +6 -460
- dateutils/_enums.py +25 -0
- dateutils/_implementation.py +409 -0
- dateutils/types.py +6 -0
- squirrels/__init__.py +9 -13
- squirrels/_api_routes/__init__.py +5 -0
- squirrels/_api_routes/auth.py +262 -0
- squirrels/_api_routes/base.py +154 -0
- squirrels/_api_routes/dashboards.py +142 -0
- squirrels/_api_routes/data_management.py +103 -0
- squirrels/_api_routes/datasets.py +242 -0
- squirrels/_api_routes/oauth2.py +300 -0
- squirrels/_api_routes/project.py +214 -0
- squirrels/_api_server.py +145 -748
- squirrels/_arguments/__init__.py +0 -0
- squirrels/{arguments → _arguments}/init_time_args.py +7 -2
- squirrels/{arguments → _arguments}/run_time_args.py +4 -26
- squirrels/_auth.py +646 -93
- squirrels/_connection_set.py +5 -5
- squirrels/_constants.py +7 -1
- squirrels/{_dashboards_io.py → _dashboards.py} +87 -6
- squirrels/_data_sources.py +564 -0
- squirrels/_exceptions.py +9 -37
- squirrels/_initializer.py +31 -26
- squirrels/_manifest.py +5 -5
- squirrels/_model_builder.py +1 -1
- squirrels/_model_configs.py +2 -2
- squirrels/_model_queries.py +1 -1
- squirrels/_models.py +40 -27
- squirrels/{package_data → _package_data}/base_project/.env +1 -0
- squirrels/{package_data → _package_data}/base_project/.env.example +1 -0
- squirrels/{package_data → _package_data}/base_project/dashboards/dashboard_example.py +4 -4
- squirrels/{package_data → _package_data}/base_project/dashboards/dashboard_example.yml +2 -2
- squirrels/_package_data/base_project/macros/macros_example.sql +17 -0
- squirrels/{package_data → _package_data}/base_project/models/builds/build_example.py +2 -2
- squirrels/{package_data → _package_data}/base_project/models/builds/build_example.sql +1 -1
- squirrels/{package_data → _package_data}/base_project/models/dbviews/dbview_example.sql +1 -1
- squirrels/_package_data/base_project/models/federates/federate_example.py +41 -0
- squirrels/_package_data/base_project/models/federates/federate_example.sql +25 -0
- squirrels/{package_data → _package_data}/base_project/models/federates/federate_example.yml +6 -6
- squirrels/{package_data → _package_data}/base_project/parameters.yml +9 -8
- squirrels/_package_data/base_project/pyconfigs/connections.py +14 -0
- squirrels/{package_data → _package_data}/base_project/pyconfigs/context.py +14 -16
- squirrels/_package_data/base_project/pyconfigs/parameters.py +106 -0
- squirrels/_package_data/base_project/pyconfigs/user.py +51 -0
- squirrels/_package_data/templates/dataset_results.html +112 -0
- squirrels/_package_data/templates/oauth_login.html +271 -0
- squirrels/_parameter_configs.py +35 -35
- squirrels/_parameter_options.py +348 -0
- squirrels/_parameter_sets.py +47 -37
- squirrels/_parameters.py +1664 -0
- squirrels/_project.py +76 -32
- squirrels/_py_module.py +3 -2
- squirrels/_schemas/__init__.py +0 -0
- squirrels/_schemas/auth_models.py +144 -0
- squirrels/_schemas/query_param_models.py +67 -0
- squirrels/{_api_response_models.py → _schemas/response_models.py} +12 -8
- squirrels/_utils.py +38 -4
- squirrels/arguments.py +2 -0
- squirrels/auth.py +1 -0
- squirrels/connections.py +1 -0
- squirrels/dashboards.py +1 -82
- squirrels/data_sources.py +8 -563
- squirrels/parameter_options.py +8 -348
- squirrels/parameters.py +9 -1266
- squirrels/types.py +11 -0
- {squirrels-0.5.0b2.dist-info → squirrels-0.5.0b4.dist-info}/METADATA +4 -1
- squirrels-0.5.0b4.dist-info/RECORD +94 -0
- squirrels/package_data/base_project/macros/macros_example.sql +0 -15
- squirrels/package_data/base_project/models/federates/federate_example.py +0 -44
- squirrels/package_data/base_project/models/federates/federate_example.sql +0 -17
- squirrels/package_data/base_project/pyconfigs/connections.py +0 -14
- squirrels/package_data/base_project/pyconfigs/parameters.py +0 -93
- squirrels/package_data/base_project/pyconfigs/user.py +0 -23
- squirrels-0.5.0b2.dist-info/RECORD +0 -70
- /squirrels/{dataset_result.py → _dataset_types.py} +0 -0
- /squirrels/{package_data → _package_data}/base_project/assets/expenses.db +0 -0
- /squirrels/{package_data → _package_data}/base_project/assets/weather.db +0 -0
- /squirrels/{package_data → _package_data}/base_project/connections.yml +0 -0
- /squirrels/{package_data → _package_data}/base_project/docker/.dockerignore +0 -0
- /squirrels/{package_data → _package_data}/base_project/docker/Dockerfile +0 -0
- /squirrels/{package_data → _package_data}/base_project/docker/compose.yml +0 -0
- /squirrels/{package_data → _package_data}/base_project/duckdb_init.sql +0 -0
- /squirrels/{package_data/base_project/.gitignore → _package_data/base_project/gitignore} +0 -0
- /squirrels/{package_data → _package_data}/base_project/models/builds/build_example.yml +0 -0
- /squirrels/{package_data → _package_data}/base_project/models/dbviews/dbview_example.yml +0 -0
- /squirrels/{package_data → _package_data}/base_project/models/sources.yml +0 -0
- /squirrels/{package_data → _package_data}/base_project/seeds/seed_categories.csv +0 -0
- /squirrels/{package_data → _package_data}/base_project/seeds/seed_categories.yml +0 -0
- /squirrels/{package_data → _package_data}/base_project/seeds/seed_subcategories.csv +0 -0
- /squirrels/{package_data → _package_data}/base_project/seeds/seed_subcategories.yml +0 -0
- /squirrels/{package_data → _package_data}/base_project/squirrels.yml.j2 +0 -0
- /squirrels/{package_data → _package_data}/base_project/tmp/.gitignore +0 -0
- {squirrels-0.5.0b2.dist-info → squirrels-0.5.0b4.dist-info}/WHEEL +0 -0
- {squirrels-0.5.0b2.dist-info → squirrels-0.5.0b4.dist-info}/entry_points.txt +0 -0
- {squirrels-0.5.0b2.dist-info → squirrels-0.5.0b4.dist-info}/licenses/LICENSE +0 -0
squirrels/_project.py
CHANGED
|
@@ -1,17 +1,20 @@
|
|
|
1
1
|
from dotenv import dotenv_values
|
|
2
2
|
from uuid import uuid4
|
|
3
|
+
from pathlib import Path
|
|
3
4
|
import asyncio, typing as t, functools as ft, shutil, json, os
|
|
4
5
|
import logging as l, matplotlib.pyplot as plt, networkx as nx, polars as pl
|
|
5
6
|
import sqlglot, sqlglot.expressions
|
|
6
7
|
|
|
7
|
-
from ._auth import Authenticator, BaseUser
|
|
8
|
+
from ._auth import Authenticator, BaseUser, AuthProviderArgs, ProviderFunctionType
|
|
9
|
+
from ._schemas import response_models as rm
|
|
8
10
|
from ._model_builder import ModelBuilder
|
|
9
11
|
from ._exceptions import InvalidInputError, ConfigurationError
|
|
10
|
-
from . import
|
|
12
|
+
from ._py_module import PyModule
|
|
13
|
+
from . import _dashboards as d, _utils as u, _constants as c, _manifest as mf, _connection_set as cs
|
|
11
14
|
from . import _seeds as s, _models as m, _model_configs as mc, _model_queries as mq, _sources as so
|
|
12
|
-
from . import _parameter_sets as ps,
|
|
15
|
+
from . import _parameter_sets as ps, _dataset_types as dr
|
|
13
16
|
|
|
14
|
-
T = t.TypeVar("T", bound=
|
|
17
|
+
T = t.TypeVar("T", bound=d.Dashboard)
|
|
15
18
|
M = t.TypeVar("M", bound=m.DataModel)
|
|
16
19
|
|
|
17
20
|
|
|
@@ -70,7 +73,7 @@ class SquirrelsProject:
|
|
|
70
73
|
raise ValueError("log_format must be either 'text' or 'json'")
|
|
71
74
|
|
|
72
75
|
if log_file:
|
|
73
|
-
path =
|
|
76
|
+
path = Path(base_path, c.LOGS_FOLDER, log_file)
|
|
74
77
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
75
78
|
|
|
76
79
|
handler = l.FileHandler(path)
|
|
@@ -128,12 +131,33 @@ class SquirrelsProject:
|
|
|
128
131
|
return cs.ConnectionSetIO.load_from_file(self._logger, self._filepath, self._manifest_cfg, self._conn_args)
|
|
129
132
|
|
|
130
133
|
@ft.cached_property
|
|
131
|
-
def
|
|
132
|
-
|
|
134
|
+
def _user_cls_and_provider_functions(self) -> tuple[type[BaseUser], list[ProviderFunctionType]]:
|
|
135
|
+
user_module_path = u.Path(self._filepath, c.PYCONFIGS_FOLDER, c.USER_FILE)
|
|
136
|
+
user_module = PyModule(user_module_path)
|
|
137
|
+
|
|
138
|
+
User = user_module.get_func_or_class("User", default_attr=BaseUser) # adds to Authenticator.providers as side effect
|
|
139
|
+
provider_functions = Authenticator.providers
|
|
140
|
+
Authenticator.providers = []
|
|
141
|
+
|
|
142
|
+
if not issubclass(User, BaseUser):
|
|
143
|
+
raise ConfigurationError(f"User class in '{c.USER_FILE}' must inherit from BaseUser")
|
|
144
|
+
|
|
145
|
+
return User, provider_functions
|
|
146
|
+
|
|
147
|
+
@ft.cached_property
|
|
148
|
+
def _auth_args(self) -> AuthProviderArgs:
|
|
149
|
+
conn_args = self._conn_args
|
|
150
|
+
return AuthProviderArgs(conn_args.project_path, conn_args.proj_vars, conn_args.env_vars)
|
|
151
|
+
|
|
152
|
+
@ft.cached_property
|
|
153
|
+
def _auth(self) -> Authenticator[BaseUser]:
|
|
154
|
+
User, provider_functions = self._user_cls_and_provider_functions
|
|
155
|
+
return Authenticator(self._logger, self._filepath, self._auth_args, provider_functions, user_cls=User)
|
|
133
156
|
|
|
134
157
|
@ft.cached_property
|
|
135
158
|
def _param_args(self) -> ps.ParametersArgs:
|
|
136
|
-
|
|
159
|
+
conn_args = self._conn_args
|
|
160
|
+
return ps.ParametersArgs(conn_args.project_path, conn_args.proj_vars, conn_args.env_vars)
|
|
137
161
|
|
|
138
162
|
@ft.cached_property
|
|
139
163
|
def _param_cfg_set(self) -> ps.ParameterConfigsSet:
|
|
@@ -143,12 +167,32 @@ class SquirrelsProject:
|
|
|
143
167
|
|
|
144
168
|
@ft.cached_property
|
|
145
169
|
def _j2_env(self) -> u.EnvironmentWithMacros:
|
|
146
|
-
|
|
170
|
+
env = u.EnvironmentWithMacros(self._logger, loader=u.j2.FileSystemLoader(self._filepath))
|
|
171
|
+
|
|
172
|
+
def value_to_str(value: t.Any, attribute: str | None = None) -> str:
|
|
173
|
+
if attribute is None:
|
|
174
|
+
return str(value)
|
|
175
|
+
else:
|
|
176
|
+
return str(getattr(value, attribute))
|
|
177
|
+
|
|
178
|
+
def join(value: list[t.Any], d: str = ", ", attribute: str | None = None) -> str:
|
|
179
|
+
return d.join(map(lambda x: value_to_str(x, attribute), value))
|
|
180
|
+
|
|
181
|
+
def quote(value: t.Any, q: str = "'", attribute: str | None = None) -> str:
|
|
182
|
+
return q + value_to_str(value, attribute) + q
|
|
183
|
+
|
|
184
|
+
def quote_and_join(value: list[t.Any], q: str = "'", d: str = ", ", attribute: str | None = None) -> str:
|
|
185
|
+
return d.join(map(lambda x: quote(x, q, attribute), value))
|
|
186
|
+
|
|
187
|
+
env.filters["join"] = join
|
|
188
|
+
env.filters["quote"] = quote
|
|
189
|
+
env.filters["quote_and_join"] = quote_and_join
|
|
190
|
+
return env
|
|
147
191
|
|
|
148
192
|
@ft.cached_property
|
|
149
193
|
def _duckdb_venv_path(self) -> str:
|
|
150
194
|
duckdb_filepath_setting_val = self._env_vars.get(c.SQRL_DUCKDB_VENV_DB_FILE_PATH, f"{c.TARGET_FOLDER}/{c.DUCKDB_VENV_FILE}")
|
|
151
|
-
return str(
|
|
195
|
+
return str(Path(self._filepath, duckdb_filepath_setting_val))
|
|
152
196
|
|
|
153
197
|
def close(self) -> None:
|
|
154
198
|
"""
|
|
@@ -236,7 +280,7 @@ class SquirrelsProject:
|
|
|
236
280
|
for model_name in dependencies:
|
|
237
281
|
model = models_dict[model_name]
|
|
238
282
|
if isinstance(model, m.SourceModel) and not model.model_config.load_to_duckdb:
|
|
239
|
-
raise InvalidInputError(
|
|
283
|
+
raise InvalidInputError(400, "Unqueryable source model", f"Source model '{model_name}' cannot be queried with DuckDB")
|
|
240
284
|
if isinstance(model, (m.SourceModel, m.BuildModel)):
|
|
241
285
|
substitutions[model_name] = f"venv.{model_name}"
|
|
242
286
|
|
|
@@ -255,7 +299,7 @@ class SquirrelsProject:
|
|
|
255
299
|
dag = m.DAG(None, fake_target_model, models_dict, self._duckdb_venv_path, self._logger)
|
|
256
300
|
return dag
|
|
257
301
|
|
|
258
|
-
def _draw_dag(self, dag: m.DAG, output_folder:
|
|
302
|
+
def _draw_dag(self, dag: m.DAG, output_folder: Path) -> None:
|
|
259
303
|
color_map = {
|
|
260
304
|
m.ModelType.SEED: "green", m.ModelType.DBVIEW: "red", m.ModelType.FEDERATE: "skyblue",
|
|
261
305
|
m.ModelType.BUILD: "purple", m.ModelType.SOURCE: "orange"
|
|
@@ -275,7 +319,7 @@ class SquirrelsProject:
|
|
|
275
319
|
|
|
276
320
|
fig.tight_layout()
|
|
277
321
|
plt.margins(x=0.1, y=0.1)
|
|
278
|
-
fig.savefig(
|
|
322
|
+
fig.savefig(Path(output_folder, "dag.png"))
|
|
279
323
|
plt.close(fig)
|
|
280
324
|
|
|
281
325
|
async def _get_compiled_dag(self, *, sql_query: str | None = None, selections: dict[str, t.Any] = {}, user: BaseUser | None = None) -> m.DAG:
|
|
@@ -285,18 +329,18 @@ class SquirrelsProject:
|
|
|
285
329
|
await dag.execute(self._param_args, self._param_cfg_set, self._context_func, user, selections, runquery=False, default_traits=default_traits)
|
|
286
330
|
return dag
|
|
287
331
|
|
|
288
|
-
def _get_all_connections(self) -> list[
|
|
332
|
+
def _get_all_connections(self) -> list[rm.ConnectionItemModel]:
|
|
289
333
|
connections = []
|
|
290
334
|
for conn_name, conn_props in self._conn_set.get_connections_as_dict().items():
|
|
291
335
|
if isinstance(conn_props, mf.ConnectionProperties):
|
|
292
336
|
label = conn_props.label if conn_props.label is not None else conn_name
|
|
293
|
-
connections.append(
|
|
337
|
+
connections.append(rm.ConnectionItemModel(name=conn_name, label=label))
|
|
294
338
|
return connections
|
|
295
339
|
|
|
296
|
-
def _get_all_data_models(self, compiled_dag: m.DAG) -> list[
|
|
340
|
+
def _get_all_data_models(self, compiled_dag: m.DAG) -> list[rm.DataModelItem]:
|
|
297
341
|
return compiled_dag.get_all_data_models()
|
|
298
342
|
|
|
299
|
-
async def get_all_data_models(self) -> list[
|
|
343
|
+
async def get_all_data_models(self) -> list[rm.DataModelItem]:
|
|
300
344
|
"""
|
|
301
345
|
Get all data models in the project
|
|
302
346
|
|
|
@@ -306,26 +350,26 @@ class SquirrelsProject:
|
|
|
306
350
|
compiled_dag = await self._get_compiled_dag()
|
|
307
351
|
return self._get_all_data_models(compiled_dag)
|
|
308
352
|
|
|
309
|
-
def _get_all_data_lineage(self, compiled_dag: m.DAG) -> list[
|
|
353
|
+
def _get_all_data_lineage(self, compiled_dag: m.DAG) -> list[rm.LineageRelation]:
|
|
310
354
|
all_lineage = compiled_dag.get_all_model_lineage()
|
|
311
355
|
|
|
312
356
|
# Add dataset nodes to the lineage
|
|
313
357
|
for dataset in self._manifest_cfg.datasets.values():
|
|
314
|
-
target_dataset =
|
|
315
|
-
source_model =
|
|
316
|
-
all_lineage.append(
|
|
358
|
+
target_dataset = rm.LineageNode(name=dataset.name, type="dataset")
|
|
359
|
+
source_model = rm.LineageNode(name=dataset.model, type="model")
|
|
360
|
+
all_lineage.append(rm.LineageRelation(type="runtime", source=source_model, target=target_dataset))
|
|
317
361
|
|
|
318
362
|
# Add dashboard nodes to the lineage
|
|
319
363
|
for dashboard in self._dashboards.values():
|
|
320
|
-
target_dashboard =
|
|
364
|
+
target_dashboard = rm.LineageNode(name=dashboard.dashboard_name, type="dashboard")
|
|
321
365
|
datasets = set(x.dataset for x in dashboard.config.depends_on)
|
|
322
366
|
for dataset in datasets:
|
|
323
|
-
source_dataset =
|
|
324
|
-
all_lineage.append(
|
|
367
|
+
source_dataset = rm.LineageNode(name=dataset, type="dataset")
|
|
368
|
+
all_lineage.append(rm.LineageRelation(type="runtime", source=source_dataset, target=target_dashboard))
|
|
325
369
|
|
|
326
370
|
return all_lineage
|
|
327
371
|
|
|
328
|
-
async def get_all_data_lineage(self) -> list[
|
|
372
|
+
async def get_all_data_lineage(self) -> list[rm.LineageRelation]:
|
|
329
373
|
"""
|
|
330
374
|
Get all data lineage in the project
|
|
331
375
|
|
|
@@ -371,28 +415,28 @@ class SquirrelsProject:
|
|
|
371
415
|
runquery=runquery, recurse=recurse, default_traits=self._manifest_cfg.get_default_traits()
|
|
372
416
|
)
|
|
373
417
|
|
|
374
|
-
output_folder =
|
|
418
|
+
output_folder = Path(self._filepath, c.TARGET_FOLDER, c.COMPILE_FOLDER, dataset, test_set)
|
|
375
419
|
if output_folder.exists():
|
|
376
420
|
shutil.rmtree(output_folder)
|
|
377
421
|
output_folder.mkdir(parents=True, exist_ok=True)
|
|
378
422
|
|
|
379
423
|
def write_placeholders() -> None:
|
|
380
|
-
output_filepath =
|
|
424
|
+
output_filepath = Path(output_folder, "placeholders.json")
|
|
381
425
|
with open(output_filepath, 'w') as f:
|
|
382
426
|
json.dump(dag.placeholders, f, indent=4)
|
|
383
427
|
|
|
384
428
|
def write_model_outputs(model: m.DataModel) -> None:
|
|
385
429
|
assert isinstance(model, m.QueryModel)
|
|
386
430
|
subfolder = c.DBVIEWS_FOLDER if model.model_type == m.ModelType.DBVIEW else c.FEDERATES_FOLDER
|
|
387
|
-
subpath =
|
|
431
|
+
subpath = Path(output_folder, subfolder)
|
|
388
432
|
subpath.mkdir(parents=True, exist_ok=True)
|
|
389
433
|
if isinstance(model.compiled_query, mq.SqlModelQuery):
|
|
390
|
-
output_filepath =
|
|
434
|
+
output_filepath = Path(subpath, model.name+'.sql')
|
|
391
435
|
query = model.compiled_query.query
|
|
392
436
|
with open(output_filepath, 'w') as f:
|
|
393
437
|
f.write(query)
|
|
394
438
|
if runquery and isinstance(model.result, pl.LazyFrame):
|
|
395
|
-
output_filepath =
|
|
439
|
+
output_filepath = Path(subpath, model.name+'.csv')
|
|
396
440
|
model.result.collect().write_csv(output_filepath)
|
|
397
441
|
|
|
398
442
|
write_placeholders()
|
|
@@ -455,7 +499,7 @@ class SquirrelsProject:
|
|
|
455
499
|
|
|
456
500
|
def _permission_error(self, user: BaseUser | None, data_type: str, data_name: str, scope: str) -> InvalidInputError:
|
|
457
501
|
username = "" if user is None else f" '{user.username}'"
|
|
458
|
-
return InvalidInputError(
|
|
502
|
+
return InvalidInputError(403, f"Unauthorized access to {data_type}", f"User{username} does not have permission to access {scope} {data_type}: {data_name}")
|
|
459
503
|
|
|
460
504
|
def seed(self, name: str) -> pl.LazyFrame:
|
|
461
505
|
"""
|
|
@@ -520,7 +564,7 @@ class SquirrelsProject:
|
|
|
520
564
|
)
|
|
521
565
|
|
|
522
566
|
async def dashboard(
|
|
523
|
-
self, name: str, *, selections: dict[str, t.Any] = {}, user: BaseUser | None = None, dashboard_type: t.Type[T] =
|
|
567
|
+
self, name: str, *, selections: dict[str, t.Any] = {}, user: BaseUser | None = None, dashboard_type: t.Type[T] = d.PngDashboard
|
|
524
568
|
) -> T:
|
|
525
569
|
"""
|
|
526
570
|
Async method to retrieve a dashboard given parameter selections.
|
squirrels/_py_module.py
CHANGED
|
@@ -43,11 +43,12 @@ class PyModule:
|
|
|
43
43
|
return func_or_class
|
|
44
44
|
|
|
45
45
|
|
|
46
|
-
def run_pyconfig_main(base_path: str, filename: str, kwargs: dict[str, Any] = {}) -> None:
|
|
46
|
+
def run_pyconfig_main(base_path: str, filename: str, kwargs: dict[str, Any] = {}) -> Any | None:
|
|
47
47
|
"""
|
|
48
48
|
Given a python file in the 'pyconfigs' folder, run its main function
|
|
49
49
|
|
|
50
50
|
Arguments:
|
|
51
|
+
base_path: The base path of the project
|
|
51
52
|
filename: The name of the file to run main function
|
|
52
53
|
kwargs: Dictionary of the main function arguments
|
|
53
54
|
"""
|
|
@@ -56,6 +57,6 @@ def run_pyconfig_main(base_path: str, filename: str, kwargs: dict[str, Any] = {}
|
|
|
56
57
|
main_function = module.get_func_or_class(c.MAIN_FUNC, is_required=False)
|
|
57
58
|
if main_function:
|
|
58
59
|
try:
|
|
59
|
-
main_function(**kwargs)
|
|
60
|
+
return main_function(**kwargs)
|
|
60
61
|
except Exception as e:
|
|
61
62
|
raise FileExecutionError(f'Failed to run python file "{filepath}"', e) from e
|
|
File without changes
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from typing import Callable, Any
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BaseUser(BaseModel):
|
|
7
|
+
model_config = ConfigDict(from_attributes=True)
|
|
8
|
+
username: str
|
|
9
|
+
is_admin: bool = False
|
|
10
|
+
|
|
11
|
+
@classmethod
|
|
12
|
+
def dropped_columns(cls):
|
|
13
|
+
return []
|
|
14
|
+
|
|
15
|
+
def __hash__(self):
|
|
16
|
+
return hash(self.username)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ApiKey(BaseModel):
|
|
20
|
+
model_config = ConfigDict(from_attributes=True)
|
|
21
|
+
id: str
|
|
22
|
+
title: str
|
|
23
|
+
username: str
|
|
24
|
+
created_at: datetime
|
|
25
|
+
expires_at: datetime
|
|
26
|
+
|
|
27
|
+
@field_serializer('created_at', 'expires_at')
|
|
28
|
+
def serialize_datetime(self, dt: datetime) -> str:
|
|
29
|
+
return dt.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class UserField(BaseModel):
|
|
33
|
+
name: str
|
|
34
|
+
type: str
|
|
35
|
+
nullable: bool
|
|
36
|
+
enum: list[str] | None
|
|
37
|
+
default: Any | None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ProviderConfigs(BaseModel):
|
|
41
|
+
client_id: str
|
|
42
|
+
client_secret: str
|
|
43
|
+
server_metadata_url: str
|
|
44
|
+
client_kwargs: dict = Field(default_factory=dict)
|
|
45
|
+
get_user: Callable[[dict], BaseUser]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class AuthProvider(BaseModel):
|
|
49
|
+
name: str
|
|
50
|
+
label: str
|
|
51
|
+
icon: str
|
|
52
|
+
provider_configs: ProviderConfigs
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# OAuth 2.1 Models
|
|
56
|
+
|
|
57
|
+
class OAuthClientModel(BaseModel):
|
|
58
|
+
"""OAuth client details"""
|
|
59
|
+
model_config = ConfigDict(from_attributes=True)
|
|
60
|
+
client_id: str
|
|
61
|
+
client_name: str
|
|
62
|
+
redirect_uris: list[str]
|
|
63
|
+
scope: str
|
|
64
|
+
grant_types: list[str]
|
|
65
|
+
response_types: list[str]
|
|
66
|
+
created_at: datetime
|
|
67
|
+
is_active: bool
|
|
68
|
+
|
|
69
|
+
@field_serializer('created_at')
|
|
70
|
+
def serialize_datetime(self, dt: datetime) -> str:
|
|
71
|
+
return dt.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class ClientRegistrationRequest(BaseModel):
|
|
75
|
+
"""Request model for OAuth client registration"""
|
|
76
|
+
client_name: str = Field(description="Human-readable name for the OAuth client")
|
|
77
|
+
redirect_uris: list[str] = Field(description="List of allowed redirect URIs for the client")
|
|
78
|
+
scope: str = Field(default="read", description="Default scope for the client")
|
|
79
|
+
grant_types: list[str] = Field(default=["authorization_code", "refresh_token"], description="Allowed grant types")
|
|
80
|
+
response_types: list[str] = Field(default=["code"], description="Allowed response types")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class ClientUpdateRequest(BaseModel):
|
|
84
|
+
"""Request model for OAuth client update"""
|
|
85
|
+
client_name: str | None = Field(default=None, description="Human-readable name for the OAuth client")
|
|
86
|
+
redirect_uris: list[str] | None = Field(default=None, description="List of allowed redirect URIs for the client")
|
|
87
|
+
scope: str | None = Field(default=None, description="Default scope for the client")
|
|
88
|
+
grant_types: list[str] | None = Field(default=None, description="Allowed grant types")
|
|
89
|
+
response_types: list[str] | None = Field(default=None, description="Allowed response types")
|
|
90
|
+
is_active: bool | None = Field(default=None, description="Whether the client is active")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class ClientDetailsResponse(BaseModel):
|
|
94
|
+
"""Response model for OAuth client details (without client_secret)"""
|
|
95
|
+
client_id: str = Field(description="Client ID")
|
|
96
|
+
client_name: str = Field(description="Client name")
|
|
97
|
+
redirect_uris: list[str] = Field(description="Registered redirect URIs")
|
|
98
|
+
scope: str = Field(description="Default scope")
|
|
99
|
+
grant_types: list[str] = Field(description="Allowed grant types")
|
|
100
|
+
response_types: list[str] = Field(description="Allowed response types")
|
|
101
|
+
created_at: datetime = Field(description="Registration timestamp")
|
|
102
|
+
is_active: bool = Field(description="Whether the client is active")
|
|
103
|
+
|
|
104
|
+
@field_serializer('created_at')
|
|
105
|
+
def serialize_datetime(self, dt: datetime) -> str:
|
|
106
|
+
return dt.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class ClientUpdateResponse(ClientDetailsResponse):
|
|
110
|
+
"""Response model for OAuth client update"""
|
|
111
|
+
registration_access_token: str | None = Field(default=None, description="Token for managing this client registration (store securely)")
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class ClientRegistrationResponse(ClientUpdateResponse):
|
|
115
|
+
"""Response model for OAuth client registration"""
|
|
116
|
+
client_secret: str = Field(description="Generated client secret (store securely)")
|
|
117
|
+
registration_client_uri: str | None = Field(default=None, description="URI for managing this client registration")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class TokenResponse(BaseModel):
|
|
121
|
+
access_token: str
|
|
122
|
+
token_type: str = "bearer"
|
|
123
|
+
expires_in: int
|
|
124
|
+
refresh_token: str | None = None
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class OAuthServerMetadata(BaseModel):
|
|
128
|
+
"""OAuth 2.1 Authorization Server Metadata (RFC 8414)"""
|
|
129
|
+
issuer: str = Field(description="Authorization server's issuer identifier URL")
|
|
130
|
+
authorization_endpoint: str = Field(description="URL of the authorization endpoint")
|
|
131
|
+
token_endpoint: str = Field(description="URL of the token endpoint")
|
|
132
|
+
revocation_endpoint: str = Field(description="URL of the token revocation endpoint")
|
|
133
|
+
registration_endpoint: str = Field(description="URL of the client registration endpoint")
|
|
134
|
+
scopes_supported: list[str] = Field(description="List of OAuth 2.1 scope values supported")
|
|
135
|
+
response_types_supported: list[str] = Field(description="List of OAuth 2.1 response_type values supported")
|
|
136
|
+
grant_types_supported: list[str] = Field(description="List of OAuth 2.1 grant type values supported")
|
|
137
|
+
token_endpoint_auth_methods_supported: list[str] = Field(
|
|
138
|
+
default=["client_secret_basic", "client_secret_post"],
|
|
139
|
+
description="List of client authentication methods supported by the token endpoint"
|
|
140
|
+
)
|
|
141
|
+
code_challenge_methods_supported: list[str] = Field(
|
|
142
|
+
default=["S256"],
|
|
143
|
+
description="List of PKCE code challenge methods supported"
|
|
144
|
+
)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Query model generation utilities for API routes
|
|
3
|
+
"""
|
|
4
|
+
from typing import Annotated
|
|
5
|
+
from dataclasses import make_dataclass
|
|
6
|
+
from fastapi import Depends
|
|
7
|
+
from pydantic import create_model
|
|
8
|
+
|
|
9
|
+
from .._parameter_configs import APIParamFieldInfo
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _get_query_models_helper(widget_parameters: list[str] | None, predefined_params: list[APIParamFieldInfo], param_fields: dict):
|
|
13
|
+
"""Helper function to generate query models"""
|
|
14
|
+
if widget_parameters is None:
|
|
15
|
+
widget_parameters = list(param_fields.keys())
|
|
16
|
+
|
|
17
|
+
QueryModelForGetRaw = make_dataclass("QueryParams", [
|
|
18
|
+
param_fields[param].as_query_info() for param in widget_parameters
|
|
19
|
+
] + [param.as_query_info() for param in predefined_params])
|
|
20
|
+
QueryModelForGet = Annotated[QueryModelForGetRaw, Depends()]
|
|
21
|
+
|
|
22
|
+
field_definitions = {param: param_fields[param].as_body_info() for param in widget_parameters}
|
|
23
|
+
for param in predefined_params:
|
|
24
|
+
field_definitions[param.name] = param.as_body_info()
|
|
25
|
+
QueryModelForPost = create_model("RequestBodyParams", **field_definitions) # type: ignore
|
|
26
|
+
return QueryModelForGet, QueryModelForPost
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def get_query_models_for_parameters(widget_parameters: list[str] | None, param_fields: dict):
|
|
30
|
+
"""Generate query models for parameter endpoints"""
|
|
31
|
+
predefined_params = [
|
|
32
|
+
APIParamFieldInfo("x_verify_params", bool, default=False, description="If true, the query parameters are verified to be valid for the dataset"),
|
|
33
|
+
APIParamFieldInfo("x_parent_param", str, description="The parameter name used for parameter updates. If not provided, then all parameters are retrieved"),
|
|
34
|
+
]
|
|
35
|
+
return _get_query_models_helper(widget_parameters, predefined_params, param_fields)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_query_models_for_dataset(widget_parameters: list[str] | None, param_fields: dict):
|
|
39
|
+
"""Generate query models for dataset endpoints"""
|
|
40
|
+
predefined_params = [
|
|
41
|
+
APIParamFieldInfo("x_verify_params", bool, default=False, description="If true, the query parameters are verified to be valid for the dataset"),
|
|
42
|
+
APIParamFieldInfo("x_orientation", str, default="records", description="The orientation of the data to return, one of: 'records', 'rows', or 'columns'"),
|
|
43
|
+
APIParamFieldInfo("x_select", list[str], examples=[[]], description="The columns to select from the dataset. All are returned if not specified"),
|
|
44
|
+
APIParamFieldInfo("x_offset", int, default=0, description="The number of rows to skip before returning data (applied after data caching)"),
|
|
45
|
+
APIParamFieldInfo("x_limit", int, default=1000, description="The maximum number of rows to return (applied after data caching and offset)"),
|
|
46
|
+
]
|
|
47
|
+
return _get_query_models_helper(widget_parameters, predefined_params, param_fields)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def get_query_models_for_dashboard(widget_parameters: list[str] | None, param_fields: dict):
|
|
51
|
+
"""Generate query models for dashboard endpoints"""
|
|
52
|
+
predefined_params = [
|
|
53
|
+
APIParamFieldInfo("x_verify_params", bool, default=False, description="If true, the query parameters are verified to be valid for the dashboard"),
|
|
54
|
+
]
|
|
55
|
+
return _get_query_models_helper(widget_parameters, predefined_params, param_fields)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_query_models_for_querying_models(param_fields: dict):
|
|
59
|
+
"""Generate query models for querying data models"""
|
|
60
|
+
predefined_params = [
|
|
61
|
+
APIParamFieldInfo("x_verify_params", bool, default=False, description="If true, the query parameters are verified to be valid"),
|
|
62
|
+
APIParamFieldInfo("x_orientation", str, default="records", description="The orientation of the data to return, one of: 'records', 'rows', or 'columns'"),
|
|
63
|
+
APIParamFieldInfo("x_offset", int, default=0, description="The number of rows to skip before returning data (applied after data caching)"),
|
|
64
|
+
APIParamFieldInfo("x_limit", int, default=1000, description="The maximum number of rows to return (applied after data caching and offset)"),
|
|
65
|
+
APIParamFieldInfo("x_sql_query", str, description="The SQL query to execute on the data models"),
|
|
66
|
+
]
|
|
67
|
+
return _get_query_models_helper(None, predefined_params, param_fields)
|
|
@@ -1,16 +1,20 @@
|
|
|
1
1
|
from typing import Annotated, Literal
|
|
2
2
|
from pydantic import BaseModel, Field
|
|
3
|
-
from datetime import
|
|
3
|
+
from datetime import date
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from .. import _model_configs as mc, _sources as s
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
8
|
+
## Simple Auth Response Models
|
|
9
|
+
|
|
10
|
+
class ApiKeyResponse(BaseModel):
|
|
11
|
+
api_key: Annotated[str, Field(examples=["sqrl-12345678"], description="The API key to use subsequent API requests")]
|
|
12
|
+
|
|
13
|
+
class ProviderResponse(BaseModel):
|
|
14
|
+
name: Annotated[str, Field(examples=["my_provider"], description="The name of the provider")]
|
|
15
|
+
label: Annotated[str, Field(examples=["My Provider"], description="The human-friendly display name for the provider")]
|
|
16
|
+
icon: Annotated[str, Field(examples=["https://example.com/my_provider_icon.png"], description="The URL of the provider's icon")]
|
|
17
|
+
login_url: Annotated[str, Field(examples=["https://example.com/my_provider_login"], description="The URL to redirect to for provider login")]
|
|
14
18
|
|
|
15
19
|
|
|
16
20
|
## Parameters Response Models
|
squirrels/_utils.py
CHANGED
|
@@ -2,10 +2,9 @@ from typing import Sequence, Optional, Union, TypeVar, Callable, Any, Iterable
|
|
|
2
2
|
from datetime import datetime
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from functools import lru_cache
|
|
5
|
-
from pydantic import BaseModel
|
|
6
5
|
import os, time, logging, json, duckdb, polars as pl, yaml
|
|
7
6
|
import jinja2 as j2, jinja2.nodes as j2_nodes
|
|
8
|
-
import sqlglot, sqlglot.expressions, asyncio
|
|
7
|
+
import sqlglot, sqlglot.expressions, asyncio, hashlib, inspect, base64
|
|
9
8
|
|
|
10
9
|
from . import _constants as c
|
|
11
10
|
from ._exceptions import ConfigurationError
|
|
@@ -290,7 +289,8 @@ def load_yaml_config(filepath: FilePath) -> dict:
|
|
|
290
289
|
|
|
291
290
|
|
|
292
291
|
def run_duckdb_stmt(
|
|
293
|
-
logger: Logger, duckdb_conn: duckdb.DuckDBPyConnection, stmt: str, *, params: dict[str, Any] | None = None,
|
|
292
|
+
logger: Logger, duckdb_conn: duckdb.DuckDBPyConnection, stmt: str, *, params: dict[str, Any] | None = None,
|
|
293
|
+
model_name: str | None = None, redacted_values: list[str] = []
|
|
294
294
|
) -> duckdb.DuckDBPyConnection:
|
|
295
295
|
"""
|
|
296
296
|
Runs a statement on a DuckDB connection
|
|
@@ -306,7 +306,8 @@ def run_duckdb_stmt(
|
|
|
306
306
|
for value in redacted_values:
|
|
307
307
|
redacted_stmt = redacted_stmt.replace(value, "[REDACTED]")
|
|
308
308
|
|
|
309
|
-
|
|
309
|
+
for_model_name = f" for model '{model_name}'" if model_name is not None else ""
|
|
310
|
+
logger.info(f"Running SQL statement{for_model_name}:\n{redacted_stmt}", extra={"data": {"params": params}})
|
|
310
311
|
try:
|
|
311
312
|
return duckdb_conn.execute(stmt, params)
|
|
312
313
|
except duckdb.ParserException as e:
|
|
@@ -357,3 +358,36 @@ async def asyncio_gather(coroutines: list):
|
|
|
357
358
|
# Wait for tasks to be cancelled
|
|
358
359
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
359
360
|
raise
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def hash_string(input_str: str, salt: str) -> str:
|
|
364
|
+
"""
|
|
365
|
+
Hashes a string using SHA-256
|
|
366
|
+
"""
|
|
367
|
+
return hashlib.sha256((input_str + salt).encode()).hexdigest()
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
T = TypeVar('T')
|
|
371
|
+
def call_func(func: Callable[..., T], **kwargs) -> T:
|
|
372
|
+
"""
|
|
373
|
+
Calls a function with the given arguments if func expects arguments, otherwise calls func without arguments
|
|
374
|
+
"""
|
|
375
|
+
sig = inspect.signature(func)
|
|
376
|
+
# Filter kwargs to only include parameters that the function accepts
|
|
377
|
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters}
|
|
378
|
+
return func(**filtered_kwargs)
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def generate_pkce_challenge(code_verifier: str) -> str:
|
|
382
|
+
"""Generate PKCE code challenge from code verifier"""
|
|
383
|
+
# Generate SHA256 hash of code_verifier
|
|
384
|
+
verifier_hash = hashlib.sha256(code_verifier.encode('utf-8')).digest()
|
|
385
|
+
# Base64 URL encode (without padding)
|
|
386
|
+
expected_challenge = base64.urlsafe_b64encode(verifier_hash).decode('utf-8').rstrip('=')
|
|
387
|
+
return expected_challenge
|
|
388
|
+
|
|
389
|
+
def validate_pkce_challenge(code_verifier: str, code_challenge: str) -> bool:
|
|
390
|
+
"""Validate PKCE code verifier against code challenge"""
|
|
391
|
+
# Generate expected challenge
|
|
392
|
+
expected_challenge = generate_pkce_challenge(code_verifier)
|
|
393
|
+
return expected_challenge == code_challenge
|
squirrels/arguments.py
ADDED
squirrels/auth.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from ._auth import BaseUser, ProviderConfigs, provider
|
squirrels/connections.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from ._manifest import ConnectionProperties, ConnectionTypeEnum
|