squirrels 0.4.0__py3-none-any.whl → 0.5.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of squirrels might be problematic. Click here for more details.
- squirrels/__init__.py +10 -6
- squirrels/_api_response_models.py +93 -44
- squirrels/_api_server.py +571 -219
- squirrels/_auth.py +451 -0
- squirrels/_command_line.py +61 -20
- squirrels/_connection_set.py +38 -25
- squirrels/_constants.py +44 -34
- squirrels/_dashboards_io.py +34 -16
- squirrels/_exceptions.py +57 -0
- squirrels/_initializer.py +117 -44
- squirrels/_manifest.py +124 -62
- squirrels/_model_builder.py +111 -0
- squirrels/_model_configs.py +74 -0
- squirrels/_model_queries.py +52 -0
- squirrels/_models.py +860 -354
- squirrels/_package_loader.py +8 -4
- squirrels/_parameter_configs.py +45 -65
- squirrels/_parameter_sets.py +15 -13
- squirrels/_project.py +561 -0
- squirrels/_py_module.py +4 -3
- squirrels/_seeds.py +35 -16
- squirrels/_sources.py +106 -0
- squirrels/_utils.py +166 -63
- squirrels/_version.py +1 -1
- squirrels/arguments/init_time_args.py +78 -15
- squirrels/arguments/run_time_args.py +62 -101
- squirrels/dashboards.py +4 -4
- squirrels/data_sources.py +94 -162
- squirrels/dataset_result.py +86 -0
- squirrels/dateutils.py +4 -4
- squirrels/package_data/base_project/.env +30 -0
- squirrels/package_data/base_project/.env.example +30 -0
- squirrels/package_data/base_project/.gitignore +3 -2
- squirrels/package_data/base_project/assets/expenses.db +0 -0
- squirrels/package_data/base_project/connections.yml +11 -3
- squirrels/package_data/base_project/dashboards/dashboard_example.py +15 -13
- squirrels/package_data/base_project/dashboards/dashboard_example.yml +22 -0
- squirrels/package_data/base_project/docker/.dockerignore +5 -2
- squirrels/package_data/base_project/docker/Dockerfile +3 -3
- squirrels/package_data/base_project/docker/compose.yml +1 -1
- squirrels/package_data/base_project/duckdb_init.sql +9 -0
- squirrels/package_data/base_project/macros/macros_example.sql +15 -0
- squirrels/package_data/base_project/models/builds/build_example.py +26 -0
- squirrels/package_data/base_project/models/builds/build_example.sql +16 -0
- squirrels/package_data/base_project/models/builds/build_example.yml +55 -0
- squirrels/package_data/base_project/models/dbviews/dbview_example.sql +12 -22
- squirrels/package_data/base_project/models/dbviews/dbview_example.yml +26 -0
- squirrels/package_data/base_project/models/federates/federate_example.py +38 -15
- squirrels/package_data/base_project/models/federates/federate_example.sql +16 -2
- squirrels/package_data/base_project/models/federates/federate_example.yml +65 -0
- squirrels/package_data/base_project/models/sources.yml +39 -0
- squirrels/package_data/base_project/parameters.yml +36 -21
- squirrels/package_data/base_project/pyconfigs/connections.py +6 -11
- squirrels/package_data/base_project/pyconfigs/context.py +20 -33
- squirrels/package_data/base_project/pyconfigs/parameters.py +19 -21
- squirrels/package_data/base_project/pyconfigs/user.py +23 -0
- squirrels/package_data/base_project/seeds/seed_categories.yml +15 -0
- squirrels/package_data/base_project/seeds/seed_subcategories.csv +15 -15
- squirrels/package_data/base_project/seeds/seed_subcategories.yml +21 -0
- squirrels/package_data/base_project/squirrels.yml.j2 +17 -40
- squirrels/parameters.py +20 -20
- {squirrels-0.4.0.dist-info → squirrels-0.5.0b1.dist-info}/METADATA +31 -32
- squirrels-0.5.0b1.dist-info/RECORD +70 -0
- {squirrels-0.4.0.dist-info → squirrels-0.5.0b1.dist-info}/WHEEL +1 -1
- squirrels-0.5.0b1.dist-info/entry_points.txt +3 -0
- {squirrels-0.4.0.dist-info → squirrels-0.5.0b1.dist-info/licenses}/LICENSE +1 -1
- squirrels/_authenticator.py +0 -85
- squirrels/_environcfg.py +0 -84
- squirrels/package_data/assets/favicon.ico +0 -0
- squirrels/package_data/assets/index.css +0 -1
- squirrels/package_data/assets/index.js +0 -58
- squirrels/package_data/base_project/dashboards.yml +0 -10
- squirrels/package_data/base_project/env.yml +0 -29
- squirrels/package_data/base_project/models/dbviews/dbview_example.py +0 -47
- squirrels/package_data/base_project/pyconfigs/auth.py +0 -45
- squirrels/package_data/templates/index.html +0 -18
- squirrels/project.py +0 -378
- squirrels/user_base.py +0 -55
- squirrels-0.4.0.dist-info/RECORD +0 -60
- squirrels-0.4.0.dist-info/entry_points.txt +0 -4
squirrels/_project.py
ADDED
|
@@ -0,0 +1,561 @@
|
|
|
1
|
+
from dotenv import dotenv_values
|
|
2
|
+
from uuid import uuid4
|
|
3
|
+
import asyncio, typing as t, functools as ft, shutil, json, os
|
|
4
|
+
import logging as l, matplotlib.pyplot as plt, networkx as nx, polars as pl
|
|
5
|
+
import sqlglot, sqlglot.expressions
|
|
6
|
+
|
|
7
|
+
from ._auth import Authenticator, BaseUser
|
|
8
|
+
from ._model_builder import ModelBuilder
|
|
9
|
+
from ._exceptions import InvalidInputError, ConfigurationError
|
|
10
|
+
from . import _utils as u, _constants as c, _manifest as mf, _connection_set as cs, _api_response_models as arm
|
|
11
|
+
from . import _seeds as s, _models as m, _model_configs as mc, _model_queries as mq, _sources as so
|
|
12
|
+
from . import _parameter_sets as ps, _dashboards_io as d, dashboards as dash, dataset_result as dr
|
|
13
|
+
|
|
14
|
+
T = t.TypeVar("T", bound=dash.Dashboard)
|
|
15
|
+
M = t.TypeVar("M", bound=m.DataModel)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class _CustomJsonFormatter(l.Formatter):
|
|
19
|
+
def format(self, record: l.LogRecord) -> str:
|
|
20
|
+
super().format(record)
|
|
21
|
+
info = {
|
|
22
|
+
"timestamp": self.formatTime(record),
|
|
23
|
+
"project_id": record.name,
|
|
24
|
+
"level": record.levelname,
|
|
25
|
+
"message": record.getMessage(),
|
|
26
|
+
"thread": record.thread,
|
|
27
|
+
"thread_name": record.threadName,
|
|
28
|
+
"process": record.process,
|
|
29
|
+
**record.__dict__.get("info", {})
|
|
30
|
+
}
|
|
31
|
+
output = {
|
|
32
|
+
"data": record.__dict__.get("data", {}),
|
|
33
|
+
"info": info
|
|
34
|
+
}
|
|
35
|
+
return json.dumps(output)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class SquirrelsProject:
|
|
39
|
+
"""
|
|
40
|
+
Initiate an instance of this class to interact with a Squirrels project through Python code. For example this can be handy to experiment with the datasets produced by Squirrels in a Jupyter notebook.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self, *, filepath: str = ".", log_file: str | None = c.LOGS_FILE, log_level: str = "INFO", log_format: str = "text") -> None:
|
|
44
|
+
"""
|
|
45
|
+
Constructor for SquirrelsProject class. Loads the file contents of the Squirrels project into memory as member fields.
|
|
46
|
+
|
|
47
|
+
Arguments:
|
|
48
|
+
filepath: The path to the Squirrels project file. Defaults to the current working directory.
|
|
49
|
+
log_level: The logging level to use. Options are "DEBUG", "INFO", and "WARNING". Default is "INFO".
|
|
50
|
+
log_file: The name of the log file to write to from the "logs/" subfolder. If None or empty string, then file logging is disabled. Default is "squirrels.log".
|
|
51
|
+
log_format: The format of the log records. Options are "text" and "json". Default is "text".
|
|
52
|
+
"""
|
|
53
|
+
self._filepath = filepath
|
|
54
|
+
self._logger = self._get_logger(self._filepath, log_file, log_level, log_format)
|
|
55
|
+
|
|
56
|
+
def _get_logger(self, base_path: str, log_file: str | None, log_level: str, log_format: str) -> u.Logger:
|
|
57
|
+
logger = u.Logger(name=uuid4().hex)
|
|
58
|
+
logger.setLevel(log_level.upper())
|
|
59
|
+
|
|
60
|
+
handler = l.StreamHandler()
|
|
61
|
+
handler.setLevel("WARNING")
|
|
62
|
+
handler.setFormatter(l.Formatter("%(levelname)s: %(asctime)s - %(message)s"))
|
|
63
|
+
logger.addHandler(handler)
|
|
64
|
+
|
|
65
|
+
if log_format.lower() == "json":
|
|
66
|
+
formatter = _CustomJsonFormatter()
|
|
67
|
+
elif log_format.lower() == "text":
|
|
68
|
+
formatter = l.Formatter("[%(name)s] %(asctime)s - %(levelname)s - %(message)s")
|
|
69
|
+
else:
|
|
70
|
+
raise ValueError("log_format must be either 'text' or 'json'")
|
|
71
|
+
|
|
72
|
+
if log_file:
|
|
73
|
+
path = u.Path(base_path, c.LOGS_FOLDER, log_file)
|
|
74
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
75
|
+
|
|
76
|
+
handler = l.FileHandler(path)
|
|
77
|
+
handler.setFormatter(formatter)
|
|
78
|
+
logger.addHandler(handler)
|
|
79
|
+
|
|
80
|
+
return logger
|
|
81
|
+
|
|
82
|
+
@ft.cached_property
|
|
83
|
+
def _env_vars(self) -> dict[str, str]:
|
|
84
|
+
dotenv_files = [c.DOTENV_FILE, c.DOTENV_LOCAL_FILE]
|
|
85
|
+
dotenv_vars = {}
|
|
86
|
+
for file in dotenv_files:
|
|
87
|
+
dotenv_vars.update({k: v for k, v in dotenv_values(f"{self._filepath}/{file}").items() if v is not None})
|
|
88
|
+
return {**os.environ, **dotenv_vars}
|
|
89
|
+
|
|
90
|
+
@ft.cached_property
|
|
91
|
+
def _manifest_cfg(self) -> mf.ManifestConfig:
|
|
92
|
+
return mf.ManifestIO.load_from_file(self._logger, self._filepath, self._env_vars)
|
|
93
|
+
|
|
94
|
+
@ft.cached_property
|
|
95
|
+
def _seeds(self) -> s.Seeds:
|
|
96
|
+
return s.SeedsIO.load_files(self._logger, self._filepath, self._env_vars)
|
|
97
|
+
|
|
98
|
+
@ft.cached_property
|
|
99
|
+
def _sources(self) -> so.Sources:
|
|
100
|
+
return so.SourcesIO.load_file(self._logger, self._filepath, self._env_vars)
|
|
101
|
+
|
|
102
|
+
@ft.cached_property
|
|
103
|
+
def _build_model_files(self) -> dict[str, mq.QueryFileWithConfig]:
|
|
104
|
+
return m.ModelsIO.load_build_files(self._logger, self._filepath)
|
|
105
|
+
|
|
106
|
+
@ft.cached_property
|
|
107
|
+
def _dbview_model_files(self) -> dict[str, mq.QueryFileWithConfig]:
|
|
108
|
+
return m.ModelsIO.load_dbview_files(self._logger, self._filepath, self._env_vars)
|
|
109
|
+
|
|
110
|
+
@ft.cached_property
|
|
111
|
+
def _federate_model_files(self) -> dict[str, mq.QueryFileWithConfig]:
|
|
112
|
+
return m.ModelsIO.load_federate_files(self._logger, self._filepath)
|
|
113
|
+
|
|
114
|
+
@ft.cached_property
|
|
115
|
+
def _context_func(self) -> m.ContextFunc:
|
|
116
|
+
return m.ModelsIO.load_context_func(self._logger, self._filepath)
|
|
117
|
+
|
|
118
|
+
@ft.cached_property
|
|
119
|
+
def _dashboards(self) -> dict[str, d.DashboardDefinition]:
|
|
120
|
+
return d.DashboardsIO.load_files(self._logger, self._filepath)
|
|
121
|
+
|
|
122
|
+
@ft.cached_property
|
|
123
|
+
def _conn_args(self) -> cs.ConnectionsArgs:
|
|
124
|
+
return cs.ConnectionSetIO.load_conn_py_args(self._logger, self._filepath, self._env_vars, self._manifest_cfg)
|
|
125
|
+
|
|
126
|
+
@ft.cached_property
|
|
127
|
+
def _conn_set(self) -> cs.ConnectionSet:
|
|
128
|
+
return cs.ConnectionSetIO.load_from_file(self._logger, self._filepath, self._manifest_cfg, self._conn_args)
|
|
129
|
+
|
|
130
|
+
@ft.cached_property
|
|
131
|
+
def _auth(self) -> Authenticator:
|
|
132
|
+
return Authenticator(self._logger, self._filepath, self._env_vars)
|
|
133
|
+
|
|
134
|
+
@ft.cached_property
|
|
135
|
+
def _param_args(self) -> ps.ParametersArgs:
|
|
136
|
+
return ps.ParameterConfigsSetIO.get_param_args(self._conn_args)
|
|
137
|
+
|
|
138
|
+
@ft.cached_property
|
|
139
|
+
def _param_cfg_set(self) -> ps.ParameterConfigsSet:
|
|
140
|
+
return ps.ParameterConfigsSetIO.load_from_file(
|
|
141
|
+
self._logger, self._filepath, self._manifest_cfg, self._seeds, self._conn_set, self._param_args
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
@ft.cached_property
|
|
145
|
+
def _j2_env(self) -> u.EnvironmentWithMacros:
|
|
146
|
+
return u.EnvironmentWithMacros(self._logger, loader=u.j2.FileSystemLoader(self._filepath))
|
|
147
|
+
|
|
148
|
+
@ft.cached_property
|
|
149
|
+
def _duckdb_venv_path(self) -> str:
|
|
150
|
+
duckdb_filepath_setting_val = self._env_vars.get(c.SQRL_DUCKDB_VENV_DB_FILE_PATH, f"{c.TARGET_FOLDER}/{c.DUCKDB_VENV_FILE}")
|
|
151
|
+
return str(u.Path(self._filepath, duckdb_filepath_setting_val))
|
|
152
|
+
|
|
153
|
+
def close(self) -> None:
|
|
154
|
+
"""
|
|
155
|
+
Deliberately close any open resources within the Squirrels project, such as database connections (instead of relying on the garbage collector).
|
|
156
|
+
"""
|
|
157
|
+
self._conn_set.dispose()
|
|
158
|
+
self._auth.close()
|
|
159
|
+
|
|
160
|
+
def __exit__(self, exc_type, exc_val, traceback):
|
|
161
|
+
self.close()
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _add_model(self, models_dict: dict[str, M], model: M) -> None:
|
|
165
|
+
if model.name in models_dict:
|
|
166
|
+
raise ConfigurationError(f"Names across all models must be unique. Model '{model.name}' is duplicated")
|
|
167
|
+
models_dict[model.name] = model
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _get_static_models(self) -> dict[str, m.StaticModel]:
|
|
171
|
+
models_dict: dict[str, m.StaticModel] = {}
|
|
172
|
+
|
|
173
|
+
seeds_dict = self._seeds.get_dataframes()
|
|
174
|
+
for key, seed in seeds_dict.items():
|
|
175
|
+
self._add_model(models_dict, m.Seed(key, seed.config, seed.df, logger=self._logger, env_vars=self._env_vars, conn_set=self._conn_set))
|
|
176
|
+
|
|
177
|
+
for source_name, source_config in self._sources.sources.items():
|
|
178
|
+
self._add_model(models_dict, m.SourceModel(source_name, source_config, logger=self._logger, env_vars=self._env_vars, conn_set=self._conn_set))
|
|
179
|
+
|
|
180
|
+
for name, val in self._build_model_files.items():
|
|
181
|
+
model = m.BuildModel(name, val.config, val.query_file, logger=self._logger, env_vars=self._env_vars, conn_set=self._conn_set, j2_env=self._j2_env)
|
|
182
|
+
self._add_model(models_dict, model)
|
|
183
|
+
|
|
184
|
+
return models_dict
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
async def build(self, *, full_refresh: bool = False, select: str | None = None, stage_file: bool = False) -> None:
|
|
188
|
+
"""
|
|
189
|
+
Build the virtual data environment for the Squirrels project
|
|
190
|
+
|
|
191
|
+
Arguments:
|
|
192
|
+
full_refresh: Whether to drop all tables and rebuild the virtual data environment from scratch. Default is False.
|
|
193
|
+
stage_file: Whether to stage the DuckDB file to overwrite the existing one later if the virtual data environment is in use. Default is False.
|
|
194
|
+
"""
|
|
195
|
+
models_dict: dict[str, m.StaticModel] = self._get_static_models()
|
|
196
|
+
builder = ModelBuilder(self._duckdb_venv_path, self._conn_set, models_dict, self._conn_args, self._logger)
|
|
197
|
+
await builder.build(full_refresh, select, stage_file)
|
|
198
|
+
|
|
199
|
+
def _get_models_dict(self, always_python_df: bool) -> dict[str, m.DataModel]:
|
|
200
|
+
models_dict: dict[str, m.DataModel] = dict(self._get_static_models())
|
|
201
|
+
|
|
202
|
+
for name, val in self._dbview_model_files.items():
|
|
203
|
+
self._add_model(models_dict, m.DbviewModel(
|
|
204
|
+
name, val.config, val.query_file, logger=self._logger, env_vars=self._env_vars, conn_set=self._conn_set, j2_env=self._j2_env
|
|
205
|
+
))
|
|
206
|
+
models_dict[name].needs_python_df = always_python_df
|
|
207
|
+
|
|
208
|
+
for name, val in self._federate_model_files.items():
|
|
209
|
+
self._add_model(models_dict, m.FederateModel(
|
|
210
|
+
name, val.config, val.query_file, logger=self._logger, env_vars=self._env_vars, conn_set=self._conn_set, j2_env=self._j2_env
|
|
211
|
+
))
|
|
212
|
+
models_dict[name].needs_python_df = always_python_df
|
|
213
|
+
|
|
214
|
+
return models_dict
|
|
215
|
+
|
|
216
|
+
def _generate_dag(self, dataset: str, *, target_model_name: str | None = None, always_python_df: bool = False) -> m.DAG:
|
|
217
|
+
models_dict = self._get_models_dict(always_python_df)
|
|
218
|
+
|
|
219
|
+
dataset_config = self._manifest_cfg.datasets[dataset]
|
|
220
|
+
target_model_name = dataset_config.model if target_model_name is None else target_model_name
|
|
221
|
+
target_model = models_dict[target_model_name]
|
|
222
|
+
target_model.is_target = True
|
|
223
|
+
dag = m.DAG(dataset_config, target_model, models_dict, self._duckdb_venv_path, self._logger)
|
|
224
|
+
|
|
225
|
+
return dag
|
|
226
|
+
|
|
227
|
+
def _generate_dag_with_fake_target(self, sql_query: str | None) -> m.DAG:
|
|
228
|
+
models_dict = self._get_models_dict(always_python_df=False)
|
|
229
|
+
|
|
230
|
+
if sql_query is None:
|
|
231
|
+
dependencies = set(models_dict.keys())
|
|
232
|
+
else:
|
|
233
|
+
dependencies, parsed = u.parse_dependent_tables(sql_query, models_dict.keys())
|
|
234
|
+
|
|
235
|
+
substitutions = {}
|
|
236
|
+
for model_name in dependencies:
|
|
237
|
+
model = models_dict[model_name]
|
|
238
|
+
if isinstance(model, m.SourceModel) and not model.model_config.load_to_duckdb:
|
|
239
|
+
raise InvalidInputError(203, f"Source model '{model_name}' cannot be queried with DuckDB")
|
|
240
|
+
if isinstance(model, (m.SourceModel, m.BuildModel)):
|
|
241
|
+
substitutions[model_name] = f"venv.{model_name}"
|
|
242
|
+
|
|
243
|
+
sql_query = parsed.transform(
|
|
244
|
+
lambda node: sqlglot.expressions.Table(this=substitutions[node.name])
|
|
245
|
+
if isinstance(node, sqlglot.expressions.Table) and node.name in substitutions
|
|
246
|
+
else node
|
|
247
|
+
).sql()
|
|
248
|
+
|
|
249
|
+
model_config = mc.FederateModelConfig(depends_on=dependencies)
|
|
250
|
+
query_file = mq.SqlQueryFile("", sql_query or "")
|
|
251
|
+
fake_target_model = m.FederateModel(
|
|
252
|
+
"__fake_target", model_config, query_file, logger=self._logger, env_vars=self._env_vars, conn_set=self._conn_set, j2_env=self._j2_env
|
|
253
|
+
)
|
|
254
|
+
fake_target_model.is_target = True
|
|
255
|
+
dag = m.DAG(None, fake_target_model, models_dict, self._duckdb_venv_path, self._logger)
|
|
256
|
+
return dag
|
|
257
|
+
|
|
258
|
+
def _draw_dag(self, dag: m.DAG, output_folder: u.Path) -> None:
|
|
259
|
+
color_map = {
|
|
260
|
+
m.ModelType.SEED: "green", m.ModelType.DBVIEW: "red", m.ModelType.FEDERATE: "skyblue",
|
|
261
|
+
m.ModelType.BUILD: "purple", m.ModelType.SOURCE: "orange"
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
G = dag.to_networkx_graph()
|
|
265
|
+
|
|
266
|
+
fig, _ = plt.subplots()
|
|
267
|
+
pos = nx.multipartite_layout(G, subset_key="layer")
|
|
268
|
+
colors = [color_map[node[1]] for node in G.nodes(data="model_type")] # type: ignore
|
|
269
|
+
nx.draw(G, pos=pos, node_shape='^', node_size=1000, node_color=colors, arrowsize=20)
|
|
270
|
+
|
|
271
|
+
y_values = [val[1] for val in pos.values()]
|
|
272
|
+
scale = max(y_values) - min(y_values) if len(y_values) > 0 else 0
|
|
273
|
+
label_pos = {key: (val[0], val[1]-0.002-0.1*scale) for key, val in pos.items()}
|
|
274
|
+
nx.draw_networkx_labels(G, pos=label_pos, font_size=8)
|
|
275
|
+
|
|
276
|
+
fig.tight_layout()
|
|
277
|
+
plt.margins(x=0.1, y=0.1)
|
|
278
|
+
fig.savefig(u.Path(output_folder, "dag.png"))
|
|
279
|
+
plt.close(fig)
|
|
280
|
+
|
|
281
|
+
async def _get_compiled_dag(self, *, sql_query: str | None = None, selections: dict[str, t.Any] = {}, user: BaseUser | None = None) -> m.DAG:
|
|
282
|
+
dag = self._generate_dag_with_fake_target(sql_query)
|
|
283
|
+
|
|
284
|
+
default_traits = self._manifest_cfg.get_default_traits()
|
|
285
|
+
await dag.execute(self._param_args, self._param_cfg_set, self._context_func, user, selections, runquery=False, default_traits=default_traits)
|
|
286
|
+
return dag
|
|
287
|
+
|
|
288
|
+
def _get_all_connections(self) -> list[arm.ConnectionItemModel]:
|
|
289
|
+
connections = []
|
|
290
|
+
for conn_name, conn_props in self._conn_set.get_connections_as_dict().items():
|
|
291
|
+
if isinstance(conn_props, mf.ConnectionProperties):
|
|
292
|
+
label = conn_props.label if conn_props.label is not None else conn_name
|
|
293
|
+
connections.append(arm.ConnectionItemModel(name=conn_name, label=label))
|
|
294
|
+
return connections
|
|
295
|
+
|
|
296
|
+
def _get_all_data_models(self, compiled_dag: m.DAG) -> list[arm.DataModelItem]:
|
|
297
|
+
return compiled_dag.get_all_data_models()
|
|
298
|
+
|
|
299
|
+
async def get_all_data_models(self) -> list[arm.DataModelItem]:
|
|
300
|
+
"""
|
|
301
|
+
Get all data models in the project
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
A list of DataModelItem objects
|
|
305
|
+
"""
|
|
306
|
+
compiled_dag = await self._get_compiled_dag()
|
|
307
|
+
return self._get_all_data_models(compiled_dag)
|
|
308
|
+
|
|
309
|
+
def _get_all_data_lineage(self, compiled_dag: m.DAG) -> list[arm.LineageRelation]:
|
|
310
|
+
all_lineage = compiled_dag.get_all_model_lineage()
|
|
311
|
+
|
|
312
|
+
# Add dataset nodes to the lineage
|
|
313
|
+
for dataset in self._manifest_cfg.datasets.values():
|
|
314
|
+
target_dataset = arm.LineageNode(name=dataset.name, type="dataset")
|
|
315
|
+
source_model = arm.LineageNode(name=dataset.model, type="model")
|
|
316
|
+
all_lineage.append(arm.LineageRelation(type="runtime", source=source_model, target=target_dataset))
|
|
317
|
+
|
|
318
|
+
# Add dashboard nodes to the lineage
|
|
319
|
+
for dashboard in self._dashboards.values():
|
|
320
|
+
target_dashboard = arm.LineageNode(name=dashboard.dashboard_name, type="dashboard")
|
|
321
|
+
datasets = set(x.dataset for x in dashboard.config.depends_on)
|
|
322
|
+
for dataset in datasets:
|
|
323
|
+
source_dataset = arm.LineageNode(name=dataset, type="dataset")
|
|
324
|
+
all_lineage.append(arm.LineageRelation(type="runtime", source=source_dataset, target=target_dashboard))
|
|
325
|
+
|
|
326
|
+
return all_lineage
|
|
327
|
+
|
|
328
|
+
async def get_all_data_lineage(self) -> list[arm.LineageRelation]:
|
|
329
|
+
"""
|
|
330
|
+
Get all data lineage in the project
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
A list of LineageRelation objects
|
|
334
|
+
"""
|
|
335
|
+
compiled_dag = await self._get_compiled_dag()
|
|
336
|
+
return self._get_all_data_lineage(compiled_dag)
|
|
337
|
+
|
|
338
|
+
async def _write_dataset_outputs_given_test_set(
|
|
339
|
+
self, dataset: str, select: str, test_set: str | None, runquery: bool, recurse: bool
|
|
340
|
+
) -> t.Any | None:
|
|
341
|
+
dataset_conf = self._manifest_cfg.datasets[dataset]
|
|
342
|
+
default_test_set_conf = self._manifest_cfg.get_default_test_set(dataset)
|
|
343
|
+
if test_set in self._manifest_cfg.selection_test_sets:
|
|
344
|
+
test_set_conf = self._manifest_cfg.selection_test_sets[test_set]
|
|
345
|
+
elif test_set is None or test_set == default_test_set_conf.name:
|
|
346
|
+
test_set, test_set_conf = default_test_set_conf.name, default_test_set_conf
|
|
347
|
+
else:
|
|
348
|
+
raise ConfigurationError(f"No test set named '{test_set}' was found when compiling dataset '{dataset}'. The test set must be defined if not default for dataset.")
|
|
349
|
+
|
|
350
|
+
error_msg_intro = f"Cannot compile dataset '{dataset}' with test set '{test_set}'."
|
|
351
|
+
if test_set_conf.datasets is not None and dataset not in test_set_conf.datasets:
|
|
352
|
+
raise ConfigurationError(f"{error_msg_intro}\n Applicable datasets for test set '{test_set}' does not include dataset '{dataset}'.")
|
|
353
|
+
|
|
354
|
+
user_attributes = test_set_conf.user_attributes.copy() if test_set_conf.user_attributes is not None else {}
|
|
355
|
+
selections = test_set_conf.parameters.copy()
|
|
356
|
+
username, is_admin = user_attributes.pop("username", ""), user_attributes.pop("is_admin", False)
|
|
357
|
+
if test_set_conf.is_authenticated:
|
|
358
|
+
user = self._auth.User(username=username, is_admin=is_admin, **user_attributes)
|
|
359
|
+
elif dataset_conf.scope == mf.PermissionScope.PUBLIC:
|
|
360
|
+
user = None
|
|
361
|
+
else:
|
|
362
|
+
raise ConfigurationError(f"{error_msg_intro}\n Non-public datasets require a test set with 'user_attributes' section defined")
|
|
363
|
+
|
|
364
|
+
if dataset_conf.scope == mf.PermissionScope.PRIVATE and not is_admin:
|
|
365
|
+
raise ConfigurationError(f"{error_msg_intro}\n Private datasets require a test set with user_attribute 'is_admin' set to true")
|
|
366
|
+
|
|
367
|
+
# always_python_df is set to True for creating CSV files from results (when runquery is True)
|
|
368
|
+
dag = self._generate_dag(dataset, target_model_name=select, always_python_df=runquery)
|
|
369
|
+
await dag.execute(
|
|
370
|
+
self._param_args, self._param_cfg_set, self._context_func, user, selections,
|
|
371
|
+
runquery=runquery, recurse=recurse, default_traits=self._manifest_cfg.get_default_traits()
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
output_folder = u.Path(self._filepath, c.TARGET_FOLDER, c.COMPILE_FOLDER, dataset, test_set)
|
|
375
|
+
if output_folder.exists():
|
|
376
|
+
shutil.rmtree(output_folder)
|
|
377
|
+
output_folder.mkdir(parents=True, exist_ok=True)
|
|
378
|
+
|
|
379
|
+
def write_placeholders() -> None:
|
|
380
|
+
output_filepath = u.Path(output_folder, "placeholders.json")
|
|
381
|
+
with open(output_filepath, 'w') as f:
|
|
382
|
+
json.dump(dag.placeholders, f, indent=4)
|
|
383
|
+
|
|
384
|
+
def write_model_outputs(model: m.DataModel) -> None:
|
|
385
|
+
assert isinstance(model, m.QueryModel)
|
|
386
|
+
subfolder = c.DBVIEWS_FOLDER if model.model_type == m.ModelType.DBVIEW else c.FEDERATES_FOLDER
|
|
387
|
+
subpath = u.Path(output_folder, subfolder)
|
|
388
|
+
subpath.mkdir(parents=True, exist_ok=True)
|
|
389
|
+
if isinstance(model.compiled_query, mq.SqlModelQuery):
|
|
390
|
+
output_filepath = u.Path(subpath, model.name+'.sql')
|
|
391
|
+
query = model.compiled_query.query
|
|
392
|
+
with open(output_filepath, 'w') as f:
|
|
393
|
+
f.write(query)
|
|
394
|
+
if runquery and isinstance(model.result, pl.LazyFrame):
|
|
395
|
+
output_filepath = u.Path(subpath, model.name+'.csv')
|
|
396
|
+
model.result.collect().write_csv(output_filepath)
|
|
397
|
+
|
|
398
|
+
write_placeholders()
|
|
399
|
+
all_model_names = dag.get_all_query_models()
|
|
400
|
+
coroutines = [asyncio.to_thread(write_model_outputs, dag.models_dict[name]) for name in all_model_names]
|
|
401
|
+
await u.asyncio_gather(coroutines)
|
|
402
|
+
|
|
403
|
+
if recurse:
|
|
404
|
+
self._draw_dag(dag, output_folder)
|
|
405
|
+
|
|
406
|
+
if isinstance(dag.target_model, m.QueryModel) and dag.target_model.compiled_query is not None:
|
|
407
|
+
return dag.target_model.compiled_query.query
|
|
408
|
+
|
|
409
|
+
async def compile(
|
|
410
|
+
self, *, dataset: str | None = None, do_all_datasets: bool = False, selected_model: str | None = None, test_set: str | None = None,
|
|
411
|
+
do_all_test_sets: bool = False, runquery: bool = False
|
|
412
|
+
) -> None:
|
|
413
|
+
"""
|
|
414
|
+
Async method to compile the SQL templates into files in the "target/" folder. Same functionality as the "sqrl compile" CLI.
|
|
415
|
+
|
|
416
|
+
Although all arguments are "optional", the "dataset" argument is required if "do_all_datasets" argument is False.
|
|
417
|
+
|
|
418
|
+
Arguments:
|
|
419
|
+
dataset: The name of the dataset to compile. Ignored if "do_all_datasets" argument is True, but required (i.e., cannot be None) if "do_all_datasets" is False. Default is None.
|
|
420
|
+
do_all_datasets: If True, compile all datasets and ignore the "dataset" argument. Default is False.
|
|
421
|
+
selected_model: The name of the model to compile. If specified, the compiled SQL query is also printed in the terminal. If None, all models for the selected dataset are compiled. Default is None.
|
|
422
|
+
test_set: The name of the test set to compile with. If None, the default test set is used (which can vary by dataset). Ignored if `do_all_test_sets` argument is True. Default is None.
|
|
423
|
+
do_all_test_sets: Whether to compile all applicable test sets for the selected dataset(s). If True, the `test_set` argument is ignored. Default is False.
|
|
424
|
+
runquery**: Whether to run all compiled queries and save each result as a CSV file. If True and `selected_model` is specified, all upstream models of the selected model is compiled as well. Default is False.
|
|
425
|
+
"""
|
|
426
|
+
recurse = True
|
|
427
|
+
if do_all_datasets:
|
|
428
|
+
selected_models = [(dataset.name, dataset.model) for dataset in self._manifest_cfg.datasets.values()]
|
|
429
|
+
else:
|
|
430
|
+
assert isinstance(dataset, str), "argument 'dataset' must be provided a string value if argument 'do_all_datasets' is False"
|
|
431
|
+
assert dataset in self._manifest_cfg.datasets, f"dataset '{dataset}' not found in {c.MANIFEST_FILE}"
|
|
432
|
+
if selected_model is None:
|
|
433
|
+
selected_model = self._manifest_cfg.datasets[dataset].model
|
|
434
|
+
else:
|
|
435
|
+
recurse = False
|
|
436
|
+
selected_models = [(dataset, selected_model)]
|
|
437
|
+
|
|
438
|
+
coroutines: list[t.Coroutine] = []
|
|
439
|
+
for dataset, selected_model in selected_models:
|
|
440
|
+
if do_all_test_sets:
|
|
441
|
+
for test_set_name in self._manifest_cfg.get_applicable_test_sets(dataset):
|
|
442
|
+
coroutine = self._write_dataset_outputs_given_test_set(dataset, selected_model, test_set_name, runquery, recurse)
|
|
443
|
+
coroutines.append(coroutine)
|
|
444
|
+
|
|
445
|
+
coroutine = self._write_dataset_outputs_given_test_set(dataset, selected_model, test_set, runquery, recurse)
|
|
446
|
+
coroutines.append(coroutine)
|
|
447
|
+
|
|
448
|
+
queries = await u.asyncio_gather(coroutines)
|
|
449
|
+
|
|
450
|
+
print(f"Compiled successfully! See the '{c.TARGET_FOLDER}/' folder for results.")
|
|
451
|
+
print()
|
|
452
|
+
if not recurse and len(queries) == 1 and isinstance(queries[0], str):
|
|
453
|
+
print(queries[0])
|
|
454
|
+
print()
|
|
455
|
+
|
|
456
|
+
def _permission_error(self, user: BaseUser | None, data_type: str, data_name: str, scope: str) -> InvalidInputError:
|
|
457
|
+
username = "" if user is None else f" '{user.username}'"
|
|
458
|
+
return InvalidInputError(25, f"User{username} does not have permission to access {scope} {data_type}: {data_name}")
|
|
459
|
+
|
|
460
|
+
def seed(self, name: str) -> pl.LazyFrame:
|
|
461
|
+
"""
|
|
462
|
+
Method to retrieve a seed as a polars LazyFrame given a seed name.
|
|
463
|
+
|
|
464
|
+
Arguments:
|
|
465
|
+
name: The name of the seed to retrieve
|
|
466
|
+
|
|
467
|
+
Returns:
|
|
468
|
+
The seed as a polars LazyFrame
|
|
469
|
+
"""
|
|
470
|
+
seeds_dict = self._seeds.get_dataframes()
|
|
471
|
+
try:
|
|
472
|
+
return seeds_dict[name].df
|
|
473
|
+
except KeyError:
|
|
474
|
+
available_seeds = list(seeds_dict.keys())
|
|
475
|
+
raise KeyError(f"Seed '{name}' not found. Available seeds are: {available_seeds}")
|
|
476
|
+
|
|
477
|
+
def dataset_metadata(self, name: str) -> dr.DatasetMetadata:
|
|
478
|
+
"""
|
|
479
|
+
Method to retrieve the metadata of a dataset given a dataset name.
|
|
480
|
+
|
|
481
|
+
Arguments:
|
|
482
|
+
name: The name of the dataset to retrieve.
|
|
483
|
+
|
|
484
|
+
Returns:
|
|
485
|
+
A DatasetMetadata object containing the dataset description and column details.
|
|
486
|
+
"""
|
|
487
|
+
dag = self._generate_dag(name)
|
|
488
|
+
dag.target_model.process_pass_through_columns(dag.models_dict)
|
|
489
|
+
return dr.DatasetMetadata(
|
|
490
|
+
target_model_config=dag.target_model.model_config
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
async def dataset(
|
|
494
|
+
self, name: str, *, selections: dict[str, t.Any] = {}, user: BaseUser | None = None, require_auth: bool = True
|
|
495
|
+
) -> dr.DatasetResult:
|
|
496
|
+
"""
|
|
497
|
+
Async method to retrieve a dataset as a DatasetResult object (with metadata) given parameter selections.
|
|
498
|
+
|
|
499
|
+
Arguments:
|
|
500
|
+
name: The name of the dataset to retrieve.
|
|
501
|
+
selections: A dictionary of parameter selections to apply to the dataset. Optional, default is empty dictionary.
|
|
502
|
+
user: The user to use for authentication. If None, no user is used. Optional, default is None.
|
|
503
|
+
|
|
504
|
+
Returns:
|
|
505
|
+
A DatasetResult object containing the dataset result (as a polars DataFrame), its description, and the column details.
|
|
506
|
+
"""
|
|
507
|
+
scope = self._manifest_cfg.datasets[name].scope
|
|
508
|
+
if require_auth and not self._auth.can_user_access_scope(user, scope):
|
|
509
|
+
raise self._permission_error(user, "dataset", name, scope.name)
|
|
510
|
+
|
|
511
|
+
dag = self._generate_dag(name)
|
|
512
|
+
await dag.execute(
|
|
513
|
+
self._param_args, self._param_cfg_set, self._context_func, user, dict(selections),
|
|
514
|
+
default_traits=self._manifest_cfg.get_default_traits()
|
|
515
|
+
)
|
|
516
|
+
assert isinstance(dag.target_model.result, pl.LazyFrame)
|
|
517
|
+
return dr.DatasetResult(
|
|
518
|
+
target_model_config=dag.target_model.model_config,
|
|
519
|
+
df=dag.target_model.result.collect().with_row_index("_row_num", offset=1)
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
async def dashboard(
|
|
523
|
+
self, name: str, *, selections: dict[str, t.Any] = {}, user: BaseUser | None = None, dashboard_type: t.Type[T] = dash.Dashboard
|
|
524
|
+
) -> T:
|
|
525
|
+
"""
|
|
526
|
+
Async method to retrieve a dashboard given parameter selections.
|
|
527
|
+
|
|
528
|
+
Arguments:
|
|
529
|
+
name: The name of the dashboard to retrieve.
|
|
530
|
+
selections: A dictionary of parameter selections to apply to the dashboard. Optional, default is empty dictionary.
|
|
531
|
+
user: The user to use for authentication. If None, no user is used. Optional, default is None.
|
|
532
|
+
dashboard_type: Return type of the method (mainly used for type hints). For instance, provide PngDashboard if you want the return type to be a PngDashboard. Optional, default is squirrels.Dashboard.
|
|
533
|
+
|
|
534
|
+
Returns:
|
|
535
|
+
The dashboard type specified by the "dashboard_type" argument.
|
|
536
|
+
"""
|
|
537
|
+
scope = self._dashboards[name].config.scope
|
|
538
|
+
if not self._auth.can_user_access_scope(user, scope):
|
|
539
|
+
raise self._permission_error(user, "dashboard", name, scope.name)
|
|
540
|
+
|
|
541
|
+
async def get_dataset_df(dataset_name: str, fixed_params: dict[str, t.Any]) -> pl.DataFrame:
|
|
542
|
+
final_selections = {**selections, **fixed_params}
|
|
543
|
+
result = await self.dataset(dataset_name, selections=final_selections, user=user, require_auth=False)
|
|
544
|
+
return result.df
|
|
545
|
+
|
|
546
|
+
args = d.DashboardArgs(self._param_args, get_dataset_df)
|
|
547
|
+
try:
|
|
548
|
+
return await self._dashboards[name].get_dashboard(args, dashboard_type=dashboard_type)
|
|
549
|
+
except KeyError:
|
|
550
|
+
raise KeyError(f"No dashboard file found for: {name}")
|
|
551
|
+
|
|
552
|
+
async def query_models(
|
|
553
|
+
self, sql_query: str, *, selections: dict[str, t.Any] = {}, user: BaseUser | None = None
|
|
554
|
+
) -> dr.DatasetResult:
|
|
555
|
+
dag = await self._get_compiled_dag(sql_query=sql_query, selections=selections, user=user)
|
|
556
|
+
await dag._run_models()
|
|
557
|
+
assert isinstance(dag.target_model.result, pl.LazyFrame)
|
|
558
|
+
return dr.DatasetResult(
|
|
559
|
+
target_model_config=dag.target_model.model_config,
|
|
560
|
+
df=dag.target_model.result.collect().with_row_index("_row_num", offset=1)
|
|
561
|
+
)
|
squirrels/_py_module.py
CHANGED
|
@@ -2,6 +2,7 @@ from typing import Type, Optional, Any
|
|
|
2
2
|
import importlib.util
|
|
3
3
|
|
|
4
4
|
from . import _constants as c, _utils as u
|
|
5
|
+
from ._exceptions import ConfigurationError, FileExecutionError
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
class PyModule:
|
|
@@ -21,7 +22,7 @@ class PyModule:
|
|
|
21
22
|
spec.loader.exec_module(self.module)
|
|
22
23
|
except FileNotFoundError as e:
|
|
23
24
|
if is_required:
|
|
24
|
-
raise
|
|
25
|
+
raise ConfigurationError(f"Required file not found: '{self.filepath}'") from e
|
|
25
26
|
self.module = default_class
|
|
26
27
|
|
|
27
28
|
def get_func_or_class(self, attr_name: str, *, default_attr: Any = None, is_required: bool = True) -> Any:
|
|
@@ -38,7 +39,7 @@ class PyModule:
|
|
|
38
39
|
"""
|
|
39
40
|
func_or_class = getattr(self.module, attr_name, default_attr)
|
|
40
41
|
if func_or_class is None and is_required:
|
|
41
|
-
raise
|
|
42
|
+
raise ConfigurationError(f"Module '{self.filepath}' missing required attribute '{attr_name}'")
|
|
42
43
|
return func_or_class
|
|
43
44
|
|
|
44
45
|
|
|
@@ -57,4 +58,4 @@ def run_pyconfig_main(base_path: str, filename: str, kwargs: dict[str, Any] = {}
|
|
|
57
58
|
try:
|
|
58
59
|
main_function(**kwargs)
|
|
59
60
|
except Exception as e:
|
|
60
|
-
raise
|
|
61
|
+
raise FileExecutionError(f'Failed to run python file "{filepath}"', e) from e
|
squirrels/_seeds.py
CHANGED
|
@@ -1,39 +1,58 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
import os, time, glob,
|
|
2
|
+
import os, time, glob, polars as pl, json
|
|
3
3
|
|
|
4
|
-
from .
|
|
5
|
-
|
|
4
|
+
from . import _utils as u, _constants as c, _model_configs as mc
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class Seed:
|
|
9
|
+
config: mc.SeedConfig
|
|
10
|
+
df: pl.LazyFrame
|
|
11
|
+
|
|
12
|
+
def __post_init__(self):
|
|
13
|
+
if self.config.cast_column_types:
|
|
14
|
+
exprs = []
|
|
15
|
+
for col_config in self.config.columns:
|
|
16
|
+
sqrl_dtype = "double" if col_config.type.lower().startswith("decimal") else col_config.type
|
|
17
|
+
polars_dtype = u.sqrl_dtypes_to_polars_dtypes.get(sqrl_dtype, pl.String)
|
|
18
|
+
exprs.append(pl.col(col_config.name).cast(polars_dtype))
|
|
19
|
+
|
|
20
|
+
self.df = self.df.with_columns(*exprs)
|
|
6
21
|
|
|
7
22
|
|
|
8
23
|
@dataclass
|
|
9
24
|
class Seeds:
|
|
10
|
-
_data: dict[str,
|
|
11
|
-
_manifest_cfg: ManifestConfig
|
|
25
|
+
_data: dict[str, Seed]
|
|
12
26
|
|
|
13
|
-
def run_query(self, sql_query: str) ->
|
|
14
|
-
|
|
15
|
-
return
|
|
27
|
+
def run_query(self, sql_query: str) -> pl.DataFrame:
|
|
28
|
+
dataframes = {key: seed.df for key, seed in self._data.items()}
|
|
29
|
+
return u.run_sql_on_dataframes(sql_query, dataframes)
|
|
16
30
|
|
|
17
|
-
def get_dataframes(self) -> dict[str,
|
|
31
|
+
def get_dataframes(self) -> dict[str, Seed]:
|
|
18
32
|
return self._data.copy()
|
|
19
33
|
|
|
20
34
|
|
|
21
35
|
class SeedsIO:
|
|
22
36
|
|
|
23
37
|
@classmethod
|
|
24
|
-
def load_files(cls, logger:
|
|
38
|
+
def load_files(cls, logger: u.Logger, base_path: str, env_vars: dict[str, str]) -> Seeds:
|
|
25
39
|
start = time.time()
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
csv_dtype = None if infer_schema else str
|
|
40
|
+
infer_schema_setting: bool = (env_vars.get(c.SQRL_SEEDS_INFER_SCHEMA, "true").lower() == "true")
|
|
41
|
+
na_values_setting: list[str] = json.loads(env_vars.get(c.SQRL_SEEDS_NA_VALUES, "[]"))
|
|
29
42
|
|
|
30
43
|
seeds_dict = {}
|
|
31
44
|
csv_files = glob.glob(os.path.join(base_path, c.SEEDS_FOLDER, '**/*.csv'), recursive=True)
|
|
32
45
|
for csv_file in csv_files:
|
|
46
|
+
config_file = os.path.splitext(csv_file)[0] + '.yml'
|
|
47
|
+
config_dict = u.load_yaml_config(config_file) if os.path.exists(config_file) else {}
|
|
48
|
+
config = mc.SeedConfig(**config_dict)
|
|
49
|
+
|
|
33
50
|
file_stem = os.path.splitext(os.path.basename(csv_file))[0]
|
|
34
|
-
|
|
35
|
-
|
|
51
|
+
infer_schema = not config.cast_column_types and infer_schema_setting
|
|
52
|
+
df = pl.read_csv(csv_file, try_parse_dates=True, infer_schema=infer_schema, null_values=na_values_setting).lazy()
|
|
53
|
+
|
|
54
|
+
seeds_dict[file_stem] = Seed(config, df)
|
|
36
55
|
|
|
37
|
-
seeds = Seeds(seeds_dict
|
|
56
|
+
seeds = Seeds(seeds_dict)
|
|
38
57
|
logger.log_activity_time("loading seed files", start)
|
|
39
58
|
return seeds
|