squirrels 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of squirrels might be problematic. Click here for more details.
- squirrels/__init__.py +11 -4
- squirrels/_api_response_models.py +118 -0
- squirrels/_api_server.py +140 -75
- squirrels/_authenticator.py +10 -8
- squirrels/_command_line.py +17 -11
- squirrels/_connection_set.py +2 -2
- squirrels/_constants.py +13 -5
- squirrels/_initializer.py +23 -13
- squirrels/_manifest.py +20 -10
- squirrels/_models.py +303 -148
- squirrels/_parameter_configs.py +195 -57
- squirrels/_parameter_sets.py +14 -17
- squirrels/_py_module.py +2 -4
- squirrels/_seeds.py +38 -0
- squirrels/_utils.py +41 -33
- squirrels/arguments/run_time_args.py +76 -34
- squirrels/data_sources.py +172 -51
- squirrels/dateutils.py +3 -3
- squirrels/package_data/assets/index.js +14 -14
- squirrels/package_data/base_project/connections.yml +1 -1
- squirrels/package_data/base_project/database/expenses.db +0 -0
- squirrels/package_data/base_project/docker/Dockerfile +1 -1
- squirrels/package_data/base_project/environcfg.yml +7 -7
- squirrels/package_data/base_project/models/dbviews/database_view1.py +25 -14
- squirrels/package_data/base_project/models/dbviews/database_view1.sql +21 -14
- squirrels/package_data/base_project/models/federates/dataset_example.py +6 -5
- squirrels/package_data/base_project/models/federates/dataset_example.sql +1 -1
- squirrels/package_data/base_project/parameters.yml +57 -28
- squirrels/package_data/base_project/pyconfigs/auth.py +11 -10
- squirrels/package_data/base_project/pyconfigs/connections.py +6 -8
- squirrels/package_data/base_project/pyconfigs/context.py +49 -33
- squirrels/package_data/base_project/pyconfigs/parameters.py +62 -30
- squirrels/package_data/base_project/seeds/seed_categories.csv +6 -0
- squirrels/package_data/base_project/seeds/seed_subcategories.csv +15 -0
- squirrels/package_data/base_project/squirrels.yml.j2 +37 -20
- squirrels/parameter_options.py +30 -10
- squirrels/parameters.py +300 -70
- squirrels/user_base.py +3 -13
- squirrels-0.3.0.dist-info/LICENSE +201 -0
- {squirrels-0.2.1.dist-info → squirrels-0.3.0.dist-info}/METADATA +15 -15
- squirrels-0.3.0.dist-info/RECORD +56 -0
- {squirrels-0.2.1.dist-info → squirrels-0.3.0.dist-info}/WHEEL +1 -1
- squirrels/package_data/base_project/seeds/mocks/category.csv +0 -3
- squirrels/package_data/base_project/seeds/mocks/max_filter.csv +0 -2
- squirrels/package_data/base_project/seeds/mocks/subcategory.csv +0 -6
- squirrels-0.2.1.dist-info/LICENSE +0 -22
- squirrels-0.2.1.dist-info/RECORD +0 -55
- {squirrels-0.2.1.dist-info → squirrels-0.3.0.dist-info}/entry_points.txt +0 -0
squirrels/_models.py
CHANGED
|
@@ -1,21 +1,26 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import Optional, Callable, Iterable, Any
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
|
+
from abc import ABCMeta, abstractmethod
|
|
4
5
|
from enum import Enum
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
|
|
7
|
+
from sqlalchemy import create_engine, text, Connection
|
|
8
|
+
import asyncio, os, shutil, pandas as pd, json
|
|
9
|
+
import matplotlib.pyplot as plt, networkx as nx
|
|
7
10
|
|
|
8
11
|
from . import _constants as c, _utils as u, _py_module as pm
|
|
9
12
|
from .arguments.run_time_args import ContextArgs, ModelDepsArgs, ModelArgs
|
|
10
13
|
from ._authenticator import User, Authenticator
|
|
11
14
|
from ._connection_set import ConnectionSetIO
|
|
12
|
-
from ._manifest import ManifestIO, DatasetsConfig
|
|
15
|
+
from ._manifest import ManifestIO, DatasetsConfig, DatasetScope
|
|
13
16
|
from ._parameter_sets import ParameterConfigsSetIO, ParameterSet
|
|
17
|
+
from ._seeds import SeedsIO
|
|
14
18
|
from ._timer import timer, time
|
|
15
19
|
|
|
16
20
|
class ModelType(Enum):
|
|
17
21
|
DBVIEW = 1
|
|
18
22
|
FEDERATE = 2
|
|
23
|
+
SEED = 3
|
|
19
24
|
|
|
20
25
|
class QueryType(Enum):
|
|
21
26
|
SQL = 0
|
|
@@ -27,7 +32,7 @@ class Materialization(Enum):
|
|
|
27
32
|
|
|
28
33
|
|
|
29
34
|
@dataclass
|
|
30
|
-
class
|
|
35
|
+
class _SqlModelConfig:
|
|
31
36
|
## Applicable for dbview models
|
|
32
37
|
connection_name: str
|
|
33
38
|
|
|
@@ -58,63 +63,132 @@ ContextFunc = Callable[[dict[str, Any], ContextArgs], None]
|
|
|
58
63
|
|
|
59
64
|
|
|
60
65
|
@dataclass(frozen=True)
|
|
61
|
-
class
|
|
66
|
+
class _RawQuery(metaclass=ABCMeta):
|
|
62
67
|
pass
|
|
63
68
|
|
|
64
69
|
@dataclass(frozen=True)
|
|
65
|
-
class
|
|
70
|
+
class _RawSqlQuery(_RawQuery):
|
|
66
71
|
query: str
|
|
67
72
|
|
|
68
73
|
@dataclass(frozen=True)
|
|
69
|
-
class
|
|
74
|
+
class _RawPyQuery(_RawQuery):
|
|
70
75
|
query: Callable[[Any], pd.DataFrame]
|
|
71
76
|
dependencies_func: Callable[[Any], Iterable]
|
|
72
77
|
|
|
73
78
|
|
|
74
79
|
@dataclass
|
|
75
|
-
class
|
|
80
|
+
class _Query(metaclass=ABCMeta):
|
|
76
81
|
query: Any
|
|
77
82
|
|
|
78
83
|
@dataclass
|
|
79
|
-
class
|
|
84
|
+
class _WorkInProgress(_Query):
|
|
80
85
|
query: None = field(default=None, init=False)
|
|
81
86
|
|
|
82
87
|
@dataclass
|
|
83
|
-
class
|
|
88
|
+
class _SqlModelQuery(_Query):
|
|
84
89
|
query: str
|
|
85
|
-
config:
|
|
90
|
+
config: _SqlModelConfig
|
|
86
91
|
|
|
87
92
|
@dataclass
|
|
88
|
-
class
|
|
93
|
+
class _PyModelQuery(_Query):
|
|
89
94
|
query: Callable[[], pd.DataFrame]
|
|
90
95
|
|
|
91
96
|
|
|
92
97
|
@dataclass(frozen=True)
|
|
93
|
-
class
|
|
98
|
+
class _QueryFile:
|
|
94
99
|
filepath: str
|
|
95
100
|
model_type: ModelType
|
|
96
101
|
query_type: QueryType
|
|
97
|
-
raw_query:
|
|
102
|
+
raw_query: _RawQuery
|
|
98
103
|
|
|
99
104
|
|
|
100
105
|
@dataclass
|
|
101
|
-
class
|
|
106
|
+
class _Referable(metaclass=ABCMeta):
|
|
102
107
|
name: str
|
|
103
|
-
query_file: QueryFile
|
|
104
108
|
is_target: bool = field(default=False, init=False)
|
|
105
|
-
compiled_query: Optional[Query] = field(default=None, init=False)
|
|
106
109
|
|
|
107
110
|
needs_sql_table: bool = field(default=False, init=False)
|
|
108
|
-
needs_pandas: bool = False
|
|
111
|
+
needs_pandas: bool = field(default=False, init=False)
|
|
109
112
|
result: Optional[pd.DataFrame] = field(default=None, init=False, repr=False)
|
|
110
|
-
|
|
111
|
-
wait_count: int = field(default=0, init=False, repr=False)
|
|
112
|
-
upstreams: dict[str, Model] = field(default_factory=dict, init=False, repr=False)
|
|
113
|
-
downstreams: dict[str, Model] = field(default_factory=dict, init=False, repr=False)
|
|
114
113
|
|
|
114
|
+
wait_count: int = field(default=0, init=False, repr=False)
|
|
115
115
|
confirmed_no_cycles: bool = field(default=False, init=False)
|
|
116
|
+
upstreams: dict[str, _Referable] = field(default_factory=dict, init=False, repr=False)
|
|
117
|
+
downstreams: dict[str, _Referable] = field(default_factory=dict, init=False, repr=False)
|
|
118
|
+
|
|
119
|
+
@abstractmethod
|
|
120
|
+
def get_model_type(self) -> ModelType:
|
|
121
|
+
pass
|
|
122
|
+
|
|
123
|
+
async def compile(
|
|
124
|
+
self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str, _Referable], recurse: bool
|
|
125
|
+
) -> None:
|
|
126
|
+
pass
|
|
127
|
+
|
|
128
|
+
@abstractmethod
|
|
129
|
+
def get_terminal_nodes(self, depencency_path: set[str]) -> set[str]:
|
|
130
|
+
pass
|
|
131
|
+
|
|
132
|
+
def _load_pandas_to_table(self, df: pd.DataFrame, conn: Connection) -> None:
|
|
133
|
+
df.to_sql(self.name, conn, index=False)
|
|
134
|
+
|
|
135
|
+
def _load_table_to_pandas(self, conn: Connection) -> pd.DataFrame:
|
|
136
|
+
query = f"SELECT * FROM {self.name}"
|
|
137
|
+
return pd.read_sql(query, conn)
|
|
138
|
+
|
|
139
|
+
async def _trigger(self, conn: Connection, placeholders: dict = {}) -> None:
|
|
140
|
+
self.wait_count -= 1
|
|
141
|
+
if (self.wait_count == 0):
|
|
142
|
+
await self.run_model(conn, placeholders)
|
|
143
|
+
|
|
144
|
+
@abstractmethod
|
|
145
|
+
async def run_model(self, conn: Connection, placeholders: dict = {}) -> None:
|
|
146
|
+
coroutines = []
|
|
147
|
+
for model in self.downstreams.values():
|
|
148
|
+
coroutines.append(model._trigger(conn, placeholders))
|
|
149
|
+
await asyncio.gather(*coroutines)
|
|
150
|
+
|
|
151
|
+
def retrieve_dependent_query_models(self, dependent_model_names: set[str]) -> None:
|
|
152
|
+
pass
|
|
153
|
+
|
|
154
|
+
def get_max_path_length_to_target(self) -> int:
|
|
155
|
+
if not hasattr(self, "max_path_len_to_target"):
|
|
156
|
+
path_lengths = []
|
|
157
|
+
for child_model in self.downstreams.values():
|
|
158
|
+
path_lengths.append(child_model.get_max_path_length_to_target()+1)
|
|
159
|
+
if len(path_lengths) > 0:
|
|
160
|
+
self.max_path_len_to_target = max(path_lengths)
|
|
161
|
+
else:
|
|
162
|
+
self.max_path_len_to_target = 0 if self.is_target else None
|
|
163
|
+
return self.max_path_len_to_target
|
|
164
|
+
|
|
116
165
|
|
|
117
|
-
|
|
166
|
+
@dataclass
|
|
167
|
+
class _Seed(_Referable):
|
|
168
|
+
result: pd.DataFrame
|
|
169
|
+
|
|
170
|
+
def get_model_type(self) -> ModelType:
|
|
171
|
+
return ModelType.SEED
|
|
172
|
+
|
|
173
|
+
def get_terminal_nodes(self, depencency_path: set[str]) -> set[str]:
|
|
174
|
+
return {self.name}
|
|
175
|
+
|
|
176
|
+
async def run_model(self, conn: Connection, placeholders: dict = {}) -> None:
|
|
177
|
+
if self.needs_sql_table:
|
|
178
|
+
await asyncio.to_thread(self._load_pandas_to_table, self.result, conn)
|
|
179
|
+
await super().run_model(conn, placeholders)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
@dataclass
|
|
183
|
+
class _Model(_Referable):
|
|
184
|
+
query_file: _QueryFile
|
|
185
|
+
|
|
186
|
+
compiled_query: Optional[_Query] = field(default=None, init=False)
|
|
187
|
+
|
|
188
|
+
def get_model_type(self) -> ModelType:
|
|
189
|
+
return self.query_file.model_type
|
|
190
|
+
|
|
191
|
+
def _add_upstream(self, other: _Referable) -> None:
|
|
118
192
|
self.upstreams[other.name] = other
|
|
119
193
|
other.downstreams[self.name] = self
|
|
120
194
|
|
|
@@ -137,17 +211,20 @@ class Model:
|
|
|
137
211
|
materialized = federate_config.materialized
|
|
138
212
|
return Materialization[materialized.upper()]
|
|
139
213
|
|
|
140
|
-
async def _compile_sql_model(
|
|
141
|
-
|
|
214
|
+
async def _compile_sql_model(
|
|
215
|
+
self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any]
|
|
216
|
+
) -> tuple[_SqlModelQuery, set]:
|
|
217
|
+
assert(isinstance(self.query_file.raw_query, _RawSqlQuery))
|
|
218
|
+
|
|
142
219
|
raw_query = self.query_file.raw_query.query
|
|
143
|
-
|
|
144
220
|
connection_name = self._get_dbview_conn_name()
|
|
145
221
|
materialized = self._get_materialized()
|
|
146
|
-
configuration =
|
|
222
|
+
configuration = _SqlModelConfig(connection_name, materialized)
|
|
223
|
+
is_placeholder = lambda x: x in placeholders
|
|
147
224
|
kwargs = {
|
|
148
|
-
"proj_vars": ctx_args.proj_vars, "env_vars": ctx_args.env_vars,
|
|
149
|
-
"
|
|
150
|
-
"
|
|
225
|
+
"proj_vars": ctx_args.proj_vars, "env_vars": ctx_args.env_vars, "user": ctx_args.user, "prms": ctx_args.prms,
|
|
226
|
+
"traits": ctx_args.traits, "ctx": ctx, "is_placeholder": is_placeholder, "set_placeholder": ctx_args.set_placeholder,
|
|
227
|
+
"config": configuration.set_attribute, "is_param_enabled": ctx_args.param_exists
|
|
151
228
|
}
|
|
152
229
|
dependencies = set()
|
|
153
230
|
if self.query_file.model_type == ModelType.FEDERATE:
|
|
@@ -157,16 +234,21 @@ class Model:
|
|
|
157
234
|
kwargs["ref"] = ref
|
|
158
235
|
|
|
159
236
|
try:
|
|
160
|
-
query = await asyncio.to_thread(u.render_string, raw_query, kwargs)
|
|
237
|
+
query = await asyncio.to_thread(u.render_string, raw_query, **kwargs)
|
|
161
238
|
except Exception as e:
|
|
162
239
|
raise u.FileExecutionError(f'Failed to compile sql model "{self.name}"', e)
|
|
163
240
|
|
|
164
|
-
compiled_query =
|
|
241
|
+
compiled_query = _SqlModelQuery(query, configuration)
|
|
165
242
|
return compiled_query, dependencies
|
|
166
243
|
|
|
167
|
-
async def _compile_python_model(
|
|
168
|
-
|
|
169
|
-
|
|
244
|
+
async def _compile_python_model(
|
|
245
|
+
self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any]
|
|
246
|
+
) -> tuple[_PyModelQuery, set]:
|
|
247
|
+
assert(isinstance(self.query_file.raw_query, _RawPyQuery))
|
|
248
|
+
|
|
249
|
+
sqrl_args = ModelDepsArgs(
|
|
250
|
+
ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits, placeholders, ctx
|
|
251
|
+
)
|
|
170
252
|
try:
|
|
171
253
|
dependencies = await asyncio.to_thread(self.query_file.raw_query.dependencies_func, sqrl_args)
|
|
172
254
|
except Exception as e:
|
|
@@ -175,34 +257,42 @@ class Model:
|
|
|
175
257
|
dbview_conn_name = self._get_dbview_conn_name()
|
|
176
258
|
connections = ConnectionSetIO.obj.get_engines_as_dict()
|
|
177
259
|
ref = lambda x: self.upstreams[x].result
|
|
178
|
-
sqrl_args = ModelArgs(
|
|
179
|
-
|
|
260
|
+
sqrl_args = ModelArgs(
|
|
261
|
+
ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits, placeholders, ctx,
|
|
262
|
+
dbview_conn_name, connections, dependencies, ref
|
|
263
|
+
)
|
|
180
264
|
|
|
181
265
|
def compiled_query():
|
|
182
266
|
try:
|
|
183
|
-
|
|
267
|
+
raw_query: _RawPyQuery = self.query_file.raw_query
|
|
268
|
+
return raw_query.query(sqrl=sqrl_args)
|
|
184
269
|
except Exception as e:
|
|
185
270
|
raise u.FileExecutionError(f'Failed to run "{c.MAIN_FUNC}" function for python model "{self.name}"', e)
|
|
186
271
|
|
|
187
|
-
return
|
|
272
|
+
return _PyModelQuery(compiled_query), dependencies
|
|
188
273
|
|
|
189
|
-
async def compile(
|
|
274
|
+
async def compile(
|
|
275
|
+
self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str, _Referable], recurse: bool
|
|
276
|
+
) -> None:
|
|
190
277
|
if self.compiled_query is not None:
|
|
191
278
|
return
|
|
192
279
|
else:
|
|
193
|
-
self.compiled_query =
|
|
280
|
+
self.compiled_query = _WorkInProgress()
|
|
194
281
|
|
|
195
282
|
start = time.time()
|
|
283
|
+
|
|
196
284
|
if self.query_file.query_type == QueryType.SQL:
|
|
197
|
-
compiled_query, dependencies = await self._compile_sql_model(ctx, ctx_args)
|
|
285
|
+
compiled_query, dependencies = await self._compile_sql_model(ctx, ctx_args, placeholders)
|
|
198
286
|
elif self.query_file.query_type == QueryType.PYTHON:
|
|
199
|
-
compiled_query, dependencies = await self._compile_python_model(ctx, ctx_args)
|
|
287
|
+
compiled_query, dependencies = await self._compile_python_model(ctx, ctx_args, placeholders)
|
|
200
288
|
else:
|
|
201
289
|
raise NotImplementedError(f"Query type not supported: {self.query_file.query_type}")
|
|
202
290
|
|
|
203
291
|
self.compiled_query = compiled_query
|
|
204
292
|
self.wait_count = len(dependencies)
|
|
205
|
-
|
|
293
|
+
|
|
294
|
+
model_type = self.get_model_type().name.lower()
|
|
295
|
+
timer.add_activity_time(f"compiling {model_type} model '{self.name}'", start)
|
|
206
296
|
|
|
207
297
|
if not recurse:
|
|
208
298
|
return
|
|
@@ -211,13 +301,13 @@ class Model:
|
|
|
211
301
|
coroutines = []
|
|
212
302
|
for dep_model in dep_models:
|
|
213
303
|
self._add_upstream(dep_model)
|
|
214
|
-
coro = dep_model.compile(ctx, ctx_args, models_dict, recurse)
|
|
304
|
+
coro = dep_model.compile(ctx, ctx_args, placeholders, models_dict, recurse)
|
|
215
305
|
coroutines.append(coro)
|
|
216
306
|
await asyncio.gather(*coroutines)
|
|
217
307
|
|
|
218
|
-
def
|
|
308
|
+
def get_terminal_nodes(self, depencency_path: set[str]) -> set[str]:
|
|
219
309
|
if self.confirmed_no_cycles:
|
|
220
|
-
return
|
|
310
|
+
return set()
|
|
221
311
|
|
|
222
312
|
if self.name in depencency_path:
|
|
223
313
|
raise u.ConfigurationError(f'Cycle found in model dependency graph')
|
|
@@ -229,34 +319,21 @@ class Model:
|
|
|
229
319
|
new_path = set(depencency_path)
|
|
230
320
|
new_path.add(self.name)
|
|
231
321
|
for dep_model in self.upstreams.values():
|
|
232
|
-
terminal_nodes_under_dep = dep_model.
|
|
322
|
+
terminal_nodes_under_dep = dep_model.get_terminal_nodes(new_path)
|
|
233
323
|
terminal_nodes = terminal_nodes.union(terminal_nodes_under_dep)
|
|
234
324
|
|
|
235
325
|
self.confirmed_no_cycles = True
|
|
236
326
|
return terminal_nodes
|
|
237
|
-
|
|
238
|
-
def _load_pandas_to_table(self, df: pd.DataFrame, conn: sqlite3.Connection) -> None:
|
|
239
|
-
if u.use_duckdb():
|
|
240
|
-
conn.execute(f"CREATE TABLE {self.name} AS FROM df")
|
|
241
|
-
else:
|
|
242
|
-
df.to_sql(self.name, conn, index=False)
|
|
243
|
-
|
|
244
|
-
def _load_table_to_pandas(self, conn: sqlite3.Connection) -> pd.DataFrame:
|
|
245
|
-
if u.use_duckdb():
|
|
246
|
-
return conn.execute(f"FROM {self.name}").df()
|
|
247
|
-
else:
|
|
248
|
-
query = f"SELECT * FROM {self.name}"
|
|
249
|
-
return pd.read_sql(query, conn)
|
|
250
327
|
|
|
251
|
-
async def _run_sql_model(self, conn:
|
|
252
|
-
assert(isinstance(self.compiled_query,
|
|
328
|
+
async def _run_sql_model(self, conn: Connection, placeholders: dict = {}) -> None:
|
|
329
|
+
assert(isinstance(self.compiled_query, _SqlModelQuery))
|
|
253
330
|
config = self.compiled_query.config
|
|
254
331
|
query = self.compiled_query.query
|
|
255
332
|
|
|
256
333
|
if self.query_file.model_type == ModelType.DBVIEW:
|
|
257
334
|
def run_sql_query():
|
|
258
335
|
try:
|
|
259
|
-
return ConnectionSetIO.obj.run_sql_query_from_conn_name(query, config.connection_name)
|
|
336
|
+
return ConnectionSetIO.obj.run_sql_query_from_conn_name(query, config.connection_name, placeholders)
|
|
260
337
|
except RuntimeError as e:
|
|
261
338
|
raise u.FileExecutionError(f'Failed to run dbview sql model "{self.name}"', e)
|
|
262
339
|
|
|
@@ -268,7 +345,7 @@ class Model:
|
|
|
268
345
|
def create_table():
|
|
269
346
|
create_query = config.get_sql_for_create(self.name, query)
|
|
270
347
|
try:
|
|
271
|
-
return conn.execute(create_query)
|
|
348
|
+
return conn.execute(text(create_query), placeholders)
|
|
272
349
|
except Exception as e:
|
|
273
350
|
raise u.FileExecutionError(f'Failed to run federate sql model "{self.name}"', e)
|
|
274
351
|
|
|
@@ -276,8 +353,8 @@ class Model:
|
|
|
276
353
|
if self.needs_pandas or self.is_target:
|
|
277
354
|
self.result = await asyncio.to_thread(self._load_table_to_pandas, conn)
|
|
278
355
|
|
|
279
|
-
async def _run_python_model(self, conn:
|
|
280
|
-
assert(isinstance(self.compiled_query,
|
|
356
|
+
async def _run_python_model(self, conn: Connection) -> None:
|
|
357
|
+
assert(isinstance(self.compiled_query, _PyModelQuery))
|
|
281
358
|
|
|
282
359
|
df = await asyncio.to_thread(self.compiled_query.query)
|
|
283
360
|
if self.needs_sql_table:
|
|
@@ -285,90 +362,86 @@ class Model:
|
|
|
285
362
|
if self.needs_pandas or self.is_target:
|
|
286
363
|
self.result = df
|
|
287
364
|
|
|
288
|
-
async def run_model(self, conn:
|
|
365
|
+
async def run_model(self, conn: Connection, placeholders: dict = {}) -> None:
|
|
289
366
|
start = time.time()
|
|
367
|
+
|
|
290
368
|
if self.query_file.query_type == QueryType.SQL:
|
|
291
|
-
await self._run_sql_model(conn)
|
|
369
|
+
await self._run_sql_model(conn, placeholders)
|
|
292
370
|
elif self.query_file.query_type == QueryType.PYTHON:
|
|
293
371
|
await self._run_python_model(conn)
|
|
294
|
-
timer.add_activity_time(f"running model '{self.name}'", start)
|
|
295
372
|
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
await
|
|
300
|
-
|
|
301
|
-
async def trigger(self, conn: sqlite3.Connection) -> None:
|
|
302
|
-
self.wait_count -= 1
|
|
303
|
-
if (self.wait_count == 0):
|
|
304
|
-
await self.run_model(conn)
|
|
373
|
+
model_type = self.get_model_type().name.lower()
|
|
374
|
+
timer.add_activity_time(f"running {model_type} model '{self.name}'", start)
|
|
375
|
+
|
|
376
|
+
await super().run_model(conn, placeholders)
|
|
305
377
|
|
|
306
|
-
def
|
|
378
|
+
def retrieve_dependent_query_models(self, dependent_model_names: set[str]) -> None:
|
|
307
379
|
if self.name not in dependent_model_names:
|
|
308
380
|
dependent_model_names.add(self.name)
|
|
309
381
|
for dep_model in self.upstreams.values():
|
|
310
|
-
dep_model.
|
|
382
|
+
dep_model.retrieve_dependent_query_models(dependent_model_names)
|
|
311
383
|
|
|
312
384
|
|
|
313
385
|
@dataclass
|
|
314
|
-
class
|
|
386
|
+
class _DAG:
|
|
315
387
|
dataset: DatasetsConfig
|
|
316
|
-
target_model:
|
|
317
|
-
models_dict: dict[str,
|
|
388
|
+
target_model: _Referable
|
|
389
|
+
models_dict: dict[str, _Referable]
|
|
318
390
|
parameter_set: Optional[ParameterSet] = field(default=None, init=False)
|
|
391
|
+
placeholders: dict[str, Any] = field(init=False, default_factory=dict)
|
|
319
392
|
|
|
320
393
|
def apply_selections(
|
|
321
394
|
self, user: Optional[User], selections: dict[str, str], *, updates_only: bool = False, request_version: Optional[int] = None
|
|
322
395
|
) -> None:
|
|
323
396
|
start = time.time()
|
|
324
397
|
dataset_params = self.dataset.parameters
|
|
325
|
-
parameter_set = ParameterConfigsSetIO.obj.apply_selections(
|
|
326
|
-
|
|
398
|
+
parameter_set = ParameterConfigsSetIO.obj.apply_selections(
|
|
399
|
+
dataset_params, selections, user, updates_only=updates_only, request_version=request_version
|
|
400
|
+
)
|
|
327
401
|
self.parameter_set = parameter_set
|
|
328
|
-
timer.add_activity_time(f"applying selections for dataset", start)
|
|
402
|
+
timer.add_activity_time(f"applying selections for dataset '{self.dataset.name}'", start)
|
|
329
403
|
|
|
330
404
|
def _compile_context(self, context_func: ContextFunc, user: Optional[User]) -> tuple[dict[str, Any], ContextArgs]:
|
|
331
405
|
start = time.time()
|
|
332
406
|
context = {}
|
|
333
407
|
param_args = ParameterConfigsSetIO.args
|
|
334
408
|
prms = self.parameter_set.get_parameters_as_dict()
|
|
335
|
-
args = ContextArgs(param_args.proj_vars, param_args.env_vars, user, prms, self.dataset.traits)
|
|
409
|
+
args = ContextArgs(param_args.proj_vars, param_args.env_vars, user, prms, self.dataset.traits, self.placeholders)
|
|
336
410
|
try:
|
|
337
411
|
context_func(ctx=context, sqrl=args)
|
|
338
412
|
except Exception as e:
|
|
339
|
-
raise u.FileExecutionError(f'Failed to run {c.CONTEXT_FILE} for dataset "{self.dataset}"', e)
|
|
340
|
-
timer.add_activity_time(f"running context.py for dataset", start)
|
|
413
|
+
raise u.FileExecutionError(f'Failed to run {c.CONTEXT_FILE} for dataset "{self.dataset.name}"', e)
|
|
414
|
+
timer.add_activity_time(f"running context.py for dataset '{self.dataset.name}'", start)
|
|
341
415
|
return context, args
|
|
342
416
|
|
|
343
417
|
async def _compile_models(self, context: dict[str, Any], ctx_args: ContextArgs, recurse: bool) -> None:
|
|
344
|
-
await self.target_model.compile(context, ctx_args, self.models_dict, recurse)
|
|
418
|
+
await self.target_model.compile(context, ctx_args, self.placeholders, self.models_dict, recurse)
|
|
345
419
|
|
|
346
|
-
def
|
|
420
|
+
def _get_terminal_nodes(self) -> set[str]:
|
|
347
421
|
start = time.time()
|
|
348
|
-
terminal_nodes = self.target_model.
|
|
349
|
-
|
|
422
|
+
terminal_nodes = self.target_model.get_terminal_nodes(set())
|
|
423
|
+
for model in self.models_dict.values():
|
|
424
|
+
model.confirmed_no_cycles = False
|
|
425
|
+
timer.add_activity_time(f"validating no cycles in model dependencies", start)
|
|
350
426
|
return terminal_nodes
|
|
351
427
|
|
|
352
|
-
async def _run_models(self, terminal_nodes: set[str]) -> None:
|
|
353
|
-
if u.use_duckdb()
|
|
354
|
-
|
|
355
|
-
conn = duckdb.connect()
|
|
356
|
-
else:
|
|
357
|
-
conn = sqlite3.connect(":memory:", check_same_thread=False)
|
|
428
|
+
async def _run_models(self, terminal_nodes: set[str], placeholders: dict = {}) -> None:
|
|
429
|
+
conn_url = "duckdb:///" if u.use_duckdb() else "sqlite:///?check_same_thread=False"
|
|
430
|
+
engine = create_engine(conn_url)
|
|
358
431
|
|
|
359
|
-
|
|
432
|
+
with engine.connect() as conn:
|
|
360
433
|
coroutines = []
|
|
361
434
|
for model_name in terminal_nodes:
|
|
362
435
|
model = self.models_dict[model_name]
|
|
363
|
-
coroutines.append(model.run_model(conn))
|
|
436
|
+
coroutines.append(model.run_model(conn, placeholders))
|
|
364
437
|
await asyncio.gather(*coroutines)
|
|
365
|
-
|
|
366
|
-
|
|
438
|
+
|
|
439
|
+
engine.dispose()
|
|
367
440
|
|
|
368
441
|
async def execute(
|
|
369
442
|
self, context_func: ContextFunc, user: Optional[User], selections: dict[str, str], *, request_version: Optional[int] = None,
|
|
370
443
|
runquery: bool = True, recurse: bool = True
|
|
371
|
-
) ->
|
|
444
|
+
) -> dict[str, Any]:
|
|
372
445
|
recurse = (recurse or runquery)
|
|
373
446
|
|
|
374
447
|
self.apply_selections(user, selections, request_version=request_version)
|
|
@@ -377,19 +450,37 @@ class DAG:
|
|
|
377
450
|
|
|
378
451
|
await self._compile_models(context, ctx_args, recurse)
|
|
379
452
|
|
|
380
|
-
terminal_nodes = self.
|
|
453
|
+
terminal_nodes = self._get_terminal_nodes()
|
|
381
454
|
|
|
455
|
+
placeholders = ctx_args._placeholders.copy()
|
|
382
456
|
if runquery:
|
|
383
|
-
await self._run_models(terminal_nodes)
|
|
457
|
+
await self._run_models(terminal_nodes, placeholders)
|
|
458
|
+
|
|
459
|
+
return placeholders
|
|
384
460
|
|
|
385
|
-
def
|
|
461
|
+
def get_all_query_models(self) -> set[str]:
|
|
386
462
|
all_model_names = set()
|
|
387
|
-
self.target_model.
|
|
463
|
+
self.target_model.retrieve_dependent_query_models(all_model_names)
|
|
388
464
|
return all_model_names
|
|
389
|
-
|
|
465
|
+
|
|
466
|
+
def to_networkx_graph(self) -> nx.DiGraph:
|
|
467
|
+
G = nx.DiGraph()
|
|
468
|
+
|
|
469
|
+
for model_name, model in self.models_dict.items():
|
|
470
|
+
model_type = model.get_model_type()
|
|
471
|
+
level = model.get_max_path_length_to_target()
|
|
472
|
+
if level is not None:
|
|
473
|
+
G.add_node(model_name, layer=-level, model_type=model_type)
|
|
474
|
+
|
|
475
|
+
for model_name in G.nodes:
|
|
476
|
+
model = self.models_dict[model_name]
|
|
477
|
+
for dep_model_name in model.downstreams:
|
|
478
|
+
G.add_edge(model_name, dep_model_name)
|
|
479
|
+
|
|
480
|
+
return G
|
|
390
481
|
|
|
391
482
|
class ModelsIO:
|
|
392
|
-
raw_queries_by_model: dict[str,
|
|
483
|
+
raw_queries_by_model: dict[str, _QueryFile]
|
|
393
484
|
context_func: ContextFunc
|
|
394
485
|
|
|
395
486
|
@classmethod
|
|
@@ -398,7 +489,6 @@ class ModelsIO:
|
|
|
398
489
|
cls.raw_queries_by_model = {}
|
|
399
490
|
|
|
400
491
|
def populate_raw_queries_for_type(folder_path: Path, model_type: ModelType):
|
|
401
|
-
|
|
402
492
|
def populate_from_file(dp, file):
|
|
403
493
|
query_type = None
|
|
404
494
|
filepath = os.path.join(dp, file)
|
|
@@ -407,13 +497,13 @@ class ModelsIO:
|
|
|
407
497
|
query_type = QueryType.PYTHON
|
|
408
498
|
module = pm.PyModule(filepath)
|
|
409
499
|
dependencies_func = module.get_func_or_class(c.DEP_FUNC, default_attr=lambda x: [])
|
|
410
|
-
raw_query =
|
|
500
|
+
raw_query = _RawPyQuery(module.get_func_or_class(c.MAIN_FUNC), dependencies_func)
|
|
411
501
|
elif extension == '.sql':
|
|
412
502
|
query_type = QueryType.SQL
|
|
413
|
-
raw_query =
|
|
503
|
+
raw_query = _RawSqlQuery(u.read_file(filepath))
|
|
414
504
|
|
|
415
505
|
if query_type is not None:
|
|
416
|
-
query_file =
|
|
506
|
+
query_file = _QueryFile(filepath, model_type, query_type, raw_query)
|
|
417
507
|
if file_stem in cls.raw_queries_by_model:
|
|
418
508
|
conflicts = [cls.raw_queries_by_model[file_stem].filepath, filepath]
|
|
419
509
|
raise u.ConfigurationError(f"Multiple models found for '{file_stem}': {conflicts}")
|
|
@@ -429,44 +519,98 @@ class ModelsIO:
|
|
|
429
519
|
federates_path = u.join_paths(c.MODELS_FOLDER, c.FEDERATES_FOLDER)
|
|
430
520
|
populate_raw_queries_for_type(federates_path, ModelType.FEDERATE)
|
|
431
521
|
|
|
432
|
-
context_path = u.join_paths(c.
|
|
522
|
+
context_path = u.join_paths(c.PYCONFIGS_FOLDER, c.CONTEXT_FILE)
|
|
433
523
|
cls.context_func = pm.PyModule(context_path).get_func_or_class(c.MAIN_FUNC, default_attr=lambda x, y: None)
|
|
434
524
|
|
|
435
|
-
timer.add_activity_time("loading models and
|
|
525
|
+
timer.add_activity_time("loading files for models and context.py", start)
|
|
436
526
|
|
|
437
527
|
@classmethod
|
|
438
|
-
def GenerateDAG(cls, dataset: str, *, target_model_name: Optional[str] = None, always_pandas: bool = False) ->
|
|
439
|
-
|
|
528
|
+
def GenerateDAG(cls, dataset: str, *, target_model_name: Optional[str] = None, always_pandas: bool = False) -> _DAG:
|
|
529
|
+
seeds_dict = SeedsIO.obj.get_dataframes()
|
|
530
|
+
|
|
531
|
+
models_dict: dict[str, _Referable] = {key: _Seed(key, df) for key, df in seeds_dict.items()}
|
|
532
|
+
for key, val in cls.raw_queries_by_model.items():
|
|
533
|
+
models_dict[key] = _Model(key, val)
|
|
534
|
+
models_dict[key].needs_pandas = always_pandas
|
|
440
535
|
|
|
441
536
|
dataset_config = ManifestIO.obj.datasets[dataset]
|
|
442
537
|
target_model_name = dataset_config.model if target_model_name is None else target_model_name
|
|
443
538
|
target_model = models_dict[target_model_name]
|
|
444
539
|
target_model.is_target = True
|
|
445
540
|
|
|
446
|
-
return
|
|
541
|
+
return _DAG(dataset_config, target_model, models_dict)
|
|
447
542
|
|
|
448
543
|
@classmethod
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
544
|
+
def draw_dag(cls, dag: _DAG, output_folder: Path) -> None:
|
|
545
|
+
color_map = {ModelType.SEED: "green", ModelType.DBVIEW: "red", ModelType.FEDERATE: "skyblue"}
|
|
546
|
+
|
|
547
|
+
G = dag.to_networkx_graph()
|
|
548
|
+
|
|
549
|
+
fig, _ = plt.subplots()
|
|
550
|
+
pos = nx.multipartite_layout(G, subset_key="layer")
|
|
551
|
+
colors = [color_map[node[1]] for node in G.nodes(data="model_type")]
|
|
552
|
+
nx.draw(G, pos=pos, node_shape='^', node_size=1000, node_color=colors, arrowsize=20)
|
|
553
|
+
|
|
554
|
+
y_values = [val[1] for val in pos.values()]
|
|
555
|
+
scale = max(y_values) - min(y_values) if len(y_values) > 0 else 0
|
|
556
|
+
label_pos = {key: (val[0], val[1]-0.002-0.1*scale) for key, val in pos.items()}
|
|
557
|
+
nx.draw_networkx_labels(G, pos=label_pos, font_size=8)
|
|
558
|
+
|
|
559
|
+
fig.tight_layout()
|
|
560
|
+
plt.margins(x=0.1, y=0.1)
|
|
561
|
+
plt.savefig(u.join_paths(output_folder, "dag.png"))
|
|
562
|
+
plt.close(fig)
|
|
563
|
+
|
|
564
|
+
@classmethod
|
|
565
|
+
async def WriteDatasetOutputsGivenTestSet(
|
|
566
|
+
cls, dataset_conf: DatasetsConfig, select: str, test_set: Optional[str], runquery: bool, recurse: bool
|
|
567
|
+
) -> Any:
|
|
568
|
+
dataset = dataset_conf.name
|
|
569
|
+
default_test_set, default_test_set_conf = ManifestIO.obj.get_default_test_set(dataset)
|
|
570
|
+
if test_set is None or test_set == default_test_set:
|
|
571
|
+
test_set, test_set_conf = default_test_set, default_test_set_conf
|
|
572
|
+
elif test_set in ManifestIO.obj.selection_test_sets:
|
|
573
|
+
test_set_conf = ManifestIO.obj.selection_test_sets[test_set]
|
|
574
|
+
else:
|
|
575
|
+
raise u.InvalidInputError(f"No test set named '{test_set}' was found when compiling dataset '{dataset}'. The test set must be defined if not default for dataset.")
|
|
453
576
|
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
577
|
+
error_msg_intro = f"Cannot compile dataset '{dataset}' with test set '{test_set}'."
|
|
578
|
+
if test_set_conf.datasets is not None and dataset not in test_set_conf.datasets:
|
|
579
|
+
raise u.InvalidInputError(f"{error_msg_intro}\n Applicable datasets for test set '{test_set}' does not include dataset '{dataset}'.")
|
|
457
580
|
|
|
581
|
+
user_attributes = test_set_conf.user_attributes.copy()
|
|
582
|
+
selections = test_set_conf.parameters.copy()
|
|
583
|
+
username, is_internal = user_attributes.pop("username", ""), user_attributes.pop("is_internal", False)
|
|
584
|
+
if test_set_conf.is_authenticated:
|
|
585
|
+
user_cls: type[User] = Authenticator.get_auth_helper().get_func_or_class("User", default_attr=User)
|
|
586
|
+
user = user_cls.Create(username, is_internal=is_internal, **user_attributes)
|
|
587
|
+
elif dataset_conf.scope == DatasetScope.PUBLIC:
|
|
588
|
+
user = None
|
|
589
|
+
else:
|
|
590
|
+
raise u.ConfigurationError(f"{error_msg_intro}\n Non-public datasets require a test set with 'user_attributes' section defined")
|
|
591
|
+
|
|
592
|
+
if dataset_conf.scope == DatasetScope.PRIVATE and not is_internal:
|
|
593
|
+
raise u.ConfigurationError(f"{error_msg_intro}\n Private datasets require a test set with user_attribute 'is_internal' set to true")
|
|
594
|
+
|
|
595
|
+
# always_pandas is set to True for creating CSV files from results (when runquery is True)
|
|
458
596
|
dag = cls.GenerateDAG(dataset, target_model_name=select, always_pandas=True)
|
|
459
|
-
await dag.execute(cls.context_func, user, selections, runquery=runquery, recurse=recurse)
|
|
597
|
+
placeholders = await dag.execute(cls.context_func, user, selections, runquery=runquery, recurse=recurse)
|
|
460
598
|
|
|
461
|
-
output_folder = u.join_paths(c.TARGET_FOLDER, c.COMPILE_FOLDER,
|
|
599
|
+
output_folder = u.join_paths(c.TARGET_FOLDER, c.COMPILE_FOLDER, dataset, test_set)
|
|
462
600
|
if os.path.exists(output_folder):
|
|
463
601
|
shutil.rmtree(output_folder)
|
|
602
|
+
os.makedirs(output_folder, exist_ok=True)
|
|
603
|
+
|
|
604
|
+
def write_placeholders() -> None:
|
|
605
|
+
output_filepath = u.join_paths(output_folder, "placeholders.json")
|
|
606
|
+
with open(output_filepath, 'w') as f:
|
|
607
|
+
json.dump(placeholders, f, indent=4)
|
|
464
608
|
|
|
465
|
-
def write_model_outputs(model:
|
|
609
|
+
def write_model_outputs(model: _Model) -> None:
|
|
466
610
|
subfolder = c.DBVIEWS_FOLDER if model.query_file.model_type == ModelType.DBVIEW else c.FEDERATES_FOLDER
|
|
467
611
|
subpath = u.join_paths(output_folder, subfolder)
|
|
468
612
|
os.makedirs(subpath, exist_ok=True)
|
|
469
|
-
if isinstance(model.compiled_query,
|
|
613
|
+
if isinstance(model.compiled_query, _SqlModelQuery):
|
|
470
614
|
output_filepath = u.join_paths(subpath, model.name+'.sql')
|
|
471
615
|
query = model.compiled_query.query
|
|
472
616
|
with open(output_filepath, 'w') as f:
|
|
@@ -475,39 +619,50 @@ class ModelsIO:
|
|
|
475
619
|
output_filepath = u.join_paths(subpath, model.name+'.csv')
|
|
476
620
|
model.result.to_csv(output_filepath, index=False)
|
|
477
621
|
|
|
478
|
-
|
|
622
|
+
write_placeholders()
|
|
623
|
+
all_model_names = dag.get_all_query_models()
|
|
479
624
|
coroutines = [asyncio.to_thread(write_model_outputs, dag.models_dict[name]) for name in all_model_names]
|
|
480
625
|
await asyncio.gather(*coroutines)
|
|
481
|
-
|
|
626
|
+
|
|
627
|
+
if recurse:
|
|
628
|
+
cls.draw_dag(dag, output_folder)
|
|
629
|
+
|
|
630
|
+
if isinstance(dag.target_model, _Model):
|
|
631
|
+
return dag.target_model.compiled_query.query # else return None
|
|
482
632
|
|
|
483
633
|
@classmethod
|
|
484
634
|
async def WriteOutputs(
|
|
485
|
-
cls, dataset: Optional[str], select: Optional[str],
|
|
635
|
+
cls, dataset: Optional[str], do_all_datasets: bool, select: Optional[str], test_set: Optional[str], do_all_test_sets: bool,
|
|
636
|
+
runquery: bool
|
|
486
637
|
) -> None:
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
638
|
+
|
|
639
|
+
def get_applicable_test_sets(dataset: str) -> list[str]:
|
|
640
|
+
applicable_test_sets = []
|
|
641
|
+
for test_set_name, test_set_config in ManifestIO.obj.selection_test_sets.items():
|
|
642
|
+
if test_set_config.datasets is None or dataset in test_set_config.datasets:
|
|
643
|
+
applicable_test_sets.append(test_set_name)
|
|
644
|
+
return applicable_test_sets
|
|
494
645
|
|
|
495
646
|
recurse = True
|
|
496
647
|
dataset_configs = ManifestIO.obj.datasets
|
|
497
|
-
if
|
|
498
|
-
selected_models = [(dataset
|
|
648
|
+
if do_all_datasets:
|
|
649
|
+
selected_models = [(dataset, dataset.model) for dataset in dataset_configs.values()]
|
|
499
650
|
else:
|
|
500
651
|
if select is None:
|
|
501
652
|
select = dataset_configs[dataset].model
|
|
502
653
|
else:
|
|
503
654
|
recurse = False
|
|
504
|
-
selected_models = [(dataset, select)]
|
|
655
|
+
selected_models = [(dataset_configs[dataset], select)]
|
|
505
656
|
|
|
506
657
|
coroutines = []
|
|
507
|
-
for
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
658
|
+
for dataset_conf, select in selected_models:
|
|
659
|
+
if do_all_test_sets:
|
|
660
|
+
for test_set_name in get_applicable_test_sets(dataset_conf.name):
|
|
661
|
+
coroutine = cls.WriteDatasetOutputsGivenTestSet(dataset_conf, select, test_set_name, runquery, recurse)
|
|
662
|
+
coroutines.append(coroutine)
|
|
663
|
+
|
|
664
|
+
coroutine = cls.WriteDatasetOutputsGivenTestSet(dataset_conf, select, test_set, runquery, recurse)
|
|
665
|
+
coroutines.append(coroutine)
|
|
511
666
|
|
|
512
667
|
queries = await asyncio.gather(*coroutines)
|
|
513
668
|
if not recurse and len(queries) == 1 and isinstance(queries[0], str):
|