squirrels 0.1.0__py3-none-any.whl → 0.6.0.post0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dateutils/__init__.py +6 -0
- dateutils/_enums.py +25 -0
- squirrels/dateutils.py → dateutils/_implementation.py +409 -380
- dateutils/types.py +6 -0
- squirrels/__init__.py +21 -18
- squirrels/_api_routes/__init__.py +5 -0
- squirrels/_api_routes/auth.py +337 -0
- squirrels/_api_routes/base.py +196 -0
- squirrels/_api_routes/dashboards.py +156 -0
- squirrels/_api_routes/data_management.py +148 -0
- squirrels/_api_routes/datasets.py +220 -0
- squirrels/_api_routes/project.py +289 -0
- squirrels/_api_server.py +552 -134
- squirrels/_arguments/__init__.py +0 -0
- squirrels/_arguments/init_time_args.py +83 -0
- squirrels/_arguments/run_time_args.py +111 -0
- squirrels/_auth.py +777 -0
- squirrels/_command_line.py +239 -107
- squirrels/_compile_prompts.py +147 -0
- squirrels/_connection_set.py +94 -0
- squirrels/_constants.py +141 -64
- squirrels/_dashboards.py +179 -0
- squirrels/_data_sources.py +570 -0
- squirrels/_dataset_types.py +91 -0
- squirrels/_env_vars.py +209 -0
- squirrels/_exceptions.py +29 -0
- squirrels/_http_error_responses.py +52 -0
- squirrels/_initializer.py +319 -110
- squirrels/_logging.py +121 -0
- squirrels/_manifest.py +357 -187
- squirrels/_mcp_server.py +578 -0
- squirrels/_model_builder.py +69 -0
- squirrels/_model_configs.py +74 -0
- squirrels/_model_queries.py +52 -0
- squirrels/_models.py +1201 -0
- squirrels/_package_data/base_project/.env +7 -0
- squirrels/_package_data/base_project/.env.example +44 -0
- squirrels/_package_data/base_project/connections.yml +16 -0
- squirrels/_package_data/base_project/dashboards/dashboard_example.py +40 -0
- squirrels/_package_data/base_project/dashboards/dashboard_example.yml +22 -0
- squirrels/_package_data/base_project/docker/.dockerignore +16 -0
- squirrels/_package_data/base_project/docker/Dockerfile +16 -0
- squirrels/_package_data/base_project/docker/compose.yml +7 -0
- squirrels/_package_data/base_project/duckdb_init.sql +10 -0
- squirrels/_package_data/base_project/gitignore +13 -0
- squirrels/_package_data/base_project/macros/macros_example.sql +17 -0
- squirrels/_package_data/base_project/models/builds/build_example.py +26 -0
- squirrels/_package_data/base_project/models/builds/build_example.sql +16 -0
- squirrels/_package_data/base_project/models/builds/build_example.yml +57 -0
- squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +17 -0
- squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +32 -0
- squirrels/_package_data/base_project/models/federates/federate_example.py +51 -0
- squirrels/_package_data/base_project/models/federates/federate_example.sql +21 -0
- squirrels/_package_data/base_project/models/federates/federate_example.yml +65 -0
- squirrels/_package_data/base_project/models/sources.yml +38 -0
- squirrels/_package_data/base_project/parameters.yml +142 -0
- squirrels/_package_data/base_project/pyconfigs/connections.py +19 -0
- squirrels/_package_data/base_project/pyconfigs/context.py +96 -0
- squirrels/_package_data/base_project/pyconfigs/parameters.py +141 -0
- squirrels/_package_data/base_project/pyconfigs/user.py +56 -0
- squirrels/_package_data/base_project/resources/expenses.db +0 -0
- squirrels/_package_data/base_project/resources/public/.gitkeep +0 -0
- squirrels/_package_data/base_project/resources/weather.db +0 -0
- squirrels/_package_data/base_project/seeds/seed_categories.csv +6 -0
- squirrels/_package_data/base_project/seeds/seed_categories.yml +15 -0
- squirrels/_package_data/base_project/seeds/seed_subcategories.csv +15 -0
- squirrels/_package_data/base_project/seeds/seed_subcategories.yml +21 -0
- squirrels/_package_data/base_project/squirrels.yml.j2 +61 -0
- squirrels/_package_data/base_project/tmp/.gitignore +2 -0
- squirrels/_package_data/templates/login_successful.html +53 -0
- squirrels/_package_data/templates/squirrels_studio.html +22 -0
- squirrels/_package_loader.py +29 -0
- squirrels/_parameter_configs.py +592 -0
- squirrels/_parameter_options.py +348 -0
- squirrels/_parameter_sets.py +207 -0
- squirrels/_parameters.py +1703 -0
- squirrels/_project.py +796 -0
- squirrels/_py_module.py +122 -0
- squirrels/_request_context.py +33 -0
- squirrels/_schemas/__init__.py +0 -0
- squirrels/_schemas/auth_models.py +83 -0
- squirrels/_schemas/query_param_models.py +70 -0
- squirrels/_schemas/request_models.py +26 -0
- squirrels/_schemas/response_models.py +286 -0
- squirrels/_seeds.py +97 -0
- squirrels/_sources.py +112 -0
- squirrels/_utils.py +540 -149
- squirrels/_version.py +1 -3
- squirrels/arguments.py +7 -0
- squirrels/auth.py +4 -0
- squirrels/connections.py +3 -0
- squirrels/dashboards.py +3 -0
- squirrels/data_sources.py +14 -282
- squirrels/parameter_options.py +13 -189
- squirrels/parameters.py +14 -801
- squirrels/types.py +18 -0
- squirrels-0.6.0.post0.dist-info/METADATA +148 -0
- squirrels-0.6.0.post0.dist-info/RECORD +101 -0
- {squirrels-0.1.0.dist-info → squirrels-0.6.0.post0.dist-info}/WHEEL +1 -2
- {squirrels-0.1.0.dist-info → squirrels-0.6.0.post0.dist-info}/entry_points.txt +1 -0
- squirrels-0.6.0.post0.dist-info/licenses/LICENSE +201 -0
- squirrels/_credentials_manager.py +0 -87
- squirrels/_module_loader.py +0 -37
- squirrels/_parameter_set.py +0 -151
- squirrels/_renderer.py +0 -286
- squirrels/_timed_imports.py +0 -37
- squirrels/connection_set.py +0 -126
- squirrels/package_data/base_project/.gitignore +0 -4
- squirrels/package_data/base_project/connections.py +0 -21
- squirrels/package_data/base_project/database/sample_database.db +0 -0
- squirrels/package_data/base_project/database/seattle_weather.db +0 -0
- squirrels/package_data/base_project/datasets/sample_dataset/context.py +0 -8
- squirrels/package_data/base_project/datasets/sample_dataset/database_view1.py +0 -23
- squirrels/package_data/base_project/datasets/sample_dataset/database_view1.sql.j2 +0 -7
- squirrels/package_data/base_project/datasets/sample_dataset/final_view.py +0 -10
- squirrels/package_data/base_project/datasets/sample_dataset/final_view.sql.j2 +0 -2
- squirrels/package_data/base_project/datasets/sample_dataset/parameters.py +0 -30
- squirrels/package_data/base_project/datasets/sample_dataset/selections.cfg +0 -6
- squirrels/package_data/base_project/squirrels.yaml +0 -26
- squirrels/package_data/static/favicon.ico +0 -0
- squirrels/package_data/static/script.js +0 -234
- squirrels/package_data/static/style.css +0 -110
- squirrels/package_data/templates/index.html +0 -32
- squirrels-0.1.0.dist-info/LICENSE +0 -22
- squirrels-0.1.0.dist-info/METADATA +0 -67
- squirrels-0.1.0.dist-info/RECORD +0 -40
- squirrels-0.1.0.dist-info/top_level.txt +0 -1
squirrels/_project.py
ADDED
|
@@ -0,0 +1,796 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
from dotenv import dotenv_values, load_dotenv
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
import asyncio, typing as t, functools as ft, shutil, json, os
|
|
5
|
+
import sqlglot, sqlglot.expressions, duckdb, polars as pl
|
|
6
|
+
|
|
7
|
+
from ._auth import Authenticator, AuthProviderArgs, ProviderFunctionType
|
|
8
|
+
from ._schemas.auth_models import CustomUserFields, AbstractUser, GuestUser, RegisteredUser
|
|
9
|
+
from ._schemas import response_models as rm
|
|
10
|
+
from ._model_builder import ModelBuilder
|
|
11
|
+
from ._env_vars import SquirrelsEnvVars
|
|
12
|
+
from ._exceptions import InvalidInputError, ConfigurationError
|
|
13
|
+
from ._py_module import PyModule
|
|
14
|
+
from . import _dashboards as d, _utils as u, _constants as c, _manifest as mf, _connection_set as cs
|
|
15
|
+
from . import _seeds as s, _models as m, _model_configs as mc, _model_queries as mq, _sources as so
|
|
16
|
+
from . import _parameter_sets as ps, _dataset_types as dr, _logging as l
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from ._api_server import FastAPIComponents
|
|
20
|
+
|
|
21
|
+
T = t.TypeVar("T", bound=d.Dashboard)
|
|
22
|
+
M = t.TypeVar("M", bound=m.DataModel)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class SquirrelsProject:
|
|
26
|
+
"""
|
|
27
|
+
Initiate an instance of this class to interact with a Squirrels project through Python code. For example this can be handy to experiment with the datasets produced by Squirrels in a Jupyter notebook.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self, *, project_path: str = ".", load_dotenv_globally: bool = False,
|
|
32
|
+
log_to_file: bool = False, log_level: str | None = None, log_format: str | None = None,
|
|
33
|
+
) -> None:
|
|
34
|
+
"""
|
|
35
|
+
Constructor for SquirrelsProject class. Loads the file contents of the Squirrels project into memory as member fields.
|
|
36
|
+
|
|
37
|
+
Arguments:
|
|
38
|
+
project_path: The path to the Squirrels project file. Defaults to the current working directory.
|
|
39
|
+
log_level: The logging level to use. Options are "DEBUG", "INFO", and "WARNING". Default is from SQRL_LOGGING__LEVEL environment variable or "INFO".
|
|
40
|
+
log_to_file: Whether to enable logging to file(s) in the "logs/" folder (or a custom folder). Default is from SQRL_LOGGING__TO_FILE environment variable or False.
|
|
41
|
+
log_format: The format of the log records. Options are "text" and "json". Default is from SQRL_LOGGING__FORMAT environment variable or "text".
|
|
42
|
+
"""
|
|
43
|
+
project_path = str(Path(project_path).resolve())
|
|
44
|
+
|
|
45
|
+
self._project_path = project_path
|
|
46
|
+
self._env_vars_unformatted = self._load_env_vars(project_path, load_dotenv_globally)
|
|
47
|
+
self._env_vars = SquirrelsEnvVars(project_path=project_path, **self._env_vars_unformatted)
|
|
48
|
+
self._vdl_catalog_db_path = self._env_vars.vdl_catalog_db_path
|
|
49
|
+
|
|
50
|
+
self._logger = self._get_logger(project_path, self._env_vars, log_to_file, log_level, log_format)
|
|
51
|
+
self._ensure_virtual_datalake_exists(project_path, self._vdl_catalog_db_path, self._env_vars.vdl_data_path)
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def _load_env_vars(project_path: str, load_dotenv_globally: bool) -> dict[str, str]:
|
|
55
|
+
dotenv_files = [c.DOTENV_FILE, c.DOTENV_LOCAL_FILE]
|
|
56
|
+
dotenv_vars = {}
|
|
57
|
+
for file in dotenv_files:
|
|
58
|
+
full_path = u.Path(project_path, file)
|
|
59
|
+
if load_dotenv_globally:
|
|
60
|
+
load_dotenv(full_path)
|
|
61
|
+
dotenv_vars.update({k: v for k, v in dotenv_values(full_path).items() if v is not None})
|
|
62
|
+
return {**os.environ, **dotenv_vars}
|
|
63
|
+
|
|
64
|
+
@staticmethod
|
|
65
|
+
def _get_logger(
|
|
66
|
+
filepath: str, env_vars: SquirrelsEnvVars, log_to_file: bool, log_level: str | None, log_format: str | None
|
|
67
|
+
) -> u.Logger:
|
|
68
|
+
# CLI arguments take precedence over environment variables
|
|
69
|
+
log_level = log_level if log_level is not None else env_vars.logging_level
|
|
70
|
+
log_format = log_format if log_format is not None else env_vars.logging_format
|
|
71
|
+
log_to_file = env_vars.logging_to_file or log_to_file
|
|
72
|
+
log_file_size_mb = float(env_vars.logging_file_size_mb)
|
|
73
|
+
log_file_backup_count = int(env_vars.logging_file_backup_count)
|
|
74
|
+
return l.get_logger(filepath, log_to_file, log_level, log_format, log_file_size_mb, log_file_backup_count)
|
|
75
|
+
|
|
76
|
+
@staticmethod
|
|
77
|
+
def _ensure_virtual_datalake_exists(project_path: str, vdl_catalog_db_path: str, vdl_data_path: str) -> None:
|
|
78
|
+
target_path = u.Path(project_path, c.TARGET_FOLDER)
|
|
79
|
+
target_path.mkdir(parents=True, exist_ok=True)
|
|
80
|
+
|
|
81
|
+
# Attempt to set up the virtual data lake with DATA_PATH if possible
|
|
82
|
+
try:
|
|
83
|
+
is_ducklake = vdl_catalog_db_path.startswith("ducklake:")
|
|
84
|
+
|
|
85
|
+
options = f"(DATA_PATH '{vdl_data_path}')" if is_ducklake else ""
|
|
86
|
+
attach_stmt = f"ATTACH '{vdl_catalog_db_path}' AS vdl {options}"
|
|
87
|
+
with duckdb.connect() as conn:
|
|
88
|
+
conn.execute(attach_stmt)
|
|
89
|
+
# TODO: support incremental loads for build models and avoid cleaning up old files all the time
|
|
90
|
+
conn.execute("CALL ducklake_expire_snapshots('vdl', older_than => now())")
|
|
91
|
+
conn.execute("CALL ducklake_cleanup_old_files('vdl', cleanup_all => true)")
|
|
92
|
+
|
|
93
|
+
except Exception as e:
|
|
94
|
+
if "DATA_PATH parameter" in str(e):
|
|
95
|
+
first_line = str(e).split("\n")[0]
|
|
96
|
+
note = "NOTE: Squirrels does not allow changing the data path for an existing Virtual Data Lake (VDL)"
|
|
97
|
+
raise u.ConfigurationError(f"{first_line}\n\n{note}")
|
|
98
|
+
|
|
99
|
+
if is_ducklake and not any(x in vdl_catalog_db_path for x in [":sqlite:", ":postgres:", ":mysql:"]):
|
|
100
|
+
extended_error = "\n- Note: if you're using DuckDB for the metadata database, only one process can connect to the VDL at a time."
|
|
101
|
+
else:
|
|
102
|
+
extended_error = ""
|
|
103
|
+
|
|
104
|
+
raise u.ConfigurationError(f"Failed to attach Virtual Data Lake (VDL).{extended_error}") from e
|
|
105
|
+
|
|
106
|
+
@ft.cached_property
|
|
107
|
+
def _manifest_cfg(self) -> mf.ManifestConfig:
|
|
108
|
+
return mf.ManifestIO.load_from_file(self._logger, self._project_path, self._env_vars_unformatted)
|
|
109
|
+
|
|
110
|
+
@ft.cached_property
|
|
111
|
+
def _seeds(self) -> s.Seeds:
|
|
112
|
+
return s.SeedsIO.load_files(self._logger, self._env_vars)
|
|
113
|
+
|
|
114
|
+
@ft.cached_property
|
|
115
|
+
def _sources(self) -> so.Sources:
|
|
116
|
+
return so.SourcesIO.load_file(self._logger, self._env_vars, self._env_vars_unformatted)
|
|
117
|
+
|
|
118
|
+
@ft.cached_property
|
|
119
|
+
def _build_model_files(self) -> dict[str, mq.QueryFileWithConfig]:
|
|
120
|
+
return m.ModelsIO.load_build_files(self._logger, self._env_vars)
|
|
121
|
+
|
|
122
|
+
@ft.cached_property
|
|
123
|
+
def _dbview_model_files(self) -> dict[str, mq.QueryFileWithConfig]:
|
|
124
|
+
return m.ModelsIO.load_dbview_files(self._logger, self._env_vars)
|
|
125
|
+
|
|
126
|
+
@ft.cached_property
|
|
127
|
+
def _federate_model_files(self) -> dict[str, mq.QueryFileWithConfig]:
|
|
128
|
+
return m.ModelsIO.load_federate_files(self._logger, self._env_vars)
|
|
129
|
+
|
|
130
|
+
@ft.cached_property
|
|
131
|
+
def _context_func(self) -> m.ContextFunc:
|
|
132
|
+
return m.ModelsIO.load_context_func(self._logger, self._project_path)
|
|
133
|
+
|
|
134
|
+
@ft.cached_property
|
|
135
|
+
def _dashboards(self) -> dict[str, d.DashboardDefinition]:
|
|
136
|
+
return d.DashboardsIO.load_files(
|
|
137
|
+
self._logger, self._project_path, self._manifest_cfg.project_variables.auth_type, self._manifest_cfg.configurables
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
@ft.cached_property
|
|
141
|
+
def _conn_args(self) -> cs.ConnectionsArgs:
|
|
142
|
+
proj_vars = self._manifest_cfg.project_variables.model_dump()
|
|
143
|
+
conn_args = cs.ConnectionsArgs(self._project_path, proj_vars, self._env_vars_unformatted)
|
|
144
|
+
return conn_args
|
|
145
|
+
|
|
146
|
+
@ft.cached_property
|
|
147
|
+
def _conn_set(self) -> cs.ConnectionSet:
|
|
148
|
+
return cs.ConnectionSetIO.load_from_file(self._logger, self._project_path, self._manifest_cfg, self._conn_args)
|
|
149
|
+
|
|
150
|
+
@ft.cached_property
|
|
151
|
+
def _custom_user_fields_cls_and_provider_functions(self) -> tuple[type[CustomUserFields], list[ProviderFunctionType]]:
|
|
152
|
+
user_module_path = u.Path(self._project_path, c.PYCONFIGS_FOLDER, c.USER_FILE)
|
|
153
|
+
user_module = PyModule(user_module_path, self._project_path)
|
|
154
|
+
|
|
155
|
+
# Load CustomUserFields class (adds to Authenticator.providers as side effect)
|
|
156
|
+
CustomUserFieldsCls = user_module.get_func_or_class("CustomUserFields", default_attr=CustomUserFields)
|
|
157
|
+
provider_functions = Authenticator.providers
|
|
158
|
+
Authenticator.providers = []
|
|
159
|
+
|
|
160
|
+
if not issubclass(CustomUserFieldsCls, CustomUserFields):
|
|
161
|
+
raise ConfigurationError(f"CustomUserFields class in '{c.USER_FILE}' must inherit from CustomUserFields")
|
|
162
|
+
|
|
163
|
+
return CustomUserFieldsCls, provider_functions
|
|
164
|
+
|
|
165
|
+
@ft.cached_property
|
|
166
|
+
def _auth(self) -> Authenticator:
|
|
167
|
+
auth_args = AuthProviderArgs(**self._conn_args.__dict__)
|
|
168
|
+
CustomUserFieldsCls, provider_functions = self._custom_user_fields_cls_and_provider_functions
|
|
169
|
+
external_only = (self._manifest_cfg.project_variables.auth_strategy == mf.AuthStrategy.EXTERNAL)
|
|
170
|
+
|
|
171
|
+
if external_only and len(provider_functions) != 1:
|
|
172
|
+
raise ConfigurationError(f"When auth_strategy is 'external', there must be exactly one auth provider function. Found {len(provider_functions)} auth providers.")
|
|
173
|
+
|
|
174
|
+
return Authenticator(
|
|
175
|
+
self._logger, self._env_vars, auth_args, provider_functions,
|
|
176
|
+
custom_user_fields_cls=CustomUserFieldsCls, external_only=external_only
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
@ft.cached_property
|
|
180
|
+
def _guest_user(self) -> AbstractUser:
|
|
181
|
+
custom_fields = self._auth.CustomUserFields()
|
|
182
|
+
return GuestUser(username="", custom_fields=custom_fields)
|
|
183
|
+
|
|
184
|
+
@ft.cached_property
|
|
185
|
+
def _admin_user(self) -> AbstractUser:
|
|
186
|
+
custom_fields = self._auth.CustomUserFields()
|
|
187
|
+
return RegisteredUser(username="", access_level="admin", custom_fields=custom_fields)
|
|
188
|
+
|
|
189
|
+
@ft.cached_property
|
|
190
|
+
def _param_args(self) -> ps.ParametersArgs:
|
|
191
|
+
conn_args = self._conn_args
|
|
192
|
+
return ps.ParametersArgs(**conn_args.__dict__)
|
|
193
|
+
|
|
194
|
+
@ft.cached_property
|
|
195
|
+
def _param_cfg_set(self) -> ps.ParameterConfigsSet:
|
|
196
|
+
return ps.ParameterConfigsSetIO.load_from_file(
|
|
197
|
+
self._logger, self._env_vars, self._manifest_cfg, self._seeds, self._conn_set, self._param_args
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
@ft.cached_property
|
|
201
|
+
def _j2_env(self) -> u.EnvironmentWithMacros:
|
|
202
|
+
env = u.EnvironmentWithMacros(self._logger, loader=u.j2.FileSystemLoader(self._project_path))
|
|
203
|
+
|
|
204
|
+
def value_to_str(value: t.Any, attribute: str | None = None) -> str:
|
|
205
|
+
if attribute is None:
|
|
206
|
+
return str(value)
|
|
207
|
+
else:
|
|
208
|
+
return str(getattr(value, attribute))
|
|
209
|
+
|
|
210
|
+
def join(value: list[t.Any], d: str = ", ", attribute: str | None = None) -> str:
|
|
211
|
+
return d.join(map(lambda x: value_to_str(x, attribute), value))
|
|
212
|
+
|
|
213
|
+
def quote(value: t.Any, q: str = "'", attribute: str | None = None) -> str:
|
|
214
|
+
return q + value_to_str(value, attribute) + q
|
|
215
|
+
|
|
216
|
+
def quote_and_join(value: list[t.Any], q: str = "'", d: str = ", ", attribute: str | None = None) -> str:
|
|
217
|
+
return d.join(map(lambda x: quote(x, q, attribute), value))
|
|
218
|
+
|
|
219
|
+
env.filters["join"] = join
|
|
220
|
+
env.filters["quote"] = quote
|
|
221
|
+
env.filters["quote_and_join"] = quote_and_join
|
|
222
|
+
return env
|
|
223
|
+
|
|
224
|
+
def get_fastapi_components(
|
|
225
|
+
self, *, no_cache: bool = False, host: str = "localhost", port: int = 8000,
|
|
226
|
+
mount_path_format: str = "/analytics/{project_name}/v{project_version}"
|
|
227
|
+
) -> "FastAPIComponents":
|
|
228
|
+
"""
|
|
229
|
+
Get the FastAPI components for the Squirrels project including mount path, lifespan, and FastAPI app.
|
|
230
|
+
|
|
231
|
+
Arguments:
|
|
232
|
+
no_cache: Whether to disable caching for parameter options, datasets, and dashboard results in the API server.
|
|
233
|
+
host: The host the API server will listen on. Only used for the welcome banner.
|
|
234
|
+
port: The port the API server will listen on. Only used for the welcome banner.
|
|
235
|
+
mount_path_format: The format of the mount path. Use {project_name} and {project_version} as placeholders.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
A FastAPIComponents object containing the mount path, lifespan, and FastAPI app.
|
|
239
|
+
"""
|
|
240
|
+
from ._api_server import ApiServer
|
|
241
|
+
api_server = ApiServer(no_cache=no_cache, project=self)
|
|
242
|
+
return api_server.get_fastapi_components(host=host, port=port, mount_path_format=mount_path_format)
|
|
243
|
+
|
|
244
|
+
def close(self) -> None:
|
|
245
|
+
"""
|
|
246
|
+
Deliberately close any open resources within the Squirrels project, such as database connections (instead of relying on the garbage collector).
|
|
247
|
+
"""
|
|
248
|
+
self._conn_set.dispose()
|
|
249
|
+
self._auth.close()
|
|
250
|
+
|
|
251
|
+
def __enter__(self):
|
|
252
|
+
return self
|
|
253
|
+
|
|
254
|
+
def __exit__(self, exc_type, exc_val, traceback):
|
|
255
|
+
self.close()
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def _add_model(self, models_dict: dict[str, M], model: M) -> None:
|
|
259
|
+
if model.name in models_dict:
|
|
260
|
+
raise ConfigurationError(f"Names across all models must be unique. Model '{model.name}' is duplicated")
|
|
261
|
+
models_dict[model.name] = model
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def _get_static_models(self) -> dict[str, m.StaticModel]:
|
|
265
|
+
models_dict: dict[str, m.StaticModel] = {}
|
|
266
|
+
|
|
267
|
+
seeds_dict = self._seeds.get_dataframes()
|
|
268
|
+
for key, seed in seeds_dict.items():
|
|
269
|
+
self._add_model(models_dict, m.Seed(key, seed.config, seed.df, logger=self._logger, conn_set=self._conn_set))
|
|
270
|
+
|
|
271
|
+
for source_name, source_config in self._sources.sources.items():
|
|
272
|
+
self._add_model(models_dict, m.SourceModel(source_name, source_config, logger=self._logger, conn_set=self._conn_set))
|
|
273
|
+
|
|
274
|
+
for name, val in self._build_model_files.items():
|
|
275
|
+
model = m.BuildModel(name, val.config, val.query_file, logger=self._logger, conn_set=self._conn_set, j2_env=self._j2_env)
|
|
276
|
+
self._add_model(models_dict, model)
|
|
277
|
+
|
|
278
|
+
return models_dict
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
async def build(self, *, full_refresh: bool = False, select: str | None = None) -> None:
|
|
282
|
+
"""
|
|
283
|
+
Build the Virtual Data Lake (VDL) for the Squirrels project
|
|
284
|
+
|
|
285
|
+
Arguments:
|
|
286
|
+
full_refresh: Whether to drop all tables and rebuild the VDL from scratch. Default is False.
|
|
287
|
+
select: The name of a specific model to build. If None, all models are built. Default is None.
|
|
288
|
+
"""
|
|
289
|
+
models_dict: dict[str, m.StaticModel] = self._get_static_models()
|
|
290
|
+
builder = ModelBuilder(self._vdl_catalog_db_path, self._conn_set, models_dict, self._conn_args, self._logger)
|
|
291
|
+
await builder.build(full_refresh, select)
|
|
292
|
+
|
|
293
|
+
def _get_models_dict(self, always_python_df: bool) -> dict[str, m.DataModel]:
|
|
294
|
+
models_dict: dict[str, m.DataModel] = self._get_static_models()
|
|
295
|
+
|
|
296
|
+
for name, val in self._dbview_model_files.items():
|
|
297
|
+
self._add_model(models_dict, m.DbviewModel(
|
|
298
|
+
name, val.config, val.query_file, logger=self._logger, conn_set=self._conn_set, j2_env=self._j2_env
|
|
299
|
+
))
|
|
300
|
+
models_dict[name].needs_python_df = always_python_df
|
|
301
|
+
|
|
302
|
+
for name, val in self._federate_model_files.items():
|
|
303
|
+
self._add_model(models_dict, m.FederateModel(
|
|
304
|
+
name, val.config, val.query_file, logger=self._logger, conn_set=self._conn_set, j2_env=self._j2_env
|
|
305
|
+
))
|
|
306
|
+
models_dict[name].needs_python_df = always_python_df
|
|
307
|
+
|
|
308
|
+
return models_dict
|
|
309
|
+
|
|
310
|
+
def _generate_dag(self, dataset: str) -> m.DAG:
|
|
311
|
+
models_dict = self._get_models_dict(always_python_df=False)
|
|
312
|
+
|
|
313
|
+
dataset_config = self._manifest_cfg.datasets[dataset]
|
|
314
|
+
target_model = models_dict[dataset_config.model]
|
|
315
|
+
target_model.is_target = True
|
|
316
|
+
dag = m.DAG(dataset_config, target_model, models_dict, self._vdl_catalog_db_path, self._logger)
|
|
317
|
+
|
|
318
|
+
return dag
|
|
319
|
+
|
|
320
|
+
def _generate_dag_with_fake_target(self, sql_query: str | None, *, always_python_df: bool = False) -> m.DAG:
|
|
321
|
+
models_dict = self._get_models_dict(always_python_df=always_python_df)
|
|
322
|
+
|
|
323
|
+
if sql_query is None:
|
|
324
|
+
dependencies = set(models_dict.keys())
|
|
325
|
+
else:
|
|
326
|
+
dependencies, parsed = u.parse_dependent_tables(sql_query, models_dict.keys())
|
|
327
|
+
|
|
328
|
+
substitutions = {}
|
|
329
|
+
for model_name in dependencies:
|
|
330
|
+
model = models_dict[model_name]
|
|
331
|
+
if isinstance(model, m.SourceModel) and not model.is_queryable:
|
|
332
|
+
raise InvalidInputError(400, "cannot_query_source_model", f"Source model '{model_name}' cannot be queried with DuckDB")
|
|
333
|
+
if isinstance(model, m.BuildModel):
|
|
334
|
+
substitutions[model_name] = f"vdl.{model_name}"
|
|
335
|
+
elif isinstance(model, m.SourceModel):
|
|
336
|
+
if model.model_config.load_to_vdl:
|
|
337
|
+
substitutions[model_name] = f"vdl.{model_name}"
|
|
338
|
+
else:
|
|
339
|
+
# DuckDB connection without load_to_vdl - reference via attached database
|
|
340
|
+
conn_name = model.model_config.get_connection()
|
|
341
|
+
table_name = model.model_config.get_table()
|
|
342
|
+
substitutions[model_name] = f"db_{conn_name}.{table_name}"
|
|
343
|
+
|
|
344
|
+
sql_query = parsed.transform(
|
|
345
|
+
lambda node: sqlglot.expressions.Table(this=substitutions[node.name], alias=node.alias)
|
|
346
|
+
if isinstance(node, sqlglot.expressions.Table) and node.name in substitutions
|
|
347
|
+
else node
|
|
348
|
+
).sql()
|
|
349
|
+
|
|
350
|
+
model_config = mc.FederateModelConfig(depends_on=dependencies)
|
|
351
|
+
query_file = mq.SqlQueryFile("", sql_query or "SELECT 1")
|
|
352
|
+
fake_target_model = m.FederateModel(
|
|
353
|
+
"__fake_target", model_config, query_file, logger=self._logger, conn_set=self._conn_set, j2_env=self._j2_env
|
|
354
|
+
)
|
|
355
|
+
fake_target_model.is_target = True
|
|
356
|
+
dag = m.DAG(None, fake_target_model, models_dict, self._vdl_catalog_db_path, self._logger)
|
|
357
|
+
return dag
|
|
358
|
+
|
|
359
|
+
async def _get_compiled_dag(
|
|
360
|
+
self, user: AbstractUser, *, sql_query: str | None = None, selections: dict[str, t.Any] = {}, configurables: dict[str, str] = {},
|
|
361
|
+
always_python_df: bool = False
|
|
362
|
+
) -> m.DAG:
|
|
363
|
+
dag = self._generate_dag_with_fake_target(sql_query, always_python_df=always_python_df)
|
|
364
|
+
|
|
365
|
+
configurables = {**self._manifest_cfg.get_default_configurables(), **configurables}
|
|
366
|
+
await dag.execute(
|
|
367
|
+
self._param_args, self._param_cfg_set, self._context_func, user, selections,
|
|
368
|
+
runquery=False, configurables=configurables
|
|
369
|
+
)
|
|
370
|
+
return dag
|
|
371
|
+
|
|
372
|
+
def _get_all_connections(self) -> list[rm.ConnectionItemModel]:
|
|
373
|
+
connections = []
|
|
374
|
+
for conn_name, conn_props in self._conn_set.get_connections_as_dict().items():
|
|
375
|
+
if isinstance(conn_props, mf.ConnectionProperties):
|
|
376
|
+
label = conn_props.label if conn_props.label is not None else conn_name
|
|
377
|
+
connections.append(rm.ConnectionItemModel(name=conn_name, label=label))
|
|
378
|
+
return connections
|
|
379
|
+
|
|
380
|
+
def _get_all_data_models(self, compiled_dag: m.DAG) -> list[rm.DataModelItem]:
|
|
381
|
+
return compiled_dag.get_all_data_models()
|
|
382
|
+
|
|
383
|
+
async def get_all_data_models(self) -> list[rm.DataModelItem]:
|
|
384
|
+
"""
|
|
385
|
+
Get all data models in the project
|
|
386
|
+
|
|
387
|
+
Returns:
|
|
388
|
+
A list of DataModelItem objects
|
|
389
|
+
"""
|
|
390
|
+
compiled_dag = await self._get_compiled_dag(self._admin_user)
|
|
391
|
+
return self._get_all_data_models(compiled_dag)
|
|
392
|
+
|
|
393
|
+
def _get_all_data_lineage(self, compiled_dag: m.DAG) -> list[rm.LineageRelation]:
|
|
394
|
+
all_lineage = compiled_dag.get_all_model_lineage()
|
|
395
|
+
|
|
396
|
+
# Add dataset nodes to the lineage
|
|
397
|
+
for dataset in self._manifest_cfg.datasets.values():
|
|
398
|
+
target_dataset = rm.LineageNode(name=dataset.name, type="dataset")
|
|
399
|
+
source_model = rm.LineageNode(name=dataset.model, type="model")
|
|
400
|
+
all_lineage.append(rm.LineageRelation(type="runtime", source=source_model, target=target_dataset))
|
|
401
|
+
|
|
402
|
+
# Add dashboard nodes to the lineage
|
|
403
|
+
for dashboard in self._dashboards.values():
|
|
404
|
+
target_dashboard = rm.LineageNode(name=dashboard.dashboard_name, type="dashboard")
|
|
405
|
+
datasets = set(x.dataset for x in dashboard.config.depends_on)
|
|
406
|
+
for dataset in datasets:
|
|
407
|
+
source_dataset = rm.LineageNode(name=dataset, type="dataset")
|
|
408
|
+
all_lineage.append(rm.LineageRelation(type="runtime", source=source_dataset, target=target_dashboard))
|
|
409
|
+
|
|
410
|
+
return all_lineage
|
|
411
|
+
|
|
412
|
+
async def get_all_data_lineage(self) -> list[rm.LineageRelation]:
|
|
413
|
+
"""
|
|
414
|
+
Get all data lineage in the project
|
|
415
|
+
|
|
416
|
+
Returns:
|
|
417
|
+
A list of LineageRelation objects
|
|
418
|
+
"""
|
|
419
|
+
compiled_dag = await self._get_compiled_dag(self._admin_user)
|
|
420
|
+
return self._get_all_data_lineage(compiled_dag)
|
|
421
|
+
|
|
422
|
+
async def compile(
|
|
423
|
+
self, *, selected_model: str | None = None, test_set: str | None = None, do_all_test_sets: bool = False,
|
|
424
|
+
runquery: bool = False, clear: bool = False, buildtime_only: bool = False, runtime_only: bool = False
|
|
425
|
+
) -> None:
|
|
426
|
+
"""
|
|
427
|
+
Compile models into the "target/compile" folder.
|
|
428
|
+
|
|
429
|
+
Behavior:
|
|
430
|
+
- Buildtime outputs: target/compile/buildtime/*.sql (for SQL build models) and dag.png
|
|
431
|
+
- Runtime outputs: target/compile/runtime/[test_set]/dbviews/*.sql, federates/*.sql, dag.png
|
|
432
|
+
If runquery=True, also write CSVs for runtime models.
|
|
433
|
+
- Options: clear entire compile folder first; compile only buildtime or only runtime.
|
|
434
|
+
|
|
435
|
+
Arguments:
|
|
436
|
+
selected_model: The name of the model to compile. If specified, the compiled SQL query is also printed in the terminal. If None, all models for the selected dataset are compiled. Default is None.
|
|
437
|
+
test_set: The name of the test set to compile with. If None, the default test set is used (which can vary by dataset). Ignored if `do_all_test_sets` argument is True. Default is None.
|
|
438
|
+
do_all_test_sets: Whether to compile all applicable test sets for the selected dataset(s). If True, the `test_set` argument is ignored. Default is False.
|
|
439
|
+
runquery: Whether to run all compiled queries and save each result as a CSV file. If True and `selected_model` is specified, all upstream models of the selected model is compiled as well. Default is False.
|
|
440
|
+
clear: Whether to clear the "target/compile/" folder before compiling. Default is False.
|
|
441
|
+
buildtime_only: Whether to compile only buildtime models. Default is False.
|
|
442
|
+
runtime_only: Whether to compile only runtime models. Default is False.
|
|
443
|
+
"""
|
|
444
|
+
border = "=" * 80
|
|
445
|
+
underlines = "-" * len(border)
|
|
446
|
+
|
|
447
|
+
compile_root = Path(self._project_path, c.TARGET_FOLDER, c.COMPILE_FOLDER)
|
|
448
|
+
if clear and compile_root.exists():
|
|
449
|
+
shutil.rmtree(compile_root)
|
|
450
|
+
|
|
451
|
+
models_dict = self._get_models_dict(always_python_df=False)
|
|
452
|
+
|
|
453
|
+
if selected_model is not None:
|
|
454
|
+
selected_model = u.normalize_name(selected_model)
|
|
455
|
+
if selected_model not in models_dict:
|
|
456
|
+
print(f"No such model found: {selected_model}")
|
|
457
|
+
return
|
|
458
|
+
if not isinstance(models_dict[selected_model], m.QueryModel):
|
|
459
|
+
print(f"Model '{selected_model}' is not a query model. Nothing to do.")
|
|
460
|
+
return
|
|
461
|
+
|
|
462
|
+
model_to_compile = None
|
|
463
|
+
|
|
464
|
+
# Buildtime compilation
|
|
465
|
+
if not runtime_only:
|
|
466
|
+
print(underlines)
|
|
467
|
+
print(f"Compiling buildtime models")
|
|
468
|
+
print(underlines)
|
|
469
|
+
|
|
470
|
+
buildtime_folder = Path(compile_root, c.COMPILE_BUILDTIME_FOLDER)
|
|
471
|
+
buildtime_folder.mkdir(parents=True, exist_ok=True)
|
|
472
|
+
|
|
473
|
+
def write_buildtime_model(model: m.DataModel, static_models: dict[str, m.StaticModel]) -> None:
|
|
474
|
+
if not isinstance(model, m.BuildModel):
|
|
475
|
+
return
|
|
476
|
+
|
|
477
|
+
model.compile_for_build(self._conn_args, static_models)
|
|
478
|
+
|
|
479
|
+
if isinstance(model.compiled_query, mq.SqlModelQuery):
|
|
480
|
+
out_path = Path(buildtime_folder, f"{model.name}.sql")
|
|
481
|
+
with open(out_path, 'w') as f:
|
|
482
|
+
f.write(model.compiled_query.query)
|
|
483
|
+
print(f"Successfully compiled build model: {model.name}")
|
|
484
|
+
elif isinstance(model.compiled_query, mq.PyModelQuery):
|
|
485
|
+
print(f"The build model '{model.name}' is in Python. Compilation for Python is not supported yet.")
|
|
486
|
+
|
|
487
|
+
static_models = self._get_static_models()
|
|
488
|
+
if selected_model is not None:
|
|
489
|
+
model_to_compile = models_dict[selected_model]
|
|
490
|
+
write_buildtime_model(model_to_compile, static_models)
|
|
491
|
+
else:
|
|
492
|
+
coros = [asyncio.to_thread(write_buildtime_model, m, static_models) for m in static_models.values()]
|
|
493
|
+
await u.asyncio_gather(coros)
|
|
494
|
+
|
|
495
|
+
print(underlines)
|
|
496
|
+
print()
|
|
497
|
+
|
|
498
|
+
# Runtime compilation
|
|
499
|
+
if not buildtime_only:
|
|
500
|
+
if do_all_test_sets:
|
|
501
|
+
test_set_names_set = set(self._manifest_cfg.selection_test_sets.keys())
|
|
502
|
+
test_set_names_set.add(c.DEFAULT_TEST_SET_NAME)
|
|
503
|
+
test_set_names = list(test_set_names_set)
|
|
504
|
+
else:
|
|
505
|
+
test_set_names = [test_set or c.DEFAULT_TEST_SET_NAME]
|
|
506
|
+
|
|
507
|
+
for ts_name in test_set_names:
|
|
508
|
+
print(underlines)
|
|
509
|
+
print(f"Compiling runtime models (test set '{ts_name}')")
|
|
510
|
+
print(underlines)
|
|
511
|
+
|
|
512
|
+
# Build user and selections from test set config if present
|
|
513
|
+
ts_conf = self._manifest_cfg.selection_test_sets.get(ts_name, self._manifest_cfg.get_default_test_set())
|
|
514
|
+
# Separate base fields from custom fields
|
|
515
|
+
access_level = ts_conf.user.access_level
|
|
516
|
+
custom_fields = self._auth.CustomUserFields(**ts_conf.user.custom_fields)
|
|
517
|
+
if access_level == "guest":
|
|
518
|
+
user = GuestUser(username="", custom_fields=custom_fields)
|
|
519
|
+
else:
|
|
520
|
+
user = RegisteredUser(username="", access_level=access_level, custom_fields=custom_fields)
|
|
521
|
+
|
|
522
|
+
# Generate DAG across all models. When runquery=True, force models to produce Python dataframes so CSVs can be written.
|
|
523
|
+
dag = await self._get_compiled_dag(
|
|
524
|
+
user=user, selections=ts_conf.parameters, configurables=ts_conf.configurables, always_python_df=runquery,
|
|
525
|
+
)
|
|
526
|
+
if runquery:
|
|
527
|
+
await dag._run_models()
|
|
528
|
+
|
|
529
|
+
# Prepare output folders
|
|
530
|
+
runtime_folder = Path(compile_root, c.COMPILE_RUNTIME_FOLDER, ts_name)
|
|
531
|
+
dbviews_folder = Path(runtime_folder, c.DBVIEWS_FOLDER)
|
|
532
|
+
federates_folder = Path(runtime_folder, c.FEDERATES_FOLDER)
|
|
533
|
+
dbviews_folder.mkdir(parents=True, exist_ok=True)
|
|
534
|
+
federates_folder.mkdir(parents=True, exist_ok=True)
|
|
535
|
+
with open(Path(runtime_folder, "placeholders.json"), "w") as f:
|
|
536
|
+
json.dump(dag.placeholders, f)
|
|
537
|
+
|
|
538
|
+
# Function to write runtime models
|
|
539
|
+
def write_runtime_model(model: m.DataModel) -> None:
|
|
540
|
+
if not isinstance(model, m.QueryModel):
|
|
541
|
+
return
|
|
542
|
+
|
|
543
|
+
if model.model_type not in (m.ModelType.DBVIEW, m.ModelType.FEDERATE):
|
|
544
|
+
return
|
|
545
|
+
|
|
546
|
+
subfolder = dbviews_folder if model.model_type == m.ModelType.DBVIEW else federates_folder
|
|
547
|
+
model_type = "dbview" if model.model_type == m.ModelType.DBVIEW else "federate"
|
|
548
|
+
|
|
549
|
+
if isinstance(model.compiled_query, mq.SqlModelQuery):
|
|
550
|
+
out_sql = Path(subfolder, f"{model.name}.sql")
|
|
551
|
+
with open(out_sql, 'w') as f:
|
|
552
|
+
f.write(model.compiled_query.query)
|
|
553
|
+
print(f"Successfully compiled {model_type} model: {model.name}")
|
|
554
|
+
elif isinstance(model.compiled_query, mq.PyModelQuery):
|
|
555
|
+
print(f"The {model_type} model '{model.name}' is in Python. Compilation for Python is not supported yet.")
|
|
556
|
+
|
|
557
|
+
if runquery and isinstance(model.result, pl.LazyFrame):
|
|
558
|
+
out_csv = Path(subfolder, f"{model.name}.csv")
|
|
559
|
+
model.result.collect().write_csv(out_csv)
|
|
560
|
+
print(f"Successfully created CSV for {model_type} model: {model.name}")
|
|
561
|
+
|
|
562
|
+
# If selected_model is provided for runtime, only emit that model's outputs
|
|
563
|
+
if selected_model is not None:
|
|
564
|
+
model_to_compile = dag.models_dict[selected_model]
|
|
565
|
+
write_runtime_model(model_to_compile)
|
|
566
|
+
else:
|
|
567
|
+
coros = [asyncio.to_thread(write_runtime_model, model) for model in dag.models_dict.values()]
|
|
568
|
+
await u.asyncio_gather(coros)
|
|
569
|
+
|
|
570
|
+
print(underlines)
|
|
571
|
+
print()
|
|
572
|
+
|
|
573
|
+
print(f"All compilations complete! See the '{c.TARGET_FOLDER}/{c.COMPILE_FOLDER}/' folder for results.")
|
|
574
|
+
if model_to_compile and isinstance(model_to_compile, m.QueryModel) and isinstance(model_to_compile.compiled_query, mq.SqlModelQuery):
|
|
575
|
+
print()
|
|
576
|
+
print(border)
|
|
577
|
+
print(f"Compiled SQL query for model '{model_to_compile.name}':")
|
|
578
|
+
print(underlines)
|
|
579
|
+
print(model_to_compile.compiled_query.query)
|
|
580
|
+
print(border)
|
|
581
|
+
print()
|
|
582
|
+
|
|
583
|
+
def _permission_error(self, user: AbstractUser, data_type: str, data_name: str, scope: str) -> InvalidInputError:
|
|
584
|
+
return InvalidInputError(403, f"unauthorized_access_to_{data_type}", f"User '{user}' does not have permission to access {scope} {data_type}: {data_name}")
|
|
585
|
+
|
|
586
|
+
def seed(self, name: str) -> pl.LazyFrame:
|
|
587
|
+
"""
|
|
588
|
+
Method to retrieve a seed as a polars LazyFrame given a seed name.
|
|
589
|
+
|
|
590
|
+
Arguments:
|
|
591
|
+
name: The name of the seed to retrieve
|
|
592
|
+
|
|
593
|
+
Returns:
|
|
594
|
+
The seed as a polars LazyFrame
|
|
595
|
+
"""
|
|
596
|
+
seeds_dict = self._seeds.get_dataframes()
|
|
597
|
+
try:
|
|
598
|
+
return seeds_dict[name].df
|
|
599
|
+
except KeyError:
|
|
600
|
+
available_seeds = list(seeds_dict.keys())
|
|
601
|
+
raise KeyError(f"Seed '{name}' not found. Available seeds are: {available_seeds}")
|
|
602
|
+
|
|
603
|
+
def dataset_metadata(self, name: str) -> dr.DatasetMetadata:
|
|
604
|
+
"""
|
|
605
|
+
Method to retrieve the metadata of a dataset given a dataset name.
|
|
606
|
+
|
|
607
|
+
Arguments:
|
|
608
|
+
name: The name of the dataset to retrieve.
|
|
609
|
+
|
|
610
|
+
Returns:
|
|
611
|
+
A DatasetMetadata object containing the dataset description and column details.
|
|
612
|
+
"""
|
|
613
|
+
dag = self._generate_dag(name)
|
|
614
|
+
dag.target_model.process_pass_through_columns(dag.models_dict)
|
|
615
|
+
return dr.DatasetMetadata(
|
|
616
|
+
target_model_config=dag.target_model.model_config
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
def _enforce_max_result_rows(self, lazy_df: pl.LazyFrame, error_type: str) -> pl.DataFrame:
|
|
620
|
+
"""
|
|
621
|
+
Collect at most max_rows + 1 rows from a LazyFrame to detect overflow.
|
|
622
|
+
Raises InvalidInputError if the result exceeds the maximum allowed rows.
|
|
623
|
+
|
|
624
|
+
Arguments:
|
|
625
|
+
lazy_df: The LazyFrame to collect and check
|
|
626
|
+
error_type: Either "dataset" or "query" to customize the error message
|
|
627
|
+
|
|
628
|
+
Returns:
|
|
629
|
+
A DataFrame with at most max_rows rows (or raises if exceeded)
|
|
630
|
+
"""
|
|
631
|
+
max_rows = self._env_vars.datasets_max_rows_output
|
|
632
|
+
# Collect max_rows + 1 to detect overflow without loading unbounded results
|
|
633
|
+
collected = lazy_df.limit(max_rows + 1).collect()
|
|
634
|
+
row_count = collected.select(pl.len()).item()
|
|
635
|
+
|
|
636
|
+
if row_count > max_rows:
|
|
637
|
+
raise InvalidInputError(
|
|
638
|
+
413, f"{error_type}_result_too_large",
|
|
639
|
+
f"The {error_type} result contains {row_count} rows, which exceeds the maximum allowed of {max_rows} rows."
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
return collected
|
|
643
|
+
|
|
644
|
+
async def _dataset_result(
|
|
645
|
+
self, name: str, *, selections: dict[str, t.Any] = {}, user: AbstractUser | None = None,
|
|
646
|
+
configurables: dict[str, str] = {}, check_user_access: bool = True
|
|
647
|
+
) -> dr.DatasetResult:
|
|
648
|
+
if user is None:
|
|
649
|
+
user = self._guest_user
|
|
650
|
+
|
|
651
|
+
scope = self._manifest_cfg.datasets[name].scope
|
|
652
|
+
if check_user_access and not self._auth.can_user_access_scope(user, scope):
|
|
653
|
+
raise self._permission_error(user, "dataset", name, scope.name)
|
|
654
|
+
|
|
655
|
+
dataset_config = self._manifest_cfg.datasets[name]
|
|
656
|
+
configurables = {**self._manifest_cfg.get_default_configurables(overrides=dataset_config.configurables), **configurables}
|
|
657
|
+
|
|
658
|
+
dag = self._generate_dag(name)
|
|
659
|
+
await dag.execute(
|
|
660
|
+
self._param_args, self._param_cfg_set, self._context_func, user, dict(selections), configurables=configurables
|
|
661
|
+
)
|
|
662
|
+
assert isinstance(dag.target_model.result, pl.LazyFrame)
|
|
663
|
+
df = self._enforce_max_result_rows(dag.target_model.result, "dataset")
|
|
664
|
+
return dr.DatasetResult(
|
|
665
|
+
target_model_config=dag.target_model.model_config,
|
|
666
|
+
df=df.with_row_index("_row_num", offset=1)
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
async def dataset_result(
|
|
670
|
+
self, name: str, *, selections: dict[str, t.Any] = {}, user: AbstractUser | None = None, configurables: dict[str, str] = {}
|
|
671
|
+
) -> dr.DatasetResult:
|
|
672
|
+
"""
|
|
673
|
+
Async method to retrieve a dataset as a DatasetResult object (with metadata) given parameter selections.
|
|
674
|
+
|
|
675
|
+
Arguments:
|
|
676
|
+
name: The name of the dataset to retrieve.
|
|
677
|
+
selections: A dictionary of parameter selections to apply to the dataset. Optional, default is empty dictionary.
|
|
678
|
+
user: The user to use for authentication. If None, no user is used. Optional, default is None.
|
|
679
|
+
configurables: A dictionary of configurables to apply to the dataset. Optional, default is empty dictionary.
|
|
680
|
+
|
|
681
|
+
Returns:
|
|
682
|
+
A DatasetResult object containing the dataset result (as a polars DataFrame), its description, and the column details.
|
|
683
|
+
"""
|
|
684
|
+
result = await self._dataset_result(name, selections=selections, user=user, configurables=configurables, check_user_access=False)
|
|
685
|
+
return result
|
|
686
|
+
|
|
687
|
+
async def dashboard(
|
|
688
|
+
self, name: str, *, selections: dict[str, t.Any] = {}, user: AbstractUser | None = None, dashboard_type: t.Type[T] = d.PngDashboard,
|
|
689
|
+
configurables: dict[str, str] = {}
|
|
690
|
+
) -> T:
|
|
691
|
+
"""
|
|
692
|
+
Async method to retrieve a dashboard given parameter selections.
|
|
693
|
+
|
|
694
|
+
Arguments:
|
|
695
|
+
name: The name of the dashboard to retrieve.
|
|
696
|
+
selections: A dictionary of parameter selections to apply to the dashboard. Optional, default is empty dictionary.
|
|
697
|
+
user: The user to use for authentication. If None, no user is used. Optional, default is None.
|
|
698
|
+
dashboard_type: Return type of the method (mainly used for type hints). For instance, provide PngDashboard if you want the return type to be a PngDashboard. Optional, default is squirrels.Dashboard.
|
|
699
|
+
|
|
700
|
+
Returns:
|
|
701
|
+
The dashboard type specified by the "dashboard_type" argument.
|
|
702
|
+
"""
|
|
703
|
+
if user is None:
|
|
704
|
+
user = self._guest_user
|
|
705
|
+
|
|
706
|
+
scope = self._dashboards[name].config.scope
|
|
707
|
+
if not self._auth.can_user_access_scope(user, scope):
|
|
708
|
+
raise self._permission_error(user, "dashboard", name, scope.name)
|
|
709
|
+
|
|
710
|
+
async def get_dataset_df(dataset_name: str, fixed_params: dict[str, t.Any]) -> pl.DataFrame:
|
|
711
|
+
final_selections = {**selections, **fixed_params}
|
|
712
|
+
result = await self.dataset_result(
|
|
713
|
+
dataset_name, selections=final_selections, user=user, configurables=configurables
|
|
714
|
+
)
|
|
715
|
+
return result.df
|
|
716
|
+
|
|
717
|
+
dashboard_config = self._dashboards[name].config
|
|
718
|
+
parameter_set = self._param_cfg_set.apply_selections(dashboard_config.parameters, selections, user)
|
|
719
|
+
prms = parameter_set.get_parameters_as_dict()
|
|
720
|
+
|
|
721
|
+
configurables = {**self._manifest_cfg.get_default_configurables(overrides=dashboard_config.configurables), **configurables}
|
|
722
|
+
context = {}
|
|
723
|
+
ctx_args = m.ContextArgs(
|
|
724
|
+
**self._param_args.__dict__, user=user, prms=prms, configurables=configurables, _conn_args=self._conn_args
|
|
725
|
+
)
|
|
726
|
+
self._context_func(context, ctx_args)
|
|
727
|
+
|
|
728
|
+
args = d.DashboardArgs(
|
|
729
|
+
**ctx_args.__dict__, ctx=context, _get_dataset=get_dataset_df
|
|
730
|
+
)
|
|
731
|
+
try:
|
|
732
|
+
return await self._dashboards[name].get_dashboard(args, dashboard_type=dashboard_type)
|
|
733
|
+
except KeyError:
|
|
734
|
+
raise KeyError(f"No dashboard file found for: {name}")
|
|
735
|
+
|
|
736
|
+
async def query_models(
|
|
737
|
+
self, sql_query: str, *, user: AbstractUser | None = None, selections: dict[str, t.Any] = {}, configurables: dict[str, str] = {}
|
|
738
|
+
) -> dr.DatasetResult:
|
|
739
|
+
if user is None:
|
|
740
|
+
user = self._guest_user
|
|
741
|
+
|
|
742
|
+
dag = await self._get_compiled_dag(user=user, sql_query=sql_query, selections=selections, configurables=configurables)
|
|
743
|
+
await dag._run_models()
|
|
744
|
+
assert isinstance(dag.target_model.result, pl.LazyFrame)
|
|
745
|
+
df = self._enforce_max_result_rows(dag.target_model.result, "query")
|
|
746
|
+
return dr.DatasetResult(
|
|
747
|
+
target_model_config=dag.target_model.model_config,
|
|
748
|
+
df=df.with_row_index("_row_num", offset=1)
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
async def get_compiled_model_query(
|
|
752
|
+
self, model_name: str, *, user: AbstractUser | None = None, selections: dict[str, t.Any] = {}, configurables: dict[str, str] = {}
|
|
753
|
+
) -> rm.CompiledQueryModel:
|
|
754
|
+
"""
|
|
755
|
+
Compile the specified data model and return its language and compiled definition.
|
|
756
|
+
"""
|
|
757
|
+
if user is None:
|
|
758
|
+
user = self._guest_user
|
|
759
|
+
|
|
760
|
+
name = u.normalize_name(model_name)
|
|
761
|
+
models_dict = self._get_models_dict(always_python_df=False)
|
|
762
|
+
if name not in models_dict:
|
|
763
|
+
raise InvalidInputError(404, "model_not_found", f"No data model found with name: {model_name}")
|
|
764
|
+
|
|
765
|
+
model = models_dict[name]
|
|
766
|
+
# Only build, dbview, and federate models support runtime compiled definition in this context
|
|
767
|
+
if not isinstance(model, (m.BuildModel, m.DbviewModel, m.FederateModel)):
|
|
768
|
+
raise InvalidInputError(400, "unsupported_model_type", "Only build, dbview, and federate models currently support compiled definition via this endpoint")
|
|
769
|
+
|
|
770
|
+
# Build a DAG with this model as the target, without a dataset context
|
|
771
|
+
model.is_target = True
|
|
772
|
+
dag = m.DAG(None, model, models_dict, self._vdl_catalog_db_path, self._logger)
|
|
773
|
+
|
|
774
|
+
cfg = {**self._manifest_cfg.get_default_configurables(), **configurables}
|
|
775
|
+
await dag.execute(
|
|
776
|
+
self._param_args, self._param_cfg_set, self._context_func, user, selections, runquery=False, configurables=cfg
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
language = "sql" if isinstance(model.query_file, mq.SqlQueryFile) else "python"
|
|
780
|
+
if isinstance(model, m.BuildModel):
|
|
781
|
+
# Compile SQL build models; Python build models not yet supported
|
|
782
|
+
if isinstance(model.query_file, mq.SqlQueryFile):
|
|
783
|
+
static_models = self._get_static_models()
|
|
784
|
+
compiled = model._compile_sql_model(model.query_file, self._conn_args, static_models)
|
|
785
|
+
definition = compiled.query
|
|
786
|
+
else:
|
|
787
|
+
definition = "# Compiling Python build models is currently not supported. This will be available in a future version of Squirrels..."
|
|
788
|
+
elif isinstance(model.compiled_query, mq.SqlModelQuery):
|
|
789
|
+
definition = model.compiled_query.query
|
|
790
|
+
elif isinstance(model.compiled_query, mq.PyModelQuery):
|
|
791
|
+
definition = "# Compiling Python data models is currently not supported. This will be available in a future version of Squirrels..."
|
|
792
|
+
else:
|
|
793
|
+
raise NotImplementedError(f"Query type not supported: {model.compiled_query.__class__.__name__}")
|
|
794
|
+
|
|
795
|
+
return rm.CompiledQueryModel(language=language, definition=definition, placeholders=dag.placeholders)
|
|
796
|
+
|