squirrels 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of squirrels might be problematic. Click here for more details.
- squirrels/__init__.py +11 -4
- squirrels/_api_response_models.py +118 -0
- squirrels/_api_server.py +146 -75
- squirrels/_authenticator.py +10 -8
- squirrels/_command_line.py +17 -11
- squirrels/_connection_set.py +4 -3
- squirrels/_constants.py +15 -6
- squirrels/_environcfg.py +15 -11
- squirrels/_initializer.py +25 -15
- squirrels/_manifest.py +22 -12
- squirrels/_models.py +316 -154
- squirrels/_parameter_configs.py +195 -57
- squirrels/_parameter_sets.py +14 -17
- squirrels/_py_module.py +2 -4
- squirrels/_seeds.py +38 -0
- squirrels/_utils.py +41 -33
- squirrels/arguments/run_time_args.py +76 -34
- squirrels/data_sources.py +172 -51
- squirrels/dateutils.py +3 -3
- squirrels/package_data/assets/index.js +14 -14
- squirrels/package_data/base_project/.gitignore +1 -0
- squirrels/package_data/base_project/{database → assets}/expenses.db +0 -0
- squirrels/package_data/base_project/assets/weather.db +0 -0
- squirrels/package_data/base_project/connections.yml +1 -1
- squirrels/package_data/base_project/docker/Dockerfile +1 -1
- squirrels/package_data/base_project/{environcfg.yml → env.yml} +8 -8
- squirrels/package_data/base_project/models/dbviews/database_view1.py +25 -14
- squirrels/package_data/base_project/models/dbviews/database_view1.sql +20 -14
- squirrels/package_data/base_project/models/federates/dataset_example.py +6 -5
- squirrels/package_data/base_project/models/federates/dataset_example.sql +1 -1
- squirrels/package_data/base_project/parameters.yml +57 -28
- squirrels/package_data/base_project/pyconfigs/auth.py +11 -10
- squirrels/package_data/base_project/pyconfigs/connections.py +7 -9
- squirrels/package_data/base_project/pyconfigs/context.py +49 -33
- squirrels/package_data/base_project/pyconfigs/parameters.py +62 -30
- squirrels/package_data/base_project/seeds/seed_categories.csv +6 -0
- squirrels/package_data/base_project/seeds/seed_subcategories.csv +15 -0
- squirrels/package_data/base_project/squirrels.yml.j2 +37 -20
- squirrels/parameter_options.py +30 -10
- squirrels/parameters.py +300 -70
- squirrels/user_base.py +3 -13
- squirrels-0.3.1.dist-info/LICENSE +201 -0
- {squirrels-0.2.2.dist-info → squirrels-0.3.1.dist-info}/METADATA +17 -17
- squirrels-0.3.1.dist-info/RECORD +56 -0
- squirrels/package_data/base_project/database/weather.db +0 -0
- squirrels/package_data/base_project/seeds/mocks/category.csv +0 -3
- squirrels/package_data/base_project/seeds/mocks/max_filter.csv +0 -2
- squirrels/package_data/base_project/seeds/mocks/subcategory.csv +0 -6
- squirrels-0.2.2.dist-info/LICENSE +0 -22
- squirrels-0.2.2.dist-info/RECORD +0 -55
- {squirrels-0.2.2.dist-info → squirrels-0.3.1.dist-info}/WHEEL +0 -0
- {squirrels-0.2.2.dist-info → squirrels-0.3.1.dist-info}/entry_points.txt +0 -0
squirrels/_models.py
CHANGED
|
@@ -1,21 +1,26 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import Optional, Callable, Iterable, Any
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
|
+
from abc import ABCMeta, abstractmethod
|
|
4
5
|
from enum import Enum
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
|
|
7
|
+
from sqlalchemy import create_engine, text, Connection
|
|
8
|
+
import asyncio, os, shutil, pandas as pd, json
|
|
9
|
+
import matplotlib.pyplot as plt, networkx as nx
|
|
7
10
|
|
|
8
11
|
from . import _constants as c, _utils as u, _py_module as pm
|
|
9
12
|
from .arguments.run_time_args import ContextArgs, ModelDepsArgs, ModelArgs
|
|
10
13
|
from ._authenticator import User, Authenticator
|
|
11
14
|
from ._connection_set import ConnectionSetIO
|
|
12
|
-
from ._manifest import ManifestIO, DatasetsConfig
|
|
15
|
+
from ._manifest import ManifestIO, DatasetsConfig, DatasetScope
|
|
13
16
|
from ._parameter_sets import ParameterConfigsSetIO, ParameterSet
|
|
17
|
+
from ._seeds import SeedsIO
|
|
14
18
|
from ._timer import timer, time
|
|
15
19
|
|
|
16
20
|
class ModelType(Enum):
|
|
17
21
|
DBVIEW = 1
|
|
18
22
|
FEDERATE = 2
|
|
23
|
+
SEED = 3
|
|
19
24
|
|
|
20
25
|
class QueryType(Enum):
|
|
21
26
|
SQL = 0
|
|
@@ -27,12 +32,30 @@ class Materialization(Enum):
|
|
|
27
32
|
|
|
28
33
|
|
|
29
34
|
@dataclass
|
|
30
|
-
class
|
|
35
|
+
class _SqlModelConfig:
|
|
31
36
|
## Applicable for dbview models
|
|
32
37
|
connection_name: str
|
|
33
38
|
|
|
34
39
|
## Applicable for federated models
|
|
35
40
|
materialized: Materialization
|
|
41
|
+
|
|
42
|
+
def set_attribute(self, **kwargs) -> str:
|
|
43
|
+
connection_name = kwargs.get(c.DBVIEW_CONN_KEY)
|
|
44
|
+
if connection_name is not None:
|
|
45
|
+
if not isinstance(connection_name, str):
|
|
46
|
+
raise u.ConfigurationError("The 'connection_name' argument of 'config' macro must be a string")
|
|
47
|
+
self.connection_name = connection_name
|
|
48
|
+
|
|
49
|
+
materialized: str = kwargs.get(c.MATERIALIZED_KEY)
|
|
50
|
+
if materialized is not None:
|
|
51
|
+
if not isinstance(materialized, str):
|
|
52
|
+
raise u.ConfigurationError("The 'materialized' argument of 'config' macro must be a string")
|
|
53
|
+
try:
|
|
54
|
+
self.materialized = Materialization[materialized.upper()]
|
|
55
|
+
except KeyError as e:
|
|
56
|
+
valid_options = [x.name for x in Materialization]
|
|
57
|
+
raise u.ConfigurationError(f"The 'materialized' argument value '{materialized}' is not valid. Must be one of: {valid_options}") from e
|
|
58
|
+
return ""
|
|
36
59
|
|
|
37
60
|
def get_sql_for_create(self, model_name: str, select_query: str) -> str:
|
|
38
61
|
if self.materialized == Materialization.TABLE:
|
|
@@ -43,78 +66,138 @@ class SqlModelConfig:
|
|
|
43
66
|
raise NotImplementedError(f"Materialization option not supported: {self.materialized}")
|
|
44
67
|
|
|
45
68
|
return create_prefix + select_query
|
|
46
|
-
|
|
47
|
-
def set_attribute(self, **kwargs) -> str:
|
|
48
|
-
connection_name = kwargs.get(c.DBVIEW_CONN_KEY)
|
|
49
|
-
materialized = kwargs.get(c.MATERIALIZED_KEY)
|
|
50
|
-
if isinstance(connection_name, str):
|
|
51
|
-
self.connection_name = connection_name
|
|
52
|
-
if isinstance(materialized, str):
|
|
53
|
-
self.materialized = Materialization[materialized.upper()]
|
|
54
|
-
return ""
|
|
55
69
|
|
|
56
70
|
|
|
57
71
|
ContextFunc = Callable[[dict[str, Any], ContextArgs], None]
|
|
58
72
|
|
|
59
73
|
|
|
60
74
|
@dataclass(frozen=True)
|
|
61
|
-
class
|
|
75
|
+
class _RawQuery(metaclass=ABCMeta):
|
|
62
76
|
pass
|
|
63
77
|
|
|
64
78
|
@dataclass(frozen=True)
|
|
65
|
-
class
|
|
79
|
+
class _RawSqlQuery(_RawQuery):
|
|
66
80
|
query: str
|
|
67
81
|
|
|
68
82
|
@dataclass(frozen=True)
|
|
69
|
-
class
|
|
83
|
+
class _RawPyQuery(_RawQuery):
|
|
70
84
|
query: Callable[[Any], pd.DataFrame]
|
|
71
85
|
dependencies_func: Callable[[Any], Iterable]
|
|
72
86
|
|
|
73
87
|
|
|
74
88
|
@dataclass
|
|
75
|
-
class
|
|
89
|
+
class _Query(metaclass=ABCMeta):
|
|
76
90
|
query: Any
|
|
77
91
|
|
|
78
92
|
@dataclass
|
|
79
|
-
class
|
|
93
|
+
class _WorkInProgress(_Query):
|
|
80
94
|
query: None = field(default=None, init=False)
|
|
81
95
|
|
|
82
96
|
@dataclass
|
|
83
|
-
class
|
|
97
|
+
class _SqlModelQuery(_Query):
|
|
84
98
|
query: str
|
|
85
|
-
config:
|
|
99
|
+
config: _SqlModelConfig
|
|
86
100
|
|
|
87
101
|
@dataclass
|
|
88
|
-
class
|
|
102
|
+
class _PyModelQuery(_Query):
|
|
89
103
|
query: Callable[[], pd.DataFrame]
|
|
90
104
|
|
|
91
105
|
|
|
92
106
|
@dataclass(frozen=True)
|
|
93
|
-
class
|
|
107
|
+
class _QueryFile:
|
|
94
108
|
filepath: str
|
|
95
109
|
model_type: ModelType
|
|
96
110
|
query_type: QueryType
|
|
97
|
-
raw_query:
|
|
111
|
+
raw_query: _RawQuery
|
|
98
112
|
|
|
99
113
|
|
|
100
114
|
@dataclass
|
|
101
|
-
class
|
|
115
|
+
class _Referable(metaclass=ABCMeta):
|
|
102
116
|
name: str
|
|
103
|
-
query_file: QueryFile
|
|
104
117
|
is_target: bool = field(default=False, init=False)
|
|
105
|
-
compiled_query: Optional[Query] = field(default=None, init=False)
|
|
106
118
|
|
|
107
119
|
needs_sql_table: bool = field(default=False, init=False)
|
|
108
|
-
needs_pandas: bool = False
|
|
120
|
+
needs_pandas: bool = field(default=False, init=False)
|
|
109
121
|
result: Optional[pd.DataFrame] = field(default=None, init=False, repr=False)
|
|
110
|
-
|
|
111
|
-
wait_count: int = field(default=0, init=False, repr=False)
|
|
112
|
-
upstreams: dict[str, Model] = field(default_factory=dict, init=False, repr=False)
|
|
113
|
-
downstreams: dict[str, Model] = field(default_factory=dict, init=False, repr=False)
|
|
114
122
|
|
|
123
|
+
wait_count: int = field(default=0, init=False, repr=False)
|
|
115
124
|
confirmed_no_cycles: bool = field(default=False, init=False)
|
|
125
|
+
upstreams: dict[str, _Referable] = field(default_factory=dict, init=False, repr=False)
|
|
126
|
+
downstreams: dict[str, _Referable] = field(default_factory=dict, init=False, repr=False)
|
|
127
|
+
|
|
128
|
+
@abstractmethod
|
|
129
|
+
def get_model_type(self) -> ModelType:
|
|
130
|
+
pass
|
|
131
|
+
|
|
132
|
+
async def compile(
|
|
133
|
+
self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str, _Referable], recurse: bool
|
|
134
|
+
) -> None:
|
|
135
|
+
pass
|
|
136
|
+
|
|
137
|
+
@abstractmethod
|
|
138
|
+
def get_terminal_nodes(self, depencency_path: set[str]) -> set[str]:
|
|
139
|
+
pass
|
|
140
|
+
|
|
141
|
+
def _load_pandas_to_table(self, df: pd.DataFrame, conn: Connection) -> None:
|
|
142
|
+
df.to_sql(self.name, conn, index=False)
|
|
143
|
+
|
|
144
|
+
def _load_table_to_pandas(self, conn: Connection) -> pd.DataFrame:
|
|
145
|
+
query = f"SELECT * FROM {self.name}"
|
|
146
|
+
return pd.read_sql(query, conn)
|
|
147
|
+
|
|
148
|
+
async def _trigger(self, conn: Connection, placeholders: dict = {}) -> None:
|
|
149
|
+
self.wait_count -= 1
|
|
150
|
+
if (self.wait_count == 0):
|
|
151
|
+
await self.run_model(conn, placeholders)
|
|
152
|
+
|
|
153
|
+
@abstractmethod
|
|
154
|
+
async def run_model(self, conn: Connection, placeholders: dict = {}) -> None:
|
|
155
|
+
coroutines = []
|
|
156
|
+
for model in self.downstreams.values():
|
|
157
|
+
coroutines.append(model._trigger(conn, placeholders))
|
|
158
|
+
await asyncio.gather(*coroutines)
|
|
159
|
+
|
|
160
|
+
def retrieve_dependent_query_models(self, dependent_model_names: set[str]) -> None:
|
|
161
|
+
pass
|
|
162
|
+
|
|
163
|
+
def get_max_path_length_to_target(self) -> int:
|
|
164
|
+
if not hasattr(self, "max_path_len_to_target"):
|
|
165
|
+
path_lengths = []
|
|
166
|
+
for child_model in self.downstreams.values():
|
|
167
|
+
path_lengths.append(child_model.get_max_path_length_to_target()+1)
|
|
168
|
+
if len(path_lengths) > 0:
|
|
169
|
+
self.max_path_len_to_target = max(path_lengths)
|
|
170
|
+
else:
|
|
171
|
+
self.max_path_len_to_target = 0 if self.is_target else None
|
|
172
|
+
return self.max_path_len_to_target
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@dataclass
|
|
176
|
+
class _Seed(_Referable):
|
|
177
|
+
result: pd.DataFrame
|
|
116
178
|
|
|
117
|
-
def
|
|
179
|
+
def get_model_type(self) -> ModelType:
|
|
180
|
+
return ModelType.SEED
|
|
181
|
+
|
|
182
|
+
def get_terminal_nodes(self, depencency_path: set[str]) -> set[str]:
|
|
183
|
+
return {self.name}
|
|
184
|
+
|
|
185
|
+
async def run_model(self, conn: Connection, placeholders: dict = {}) -> None:
|
|
186
|
+
if self.needs_sql_table:
|
|
187
|
+
await asyncio.to_thread(self._load_pandas_to_table, self.result, conn)
|
|
188
|
+
await super().run_model(conn, placeholders)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
@dataclass
|
|
192
|
+
class _Model(_Referable):
|
|
193
|
+
query_file: _QueryFile
|
|
194
|
+
|
|
195
|
+
compiled_query: Optional[_Query] = field(default=None, init=False)
|
|
196
|
+
|
|
197
|
+
def get_model_type(self) -> ModelType:
|
|
198
|
+
return self.query_file.model_type
|
|
199
|
+
|
|
200
|
+
def _add_upstream(self, other: _Referable) -> None:
|
|
118
201
|
self.upstreams[other.name] = other
|
|
119
202
|
other.downstreams[self.name] = self
|
|
120
203
|
|
|
@@ -137,17 +220,20 @@ class Model:
|
|
|
137
220
|
materialized = federate_config.materialized
|
|
138
221
|
return Materialization[materialized.upper()]
|
|
139
222
|
|
|
140
|
-
async def _compile_sql_model(
|
|
141
|
-
|
|
223
|
+
async def _compile_sql_model(
|
|
224
|
+
self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any]
|
|
225
|
+
) -> tuple[_SqlModelQuery, set]:
|
|
226
|
+
assert(isinstance(self.query_file.raw_query, _RawSqlQuery))
|
|
227
|
+
|
|
142
228
|
raw_query = self.query_file.raw_query.query
|
|
143
|
-
|
|
144
229
|
connection_name = self._get_dbview_conn_name()
|
|
145
230
|
materialized = self._get_materialized()
|
|
146
|
-
configuration =
|
|
231
|
+
configuration = _SqlModelConfig(connection_name, materialized)
|
|
232
|
+
is_placeholder = lambda placeholder: placeholder in placeholders
|
|
147
233
|
kwargs = {
|
|
148
|
-
"proj_vars": ctx_args.proj_vars, "env_vars": ctx_args.env_vars,
|
|
149
|
-
"
|
|
150
|
-
"
|
|
234
|
+
"proj_vars": ctx_args.proj_vars, "env_vars": ctx_args.env_vars, "user": ctx_args.user, "prms": ctx_args.prms,
|
|
235
|
+
"traits": ctx_args.traits, "ctx": ctx, "is_placeholder": is_placeholder, "set_placeholder": ctx_args.set_placeholder,
|
|
236
|
+
"config": configuration.set_attribute, "is_param_enabled": ctx_args.param_exists
|
|
151
237
|
}
|
|
152
238
|
dependencies = set()
|
|
153
239
|
if self.query_file.model_type == ModelType.FEDERATE:
|
|
@@ -157,16 +243,21 @@ class Model:
|
|
|
157
243
|
kwargs["ref"] = ref
|
|
158
244
|
|
|
159
245
|
try:
|
|
160
|
-
query = await asyncio.to_thread(u.render_string, raw_query, kwargs)
|
|
246
|
+
query = await asyncio.to_thread(u.render_string, raw_query, **kwargs)
|
|
161
247
|
except Exception as e:
|
|
162
248
|
raise u.FileExecutionError(f'Failed to compile sql model "{self.name}"', e)
|
|
163
249
|
|
|
164
|
-
compiled_query =
|
|
250
|
+
compiled_query = _SqlModelQuery(query, configuration)
|
|
165
251
|
return compiled_query, dependencies
|
|
166
252
|
|
|
167
|
-
async def _compile_python_model(
|
|
168
|
-
|
|
169
|
-
|
|
253
|
+
async def _compile_python_model(
|
|
254
|
+
self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any]
|
|
255
|
+
) -> tuple[_PyModelQuery, set]:
|
|
256
|
+
assert(isinstance(self.query_file.raw_query, _RawPyQuery))
|
|
257
|
+
|
|
258
|
+
sqrl_args = ModelDepsArgs(
|
|
259
|
+
ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits, placeholders, ctx
|
|
260
|
+
)
|
|
170
261
|
try:
|
|
171
262
|
dependencies = await asyncio.to_thread(self.query_file.raw_query.dependencies_func, sqrl_args)
|
|
172
263
|
except Exception as e:
|
|
@@ -174,35 +265,43 @@ class Model:
|
|
|
174
265
|
|
|
175
266
|
dbview_conn_name = self._get_dbview_conn_name()
|
|
176
267
|
connections = ConnectionSetIO.obj.get_engines_as_dict()
|
|
177
|
-
ref = lambda
|
|
178
|
-
sqrl_args = ModelArgs(
|
|
179
|
-
|
|
268
|
+
ref = lambda model: self.upstreams[model].result
|
|
269
|
+
sqrl_args = ModelArgs(
|
|
270
|
+
ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits, placeholders, ctx,
|
|
271
|
+
dbview_conn_name, connections, dependencies, ref
|
|
272
|
+
)
|
|
180
273
|
|
|
181
274
|
def compiled_query():
|
|
182
275
|
try:
|
|
183
|
-
|
|
276
|
+
raw_query: _RawPyQuery = self.query_file.raw_query
|
|
277
|
+
return raw_query.query(sqrl=sqrl_args)
|
|
184
278
|
except Exception as e:
|
|
185
279
|
raise u.FileExecutionError(f'Failed to run "{c.MAIN_FUNC}" function for python model "{self.name}"', e)
|
|
186
280
|
|
|
187
|
-
return
|
|
281
|
+
return _PyModelQuery(compiled_query), dependencies
|
|
188
282
|
|
|
189
|
-
async def compile(
|
|
283
|
+
async def compile(
|
|
284
|
+
self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str, _Referable], recurse: bool
|
|
285
|
+
) -> None:
|
|
190
286
|
if self.compiled_query is not None:
|
|
191
287
|
return
|
|
192
288
|
else:
|
|
193
|
-
self.compiled_query =
|
|
289
|
+
self.compiled_query = _WorkInProgress()
|
|
194
290
|
|
|
195
291
|
start = time.time()
|
|
292
|
+
|
|
196
293
|
if self.query_file.query_type == QueryType.SQL:
|
|
197
|
-
compiled_query, dependencies = await self._compile_sql_model(ctx, ctx_args)
|
|
294
|
+
compiled_query, dependencies = await self._compile_sql_model(ctx, ctx_args, placeholders)
|
|
198
295
|
elif self.query_file.query_type == QueryType.PYTHON:
|
|
199
|
-
compiled_query, dependencies = await self._compile_python_model(ctx, ctx_args)
|
|
296
|
+
compiled_query, dependencies = await self._compile_python_model(ctx, ctx_args, placeholders)
|
|
200
297
|
else:
|
|
201
298
|
raise NotImplementedError(f"Query type not supported: {self.query_file.query_type}")
|
|
202
299
|
|
|
203
300
|
self.compiled_query = compiled_query
|
|
204
301
|
self.wait_count = len(dependencies)
|
|
205
|
-
|
|
302
|
+
|
|
303
|
+
model_type = self.get_model_type().name.lower()
|
|
304
|
+
timer.add_activity_time(f"compiling {model_type} model '{self.name}'", start)
|
|
206
305
|
|
|
207
306
|
if not recurse:
|
|
208
307
|
return
|
|
@@ -211,7 +310,7 @@ class Model:
|
|
|
211
310
|
coroutines = []
|
|
212
311
|
for dep_model in dep_models:
|
|
213
312
|
self._add_upstream(dep_model)
|
|
214
|
-
coro = dep_model.compile(ctx, ctx_args, models_dict, recurse)
|
|
313
|
+
coro = dep_model.compile(ctx, ctx_args, placeholders, models_dict, recurse)
|
|
215
314
|
coroutines.append(coro)
|
|
216
315
|
await asyncio.gather(*coroutines)
|
|
217
316
|
|
|
@@ -234,29 +333,16 @@ class Model:
|
|
|
234
333
|
|
|
235
334
|
self.confirmed_no_cycles = True
|
|
236
335
|
return terminal_nodes
|
|
237
|
-
|
|
238
|
-
def _load_pandas_to_table(self, df: pd.DataFrame, conn: sqlite3.Connection) -> None:
|
|
239
|
-
if u.use_duckdb():
|
|
240
|
-
conn.execute(f"CREATE TABLE {self.name} AS FROM df")
|
|
241
|
-
else:
|
|
242
|
-
df.to_sql(self.name, conn, index=False)
|
|
243
|
-
|
|
244
|
-
def _load_table_to_pandas(self, conn: sqlite3.Connection) -> pd.DataFrame:
|
|
245
|
-
if u.use_duckdb():
|
|
246
|
-
return conn.execute(f"FROM {self.name}").df()
|
|
247
|
-
else:
|
|
248
|
-
query = f"SELECT * FROM {self.name}"
|
|
249
|
-
return pd.read_sql(query, conn)
|
|
250
336
|
|
|
251
|
-
async def _run_sql_model(self, conn:
|
|
252
|
-
assert(isinstance(self.compiled_query,
|
|
337
|
+
async def _run_sql_model(self, conn: Connection, placeholders: dict = {}) -> None:
|
|
338
|
+
assert(isinstance(self.compiled_query, _SqlModelQuery))
|
|
253
339
|
config = self.compiled_query.config
|
|
254
340
|
query = self.compiled_query.query
|
|
255
341
|
|
|
256
342
|
if self.query_file.model_type == ModelType.DBVIEW:
|
|
257
343
|
def run_sql_query():
|
|
258
344
|
try:
|
|
259
|
-
return ConnectionSetIO.obj.run_sql_query_from_conn_name(query, config.connection_name)
|
|
345
|
+
return ConnectionSetIO.obj.run_sql_query_from_conn_name(query, config.connection_name, placeholders)
|
|
260
346
|
except RuntimeError as e:
|
|
261
347
|
raise u.FileExecutionError(f'Failed to run dbview sql model "{self.name}"', e)
|
|
262
348
|
|
|
@@ -268,7 +354,7 @@ class Model:
|
|
|
268
354
|
def create_table():
|
|
269
355
|
create_query = config.get_sql_for_create(self.name, query)
|
|
270
356
|
try:
|
|
271
|
-
return conn.execute(create_query)
|
|
357
|
+
return conn.execute(text(create_query), placeholders)
|
|
272
358
|
except Exception as e:
|
|
273
359
|
raise u.FileExecutionError(f'Failed to run federate sql model "{self.name}"', e)
|
|
274
360
|
|
|
@@ -276,8 +362,8 @@ class Model:
|
|
|
276
362
|
if self.needs_pandas or self.is_target:
|
|
277
363
|
self.result = await asyncio.to_thread(self._load_table_to_pandas, conn)
|
|
278
364
|
|
|
279
|
-
async def _run_python_model(self, conn:
|
|
280
|
-
assert(isinstance(self.compiled_query,
|
|
365
|
+
async def _run_python_model(self, conn: Connection) -> None:
|
|
366
|
+
assert(isinstance(self.compiled_query, _PyModelQuery))
|
|
281
367
|
|
|
282
368
|
df = await asyncio.to_thread(self.compiled_query.query)
|
|
283
369
|
if self.needs_sql_table:
|
|
@@ -285,92 +371,86 @@ class Model:
|
|
|
285
371
|
if self.needs_pandas or self.is_target:
|
|
286
372
|
self.result = df
|
|
287
373
|
|
|
288
|
-
async def run_model(self, conn:
|
|
374
|
+
async def run_model(self, conn: Connection, placeholders: dict = {}) -> None:
|
|
289
375
|
start = time.time()
|
|
376
|
+
|
|
290
377
|
if self.query_file.query_type == QueryType.SQL:
|
|
291
|
-
await self._run_sql_model(conn)
|
|
378
|
+
await self._run_sql_model(conn, placeholders)
|
|
292
379
|
elif self.query_file.query_type == QueryType.PYTHON:
|
|
293
380
|
await self._run_python_model(conn)
|
|
294
|
-
timer.add_activity_time(f"running model '{self.name}'", start)
|
|
295
381
|
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
await
|
|
300
|
-
|
|
301
|
-
async def trigger(self, conn: sqlite3.Connection) -> None:
|
|
302
|
-
self.wait_count -= 1
|
|
303
|
-
if (self.wait_count == 0):
|
|
304
|
-
await self.run_model(conn)
|
|
382
|
+
model_type = self.get_model_type().name.lower()
|
|
383
|
+
timer.add_activity_time(f"running {model_type} model '{self.name}'", start)
|
|
384
|
+
|
|
385
|
+
await super().run_model(conn, placeholders)
|
|
305
386
|
|
|
306
|
-
def
|
|
387
|
+
def retrieve_dependent_query_models(self, dependent_model_names: set[str]) -> None:
|
|
307
388
|
if self.name not in dependent_model_names:
|
|
308
389
|
dependent_model_names.add(self.name)
|
|
309
390
|
for dep_model in self.upstreams.values():
|
|
310
|
-
dep_model.
|
|
391
|
+
dep_model.retrieve_dependent_query_models(dependent_model_names)
|
|
311
392
|
|
|
312
393
|
|
|
313
394
|
@dataclass
|
|
314
|
-
class
|
|
395
|
+
class _DAG:
|
|
315
396
|
dataset: DatasetsConfig
|
|
316
|
-
target_model:
|
|
317
|
-
models_dict: dict[str,
|
|
397
|
+
target_model: _Referable
|
|
398
|
+
models_dict: dict[str, _Referable]
|
|
318
399
|
parameter_set: Optional[ParameterSet] = field(default=None, init=False)
|
|
400
|
+
placeholders: dict[str, Any] = field(init=False, default_factory=dict)
|
|
319
401
|
|
|
320
402
|
def apply_selections(
|
|
321
403
|
self, user: Optional[User], selections: dict[str, str], *, updates_only: bool = False, request_version: Optional[int] = None
|
|
322
404
|
) -> None:
|
|
323
405
|
start = time.time()
|
|
324
406
|
dataset_params = self.dataset.parameters
|
|
325
|
-
parameter_set = ParameterConfigsSetIO.obj.apply_selections(
|
|
326
|
-
|
|
407
|
+
parameter_set = ParameterConfigsSetIO.obj.apply_selections(
|
|
408
|
+
dataset_params, selections, user, updates_only=updates_only, request_version=request_version
|
|
409
|
+
)
|
|
327
410
|
self.parameter_set = parameter_set
|
|
328
|
-
timer.add_activity_time(f"applying selections for dataset", start)
|
|
411
|
+
timer.add_activity_time(f"applying selections for dataset '{self.dataset.name}'", start)
|
|
329
412
|
|
|
330
413
|
def _compile_context(self, context_func: ContextFunc, user: Optional[User]) -> tuple[dict[str, Any], ContextArgs]:
|
|
331
414
|
start = time.time()
|
|
332
415
|
context = {}
|
|
333
416
|
param_args = ParameterConfigsSetIO.args
|
|
334
417
|
prms = self.parameter_set.get_parameters_as_dict()
|
|
335
|
-
args = ContextArgs(param_args.proj_vars, param_args.env_vars, user, prms, self.dataset.traits)
|
|
418
|
+
args = ContextArgs(param_args.proj_vars, param_args.env_vars, user, prms, self.dataset.traits, self.placeholders)
|
|
336
419
|
try:
|
|
337
420
|
context_func(ctx=context, sqrl=args)
|
|
338
421
|
except Exception as e:
|
|
339
|
-
raise u.FileExecutionError(f'Failed to run {c.CONTEXT_FILE} for dataset "{self.dataset}"', e)
|
|
340
|
-
timer.add_activity_time(f"running context.py for dataset", start)
|
|
422
|
+
raise u.FileExecutionError(f'Failed to run {c.CONTEXT_FILE} for dataset "{self.dataset.name}"', e)
|
|
423
|
+
timer.add_activity_time(f"running context.py for dataset '{self.dataset.name}'", start)
|
|
341
424
|
return context, args
|
|
342
425
|
|
|
343
426
|
async def _compile_models(self, context: dict[str, Any], ctx_args: ContextArgs, recurse: bool) -> None:
|
|
344
|
-
await self.target_model.compile(context, ctx_args, self.models_dict, recurse)
|
|
427
|
+
await self.target_model.compile(context, ctx_args, self.placeholders, self.models_dict, recurse)
|
|
345
428
|
|
|
346
429
|
def _get_terminal_nodes(self) -> set[str]:
|
|
347
430
|
start = time.time()
|
|
348
431
|
terminal_nodes = self.target_model.get_terminal_nodes(set())
|
|
349
432
|
for model in self.models_dict.values():
|
|
350
433
|
model.confirmed_no_cycles = False
|
|
351
|
-
timer.add_activity_time(f"validating no cycles in
|
|
434
|
+
timer.add_activity_time(f"validating no cycles in model dependencies", start)
|
|
352
435
|
return terminal_nodes
|
|
353
436
|
|
|
354
|
-
async def _run_models(self, terminal_nodes: set[str]) -> None:
|
|
355
|
-
if u.use_duckdb()
|
|
356
|
-
|
|
357
|
-
conn = duckdb.connect()
|
|
358
|
-
else:
|
|
359
|
-
conn = sqlite3.connect(":memory:", check_same_thread=False)
|
|
437
|
+
async def _run_models(self, terminal_nodes: set[str], placeholders: dict = {}) -> None:
|
|
438
|
+
conn_url = "duckdb:///" if u.use_duckdb() else "sqlite:///?check_same_thread=False"
|
|
439
|
+
engine = create_engine(conn_url)
|
|
360
440
|
|
|
361
|
-
|
|
441
|
+
with engine.connect() as conn:
|
|
362
442
|
coroutines = []
|
|
363
443
|
for model_name in terminal_nodes:
|
|
364
444
|
model = self.models_dict[model_name]
|
|
365
|
-
coroutines.append(model.run_model(conn))
|
|
445
|
+
coroutines.append(model.run_model(conn, placeholders))
|
|
366
446
|
await asyncio.gather(*coroutines)
|
|
367
|
-
|
|
368
|
-
|
|
447
|
+
|
|
448
|
+
engine.dispose()
|
|
369
449
|
|
|
370
450
|
async def execute(
|
|
371
451
|
self, context_func: ContextFunc, user: Optional[User], selections: dict[str, str], *, request_version: Optional[int] = None,
|
|
372
452
|
runquery: bool = True, recurse: bool = True
|
|
373
|
-
) ->
|
|
453
|
+
) -> dict[str, Any]:
|
|
374
454
|
recurse = (recurse or runquery)
|
|
375
455
|
|
|
376
456
|
self.apply_selections(user, selections, request_version=request_version)
|
|
@@ -381,17 +461,35 @@ class DAG:
|
|
|
381
461
|
|
|
382
462
|
terminal_nodes = self._get_terminal_nodes()
|
|
383
463
|
|
|
464
|
+
placeholders = ctx_args._placeholders.copy()
|
|
384
465
|
if runquery:
|
|
385
|
-
await self._run_models(terminal_nodes)
|
|
466
|
+
await self._run_models(terminal_nodes, placeholders)
|
|
467
|
+
|
|
468
|
+
return placeholders
|
|
386
469
|
|
|
387
|
-
def
|
|
470
|
+
def get_all_query_models(self) -> set[str]:
|
|
388
471
|
all_model_names = set()
|
|
389
|
-
self.target_model.
|
|
472
|
+
self.target_model.retrieve_dependent_query_models(all_model_names)
|
|
390
473
|
return all_model_names
|
|
391
|
-
|
|
474
|
+
|
|
475
|
+
def to_networkx_graph(self) -> nx.DiGraph:
|
|
476
|
+
G = nx.DiGraph()
|
|
477
|
+
|
|
478
|
+
for model_name, model in self.models_dict.items():
|
|
479
|
+
model_type = model.get_model_type()
|
|
480
|
+
level = model.get_max_path_length_to_target()
|
|
481
|
+
if level is not None:
|
|
482
|
+
G.add_node(model_name, layer=-level, model_type=model_type)
|
|
483
|
+
|
|
484
|
+
for model_name in G.nodes:
|
|
485
|
+
model = self.models_dict[model_name]
|
|
486
|
+
for dep_model_name in model.downstreams:
|
|
487
|
+
G.add_edge(model_name, dep_model_name)
|
|
488
|
+
|
|
489
|
+
return G
|
|
392
490
|
|
|
393
491
|
class ModelsIO:
|
|
394
|
-
raw_queries_by_model: dict[str,
|
|
492
|
+
raw_queries_by_model: dict[str, _QueryFile]
|
|
395
493
|
context_func: ContextFunc
|
|
396
494
|
|
|
397
495
|
@classmethod
|
|
@@ -400,7 +498,6 @@ class ModelsIO:
|
|
|
400
498
|
cls.raw_queries_by_model = {}
|
|
401
499
|
|
|
402
500
|
def populate_raw_queries_for_type(folder_path: Path, model_type: ModelType):
|
|
403
|
-
|
|
404
501
|
def populate_from_file(dp, file):
|
|
405
502
|
query_type = None
|
|
406
503
|
filepath = os.path.join(dp, file)
|
|
@@ -408,14 +505,14 @@ class ModelsIO:
|
|
|
408
505
|
if extension == '.py':
|
|
409
506
|
query_type = QueryType.PYTHON
|
|
410
507
|
module = pm.PyModule(filepath)
|
|
411
|
-
dependencies_func = module.get_func_or_class(c.DEP_FUNC, default_attr=lambda
|
|
412
|
-
raw_query =
|
|
508
|
+
dependencies_func = module.get_func_or_class(c.DEP_FUNC, default_attr=lambda sqrl: [])
|
|
509
|
+
raw_query = _RawPyQuery(module.get_func_or_class(c.MAIN_FUNC), dependencies_func)
|
|
413
510
|
elif extension == '.sql':
|
|
414
511
|
query_type = QueryType.SQL
|
|
415
|
-
raw_query =
|
|
512
|
+
raw_query = _RawSqlQuery(u.read_file(filepath))
|
|
416
513
|
|
|
417
514
|
if query_type is not None:
|
|
418
|
-
query_file =
|
|
515
|
+
query_file = _QueryFile(filepath, model_type, query_type, raw_query)
|
|
419
516
|
if file_stem in cls.raw_queries_by_model:
|
|
420
517
|
conflicts = [cls.raw_queries_by_model[file_stem].filepath, filepath]
|
|
421
518
|
raise u.ConfigurationError(f"Multiple models found for '{file_stem}': {conflicts}")
|
|
@@ -431,44 +528,98 @@ class ModelsIO:
|
|
|
431
528
|
federates_path = u.join_paths(c.MODELS_FOLDER, c.FEDERATES_FOLDER)
|
|
432
529
|
populate_raw_queries_for_type(federates_path, ModelType.FEDERATE)
|
|
433
530
|
|
|
434
|
-
context_path = u.join_paths(c.
|
|
435
|
-
cls.context_func = pm.PyModule(context_path).get_func_or_class(c.MAIN_FUNC, default_attr=lambda
|
|
531
|
+
context_path = u.join_paths(c.PYCONFIGS_FOLDER, c.CONTEXT_FILE)
|
|
532
|
+
cls.context_func = pm.PyModule(context_path).get_func_or_class(c.MAIN_FUNC, default_attr=lambda ctx, sqrl: None)
|
|
436
533
|
|
|
437
|
-
timer.add_activity_time("loading models and
|
|
534
|
+
timer.add_activity_time("loading files for models and context.py", start)
|
|
438
535
|
|
|
439
536
|
@classmethod
|
|
440
|
-
def GenerateDAG(cls, dataset: str, *, target_model_name: Optional[str] = None, always_pandas: bool = False) ->
|
|
441
|
-
|
|
537
|
+
def GenerateDAG(cls, dataset: str, *, target_model_name: Optional[str] = None, always_pandas: bool = False) -> _DAG:
|
|
538
|
+
seeds_dict = SeedsIO.obj.get_dataframes()
|
|
539
|
+
|
|
540
|
+
models_dict: dict[str, _Referable] = {key: _Seed(key, df) for key, df in seeds_dict.items()}
|
|
541
|
+
for key, val in cls.raw_queries_by_model.items():
|
|
542
|
+
models_dict[key] = _Model(key, val)
|
|
543
|
+
models_dict[key].needs_pandas = always_pandas
|
|
442
544
|
|
|
443
545
|
dataset_config = ManifestIO.obj.datasets[dataset]
|
|
444
546
|
target_model_name = dataset_config.model if target_model_name is None else target_model_name
|
|
445
547
|
target_model = models_dict[target_model_name]
|
|
446
548
|
target_model.is_target = True
|
|
447
549
|
|
|
448
|
-
return
|
|
550
|
+
return _DAG(dataset_config, target_model, models_dict)
|
|
449
551
|
|
|
450
552
|
@classmethod
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
553
|
+
def draw_dag(cls, dag: _DAG, output_folder: Path) -> None:
|
|
554
|
+
color_map = {ModelType.SEED: "green", ModelType.DBVIEW: "red", ModelType.FEDERATE: "skyblue"}
|
|
555
|
+
|
|
556
|
+
G = dag.to_networkx_graph()
|
|
557
|
+
|
|
558
|
+
fig, _ = plt.subplots()
|
|
559
|
+
pos = nx.multipartite_layout(G, subset_key="layer")
|
|
560
|
+
colors = [color_map[node[1]] for node in G.nodes(data="model_type")]
|
|
561
|
+
nx.draw(G, pos=pos, node_shape='^', node_size=1000, node_color=colors, arrowsize=20)
|
|
562
|
+
|
|
563
|
+
y_values = [val[1] for val in pos.values()]
|
|
564
|
+
scale = max(y_values) - min(y_values) if len(y_values) > 0 else 0
|
|
565
|
+
label_pos = {key: (val[0], val[1]-0.002-0.1*scale) for key, val in pos.items()}
|
|
566
|
+
nx.draw_networkx_labels(G, pos=label_pos, font_size=8)
|
|
567
|
+
|
|
568
|
+
fig.tight_layout()
|
|
569
|
+
plt.margins(x=0.1, y=0.1)
|
|
570
|
+
plt.savefig(u.join_paths(output_folder, "dag.png"))
|
|
571
|
+
plt.close(fig)
|
|
572
|
+
|
|
573
|
+
@classmethod
|
|
574
|
+
async def WriteDatasetOutputsGivenTestSet(
|
|
575
|
+
cls, dataset_conf: DatasetsConfig, select: str, test_set: Optional[str], runquery: bool, recurse: bool
|
|
576
|
+
) -> Any:
|
|
577
|
+
dataset = dataset_conf.name
|
|
578
|
+
default_test_set, default_test_set_conf = ManifestIO.obj.get_default_test_set(dataset)
|
|
579
|
+
if test_set is None or test_set == default_test_set:
|
|
580
|
+
test_set, test_set_conf = default_test_set, default_test_set_conf
|
|
581
|
+
elif test_set in ManifestIO.obj.selection_test_sets:
|
|
582
|
+
test_set_conf = ManifestIO.obj.selection_test_sets[test_set]
|
|
583
|
+
else:
|
|
584
|
+
raise u.InvalidInputError(f"No test set named '{test_set}' was found when compiling dataset '{dataset}'. The test set must be defined if not default for dataset.")
|
|
585
|
+
|
|
586
|
+
error_msg_intro = f"Cannot compile dataset '{dataset}' with test set '{test_set}'."
|
|
587
|
+
if test_set_conf.datasets is not None and dataset not in test_set_conf.datasets:
|
|
588
|
+
raise u.InvalidInputError(f"{error_msg_intro}\n Applicable datasets for test set '{test_set}' does not include dataset '{dataset}'.")
|
|
455
589
|
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
590
|
+
user_attributes = test_set_conf.user_attributes.copy()
|
|
591
|
+
selections = test_set_conf.parameters.copy()
|
|
592
|
+
username, is_internal = user_attributes.pop("username", ""), user_attributes.pop("is_internal", False)
|
|
593
|
+
if test_set_conf.is_authenticated:
|
|
594
|
+
user_cls: type[User] = Authenticator.get_auth_helper().get_func_or_class("User", default_attr=User)
|
|
595
|
+
user = user_cls.Create(username, is_internal=is_internal, **user_attributes)
|
|
596
|
+
elif dataset_conf.scope == DatasetScope.PUBLIC:
|
|
597
|
+
user = None
|
|
598
|
+
else:
|
|
599
|
+
raise u.ConfigurationError(f"{error_msg_intro}\n Non-public datasets require a test set with 'user_attributes' section defined")
|
|
459
600
|
|
|
601
|
+
if dataset_conf.scope == DatasetScope.PRIVATE and not is_internal:
|
|
602
|
+
raise u.ConfigurationError(f"{error_msg_intro}\n Private datasets require a test set with user_attribute 'is_internal' set to true")
|
|
603
|
+
|
|
604
|
+
# always_pandas is set to True for creating CSV files from results (when runquery is True)
|
|
460
605
|
dag = cls.GenerateDAG(dataset, target_model_name=select, always_pandas=True)
|
|
461
|
-
await dag.execute(cls.context_func, user, selections, runquery=runquery, recurse=recurse)
|
|
606
|
+
placeholders = await dag.execute(cls.context_func, user, selections, runquery=runquery, recurse=recurse)
|
|
462
607
|
|
|
463
|
-
output_folder = u.join_paths(c.TARGET_FOLDER, c.COMPILE_FOLDER,
|
|
608
|
+
output_folder = u.join_paths(c.TARGET_FOLDER, c.COMPILE_FOLDER, dataset, test_set)
|
|
464
609
|
if os.path.exists(output_folder):
|
|
465
610
|
shutil.rmtree(output_folder)
|
|
611
|
+
os.makedirs(output_folder, exist_ok=True)
|
|
612
|
+
|
|
613
|
+
def write_placeholders() -> None:
|
|
614
|
+
output_filepath = u.join_paths(output_folder, "placeholders.json")
|
|
615
|
+
with open(output_filepath, 'w') as f:
|
|
616
|
+
json.dump(placeholders, f, indent=4)
|
|
466
617
|
|
|
467
|
-
def write_model_outputs(model:
|
|
618
|
+
def write_model_outputs(model: _Model) -> None:
|
|
468
619
|
subfolder = c.DBVIEWS_FOLDER if model.query_file.model_type == ModelType.DBVIEW else c.FEDERATES_FOLDER
|
|
469
620
|
subpath = u.join_paths(output_folder, subfolder)
|
|
470
621
|
os.makedirs(subpath, exist_ok=True)
|
|
471
|
-
if isinstance(model.compiled_query,
|
|
622
|
+
if isinstance(model.compiled_query, _SqlModelQuery):
|
|
472
623
|
output_filepath = u.join_paths(subpath, model.name+'.sql')
|
|
473
624
|
query = model.compiled_query.query
|
|
474
625
|
with open(output_filepath, 'w') as f:
|
|
@@ -477,39 +628,50 @@ class ModelsIO:
|
|
|
477
628
|
output_filepath = u.join_paths(subpath, model.name+'.csv')
|
|
478
629
|
model.result.to_csv(output_filepath, index=False)
|
|
479
630
|
|
|
480
|
-
|
|
631
|
+
write_placeholders()
|
|
632
|
+
all_model_names = dag.get_all_query_models()
|
|
481
633
|
coroutines = [asyncio.to_thread(write_model_outputs, dag.models_dict[name]) for name in all_model_names]
|
|
482
634
|
await asyncio.gather(*coroutines)
|
|
483
|
-
|
|
635
|
+
|
|
636
|
+
if recurse:
|
|
637
|
+
cls.draw_dag(dag, output_folder)
|
|
638
|
+
|
|
639
|
+
if isinstance(dag.target_model, _Model):
|
|
640
|
+
return dag.target_model.compiled_query.query # else return None
|
|
484
641
|
|
|
485
642
|
@classmethod
|
|
486
643
|
async def WriteOutputs(
|
|
487
|
-
cls, dataset: Optional[str], select: Optional[str],
|
|
644
|
+
cls, dataset: Optional[str], do_all_datasets: bool, select: Optional[str], test_set: Optional[str], do_all_test_sets: bool,
|
|
645
|
+
runquery: bool
|
|
488
646
|
) -> None:
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
647
|
+
|
|
648
|
+
def get_applicable_test_sets(dataset: str) -> list[str]:
|
|
649
|
+
applicable_test_sets = []
|
|
650
|
+
for test_set_name, test_set_config in ManifestIO.obj.selection_test_sets.items():
|
|
651
|
+
if test_set_config.datasets is None or dataset in test_set_config.datasets:
|
|
652
|
+
applicable_test_sets.append(test_set_name)
|
|
653
|
+
return applicable_test_sets
|
|
496
654
|
|
|
497
655
|
recurse = True
|
|
498
656
|
dataset_configs = ManifestIO.obj.datasets
|
|
499
|
-
if
|
|
500
|
-
selected_models = [(dataset
|
|
657
|
+
if do_all_datasets:
|
|
658
|
+
selected_models = [(dataset, dataset.model) for dataset in dataset_configs.values()]
|
|
501
659
|
else:
|
|
502
660
|
if select is None:
|
|
503
661
|
select = dataset_configs[dataset].model
|
|
504
662
|
else:
|
|
505
663
|
recurse = False
|
|
506
|
-
selected_models = [(dataset, select)]
|
|
664
|
+
selected_models = [(dataset_configs[dataset], select)]
|
|
507
665
|
|
|
508
666
|
coroutines = []
|
|
509
|
-
for
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
667
|
+
for dataset_conf, select in selected_models:
|
|
668
|
+
if do_all_test_sets:
|
|
669
|
+
for test_set_name in get_applicable_test_sets(dataset_conf.name):
|
|
670
|
+
coroutine = cls.WriteDatasetOutputsGivenTestSet(dataset_conf, select, test_set_name, runquery, recurse)
|
|
671
|
+
coroutines.append(coroutine)
|
|
672
|
+
|
|
673
|
+
coroutine = cls.WriteDatasetOutputsGivenTestSet(dataset_conf, select, test_set, runquery, recurse)
|
|
674
|
+
coroutines.append(coroutine)
|
|
513
675
|
|
|
514
676
|
queries = await asyncio.gather(*coroutines)
|
|
515
677
|
if not recurse and len(queries) == 1 and isinstance(queries[0], str):
|