squirrels 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of squirrels might be problematic. Click here for more details.

Files changed (52) hide show
  1. squirrels/__init__.py +11 -4
  2. squirrels/_api_response_models.py +118 -0
  3. squirrels/_api_server.py +146 -75
  4. squirrels/_authenticator.py +10 -8
  5. squirrels/_command_line.py +17 -11
  6. squirrels/_connection_set.py +4 -3
  7. squirrels/_constants.py +15 -6
  8. squirrels/_environcfg.py +15 -11
  9. squirrels/_initializer.py +25 -15
  10. squirrels/_manifest.py +22 -12
  11. squirrels/_models.py +316 -154
  12. squirrels/_parameter_configs.py +195 -57
  13. squirrels/_parameter_sets.py +14 -17
  14. squirrels/_py_module.py +2 -4
  15. squirrels/_seeds.py +38 -0
  16. squirrels/_utils.py +41 -33
  17. squirrels/arguments/run_time_args.py +76 -34
  18. squirrels/data_sources.py +172 -51
  19. squirrels/dateutils.py +3 -3
  20. squirrels/package_data/assets/index.js +14 -14
  21. squirrels/package_data/base_project/.gitignore +1 -0
  22. squirrels/package_data/base_project/{database → assets}/expenses.db +0 -0
  23. squirrels/package_data/base_project/assets/weather.db +0 -0
  24. squirrels/package_data/base_project/connections.yml +1 -1
  25. squirrels/package_data/base_project/docker/Dockerfile +1 -1
  26. squirrels/package_data/base_project/{environcfg.yml → env.yml} +8 -8
  27. squirrels/package_data/base_project/models/dbviews/database_view1.py +25 -14
  28. squirrels/package_data/base_project/models/dbviews/database_view1.sql +20 -14
  29. squirrels/package_data/base_project/models/federates/dataset_example.py +6 -5
  30. squirrels/package_data/base_project/models/federates/dataset_example.sql +1 -1
  31. squirrels/package_data/base_project/parameters.yml +57 -28
  32. squirrels/package_data/base_project/pyconfigs/auth.py +11 -10
  33. squirrels/package_data/base_project/pyconfigs/connections.py +7 -9
  34. squirrels/package_data/base_project/pyconfigs/context.py +49 -33
  35. squirrels/package_data/base_project/pyconfigs/parameters.py +62 -30
  36. squirrels/package_data/base_project/seeds/seed_categories.csv +6 -0
  37. squirrels/package_data/base_project/seeds/seed_subcategories.csv +15 -0
  38. squirrels/package_data/base_project/squirrels.yml.j2 +37 -20
  39. squirrels/parameter_options.py +30 -10
  40. squirrels/parameters.py +300 -70
  41. squirrels/user_base.py +3 -13
  42. squirrels-0.3.1.dist-info/LICENSE +201 -0
  43. {squirrels-0.2.2.dist-info → squirrels-0.3.1.dist-info}/METADATA +17 -17
  44. squirrels-0.3.1.dist-info/RECORD +56 -0
  45. squirrels/package_data/base_project/database/weather.db +0 -0
  46. squirrels/package_data/base_project/seeds/mocks/category.csv +0 -3
  47. squirrels/package_data/base_project/seeds/mocks/max_filter.csv +0 -2
  48. squirrels/package_data/base_project/seeds/mocks/subcategory.csv +0 -6
  49. squirrels-0.2.2.dist-info/LICENSE +0 -22
  50. squirrels-0.2.2.dist-info/RECORD +0 -55
  51. {squirrels-0.2.2.dist-info → squirrels-0.3.1.dist-info}/WHEEL +0 -0
  52. {squirrels-0.2.2.dist-info → squirrels-0.3.1.dist-info}/entry_points.txt +0 -0
squirrels/_models.py CHANGED
@@ -1,21 +1,26 @@
1
1
  from __future__ import annotations
2
- from typing import Union, Optional, Callable, Iterable, Any
2
+ from typing import Optional, Callable, Iterable, Any
3
3
  from dataclasses import dataclass, field
4
+ from abc import ABCMeta, abstractmethod
4
5
  from enum import Enum
5
6
  from pathlib import Path
6
- import sqlite3, pandas as pd, asyncio, os, shutil
7
+ from sqlalchemy import create_engine, text, Connection
8
+ import asyncio, os, shutil, pandas as pd, json
9
+ import matplotlib.pyplot as plt, networkx as nx
7
10
 
8
11
  from . import _constants as c, _utils as u, _py_module as pm
9
12
  from .arguments.run_time_args import ContextArgs, ModelDepsArgs, ModelArgs
10
13
  from ._authenticator import User, Authenticator
11
14
  from ._connection_set import ConnectionSetIO
12
- from ._manifest import ManifestIO, DatasetsConfig
15
+ from ._manifest import ManifestIO, DatasetsConfig, DatasetScope
13
16
  from ._parameter_sets import ParameterConfigsSetIO, ParameterSet
17
+ from ._seeds import SeedsIO
14
18
  from ._timer import timer, time
15
19
 
16
20
  class ModelType(Enum):
17
21
  DBVIEW = 1
18
22
  FEDERATE = 2
23
+ SEED = 3
19
24
 
20
25
  class QueryType(Enum):
21
26
  SQL = 0
@@ -27,12 +32,30 @@ class Materialization(Enum):
27
32
 
28
33
 
29
34
  @dataclass
30
- class SqlModelConfig:
35
+ class _SqlModelConfig:
31
36
  ## Applicable for dbview models
32
37
  connection_name: str
33
38
 
34
39
  ## Applicable for federated models
35
40
  materialized: Materialization
41
+
42
+ def set_attribute(self, **kwargs) -> str:
43
+ connection_name = kwargs.get(c.DBVIEW_CONN_KEY)
44
+ if connection_name is not None:
45
+ if not isinstance(connection_name, str):
46
+ raise u.ConfigurationError("The 'connection_name' argument of 'config' macro must be a string")
47
+ self.connection_name = connection_name
48
+
49
+ materialized: str = kwargs.get(c.MATERIALIZED_KEY)
50
+ if materialized is not None:
51
+ if not isinstance(materialized, str):
52
+ raise u.ConfigurationError("The 'materialized' argument of 'config' macro must be a string")
53
+ try:
54
+ self.materialized = Materialization[materialized.upper()]
55
+ except KeyError as e:
56
+ valid_options = [x.name for x in Materialization]
57
+ raise u.ConfigurationError(f"The 'materialized' argument value '{materialized}' is not valid. Must be one of: {valid_options}") from e
58
+ return ""
36
59
 
37
60
  def get_sql_for_create(self, model_name: str, select_query: str) -> str:
38
61
  if self.materialized == Materialization.TABLE:
@@ -43,78 +66,138 @@ class SqlModelConfig:
43
66
  raise NotImplementedError(f"Materialization option not supported: {self.materialized}")
44
67
 
45
68
  return create_prefix + select_query
46
-
47
- def set_attribute(self, **kwargs) -> str:
48
- connection_name = kwargs.get(c.DBVIEW_CONN_KEY)
49
- materialized = kwargs.get(c.MATERIALIZED_KEY)
50
- if isinstance(connection_name, str):
51
- self.connection_name = connection_name
52
- if isinstance(materialized, str):
53
- self.materialized = Materialization[materialized.upper()]
54
- return ""
55
69
 
56
70
 
57
71
  ContextFunc = Callable[[dict[str, Any], ContextArgs], None]
58
72
 
59
73
 
60
74
  @dataclass(frozen=True)
61
- class RawQuery:
75
+ class _RawQuery(metaclass=ABCMeta):
62
76
  pass
63
77
 
64
78
  @dataclass(frozen=True)
65
- class RawSqlQuery(RawQuery):
79
+ class _RawSqlQuery(_RawQuery):
66
80
  query: str
67
81
 
68
82
  @dataclass(frozen=True)
69
- class RawPyQuery(RawQuery):
83
+ class _RawPyQuery(_RawQuery):
70
84
  query: Callable[[Any], pd.DataFrame]
71
85
  dependencies_func: Callable[[Any], Iterable]
72
86
 
73
87
 
74
88
  @dataclass
75
- class Query:
89
+ class _Query(metaclass=ABCMeta):
76
90
  query: Any
77
91
 
78
92
  @dataclass
79
- class WorkInProgress:
93
+ class _WorkInProgress(_Query):
80
94
  query: None = field(default=None, init=False)
81
95
 
82
96
  @dataclass
83
- class SqlModelQuery(Query):
97
+ class _SqlModelQuery(_Query):
84
98
  query: str
85
- config: SqlModelConfig
99
+ config: _SqlModelConfig
86
100
 
87
101
  @dataclass
88
- class PyModelQuery(Query):
102
+ class _PyModelQuery(_Query):
89
103
  query: Callable[[], pd.DataFrame]
90
104
 
91
105
 
92
106
  @dataclass(frozen=True)
93
- class QueryFile:
107
+ class _QueryFile:
94
108
  filepath: str
95
109
  model_type: ModelType
96
110
  query_type: QueryType
97
- raw_query: RawQuery
111
+ raw_query: _RawQuery
98
112
 
99
113
 
100
114
  @dataclass
101
- class Model:
115
+ class _Referable(metaclass=ABCMeta):
102
116
  name: str
103
- query_file: QueryFile
104
117
  is_target: bool = field(default=False, init=False)
105
- compiled_query: Optional[Query] = field(default=None, init=False)
106
118
 
107
119
  needs_sql_table: bool = field(default=False, init=False)
108
- needs_pandas: bool = False
120
+ needs_pandas: bool = field(default=False, init=False)
109
121
  result: Optional[pd.DataFrame] = field(default=None, init=False, repr=False)
110
-
111
- wait_count: int = field(default=0, init=False, repr=False)
112
- upstreams: dict[str, Model] = field(default_factory=dict, init=False, repr=False)
113
- downstreams: dict[str, Model] = field(default_factory=dict, init=False, repr=False)
114
122
 
123
+ wait_count: int = field(default=0, init=False, repr=False)
115
124
  confirmed_no_cycles: bool = field(default=False, init=False)
125
+ upstreams: dict[str, _Referable] = field(default_factory=dict, init=False, repr=False)
126
+ downstreams: dict[str, _Referable] = field(default_factory=dict, init=False, repr=False)
127
+
128
+ @abstractmethod
129
+ def get_model_type(self) -> ModelType:
130
+ pass
131
+
132
+ async def compile(
133
+ self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str, _Referable], recurse: bool
134
+ ) -> None:
135
+ pass
136
+
137
+ @abstractmethod
138
+ def get_terminal_nodes(self, depencency_path: set[str]) -> set[str]:
139
+ pass
140
+
141
+ def _load_pandas_to_table(self, df: pd.DataFrame, conn: Connection) -> None:
142
+ df.to_sql(self.name, conn, index=False)
143
+
144
+ def _load_table_to_pandas(self, conn: Connection) -> pd.DataFrame:
145
+ query = f"SELECT * FROM {self.name}"
146
+ return pd.read_sql(query, conn)
147
+
148
+ async def _trigger(self, conn: Connection, placeholders: dict = {}) -> None:
149
+ self.wait_count -= 1
150
+ if (self.wait_count == 0):
151
+ await self.run_model(conn, placeholders)
152
+
153
+ @abstractmethod
154
+ async def run_model(self, conn: Connection, placeholders: dict = {}) -> None:
155
+ coroutines = []
156
+ for model in self.downstreams.values():
157
+ coroutines.append(model._trigger(conn, placeholders))
158
+ await asyncio.gather(*coroutines)
159
+
160
+ def retrieve_dependent_query_models(self, dependent_model_names: set[str]) -> None:
161
+ pass
162
+
163
+ def get_max_path_length_to_target(self) -> int:
164
+ if not hasattr(self, "max_path_len_to_target"):
165
+ path_lengths = []
166
+ for child_model in self.downstreams.values():
167
+ path_lengths.append(child_model.get_max_path_length_to_target()+1)
168
+ if len(path_lengths) > 0:
169
+ self.max_path_len_to_target = max(path_lengths)
170
+ else:
171
+ self.max_path_len_to_target = 0 if self.is_target else None
172
+ return self.max_path_len_to_target
173
+
174
+
175
+ @dataclass
176
+ class _Seed(_Referable):
177
+ result: pd.DataFrame
116
178
 
117
- def _add_upstream(self, other: Model) -> None:
179
+ def get_model_type(self) -> ModelType:
180
+ return ModelType.SEED
181
+
182
+ def get_terminal_nodes(self, depencency_path: set[str]) -> set[str]:
183
+ return {self.name}
184
+
185
+ async def run_model(self, conn: Connection, placeholders: dict = {}) -> None:
186
+ if self.needs_sql_table:
187
+ await asyncio.to_thread(self._load_pandas_to_table, self.result, conn)
188
+ await super().run_model(conn, placeholders)
189
+
190
+
191
+ @dataclass
192
+ class _Model(_Referable):
193
+ query_file: _QueryFile
194
+
195
+ compiled_query: Optional[_Query] = field(default=None, init=False)
196
+
197
+ def get_model_type(self) -> ModelType:
198
+ return self.query_file.model_type
199
+
200
+ def _add_upstream(self, other: _Referable) -> None:
118
201
  self.upstreams[other.name] = other
119
202
  other.downstreams[self.name] = self
120
203
 
@@ -137,17 +220,20 @@ class Model:
137
220
  materialized = federate_config.materialized
138
221
  return Materialization[materialized.upper()]
139
222
 
140
- async def _compile_sql_model(self, ctx: dict[str, Any], ctx_args: ContextArgs) -> tuple[SqlModelQuery, set]:
141
- assert(isinstance(self.query_file.raw_query, RawSqlQuery))
223
+ async def _compile_sql_model(
224
+ self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any]
225
+ ) -> tuple[_SqlModelQuery, set]:
226
+ assert(isinstance(self.query_file.raw_query, _RawSqlQuery))
227
+
142
228
  raw_query = self.query_file.raw_query.query
143
-
144
229
  connection_name = self._get_dbview_conn_name()
145
230
  materialized = self._get_materialized()
146
- configuration = SqlModelConfig(connection_name, materialized)
231
+ configuration = _SqlModelConfig(connection_name, materialized)
232
+ is_placeholder = lambda placeholder: placeholder in placeholders
147
233
  kwargs = {
148
- "proj_vars": ctx_args.proj_vars, "env_vars": ctx_args.env_vars,
149
- "user": ctx_args.user, "prms": ctx_args.prms, "traits": ctx_args.traits,
150
- "ctx": ctx, "config": configuration.set_attribute
234
+ "proj_vars": ctx_args.proj_vars, "env_vars": ctx_args.env_vars, "user": ctx_args.user, "prms": ctx_args.prms,
235
+ "traits": ctx_args.traits, "ctx": ctx, "is_placeholder": is_placeholder, "set_placeholder": ctx_args.set_placeholder,
236
+ "config": configuration.set_attribute, "is_param_enabled": ctx_args.param_exists
151
237
  }
152
238
  dependencies = set()
153
239
  if self.query_file.model_type == ModelType.FEDERATE:
@@ -157,16 +243,21 @@ class Model:
157
243
  kwargs["ref"] = ref
158
244
 
159
245
  try:
160
- query = await asyncio.to_thread(u.render_string, raw_query, kwargs)
246
+ query = await asyncio.to_thread(u.render_string, raw_query, **kwargs)
161
247
  except Exception as e:
162
248
  raise u.FileExecutionError(f'Failed to compile sql model "{self.name}"', e)
163
249
 
164
- compiled_query = SqlModelQuery(query, configuration)
250
+ compiled_query = _SqlModelQuery(query, configuration)
165
251
  return compiled_query, dependencies
166
252
 
167
- async def _compile_python_model(self, ctx: dict[str, Any], ctx_args: ContextArgs) -> tuple[PyModelQuery, set]:
168
- assert(isinstance(self.query_file.raw_query, RawPyQuery))
169
- sqrl_args = ModelDepsArgs(ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits, ctx)
253
+ async def _compile_python_model(
254
+ self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any]
255
+ ) -> tuple[_PyModelQuery, set]:
256
+ assert(isinstance(self.query_file.raw_query, _RawPyQuery))
257
+
258
+ sqrl_args = ModelDepsArgs(
259
+ ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits, placeholders, ctx
260
+ )
170
261
  try:
171
262
  dependencies = await asyncio.to_thread(self.query_file.raw_query.dependencies_func, sqrl_args)
172
263
  except Exception as e:
@@ -174,35 +265,43 @@ class Model:
174
265
 
175
266
  dbview_conn_name = self._get_dbview_conn_name()
176
267
  connections = ConnectionSetIO.obj.get_engines_as_dict()
177
- ref = lambda x: self.upstreams[x].result
178
- sqrl_args = ModelArgs(ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits,
179
- ctx, dbview_conn_name, connections, ref, set(dependencies))
268
+ ref = lambda model: self.upstreams[model].result
269
+ sqrl_args = ModelArgs(
270
+ ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits, placeholders, ctx,
271
+ dbview_conn_name, connections, dependencies, ref
272
+ )
180
273
 
181
274
  def compiled_query():
182
275
  try:
183
- return self.query_file.raw_query.query(sqrl=sqrl_args)
276
+ raw_query: _RawPyQuery = self.query_file.raw_query
277
+ return raw_query.query(sqrl=sqrl_args)
184
278
  except Exception as e:
185
279
  raise u.FileExecutionError(f'Failed to run "{c.MAIN_FUNC}" function for python model "{self.name}"', e)
186
280
 
187
- return PyModelQuery(compiled_query), dependencies
281
+ return _PyModelQuery(compiled_query), dependencies
188
282
 
189
- async def compile(self, ctx: dict[str, Any], ctx_args: ContextArgs, models_dict: dict[str, Model], recurse: bool) -> None:
283
+ async def compile(
284
+ self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str, _Referable], recurse: bool
285
+ ) -> None:
190
286
  if self.compiled_query is not None:
191
287
  return
192
288
  else:
193
- self.compiled_query = WorkInProgress()
289
+ self.compiled_query = _WorkInProgress()
194
290
 
195
291
  start = time.time()
292
+
196
293
  if self.query_file.query_type == QueryType.SQL:
197
- compiled_query, dependencies = await self._compile_sql_model(ctx, ctx_args)
294
+ compiled_query, dependencies = await self._compile_sql_model(ctx, ctx_args, placeholders)
198
295
  elif self.query_file.query_type == QueryType.PYTHON:
199
- compiled_query, dependencies = await self._compile_python_model(ctx, ctx_args)
296
+ compiled_query, dependencies = await self._compile_python_model(ctx, ctx_args, placeholders)
200
297
  else:
201
298
  raise NotImplementedError(f"Query type not supported: {self.query_file.query_type}")
202
299
 
203
300
  self.compiled_query = compiled_query
204
301
  self.wait_count = len(dependencies)
205
- timer.add_activity_time(f"compiling model '{self.name}'", start)
302
+
303
+ model_type = self.get_model_type().name.lower()
304
+ timer.add_activity_time(f"compiling {model_type} model '{self.name}'", start)
206
305
 
207
306
  if not recurse:
208
307
  return
@@ -211,7 +310,7 @@ class Model:
211
310
  coroutines = []
212
311
  for dep_model in dep_models:
213
312
  self._add_upstream(dep_model)
214
- coro = dep_model.compile(ctx, ctx_args, models_dict, recurse)
313
+ coro = dep_model.compile(ctx, ctx_args, placeholders, models_dict, recurse)
215
314
  coroutines.append(coro)
216
315
  await asyncio.gather(*coroutines)
217
316
 
@@ -234,29 +333,16 @@ class Model:
234
333
 
235
334
  self.confirmed_no_cycles = True
236
335
  return terminal_nodes
237
-
238
- def _load_pandas_to_table(self, df: pd.DataFrame, conn: sqlite3.Connection) -> None:
239
- if u.use_duckdb():
240
- conn.execute(f"CREATE TABLE {self.name} AS FROM df")
241
- else:
242
- df.to_sql(self.name, conn, index=False)
243
-
244
- def _load_table_to_pandas(self, conn: sqlite3.Connection) -> pd.DataFrame:
245
- if u.use_duckdb():
246
- return conn.execute(f"FROM {self.name}").df()
247
- else:
248
- query = f"SELECT * FROM {self.name}"
249
- return pd.read_sql(query, conn)
250
336
 
251
- async def _run_sql_model(self, conn: sqlite3.Connection) -> None:
252
- assert(isinstance(self.compiled_query, SqlModelQuery))
337
+ async def _run_sql_model(self, conn: Connection, placeholders: dict = {}) -> None:
338
+ assert(isinstance(self.compiled_query, _SqlModelQuery))
253
339
  config = self.compiled_query.config
254
340
  query = self.compiled_query.query
255
341
 
256
342
  if self.query_file.model_type == ModelType.DBVIEW:
257
343
  def run_sql_query():
258
344
  try:
259
- return ConnectionSetIO.obj.run_sql_query_from_conn_name(query, config.connection_name)
345
+ return ConnectionSetIO.obj.run_sql_query_from_conn_name(query, config.connection_name, placeholders)
260
346
  except RuntimeError as e:
261
347
  raise u.FileExecutionError(f'Failed to run dbview sql model "{self.name}"', e)
262
348
 
@@ -268,7 +354,7 @@ class Model:
268
354
  def create_table():
269
355
  create_query = config.get_sql_for_create(self.name, query)
270
356
  try:
271
- return conn.execute(create_query)
357
+ return conn.execute(text(create_query), placeholders)
272
358
  except Exception as e:
273
359
  raise u.FileExecutionError(f'Failed to run federate sql model "{self.name}"', e)
274
360
 
@@ -276,8 +362,8 @@ class Model:
276
362
  if self.needs_pandas or self.is_target:
277
363
  self.result = await asyncio.to_thread(self._load_table_to_pandas, conn)
278
364
 
279
- async def _run_python_model(self, conn: sqlite3.Connection) -> None:
280
- assert(isinstance(self.compiled_query, PyModelQuery))
365
+ async def _run_python_model(self, conn: Connection) -> None:
366
+ assert(isinstance(self.compiled_query, _PyModelQuery))
281
367
 
282
368
  df = await asyncio.to_thread(self.compiled_query.query)
283
369
  if self.needs_sql_table:
@@ -285,92 +371,86 @@ class Model:
285
371
  if self.needs_pandas or self.is_target:
286
372
  self.result = df
287
373
 
288
- async def run_model(self, conn: sqlite3.Connection) -> None:
374
+ async def run_model(self, conn: Connection, placeholders: dict = {}) -> None:
289
375
  start = time.time()
376
+
290
377
  if self.query_file.query_type == QueryType.SQL:
291
- await self._run_sql_model(conn)
378
+ await self._run_sql_model(conn, placeholders)
292
379
  elif self.query_file.query_type == QueryType.PYTHON:
293
380
  await self._run_python_model(conn)
294
- timer.add_activity_time(f"running model '{self.name}'", start)
295
381
 
296
- coroutines = []
297
- for model in self.downstreams.values():
298
- coroutines.append(model.trigger(conn))
299
- await asyncio.gather(*coroutines)
300
-
301
- async def trigger(self, conn: sqlite3.Connection) -> None:
302
- self.wait_count -= 1
303
- if (self.wait_count == 0):
304
- await self.run_model(conn)
382
+ model_type = self.get_model_type().name.lower()
383
+ timer.add_activity_time(f"running {model_type} model '{self.name}'", start)
384
+
385
+ await super().run_model(conn, placeholders)
305
386
 
306
- def fill_dependent_model_names(self, dependent_model_names: set[str]) -> None:
387
+ def retrieve_dependent_query_models(self, dependent_model_names: set[str]) -> None:
307
388
  if self.name not in dependent_model_names:
308
389
  dependent_model_names.add(self.name)
309
390
  for dep_model in self.upstreams.values():
310
- dep_model.fill_dependent_model_names(dependent_model_names)
391
+ dep_model.retrieve_dependent_query_models(dependent_model_names)
311
392
 
312
393
 
313
394
  @dataclass
314
- class DAG:
395
+ class _DAG:
315
396
  dataset: DatasetsConfig
316
- target_model: Model
317
- models_dict: dict[str, Model]
397
+ target_model: _Referable
398
+ models_dict: dict[str, _Referable]
318
399
  parameter_set: Optional[ParameterSet] = field(default=None, init=False)
400
+ placeholders: dict[str, Any] = field(init=False, default_factory=dict)
319
401
 
320
402
  def apply_selections(
321
403
  self, user: Optional[User], selections: dict[str, str], *, updates_only: bool = False, request_version: Optional[int] = None
322
404
  ) -> None:
323
405
  start = time.time()
324
406
  dataset_params = self.dataset.parameters
325
- parameter_set = ParameterConfigsSetIO.obj.apply_selections(dataset_params, selections, user, updates_only=updates_only,
326
- request_version=request_version)
407
+ parameter_set = ParameterConfigsSetIO.obj.apply_selections(
408
+ dataset_params, selections, user, updates_only=updates_only, request_version=request_version
409
+ )
327
410
  self.parameter_set = parameter_set
328
- timer.add_activity_time(f"applying selections for dataset", start)
411
+ timer.add_activity_time(f"applying selections for dataset '{self.dataset.name}'", start)
329
412
 
330
413
  def _compile_context(self, context_func: ContextFunc, user: Optional[User]) -> tuple[dict[str, Any], ContextArgs]:
331
414
  start = time.time()
332
415
  context = {}
333
416
  param_args = ParameterConfigsSetIO.args
334
417
  prms = self.parameter_set.get_parameters_as_dict()
335
- args = ContextArgs(param_args.proj_vars, param_args.env_vars, user, prms, self.dataset.traits)
418
+ args = ContextArgs(param_args.proj_vars, param_args.env_vars, user, prms, self.dataset.traits, self.placeholders)
336
419
  try:
337
420
  context_func(ctx=context, sqrl=args)
338
421
  except Exception as e:
339
- raise u.FileExecutionError(f'Failed to run {c.CONTEXT_FILE} for dataset "{self.dataset}"', e)
340
- timer.add_activity_time(f"running context.py for dataset", start)
422
+ raise u.FileExecutionError(f'Failed to run {c.CONTEXT_FILE} for dataset "{self.dataset.name}"', e)
423
+ timer.add_activity_time(f"running context.py for dataset '{self.dataset.name}'", start)
341
424
  return context, args
342
425
 
343
426
  async def _compile_models(self, context: dict[str, Any], ctx_args: ContextArgs, recurse: bool) -> None:
344
- await self.target_model.compile(context, ctx_args, self.models_dict, recurse)
427
+ await self.target_model.compile(context, ctx_args, self.placeholders, self.models_dict, recurse)
345
428
 
346
429
  def _get_terminal_nodes(self) -> set[str]:
347
430
  start = time.time()
348
431
  terminal_nodes = self.target_model.get_terminal_nodes(set())
349
432
  for model in self.models_dict.values():
350
433
  model.confirmed_no_cycles = False
351
- timer.add_activity_time(f"validating no cycles in models dependencies", start)
434
+ timer.add_activity_time(f"validating no cycles in model dependencies", start)
352
435
  return terminal_nodes
353
436
 
354
- async def _run_models(self, terminal_nodes: set[str]) -> None:
355
- if u.use_duckdb():
356
- import duckdb
357
- conn = duckdb.connect()
358
- else:
359
- conn = sqlite3.connect(":memory:", check_same_thread=False)
437
+ async def _run_models(self, terminal_nodes: set[str], placeholders: dict = {}) -> None:
438
+ conn_url = "duckdb:///" if u.use_duckdb() else "sqlite:///?check_same_thread=False"
439
+ engine = create_engine(conn_url)
360
440
 
361
- try:
441
+ with engine.connect() as conn:
362
442
  coroutines = []
363
443
  for model_name in terminal_nodes:
364
444
  model = self.models_dict[model_name]
365
- coroutines.append(model.run_model(conn))
445
+ coroutines.append(model.run_model(conn, placeholders))
366
446
  await asyncio.gather(*coroutines)
367
- finally:
368
- conn.close()
447
+
448
+ engine.dispose()
369
449
 
370
450
  async def execute(
371
451
  self, context_func: ContextFunc, user: Optional[User], selections: dict[str, str], *, request_version: Optional[int] = None,
372
452
  runquery: bool = True, recurse: bool = True
373
- ) -> None:
453
+ ) -> dict[str, Any]:
374
454
  recurse = (recurse or runquery)
375
455
 
376
456
  self.apply_selections(user, selections, request_version=request_version)
@@ -381,17 +461,35 @@ class DAG:
381
461
 
382
462
  terminal_nodes = self._get_terminal_nodes()
383
463
 
464
+ placeholders = ctx_args._placeholders.copy()
384
465
  if runquery:
385
- await self._run_models(terminal_nodes)
466
+ await self._run_models(terminal_nodes, placeholders)
467
+
468
+ return placeholders
386
469
 
387
- def get_all_model_names(self) -> set[str]:
470
+ def get_all_query_models(self) -> set[str]:
388
471
  all_model_names = set()
389
- self.target_model.fill_dependent_model_names(all_model_names)
472
+ self.target_model.retrieve_dependent_query_models(all_model_names)
390
473
  return all_model_names
391
-
474
+
475
+ def to_networkx_graph(self) -> nx.DiGraph:
476
+ G = nx.DiGraph()
477
+
478
+ for model_name, model in self.models_dict.items():
479
+ model_type = model.get_model_type()
480
+ level = model.get_max_path_length_to_target()
481
+ if level is not None:
482
+ G.add_node(model_name, layer=-level, model_type=model_type)
483
+
484
+ for model_name in G.nodes:
485
+ model = self.models_dict[model_name]
486
+ for dep_model_name in model.downstreams:
487
+ G.add_edge(model_name, dep_model_name)
488
+
489
+ return G
392
490
 
393
491
  class ModelsIO:
394
- raw_queries_by_model: dict[str, QueryFile]
492
+ raw_queries_by_model: dict[str, _QueryFile]
395
493
  context_func: ContextFunc
396
494
 
397
495
  @classmethod
@@ -400,7 +498,6 @@ class ModelsIO:
400
498
  cls.raw_queries_by_model = {}
401
499
 
402
500
  def populate_raw_queries_for_type(folder_path: Path, model_type: ModelType):
403
-
404
501
  def populate_from_file(dp, file):
405
502
  query_type = None
406
503
  filepath = os.path.join(dp, file)
@@ -408,14 +505,14 @@ class ModelsIO:
408
505
  if extension == '.py':
409
506
  query_type = QueryType.PYTHON
410
507
  module = pm.PyModule(filepath)
411
- dependencies_func = module.get_func_or_class(c.DEP_FUNC, default_attr=lambda x: [])
412
- raw_query = RawPyQuery(module.get_func_or_class(c.MAIN_FUNC), dependencies_func)
508
+ dependencies_func = module.get_func_or_class(c.DEP_FUNC, default_attr=lambda sqrl: [])
509
+ raw_query = _RawPyQuery(module.get_func_or_class(c.MAIN_FUNC), dependencies_func)
413
510
  elif extension == '.sql':
414
511
  query_type = QueryType.SQL
415
- raw_query = RawSqlQuery(u.read_file(filepath))
512
+ raw_query = _RawSqlQuery(u.read_file(filepath))
416
513
 
417
514
  if query_type is not None:
418
- query_file = QueryFile(filepath, model_type, query_type, raw_query)
515
+ query_file = _QueryFile(filepath, model_type, query_type, raw_query)
419
516
  if file_stem in cls.raw_queries_by_model:
420
517
  conflicts = [cls.raw_queries_by_model[file_stem].filepath, filepath]
421
518
  raise u.ConfigurationError(f"Multiple models found for '{file_stem}': {conflicts}")
@@ -431,44 +528,98 @@ class ModelsIO:
431
528
  federates_path = u.join_paths(c.MODELS_FOLDER, c.FEDERATES_FOLDER)
432
529
  populate_raw_queries_for_type(federates_path, ModelType.FEDERATE)
433
530
 
434
- context_path = u.join_paths(c.PYCONFIG_FOLDER, c.CONTEXT_FILE)
435
- cls.context_func = pm.PyModule(context_path).get_func_or_class(c.MAIN_FUNC, default_attr=lambda x, y: None)
531
+ context_path = u.join_paths(c.PYCONFIGS_FOLDER, c.CONTEXT_FILE)
532
+ cls.context_func = pm.PyModule(context_path).get_func_or_class(c.MAIN_FUNC, default_attr=lambda ctx, sqrl: None)
436
533
 
437
- timer.add_activity_time("loading models and/or context.py", start)
534
+ timer.add_activity_time("loading files for models and context.py", start)
438
535
 
439
536
  @classmethod
440
- def GenerateDAG(cls, dataset: str, *, target_model_name: Optional[str] = None, always_pandas: bool = False) -> DAG:
441
- models_dict = {key: Model(key, val, needs_pandas=always_pandas) for key, val in cls.raw_queries_by_model.items()}
537
+ def GenerateDAG(cls, dataset: str, *, target_model_name: Optional[str] = None, always_pandas: bool = False) -> _DAG:
538
+ seeds_dict = SeedsIO.obj.get_dataframes()
539
+
540
+ models_dict: dict[str, _Referable] = {key: _Seed(key, df) for key, df in seeds_dict.items()}
541
+ for key, val in cls.raw_queries_by_model.items():
542
+ models_dict[key] = _Model(key, val)
543
+ models_dict[key].needs_pandas = always_pandas
442
544
 
443
545
  dataset_config = ManifestIO.obj.datasets[dataset]
444
546
  target_model_name = dataset_config.model if target_model_name is None else target_model_name
445
547
  target_model = models_dict[target_model_name]
446
548
  target_model.is_target = True
447
549
 
448
- return DAG(dataset_config, target_model, models_dict)
550
+ return _DAG(dataset_config, target_model, models_dict)
449
551
 
450
552
  @classmethod
451
- async def WriteDatasetOutputsGivenTestSet(cls, dataset: str, select: str, test_set: str, runquery: bool, recurse: bool) -> Any:
452
- test_set_conf = ManifestIO.obj.selection_test_sets[test_set]
453
- user_attributes = test_set_conf.user_attributes
454
- selections = test_set_conf.parameters
553
+ def draw_dag(cls, dag: _DAG, output_folder: Path) -> None:
554
+ color_map = {ModelType.SEED: "green", ModelType.DBVIEW: "red", ModelType.FEDERATE: "skyblue"}
555
+
556
+ G = dag.to_networkx_graph()
557
+
558
+ fig, _ = plt.subplots()
559
+ pos = nx.multipartite_layout(G, subset_key="layer")
560
+ colors = [color_map[node[1]] for node in G.nodes(data="model_type")]
561
+ nx.draw(G, pos=pos, node_shape='^', node_size=1000, node_color=colors, arrowsize=20)
562
+
563
+ y_values = [val[1] for val in pos.values()]
564
+ scale = max(y_values) - min(y_values) if len(y_values) > 0 else 0
565
+ label_pos = {key: (val[0], val[1]-0.002-0.1*scale) for key, val in pos.items()}
566
+ nx.draw_networkx_labels(G, pos=label_pos, font_size=8)
567
+
568
+ fig.tight_layout()
569
+ plt.margins(x=0.1, y=0.1)
570
+ plt.savefig(u.join_paths(output_folder, "dag.png"))
571
+ plt.close(fig)
572
+
573
+ @classmethod
574
+ async def WriteDatasetOutputsGivenTestSet(
575
+ cls, dataset_conf: DatasetsConfig, select: str, test_set: Optional[str], runquery: bool, recurse: bool
576
+ ) -> Any:
577
+ dataset = dataset_conf.name
578
+ default_test_set, default_test_set_conf = ManifestIO.obj.get_default_test_set(dataset)
579
+ if test_set is None or test_set == default_test_set:
580
+ test_set, test_set_conf = default_test_set, default_test_set_conf
581
+ elif test_set in ManifestIO.obj.selection_test_sets:
582
+ test_set_conf = ManifestIO.obj.selection_test_sets[test_set]
583
+ else:
584
+ raise u.InvalidInputError(f"No test set named '{test_set}' was found when compiling dataset '{dataset}'. The test set must be defined if not default for dataset.")
585
+
586
+ error_msg_intro = f"Cannot compile dataset '{dataset}' with test set '{test_set}'."
587
+ if test_set_conf.datasets is not None and dataset not in test_set_conf.datasets:
588
+ raise u.InvalidInputError(f"{error_msg_intro}\n Applicable datasets for test set '{test_set}' does not include dataset '{dataset}'.")
455
589
 
456
- username, is_internal = user_attributes.get("username", ""), user_attributes.get("is_internal", False)
457
- user_cls: type[User] = Authenticator.get_auth_helper().get_func_or_class("User", default_attr=User)
458
- user = user_cls.Create(username, test_set_conf.user_attributes, is_internal=is_internal)
590
+ user_attributes = test_set_conf.user_attributes.copy()
591
+ selections = test_set_conf.parameters.copy()
592
+ username, is_internal = user_attributes.pop("username", ""), user_attributes.pop("is_internal", False)
593
+ if test_set_conf.is_authenticated:
594
+ user_cls: type[User] = Authenticator.get_auth_helper().get_func_or_class("User", default_attr=User)
595
+ user = user_cls.Create(username, is_internal=is_internal, **user_attributes)
596
+ elif dataset_conf.scope == DatasetScope.PUBLIC:
597
+ user = None
598
+ else:
599
+ raise u.ConfigurationError(f"{error_msg_intro}\n Non-public datasets require a test set with 'user_attributes' section defined")
459
600
 
601
+ if dataset_conf.scope == DatasetScope.PRIVATE and not is_internal:
602
+ raise u.ConfigurationError(f"{error_msg_intro}\n Private datasets require a test set with user_attribute 'is_internal' set to true")
603
+
604
+ # always_pandas is set to True for creating CSV files from results (when runquery is True)
460
605
  dag = cls.GenerateDAG(dataset, target_model_name=select, always_pandas=True)
461
- await dag.execute(cls.context_func, user, selections, runquery=runquery, recurse=recurse)
606
+ placeholders = await dag.execute(cls.context_func, user, selections, runquery=runquery, recurse=recurse)
462
607
 
463
- output_folder = u.join_paths(c.TARGET_FOLDER, c.COMPILE_FOLDER, test_set, dataset)
608
+ output_folder = u.join_paths(c.TARGET_FOLDER, c.COMPILE_FOLDER, dataset, test_set)
464
609
  if os.path.exists(output_folder):
465
610
  shutil.rmtree(output_folder)
611
+ os.makedirs(output_folder, exist_ok=True)
612
+
613
+ def write_placeholders() -> None:
614
+ output_filepath = u.join_paths(output_folder, "placeholders.json")
615
+ with open(output_filepath, 'w') as f:
616
+ json.dump(placeholders, f, indent=4)
466
617
 
467
- def write_model_outputs(model: Model) -> None:
618
+ def write_model_outputs(model: _Model) -> None:
468
619
  subfolder = c.DBVIEWS_FOLDER if model.query_file.model_type == ModelType.DBVIEW else c.FEDERATES_FOLDER
469
620
  subpath = u.join_paths(output_folder, subfolder)
470
621
  os.makedirs(subpath, exist_ok=True)
471
- if isinstance(model.compiled_query, SqlModelQuery):
622
+ if isinstance(model.compiled_query, _SqlModelQuery):
472
623
  output_filepath = u.join_paths(subpath, model.name+'.sql')
473
624
  query = model.compiled_query.query
474
625
  with open(output_filepath, 'w') as f:
@@ -477,39 +628,50 @@ class ModelsIO:
477
628
  output_filepath = u.join_paths(subpath, model.name+'.csv')
478
629
  model.result.to_csv(output_filepath, index=False)
479
630
 
480
- all_model_names = dag.get_all_model_names()
631
+ write_placeholders()
632
+ all_model_names = dag.get_all_query_models()
481
633
  coroutines = [asyncio.to_thread(write_model_outputs, dag.models_dict[name]) for name in all_model_names]
482
634
  await asyncio.gather(*coroutines)
483
- return dag.target_model.compiled_query.query
635
+
636
+ if recurse:
637
+ cls.draw_dag(dag, output_folder)
638
+
639
+ if isinstance(dag.target_model, _Model):
640
+ return dag.target_model.compiled_query.query # else return None
484
641
 
485
642
  @classmethod
486
643
  async def WriteOutputs(
487
- cls, dataset: Optional[str], select: Optional[str], all_test_sets: bool, test_set: Optional[str], runquery: bool
644
+ cls, dataset: Optional[str], do_all_datasets: bool, select: Optional[str], test_set: Optional[str], do_all_test_sets: bool,
645
+ runquery: bool
488
646
  ) -> None:
489
- if test_set is None:
490
- test_set = ManifestIO.obj.settings.get(c.TEST_SET_DEFAULT_USED_SETTING, c.DEFAULT_TEST_SET_NAME)
491
-
492
- if all_test_sets:
493
- test_sets = ManifestIO.obj.selection_test_sets.keys()
494
- else:
495
- test_sets = [test_set]
647
+
648
+ def get_applicable_test_sets(dataset: str) -> list[str]:
649
+ applicable_test_sets = []
650
+ for test_set_name, test_set_config in ManifestIO.obj.selection_test_sets.items():
651
+ if test_set_config.datasets is None or dataset in test_set_config.datasets:
652
+ applicable_test_sets.append(test_set_name)
653
+ return applicable_test_sets
496
654
 
497
655
  recurse = True
498
656
  dataset_configs = ManifestIO.obj.datasets
499
- if dataset is None:
500
- selected_models = [(dataset.name, dataset.model) for dataset in dataset_configs.values()]
657
+ if do_all_datasets:
658
+ selected_models = [(dataset, dataset.model) for dataset in dataset_configs.values()]
501
659
  else:
502
660
  if select is None:
503
661
  select = dataset_configs[dataset].model
504
662
  else:
505
663
  recurse = False
506
- selected_models = [(dataset, select)]
664
+ selected_models = [(dataset_configs[dataset], select)]
507
665
 
508
666
  coroutines = []
509
- for test_set in test_sets:
510
- for dataset, select in selected_models:
511
- coroutine = cls.WriteDatasetOutputsGivenTestSet(dataset, select, test_set, runquery, recurse)
512
- coroutines.append(coroutine)
667
+ for dataset_conf, select in selected_models:
668
+ if do_all_test_sets:
669
+ for test_set_name in get_applicable_test_sets(dataset_conf.name):
670
+ coroutine = cls.WriteDatasetOutputsGivenTestSet(dataset_conf, select, test_set_name, runquery, recurse)
671
+ coroutines.append(coroutine)
672
+
673
+ coroutine = cls.WriteDatasetOutputsGivenTestSet(dataset_conf, select, test_set, runquery, recurse)
674
+ coroutines.append(coroutine)
513
675
 
514
676
  queries = await asyncio.gather(*coroutines)
515
677
  if not recurse and len(queries) == 1 and isinstance(queries[0], str):