squirrels 0.5.0b3__py3-none-any.whl → 0.6.0.post0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- squirrels/__init__.py +4 -0
- squirrels/_api_routes/__init__.py +5 -0
- squirrels/_api_routes/auth.py +337 -0
- squirrels/_api_routes/base.py +196 -0
- squirrels/_api_routes/dashboards.py +156 -0
- squirrels/_api_routes/data_management.py +148 -0
- squirrels/_api_routes/datasets.py +220 -0
- squirrels/_api_routes/project.py +289 -0
- squirrels/_api_server.py +440 -792
- squirrels/_arguments/__init__.py +0 -0
- squirrels/_arguments/{_init_time_args.py → init_time_args.py} +23 -43
- squirrels/_arguments/{_run_time_args.py → run_time_args.py} +32 -68
- squirrels/_auth.py +590 -264
- squirrels/_command_line.py +130 -58
- squirrels/_compile_prompts.py +147 -0
- squirrels/_connection_set.py +16 -15
- squirrels/_constants.py +36 -11
- squirrels/_dashboards.py +179 -0
- squirrels/_data_sources.py +40 -34
- squirrels/_dataset_types.py +16 -11
- squirrels/_env_vars.py +209 -0
- squirrels/_exceptions.py +9 -37
- squirrels/_http_error_responses.py +52 -0
- squirrels/_initializer.py +7 -6
- squirrels/_logging.py +121 -0
- squirrels/_manifest.py +155 -77
- squirrels/_mcp_server.py +578 -0
- squirrels/_model_builder.py +11 -55
- squirrels/_model_configs.py +5 -5
- squirrels/_model_queries.py +1 -1
- squirrels/_models.py +276 -143
- squirrels/_package_data/base_project/.env +1 -24
- squirrels/_package_data/base_project/.env.example +31 -17
- squirrels/_package_data/base_project/connections.yml +4 -3
- squirrels/_package_data/base_project/dashboards/dashboard_example.py +13 -7
- squirrels/_package_data/base_project/dashboards/dashboard_example.yml +6 -6
- squirrels/_package_data/base_project/docker/Dockerfile +2 -2
- squirrels/_package_data/base_project/docker/compose.yml +1 -1
- squirrels/_package_data/base_project/duckdb_init.sql +1 -0
- squirrels/_package_data/base_project/models/builds/build_example.py +2 -2
- squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +7 -2
- squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +16 -10
- squirrels/_package_data/base_project/models/federates/federate_example.py +27 -17
- squirrels/_package_data/base_project/models/federates/federate_example.sql +3 -7
- squirrels/_package_data/base_project/models/federates/federate_example.yml +7 -7
- squirrels/_package_data/base_project/models/sources.yml +5 -6
- squirrels/_package_data/base_project/parameters.yml +24 -38
- squirrels/_package_data/base_project/pyconfigs/connections.py +8 -3
- squirrels/_package_data/base_project/pyconfigs/context.py +26 -14
- squirrels/_package_data/base_project/pyconfigs/parameters.py +124 -81
- squirrels/_package_data/base_project/pyconfigs/user.py +48 -15
- squirrels/_package_data/base_project/resources/public/.gitkeep +0 -0
- squirrels/_package_data/base_project/seeds/seed_categories.yml +1 -1
- squirrels/_package_data/base_project/seeds/seed_subcategories.yml +1 -1
- squirrels/_package_data/base_project/squirrels.yml.j2 +21 -31
- squirrels/_package_data/templates/login_successful.html +53 -0
- squirrels/_package_data/templates/squirrels_studio.html +22 -0
- squirrels/_parameter_configs.py +43 -22
- squirrels/_parameter_options.py +1 -1
- squirrels/_parameter_sets.py +41 -30
- squirrels/_parameters.py +560 -123
- squirrels/_project.py +487 -277
- squirrels/_py_module.py +71 -10
- squirrels/_request_context.py +33 -0
- squirrels/_schemas/__init__.py +0 -0
- squirrels/_schemas/auth_models.py +83 -0
- squirrels/_schemas/query_param_models.py +70 -0
- squirrels/_schemas/request_models.py +26 -0
- squirrels/_schemas/response_models.py +286 -0
- squirrels/_seeds.py +52 -13
- squirrels/_sources.py +29 -23
- squirrels/_utils.py +221 -42
- squirrels/_version.py +1 -3
- squirrels/arguments.py +7 -2
- squirrels/auth.py +4 -0
- squirrels/connections.py +2 -0
- squirrels/dashboards.py +3 -1
- squirrels/data_sources.py +6 -0
- squirrels/parameter_options.py +5 -0
- squirrels/parameters.py +5 -0
- squirrels/types.py +10 -3
- squirrels-0.6.0.post0.dist-info/METADATA +148 -0
- squirrels-0.6.0.post0.dist-info/RECORD +101 -0
- {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/WHEEL +1 -1
- squirrels/_api_response_models.py +0 -190
- squirrels/_dashboard_types.py +0 -82
- squirrels/_dashboards_io.py +0 -79
- squirrels-0.5.0b3.dist-info/METADATA +0 -110
- squirrels-0.5.0b3.dist-info/RECORD +0 -80
- /squirrels/_package_data/base_project/{assets → resources}/expenses.db +0 -0
- /squirrels/_package_data/base_project/{assets → resources}/weather.db +0 -0
- {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/entry_points.txt +0 -0
- {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/licenses/LICENSE +0 -0
squirrels/_project.py
CHANGED
|
@@ -1,39 +1,25 @@
|
|
|
1
|
-
from
|
|
2
|
-
from
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
from dotenv import dotenv_values, load_dotenv
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
import asyncio, typing as t, functools as ft, shutil, json, os
|
|
5
|
-
import
|
|
6
|
-
import sqlglot, sqlglot.expressions
|
|
5
|
+
import sqlglot, sqlglot.expressions, duckdb, polars as pl
|
|
7
6
|
|
|
8
|
-
from ._auth import Authenticator,
|
|
7
|
+
from ._auth import Authenticator, AuthProviderArgs, ProviderFunctionType
|
|
8
|
+
from ._schemas.auth_models import CustomUserFields, AbstractUser, GuestUser, RegisteredUser
|
|
9
|
+
from ._schemas import response_models as rm
|
|
9
10
|
from ._model_builder import ModelBuilder
|
|
11
|
+
from ._env_vars import SquirrelsEnvVars
|
|
10
12
|
from ._exceptions import InvalidInputError, ConfigurationError
|
|
11
|
-
from . import
|
|
13
|
+
from ._py_module import PyModule
|
|
14
|
+
from . import _dashboards as d, _utils as u, _constants as c, _manifest as mf, _connection_set as cs
|
|
12
15
|
from . import _seeds as s, _models as m, _model_configs as mc, _model_queries as mq, _sources as so
|
|
13
|
-
from . import _parameter_sets as ps,
|
|
14
|
-
|
|
15
|
-
T = t.TypeVar("T", bound=dash.Dashboard)
|
|
16
|
-
M = t.TypeVar("M", bound=m.DataModel)
|
|
16
|
+
from . import _parameter_sets as ps, _dataset_types as dr, _logging as l
|
|
17
17
|
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from ._api_server import FastAPIComponents
|
|
18
20
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
super().format(record)
|
|
22
|
-
info = {
|
|
23
|
-
"timestamp": self.formatTime(record),
|
|
24
|
-
"project_id": record.name,
|
|
25
|
-
"level": record.levelname,
|
|
26
|
-
"message": record.getMessage(),
|
|
27
|
-
"thread": record.thread,
|
|
28
|
-
"thread_name": record.threadName,
|
|
29
|
-
"process": record.process,
|
|
30
|
-
**record.__dict__.get("info", {})
|
|
31
|
-
}
|
|
32
|
-
output = {
|
|
33
|
-
"data": record.__dict__.get("data", {}),
|
|
34
|
-
"info": info
|
|
35
|
-
}
|
|
36
|
-
return json.dumps(output)
|
|
21
|
+
T = t.TypeVar("T", bound=d.Dashboard)
|
|
22
|
+
M = t.TypeVar("M", bound=m.DataModel)
|
|
37
23
|
|
|
38
24
|
|
|
39
25
|
class SquirrelsProject:
|
|
@@ -41,114 +27,179 @@ class SquirrelsProject:
|
|
|
41
27
|
Initiate an instance of this class to interact with a Squirrels project through Python code. For example this can be handy to experiment with the datasets produced by Squirrels in a Jupyter notebook.
|
|
42
28
|
"""
|
|
43
29
|
|
|
44
|
-
def __init__(
|
|
30
|
+
def __init__(
|
|
31
|
+
self, *, project_path: str = ".", load_dotenv_globally: bool = False,
|
|
32
|
+
log_to_file: bool = False, log_level: str | None = None, log_format: str | None = None,
|
|
33
|
+
) -> None:
|
|
45
34
|
"""
|
|
46
35
|
Constructor for SquirrelsProject class. Loads the file contents of the Squirrels project into memory as member fields.
|
|
47
36
|
|
|
48
37
|
Arguments:
|
|
49
|
-
|
|
50
|
-
log_level: The logging level to use. Options are "DEBUG", "INFO", and "WARNING". Default is "INFO".
|
|
51
|
-
|
|
52
|
-
log_format: The format of the log records. Options are "text" and "json". Default is "text".
|
|
38
|
+
project_path: The path to the Squirrels project file. Defaults to the current working directory.
|
|
39
|
+
log_level: The logging level to use. Options are "DEBUG", "INFO", and "WARNING". Default is from SQRL_LOGGING__LEVEL environment variable or "INFO".
|
|
40
|
+
log_to_file: Whether to enable logging to file(s) in the "logs/" folder (or a custom folder). Default is from SQRL_LOGGING__TO_FILE environment variable or False.
|
|
41
|
+
log_format: The format of the log records. Options are "text" and "json". Default is from SQRL_LOGGING__FORMAT environment variable or "text".
|
|
53
42
|
"""
|
|
54
|
-
|
|
55
|
-
self._logger = self._get_logger(self._filepath, log_file, log_level, log_format)
|
|
56
|
-
|
|
57
|
-
def _get_logger(self, base_path: str, log_file: str | None, log_level: str, log_format: str) -> u.Logger:
|
|
58
|
-
logger = u.Logger(name=uuid4().hex)
|
|
59
|
-
logger.setLevel(log_level.upper())
|
|
60
|
-
|
|
61
|
-
handler = l.StreamHandler()
|
|
62
|
-
handler.setLevel("WARNING")
|
|
63
|
-
handler.setFormatter(l.Formatter("%(levelname)s: %(asctime)s - %(message)s"))
|
|
64
|
-
logger.addHandler(handler)
|
|
65
|
-
|
|
66
|
-
if log_format.lower() == "json":
|
|
67
|
-
formatter = _CustomJsonFormatter()
|
|
68
|
-
elif log_format.lower() == "text":
|
|
69
|
-
formatter = l.Formatter("[%(name)s] %(asctime)s - %(levelname)s - %(message)s")
|
|
70
|
-
else:
|
|
71
|
-
raise ValueError("log_format must be either 'text' or 'json'")
|
|
72
|
-
|
|
73
|
-
if log_file:
|
|
74
|
-
path = Path(base_path, c.LOGS_FOLDER, log_file)
|
|
75
|
-
path.parent.mkdir(parents=True, exist_ok=True)
|
|
43
|
+
project_path = str(Path(project_path).resolve())
|
|
76
44
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
45
|
+
self._project_path = project_path
|
|
46
|
+
self._env_vars_unformatted = self._load_env_vars(project_path, load_dotenv_globally)
|
|
47
|
+
self._env_vars = SquirrelsEnvVars(project_path=project_path, **self._env_vars_unformatted)
|
|
48
|
+
self._vdl_catalog_db_path = self._env_vars.vdl_catalog_db_path
|
|
80
49
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
50
|
+
self._logger = self._get_logger(project_path, self._env_vars, log_to_file, log_level, log_format)
|
|
51
|
+
self._ensure_virtual_datalake_exists(project_path, self._vdl_catalog_db_path, self._env_vars.vdl_data_path)
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def _load_env_vars(project_path: str, load_dotenv_globally: bool) -> dict[str, str]:
|
|
85
55
|
dotenv_files = [c.DOTENV_FILE, c.DOTENV_LOCAL_FILE]
|
|
86
56
|
dotenv_vars = {}
|
|
87
57
|
for file in dotenv_files:
|
|
88
|
-
|
|
58
|
+
full_path = u.Path(project_path, file)
|
|
59
|
+
if load_dotenv_globally:
|
|
60
|
+
load_dotenv(full_path)
|
|
61
|
+
dotenv_vars.update({k: v for k, v in dotenv_values(full_path).items() if v is not None})
|
|
89
62
|
return {**os.environ, **dotenv_vars}
|
|
90
63
|
|
|
64
|
+
@staticmethod
|
|
65
|
+
def _get_logger(
|
|
66
|
+
filepath: str, env_vars: SquirrelsEnvVars, log_to_file: bool, log_level: str | None, log_format: str | None
|
|
67
|
+
) -> u.Logger:
|
|
68
|
+
# CLI arguments take precedence over environment variables
|
|
69
|
+
log_level = log_level if log_level is not None else env_vars.logging_level
|
|
70
|
+
log_format = log_format if log_format is not None else env_vars.logging_format
|
|
71
|
+
log_to_file = env_vars.logging_to_file or log_to_file
|
|
72
|
+
log_file_size_mb = float(env_vars.logging_file_size_mb)
|
|
73
|
+
log_file_backup_count = int(env_vars.logging_file_backup_count)
|
|
74
|
+
return l.get_logger(filepath, log_to_file, log_level, log_format, log_file_size_mb, log_file_backup_count)
|
|
75
|
+
|
|
76
|
+
@staticmethod
|
|
77
|
+
def _ensure_virtual_datalake_exists(project_path: str, vdl_catalog_db_path: str, vdl_data_path: str) -> None:
|
|
78
|
+
target_path = u.Path(project_path, c.TARGET_FOLDER)
|
|
79
|
+
target_path.mkdir(parents=True, exist_ok=True)
|
|
80
|
+
|
|
81
|
+
# Attempt to set up the virtual data lake with DATA_PATH if possible
|
|
82
|
+
try:
|
|
83
|
+
is_ducklake = vdl_catalog_db_path.startswith("ducklake:")
|
|
84
|
+
|
|
85
|
+
options = f"(DATA_PATH '{vdl_data_path}')" if is_ducklake else ""
|
|
86
|
+
attach_stmt = f"ATTACH '{vdl_catalog_db_path}' AS vdl {options}"
|
|
87
|
+
with duckdb.connect() as conn:
|
|
88
|
+
conn.execute(attach_stmt)
|
|
89
|
+
# TODO: support incremental loads for build models and avoid cleaning up old files all the time
|
|
90
|
+
conn.execute("CALL ducklake_expire_snapshots('vdl', older_than => now())")
|
|
91
|
+
conn.execute("CALL ducklake_cleanup_old_files('vdl', cleanup_all => true)")
|
|
92
|
+
|
|
93
|
+
except Exception as e:
|
|
94
|
+
if "DATA_PATH parameter" in str(e):
|
|
95
|
+
first_line = str(e).split("\n")[0]
|
|
96
|
+
note = "NOTE: Squirrels does not allow changing the data path for an existing Virtual Data Lake (VDL)"
|
|
97
|
+
raise u.ConfigurationError(f"{first_line}\n\n{note}")
|
|
98
|
+
|
|
99
|
+
if is_ducklake and not any(x in vdl_catalog_db_path for x in [":sqlite:", ":postgres:", ":mysql:"]):
|
|
100
|
+
extended_error = "\n- Note: if you're using DuckDB for the metadata database, only one process can connect to the VDL at a time."
|
|
101
|
+
else:
|
|
102
|
+
extended_error = ""
|
|
103
|
+
|
|
104
|
+
raise u.ConfigurationError(f"Failed to attach Virtual Data Lake (VDL).{extended_error}") from e
|
|
105
|
+
|
|
91
106
|
@ft.cached_property
|
|
92
107
|
def _manifest_cfg(self) -> mf.ManifestConfig:
|
|
93
|
-
return mf.ManifestIO.load_from_file(self._logger, self.
|
|
108
|
+
return mf.ManifestIO.load_from_file(self._logger, self._project_path, self._env_vars_unformatted)
|
|
94
109
|
|
|
95
110
|
@ft.cached_property
|
|
96
111
|
def _seeds(self) -> s.Seeds:
|
|
97
|
-
return s.SeedsIO.load_files(self._logger, self.
|
|
112
|
+
return s.SeedsIO.load_files(self._logger, self._env_vars)
|
|
98
113
|
|
|
99
114
|
@ft.cached_property
|
|
100
115
|
def _sources(self) -> so.Sources:
|
|
101
|
-
return so.SourcesIO.load_file(self._logger, self.
|
|
116
|
+
return so.SourcesIO.load_file(self._logger, self._env_vars, self._env_vars_unformatted)
|
|
102
117
|
|
|
103
118
|
@ft.cached_property
|
|
104
119
|
def _build_model_files(self) -> dict[str, mq.QueryFileWithConfig]:
|
|
105
|
-
return m.ModelsIO.load_build_files(self._logger, self.
|
|
120
|
+
return m.ModelsIO.load_build_files(self._logger, self._env_vars)
|
|
106
121
|
|
|
107
122
|
@ft.cached_property
|
|
108
123
|
def _dbview_model_files(self) -> dict[str, mq.QueryFileWithConfig]:
|
|
109
|
-
return m.ModelsIO.load_dbview_files(self._logger, self.
|
|
124
|
+
return m.ModelsIO.load_dbview_files(self._logger, self._env_vars)
|
|
110
125
|
|
|
111
126
|
@ft.cached_property
|
|
112
127
|
def _federate_model_files(self) -> dict[str, mq.QueryFileWithConfig]:
|
|
113
|
-
return m.ModelsIO.load_federate_files(self._logger, self.
|
|
128
|
+
return m.ModelsIO.load_federate_files(self._logger, self._env_vars)
|
|
114
129
|
|
|
115
130
|
@ft.cached_property
|
|
116
131
|
def _context_func(self) -> m.ContextFunc:
|
|
117
|
-
return m.ModelsIO.load_context_func(self._logger, self.
|
|
132
|
+
return m.ModelsIO.load_context_func(self._logger, self._project_path)
|
|
118
133
|
|
|
119
134
|
@ft.cached_property
|
|
120
135
|
def _dashboards(self) -> dict[str, d.DashboardDefinition]:
|
|
121
|
-
return d.DashboardsIO.load_files(
|
|
136
|
+
return d.DashboardsIO.load_files(
|
|
137
|
+
self._logger, self._project_path, self._manifest_cfg.project_variables.auth_type, self._manifest_cfg.configurables
|
|
138
|
+
)
|
|
122
139
|
|
|
123
140
|
@ft.cached_property
|
|
124
141
|
def _conn_args(self) -> cs.ConnectionsArgs:
|
|
125
|
-
|
|
142
|
+
proj_vars = self._manifest_cfg.project_variables.model_dump()
|
|
143
|
+
conn_args = cs.ConnectionsArgs(self._project_path, proj_vars, self._env_vars_unformatted)
|
|
144
|
+
return conn_args
|
|
126
145
|
|
|
127
146
|
@ft.cached_property
|
|
128
147
|
def _conn_set(self) -> cs.ConnectionSet:
|
|
129
|
-
return cs.ConnectionSetIO.load_from_file(self._logger, self.
|
|
148
|
+
return cs.ConnectionSetIO.load_from_file(self._logger, self._project_path, self._manifest_cfg, self._conn_args)
|
|
149
|
+
|
|
150
|
+
@ft.cached_property
|
|
151
|
+
def _custom_user_fields_cls_and_provider_functions(self) -> tuple[type[CustomUserFields], list[ProviderFunctionType]]:
|
|
152
|
+
user_module_path = u.Path(self._project_path, c.PYCONFIGS_FOLDER, c.USER_FILE)
|
|
153
|
+
user_module = PyModule(user_module_path, self._project_path)
|
|
154
|
+
|
|
155
|
+
# Load CustomUserFields class (adds to Authenticator.providers as side effect)
|
|
156
|
+
CustomUserFieldsCls = user_module.get_func_or_class("CustomUserFields", default_attr=CustomUserFields)
|
|
157
|
+
provider_functions = Authenticator.providers
|
|
158
|
+
Authenticator.providers = []
|
|
159
|
+
|
|
160
|
+
if not issubclass(CustomUserFieldsCls, CustomUserFields):
|
|
161
|
+
raise ConfigurationError(f"CustomUserFields class in '{c.USER_FILE}' must inherit from CustomUserFields")
|
|
162
|
+
|
|
163
|
+
return CustomUserFieldsCls, provider_functions
|
|
130
164
|
|
|
131
165
|
@ft.cached_property
|
|
132
166
|
def _auth(self) -> Authenticator:
|
|
133
|
-
|
|
167
|
+
auth_args = AuthProviderArgs(**self._conn_args.__dict__)
|
|
168
|
+
CustomUserFieldsCls, provider_functions = self._custom_user_fields_cls_and_provider_functions
|
|
169
|
+
external_only = (self._manifest_cfg.project_variables.auth_strategy == mf.AuthStrategy.EXTERNAL)
|
|
170
|
+
|
|
171
|
+
if external_only and len(provider_functions) != 1:
|
|
172
|
+
raise ConfigurationError(f"When auth_strategy is 'external', there must be exactly one auth provider function. Found {len(provider_functions)} auth providers.")
|
|
173
|
+
|
|
174
|
+
return Authenticator(
|
|
175
|
+
self._logger, self._env_vars, auth_args, provider_functions,
|
|
176
|
+
custom_user_fields_cls=CustomUserFieldsCls, external_only=external_only
|
|
177
|
+
)
|
|
134
178
|
|
|
135
179
|
@ft.cached_property
|
|
136
|
-
def
|
|
137
|
-
|
|
180
|
+
def _guest_user(self) -> AbstractUser:
|
|
181
|
+
custom_fields = self._auth.CustomUserFields()
|
|
182
|
+
return GuestUser(username="", custom_fields=custom_fields)
|
|
183
|
+
|
|
184
|
+
@ft.cached_property
|
|
185
|
+
def _admin_user(self) -> AbstractUser:
|
|
186
|
+
custom_fields = self._auth.CustomUserFields()
|
|
187
|
+
return RegisteredUser(username="", access_level="admin", custom_fields=custom_fields)
|
|
138
188
|
|
|
139
189
|
@ft.cached_property
|
|
140
190
|
def _param_args(self) -> ps.ParametersArgs:
|
|
141
|
-
|
|
191
|
+
conn_args = self._conn_args
|
|
192
|
+
return ps.ParametersArgs(**conn_args.__dict__)
|
|
142
193
|
|
|
143
194
|
@ft.cached_property
|
|
144
195
|
def _param_cfg_set(self) -> ps.ParameterConfigsSet:
|
|
145
196
|
return ps.ParameterConfigsSetIO.load_from_file(
|
|
146
|
-
self._logger, self.
|
|
197
|
+
self._logger, self._env_vars, self._manifest_cfg, self._seeds, self._conn_set, self._param_args
|
|
147
198
|
)
|
|
148
199
|
|
|
149
200
|
@ft.cached_property
|
|
150
201
|
def _j2_env(self) -> u.EnvironmentWithMacros:
|
|
151
|
-
env = u.EnvironmentWithMacros(self._logger, loader=u.j2.FileSystemLoader(self.
|
|
202
|
+
env = u.EnvironmentWithMacros(self._logger, loader=u.j2.FileSystemLoader(self._project_path))
|
|
152
203
|
|
|
153
204
|
def value_to_str(value: t.Any, attribute: str | None = None) -> str:
|
|
154
205
|
if attribute is None:
|
|
@@ -170,11 +221,26 @@ class SquirrelsProject:
|
|
|
170
221
|
env.filters["quote_and_join"] = quote_and_join
|
|
171
222
|
return env
|
|
172
223
|
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
224
|
+
def get_fastapi_components(
|
|
225
|
+
self, *, no_cache: bool = False, host: str = "localhost", port: int = 8000,
|
|
226
|
+
mount_path_format: str = "/analytics/{project_name}/v{project_version}"
|
|
227
|
+
) -> "FastAPIComponents":
|
|
228
|
+
"""
|
|
229
|
+
Get the FastAPI components for the Squirrels project including mount path, lifespan, and FastAPI app.
|
|
230
|
+
|
|
231
|
+
Arguments:
|
|
232
|
+
no_cache: Whether to disable caching for parameter options, datasets, and dashboard results in the API server.
|
|
233
|
+
host: The host the API server will listen on. Only used for the welcome banner.
|
|
234
|
+
port: The port the API server will listen on. Only used for the welcome banner.
|
|
235
|
+
mount_path_format: The format of the mount path. Use {project_name} and {project_version} as placeholders.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
A FastAPIComponents object containing the mount path, lifespan, and FastAPI app.
|
|
239
|
+
"""
|
|
240
|
+
from ._api_server import ApiServer
|
|
241
|
+
api_server = ApiServer(no_cache=no_cache, project=self)
|
|
242
|
+
return api_server.get_fastapi_components(host=host, port=port, mount_path_format=mount_path_format)
|
|
243
|
+
|
|
178
244
|
def close(self) -> None:
|
|
179
245
|
"""
|
|
180
246
|
Deliberately close any open resources within the Squirrels project, such as database connections (instead of relying on the garbage collector).
|
|
@@ -182,6 +248,9 @@ class SquirrelsProject:
|
|
|
182
248
|
self._conn_set.dispose()
|
|
183
249
|
self._auth.close()
|
|
184
250
|
|
|
251
|
+
def __enter__(self):
|
|
252
|
+
return self
|
|
253
|
+
|
|
185
254
|
def __exit__(self, exc_type, exc_val, traceback):
|
|
186
255
|
self.close()
|
|
187
256
|
|
|
@@ -197,60 +266,59 @@ class SquirrelsProject:
|
|
|
197
266
|
|
|
198
267
|
seeds_dict = self._seeds.get_dataframes()
|
|
199
268
|
for key, seed in seeds_dict.items():
|
|
200
|
-
self._add_model(models_dict, m.Seed(key, seed.config, seed.df, logger=self._logger,
|
|
269
|
+
self._add_model(models_dict, m.Seed(key, seed.config, seed.df, logger=self._logger, conn_set=self._conn_set))
|
|
201
270
|
|
|
202
271
|
for source_name, source_config in self._sources.sources.items():
|
|
203
|
-
self._add_model(models_dict, m.SourceModel(source_name, source_config, logger=self._logger,
|
|
272
|
+
self._add_model(models_dict, m.SourceModel(source_name, source_config, logger=self._logger, conn_set=self._conn_set))
|
|
204
273
|
|
|
205
274
|
for name, val in self._build_model_files.items():
|
|
206
|
-
model = m.BuildModel(name, val.config, val.query_file, logger=self._logger,
|
|
275
|
+
model = m.BuildModel(name, val.config, val.query_file, logger=self._logger, conn_set=self._conn_set, j2_env=self._j2_env)
|
|
207
276
|
self._add_model(models_dict, model)
|
|
208
277
|
|
|
209
278
|
return models_dict
|
|
210
279
|
|
|
211
280
|
|
|
212
|
-
async def build(self, *, full_refresh: bool = False, select: str | None = None
|
|
281
|
+
async def build(self, *, full_refresh: bool = False, select: str | None = None) -> None:
|
|
213
282
|
"""
|
|
214
|
-
Build the
|
|
283
|
+
Build the Virtual Data Lake (VDL) for the Squirrels project
|
|
215
284
|
|
|
216
285
|
Arguments:
|
|
217
|
-
full_refresh: Whether to drop all tables and rebuild the
|
|
218
|
-
|
|
286
|
+
full_refresh: Whether to drop all tables and rebuild the VDL from scratch. Default is False.
|
|
287
|
+
select: The name of a specific model to build. If None, all models are built. Default is None.
|
|
219
288
|
"""
|
|
220
289
|
models_dict: dict[str, m.StaticModel] = self._get_static_models()
|
|
221
|
-
builder = ModelBuilder(self.
|
|
222
|
-
await builder.build(full_refresh, select
|
|
290
|
+
builder = ModelBuilder(self._vdl_catalog_db_path, self._conn_set, models_dict, self._conn_args, self._logger)
|
|
291
|
+
await builder.build(full_refresh, select)
|
|
223
292
|
|
|
224
293
|
def _get_models_dict(self, always_python_df: bool) -> dict[str, m.DataModel]:
|
|
225
|
-
models_dict: dict[str, m.DataModel] =
|
|
294
|
+
models_dict: dict[str, m.DataModel] = self._get_static_models()
|
|
226
295
|
|
|
227
296
|
for name, val in self._dbview_model_files.items():
|
|
228
297
|
self._add_model(models_dict, m.DbviewModel(
|
|
229
|
-
name, val.config, val.query_file, logger=self._logger,
|
|
298
|
+
name, val.config, val.query_file, logger=self._logger, conn_set=self._conn_set, j2_env=self._j2_env
|
|
230
299
|
))
|
|
231
300
|
models_dict[name].needs_python_df = always_python_df
|
|
232
301
|
|
|
233
302
|
for name, val in self._federate_model_files.items():
|
|
234
303
|
self._add_model(models_dict, m.FederateModel(
|
|
235
|
-
name, val.config, val.query_file, logger=self._logger,
|
|
304
|
+
name, val.config, val.query_file, logger=self._logger, conn_set=self._conn_set, j2_env=self._j2_env
|
|
236
305
|
))
|
|
237
306
|
models_dict[name].needs_python_df = always_python_df
|
|
238
307
|
|
|
239
308
|
return models_dict
|
|
240
309
|
|
|
241
|
-
def _generate_dag(self, dataset: str
|
|
242
|
-
models_dict = self._get_models_dict(always_python_df)
|
|
310
|
+
def _generate_dag(self, dataset: str) -> m.DAG:
|
|
311
|
+
models_dict = self._get_models_dict(always_python_df=False)
|
|
243
312
|
|
|
244
313
|
dataset_config = self._manifest_cfg.datasets[dataset]
|
|
245
|
-
|
|
246
|
-
target_model = models_dict[target_model_name]
|
|
314
|
+
target_model = models_dict[dataset_config.model]
|
|
247
315
|
target_model.is_target = True
|
|
248
|
-
dag = m.DAG(dataset_config, target_model, models_dict, self.
|
|
316
|
+
dag = m.DAG(dataset_config, target_model, models_dict, self._vdl_catalog_db_path, self._logger)
|
|
249
317
|
|
|
250
318
|
return dag
|
|
251
319
|
|
|
252
|
-
def _generate_dag_with_fake_target(self, sql_query: str | None) -> m.DAG:
|
|
253
|
-
models_dict = self._get_models_dict(always_python_df=
|
|
320
|
+
def _generate_dag_with_fake_target(self, sql_query: str | None, *, always_python_df: bool = False) -> m.DAG:
|
|
321
|
+
models_dict = self._get_models_dict(always_python_df=always_python_df)
|
|
254
322
|
|
|
255
323
|
if sql_query is None:
|
|
256
324
|
dependencies = set(models_dict.keys())
|
|
@@ -260,227 +328,260 @@ class SquirrelsProject:
|
|
|
260
328
|
substitutions = {}
|
|
261
329
|
for model_name in dependencies:
|
|
262
330
|
model = models_dict[model_name]
|
|
263
|
-
if isinstance(model, m.SourceModel) and not model.
|
|
264
|
-
raise InvalidInputError(
|
|
265
|
-
if isinstance(model,
|
|
266
|
-
substitutions[model_name] = f"
|
|
331
|
+
if isinstance(model, m.SourceModel) and not model.is_queryable:
|
|
332
|
+
raise InvalidInputError(400, "cannot_query_source_model", f"Source model '{model_name}' cannot be queried with DuckDB")
|
|
333
|
+
if isinstance(model, m.BuildModel):
|
|
334
|
+
substitutions[model_name] = f"vdl.{model_name}"
|
|
335
|
+
elif isinstance(model, m.SourceModel):
|
|
336
|
+
if model.model_config.load_to_vdl:
|
|
337
|
+
substitutions[model_name] = f"vdl.{model_name}"
|
|
338
|
+
else:
|
|
339
|
+
# DuckDB connection without load_to_vdl - reference via attached database
|
|
340
|
+
conn_name = model.model_config.get_connection()
|
|
341
|
+
table_name = model.model_config.get_table()
|
|
342
|
+
substitutions[model_name] = f"db_{conn_name}.{table_name}"
|
|
267
343
|
|
|
268
344
|
sql_query = parsed.transform(
|
|
269
|
-
lambda node: sqlglot.expressions.Table(this=substitutions[node.name])
|
|
345
|
+
lambda node: sqlglot.expressions.Table(this=substitutions[node.name], alias=node.alias)
|
|
270
346
|
if isinstance(node, sqlglot.expressions.Table) and node.name in substitutions
|
|
271
347
|
else node
|
|
272
348
|
).sql()
|
|
273
349
|
|
|
274
350
|
model_config = mc.FederateModelConfig(depends_on=dependencies)
|
|
275
|
-
query_file = mq.SqlQueryFile("", sql_query or "")
|
|
351
|
+
query_file = mq.SqlQueryFile("", sql_query or "SELECT 1")
|
|
276
352
|
fake_target_model = m.FederateModel(
|
|
277
|
-
"__fake_target", model_config, query_file, logger=self._logger,
|
|
353
|
+
"__fake_target", model_config, query_file, logger=self._logger, conn_set=self._conn_set, j2_env=self._j2_env
|
|
278
354
|
)
|
|
279
355
|
fake_target_model.is_target = True
|
|
280
|
-
dag = m.DAG(None, fake_target_model, models_dict, self.
|
|
356
|
+
dag = m.DAG(None, fake_target_model, models_dict, self._vdl_catalog_db_path, self._logger)
|
|
281
357
|
return dag
|
|
282
358
|
|
|
283
|
-
def
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
G = dag.to_networkx_graph()
|
|
290
|
-
|
|
291
|
-
fig, _ = plt.subplots()
|
|
292
|
-
pos = nx.multipartite_layout(G, subset_key="layer")
|
|
293
|
-
colors = [color_map[node[1]] for node in G.nodes(data="model_type")] # type: ignore
|
|
294
|
-
nx.draw(G, pos=pos, node_shape='^', node_size=1000, node_color=colors, arrowsize=20)
|
|
295
|
-
|
|
296
|
-
y_values = [val[1] for val in pos.values()]
|
|
297
|
-
scale = max(y_values) - min(y_values) if len(y_values) > 0 else 0
|
|
298
|
-
label_pos = {key: (val[0], val[1]-0.002-0.1*scale) for key, val in pos.items()}
|
|
299
|
-
nx.draw_networkx_labels(G, pos=label_pos, font_size=8)
|
|
300
|
-
|
|
301
|
-
fig.tight_layout()
|
|
302
|
-
plt.margins(x=0.1, y=0.1)
|
|
303
|
-
fig.savefig(Path(output_folder, "dag.png"))
|
|
304
|
-
plt.close(fig)
|
|
305
|
-
|
|
306
|
-
async def _get_compiled_dag(self, *, sql_query: str | None = None, selections: dict[str, t.Any] = {}, user: BaseUser | None = None) -> m.DAG:
|
|
307
|
-
dag = self._generate_dag_with_fake_target(sql_query)
|
|
359
|
+
async def _get_compiled_dag(
|
|
360
|
+
self, user: AbstractUser, *, sql_query: str | None = None, selections: dict[str, t.Any] = {}, configurables: dict[str, str] = {},
|
|
361
|
+
always_python_df: bool = False
|
|
362
|
+
) -> m.DAG:
|
|
363
|
+
dag = self._generate_dag_with_fake_target(sql_query, always_python_df=always_python_df)
|
|
308
364
|
|
|
309
|
-
|
|
310
|
-
await dag.execute(
|
|
365
|
+
configurables = {**self._manifest_cfg.get_default_configurables(), **configurables}
|
|
366
|
+
await dag.execute(
|
|
367
|
+
self._param_args, self._param_cfg_set, self._context_func, user, selections,
|
|
368
|
+
runquery=False, configurables=configurables
|
|
369
|
+
)
|
|
311
370
|
return dag
|
|
312
371
|
|
|
313
|
-
def _get_all_connections(self) -> list[
|
|
372
|
+
def _get_all_connections(self) -> list[rm.ConnectionItemModel]:
|
|
314
373
|
connections = []
|
|
315
374
|
for conn_name, conn_props in self._conn_set.get_connections_as_dict().items():
|
|
316
375
|
if isinstance(conn_props, mf.ConnectionProperties):
|
|
317
376
|
label = conn_props.label if conn_props.label is not None else conn_name
|
|
318
|
-
connections.append(
|
|
377
|
+
connections.append(rm.ConnectionItemModel(name=conn_name, label=label))
|
|
319
378
|
return connections
|
|
320
379
|
|
|
321
|
-
def _get_all_data_models(self, compiled_dag: m.DAG) -> list[
|
|
380
|
+
def _get_all_data_models(self, compiled_dag: m.DAG) -> list[rm.DataModelItem]:
|
|
322
381
|
return compiled_dag.get_all_data_models()
|
|
323
382
|
|
|
324
|
-
async def get_all_data_models(self) -> list[
|
|
383
|
+
async def get_all_data_models(self) -> list[rm.DataModelItem]:
|
|
325
384
|
"""
|
|
326
385
|
Get all data models in the project
|
|
327
386
|
|
|
328
387
|
Returns:
|
|
329
388
|
A list of DataModelItem objects
|
|
330
389
|
"""
|
|
331
|
-
compiled_dag = await self._get_compiled_dag()
|
|
390
|
+
compiled_dag = await self._get_compiled_dag(self._admin_user)
|
|
332
391
|
return self._get_all_data_models(compiled_dag)
|
|
333
392
|
|
|
334
|
-
def _get_all_data_lineage(self, compiled_dag: m.DAG) -> list[
|
|
393
|
+
def _get_all_data_lineage(self, compiled_dag: m.DAG) -> list[rm.LineageRelation]:
|
|
335
394
|
all_lineage = compiled_dag.get_all_model_lineage()
|
|
336
395
|
|
|
337
396
|
# Add dataset nodes to the lineage
|
|
338
397
|
for dataset in self._manifest_cfg.datasets.values():
|
|
339
|
-
target_dataset =
|
|
340
|
-
source_model =
|
|
341
|
-
all_lineage.append(
|
|
398
|
+
target_dataset = rm.LineageNode(name=dataset.name, type="dataset")
|
|
399
|
+
source_model = rm.LineageNode(name=dataset.model, type="model")
|
|
400
|
+
all_lineage.append(rm.LineageRelation(type="runtime", source=source_model, target=target_dataset))
|
|
342
401
|
|
|
343
402
|
# Add dashboard nodes to the lineage
|
|
344
403
|
for dashboard in self._dashboards.values():
|
|
345
|
-
target_dashboard =
|
|
404
|
+
target_dashboard = rm.LineageNode(name=dashboard.dashboard_name, type="dashboard")
|
|
346
405
|
datasets = set(x.dataset for x in dashboard.config.depends_on)
|
|
347
406
|
for dataset in datasets:
|
|
348
|
-
source_dataset =
|
|
349
|
-
all_lineage.append(
|
|
407
|
+
source_dataset = rm.LineageNode(name=dataset, type="dataset")
|
|
408
|
+
all_lineage.append(rm.LineageRelation(type="runtime", source=source_dataset, target=target_dashboard))
|
|
350
409
|
|
|
351
410
|
return all_lineage
|
|
352
411
|
|
|
353
|
-
async def get_all_data_lineage(self) -> list[
|
|
412
|
+
async def get_all_data_lineage(self) -> list[rm.LineageRelation]:
|
|
354
413
|
"""
|
|
355
414
|
Get all data lineage in the project
|
|
356
415
|
|
|
357
416
|
Returns:
|
|
358
417
|
A list of LineageRelation objects
|
|
359
418
|
"""
|
|
360
|
-
compiled_dag = await self._get_compiled_dag()
|
|
419
|
+
compiled_dag = await self._get_compiled_dag(self._admin_user)
|
|
361
420
|
return self._get_all_data_lineage(compiled_dag)
|
|
362
421
|
|
|
363
|
-
async def _write_dataset_outputs_given_test_set(
|
|
364
|
-
self, dataset: str, select: str, test_set: str | None, runquery: bool, recurse: bool
|
|
365
|
-
) -> t.Any | None:
|
|
366
|
-
dataset_conf = self._manifest_cfg.datasets[dataset]
|
|
367
|
-
default_test_set_conf = self._manifest_cfg.get_default_test_set(dataset)
|
|
368
|
-
if test_set in self._manifest_cfg.selection_test_sets:
|
|
369
|
-
test_set_conf = self._manifest_cfg.selection_test_sets[test_set]
|
|
370
|
-
elif test_set is None or test_set == default_test_set_conf.name:
|
|
371
|
-
test_set, test_set_conf = default_test_set_conf.name, default_test_set_conf
|
|
372
|
-
else:
|
|
373
|
-
raise ConfigurationError(f"No test set named '{test_set}' was found when compiling dataset '{dataset}'. The test set must be defined if not default for dataset.")
|
|
374
|
-
|
|
375
|
-
error_msg_intro = f"Cannot compile dataset '{dataset}' with test set '{test_set}'."
|
|
376
|
-
if test_set_conf.datasets is not None and dataset not in test_set_conf.datasets:
|
|
377
|
-
raise ConfigurationError(f"{error_msg_intro}\n Applicable datasets for test set '{test_set}' does not include dataset '{dataset}'.")
|
|
378
|
-
|
|
379
|
-
user_attributes = test_set_conf.user_attributes.copy() if test_set_conf.user_attributes is not None else {}
|
|
380
|
-
selections = test_set_conf.parameters.copy()
|
|
381
|
-
username, is_admin = user_attributes.pop("username", ""), user_attributes.pop("is_admin", False)
|
|
382
|
-
if test_set_conf.is_authenticated:
|
|
383
|
-
user = self._auth.User(username=username, is_admin=is_admin, **user_attributes)
|
|
384
|
-
elif dataset_conf.scope == mf.PermissionScope.PUBLIC:
|
|
385
|
-
user = None
|
|
386
|
-
else:
|
|
387
|
-
raise ConfigurationError(f"{error_msg_intro}\n Non-public datasets require a test set with 'user_attributes' section defined")
|
|
388
|
-
|
|
389
|
-
if dataset_conf.scope == mf.PermissionScope.PRIVATE and not is_admin:
|
|
390
|
-
raise ConfigurationError(f"{error_msg_intro}\n Private datasets require a test set with user_attribute 'is_admin' set to true")
|
|
391
|
-
|
|
392
|
-
# always_python_df is set to True for creating CSV files from results (when runquery is True)
|
|
393
|
-
dag = self._generate_dag(dataset, target_model_name=select, always_python_df=runquery)
|
|
394
|
-
await dag.execute(
|
|
395
|
-
self._param_args, self._param_cfg_set, self._context_func, user, selections,
|
|
396
|
-
runquery=runquery, recurse=recurse, default_traits=self._manifest_cfg.get_default_traits()
|
|
397
|
-
)
|
|
398
|
-
|
|
399
|
-
output_folder = Path(self._filepath, c.TARGET_FOLDER, c.COMPILE_FOLDER, dataset, test_set)
|
|
400
|
-
if output_folder.exists():
|
|
401
|
-
shutil.rmtree(output_folder)
|
|
402
|
-
output_folder.mkdir(parents=True, exist_ok=True)
|
|
403
|
-
|
|
404
|
-
def write_placeholders() -> None:
|
|
405
|
-
output_filepath = Path(output_folder, "placeholders.json")
|
|
406
|
-
with open(output_filepath, 'w') as f:
|
|
407
|
-
json.dump(dag.placeholders, f, indent=4)
|
|
408
|
-
|
|
409
|
-
def write_model_outputs(model: m.DataModel) -> None:
|
|
410
|
-
assert isinstance(model, m.QueryModel)
|
|
411
|
-
subfolder = c.DBVIEWS_FOLDER if model.model_type == m.ModelType.DBVIEW else c.FEDERATES_FOLDER
|
|
412
|
-
subpath = Path(output_folder, subfolder)
|
|
413
|
-
subpath.mkdir(parents=True, exist_ok=True)
|
|
414
|
-
if isinstance(model.compiled_query, mq.SqlModelQuery):
|
|
415
|
-
output_filepath = Path(subpath, model.name+'.sql')
|
|
416
|
-
query = model.compiled_query.query
|
|
417
|
-
with open(output_filepath, 'w') as f:
|
|
418
|
-
f.write(query)
|
|
419
|
-
if runquery and isinstance(model.result, pl.LazyFrame):
|
|
420
|
-
output_filepath = Path(subpath, model.name+'.csv')
|
|
421
|
-
model.result.collect().write_csv(output_filepath)
|
|
422
|
-
|
|
423
|
-
write_placeholders()
|
|
424
|
-
all_model_names = dag.get_all_query_models()
|
|
425
|
-
coroutines = [asyncio.to_thread(write_model_outputs, dag.models_dict[name]) for name in all_model_names]
|
|
426
|
-
await u.asyncio_gather(coroutines)
|
|
427
|
-
|
|
428
|
-
if recurse:
|
|
429
|
-
self._draw_dag(dag, output_folder)
|
|
430
|
-
|
|
431
|
-
if isinstance(dag.target_model, m.QueryModel) and dag.target_model.compiled_query is not None:
|
|
432
|
-
return dag.target_model.compiled_query.query
|
|
433
|
-
|
|
434
422
|
async def compile(
|
|
435
|
-
self, *,
|
|
436
|
-
|
|
423
|
+
self, *, selected_model: str | None = None, test_set: str | None = None, do_all_test_sets: bool = False,
|
|
424
|
+
runquery: bool = False, clear: bool = False, buildtime_only: bool = False, runtime_only: bool = False
|
|
437
425
|
) -> None:
|
|
438
426
|
"""
|
|
439
|
-
|
|
427
|
+
Compile models into the "target/compile" folder.
|
|
440
428
|
|
|
441
|
-
|
|
429
|
+
Behavior:
|
|
430
|
+
- Buildtime outputs: target/compile/buildtime/*.sql (for SQL build models) and dag.png
|
|
431
|
+
- Runtime outputs: target/compile/runtime/[test_set]/dbviews/*.sql, federates/*.sql, dag.png
|
|
432
|
+
If runquery=True, also write CSVs for runtime models.
|
|
433
|
+
- Options: clear entire compile folder first; compile only buildtime or only runtime.
|
|
442
434
|
|
|
443
435
|
Arguments:
|
|
444
|
-
dataset: The name of the dataset to compile. Ignored if "do_all_datasets" argument is True, but required (i.e., cannot be None) if "do_all_datasets" is False. Default is None.
|
|
445
|
-
do_all_datasets: If True, compile all datasets and ignore the "dataset" argument. Default is False.
|
|
446
436
|
selected_model: The name of the model to compile. If specified, the compiled SQL query is also printed in the terminal. If None, all models for the selected dataset are compiled. Default is None.
|
|
447
437
|
test_set: The name of the test set to compile with. If None, the default test set is used (which can vary by dataset). Ignored if `do_all_test_sets` argument is True. Default is None.
|
|
448
438
|
do_all_test_sets: Whether to compile all applicable test sets for the selected dataset(s). If True, the `test_set` argument is ignored. Default is False.
|
|
449
|
-
runquery
|
|
439
|
+
runquery: Whether to run all compiled queries and save each result as a CSV file. If True and `selected_model` is specified, all upstream models of the selected model is compiled as well. Default is False.
|
|
440
|
+
clear: Whether to clear the "target/compile/" folder before compiling. Default is False.
|
|
441
|
+
buildtime_only: Whether to compile only buildtime models. Default is False.
|
|
442
|
+
runtime_only: Whether to compile only runtime models. Default is False.
|
|
450
443
|
"""
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
444
|
+
border = "=" * 80
|
|
445
|
+
underlines = "-" * len(border)
|
|
446
|
+
|
|
447
|
+
compile_root = Path(self._project_path, c.TARGET_FOLDER, c.COMPILE_FOLDER)
|
|
448
|
+
if clear and compile_root.exists():
|
|
449
|
+
shutil.rmtree(compile_root)
|
|
450
|
+
|
|
451
|
+
models_dict = self._get_models_dict(always_python_df=False)
|
|
452
|
+
|
|
453
|
+
if selected_model is not None:
|
|
454
|
+
selected_model = u.normalize_name(selected_model)
|
|
455
|
+
if selected_model not in models_dict:
|
|
456
|
+
print(f"No such model found: {selected_model}")
|
|
457
|
+
return
|
|
458
|
+
if not isinstance(models_dict[selected_model], m.QueryModel):
|
|
459
|
+
print(f"Model '{selected_model}' is not a query model. Nothing to do.")
|
|
460
|
+
return
|
|
462
461
|
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
462
|
+
model_to_compile = None
|
|
463
|
+
|
|
464
|
+
# Buildtime compilation
|
|
465
|
+
if not runtime_only:
|
|
466
|
+
print(underlines)
|
|
467
|
+
print(f"Compiling buildtime models")
|
|
468
|
+
print(underlines)
|
|
469
|
+
|
|
470
|
+
buildtime_folder = Path(compile_root, c.COMPILE_BUILDTIME_FOLDER)
|
|
471
|
+
buildtime_folder.mkdir(parents=True, exist_ok=True)
|
|
472
|
+
|
|
473
|
+
def write_buildtime_model(model: m.DataModel, static_models: dict[str, m.StaticModel]) -> None:
|
|
474
|
+
if not isinstance(model, m.BuildModel):
|
|
475
|
+
return
|
|
476
|
+
|
|
477
|
+
model.compile_for_build(self._conn_args, static_models)
|
|
478
|
+
|
|
479
|
+
if isinstance(model.compiled_query, mq.SqlModelQuery):
|
|
480
|
+
out_path = Path(buildtime_folder, f"{model.name}.sql")
|
|
481
|
+
with open(out_path, 'w') as f:
|
|
482
|
+
f.write(model.compiled_query.query)
|
|
483
|
+
print(f"Successfully compiled build model: {model.name}")
|
|
484
|
+
elif isinstance(model.compiled_query, mq.PyModelQuery):
|
|
485
|
+
print(f"The build model '{model.name}' is in Python. Compilation for Python is not supported yet.")
|
|
486
|
+
|
|
487
|
+
static_models = self._get_static_models()
|
|
488
|
+
if selected_model is not None:
|
|
489
|
+
model_to_compile = models_dict[selected_model]
|
|
490
|
+
write_buildtime_model(model_to_compile, static_models)
|
|
491
|
+
else:
|
|
492
|
+
coros = [asyncio.to_thread(write_buildtime_model, m, static_models) for m in static_models.values()]
|
|
493
|
+
await u.asyncio_gather(coros)
|
|
469
494
|
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
queries = await u.asyncio_gather(coroutines)
|
|
495
|
+
print(underlines)
|
|
496
|
+
print()
|
|
474
497
|
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
498
|
+
# Runtime compilation
|
|
499
|
+
if not buildtime_only:
|
|
500
|
+
if do_all_test_sets:
|
|
501
|
+
test_set_names_set = set(self._manifest_cfg.selection_test_sets.keys())
|
|
502
|
+
test_set_names_set.add(c.DEFAULT_TEST_SET_NAME)
|
|
503
|
+
test_set_names = list(test_set_names_set)
|
|
504
|
+
else:
|
|
505
|
+
test_set_names = [test_set or c.DEFAULT_TEST_SET_NAME]
|
|
506
|
+
|
|
507
|
+
for ts_name in test_set_names:
|
|
508
|
+
print(underlines)
|
|
509
|
+
print(f"Compiling runtime models (test set '{ts_name}')")
|
|
510
|
+
print(underlines)
|
|
511
|
+
|
|
512
|
+
# Build user and selections from test set config if present
|
|
513
|
+
ts_conf = self._manifest_cfg.selection_test_sets.get(ts_name, self._manifest_cfg.get_default_test_set())
|
|
514
|
+
# Separate base fields from custom fields
|
|
515
|
+
access_level = ts_conf.user.access_level
|
|
516
|
+
custom_fields = self._auth.CustomUserFields(**ts_conf.user.custom_fields)
|
|
517
|
+
if access_level == "guest":
|
|
518
|
+
user = GuestUser(username="", custom_fields=custom_fields)
|
|
519
|
+
else:
|
|
520
|
+
user = RegisteredUser(username="", access_level=access_level, custom_fields=custom_fields)
|
|
521
|
+
|
|
522
|
+
# Generate DAG across all models. When runquery=True, force models to produce Python dataframes so CSVs can be written.
|
|
523
|
+
dag = await self._get_compiled_dag(
|
|
524
|
+
user=user, selections=ts_conf.parameters, configurables=ts_conf.configurables, always_python_df=runquery,
|
|
525
|
+
)
|
|
526
|
+
if runquery:
|
|
527
|
+
await dag._run_models()
|
|
528
|
+
|
|
529
|
+
# Prepare output folders
|
|
530
|
+
runtime_folder = Path(compile_root, c.COMPILE_RUNTIME_FOLDER, ts_name)
|
|
531
|
+
dbviews_folder = Path(runtime_folder, c.DBVIEWS_FOLDER)
|
|
532
|
+
federates_folder = Path(runtime_folder, c.FEDERATES_FOLDER)
|
|
533
|
+
dbviews_folder.mkdir(parents=True, exist_ok=True)
|
|
534
|
+
federates_folder.mkdir(parents=True, exist_ok=True)
|
|
535
|
+
with open(Path(runtime_folder, "placeholders.json"), "w") as f:
|
|
536
|
+
json.dump(dag.placeholders, f)
|
|
537
|
+
|
|
538
|
+
# Function to write runtime models
|
|
539
|
+
def write_runtime_model(model: m.DataModel) -> None:
|
|
540
|
+
if not isinstance(model, m.QueryModel):
|
|
541
|
+
return
|
|
542
|
+
|
|
543
|
+
if model.model_type not in (m.ModelType.DBVIEW, m.ModelType.FEDERATE):
|
|
544
|
+
return
|
|
545
|
+
|
|
546
|
+
subfolder = dbviews_folder if model.model_type == m.ModelType.DBVIEW else federates_folder
|
|
547
|
+
model_type = "dbview" if model.model_type == m.ModelType.DBVIEW else "federate"
|
|
548
|
+
|
|
549
|
+
if isinstance(model.compiled_query, mq.SqlModelQuery):
|
|
550
|
+
out_sql = Path(subfolder, f"{model.name}.sql")
|
|
551
|
+
with open(out_sql, 'w') as f:
|
|
552
|
+
f.write(model.compiled_query.query)
|
|
553
|
+
print(f"Successfully compiled {model_type} model: {model.name}")
|
|
554
|
+
elif isinstance(model.compiled_query, mq.PyModelQuery):
|
|
555
|
+
print(f"The {model_type} model '{model.name}' is in Python. Compilation for Python is not supported yet.")
|
|
556
|
+
|
|
557
|
+
if runquery and isinstance(model.result, pl.LazyFrame):
|
|
558
|
+
out_csv = Path(subfolder, f"{model.name}.csv")
|
|
559
|
+
model.result.collect().write_csv(out_csv)
|
|
560
|
+
print(f"Successfully created CSV for {model_type} model: {model.name}")
|
|
561
|
+
|
|
562
|
+
# If selected_model is provided for runtime, only emit that model's outputs
|
|
563
|
+
if selected_model is not None:
|
|
564
|
+
model_to_compile = dag.models_dict[selected_model]
|
|
565
|
+
write_runtime_model(model_to_compile)
|
|
566
|
+
else:
|
|
567
|
+
coros = [asyncio.to_thread(write_runtime_model, model) for model in dag.models_dict.values()]
|
|
568
|
+
await u.asyncio_gather(coros)
|
|
569
|
+
|
|
570
|
+
print(underlines)
|
|
571
|
+
print()
|
|
572
|
+
|
|
573
|
+
print(f"All compilations complete! See the '{c.TARGET_FOLDER}/{c.COMPILE_FOLDER}/' folder for results.")
|
|
574
|
+
if model_to_compile and isinstance(model_to_compile, m.QueryModel) and isinstance(model_to_compile.compiled_query, mq.SqlModelQuery):
|
|
575
|
+
print()
|
|
576
|
+
print(border)
|
|
577
|
+
print(f"Compiled SQL query for model '{model_to_compile.name}':")
|
|
578
|
+
print(underlines)
|
|
579
|
+
print(model_to_compile.compiled_query.query)
|
|
580
|
+
print(border)
|
|
479
581
|
print()
|
|
480
582
|
|
|
481
|
-
def _permission_error(self, user:
|
|
482
|
-
|
|
483
|
-
return InvalidInputError(25, f"User{username} does not have permission to access {scope} {data_type}: {data_name}")
|
|
583
|
+
def _permission_error(self, user: AbstractUser, data_type: str, data_name: str, scope: str) -> InvalidInputError:
|
|
584
|
+
return InvalidInputError(403, f"unauthorized_access_to_{data_type}", f"User '{user}' does not have permission to access {scope} {data_type}: {data_name}")
|
|
484
585
|
|
|
485
586
|
def seed(self, name: str) -> pl.LazyFrame:
|
|
486
587
|
"""
|
|
@@ -515,37 +616,77 @@ class SquirrelsProject:
|
|
|
515
616
|
target_model_config=dag.target_model.model_config
|
|
516
617
|
)
|
|
517
618
|
|
|
518
|
-
|
|
519
|
-
self, name: str, *, selections: dict[str, t.Any] = {}, user: BaseUser | None = None, require_auth: bool = True
|
|
520
|
-
) -> dr.DatasetResult:
|
|
619
|
+
def _enforce_max_result_rows(self, lazy_df: pl.LazyFrame, error_type: str) -> pl.DataFrame:
|
|
521
620
|
"""
|
|
522
|
-
|
|
523
|
-
|
|
621
|
+
Collect at most max_rows + 1 rows from a LazyFrame to detect overflow.
|
|
622
|
+
Raises InvalidInputError if the result exceeds the maximum allowed rows.
|
|
623
|
+
|
|
524
624
|
Arguments:
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
user: The user to use for authentication. If None, no user is used. Optional, default is None.
|
|
625
|
+
lazy_df: The LazyFrame to collect and check
|
|
626
|
+
error_type: Either "dataset" or "query" to customize the error message
|
|
528
627
|
|
|
529
628
|
Returns:
|
|
530
|
-
A
|
|
629
|
+
A DataFrame with at most max_rows rows (or raises if exceeded)
|
|
531
630
|
"""
|
|
631
|
+
max_rows = self._env_vars.datasets_max_rows_output
|
|
632
|
+
# Collect max_rows + 1 to detect overflow without loading unbounded results
|
|
633
|
+
collected = lazy_df.limit(max_rows + 1).collect()
|
|
634
|
+
row_count = collected.select(pl.len()).item()
|
|
635
|
+
|
|
636
|
+
if row_count > max_rows:
|
|
637
|
+
raise InvalidInputError(
|
|
638
|
+
413, f"{error_type}_result_too_large",
|
|
639
|
+
f"The {error_type} result contains {row_count} rows, which exceeds the maximum allowed of {max_rows} rows."
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
return collected
|
|
643
|
+
|
|
644
|
+
async def _dataset_result(
|
|
645
|
+
self, name: str, *, selections: dict[str, t.Any] = {}, user: AbstractUser | None = None,
|
|
646
|
+
configurables: dict[str, str] = {}, check_user_access: bool = True
|
|
647
|
+
) -> dr.DatasetResult:
|
|
648
|
+
if user is None:
|
|
649
|
+
user = self._guest_user
|
|
650
|
+
|
|
532
651
|
scope = self._manifest_cfg.datasets[name].scope
|
|
533
|
-
if
|
|
652
|
+
if check_user_access and not self._auth.can_user_access_scope(user, scope):
|
|
534
653
|
raise self._permission_error(user, "dataset", name, scope.name)
|
|
535
654
|
|
|
655
|
+
dataset_config = self._manifest_cfg.datasets[name]
|
|
656
|
+
configurables = {**self._manifest_cfg.get_default_configurables(overrides=dataset_config.configurables), **configurables}
|
|
657
|
+
|
|
536
658
|
dag = self._generate_dag(name)
|
|
537
659
|
await dag.execute(
|
|
538
|
-
self._param_args, self._param_cfg_set, self._context_func, user, dict(selections),
|
|
539
|
-
default_traits=self._manifest_cfg.get_default_traits()
|
|
660
|
+
self._param_args, self._param_cfg_set, self._context_func, user, dict(selections), configurables=configurables
|
|
540
661
|
)
|
|
541
662
|
assert isinstance(dag.target_model.result, pl.LazyFrame)
|
|
663
|
+
df = self._enforce_max_result_rows(dag.target_model.result, "dataset")
|
|
542
664
|
return dr.DatasetResult(
|
|
543
665
|
target_model_config=dag.target_model.model_config,
|
|
544
|
-
df=
|
|
666
|
+
df=df.with_row_index("_row_num", offset=1)
|
|
545
667
|
)
|
|
546
668
|
|
|
669
|
+
async def dataset_result(
|
|
670
|
+
self, name: str, *, selections: dict[str, t.Any] = {}, user: AbstractUser | None = None, configurables: dict[str, str] = {}
|
|
671
|
+
) -> dr.DatasetResult:
|
|
672
|
+
"""
|
|
673
|
+
Async method to retrieve a dataset as a DatasetResult object (with metadata) given parameter selections.
|
|
674
|
+
|
|
675
|
+
Arguments:
|
|
676
|
+
name: The name of the dataset to retrieve.
|
|
677
|
+
selections: A dictionary of parameter selections to apply to the dataset. Optional, default is empty dictionary.
|
|
678
|
+
user: The user to use for authentication. If None, no user is used. Optional, default is None.
|
|
679
|
+
configurables: A dictionary of configurables to apply to the dataset. Optional, default is empty dictionary.
|
|
680
|
+
|
|
681
|
+
Returns:
|
|
682
|
+
A DatasetResult object containing the dataset result (as a polars DataFrame), its description, and the column details.
|
|
683
|
+
"""
|
|
684
|
+
result = await self._dataset_result(name, selections=selections, user=user, configurables=configurables, check_user_access=False)
|
|
685
|
+
return result
|
|
686
|
+
|
|
547
687
|
async def dashboard(
|
|
548
|
-
self, name: str, *, selections: dict[str, t.Any] = {}, user:
|
|
688
|
+
self, name: str, *, selections: dict[str, t.Any] = {}, user: AbstractUser | None = None, dashboard_type: t.Type[T] = d.PngDashboard,
|
|
689
|
+
configurables: dict[str, str] = {}
|
|
549
690
|
) -> T:
|
|
550
691
|
"""
|
|
551
692
|
Async method to retrieve a dashboard given parameter selections.
|
|
@@ -559,28 +700,97 @@ class SquirrelsProject:
|
|
|
559
700
|
Returns:
|
|
560
701
|
The dashboard type specified by the "dashboard_type" argument.
|
|
561
702
|
"""
|
|
703
|
+
if user is None:
|
|
704
|
+
user = self._guest_user
|
|
705
|
+
|
|
562
706
|
scope = self._dashboards[name].config.scope
|
|
563
707
|
if not self._auth.can_user_access_scope(user, scope):
|
|
564
708
|
raise self._permission_error(user, "dashboard", name, scope.name)
|
|
565
709
|
|
|
566
710
|
async def get_dataset_df(dataset_name: str, fixed_params: dict[str, t.Any]) -> pl.DataFrame:
|
|
567
711
|
final_selections = {**selections, **fixed_params}
|
|
568
|
-
result = await self.
|
|
712
|
+
result = await self.dataset_result(
|
|
713
|
+
dataset_name, selections=final_selections, user=user, configurables=configurables
|
|
714
|
+
)
|
|
569
715
|
return result.df
|
|
570
716
|
|
|
571
|
-
|
|
717
|
+
dashboard_config = self._dashboards[name].config
|
|
718
|
+
parameter_set = self._param_cfg_set.apply_selections(dashboard_config.parameters, selections, user)
|
|
719
|
+
prms = parameter_set.get_parameters_as_dict()
|
|
720
|
+
|
|
721
|
+
configurables = {**self._manifest_cfg.get_default_configurables(overrides=dashboard_config.configurables), **configurables}
|
|
722
|
+
context = {}
|
|
723
|
+
ctx_args = m.ContextArgs(
|
|
724
|
+
**self._param_args.__dict__, user=user, prms=prms, configurables=configurables, _conn_args=self._conn_args
|
|
725
|
+
)
|
|
726
|
+
self._context_func(context, ctx_args)
|
|
727
|
+
|
|
728
|
+
args = d.DashboardArgs(
|
|
729
|
+
**ctx_args.__dict__, ctx=context, _get_dataset=get_dataset_df
|
|
730
|
+
)
|
|
572
731
|
try:
|
|
573
732
|
return await self._dashboards[name].get_dashboard(args, dashboard_type=dashboard_type)
|
|
574
733
|
except KeyError:
|
|
575
734
|
raise KeyError(f"No dashboard file found for: {name}")
|
|
576
735
|
|
|
577
736
|
async def query_models(
|
|
578
|
-
self, sql_query: str, *, selections: dict[str, t.Any] = {},
|
|
737
|
+
self, sql_query: str, *, user: AbstractUser | None = None, selections: dict[str, t.Any] = {}, configurables: dict[str, str] = {}
|
|
579
738
|
) -> dr.DatasetResult:
|
|
580
|
-
|
|
739
|
+
if user is None:
|
|
740
|
+
user = self._guest_user
|
|
741
|
+
|
|
742
|
+
dag = await self._get_compiled_dag(user=user, sql_query=sql_query, selections=selections, configurables=configurables)
|
|
581
743
|
await dag._run_models()
|
|
582
744
|
assert isinstance(dag.target_model.result, pl.LazyFrame)
|
|
745
|
+
df = self._enforce_max_result_rows(dag.target_model.result, "query")
|
|
583
746
|
return dr.DatasetResult(
|
|
584
747
|
target_model_config=dag.target_model.model_config,
|
|
585
|
-
df=
|
|
748
|
+
df=df.with_row_index("_row_num", offset=1)
|
|
586
749
|
)
|
|
750
|
+
|
|
751
|
+
async def get_compiled_model_query(
|
|
752
|
+
self, model_name: str, *, user: AbstractUser | None = None, selections: dict[str, t.Any] = {}, configurables: dict[str, str] = {}
|
|
753
|
+
) -> rm.CompiledQueryModel:
|
|
754
|
+
"""
|
|
755
|
+
Compile the specified data model and return its language and compiled definition.
|
|
756
|
+
"""
|
|
757
|
+
if user is None:
|
|
758
|
+
user = self._guest_user
|
|
759
|
+
|
|
760
|
+
name = u.normalize_name(model_name)
|
|
761
|
+
models_dict = self._get_models_dict(always_python_df=False)
|
|
762
|
+
if name not in models_dict:
|
|
763
|
+
raise InvalidInputError(404, "model_not_found", f"No data model found with name: {model_name}")
|
|
764
|
+
|
|
765
|
+
model = models_dict[name]
|
|
766
|
+
# Only build, dbview, and federate models support runtime compiled definition in this context
|
|
767
|
+
if not isinstance(model, (m.BuildModel, m.DbviewModel, m.FederateModel)):
|
|
768
|
+
raise InvalidInputError(400, "unsupported_model_type", "Only build, dbview, and federate models currently support compiled definition via this endpoint")
|
|
769
|
+
|
|
770
|
+
# Build a DAG with this model as the target, without a dataset context
|
|
771
|
+
model.is_target = True
|
|
772
|
+
dag = m.DAG(None, model, models_dict, self._vdl_catalog_db_path, self._logger)
|
|
773
|
+
|
|
774
|
+
cfg = {**self._manifest_cfg.get_default_configurables(), **configurables}
|
|
775
|
+
await dag.execute(
|
|
776
|
+
self._param_args, self._param_cfg_set, self._context_func, user, selections, runquery=False, configurables=cfg
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
language = "sql" if isinstance(model.query_file, mq.SqlQueryFile) else "python"
|
|
780
|
+
if isinstance(model, m.BuildModel):
|
|
781
|
+
# Compile SQL build models; Python build models not yet supported
|
|
782
|
+
if isinstance(model.query_file, mq.SqlQueryFile):
|
|
783
|
+
static_models = self._get_static_models()
|
|
784
|
+
compiled = model._compile_sql_model(model.query_file, self._conn_args, static_models)
|
|
785
|
+
definition = compiled.query
|
|
786
|
+
else:
|
|
787
|
+
definition = "# Compiling Python build models is currently not supported. This will be available in a future version of Squirrels..."
|
|
788
|
+
elif isinstance(model.compiled_query, mq.SqlModelQuery):
|
|
789
|
+
definition = model.compiled_query.query
|
|
790
|
+
elif isinstance(model.compiled_query, mq.PyModelQuery):
|
|
791
|
+
definition = "# Compiling Python data models is currently not supported. This will be available in a future version of Squirrels..."
|
|
792
|
+
else:
|
|
793
|
+
raise NotImplementedError(f"Query type not supported: {model.compiled_query.__class__.__name__}")
|
|
794
|
+
|
|
795
|
+
return rm.CompiledQueryModel(language=language, definition=definition, placeholders=dag.placeholders)
|
|
796
|
+
|