squirrels 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of squirrels might be problematic. Click here for more details.
- squirrels/__init__.py +11 -4
- squirrels/_api_response_models.py +118 -0
- squirrels/_api_server.py +140 -75
- squirrels/_authenticator.py +10 -8
- squirrels/_command_line.py +17 -11
- squirrels/_connection_set.py +2 -2
- squirrels/_constants.py +13 -5
- squirrels/_initializer.py +23 -13
- squirrels/_manifest.py +20 -10
- squirrels/_models.py +295 -142
- squirrels/_parameter_configs.py +195 -57
- squirrels/_parameter_sets.py +14 -17
- squirrels/_py_module.py +2 -4
- squirrels/_seeds.py +38 -0
- squirrels/_utils.py +41 -33
- squirrels/arguments/run_time_args.py +76 -34
- squirrels/data_sources.py +172 -51
- squirrels/dateutils.py +3 -3
- squirrels/package_data/assets/index.js +14 -14
- squirrels/package_data/base_project/connections.yml +1 -1
- squirrels/package_data/base_project/database/expenses.db +0 -0
- squirrels/package_data/base_project/docker/Dockerfile +1 -1
- squirrels/package_data/base_project/environcfg.yml +7 -7
- squirrels/package_data/base_project/models/dbviews/database_view1.py +25 -14
- squirrels/package_data/base_project/models/dbviews/database_view1.sql +21 -14
- squirrels/package_data/base_project/models/federates/dataset_example.py +6 -5
- squirrels/package_data/base_project/models/federates/dataset_example.sql +1 -1
- squirrels/package_data/base_project/parameters.yml +57 -28
- squirrels/package_data/base_project/pyconfigs/auth.py +11 -10
- squirrels/package_data/base_project/pyconfigs/connections.py +6 -8
- squirrels/package_data/base_project/pyconfigs/context.py +49 -33
- squirrels/package_data/base_project/pyconfigs/parameters.py +62 -30
- squirrels/package_data/base_project/seeds/seed_categories.csv +6 -0
- squirrels/package_data/base_project/seeds/seed_subcategories.csv +15 -0
- squirrels/package_data/base_project/squirrels.yml.j2 +37 -20
- squirrels/parameter_options.py +30 -10
- squirrels/parameters.py +300 -70
- squirrels/user_base.py +3 -13
- squirrels-0.3.0.dist-info/LICENSE +201 -0
- {squirrels-0.2.2.dist-info → squirrels-0.3.0.dist-info}/METADATA +15 -15
- squirrels-0.3.0.dist-info/RECORD +56 -0
- squirrels/package_data/base_project/seeds/mocks/category.csv +0 -3
- squirrels/package_data/base_project/seeds/mocks/max_filter.csv +0 -2
- squirrels/package_data/base_project/seeds/mocks/subcategory.csv +0 -6
- squirrels-0.2.2.dist-info/LICENSE +0 -22
- squirrels-0.2.2.dist-info/RECORD +0 -55
- {squirrels-0.2.2.dist-info → squirrels-0.3.0.dist-info}/WHEEL +0 -0
- {squirrels-0.2.2.dist-info → squirrels-0.3.0.dist-info}/entry_points.txt +0 -0
squirrels/_models.py
CHANGED
|
@@ -1,21 +1,26 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import Optional, Callable, Iterable, Any
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
|
+
from abc import ABCMeta, abstractmethod
|
|
4
5
|
from enum import Enum
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
|
|
7
|
+
from sqlalchemy import create_engine, text, Connection
|
|
8
|
+
import asyncio, os, shutil, pandas as pd, json
|
|
9
|
+
import matplotlib.pyplot as plt, networkx as nx
|
|
7
10
|
|
|
8
11
|
from . import _constants as c, _utils as u, _py_module as pm
|
|
9
12
|
from .arguments.run_time_args import ContextArgs, ModelDepsArgs, ModelArgs
|
|
10
13
|
from ._authenticator import User, Authenticator
|
|
11
14
|
from ._connection_set import ConnectionSetIO
|
|
12
|
-
from ._manifest import ManifestIO, DatasetsConfig
|
|
15
|
+
from ._manifest import ManifestIO, DatasetsConfig, DatasetScope
|
|
13
16
|
from ._parameter_sets import ParameterConfigsSetIO, ParameterSet
|
|
17
|
+
from ._seeds import SeedsIO
|
|
14
18
|
from ._timer import timer, time
|
|
15
19
|
|
|
16
20
|
class ModelType(Enum):
|
|
17
21
|
DBVIEW = 1
|
|
18
22
|
FEDERATE = 2
|
|
23
|
+
SEED = 3
|
|
19
24
|
|
|
20
25
|
class QueryType(Enum):
|
|
21
26
|
SQL = 0
|
|
@@ -27,7 +32,7 @@ class Materialization(Enum):
|
|
|
27
32
|
|
|
28
33
|
|
|
29
34
|
@dataclass
|
|
30
|
-
class
|
|
35
|
+
class _SqlModelConfig:
|
|
31
36
|
## Applicable for dbview models
|
|
32
37
|
connection_name: str
|
|
33
38
|
|
|
@@ -58,63 +63,132 @@ ContextFunc = Callable[[dict[str, Any], ContextArgs], None]
|
|
|
58
63
|
|
|
59
64
|
|
|
60
65
|
@dataclass(frozen=True)
|
|
61
|
-
class
|
|
66
|
+
class _RawQuery(metaclass=ABCMeta):
|
|
62
67
|
pass
|
|
63
68
|
|
|
64
69
|
@dataclass(frozen=True)
|
|
65
|
-
class
|
|
70
|
+
class _RawSqlQuery(_RawQuery):
|
|
66
71
|
query: str
|
|
67
72
|
|
|
68
73
|
@dataclass(frozen=True)
|
|
69
|
-
class
|
|
74
|
+
class _RawPyQuery(_RawQuery):
|
|
70
75
|
query: Callable[[Any], pd.DataFrame]
|
|
71
76
|
dependencies_func: Callable[[Any], Iterable]
|
|
72
77
|
|
|
73
78
|
|
|
74
79
|
@dataclass
|
|
75
|
-
class
|
|
80
|
+
class _Query(metaclass=ABCMeta):
|
|
76
81
|
query: Any
|
|
77
82
|
|
|
78
83
|
@dataclass
|
|
79
|
-
class
|
|
84
|
+
class _WorkInProgress(_Query):
|
|
80
85
|
query: None = field(default=None, init=False)
|
|
81
86
|
|
|
82
87
|
@dataclass
|
|
83
|
-
class
|
|
88
|
+
class _SqlModelQuery(_Query):
|
|
84
89
|
query: str
|
|
85
|
-
config:
|
|
90
|
+
config: _SqlModelConfig
|
|
86
91
|
|
|
87
92
|
@dataclass
|
|
88
|
-
class
|
|
93
|
+
class _PyModelQuery(_Query):
|
|
89
94
|
query: Callable[[], pd.DataFrame]
|
|
90
95
|
|
|
91
96
|
|
|
92
97
|
@dataclass(frozen=True)
|
|
93
|
-
class
|
|
98
|
+
class _QueryFile:
|
|
94
99
|
filepath: str
|
|
95
100
|
model_type: ModelType
|
|
96
101
|
query_type: QueryType
|
|
97
|
-
raw_query:
|
|
102
|
+
raw_query: _RawQuery
|
|
98
103
|
|
|
99
104
|
|
|
100
105
|
@dataclass
|
|
101
|
-
class
|
|
106
|
+
class _Referable(metaclass=ABCMeta):
|
|
102
107
|
name: str
|
|
103
|
-
query_file: QueryFile
|
|
104
108
|
is_target: bool = field(default=False, init=False)
|
|
105
|
-
compiled_query: Optional[Query] = field(default=None, init=False)
|
|
106
109
|
|
|
107
110
|
needs_sql_table: bool = field(default=False, init=False)
|
|
108
|
-
needs_pandas: bool = False
|
|
111
|
+
needs_pandas: bool = field(default=False, init=False)
|
|
109
112
|
result: Optional[pd.DataFrame] = field(default=None, init=False, repr=False)
|
|
110
|
-
|
|
111
|
-
wait_count: int = field(default=0, init=False, repr=False)
|
|
112
|
-
upstreams: dict[str, Model] = field(default_factory=dict, init=False, repr=False)
|
|
113
|
-
downstreams: dict[str, Model] = field(default_factory=dict, init=False, repr=False)
|
|
114
113
|
|
|
114
|
+
wait_count: int = field(default=0, init=False, repr=False)
|
|
115
115
|
confirmed_no_cycles: bool = field(default=False, init=False)
|
|
116
|
+
upstreams: dict[str, _Referable] = field(default_factory=dict, init=False, repr=False)
|
|
117
|
+
downstreams: dict[str, _Referable] = field(default_factory=dict, init=False, repr=False)
|
|
118
|
+
|
|
119
|
+
@abstractmethod
|
|
120
|
+
def get_model_type(self) -> ModelType:
|
|
121
|
+
pass
|
|
122
|
+
|
|
123
|
+
async def compile(
|
|
124
|
+
self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str, _Referable], recurse: bool
|
|
125
|
+
) -> None:
|
|
126
|
+
pass
|
|
127
|
+
|
|
128
|
+
@abstractmethod
|
|
129
|
+
def get_terminal_nodes(self, depencency_path: set[str]) -> set[str]:
|
|
130
|
+
pass
|
|
131
|
+
|
|
132
|
+
def _load_pandas_to_table(self, df: pd.DataFrame, conn: Connection) -> None:
|
|
133
|
+
df.to_sql(self.name, conn, index=False)
|
|
134
|
+
|
|
135
|
+
def _load_table_to_pandas(self, conn: Connection) -> pd.DataFrame:
|
|
136
|
+
query = f"SELECT * FROM {self.name}"
|
|
137
|
+
return pd.read_sql(query, conn)
|
|
138
|
+
|
|
139
|
+
async def _trigger(self, conn: Connection, placeholders: dict = {}) -> None:
|
|
140
|
+
self.wait_count -= 1
|
|
141
|
+
if (self.wait_count == 0):
|
|
142
|
+
await self.run_model(conn, placeholders)
|
|
143
|
+
|
|
144
|
+
@abstractmethod
|
|
145
|
+
async def run_model(self, conn: Connection, placeholders: dict = {}) -> None:
|
|
146
|
+
coroutines = []
|
|
147
|
+
for model in self.downstreams.values():
|
|
148
|
+
coroutines.append(model._trigger(conn, placeholders))
|
|
149
|
+
await asyncio.gather(*coroutines)
|
|
150
|
+
|
|
151
|
+
def retrieve_dependent_query_models(self, dependent_model_names: set[str]) -> None:
|
|
152
|
+
pass
|
|
153
|
+
|
|
154
|
+
def get_max_path_length_to_target(self) -> int:
|
|
155
|
+
if not hasattr(self, "max_path_len_to_target"):
|
|
156
|
+
path_lengths = []
|
|
157
|
+
for child_model in self.downstreams.values():
|
|
158
|
+
path_lengths.append(child_model.get_max_path_length_to_target()+1)
|
|
159
|
+
if len(path_lengths) > 0:
|
|
160
|
+
self.max_path_len_to_target = max(path_lengths)
|
|
161
|
+
else:
|
|
162
|
+
self.max_path_len_to_target = 0 if self.is_target else None
|
|
163
|
+
return self.max_path_len_to_target
|
|
116
164
|
|
|
117
|
-
|
|
165
|
+
|
|
166
|
+
@dataclass
|
|
167
|
+
class _Seed(_Referable):
|
|
168
|
+
result: pd.DataFrame
|
|
169
|
+
|
|
170
|
+
def get_model_type(self) -> ModelType:
|
|
171
|
+
return ModelType.SEED
|
|
172
|
+
|
|
173
|
+
def get_terminal_nodes(self, depencency_path: set[str]) -> set[str]:
|
|
174
|
+
return {self.name}
|
|
175
|
+
|
|
176
|
+
async def run_model(self, conn: Connection, placeholders: dict = {}) -> None:
|
|
177
|
+
if self.needs_sql_table:
|
|
178
|
+
await asyncio.to_thread(self._load_pandas_to_table, self.result, conn)
|
|
179
|
+
await super().run_model(conn, placeholders)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
@dataclass
|
|
183
|
+
class _Model(_Referable):
|
|
184
|
+
query_file: _QueryFile
|
|
185
|
+
|
|
186
|
+
compiled_query: Optional[_Query] = field(default=None, init=False)
|
|
187
|
+
|
|
188
|
+
def get_model_type(self) -> ModelType:
|
|
189
|
+
return self.query_file.model_type
|
|
190
|
+
|
|
191
|
+
def _add_upstream(self, other: _Referable) -> None:
|
|
118
192
|
self.upstreams[other.name] = other
|
|
119
193
|
other.downstreams[self.name] = self
|
|
120
194
|
|
|
@@ -137,17 +211,20 @@ class Model:
|
|
|
137
211
|
materialized = federate_config.materialized
|
|
138
212
|
return Materialization[materialized.upper()]
|
|
139
213
|
|
|
140
|
-
async def _compile_sql_model(
|
|
141
|
-
|
|
214
|
+
async def _compile_sql_model(
|
|
215
|
+
self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any]
|
|
216
|
+
) -> tuple[_SqlModelQuery, set]:
|
|
217
|
+
assert(isinstance(self.query_file.raw_query, _RawSqlQuery))
|
|
218
|
+
|
|
142
219
|
raw_query = self.query_file.raw_query.query
|
|
143
|
-
|
|
144
220
|
connection_name = self._get_dbview_conn_name()
|
|
145
221
|
materialized = self._get_materialized()
|
|
146
|
-
configuration =
|
|
222
|
+
configuration = _SqlModelConfig(connection_name, materialized)
|
|
223
|
+
is_placeholder = lambda x: x in placeholders
|
|
147
224
|
kwargs = {
|
|
148
|
-
"proj_vars": ctx_args.proj_vars, "env_vars": ctx_args.env_vars,
|
|
149
|
-
"
|
|
150
|
-
"
|
|
225
|
+
"proj_vars": ctx_args.proj_vars, "env_vars": ctx_args.env_vars, "user": ctx_args.user, "prms": ctx_args.prms,
|
|
226
|
+
"traits": ctx_args.traits, "ctx": ctx, "is_placeholder": is_placeholder, "set_placeholder": ctx_args.set_placeholder,
|
|
227
|
+
"config": configuration.set_attribute, "is_param_enabled": ctx_args.param_exists
|
|
151
228
|
}
|
|
152
229
|
dependencies = set()
|
|
153
230
|
if self.query_file.model_type == ModelType.FEDERATE:
|
|
@@ -157,16 +234,21 @@ class Model:
|
|
|
157
234
|
kwargs["ref"] = ref
|
|
158
235
|
|
|
159
236
|
try:
|
|
160
|
-
query = await asyncio.to_thread(u.render_string, raw_query, kwargs)
|
|
237
|
+
query = await asyncio.to_thread(u.render_string, raw_query, **kwargs)
|
|
161
238
|
except Exception as e:
|
|
162
239
|
raise u.FileExecutionError(f'Failed to compile sql model "{self.name}"', e)
|
|
163
240
|
|
|
164
|
-
compiled_query =
|
|
241
|
+
compiled_query = _SqlModelQuery(query, configuration)
|
|
165
242
|
return compiled_query, dependencies
|
|
166
243
|
|
|
167
|
-
async def _compile_python_model(
|
|
168
|
-
|
|
169
|
-
|
|
244
|
+
async def _compile_python_model(
|
|
245
|
+
self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any]
|
|
246
|
+
) -> tuple[_PyModelQuery, set]:
|
|
247
|
+
assert(isinstance(self.query_file.raw_query, _RawPyQuery))
|
|
248
|
+
|
|
249
|
+
sqrl_args = ModelDepsArgs(
|
|
250
|
+
ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits, placeholders, ctx
|
|
251
|
+
)
|
|
170
252
|
try:
|
|
171
253
|
dependencies = await asyncio.to_thread(self.query_file.raw_query.dependencies_func, sqrl_args)
|
|
172
254
|
except Exception as e:
|
|
@@ -175,34 +257,42 @@ class Model:
|
|
|
175
257
|
dbview_conn_name = self._get_dbview_conn_name()
|
|
176
258
|
connections = ConnectionSetIO.obj.get_engines_as_dict()
|
|
177
259
|
ref = lambda x: self.upstreams[x].result
|
|
178
|
-
sqrl_args = ModelArgs(
|
|
179
|
-
|
|
260
|
+
sqrl_args = ModelArgs(
|
|
261
|
+
ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits, placeholders, ctx,
|
|
262
|
+
dbview_conn_name, connections, dependencies, ref
|
|
263
|
+
)
|
|
180
264
|
|
|
181
265
|
def compiled_query():
|
|
182
266
|
try:
|
|
183
|
-
|
|
267
|
+
raw_query: _RawPyQuery = self.query_file.raw_query
|
|
268
|
+
return raw_query.query(sqrl=sqrl_args)
|
|
184
269
|
except Exception as e:
|
|
185
270
|
raise u.FileExecutionError(f'Failed to run "{c.MAIN_FUNC}" function for python model "{self.name}"', e)
|
|
186
271
|
|
|
187
|
-
return
|
|
272
|
+
return _PyModelQuery(compiled_query), dependencies
|
|
188
273
|
|
|
189
|
-
async def compile(
|
|
274
|
+
async def compile(
|
|
275
|
+
self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str, _Referable], recurse: bool
|
|
276
|
+
) -> None:
|
|
190
277
|
if self.compiled_query is not None:
|
|
191
278
|
return
|
|
192
279
|
else:
|
|
193
|
-
self.compiled_query =
|
|
280
|
+
self.compiled_query = _WorkInProgress()
|
|
194
281
|
|
|
195
282
|
start = time.time()
|
|
283
|
+
|
|
196
284
|
if self.query_file.query_type == QueryType.SQL:
|
|
197
|
-
compiled_query, dependencies = await self._compile_sql_model(ctx, ctx_args)
|
|
285
|
+
compiled_query, dependencies = await self._compile_sql_model(ctx, ctx_args, placeholders)
|
|
198
286
|
elif self.query_file.query_type == QueryType.PYTHON:
|
|
199
|
-
compiled_query, dependencies = await self._compile_python_model(ctx, ctx_args)
|
|
287
|
+
compiled_query, dependencies = await self._compile_python_model(ctx, ctx_args, placeholders)
|
|
200
288
|
else:
|
|
201
289
|
raise NotImplementedError(f"Query type not supported: {self.query_file.query_type}")
|
|
202
290
|
|
|
203
291
|
self.compiled_query = compiled_query
|
|
204
292
|
self.wait_count = len(dependencies)
|
|
205
|
-
|
|
293
|
+
|
|
294
|
+
model_type = self.get_model_type().name.lower()
|
|
295
|
+
timer.add_activity_time(f"compiling {model_type} model '{self.name}'", start)
|
|
206
296
|
|
|
207
297
|
if not recurse:
|
|
208
298
|
return
|
|
@@ -211,7 +301,7 @@ class Model:
|
|
|
211
301
|
coroutines = []
|
|
212
302
|
for dep_model in dep_models:
|
|
213
303
|
self._add_upstream(dep_model)
|
|
214
|
-
coro = dep_model.compile(ctx, ctx_args, models_dict, recurse)
|
|
304
|
+
coro = dep_model.compile(ctx, ctx_args, placeholders, models_dict, recurse)
|
|
215
305
|
coroutines.append(coro)
|
|
216
306
|
await asyncio.gather(*coroutines)
|
|
217
307
|
|
|
@@ -234,29 +324,16 @@ class Model:
|
|
|
234
324
|
|
|
235
325
|
self.confirmed_no_cycles = True
|
|
236
326
|
return terminal_nodes
|
|
237
|
-
|
|
238
|
-
def _load_pandas_to_table(self, df: pd.DataFrame, conn: sqlite3.Connection) -> None:
|
|
239
|
-
if u.use_duckdb():
|
|
240
|
-
conn.execute(f"CREATE TABLE {self.name} AS FROM df")
|
|
241
|
-
else:
|
|
242
|
-
df.to_sql(self.name, conn, index=False)
|
|
243
|
-
|
|
244
|
-
def _load_table_to_pandas(self, conn: sqlite3.Connection) -> pd.DataFrame:
|
|
245
|
-
if u.use_duckdb():
|
|
246
|
-
return conn.execute(f"FROM {self.name}").df()
|
|
247
|
-
else:
|
|
248
|
-
query = f"SELECT * FROM {self.name}"
|
|
249
|
-
return pd.read_sql(query, conn)
|
|
250
327
|
|
|
251
|
-
async def _run_sql_model(self, conn:
|
|
252
|
-
assert(isinstance(self.compiled_query,
|
|
328
|
+
async def _run_sql_model(self, conn: Connection, placeholders: dict = {}) -> None:
|
|
329
|
+
assert(isinstance(self.compiled_query, _SqlModelQuery))
|
|
253
330
|
config = self.compiled_query.config
|
|
254
331
|
query = self.compiled_query.query
|
|
255
332
|
|
|
256
333
|
if self.query_file.model_type == ModelType.DBVIEW:
|
|
257
334
|
def run_sql_query():
|
|
258
335
|
try:
|
|
259
|
-
return ConnectionSetIO.obj.run_sql_query_from_conn_name(query, config.connection_name)
|
|
336
|
+
return ConnectionSetIO.obj.run_sql_query_from_conn_name(query, config.connection_name, placeholders)
|
|
260
337
|
except RuntimeError as e:
|
|
261
338
|
raise u.FileExecutionError(f'Failed to run dbview sql model "{self.name}"', e)
|
|
262
339
|
|
|
@@ -268,7 +345,7 @@ class Model:
|
|
|
268
345
|
def create_table():
|
|
269
346
|
create_query = config.get_sql_for_create(self.name, query)
|
|
270
347
|
try:
|
|
271
|
-
return conn.execute(create_query)
|
|
348
|
+
return conn.execute(text(create_query), placeholders)
|
|
272
349
|
except Exception as e:
|
|
273
350
|
raise u.FileExecutionError(f'Failed to run federate sql model "{self.name}"', e)
|
|
274
351
|
|
|
@@ -276,8 +353,8 @@ class Model:
|
|
|
276
353
|
if self.needs_pandas or self.is_target:
|
|
277
354
|
self.result = await asyncio.to_thread(self._load_table_to_pandas, conn)
|
|
278
355
|
|
|
279
|
-
async def _run_python_model(self, conn:
|
|
280
|
-
assert(isinstance(self.compiled_query,
|
|
356
|
+
async def _run_python_model(self, conn: Connection) -> None:
|
|
357
|
+
assert(isinstance(self.compiled_query, _PyModelQuery))
|
|
281
358
|
|
|
282
359
|
df = await asyncio.to_thread(self.compiled_query.query)
|
|
283
360
|
if self.needs_sql_table:
|
|
@@ -285,92 +362,86 @@ class Model:
|
|
|
285
362
|
if self.needs_pandas or self.is_target:
|
|
286
363
|
self.result = df
|
|
287
364
|
|
|
288
|
-
async def run_model(self, conn:
|
|
365
|
+
async def run_model(self, conn: Connection, placeholders: dict = {}) -> None:
|
|
289
366
|
start = time.time()
|
|
367
|
+
|
|
290
368
|
if self.query_file.query_type == QueryType.SQL:
|
|
291
|
-
await self._run_sql_model(conn)
|
|
369
|
+
await self._run_sql_model(conn, placeholders)
|
|
292
370
|
elif self.query_file.query_type == QueryType.PYTHON:
|
|
293
371
|
await self._run_python_model(conn)
|
|
294
|
-
timer.add_activity_time(f"running model '{self.name}'", start)
|
|
295
372
|
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
await
|
|
300
|
-
|
|
301
|
-
async def trigger(self, conn: sqlite3.Connection) -> None:
|
|
302
|
-
self.wait_count -= 1
|
|
303
|
-
if (self.wait_count == 0):
|
|
304
|
-
await self.run_model(conn)
|
|
373
|
+
model_type = self.get_model_type().name.lower()
|
|
374
|
+
timer.add_activity_time(f"running {model_type} model '{self.name}'", start)
|
|
375
|
+
|
|
376
|
+
await super().run_model(conn, placeholders)
|
|
305
377
|
|
|
306
|
-
def
|
|
378
|
+
def retrieve_dependent_query_models(self, dependent_model_names: set[str]) -> None:
|
|
307
379
|
if self.name not in dependent_model_names:
|
|
308
380
|
dependent_model_names.add(self.name)
|
|
309
381
|
for dep_model in self.upstreams.values():
|
|
310
|
-
dep_model.
|
|
382
|
+
dep_model.retrieve_dependent_query_models(dependent_model_names)
|
|
311
383
|
|
|
312
384
|
|
|
313
385
|
@dataclass
|
|
314
|
-
class
|
|
386
|
+
class _DAG:
|
|
315
387
|
dataset: DatasetsConfig
|
|
316
|
-
target_model:
|
|
317
|
-
models_dict: dict[str,
|
|
388
|
+
target_model: _Referable
|
|
389
|
+
models_dict: dict[str, _Referable]
|
|
318
390
|
parameter_set: Optional[ParameterSet] = field(default=None, init=False)
|
|
391
|
+
placeholders: dict[str, Any] = field(init=False, default_factory=dict)
|
|
319
392
|
|
|
320
393
|
def apply_selections(
|
|
321
394
|
self, user: Optional[User], selections: dict[str, str], *, updates_only: bool = False, request_version: Optional[int] = None
|
|
322
395
|
) -> None:
|
|
323
396
|
start = time.time()
|
|
324
397
|
dataset_params = self.dataset.parameters
|
|
325
|
-
parameter_set = ParameterConfigsSetIO.obj.apply_selections(
|
|
326
|
-
|
|
398
|
+
parameter_set = ParameterConfigsSetIO.obj.apply_selections(
|
|
399
|
+
dataset_params, selections, user, updates_only=updates_only, request_version=request_version
|
|
400
|
+
)
|
|
327
401
|
self.parameter_set = parameter_set
|
|
328
|
-
timer.add_activity_time(f"applying selections for dataset", start)
|
|
402
|
+
timer.add_activity_time(f"applying selections for dataset '{self.dataset.name}'", start)
|
|
329
403
|
|
|
330
404
|
def _compile_context(self, context_func: ContextFunc, user: Optional[User]) -> tuple[dict[str, Any], ContextArgs]:
|
|
331
405
|
start = time.time()
|
|
332
406
|
context = {}
|
|
333
407
|
param_args = ParameterConfigsSetIO.args
|
|
334
408
|
prms = self.parameter_set.get_parameters_as_dict()
|
|
335
|
-
args = ContextArgs(param_args.proj_vars, param_args.env_vars, user, prms, self.dataset.traits)
|
|
409
|
+
args = ContextArgs(param_args.proj_vars, param_args.env_vars, user, prms, self.dataset.traits, self.placeholders)
|
|
336
410
|
try:
|
|
337
411
|
context_func(ctx=context, sqrl=args)
|
|
338
412
|
except Exception as e:
|
|
339
|
-
raise u.FileExecutionError(f'Failed to run {c.CONTEXT_FILE} for dataset "{self.dataset}"', e)
|
|
340
|
-
timer.add_activity_time(f"running context.py for dataset", start)
|
|
413
|
+
raise u.FileExecutionError(f'Failed to run {c.CONTEXT_FILE} for dataset "{self.dataset.name}"', e)
|
|
414
|
+
timer.add_activity_time(f"running context.py for dataset '{self.dataset.name}'", start)
|
|
341
415
|
return context, args
|
|
342
416
|
|
|
343
417
|
async def _compile_models(self, context: dict[str, Any], ctx_args: ContextArgs, recurse: bool) -> None:
|
|
344
|
-
await self.target_model.compile(context, ctx_args, self.models_dict, recurse)
|
|
418
|
+
await self.target_model.compile(context, ctx_args, self.placeholders, self.models_dict, recurse)
|
|
345
419
|
|
|
346
420
|
def _get_terminal_nodes(self) -> set[str]:
|
|
347
421
|
start = time.time()
|
|
348
422
|
terminal_nodes = self.target_model.get_terminal_nodes(set())
|
|
349
423
|
for model in self.models_dict.values():
|
|
350
424
|
model.confirmed_no_cycles = False
|
|
351
|
-
timer.add_activity_time(f"validating no cycles in
|
|
425
|
+
timer.add_activity_time(f"validating no cycles in model dependencies", start)
|
|
352
426
|
return terminal_nodes
|
|
353
427
|
|
|
354
|
-
async def _run_models(self, terminal_nodes: set[str]) -> None:
|
|
355
|
-
if u.use_duckdb()
|
|
356
|
-
|
|
357
|
-
conn = duckdb.connect()
|
|
358
|
-
else:
|
|
359
|
-
conn = sqlite3.connect(":memory:", check_same_thread=False)
|
|
428
|
+
async def _run_models(self, terminal_nodes: set[str], placeholders: dict = {}) -> None:
|
|
429
|
+
conn_url = "duckdb:///" if u.use_duckdb() else "sqlite:///?check_same_thread=False"
|
|
430
|
+
engine = create_engine(conn_url)
|
|
360
431
|
|
|
361
|
-
|
|
432
|
+
with engine.connect() as conn:
|
|
362
433
|
coroutines = []
|
|
363
434
|
for model_name in terminal_nodes:
|
|
364
435
|
model = self.models_dict[model_name]
|
|
365
|
-
coroutines.append(model.run_model(conn))
|
|
436
|
+
coroutines.append(model.run_model(conn, placeholders))
|
|
366
437
|
await asyncio.gather(*coroutines)
|
|
367
|
-
|
|
368
|
-
|
|
438
|
+
|
|
439
|
+
engine.dispose()
|
|
369
440
|
|
|
370
441
|
async def execute(
|
|
371
442
|
self, context_func: ContextFunc, user: Optional[User], selections: dict[str, str], *, request_version: Optional[int] = None,
|
|
372
443
|
runquery: bool = True, recurse: bool = True
|
|
373
|
-
) ->
|
|
444
|
+
) -> dict[str, Any]:
|
|
374
445
|
recurse = (recurse or runquery)
|
|
375
446
|
|
|
376
447
|
self.apply_selections(user, selections, request_version=request_version)
|
|
@@ -381,17 +452,35 @@ class DAG:
|
|
|
381
452
|
|
|
382
453
|
terminal_nodes = self._get_terminal_nodes()
|
|
383
454
|
|
|
455
|
+
placeholders = ctx_args._placeholders.copy()
|
|
384
456
|
if runquery:
|
|
385
|
-
await self._run_models(terminal_nodes)
|
|
457
|
+
await self._run_models(terminal_nodes, placeholders)
|
|
458
|
+
|
|
459
|
+
return placeholders
|
|
386
460
|
|
|
387
|
-
def
|
|
461
|
+
def get_all_query_models(self) -> set[str]:
|
|
388
462
|
all_model_names = set()
|
|
389
|
-
self.target_model.
|
|
463
|
+
self.target_model.retrieve_dependent_query_models(all_model_names)
|
|
390
464
|
return all_model_names
|
|
391
|
-
|
|
465
|
+
|
|
466
|
+
def to_networkx_graph(self) -> nx.DiGraph:
|
|
467
|
+
G = nx.DiGraph()
|
|
468
|
+
|
|
469
|
+
for model_name, model in self.models_dict.items():
|
|
470
|
+
model_type = model.get_model_type()
|
|
471
|
+
level = model.get_max_path_length_to_target()
|
|
472
|
+
if level is not None:
|
|
473
|
+
G.add_node(model_name, layer=-level, model_type=model_type)
|
|
474
|
+
|
|
475
|
+
for model_name in G.nodes:
|
|
476
|
+
model = self.models_dict[model_name]
|
|
477
|
+
for dep_model_name in model.downstreams:
|
|
478
|
+
G.add_edge(model_name, dep_model_name)
|
|
479
|
+
|
|
480
|
+
return G
|
|
392
481
|
|
|
393
482
|
class ModelsIO:
|
|
394
|
-
raw_queries_by_model: dict[str,
|
|
483
|
+
raw_queries_by_model: dict[str, _QueryFile]
|
|
395
484
|
context_func: ContextFunc
|
|
396
485
|
|
|
397
486
|
@classmethod
|
|
@@ -400,7 +489,6 @@ class ModelsIO:
|
|
|
400
489
|
cls.raw_queries_by_model = {}
|
|
401
490
|
|
|
402
491
|
def populate_raw_queries_for_type(folder_path: Path, model_type: ModelType):
|
|
403
|
-
|
|
404
492
|
def populate_from_file(dp, file):
|
|
405
493
|
query_type = None
|
|
406
494
|
filepath = os.path.join(dp, file)
|
|
@@ -409,13 +497,13 @@ class ModelsIO:
|
|
|
409
497
|
query_type = QueryType.PYTHON
|
|
410
498
|
module = pm.PyModule(filepath)
|
|
411
499
|
dependencies_func = module.get_func_or_class(c.DEP_FUNC, default_attr=lambda x: [])
|
|
412
|
-
raw_query =
|
|
500
|
+
raw_query = _RawPyQuery(module.get_func_or_class(c.MAIN_FUNC), dependencies_func)
|
|
413
501
|
elif extension == '.sql':
|
|
414
502
|
query_type = QueryType.SQL
|
|
415
|
-
raw_query =
|
|
503
|
+
raw_query = _RawSqlQuery(u.read_file(filepath))
|
|
416
504
|
|
|
417
505
|
if query_type is not None:
|
|
418
|
-
query_file =
|
|
506
|
+
query_file = _QueryFile(filepath, model_type, query_type, raw_query)
|
|
419
507
|
if file_stem in cls.raw_queries_by_model:
|
|
420
508
|
conflicts = [cls.raw_queries_by_model[file_stem].filepath, filepath]
|
|
421
509
|
raise u.ConfigurationError(f"Multiple models found for '{file_stem}': {conflicts}")
|
|
@@ -431,44 +519,98 @@ class ModelsIO:
|
|
|
431
519
|
federates_path = u.join_paths(c.MODELS_FOLDER, c.FEDERATES_FOLDER)
|
|
432
520
|
populate_raw_queries_for_type(federates_path, ModelType.FEDERATE)
|
|
433
521
|
|
|
434
|
-
context_path = u.join_paths(c.
|
|
522
|
+
context_path = u.join_paths(c.PYCONFIGS_FOLDER, c.CONTEXT_FILE)
|
|
435
523
|
cls.context_func = pm.PyModule(context_path).get_func_or_class(c.MAIN_FUNC, default_attr=lambda x, y: None)
|
|
436
524
|
|
|
437
|
-
timer.add_activity_time("loading models and
|
|
525
|
+
timer.add_activity_time("loading files for models and context.py", start)
|
|
438
526
|
|
|
439
527
|
@classmethod
|
|
440
|
-
def GenerateDAG(cls, dataset: str, *, target_model_name: Optional[str] = None, always_pandas: bool = False) ->
|
|
441
|
-
|
|
528
|
+
def GenerateDAG(cls, dataset: str, *, target_model_name: Optional[str] = None, always_pandas: bool = False) -> _DAG:
|
|
529
|
+
seeds_dict = SeedsIO.obj.get_dataframes()
|
|
530
|
+
|
|
531
|
+
models_dict: dict[str, _Referable] = {key: _Seed(key, df) for key, df in seeds_dict.items()}
|
|
532
|
+
for key, val in cls.raw_queries_by_model.items():
|
|
533
|
+
models_dict[key] = _Model(key, val)
|
|
534
|
+
models_dict[key].needs_pandas = always_pandas
|
|
442
535
|
|
|
443
536
|
dataset_config = ManifestIO.obj.datasets[dataset]
|
|
444
537
|
target_model_name = dataset_config.model if target_model_name is None else target_model_name
|
|
445
538
|
target_model = models_dict[target_model_name]
|
|
446
539
|
target_model.is_target = True
|
|
447
540
|
|
|
448
|
-
return
|
|
541
|
+
return _DAG(dataset_config, target_model, models_dict)
|
|
449
542
|
|
|
450
543
|
@classmethod
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
544
|
+
def draw_dag(cls, dag: _DAG, output_folder: Path) -> None:
|
|
545
|
+
color_map = {ModelType.SEED: "green", ModelType.DBVIEW: "red", ModelType.FEDERATE: "skyblue"}
|
|
546
|
+
|
|
547
|
+
G = dag.to_networkx_graph()
|
|
548
|
+
|
|
549
|
+
fig, _ = plt.subplots()
|
|
550
|
+
pos = nx.multipartite_layout(G, subset_key="layer")
|
|
551
|
+
colors = [color_map[node[1]] for node in G.nodes(data="model_type")]
|
|
552
|
+
nx.draw(G, pos=pos, node_shape='^', node_size=1000, node_color=colors, arrowsize=20)
|
|
553
|
+
|
|
554
|
+
y_values = [val[1] for val in pos.values()]
|
|
555
|
+
scale = max(y_values) - min(y_values) if len(y_values) > 0 else 0
|
|
556
|
+
label_pos = {key: (val[0], val[1]-0.002-0.1*scale) for key, val in pos.items()}
|
|
557
|
+
nx.draw_networkx_labels(G, pos=label_pos, font_size=8)
|
|
558
|
+
|
|
559
|
+
fig.tight_layout()
|
|
560
|
+
plt.margins(x=0.1, y=0.1)
|
|
561
|
+
plt.savefig(u.join_paths(output_folder, "dag.png"))
|
|
562
|
+
plt.close(fig)
|
|
563
|
+
|
|
564
|
+
@classmethod
|
|
565
|
+
async def WriteDatasetOutputsGivenTestSet(
|
|
566
|
+
cls, dataset_conf: DatasetsConfig, select: str, test_set: Optional[str], runquery: bool, recurse: bool
|
|
567
|
+
) -> Any:
|
|
568
|
+
dataset = dataset_conf.name
|
|
569
|
+
default_test_set, default_test_set_conf = ManifestIO.obj.get_default_test_set(dataset)
|
|
570
|
+
if test_set is None or test_set == default_test_set:
|
|
571
|
+
test_set, test_set_conf = default_test_set, default_test_set_conf
|
|
572
|
+
elif test_set in ManifestIO.obj.selection_test_sets:
|
|
573
|
+
test_set_conf = ManifestIO.obj.selection_test_sets[test_set]
|
|
574
|
+
else:
|
|
575
|
+
raise u.InvalidInputError(f"No test set named '{test_set}' was found when compiling dataset '{dataset}'. The test set must be defined if not default for dataset.")
|
|
455
576
|
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
577
|
+
error_msg_intro = f"Cannot compile dataset '{dataset}' with test set '{test_set}'."
|
|
578
|
+
if test_set_conf.datasets is not None and dataset not in test_set_conf.datasets:
|
|
579
|
+
raise u.InvalidInputError(f"{error_msg_intro}\n Applicable datasets for test set '{test_set}' does not include dataset '{dataset}'.")
|
|
459
580
|
|
|
581
|
+
user_attributes = test_set_conf.user_attributes.copy()
|
|
582
|
+
selections = test_set_conf.parameters.copy()
|
|
583
|
+
username, is_internal = user_attributes.pop("username", ""), user_attributes.pop("is_internal", False)
|
|
584
|
+
if test_set_conf.is_authenticated:
|
|
585
|
+
user_cls: type[User] = Authenticator.get_auth_helper().get_func_or_class("User", default_attr=User)
|
|
586
|
+
user = user_cls.Create(username, is_internal=is_internal, **user_attributes)
|
|
587
|
+
elif dataset_conf.scope == DatasetScope.PUBLIC:
|
|
588
|
+
user = None
|
|
589
|
+
else:
|
|
590
|
+
raise u.ConfigurationError(f"{error_msg_intro}\n Non-public datasets require a test set with 'user_attributes' section defined")
|
|
591
|
+
|
|
592
|
+
if dataset_conf.scope == DatasetScope.PRIVATE and not is_internal:
|
|
593
|
+
raise u.ConfigurationError(f"{error_msg_intro}\n Private datasets require a test set with user_attribute 'is_internal' set to true")
|
|
594
|
+
|
|
595
|
+
# always_pandas is set to True for creating CSV files from results (when runquery is True)
|
|
460
596
|
dag = cls.GenerateDAG(dataset, target_model_name=select, always_pandas=True)
|
|
461
|
-
await dag.execute(cls.context_func, user, selections, runquery=runquery, recurse=recurse)
|
|
597
|
+
placeholders = await dag.execute(cls.context_func, user, selections, runquery=runquery, recurse=recurse)
|
|
462
598
|
|
|
463
|
-
output_folder = u.join_paths(c.TARGET_FOLDER, c.COMPILE_FOLDER,
|
|
599
|
+
output_folder = u.join_paths(c.TARGET_FOLDER, c.COMPILE_FOLDER, dataset, test_set)
|
|
464
600
|
if os.path.exists(output_folder):
|
|
465
601
|
shutil.rmtree(output_folder)
|
|
602
|
+
os.makedirs(output_folder, exist_ok=True)
|
|
603
|
+
|
|
604
|
+
def write_placeholders() -> None:
|
|
605
|
+
output_filepath = u.join_paths(output_folder, "placeholders.json")
|
|
606
|
+
with open(output_filepath, 'w') as f:
|
|
607
|
+
json.dump(placeholders, f, indent=4)
|
|
466
608
|
|
|
467
|
-
def write_model_outputs(model:
|
|
609
|
+
def write_model_outputs(model: _Model) -> None:
|
|
468
610
|
subfolder = c.DBVIEWS_FOLDER if model.query_file.model_type == ModelType.DBVIEW else c.FEDERATES_FOLDER
|
|
469
611
|
subpath = u.join_paths(output_folder, subfolder)
|
|
470
612
|
os.makedirs(subpath, exist_ok=True)
|
|
471
|
-
if isinstance(model.compiled_query,
|
|
613
|
+
if isinstance(model.compiled_query, _SqlModelQuery):
|
|
472
614
|
output_filepath = u.join_paths(subpath, model.name+'.sql')
|
|
473
615
|
query = model.compiled_query.query
|
|
474
616
|
with open(output_filepath, 'w') as f:
|
|
@@ -477,39 +619,50 @@ class ModelsIO:
|
|
|
477
619
|
output_filepath = u.join_paths(subpath, model.name+'.csv')
|
|
478
620
|
model.result.to_csv(output_filepath, index=False)
|
|
479
621
|
|
|
480
|
-
|
|
622
|
+
write_placeholders()
|
|
623
|
+
all_model_names = dag.get_all_query_models()
|
|
481
624
|
coroutines = [asyncio.to_thread(write_model_outputs, dag.models_dict[name]) for name in all_model_names]
|
|
482
625
|
await asyncio.gather(*coroutines)
|
|
483
|
-
|
|
626
|
+
|
|
627
|
+
if recurse:
|
|
628
|
+
cls.draw_dag(dag, output_folder)
|
|
629
|
+
|
|
630
|
+
if isinstance(dag.target_model, _Model):
|
|
631
|
+
return dag.target_model.compiled_query.query # else return None
|
|
484
632
|
|
|
485
633
|
@classmethod
|
|
486
634
|
async def WriteOutputs(
|
|
487
|
-
cls, dataset: Optional[str], select: Optional[str],
|
|
635
|
+
cls, dataset: Optional[str], do_all_datasets: bool, select: Optional[str], test_set: Optional[str], do_all_test_sets: bool,
|
|
636
|
+
runquery: bool
|
|
488
637
|
) -> None:
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
638
|
+
|
|
639
|
+
def get_applicable_test_sets(dataset: str) -> list[str]:
|
|
640
|
+
applicable_test_sets = []
|
|
641
|
+
for test_set_name, test_set_config in ManifestIO.obj.selection_test_sets.items():
|
|
642
|
+
if test_set_config.datasets is None or dataset in test_set_config.datasets:
|
|
643
|
+
applicable_test_sets.append(test_set_name)
|
|
644
|
+
return applicable_test_sets
|
|
496
645
|
|
|
497
646
|
recurse = True
|
|
498
647
|
dataset_configs = ManifestIO.obj.datasets
|
|
499
|
-
if
|
|
500
|
-
selected_models = [(dataset
|
|
648
|
+
if do_all_datasets:
|
|
649
|
+
selected_models = [(dataset, dataset.model) for dataset in dataset_configs.values()]
|
|
501
650
|
else:
|
|
502
651
|
if select is None:
|
|
503
652
|
select = dataset_configs[dataset].model
|
|
504
653
|
else:
|
|
505
654
|
recurse = False
|
|
506
|
-
selected_models = [(dataset, select)]
|
|
655
|
+
selected_models = [(dataset_configs[dataset], select)]
|
|
507
656
|
|
|
508
657
|
coroutines = []
|
|
509
|
-
for
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
658
|
+
for dataset_conf, select in selected_models:
|
|
659
|
+
if do_all_test_sets:
|
|
660
|
+
for test_set_name in get_applicable_test_sets(dataset_conf.name):
|
|
661
|
+
coroutine = cls.WriteDatasetOutputsGivenTestSet(dataset_conf, select, test_set_name, runquery, recurse)
|
|
662
|
+
coroutines.append(coroutine)
|
|
663
|
+
|
|
664
|
+
coroutine = cls.WriteDatasetOutputsGivenTestSet(dataset_conf, select, test_set, runquery, recurse)
|
|
665
|
+
coroutines.append(coroutine)
|
|
513
666
|
|
|
514
667
|
queries = await asyncio.gather(*coroutines)
|
|
515
668
|
if not recurse and len(queries) == 1 and isinstance(queries[0], str):
|