squirrels 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of squirrels might be problematic. Click here for more details.
- squirrels/__init__.py +7 -3
- squirrels/_api_response_models.py +96 -72
- squirrels/_api_server.py +375 -201
- squirrels/_authenticator.py +23 -22
- squirrels/_command_line.py +70 -46
- squirrels/_connection_set.py +23 -25
- squirrels/_constants.py +29 -78
- squirrels/_dashboards_io.py +61 -0
- squirrels/_environcfg.py +53 -50
- squirrels/_initializer.py +184 -141
- squirrels/_manifest.py +168 -195
- squirrels/_models.py +159 -292
- squirrels/_package_loader.py +7 -8
- squirrels/_parameter_configs.py +173 -141
- squirrels/_parameter_sets.py +49 -38
- squirrels/_py_module.py +7 -7
- squirrels/_seeds.py +13 -12
- squirrels/_utils.py +114 -54
- squirrels/_version.py +1 -1
- squirrels/arguments/init_time_args.py +16 -10
- squirrels/arguments/run_time_args.py +89 -24
- squirrels/dashboards.py +82 -0
- squirrels/data_sources.py +212 -232
- squirrels/dateutils.py +29 -26
- squirrels/package_data/assets/index.css +1 -1
- squirrels/package_data/assets/index.js +27 -18
- squirrels/package_data/base_project/.gitignore +2 -2
- squirrels/package_data/base_project/connections.yml +1 -1
- squirrels/package_data/base_project/dashboards/dashboard_example.py +32 -0
- squirrels/package_data/base_project/dashboards.yml +10 -0
- squirrels/package_data/base_project/docker/.dockerignore +9 -4
- squirrels/package_data/base_project/docker/Dockerfile +7 -6
- squirrels/package_data/base_project/docker/compose.yml +1 -1
- squirrels/package_data/base_project/env.yml +2 -2
- squirrels/package_data/base_project/models/dbviews/{database_view1.py → dbview_example.py} +2 -1
- squirrels/package_data/base_project/models/dbviews/{database_view1.sql → dbview_example.sql} +3 -2
- squirrels/package_data/base_project/models/federates/{dataset_example.py → federate_example.py} +6 -6
- squirrels/package_data/base_project/models/federates/{dataset_example.sql → federate_example.sql} +1 -1
- squirrels/package_data/base_project/parameters.yml +6 -4
- squirrels/package_data/base_project/pyconfigs/auth.py +1 -1
- squirrels/package_data/base_project/pyconfigs/connections.py +1 -1
- squirrels/package_data/base_project/pyconfigs/context.py +38 -10
- squirrels/package_data/base_project/pyconfigs/parameters.py +15 -7
- squirrels/package_data/base_project/squirrels.yml.j2 +14 -7
- squirrels/package_data/templates/index.html +3 -3
- squirrels/parameter_options.py +103 -106
- squirrels/parameters.py +347 -195
- squirrels/project.py +378 -0
- squirrels/user_base.py +14 -6
- {squirrels-0.3.3.dist-info → squirrels-0.4.0.dist-info}/METADATA +9 -21
- squirrels-0.4.0.dist-info/RECORD +60 -0
- squirrels/_timer.py +0 -23
- squirrels-0.3.3.dist-info/RECORD +0 -56
- {squirrels-0.3.3.dist-info → squirrels-0.4.0.dist-info}/LICENSE +0 -0
- {squirrels-0.3.3.dist-info → squirrels-0.4.0.dist-info}/WHEEL +0 -0
- {squirrels-0.3.3.dist-info → squirrels-0.4.0.dist-info}/entry_points.txt +0 -0
squirrels/_models.py
CHANGED
|
@@ -1,32 +1,28 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import Iterable, Callable, Any
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
4
|
from abc import ABCMeta, abstractmethod
|
|
5
5
|
from enum import Enum
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
from sqlalchemy import create_engine, text, Connection
|
|
8
|
-
import asyncio, os,
|
|
9
|
-
import matplotlib.pyplot as plt, networkx as nx
|
|
8
|
+
import asyncio, os, time, pandas as pd, networkx as nx
|
|
10
9
|
|
|
11
10
|
from . import _constants as c, _utils as u, _py_module as pm
|
|
12
11
|
from .arguments.run_time_args import ContextArgs, ModelDepsArgs, ModelArgs
|
|
13
|
-
from ._authenticator import User
|
|
14
|
-
from ._connection_set import
|
|
15
|
-
from ._manifest import
|
|
16
|
-
from ._parameter_sets import
|
|
17
|
-
|
|
18
|
-
|
|
12
|
+
from ._authenticator import User
|
|
13
|
+
from ._connection_set import ConnectionSet
|
|
14
|
+
from ._manifest import ManifestConfig, DatasetConfig
|
|
15
|
+
from ._parameter_sets import ParameterConfigsSet, ParametersArgs, ParameterSet
|
|
16
|
+
|
|
17
|
+
ContextFunc = Callable[[dict[str, Any], ContextArgs], None]
|
|
18
|
+
|
|
19
19
|
|
|
20
20
|
class ModelType(Enum):
|
|
21
21
|
DBVIEW = 1
|
|
22
22
|
FEDERATE = 2
|
|
23
23
|
SEED = 3
|
|
24
24
|
|
|
25
|
-
class
|
|
26
|
-
SQL = 0
|
|
27
|
-
PYTHON = 1
|
|
28
|
-
|
|
29
|
-
class Materialization(Enum):
|
|
25
|
+
class _Materialization(Enum):
|
|
30
26
|
TABLE = 0
|
|
31
27
|
VIEW = 1
|
|
32
28
|
|
|
@@ -37,52 +33,46 @@ class _SqlModelConfig:
|
|
|
37
33
|
connection_name: str
|
|
38
34
|
|
|
39
35
|
## Applicable for federated models
|
|
40
|
-
materialized:
|
|
36
|
+
materialized: _Materialization
|
|
41
37
|
|
|
42
|
-
def set_attribute(self, **kwargs) -> str:
|
|
43
|
-
connection_name = kwargs.get(c.DBVIEW_CONN_KEY)
|
|
38
|
+
def set_attribute(self, *, connection_name: str | None = None, materialized: str | None = None, **kwargs) -> str:
|
|
44
39
|
if connection_name is not None:
|
|
45
40
|
if not isinstance(connection_name, str):
|
|
46
41
|
raise u.ConfigurationError("The 'connection_name' argument of 'config' macro must be a string")
|
|
47
42
|
self.connection_name = connection_name
|
|
48
43
|
|
|
49
|
-
materialized: str = kwargs.get(c.MATERIALIZED_KEY)
|
|
50
44
|
if materialized is not None:
|
|
51
45
|
if not isinstance(materialized, str):
|
|
52
46
|
raise u.ConfigurationError("The 'materialized' argument of 'config' macro must be a string")
|
|
53
47
|
try:
|
|
54
|
-
self.materialized =
|
|
48
|
+
self.materialized = _Materialization[materialized.upper()]
|
|
55
49
|
except KeyError as e:
|
|
56
|
-
valid_options = [x.name for x in
|
|
50
|
+
valid_options = [x.name for x in _Materialization]
|
|
57
51
|
raise u.ConfigurationError(f"The 'materialized' argument value '{materialized}' is not valid. Must be one of: {valid_options}") from e
|
|
58
52
|
return ""
|
|
59
53
|
|
|
60
54
|
def get_sql_for_create(self, model_name: str, select_query: str) -> str:
|
|
61
|
-
|
|
62
|
-
create_prefix = f"CREATE TABLE {model_name} AS\n"
|
|
63
|
-
elif self.materialized == Materialization.VIEW:
|
|
64
|
-
create_prefix = f"CREATE VIEW {model_name} AS\n"
|
|
65
|
-
else:
|
|
66
|
-
raise u.ConfigurationError(f"Materialization option not supported: {self.materialized}")
|
|
67
|
-
|
|
55
|
+
create_prefix = f"CREATE {self.materialized.name} {model_name} AS\n"
|
|
68
56
|
return create_prefix + select_query
|
|
69
57
|
|
|
70
58
|
|
|
71
|
-
|
|
72
|
-
|
|
59
|
+
@dataclass(frozen=True)
|
|
60
|
+
class QueryFile:
|
|
61
|
+
filepath: str
|
|
62
|
+
model_type: ModelType
|
|
73
63
|
|
|
74
64
|
@dataclass(frozen=True)
|
|
75
|
-
class
|
|
76
|
-
|
|
65
|
+
class SqlQueryFile(QueryFile):
|
|
66
|
+
raw_query: str
|
|
77
67
|
|
|
78
68
|
@dataclass(frozen=True)
|
|
79
|
-
class
|
|
80
|
-
query:
|
|
69
|
+
class _RawPyQuery:
|
|
70
|
+
query: Callable[[ModelArgs], pd.DataFrame]
|
|
71
|
+
dependencies_func: Callable[[ModelDepsArgs], Iterable[str]]
|
|
81
72
|
|
|
82
73
|
@dataclass(frozen=True)
|
|
83
|
-
class
|
|
84
|
-
|
|
85
|
-
dependencies_func: Callable[[Any], Iterable]
|
|
74
|
+
class PyQueryFile(QueryFile):
|
|
75
|
+
raw_query: _RawPyQuery
|
|
86
76
|
|
|
87
77
|
|
|
88
78
|
@dataclass
|
|
@@ -94,43 +84,35 @@ class _WorkInProgress(_Query):
|
|
|
94
84
|
query: None = field(default=None, init=False)
|
|
95
85
|
|
|
96
86
|
@dataclass
|
|
97
|
-
class
|
|
87
|
+
class SqlModelQuery(_Query):
|
|
98
88
|
query: str
|
|
99
89
|
config: _SqlModelConfig
|
|
100
90
|
|
|
101
91
|
@dataclass
|
|
102
|
-
class
|
|
92
|
+
class PyModelQuery(_Query):
|
|
103
93
|
query: Callable[[], pd.DataFrame]
|
|
104
94
|
|
|
105
95
|
|
|
106
|
-
@dataclass(frozen=True)
|
|
107
|
-
class _QueryFile:
|
|
108
|
-
filepath: str
|
|
109
|
-
model_type: ModelType
|
|
110
|
-
query_type: QueryType
|
|
111
|
-
raw_query: _RawQuery
|
|
112
|
-
|
|
113
|
-
|
|
114
96
|
@dataclass
|
|
115
|
-
class
|
|
97
|
+
class Referable(metaclass=ABCMeta):
|
|
116
98
|
name: str
|
|
117
99
|
is_target: bool = field(default=False, init=False)
|
|
118
100
|
|
|
119
101
|
needs_sql_table: bool = field(default=False, init=False)
|
|
120
102
|
needs_pandas: bool = field(default=False, init=False)
|
|
121
|
-
result:
|
|
103
|
+
result: pd.DataFrame | None = field(default=None, init=False, repr=False)
|
|
122
104
|
|
|
123
105
|
wait_count: int = field(default=0, init=False, repr=False)
|
|
124
106
|
confirmed_no_cycles: bool = field(default=False, init=False)
|
|
125
|
-
upstreams: dict[str,
|
|
126
|
-
downstreams: dict[str,
|
|
107
|
+
upstreams: dict[str, Referable] = field(default_factory=dict, init=False, repr=False)
|
|
108
|
+
downstreams: dict[str, Referable] = field(default_factory=dict, init=False, repr=False)
|
|
127
109
|
|
|
128
110
|
@abstractmethod
|
|
129
111
|
def get_model_type(self) -> ModelType:
|
|
130
112
|
pass
|
|
131
113
|
|
|
132
114
|
async def compile(
|
|
133
|
-
self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str,
|
|
115
|
+
self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str, Referable], recurse: bool
|
|
134
116
|
) -> None:
|
|
135
117
|
pass
|
|
136
118
|
|
|
@@ -160,11 +142,12 @@ class _Referable(metaclass=ABCMeta):
|
|
|
160
142
|
def retrieve_dependent_query_models(self, dependent_model_names: set[str]) -> None:
|
|
161
143
|
pass
|
|
162
144
|
|
|
163
|
-
def get_max_path_length_to_target(self) -> int:
|
|
145
|
+
def get_max_path_length_to_target(self) -> int | None:
|
|
164
146
|
if not hasattr(self, "max_path_len_to_target"):
|
|
165
147
|
path_lengths = []
|
|
166
148
|
for child_model in self.downstreams.values():
|
|
167
|
-
|
|
149
|
+
assert isinstance(child_model_path_length := child_model.get_max_path_length_to_target(), int)
|
|
150
|
+
path_lengths.append(child_model_path_length+1)
|
|
168
151
|
if len(path_lengths) > 0:
|
|
169
152
|
self.max_path_len_to_target = max(path_lengths)
|
|
170
153
|
else:
|
|
@@ -173,7 +156,7 @@ class _Referable(metaclass=ABCMeta):
|
|
|
173
156
|
|
|
174
157
|
|
|
175
158
|
@dataclass
|
|
176
|
-
class
|
|
159
|
+
class Seed(Referable):
|
|
177
160
|
result: pd.DataFrame
|
|
178
161
|
|
|
179
162
|
def get_model_type(self) -> ModelType:
|
|
@@ -189,43 +172,45 @@ class _Seed(_Referable):
|
|
|
189
172
|
|
|
190
173
|
|
|
191
174
|
@dataclass
|
|
192
|
-
class
|
|
193
|
-
query_file:
|
|
194
|
-
|
|
195
|
-
|
|
175
|
+
class Model(Referable):
|
|
176
|
+
query_file: QueryFile
|
|
177
|
+
manifest_cfg: ManifestConfig
|
|
178
|
+
conn_set: ConnectionSet
|
|
179
|
+
logger: u.Logger = field(default_factory=lambda: u.Logger(""))
|
|
180
|
+
j2_env: u.j2.Environment = field(default_factory=lambda: u.j2.Environment(loader=u.j2.FileSystemLoader(".")))
|
|
181
|
+
compiled_query: _Query | None = field(default=None, init=False)
|
|
196
182
|
|
|
197
183
|
def get_model_type(self) -> ModelType:
|
|
198
184
|
return self.query_file.model_type
|
|
199
185
|
|
|
200
|
-
def _add_upstream(self, other:
|
|
186
|
+
def _add_upstream(self, other: Referable) -> None:
|
|
201
187
|
self.upstreams[other.name] = other
|
|
202
188
|
other.downstreams[self.name] = self
|
|
203
189
|
|
|
204
|
-
if self.query_file
|
|
205
|
-
other.needs_pandas = True
|
|
206
|
-
elif self.query_file.query_type == QueryType.SQL:
|
|
190
|
+
if isinstance(self.query_file, SqlQueryFile):
|
|
207
191
|
other.needs_sql_table = True
|
|
192
|
+
elif isinstance(self.query_file, PyQueryFile):
|
|
193
|
+
other.needs_pandas = True
|
|
208
194
|
|
|
209
195
|
def _get_dbview_conn_name(self) -> str:
|
|
210
|
-
dbview_config =
|
|
196
|
+
dbview_config = self.manifest_cfg.dbviews.get(self.name)
|
|
211
197
|
if dbview_config is None or dbview_config.connection_name is None:
|
|
212
|
-
return
|
|
198
|
+
return self.manifest_cfg.settings.get(c.DB_CONN_DEFAULT_USED_SETTING, c.DEFAULT_DB_CONN)
|
|
213
199
|
return dbview_config.connection_name
|
|
214
200
|
|
|
215
|
-
def _get_materialized(self) ->
|
|
216
|
-
federate_config =
|
|
201
|
+
def _get_materialized(self) -> _Materialization:
|
|
202
|
+
federate_config = self.manifest_cfg.federates.get(self.name)
|
|
217
203
|
if federate_config is None or federate_config.materialized is None:
|
|
218
|
-
materialized =
|
|
204
|
+
materialized = self.manifest_cfg.settings.get(c.DEFAULT_MATERIALIZE_SETTING, c.DEFAULT_MATERIALIZE)
|
|
219
205
|
else:
|
|
220
206
|
materialized = federate_config.materialized
|
|
221
|
-
return
|
|
207
|
+
return _Materialization[materialized.upper()]
|
|
222
208
|
|
|
223
209
|
async def _compile_sql_model(
|
|
224
|
-
self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any]
|
|
225
|
-
) -> tuple[
|
|
226
|
-
assert
|
|
227
|
-
|
|
228
|
-
raw_query = self.query_file.raw_query.query
|
|
210
|
+
self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str, Referable]
|
|
211
|
+
) -> tuple[SqlModelQuery, set]:
|
|
212
|
+
assert isinstance(self.query_file, SqlQueryFile)
|
|
213
|
+
|
|
229
214
|
connection_name = self._get_dbview_conn_name()
|
|
230
215
|
materialized = self._get_materialized()
|
|
231
216
|
configuration = _SqlModelConfig(connection_name, materialized)
|
|
@@ -237,51 +222,68 @@ class _Model(_Referable):
|
|
|
237
222
|
}
|
|
238
223
|
dependencies = set()
|
|
239
224
|
if self.query_file.model_type == ModelType.FEDERATE:
|
|
240
|
-
def ref(
|
|
241
|
-
|
|
242
|
-
|
|
225
|
+
def ref(dependent_model_name):
|
|
226
|
+
if dependent_model_name not in models_dict:
|
|
227
|
+
raise u.ConfigurationError(f'Model "{self.name}" references unknown model "{dependent_model_name}"')
|
|
228
|
+
dependencies.add(dependent_model_name)
|
|
229
|
+
return dependent_model_name
|
|
243
230
|
kwargs["ref"] = ref
|
|
244
231
|
|
|
245
232
|
try:
|
|
246
|
-
|
|
233
|
+
template = self.j2_env.from_string(self.query_file.raw_query)
|
|
234
|
+
query = await asyncio.to_thread(template.render, kwargs)
|
|
247
235
|
except Exception as e:
|
|
248
236
|
raise u.FileExecutionError(f'Failed to compile sql model "{self.name}"', e) from e
|
|
249
237
|
|
|
250
|
-
compiled_query =
|
|
238
|
+
compiled_query = SqlModelQuery(query, configuration)
|
|
251
239
|
return compiled_query, dependencies
|
|
252
240
|
|
|
253
241
|
async def _compile_python_model(
|
|
254
|
-
self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any]
|
|
255
|
-
) -> tuple[
|
|
256
|
-
assert
|
|
242
|
+
self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str, Referable]
|
|
243
|
+
) -> tuple[PyModelQuery, Iterable]:
|
|
244
|
+
assert isinstance(self.query_file, PyQueryFile)
|
|
257
245
|
|
|
258
246
|
sqrl_args = ModelDepsArgs(
|
|
259
247
|
ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits, placeholders, ctx
|
|
260
248
|
)
|
|
261
249
|
try:
|
|
262
250
|
dependencies = await asyncio.to_thread(self.query_file.raw_query.dependencies_func, sqrl_args)
|
|
251
|
+
for dependent_model_name in dependencies:
|
|
252
|
+
if dependent_model_name not in models_dict:
|
|
253
|
+
raise u.ConfigurationError(f'Model "{self.name}" references unknown model "{dependent_model_name}"')
|
|
263
254
|
except Exception as e:
|
|
264
255
|
raise u.FileExecutionError(f'Failed to run "{c.DEP_FUNC}" function for python model "{self.name}"', e) from e
|
|
265
256
|
|
|
266
257
|
dbview_conn_name = self._get_dbview_conn_name()
|
|
267
|
-
connections =
|
|
268
|
-
|
|
258
|
+
connections = self.conn_set.get_engines_as_dict()
|
|
259
|
+
|
|
260
|
+
def ref(dependent_model_name):
|
|
261
|
+
if dependent_model_name not in self.upstreams:
|
|
262
|
+
raise u.ConfigurationError(f'Model "{self.name}" must include model "{dependent_model_name}" as a dependency to use')
|
|
263
|
+
return pd.DataFrame(self.upstreams[dependent_model_name].result)
|
|
264
|
+
|
|
265
|
+
def run_external_sql(sql_query: str, connection_name: str | None):
|
|
266
|
+
connection_name = dbview_conn_name if connection_name is None else connection_name
|
|
267
|
+
return self.conn_set.run_sql_query_from_conn_name(sql_query, connection_name, placeholders)
|
|
268
|
+
|
|
269
|
+
use_duckdb = self.manifest_cfg.settings_obj.do_use_duckdb()
|
|
269
270
|
sqrl_args = ModelArgs(
|
|
270
271
|
ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits, placeholders, ctx,
|
|
271
|
-
dbview_conn_name, connections, dependencies, ref
|
|
272
|
+
dbview_conn_name, connections, dependencies, ref, run_external_sql, use_duckdb
|
|
272
273
|
)
|
|
273
274
|
|
|
274
275
|
def compiled_query():
|
|
275
276
|
try:
|
|
277
|
+
assert isinstance(self.query_file, PyQueryFile)
|
|
276
278
|
raw_query: _RawPyQuery = self.query_file.raw_query
|
|
277
|
-
return raw_query.query(
|
|
279
|
+
return raw_query.query(sqrl_args)
|
|
278
280
|
except Exception as e:
|
|
279
281
|
raise u.FileExecutionError(f'Failed to run "{c.MAIN_FUNC}" function for python model "{self.name}"', e) from e
|
|
280
282
|
|
|
281
|
-
return
|
|
283
|
+
return PyModelQuery(compiled_query), dependencies
|
|
282
284
|
|
|
283
285
|
async def compile(
|
|
284
|
-
self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str,
|
|
286
|
+
self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str, Referable], recurse: bool
|
|
285
287
|
) -> None:
|
|
286
288
|
if self.compiled_query is not None:
|
|
287
289
|
return
|
|
@@ -290,18 +292,18 @@ class _Model(_Referable):
|
|
|
290
292
|
|
|
291
293
|
start = time.time()
|
|
292
294
|
|
|
293
|
-
if self.query_file
|
|
294
|
-
compiled_query, dependencies = await self._compile_sql_model(ctx, ctx_args, placeholders)
|
|
295
|
-
elif self.query_file
|
|
296
|
-
compiled_query, dependencies = await self._compile_python_model(ctx, ctx_args, placeholders)
|
|
295
|
+
if isinstance(self.query_file, SqlQueryFile):
|
|
296
|
+
compiled_query, dependencies = await self._compile_sql_model(ctx, ctx_args, placeholders, models_dict)
|
|
297
|
+
elif isinstance(self.query_file, PyQueryFile):
|
|
298
|
+
compiled_query, dependencies = await self._compile_python_model(ctx, ctx_args, placeholders, models_dict)
|
|
297
299
|
else:
|
|
298
|
-
raise
|
|
300
|
+
raise NotImplementedError(f"Query type not supported: {self.query_file.__class__.__name__}")
|
|
299
301
|
|
|
300
302
|
self.compiled_query = compiled_query
|
|
301
|
-
self.wait_count = len(dependencies)
|
|
303
|
+
self.wait_count = len(set(dependencies))
|
|
302
304
|
|
|
303
305
|
model_type = self.get_model_type().name.lower()
|
|
304
|
-
|
|
306
|
+
self.logger.log_activity_time(f"compiling {model_type} model '{self.name}'", start)
|
|
305
307
|
|
|
306
308
|
if not recurse:
|
|
307
309
|
return
|
|
@@ -335,14 +337,14 @@ class _Model(_Referable):
|
|
|
335
337
|
return terminal_nodes
|
|
336
338
|
|
|
337
339
|
async def _run_sql_model(self, conn: Connection, placeholders: dict = {}) -> None:
|
|
338
|
-
assert(isinstance(self.compiled_query,
|
|
340
|
+
assert(isinstance(self.compiled_query, SqlModelQuery))
|
|
339
341
|
config = self.compiled_query.config
|
|
340
342
|
query = self.compiled_query.query
|
|
341
343
|
|
|
342
344
|
if self.query_file.model_type == ModelType.DBVIEW:
|
|
343
345
|
def run_sql_query():
|
|
344
346
|
try:
|
|
345
|
-
return
|
|
347
|
+
return self.conn_set.run_sql_query_from_conn_name(query, config.connection_name, placeholders)
|
|
346
348
|
except RuntimeError as e:
|
|
347
349
|
raise u.FileExecutionError(f'Failed to run dbview sql model "{self.name}"', e) from e
|
|
348
350
|
|
|
@@ -363,7 +365,7 @@ class _Model(_Referable):
|
|
|
363
365
|
self.result = await asyncio.to_thread(self._load_table_to_pandas, conn)
|
|
364
366
|
|
|
365
367
|
async def _run_python_model(self, conn: Connection) -> None:
|
|
366
|
-
assert(isinstance(self.compiled_query,
|
|
368
|
+
assert(isinstance(self.compiled_query, PyModelQuery))
|
|
367
369
|
|
|
368
370
|
df = await asyncio.to_thread(self.compiled_query.query)
|
|
369
371
|
if self.needs_sql_table:
|
|
@@ -374,13 +376,15 @@ class _Model(_Referable):
|
|
|
374
376
|
async def run_model(self, conn: Connection, placeholders: dict = {}) -> None:
|
|
375
377
|
start = time.time()
|
|
376
378
|
|
|
377
|
-
if self.query_file
|
|
379
|
+
if isinstance(self.query_file, SqlQueryFile):
|
|
378
380
|
await self._run_sql_model(conn, placeholders)
|
|
379
|
-
elif self.query_file
|
|
381
|
+
elif isinstance(self.query_file, PyQueryFile):
|
|
380
382
|
await self._run_python_model(conn)
|
|
383
|
+
else:
|
|
384
|
+
raise NotImplementedError(f"Query type not supported: {self.query_file.__class__.__name__}")
|
|
381
385
|
|
|
382
386
|
model_type = self.get_model_type().name.lower()
|
|
383
|
-
|
|
387
|
+
self.logger.log_activity_time(f"running {model_type} model '{self.name}'", start)
|
|
384
388
|
|
|
385
389
|
await super().run_model(conn, placeholders)
|
|
386
390
|
|
|
@@ -392,35 +396,37 @@ class _Model(_Referable):
|
|
|
392
396
|
|
|
393
397
|
|
|
394
398
|
@dataclass
|
|
395
|
-
class
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
399
|
+
class DAG:
|
|
400
|
+
manifest_cfg: ManifestConfig
|
|
401
|
+
dataset: DatasetConfig
|
|
402
|
+
target_model: Referable
|
|
403
|
+
models_dict: dict[str, Referable]
|
|
404
|
+
logger: u.Logger = field(default_factory=lambda: u.Logger(""))
|
|
405
|
+
parameter_set: ParameterSet | None = field(default=None, init=False) # set in apply_selections
|
|
400
406
|
placeholders: dict[str, Any] = field(init=False, default_factory=dict)
|
|
401
407
|
|
|
402
408
|
def apply_selections(
|
|
403
|
-
self, user:
|
|
409
|
+
self, param_cfg_set: ParameterConfigsSet, user: User | None, selections: dict[str, str], *, updates_only: bool = False, request_version: int | None = None
|
|
404
410
|
) -> None:
|
|
405
411
|
start = time.time()
|
|
406
412
|
dataset_params = self.dataset.parameters
|
|
407
|
-
parameter_set =
|
|
413
|
+
parameter_set = param_cfg_set.apply_selections(
|
|
408
414
|
dataset_params, selections, user, updates_only=updates_only, request_version=request_version
|
|
409
415
|
)
|
|
410
416
|
self.parameter_set = parameter_set
|
|
411
|
-
|
|
417
|
+
self.logger.log_activity_time(f"applying selections for dataset '{self.dataset.name}'", start)
|
|
412
418
|
|
|
413
|
-
def _compile_context(self, context_func: ContextFunc, user:
|
|
419
|
+
def _compile_context(self, param_args: ParametersArgs, context_func: ContextFunc, user: User | None) -> tuple[dict[str, Any], ContextArgs]:
|
|
414
420
|
start = time.time()
|
|
415
421
|
context = {}
|
|
416
|
-
|
|
422
|
+
assert isinstance(self.parameter_set, ParameterSet)
|
|
417
423
|
prms = self.parameter_set.get_parameters_as_dict()
|
|
418
424
|
args = ContextArgs(param_args.proj_vars, param_args.env_vars, user, prms, self.dataset.traits, self.placeholders)
|
|
419
425
|
try:
|
|
420
|
-
context_func(
|
|
426
|
+
context_func(context, args)
|
|
421
427
|
except Exception as e:
|
|
422
428
|
raise u.FileExecutionError(f'Failed to run {c.CONTEXT_FILE} for dataset "{self.dataset.name}"', e) from e
|
|
423
|
-
|
|
429
|
+
self.logger.log_activity_time(f"running context.py for dataset '{self.dataset.name}'", start)
|
|
424
430
|
return context, args
|
|
425
431
|
|
|
426
432
|
async def _compile_models(self, context: dict[str, Any], ctx_args: ContextArgs, recurse: bool) -> None:
|
|
@@ -431,11 +437,12 @@ class _DAG:
|
|
|
431
437
|
terminal_nodes = self.target_model.get_terminal_nodes(set())
|
|
432
438
|
for model in self.models_dict.values():
|
|
433
439
|
model.confirmed_no_cycles = False
|
|
434
|
-
|
|
440
|
+
self.logger.log_activity_time(f"validating no cycles in model dependencies", start)
|
|
435
441
|
return terminal_nodes
|
|
436
442
|
|
|
437
443
|
async def _run_models(self, terminal_nodes: set[str], placeholders: dict = {}) -> None:
|
|
438
|
-
|
|
444
|
+
use_duckdb = self.manifest_cfg.settings_obj.do_use_duckdb()
|
|
445
|
+
conn_url = "duckdb:///" if use_duckdb else "sqlite:///?check_same_thread=False"
|
|
439
446
|
engine = create_engine(conn_url)
|
|
440
447
|
|
|
441
448
|
with engine.connect() as conn:
|
|
@@ -448,14 +455,14 @@ class _DAG:
|
|
|
448
455
|
engine.dispose()
|
|
449
456
|
|
|
450
457
|
async def execute(
|
|
451
|
-
self, context_func: ContextFunc, user:
|
|
452
|
-
runquery: bool = True, recurse: bool = True
|
|
458
|
+
self, param_args: ParametersArgs, param_cfg_set: ParameterConfigsSet, context_func: ContextFunc, user: User | None, selections: dict[str, str],
|
|
459
|
+
*, request_version: int | None = None, runquery: bool = True, recurse: bool = True
|
|
453
460
|
) -> dict[str, Any]:
|
|
454
461
|
recurse = (recurse or runquery)
|
|
455
462
|
|
|
456
|
-
self.apply_selections(user, selections, request_version=request_version)
|
|
463
|
+
self.apply_selections(param_cfg_set, user, selections, request_version=request_version)
|
|
457
464
|
|
|
458
|
-
context, ctx_args = self._compile_context(context_func, user)
|
|
465
|
+
context, ctx_args = self._compile_context(param_args, context_func, user)
|
|
459
466
|
|
|
460
467
|
await self._compile_models(context, ctx_args, recurse)
|
|
461
468
|
|
|
@@ -488,194 +495,54 @@ class _DAG:
|
|
|
488
495
|
|
|
489
496
|
return G
|
|
490
497
|
|
|
498
|
+
|
|
491
499
|
class ModelsIO:
|
|
492
|
-
raw_queries_by_model: dict[str, _QueryFile]
|
|
493
|
-
context_func: ContextFunc
|
|
494
500
|
|
|
495
501
|
@classmethod
|
|
496
|
-
def
|
|
502
|
+
def load_files(cls, logger: u.Logger, base_path: str) -> dict[str, QueryFile]:
|
|
497
503
|
start = time.time()
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
def
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
raw_query = _RawSqlQuery(u.read_file(filepath))
|
|
513
|
-
|
|
514
|
-
if query_type is not None:
|
|
515
|
-
query_file = _QueryFile(filepath, model_type, query_type, raw_query)
|
|
516
|
-
if file_stem in cls.raw_queries_by_model:
|
|
517
|
-
conflicts = [cls.raw_queries_by_model[file_stem].filepath, filepath]
|
|
518
|
-
raise u.ConfigurationError(f"Multiple models found for '{file_stem}': {conflicts}")
|
|
519
|
-
cls.raw_queries_by_model[file_stem] = query_file
|
|
504
|
+
raw_queries_by_model: dict[str, QueryFile] = {}
|
|
505
|
+
|
|
506
|
+
def populate_from_file(dp: str, file: str, model_type: ModelType) -> None:
|
|
507
|
+
filepath = Path(dp, file)
|
|
508
|
+
file_stem, extension = os.path.splitext(file)
|
|
509
|
+
if extension == '.py':
|
|
510
|
+
module = pm.PyModule(filepath)
|
|
511
|
+
dependencies_func = module.get_func_or_class(c.DEP_FUNC, default_attr=lambda sqrl: [])
|
|
512
|
+
raw_query = _RawPyQuery(module.get_func_or_class(c.MAIN_FUNC), dependencies_func)
|
|
513
|
+
query_file = PyQueryFile(filepath.as_posix(), model_type, raw_query)
|
|
514
|
+
elif extension == '.sql':
|
|
515
|
+
query_file = SqlQueryFile(filepath.as_posix(), model_type, filepath.read_text())
|
|
516
|
+
else:
|
|
517
|
+
query_file = None
|
|
520
518
|
|
|
519
|
+
if query_file is not None:
|
|
520
|
+
if file_stem in raw_queries_by_model:
|
|
521
|
+
conflicts = [raw_queries_by_model[file_stem].filepath, filepath]
|
|
522
|
+
raise u.ConfigurationError(f"Multiple models found for '{file_stem}': {conflicts}")
|
|
523
|
+
raw_queries_by_model[file_stem] = query_file
|
|
524
|
+
|
|
525
|
+
def populate_raw_queries_for_type(folder_path: Path, model_type: ModelType) -> None:
|
|
521
526
|
for dp, _, filenames in os.walk(folder_path):
|
|
522
527
|
for file in filenames:
|
|
523
|
-
populate_from_file(dp, file)
|
|
528
|
+
populate_from_file(dp, file, model_type)
|
|
524
529
|
|
|
525
|
-
dbviews_path = u.
|
|
530
|
+
dbviews_path = u.Path(base_path, c.MODELS_FOLDER, c.DBVIEWS_FOLDER)
|
|
526
531
|
populate_raw_queries_for_type(dbviews_path, ModelType.DBVIEW)
|
|
527
532
|
|
|
528
|
-
federates_path = u.
|
|
533
|
+
federates_path = u.Path(base_path, c.MODELS_FOLDER, c.FEDERATES_FOLDER)
|
|
529
534
|
populate_raw_queries_for_type(federates_path, ModelType.FEDERATE)
|
|
530
535
|
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
timer.add_activity_time("loading files for models and context.py", start)
|
|
536
|
+
logger.log_activity_time("loading files for models", start)
|
|
537
|
+
return raw_queries_by_model
|
|
535
538
|
|
|
536
539
|
@classmethod
|
|
537
|
-
def
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
models_dict: dict[str, _Referable] = {key: _Seed(key, df) for key, df in seeds_dict.items()}
|
|
541
|
-
for key, val in cls.raw_queries_by_model.items():
|
|
542
|
-
models_dict[key] = _Model(key, val)
|
|
543
|
-
models_dict[key].needs_pandas = always_pandas
|
|
544
|
-
|
|
545
|
-
dataset_config = ManifestIO.obj.datasets[dataset]
|
|
546
|
-
target_model_name = dataset_config.model if target_model_name is None else target_model_name
|
|
547
|
-
target_model = models_dict[target_model_name]
|
|
548
|
-
target_model.is_target = True
|
|
549
|
-
|
|
550
|
-
return _DAG(dataset_config, target_model, models_dict)
|
|
551
|
-
|
|
552
|
-
@classmethod
|
|
553
|
-
def draw_dag(cls, dag: _DAG, output_folder: Path) -> None:
|
|
554
|
-
color_map = {ModelType.SEED: "green", ModelType.DBVIEW: "red", ModelType.FEDERATE: "skyblue"}
|
|
555
|
-
|
|
556
|
-
G = dag.to_networkx_graph()
|
|
557
|
-
|
|
558
|
-
fig, _ = plt.subplots()
|
|
559
|
-
pos = nx.multipartite_layout(G, subset_key="layer")
|
|
560
|
-
colors = [color_map[node[1]] for node in G.nodes(data="model_type")]
|
|
561
|
-
nx.draw(G, pos=pos, node_shape='^', node_size=1000, node_color=colors, arrowsize=20)
|
|
562
|
-
|
|
563
|
-
y_values = [val[1] for val in pos.values()]
|
|
564
|
-
scale = max(y_values) - min(y_values) if len(y_values) > 0 else 0
|
|
565
|
-
label_pos = {key: (val[0], val[1]-0.002-0.1*scale) for key, val in pos.items()}
|
|
566
|
-
nx.draw_networkx_labels(G, pos=label_pos, font_size=8)
|
|
567
|
-
|
|
568
|
-
fig.tight_layout()
|
|
569
|
-
plt.margins(x=0.1, y=0.1)
|
|
570
|
-
plt.savefig(u.join_paths(output_folder, "dag.png"))
|
|
571
|
-
plt.close(fig)
|
|
572
|
-
|
|
573
|
-
@classmethod
|
|
574
|
-
async def WriteDatasetOutputsGivenTestSet(
|
|
575
|
-
cls, dataset_conf: DatasetsConfig, select: str, test_set: Optional[str], runquery: bool, recurse: bool
|
|
576
|
-
) -> Any:
|
|
577
|
-
dataset = dataset_conf.name
|
|
578
|
-
default_test_set, default_test_set_conf = ManifestIO.obj.get_default_test_set(dataset)
|
|
579
|
-
if test_set is None or test_set == default_test_set:
|
|
580
|
-
test_set, test_set_conf = default_test_set, default_test_set_conf
|
|
581
|
-
elif test_set in ManifestIO.obj.selection_test_sets:
|
|
582
|
-
test_set_conf = ManifestIO.obj.selection_test_sets[test_set]
|
|
583
|
-
else:
|
|
584
|
-
raise u.ConfigurationError(f"No test set named '{test_set}' was found when compiling dataset '{dataset}'. The test set must be defined if not default for dataset.")
|
|
585
|
-
|
|
586
|
-
error_msg_intro = f"Cannot compile dataset '{dataset}' with test set '{test_set}'."
|
|
587
|
-
if test_set_conf.datasets is not None and dataset not in test_set_conf.datasets:
|
|
588
|
-
raise u.ConfigurationError(f"{error_msg_intro}\n Applicable datasets for test set '{test_set}' does not include dataset '{dataset}'.")
|
|
589
|
-
|
|
590
|
-
user_attributes = test_set_conf.user_attributes.copy()
|
|
591
|
-
selections = test_set_conf.parameters.copy()
|
|
592
|
-
username, is_internal = user_attributes.pop("username", ""), user_attributes.pop("is_internal", False)
|
|
593
|
-
if test_set_conf.is_authenticated:
|
|
594
|
-
user_cls: type[User] = Authenticator.get_auth_helper().get_func_or_class("User", default_attr=User)
|
|
595
|
-
user = user_cls.Create(username, is_internal=is_internal, **user_attributes)
|
|
596
|
-
elif dataset_conf.scope == DatasetScope.PUBLIC:
|
|
597
|
-
user = None
|
|
598
|
-
else:
|
|
599
|
-
raise u.ConfigurationError(f"{error_msg_intro}\n Non-public datasets require a test set with 'user_attributes' section defined")
|
|
600
|
-
|
|
601
|
-
if dataset_conf.scope == DatasetScope.PRIVATE and not is_internal:
|
|
602
|
-
raise u.ConfigurationError(f"{error_msg_intro}\n Private datasets require a test set with user_attribute 'is_internal' set to true")
|
|
603
|
-
|
|
604
|
-
# always_pandas is set to True for creating CSV files from results (when runquery is True)
|
|
605
|
-
dag = cls.GenerateDAG(dataset, target_model_name=select, always_pandas=True)
|
|
606
|
-
placeholders = await dag.execute(cls.context_func, user, selections, runquery=runquery, recurse=recurse)
|
|
607
|
-
|
|
608
|
-
output_folder = u.join_paths(c.TARGET_FOLDER, c.COMPILE_FOLDER, dataset, test_set)
|
|
609
|
-
if os.path.exists(output_folder):
|
|
610
|
-
shutil.rmtree(output_folder)
|
|
611
|
-
os.makedirs(output_folder, exist_ok=True)
|
|
612
|
-
|
|
613
|
-
def write_placeholders() -> None:
|
|
614
|
-
output_filepath = u.join_paths(output_folder, "placeholders.json")
|
|
615
|
-
with open(output_filepath, 'w') as f:
|
|
616
|
-
json.dump(placeholders, f, indent=4)
|
|
617
|
-
|
|
618
|
-
def write_model_outputs(model: _Model) -> None:
|
|
619
|
-
subfolder = c.DBVIEWS_FOLDER if model.query_file.model_type == ModelType.DBVIEW else c.FEDERATES_FOLDER
|
|
620
|
-
subpath = u.join_paths(output_folder, subfolder)
|
|
621
|
-
os.makedirs(subpath, exist_ok=True)
|
|
622
|
-
if isinstance(model.compiled_query, _SqlModelQuery):
|
|
623
|
-
output_filepath = u.join_paths(subpath, model.name+'.sql')
|
|
624
|
-
query = model.compiled_query.query
|
|
625
|
-
with open(output_filepath, 'w') as f:
|
|
626
|
-
f.write(query)
|
|
627
|
-
if runquery and isinstance(model.result, pd.DataFrame):
|
|
628
|
-
output_filepath = u.join_paths(subpath, model.name+'.csv')
|
|
629
|
-
model.result.to_csv(output_filepath, index=False)
|
|
630
|
-
|
|
631
|
-
write_placeholders()
|
|
632
|
-
all_model_names = dag.get_all_query_models()
|
|
633
|
-
coroutines = [asyncio.to_thread(write_model_outputs, dag.models_dict[name]) for name in all_model_names]
|
|
634
|
-
await asyncio.gather(*coroutines)
|
|
540
|
+
def load_context_func(cls, logger: u.Logger, base_path: str) -> ContextFunc:
|
|
541
|
+
start = time.time()
|
|
635
542
|
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
if isinstance(dag.target_model, _Model):
|
|
640
|
-
return dag.target_model.compiled_query.query # else return None
|
|
543
|
+
context_path = u.Path(base_path, c.PYCONFIGS_FOLDER, c.CONTEXT_FILE)
|
|
544
|
+
context_func: ContextFunc = pm.PyModule(context_path).get_func_or_class(c.MAIN_FUNC, default_attr=lambda ctx, sqrl: None)
|
|
641
545
|
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
cls, dataset: Optional[str], do_all_datasets: bool, select: Optional[str], test_set: Optional[str], do_all_test_sets: bool,
|
|
645
|
-
runquery: bool
|
|
646
|
-
) -> None:
|
|
647
|
-
|
|
648
|
-
def get_applicable_test_sets(dataset: str) -> list[str]:
|
|
649
|
-
applicable_test_sets = []
|
|
650
|
-
for test_set_name, test_set_config in ManifestIO.obj.selection_test_sets.items():
|
|
651
|
-
if test_set_config.datasets is None or dataset in test_set_config.datasets:
|
|
652
|
-
applicable_test_sets.append(test_set_name)
|
|
653
|
-
return applicable_test_sets
|
|
654
|
-
|
|
655
|
-
recurse = True
|
|
656
|
-
dataset_configs = ManifestIO.obj.datasets
|
|
657
|
-
if do_all_datasets:
|
|
658
|
-
selected_models = [(dataset, dataset.model) for dataset in dataset_configs.values()]
|
|
659
|
-
else:
|
|
660
|
-
if select is None:
|
|
661
|
-
select = dataset_configs[dataset].model
|
|
662
|
-
else:
|
|
663
|
-
recurse = False
|
|
664
|
-
selected_models = [(dataset_configs[dataset], select)]
|
|
665
|
-
|
|
666
|
-
coroutines = []
|
|
667
|
-
for dataset_conf, select in selected_models:
|
|
668
|
-
if do_all_test_sets:
|
|
669
|
-
for test_set_name in get_applicable_test_sets(dataset_conf.name):
|
|
670
|
-
coroutine = cls.WriteDatasetOutputsGivenTestSet(dataset_conf, select, test_set_name, runquery, recurse)
|
|
671
|
-
coroutines.append(coroutine)
|
|
672
|
-
|
|
673
|
-
coroutine = cls.WriteDatasetOutputsGivenTestSet(dataset_conf, select, test_set, runquery, recurse)
|
|
674
|
-
coroutines.append(coroutine)
|
|
675
|
-
|
|
676
|
-
queries = await asyncio.gather(*coroutines)
|
|
677
|
-
if not recurse and len(queries) == 1 and isinstance(queries[0], str):
|
|
678
|
-
print()
|
|
679
|
-
print(queries[0])
|
|
680
|
-
print()
|
|
546
|
+
logger.log_activity_time("loading file for context.py", start)
|
|
547
|
+
return context_func
|
|
681
548
|
|