squirrels 0.5.0b2__py3-none-any.whl → 0.5.0b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of squirrels might be problematic. Click here for more details.

Files changed (96) hide show
  1. dateutils/__init__.py +6 -460
  2. dateutils/_enums.py +25 -0
  3. dateutils/_implementation.py +409 -0
  4. dateutils/types.py +6 -0
  5. squirrels/__init__.py +9 -13
  6. squirrels/_api_routes/__init__.py +5 -0
  7. squirrels/_api_routes/auth.py +262 -0
  8. squirrels/_api_routes/base.py +154 -0
  9. squirrels/_api_routes/dashboards.py +142 -0
  10. squirrels/_api_routes/data_management.py +103 -0
  11. squirrels/_api_routes/datasets.py +242 -0
  12. squirrels/_api_routes/oauth2.py +300 -0
  13. squirrels/_api_routes/project.py +214 -0
  14. squirrels/_api_server.py +145 -748
  15. squirrels/_arguments/__init__.py +0 -0
  16. squirrels/{arguments → _arguments}/init_time_args.py +7 -2
  17. squirrels/{arguments → _arguments}/run_time_args.py +4 -26
  18. squirrels/_auth.py +646 -93
  19. squirrels/_connection_set.py +5 -5
  20. squirrels/_constants.py +7 -1
  21. squirrels/{_dashboards_io.py → _dashboards.py} +87 -6
  22. squirrels/_data_sources.py +564 -0
  23. squirrels/_exceptions.py +9 -37
  24. squirrels/_initializer.py +31 -26
  25. squirrels/_manifest.py +5 -5
  26. squirrels/_model_builder.py +1 -1
  27. squirrels/_model_configs.py +2 -2
  28. squirrels/_model_queries.py +1 -1
  29. squirrels/_models.py +40 -27
  30. squirrels/{package_data → _package_data}/base_project/.env +1 -0
  31. squirrels/{package_data → _package_data}/base_project/.env.example +1 -0
  32. squirrels/{package_data → _package_data}/base_project/dashboards/dashboard_example.py +4 -4
  33. squirrels/{package_data → _package_data}/base_project/dashboards/dashboard_example.yml +2 -2
  34. squirrels/_package_data/base_project/macros/macros_example.sql +17 -0
  35. squirrels/{package_data → _package_data}/base_project/models/builds/build_example.py +2 -2
  36. squirrels/{package_data → _package_data}/base_project/models/builds/build_example.sql +1 -1
  37. squirrels/{package_data → _package_data}/base_project/models/dbviews/dbview_example.sql +1 -1
  38. squirrels/_package_data/base_project/models/federates/federate_example.py +41 -0
  39. squirrels/_package_data/base_project/models/federates/federate_example.sql +25 -0
  40. squirrels/{package_data → _package_data}/base_project/models/federates/federate_example.yml +6 -6
  41. squirrels/{package_data → _package_data}/base_project/parameters.yml +9 -8
  42. squirrels/_package_data/base_project/pyconfigs/connections.py +14 -0
  43. squirrels/{package_data → _package_data}/base_project/pyconfigs/context.py +14 -16
  44. squirrels/_package_data/base_project/pyconfigs/parameters.py +106 -0
  45. squirrels/_package_data/base_project/pyconfigs/user.py +51 -0
  46. squirrels/_package_data/templates/dataset_results.html +112 -0
  47. squirrels/_package_data/templates/oauth_login.html +271 -0
  48. squirrels/_parameter_configs.py +35 -35
  49. squirrels/_parameter_options.py +348 -0
  50. squirrels/_parameter_sets.py +47 -37
  51. squirrels/_parameters.py +1664 -0
  52. squirrels/_project.py +76 -32
  53. squirrels/_py_module.py +3 -2
  54. squirrels/_schemas/__init__.py +0 -0
  55. squirrels/_schemas/auth_models.py +144 -0
  56. squirrels/_schemas/query_param_models.py +67 -0
  57. squirrels/{_api_response_models.py → _schemas/response_models.py} +12 -8
  58. squirrels/_utils.py +38 -4
  59. squirrels/arguments.py +2 -0
  60. squirrels/auth.py +1 -0
  61. squirrels/connections.py +1 -0
  62. squirrels/dashboards.py +1 -82
  63. squirrels/data_sources.py +8 -563
  64. squirrels/parameter_options.py +8 -348
  65. squirrels/parameters.py +9 -1266
  66. squirrels/types.py +11 -0
  67. {squirrels-0.5.0b2.dist-info → squirrels-0.5.0b4.dist-info}/METADATA +4 -1
  68. squirrels-0.5.0b4.dist-info/RECORD +94 -0
  69. squirrels/package_data/base_project/macros/macros_example.sql +0 -15
  70. squirrels/package_data/base_project/models/federates/federate_example.py +0 -44
  71. squirrels/package_data/base_project/models/federates/federate_example.sql +0 -17
  72. squirrels/package_data/base_project/pyconfigs/connections.py +0 -14
  73. squirrels/package_data/base_project/pyconfigs/parameters.py +0 -93
  74. squirrels/package_data/base_project/pyconfigs/user.py +0 -23
  75. squirrels-0.5.0b2.dist-info/RECORD +0 -70
  76. /squirrels/{dataset_result.py → _dataset_types.py} +0 -0
  77. /squirrels/{package_data → _package_data}/base_project/assets/expenses.db +0 -0
  78. /squirrels/{package_data → _package_data}/base_project/assets/weather.db +0 -0
  79. /squirrels/{package_data → _package_data}/base_project/connections.yml +0 -0
  80. /squirrels/{package_data → _package_data}/base_project/docker/.dockerignore +0 -0
  81. /squirrels/{package_data → _package_data}/base_project/docker/Dockerfile +0 -0
  82. /squirrels/{package_data → _package_data}/base_project/docker/compose.yml +0 -0
  83. /squirrels/{package_data → _package_data}/base_project/duckdb_init.sql +0 -0
  84. /squirrels/{package_data/base_project/.gitignore → _package_data/base_project/gitignore} +0 -0
  85. /squirrels/{package_data → _package_data}/base_project/models/builds/build_example.yml +0 -0
  86. /squirrels/{package_data → _package_data}/base_project/models/dbviews/dbview_example.yml +0 -0
  87. /squirrels/{package_data → _package_data}/base_project/models/sources.yml +0 -0
  88. /squirrels/{package_data → _package_data}/base_project/seeds/seed_categories.csv +0 -0
  89. /squirrels/{package_data → _package_data}/base_project/seeds/seed_categories.yml +0 -0
  90. /squirrels/{package_data → _package_data}/base_project/seeds/seed_subcategories.csv +0 -0
  91. /squirrels/{package_data → _package_data}/base_project/seeds/seed_subcategories.yml +0 -0
  92. /squirrels/{package_data → _package_data}/base_project/squirrels.yml.j2 +0 -0
  93. /squirrels/{package_data → _package_data}/base_project/tmp/.gitignore +0 -0
  94. {squirrels-0.5.0b2.dist-info → squirrels-0.5.0b4.dist-info}/WHEEL +0 -0
  95. {squirrels-0.5.0b2.dist-info → squirrels-0.5.0b4.dist-info}/entry_points.txt +0 -0
  96. {squirrels-0.5.0b2.dist-info → squirrels-0.5.0b4.dist-info}/licenses/LICENSE +0 -0
squirrels/_project.py CHANGED
@@ -1,17 +1,20 @@
1
1
  from dotenv import dotenv_values
2
2
  from uuid import uuid4
3
+ from pathlib import Path
3
4
  import asyncio, typing as t, functools as ft, shutil, json, os
4
5
  import logging as l, matplotlib.pyplot as plt, networkx as nx, polars as pl
5
6
  import sqlglot, sqlglot.expressions
6
7
 
7
- from ._auth import Authenticator, BaseUser
8
+ from ._auth import Authenticator, BaseUser, AuthProviderArgs, ProviderFunctionType
9
+ from ._schemas import response_models as rm
8
10
  from ._model_builder import ModelBuilder
9
11
  from ._exceptions import InvalidInputError, ConfigurationError
10
- from . import _utils as u, _constants as c, _manifest as mf, _connection_set as cs, _api_response_models as arm
12
+ from ._py_module import PyModule
13
+ from . import _dashboards as d, _utils as u, _constants as c, _manifest as mf, _connection_set as cs
11
14
  from . import _seeds as s, _models as m, _model_configs as mc, _model_queries as mq, _sources as so
12
- from . import _parameter_sets as ps, _dashboards_io as d, dashboards as dash, dataset_result as dr
15
+ from . import _parameter_sets as ps, _dataset_types as dr
13
16
 
14
- T = t.TypeVar("T", bound=dash.Dashboard)
17
+ T = t.TypeVar("T", bound=d.Dashboard)
15
18
  M = t.TypeVar("M", bound=m.DataModel)
16
19
 
17
20
 
@@ -70,7 +73,7 @@ class SquirrelsProject:
70
73
  raise ValueError("log_format must be either 'text' or 'json'")
71
74
 
72
75
  if log_file:
73
- path = u.Path(base_path, c.LOGS_FOLDER, log_file)
76
+ path = Path(base_path, c.LOGS_FOLDER, log_file)
74
77
  path.parent.mkdir(parents=True, exist_ok=True)
75
78
 
76
79
  handler = l.FileHandler(path)
@@ -128,12 +131,33 @@ class SquirrelsProject:
128
131
  return cs.ConnectionSetIO.load_from_file(self._logger, self._filepath, self._manifest_cfg, self._conn_args)
129
132
 
130
133
  @ft.cached_property
131
- def _auth(self) -> Authenticator:
132
- return Authenticator(self._logger, self._filepath, self._env_vars)
134
+ def _user_cls_and_provider_functions(self) -> tuple[type[BaseUser], list[ProviderFunctionType]]:
135
+ user_module_path = u.Path(self._filepath, c.PYCONFIGS_FOLDER, c.USER_FILE)
136
+ user_module = PyModule(user_module_path)
137
+
138
+ User = user_module.get_func_or_class("User", default_attr=BaseUser) # adds to Authenticator.providers as side effect
139
+ provider_functions = Authenticator.providers
140
+ Authenticator.providers = []
141
+
142
+ if not issubclass(User, BaseUser):
143
+ raise ConfigurationError(f"User class in '{c.USER_FILE}' must inherit from BaseUser")
144
+
145
+ return User, provider_functions
146
+
147
+ @ft.cached_property
148
+ def _auth_args(self) -> AuthProviderArgs:
149
+ conn_args = self._conn_args
150
+ return AuthProviderArgs(conn_args.project_path, conn_args.proj_vars, conn_args.env_vars)
151
+
152
+ @ft.cached_property
153
+ def _auth(self) -> Authenticator[BaseUser]:
154
+ User, provider_functions = self._user_cls_and_provider_functions
155
+ return Authenticator(self._logger, self._filepath, self._auth_args, provider_functions, user_cls=User)
133
156
 
134
157
  @ft.cached_property
135
158
  def _param_args(self) -> ps.ParametersArgs:
136
- return ps.ParameterConfigsSetIO.get_param_args(self._conn_args)
159
+ conn_args = self._conn_args
160
+ return ps.ParametersArgs(conn_args.project_path, conn_args.proj_vars, conn_args.env_vars)
137
161
 
138
162
  @ft.cached_property
139
163
  def _param_cfg_set(self) -> ps.ParameterConfigsSet:
@@ -143,12 +167,32 @@ class SquirrelsProject:
143
167
 
144
168
  @ft.cached_property
145
169
  def _j2_env(self) -> u.EnvironmentWithMacros:
146
- return u.EnvironmentWithMacros(self._logger, loader=u.j2.FileSystemLoader(self._filepath))
170
+ env = u.EnvironmentWithMacros(self._logger, loader=u.j2.FileSystemLoader(self._filepath))
171
+
172
+ def value_to_str(value: t.Any, attribute: str | None = None) -> str:
173
+ if attribute is None:
174
+ return str(value)
175
+ else:
176
+ return str(getattr(value, attribute))
177
+
178
+ def join(value: list[t.Any], d: str = ", ", attribute: str | None = None) -> str:
179
+ return d.join(map(lambda x: value_to_str(x, attribute), value))
180
+
181
+ def quote(value: t.Any, q: str = "'", attribute: str | None = None) -> str:
182
+ return q + value_to_str(value, attribute) + q
183
+
184
+ def quote_and_join(value: list[t.Any], q: str = "'", d: str = ", ", attribute: str | None = None) -> str:
185
+ return d.join(map(lambda x: quote(x, q, attribute), value))
186
+
187
+ env.filters["join"] = join
188
+ env.filters["quote"] = quote
189
+ env.filters["quote_and_join"] = quote_and_join
190
+ return env
147
191
 
148
192
  @ft.cached_property
149
193
  def _duckdb_venv_path(self) -> str:
150
194
  duckdb_filepath_setting_val = self._env_vars.get(c.SQRL_DUCKDB_VENV_DB_FILE_PATH, f"{c.TARGET_FOLDER}/{c.DUCKDB_VENV_FILE}")
151
- return str(u.Path(self._filepath, duckdb_filepath_setting_val))
195
+ return str(Path(self._filepath, duckdb_filepath_setting_val))
152
196
 
153
197
  def close(self) -> None:
154
198
  """
@@ -236,7 +280,7 @@ class SquirrelsProject:
236
280
  for model_name in dependencies:
237
281
  model = models_dict[model_name]
238
282
  if isinstance(model, m.SourceModel) and not model.model_config.load_to_duckdb:
239
- raise InvalidInputError(203, f"Source model '{model_name}' cannot be queried with DuckDB")
283
+ raise InvalidInputError(400, "Unqueryable source model", f"Source model '{model_name}' cannot be queried with DuckDB")
240
284
  if isinstance(model, (m.SourceModel, m.BuildModel)):
241
285
  substitutions[model_name] = f"venv.{model_name}"
242
286
 
@@ -255,7 +299,7 @@ class SquirrelsProject:
255
299
  dag = m.DAG(None, fake_target_model, models_dict, self._duckdb_venv_path, self._logger)
256
300
  return dag
257
301
 
258
- def _draw_dag(self, dag: m.DAG, output_folder: u.Path) -> None:
302
+ def _draw_dag(self, dag: m.DAG, output_folder: Path) -> None:
259
303
  color_map = {
260
304
  m.ModelType.SEED: "green", m.ModelType.DBVIEW: "red", m.ModelType.FEDERATE: "skyblue",
261
305
  m.ModelType.BUILD: "purple", m.ModelType.SOURCE: "orange"
@@ -275,7 +319,7 @@ class SquirrelsProject:
275
319
 
276
320
  fig.tight_layout()
277
321
  plt.margins(x=0.1, y=0.1)
278
- fig.savefig(u.Path(output_folder, "dag.png"))
322
+ fig.savefig(Path(output_folder, "dag.png"))
279
323
  plt.close(fig)
280
324
 
281
325
  async def _get_compiled_dag(self, *, sql_query: str | None = None, selections: dict[str, t.Any] = {}, user: BaseUser | None = None) -> m.DAG:
@@ -285,18 +329,18 @@ class SquirrelsProject:
285
329
  await dag.execute(self._param_args, self._param_cfg_set, self._context_func, user, selections, runquery=False, default_traits=default_traits)
286
330
  return dag
287
331
 
288
- def _get_all_connections(self) -> list[arm.ConnectionItemModel]:
332
+ def _get_all_connections(self) -> list[rm.ConnectionItemModel]:
289
333
  connections = []
290
334
  for conn_name, conn_props in self._conn_set.get_connections_as_dict().items():
291
335
  if isinstance(conn_props, mf.ConnectionProperties):
292
336
  label = conn_props.label if conn_props.label is not None else conn_name
293
- connections.append(arm.ConnectionItemModel(name=conn_name, label=label))
337
+ connections.append(rm.ConnectionItemModel(name=conn_name, label=label))
294
338
  return connections
295
339
 
296
- def _get_all_data_models(self, compiled_dag: m.DAG) -> list[arm.DataModelItem]:
340
+ def _get_all_data_models(self, compiled_dag: m.DAG) -> list[rm.DataModelItem]:
297
341
  return compiled_dag.get_all_data_models()
298
342
 
299
- async def get_all_data_models(self) -> list[arm.DataModelItem]:
343
+ async def get_all_data_models(self) -> list[rm.DataModelItem]:
300
344
  """
301
345
  Get all data models in the project
302
346
 
@@ -306,26 +350,26 @@ class SquirrelsProject:
306
350
  compiled_dag = await self._get_compiled_dag()
307
351
  return self._get_all_data_models(compiled_dag)
308
352
 
309
- def _get_all_data_lineage(self, compiled_dag: m.DAG) -> list[arm.LineageRelation]:
353
+ def _get_all_data_lineage(self, compiled_dag: m.DAG) -> list[rm.LineageRelation]:
310
354
  all_lineage = compiled_dag.get_all_model_lineage()
311
355
 
312
356
  # Add dataset nodes to the lineage
313
357
  for dataset in self._manifest_cfg.datasets.values():
314
- target_dataset = arm.LineageNode(name=dataset.name, type="dataset")
315
- source_model = arm.LineageNode(name=dataset.model, type="model")
316
- all_lineage.append(arm.LineageRelation(type="runtime", source=source_model, target=target_dataset))
358
+ target_dataset = rm.LineageNode(name=dataset.name, type="dataset")
359
+ source_model = rm.LineageNode(name=dataset.model, type="model")
360
+ all_lineage.append(rm.LineageRelation(type="runtime", source=source_model, target=target_dataset))
317
361
 
318
362
  # Add dashboard nodes to the lineage
319
363
  for dashboard in self._dashboards.values():
320
- target_dashboard = arm.LineageNode(name=dashboard.dashboard_name, type="dashboard")
364
+ target_dashboard = rm.LineageNode(name=dashboard.dashboard_name, type="dashboard")
321
365
  datasets = set(x.dataset for x in dashboard.config.depends_on)
322
366
  for dataset in datasets:
323
- source_dataset = arm.LineageNode(name=dataset, type="dataset")
324
- all_lineage.append(arm.LineageRelation(type="runtime", source=source_dataset, target=target_dashboard))
367
+ source_dataset = rm.LineageNode(name=dataset, type="dataset")
368
+ all_lineage.append(rm.LineageRelation(type="runtime", source=source_dataset, target=target_dashboard))
325
369
 
326
370
  return all_lineage
327
371
 
328
- async def get_all_data_lineage(self) -> list[arm.LineageRelation]:
372
+ async def get_all_data_lineage(self) -> list[rm.LineageRelation]:
329
373
  """
330
374
  Get all data lineage in the project
331
375
 
@@ -371,28 +415,28 @@ class SquirrelsProject:
371
415
  runquery=runquery, recurse=recurse, default_traits=self._manifest_cfg.get_default_traits()
372
416
  )
373
417
 
374
- output_folder = u.Path(self._filepath, c.TARGET_FOLDER, c.COMPILE_FOLDER, dataset, test_set)
418
+ output_folder = Path(self._filepath, c.TARGET_FOLDER, c.COMPILE_FOLDER, dataset, test_set)
375
419
  if output_folder.exists():
376
420
  shutil.rmtree(output_folder)
377
421
  output_folder.mkdir(parents=True, exist_ok=True)
378
422
 
379
423
  def write_placeholders() -> None:
380
- output_filepath = u.Path(output_folder, "placeholders.json")
424
+ output_filepath = Path(output_folder, "placeholders.json")
381
425
  with open(output_filepath, 'w') as f:
382
426
  json.dump(dag.placeholders, f, indent=4)
383
427
 
384
428
  def write_model_outputs(model: m.DataModel) -> None:
385
429
  assert isinstance(model, m.QueryModel)
386
430
  subfolder = c.DBVIEWS_FOLDER if model.model_type == m.ModelType.DBVIEW else c.FEDERATES_FOLDER
387
- subpath = u.Path(output_folder, subfolder)
431
+ subpath = Path(output_folder, subfolder)
388
432
  subpath.mkdir(parents=True, exist_ok=True)
389
433
  if isinstance(model.compiled_query, mq.SqlModelQuery):
390
- output_filepath = u.Path(subpath, model.name+'.sql')
434
+ output_filepath = Path(subpath, model.name+'.sql')
391
435
  query = model.compiled_query.query
392
436
  with open(output_filepath, 'w') as f:
393
437
  f.write(query)
394
438
  if runquery and isinstance(model.result, pl.LazyFrame):
395
- output_filepath = u.Path(subpath, model.name+'.csv')
439
+ output_filepath = Path(subpath, model.name+'.csv')
396
440
  model.result.collect().write_csv(output_filepath)
397
441
 
398
442
  write_placeholders()
@@ -455,7 +499,7 @@ class SquirrelsProject:
455
499
 
456
500
  def _permission_error(self, user: BaseUser | None, data_type: str, data_name: str, scope: str) -> InvalidInputError:
457
501
  username = "" if user is None else f" '{user.username}'"
458
- return InvalidInputError(25, f"User{username} does not have permission to access {scope} {data_type}: {data_name}")
502
+ return InvalidInputError(403, f"Unauthorized access to {data_type}", f"User{username} does not have permission to access {scope} {data_type}: {data_name}")
459
503
 
460
504
  def seed(self, name: str) -> pl.LazyFrame:
461
505
  """
@@ -520,7 +564,7 @@ class SquirrelsProject:
520
564
  )
521
565
 
522
566
  async def dashboard(
523
- self, name: str, *, selections: dict[str, t.Any] = {}, user: BaseUser | None = None, dashboard_type: t.Type[T] = dash.Dashboard
567
+ self, name: str, *, selections: dict[str, t.Any] = {}, user: BaseUser | None = None, dashboard_type: t.Type[T] = d.PngDashboard
524
568
  ) -> T:
525
569
  """
526
570
  Async method to retrieve a dashboard given parameter selections.
squirrels/_py_module.py CHANGED
@@ -43,11 +43,12 @@ class PyModule:
43
43
  return func_or_class
44
44
 
45
45
 
46
- def run_pyconfig_main(base_path: str, filename: str, kwargs: dict[str, Any] = {}) -> None:
46
+ def run_pyconfig_main(base_path: str, filename: str, kwargs: dict[str, Any] = {}) -> Any | None:
47
47
  """
48
48
  Given a python file in the 'pyconfigs' folder, run its main function
49
49
 
50
50
  Arguments:
51
+ base_path: The base path of the project
51
52
  filename: The name of the file to run main function
52
53
  kwargs: Dictionary of the main function arguments
53
54
  """
@@ -56,6 +57,6 @@ def run_pyconfig_main(base_path: str, filename: str, kwargs: dict[str, Any] = {}
56
57
  main_function = module.get_func_or_class(c.MAIN_FUNC, is_required=False)
57
58
  if main_function:
58
59
  try:
59
- main_function(**kwargs)
60
+ return main_function(**kwargs)
60
61
  except Exception as e:
61
62
  raise FileExecutionError(f'Failed to run python file "{filepath}"', e) from e
File without changes
@@ -0,0 +1,144 @@
1
+ from typing import Callable, Any
2
+ from datetime import datetime
3
+ from pydantic import BaseModel, ConfigDict, Field, field_serializer
4
+
5
+
6
+ class BaseUser(BaseModel):
7
+ model_config = ConfigDict(from_attributes=True)
8
+ username: str
9
+ is_admin: bool = False
10
+
11
+ @classmethod
12
+ def dropped_columns(cls):
13
+ return []
14
+
15
+ def __hash__(self):
16
+ return hash(self.username)
17
+
18
+
19
+ class ApiKey(BaseModel):
20
+ model_config = ConfigDict(from_attributes=True)
21
+ id: str
22
+ title: str
23
+ username: str
24
+ created_at: datetime
25
+ expires_at: datetime
26
+
27
+ @field_serializer('created_at', 'expires_at')
28
+ def serialize_datetime(self, dt: datetime) -> str:
29
+ return dt.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
30
+
31
+
32
+ class UserField(BaseModel):
33
+ name: str
34
+ type: str
35
+ nullable: bool
36
+ enum: list[str] | None
37
+ default: Any | None
38
+
39
+
40
+ class ProviderConfigs(BaseModel):
41
+ client_id: str
42
+ client_secret: str
43
+ server_metadata_url: str
44
+ client_kwargs: dict = Field(default_factory=dict)
45
+ get_user: Callable[[dict], BaseUser]
46
+
47
+
48
+ class AuthProvider(BaseModel):
49
+ name: str
50
+ label: str
51
+ icon: str
52
+ provider_configs: ProviderConfigs
53
+
54
+
55
+ # OAuth 2.1 Models
56
+
57
+ class OAuthClientModel(BaseModel):
58
+ """OAuth client details"""
59
+ model_config = ConfigDict(from_attributes=True)
60
+ client_id: str
61
+ client_name: str
62
+ redirect_uris: list[str]
63
+ scope: str
64
+ grant_types: list[str]
65
+ response_types: list[str]
66
+ created_at: datetime
67
+ is_active: bool
68
+
69
+ @field_serializer('created_at')
70
+ def serialize_datetime(self, dt: datetime) -> str:
71
+ return dt.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
72
+
73
+
74
+ class ClientRegistrationRequest(BaseModel):
75
+ """Request model for OAuth client registration"""
76
+ client_name: str = Field(description="Human-readable name for the OAuth client")
77
+ redirect_uris: list[str] = Field(description="List of allowed redirect URIs for the client")
78
+ scope: str = Field(default="read", description="Default scope for the client")
79
+ grant_types: list[str] = Field(default=["authorization_code", "refresh_token"], description="Allowed grant types")
80
+ response_types: list[str] = Field(default=["code"], description="Allowed response types")
81
+
82
+
83
+ class ClientUpdateRequest(BaseModel):
84
+ """Request model for OAuth client update"""
85
+ client_name: str | None = Field(default=None, description="Human-readable name for the OAuth client")
86
+ redirect_uris: list[str] | None = Field(default=None, description="List of allowed redirect URIs for the client")
87
+ scope: str | None = Field(default=None, description="Default scope for the client")
88
+ grant_types: list[str] | None = Field(default=None, description="Allowed grant types")
89
+ response_types: list[str] | None = Field(default=None, description="Allowed response types")
90
+ is_active: bool | None = Field(default=None, description="Whether the client is active")
91
+
92
+
93
+ class ClientDetailsResponse(BaseModel):
94
+ """Response model for OAuth client details (without client_secret)"""
95
+ client_id: str = Field(description="Client ID")
96
+ client_name: str = Field(description="Client name")
97
+ redirect_uris: list[str] = Field(description="Registered redirect URIs")
98
+ scope: str = Field(description="Default scope")
99
+ grant_types: list[str] = Field(description="Allowed grant types")
100
+ response_types: list[str] = Field(description="Allowed response types")
101
+ created_at: datetime = Field(description="Registration timestamp")
102
+ is_active: bool = Field(description="Whether the client is active")
103
+
104
+ @field_serializer('created_at')
105
+ def serialize_datetime(self, dt: datetime) -> str:
106
+ return dt.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
107
+
108
+
109
+ class ClientUpdateResponse(ClientDetailsResponse):
110
+ """Response model for OAuth client update"""
111
+ registration_access_token: str | None = Field(default=None, description="Token for managing this client registration (store securely)")
112
+
113
+
114
+ class ClientRegistrationResponse(ClientUpdateResponse):
115
+ """Response model for OAuth client registration"""
116
+ client_secret: str = Field(description="Generated client secret (store securely)")
117
+ registration_client_uri: str | None = Field(default=None, description="URI for managing this client registration")
118
+
119
+
120
+ class TokenResponse(BaseModel):
121
+ access_token: str
122
+ token_type: str = "bearer"
123
+ expires_in: int
124
+ refresh_token: str | None = None
125
+
126
+
127
+ class OAuthServerMetadata(BaseModel):
128
+ """OAuth 2.1 Authorization Server Metadata (RFC 8414)"""
129
+ issuer: str = Field(description="Authorization server's issuer identifier URL")
130
+ authorization_endpoint: str = Field(description="URL of the authorization endpoint")
131
+ token_endpoint: str = Field(description="URL of the token endpoint")
132
+ revocation_endpoint: str = Field(description="URL of the token revocation endpoint")
133
+ registration_endpoint: str = Field(description="URL of the client registration endpoint")
134
+ scopes_supported: list[str] = Field(description="List of OAuth 2.1 scope values supported")
135
+ response_types_supported: list[str] = Field(description="List of OAuth 2.1 response_type values supported")
136
+ grant_types_supported: list[str] = Field(description="List of OAuth 2.1 grant type values supported")
137
+ token_endpoint_auth_methods_supported: list[str] = Field(
138
+ default=["client_secret_basic", "client_secret_post"],
139
+ description="List of client authentication methods supported by the token endpoint"
140
+ )
141
+ code_challenge_methods_supported: list[str] = Field(
142
+ default=["S256"],
143
+ description="List of PKCE code challenge methods supported"
144
+ )
@@ -0,0 +1,67 @@
1
+ """
2
+ Query model generation utilities for API routes
3
+ """
4
+ from typing import Annotated
5
+ from dataclasses import make_dataclass
6
+ from fastapi import Depends
7
+ from pydantic import create_model
8
+
9
+ from .._parameter_configs import APIParamFieldInfo
10
+
11
+
12
+ def _get_query_models_helper(widget_parameters: list[str] | None, predefined_params: list[APIParamFieldInfo], param_fields: dict):
13
+ """Helper function to generate query models"""
14
+ if widget_parameters is None:
15
+ widget_parameters = list(param_fields.keys())
16
+
17
+ QueryModelForGetRaw = make_dataclass("QueryParams", [
18
+ param_fields[param].as_query_info() for param in widget_parameters
19
+ ] + [param.as_query_info() for param in predefined_params])
20
+ QueryModelForGet = Annotated[QueryModelForGetRaw, Depends()]
21
+
22
+ field_definitions = {param: param_fields[param].as_body_info() for param in widget_parameters}
23
+ for param in predefined_params:
24
+ field_definitions[param.name] = param.as_body_info()
25
+ QueryModelForPost = create_model("RequestBodyParams", **field_definitions) # type: ignore
26
+ return QueryModelForGet, QueryModelForPost
27
+
28
+
29
+ def get_query_models_for_parameters(widget_parameters: list[str] | None, param_fields: dict):
30
+ """Generate query models for parameter endpoints"""
31
+ predefined_params = [
32
+ APIParamFieldInfo("x_verify_params", bool, default=False, description="If true, the query parameters are verified to be valid for the dataset"),
33
+ APIParamFieldInfo("x_parent_param", str, description="The parameter name used for parameter updates. If not provided, then all parameters are retrieved"),
34
+ ]
35
+ return _get_query_models_helper(widget_parameters, predefined_params, param_fields)
36
+
37
+
38
+ def get_query_models_for_dataset(widget_parameters: list[str] | None, param_fields: dict):
39
+ """Generate query models for dataset endpoints"""
40
+ predefined_params = [
41
+ APIParamFieldInfo("x_verify_params", bool, default=False, description="If true, the query parameters are verified to be valid for the dataset"),
42
+ APIParamFieldInfo("x_orientation", str, default="records", description="The orientation of the data to return, one of: 'records', 'rows', or 'columns'"),
43
+ APIParamFieldInfo("x_select", list[str], examples=[[]], description="The columns to select from the dataset. All are returned if not specified"),
44
+ APIParamFieldInfo("x_offset", int, default=0, description="The number of rows to skip before returning data (applied after data caching)"),
45
+ APIParamFieldInfo("x_limit", int, default=1000, description="The maximum number of rows to return (applied after data caching and offset)"),
46
+ ]
47
+ return _get_query_models_helper(widget_parameters, predefined_params, param_fields)
48
+
49
+
50
+ def get_query_models_for_dashboard(widget_parameters: list[str] | None, param_fields: dict):
51
+ """Generate query models for dashboard endpoints"""
52
+ predefined_params = [
53
+ APIParamFieldInfo("x_verify_params", bool, default=False, description="If true, the query parameters are verified to be valid for the dashboard"),
54
+ ]
55
+ return _get_query_models_helper(widget_parameters, predefined_params, param_fields)
56
+
57
+
58
+ def get_query_models_for_querying_models(param_fields: dict):
59
+ """Generate query models for querying data models"""
60
+ predefined_params = [
61
+ APIParamFieldInfo("x_verify_params", bool, default=False, description="If true, the query parameters are verified to be valid"),
62
+ APIParamFieldInfo("x_orientation", str, default="records", description="The orientation of the data to return, one of: 'records', 'rows', or 'columns'"),
63
+ APIParamFieldInfo("x_offset", int, default=0, description="The number of rows to skip before returning data (applied after data caching)"),
64
+ APIParamFieldInfo("x_limit", int, default=1000, description="The maximum number of rows to return (applied after data caching and offset)"),
65
+ APIParamFieldInfo("x_sql_query", str, description="The SQL query to execute on the data models"),
66
+ ]
67
+ return _get_query_models_helper(None, predefined_params, param_fields)
@@ -1,16 +1,20 @@
1
1
  from typing import Annotated, Literal
2
2
  from pydantic import BaseModel, Field
3
- from datetime import datetime, date
3
+ from datetime import date
4
4
 
5
- from . import _model_configs as mc, _sources as s
5
+ from .. import _model_configs as mc, _sources as s
6
6
 
7
7
 
8
- class LoginReponse(BaseModel):
9
- access_token: Annotated[str, Field(examples=["encoded_jwt_token"], description="An encoded JSON web token to use subsequent API requests")]
10
- token_type: Annotated[str, Field(examples=["bearer"], description='Always "bearer" for Bearer token')]
11
- username: Annotated[str, Field(examples=["johndoe"], description='The username authenticated with from the form data')]
12
- is_admin: Annotated[bool, Field(examples=[False], description="A boolean for whether the user is an admin")]
13
- expiry_time: Annotated[datetime, Field(examples=["2023-08-01T12:00:00.000000Z"], description="The expiry time of the access token in yyyy-MM-dd'T'hh:mm:ss.SSSSSS'Z' format")]
8
+ ## Simple Auth Response Models
9
+
10
+ class ApiKeyResponse(BaseModel):
11
+ api_key: Annotated[str, Field(examples=["sqrl-12345678"], description="The API key to use subsequent API requests")]
12
+
13
+ class ProviderResponse(BaseModel):
14
+ name: Annotated[str, Field(examples=["my_provider"], description="The name of the provider")]
15
+ label: Annotated[str, Field(examples=["My Provider"], description="The human-friendly display name for the provider")]
16
+ icon: Annotated[str, Field(examples=["https://example.com/my_provider_icon.png"], description="The URL of the provider's icon")]
17
+ login_url: Annotated[str, Field(examples=["https://example.com/my_provider_login"], description="The URL to redirect to for provider login")]
14
18
 
15
19
 
16
20
  ## Parameters Response Models
squirrels/_utils.py CHANGED
@@ -2,10 +2,9 @@ from typing import Sequence, Optional, Union, TypeVar, Callable, Any, Iterable
2
2
  from datetime import datetime
3
3
  from pathlib import Path
4
4
  from functools import lru_cache
5
- from pydantic import BaseModel
6
5
  import os, time, logging, json, duckdb, polars as pl, yaml
7
6
  import jinja2 as j2, jinja2.nodes as j2_nodes
8
- import sqlglot, sqlglot.expressions, asyncio
7
+ import sqlglot, sqlglot.expressions, asyncio, hashlib, inspect, base64
9
8
 
10
9
  from . import _constants as c
11
10
  from ._exceptions import ConfigurationError
@@ -290,7 +289,8 @@ def load_yaml_config(filepath: FilePath) -> dict:
290
289
 
291
290
 
292
291
  def run_duckdb_stmt(
293
- logger: Logger, duckdb_conn: duckdb.DuckDBPyConnection, stmt: str, *, params: dict[str, Any] | None = None, redacted_values: list[str] = []
292
+ logger: Logger, duckdb_conn: duckdb.DuckDBPyConnection, stmt: str, *, params: dict[str, Any] | None = None,
293
+ model_name: str | None = None, redacted_values: list[str] = []
294
294
  ) -> duckdb.DuckDBPyConnection:
295
295
  """
296
296
  Runs a statement on a DuckDB connection
@@ -306,7 +306,8 @@ def run_duckdb_stmt(
306
306
  for value in redacted_values:
307
307
  redacted_stmt = redacted_stmt.replace(value, "[REDACTED]")
308
308
 
309
- logger.info(f"Running statement: {redacted_stmt}", extra={"data": {"params": params}})
309
+ for_model_name = f" for model '{model_name}'" if model_name is not None else ""
310
+ logger.info(f"Running SQL statement{for_model_name}:\n{redacted_stmt}", extra={"data": {"params": params}})
310
311
  try:
311
312
  return duckdb_conn.execute(stmt, params)
312
313
  except duckdb.ParserException as e:
@@ -357,3 +358,36 @@ async def asyncio_gather(coroutines: list):
357
358
  # Wait for tasks to be cancelled
358
359
  await asyncio.gather(*tasks, return_exceptions=True)
359
360
  raise
361
+
362
+
363
+ def hash_string(input_str: str, salt: str) -> str:
364
+ """
365
+ Hashes a string using SHA-256
366
+ """
367
+ return hashlib.sha256((input_str + salt).encode()).hexdigest()
368
+
369
+
370
+ T = TypeVar('T')
371
+ def call_func(func: Callable[..., T], **kwargs) -> T:
372
+ """
373
+ Calls a function with the given arguments if func expects arguments, otherwise calls func without arguments
374
+ """
375
+ sig = inspect.signature(func)
376
+ # Filter kwargs to only include parameters that the function accepts
377
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters}
378
+ return func(**filtered_kwargs)
379
+
380
+
381
+ def generate_pkce_challenge(code_verifier: str) -> str:
382
+ """Generate PKCE code challenge from code verifier"""
383
+ # Generate SHA256 hash of code_verifier
384
+ verifier_hash = hashlib.sha256(code_verifier.encode('utf-8')).digest()
385
+ # Base64 URL encode (without padding)
386
+ expected_challenge = base64.urlsafe_b64encode(verifier_hash).decode('utf-8').rstrip('=')
387
+ return expected_challenge
388
+
389
+ def validate_pkce_challenge(code_verifier: str, code_challenge: str) -> bool:
390
+ """Validate PKCE code verifier against code challenge"""
391
+ # Generate expected challenge
392
+ expected_challenge = generate_pkce_challenge(code_verifier)
393
+ return expected_challenge == code_challenge
squirrels/arguments.py ADDED
@@ -0,0 +1,2 @@
1
+ from ._arguments.init_time_args import ConnectionsArgs, AuthProviderArgs, ParametersArgs, BuildModelArgs
2
+ from ._arguments.run_time_args import ContextArgs, ModelArgs, DashboardArgs
squirrels/auth.py ADDED
@@ -0,0 +1 @@
1
+ from ._auth import BaseUser, ProviderConfigs, provider
@@ -0,0 +1 @@
1
+ from ._manifest import ConnectionProperties, ConnectionTypeEnum