squirrels 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of squirrels might be problematic. Click here for more details.

Files changed (48) hide show
  1. squirrels/__init__.py +11 -4
  2. squirrels/_api_response_models.py +118 -0
  3. squirrels/_api_server.py +140 -75
  4. squirrels/_authenticator.py +10 -8
  5. squirrels/_command_line.py +17 -11
  6. squirrels/_connection_set.py +2 -2
  7. squirrels/_constants.py +13 -5
  8. squirrels/_initializer.py +23 -13
  9. squirrels/_manifest.py +20 -10
  10. squirrels/_models.py +303 -148
  11. squirrels/_parameter_configs.py +195 -57
  12. squirrels/_parameter_sets.py +14 -17
  13. squirrels/_py_module.py +2 -4
  14. squirrels/_seeds.py +38 -0
  15. squirrels/_utils.py +41 -33
  16. squirrels/arguments/run_time_args.py +76 -34
  17. squirrels/data_sources.py +172 -51
  18. squirrels/dateutils.py +3 -3
  19. squirrels/package_data/assets/index.js +14 -14
  20. squirrels/package_data/base_project/connections.yml +1 -1
  21. squirrels/package_data/base_project/database/expenses.db +0 -0
  22. squirrels/package_data/base_project/docker/Dockerfile +1 -1
  23. squirrels/package_data/base_project/environcfg.yml +7 -7
  24. squirrels/package_data/base_project/models/dbviews/database_view1.py +25 -14
  25. squirrels/package_data/base_project/models/dbviews/database_view1.sql +21 -14
  26. squirrels/package_data/base_project/models/federates/dataset_example.py +6 -5
  27. squirrels/package_data/base_project/models/federates/dataset_example.sql +1 -1
  28. squirrels/package_data/base_project/parameters.yml +57 -28
  29. squirrels/package_data/base_project/pyconfigs/auth.py +11 -10
  30. squirrels/package_data/base_project/pyconfigs/connections.py +6 -8
  31. squirrels/package_data/base_project/pyconfigs/context.py +49 -33
  32. squirrels/package_data/base_project/pyconfigs/parameters.py +62 -30
  33. squirrels/package_data/base_project/seeds/seed_categories.csv +6 -0
  34. squirrels/package_data/base_project/seeds/seed_subcategories.csv +15 -0
  35. squirrels/package_data/base_project/squirrels.yml.j2 +37 -20
  36. squirrels/parameter_options.py +30 -10
  37. squirrels/parameters.py +300 -70
  38. squirrels/user_base.py +3 -13
  39. squirrels-0.3.0.dist-info/LICENSE +201 -0
  40. {squirrels-0.2.1.dist-info → squirrels-0.3.0.dist-info}/METADATA +15 -15
  41. squirrels-0.3.0.dist-info/RECORD +56 -0
  42. {squirrels-0.2.1.dist-info → squirrels-0.3.0.dist-info}/WHEEL +1 -1
  43. squirrels/package_data/base_project/seeds/mocks/category.csv +0 -3
  44. squirrels/package_data/base_project/seeds/mocks/max_filter.csv +0 -2
  45. squirrels/package_data/base_project/seeds/mocks/subcategory.csv +0 -6
  46. squirrels-0.2.1.dist-info/LICENSE +0 -22
  47. squirrels-0.2.1.dist-info/RECORD +0 -55
  48. {squirrels-0.2.1.dist-info → squirrels-0.3.0.dist-info}/entry_points.txt +0 -0
squirrels/_models.py CHANGED
@@ -1,21 +1,26 @@
1
1
  from __future__ import annotations
2
- from typing import Union, Optional, Callable, Iterable, Any
2
+ from typing import Optional, Callable, Iterable, Any
3
3
  from dataclasses import dataclass, field
4
+ from abc import ABCMeta, abstractmethod
4
5
  from enum import Enum
5
6
  from pathlib import Path
6
- import sqlite3, pandas as pd, asyncio, os, shutil
7
+ from sqlalchemy import create_engine, text, Connection
8
+ import asyncio, os, shutil, pandas as pd, json
9
+ import matplotlib.pyplot as plt, networkx as nx
7
10
 
8
11
  from . import _constants as c, _utils as u, _py_module as pm
9
12
  from .arguments.run_time_args import ContextArgs, ModelDepsArgs, ModelArgs
10
13
  from ._authenticator import User, Authenticator
11
14
  from ._connection_set import ConnectionSetIO
12
- from ._manifest import ManifestIO, DatasetsConfig
15
+ from ._manifest import ManifestIO, DatasetsConfig, DatasetScope
13
16
  from ._parameter_sets import ParameterConfigsSetIO, ParameterSet
17
+ from ._seeds import SeedsIO
14
18
  from ._timer import timer, time
15
19
 
16
20
  class ModelType(Enum):
17
21
  DBVIEW = 1
18
22
  FEDERATE = 2
23
+ SEED = 3
19
24
 
20
25
  class QueryType(Enum):
21
26
  SQL = 0
@@ -27,7 +32,7 @@ class Materialization(Enum):
27
32
 
28
33
 
29
34
  @dataclass
30
- class SqlModelConfig:
35
+ class _SqlModelConfig:
31
36
  ## Applicable for dbview models
32
37
  connection_name: str
33
38
 
@@ -58,63 +63,132 @@ ContextFunc = Callable[[dict[str, Any], ContextArgs], None]
58
63
 
59
64
 
60
65
  @dataclass(frozen=True)
61
- class RawQuery:
66
+ class _RawQuery(metaclass=ABCMeta):
62
67
  pass
63
68
 
64
69
  @dataclass(frozen=True)
65
- class RawSqlQuery(RawQuery):
70
+ class _RawSqlQuery(_RawQuery):
66
71
  query: str
67
72
 
68
73
  @dataclass(frozen=True)
69
- class RawPyQuery(RawQuery):
74
+ class _RawPyQuery(_RawQuery):
70
75
  query: Callable[[Any], pd.DataFrame]
71
76
  dependencies_func: Callable[[Any], Iterable]
72
77
 
73
78
 
74
79
  @dataclass
75
- class Query:
80
+ class _Query(metaclass=ABCMeta):
76
81
  query: Any
77
82
 
78
83
  @dataclass
79
- class WorkInProgress:
84
+ class _WorkInProgress(_Query):
80
85
  query: None = field(default=None, init=False)
81
86
 
82
87
  @dataclass
83
- class SqlModelQuery(Query):
88
+ class _SqlModelQuery(_Query):
84
89
  query: str
85
- config: SqlModelConfig
90
+ config: _SqlModelConfig
86
91
 
87
92
  @dataclass
88
- class PyModelQuery(Query):
93
+ class _PyModelQuery(_Query):
89
94
  query: Callable[[], pd.DataFrame]
90
95
 
91
96
 
92
97
  @dataclass(frozen=True)
93
- class QueryFile:
98
+ class _QueryFile:
94
99
  filepath: str
95
100
  model_type: ModelType
96
101
  query_type: QueryType
97
- raw_query: RawQuery
102
+ raw_query: _RawQuery
98
103
 
99
104
 
100
105
  @dataclass
101
- class Model:
106
+ class _Referable(metaclass=ABCMeta):
102
107
  name: str
103
- query_file: QueryFile
104
108
  is_target: bool = field(default=False, init=False)
105
- compiled_query: Optional[Query] = field(default=None, init=False)
106
109
 
107
110
  needs_sql_table: bool = field(default=False, init=False)
108
- needs_pandas: bool = False
111
+ needs_pandas: bool = field(default=False, init=False)
109
112
  result: Optional[pd.DataFrame] = field(default=None, init=False, repr=False)
110
-
111
- wait_count: int = field(default=0, init=False, repr=False)
112
- upstreams: dict[str, Model] = field(default_factory=dict, init=False, repr=False)
113
- downstreams: dict[str, Model] = field(default_factory=dict, init=False, repr=False)
114
113
 
114
+ wait_count: int = field(default=0, init=False, repr=False)
115
115
  confirmed_no_cycles: bool = field(default=False, init=False)
116
+ upstreams: dict[str, _Referable] = field(default_factory=dict, init=False, repr=False)
117
+ downstreams: dict[str, _Referable] = field(default_factory=dict, init=False, repr=False)
118
+
119
+ @abstractmethod
120
+ def get_model_type(self) -> ModelType:
121
+ pass
122
+
123
+ async def compile(
124
+ self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str, _Referable], recurse: bool
125
+ ) -> None:
126
+ pass
127
+
128
+ @abstractmethod
129
+ def get_terminal_nodes(self, depencency_path: set[str]) -> set[str]:
130
+ pass
131
+
132
+ def _load_pandas_to_table(self, df: pd.DataFrame, conn: Connection) -> None:
133
+ df.to_sql(self.name, conn, index=False)
134
+
135
+ def _load_table_to_pandas(self, conn: Connection) -> pd.DataFrame:
136
+ query = f"SELECT * FROM {self.name}"
137
+ return pd.read_sql(query, conn)
138
+
139
+ async def _trigger(self, conn: Connection, placeholders: dict = {}) -> None:
140
+ self.wait_count -= 1
141
+ if (self.wait_count == 0):
142
+ await self.run_model(conn, placeholders)
143
+
144
+ @abstractmethod
145
+ async def run_model(self, conn: Connection, placeholders: dict = {}) -> None:
146
+ coroutines = []
147
+ for model in self.downstreams.values():
148
+ coroutines.append(model._trigger(conn, placeholders))
149
+ await asyncio.gather(*coroutines)
150
+
151
+ def retrieve_dependent_query_models(self, dependent_model_names: set[str]) -> None:
152
+ pass
153
+
154
+ def get_max_path_length_to_target(self) -> int:
155
+ if not hasattr(self, "max_path_len_to_target"):
156
+ path_lengths = []
157
+ for child_model in self.downstreams.values():
158
+ path_lengths.append(child_model.get_max_path_length_to_target()+1)
159
+ if len(path_lengths) > 0:
160
+ self.max_path_len_to_target = max(path_lengths)
161
+ else:
162
+ self.max_path_len_to_target = 0 if self.is_target else None
163
+ return self.max_path_len_to_target
164
+
116
165
 
117
- def _add_upstream(self, other: Model) -> None:
166
+ @dataclass
167
+ class _Seed(_Referable):
168
+ result: pd.DataFrame
169
+
170
+ def get_model_type(self) -> ModelType:
171
+ return ModelType.SEED
172
+
173
+ def get_terminal_nodes(self, depencency_path: set[str]) -> set[str]:
174
+ return {self.name}
175
+
176
+ async def run_model(self, conn: Connection, placeholders: dict = {}) -> None:
177
+ if self.needs_sql_table:
178
+ await asyncio.to_thread(self._load_pandas_to_table, self.result, conn)
179
+ await super().run_model(conn, placeholders)
180
+
181
+
182
+ @dataclass
183
+ class _Model(_Referable):
184
+ query_file: _QueryFile
185
+
186
+ compiled_query: Optional[_Query] = field(default=None, init=False)
187
+
188
+ def get_model_type(self) -> ModelType:
189
+ return self.query_file.model_type
190
+
191
+ def _add_upstream(self, other: _Referable) -> None:
118
192
  self.upstreams[other.name] = other
119
193
  other.downstreams[self.name] = self
120
194
 
@@ -137,17 +211,20 @@ class Model:
137
211
  materialized = federate_config.materialized
138
212
  return Materialization[materialized.upper()]
139
213
 
140
- async def _compile_sql_model(self, ctx: dict[str, Any], ctx_args: ContextArgs) -> tuple[SqlModelQuery, set]:
141
- assert(isinstance(self.query_file.raw_query, RawSqlQuery))
214
+ async def _compile_sql_model(
215
+ self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any]
216
+ ) -> tuple[_SqlModelQuery, set]:
217
+ assert(isinstance(self.query_file.raw_query, _RawSqlQuery))
218
+
142
219
  raw_query = self.query_file.raw_query.query
143
-
144
220
  connection_name = self._get_dbview_conn_name()
145
221
  materialized = self._get_materialized()
146
- configuration = SqlModelConfig(connection_name, materialized)
222
+ configuration = _SqlModelConfig(connection_name, materialized)
223
+ is_placeholder = lambda x: x in placeholders
147
224
  kwargs = {
148
- "proj_vars": ctx_args.proj_vars, "env_vars": ctx_args.env_vars,
149
- "user": ctx_args.user, "prms": ctx_args.prms, "traits": ctx_args.traits,
150
- "ctx": ctx, "config": configuration.set_attribute
225
+ "proj_vars": ctx_args.proj_vars, "env_vars": ctx_args.env_vars, "user": ctx_args.user, "prms": ctx_args.prms,
226
+ "traits": ctx_args.traits, "ctx": ctx, "is_placeholder": is_placeholder, "set_placeholder": ctx_args.set_placeholder,
227
+ "config": configuration.set_attribute, "is_param_enabled": ctx_args.param_exists
151
228
  }
152
229
  dependencies = set()
153
230
  if self.query_file.model_type == ModelType.FEDERATE:
@@ -157,16 +234,21 @@ class Model:
157
234
  kwargs["ref"] = ref
158
235
 
159
236
  try:
160
- query = await asyncio.to_thread(u.render_string, raw_query, kwargs)
237
+ query = await asyncio.to_thread(u.render_string, raw_query, **kwargs)
161
238
  except Exception as e:
162
239
  raise u.FileExecutionError(f'Failed to compile sql model "{self.name}"', e)
163
240
 
164
- compiled_query = SqlModelQuery(query, configuration)
241
+ compiled_query = _SqlModelQuery(query, configuration)
165
242
  return compiled_query, dependencies
166
243
 
167
- async def _compile_python_model(self, ctx: dict[str, Any], ctx_args: ContextArgs) -> tuple[PyModelQuery, set]:
168
- assert(isinstance(self.query_file.raw_query, RawPyQuery))
169
- sqrl_args = ModelDepsArgs(ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits, ctx)
244
+ async def _compile_python_model(
245
+ self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any]
246
+ ) -> tuple[_PyModelQuery, set]:
247
+ assert(isinstance(self.query_file.raw_query, _RawPyQuery))
248
+
249
+ sqrl_args = ModelDepsArgs(
250
+ ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits, placeholders, ctx
251
+ )
170
252
  try:
171
253
  dependencies = await asyncio.to_thread(self.query_file.raw_query.dependencies_func, sqrl_args)
172
254
  except Exception as e:
@@ -175,34 +257,42 @@ class Model:
175
257
  dbview_conn_name = self._get_dbview_conn_name()
176
258
  connections = ConnectionSetIO.obj.get_engines_as_dict()
177
259
  ref = lambda x: self.upstreams[x].result
178
- sqrl_args = ModelArgs(ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits,
179
- ctx, dbview_conn_name, connections, ref, set(dependencies))
260
+ sqrl_args = ModelArgs(
261
+ ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits, placeholders, ctx,
262
+ dbview_conn_name, connections, dependencies, ref
263
+ )
180
264
 
181
265
  def compiled_query():
182
266
  try:
183
- return self.query_file.raw_query.query(sqrl=sqrl_args)
267
+ raw_query: _RawPyQuery = self.query_file.raw_query
268
+ return raw_query.query(sqrl=sqrl_args)
184
269
  except Exception as e:
185
270
  raise u.FileExecutionError(f'Failed to run "{c.MAIN_FUNC}" function for python model "{self.name}"', e)
186
271
 
187
- return PyModelQuery(compiled_query), dependencies
272
+ return _PyModelQuery(compiled_query), dependencies
188
273
 
189
- async def compile(self, ctx: dict[str, Any], ctx_args: ContextArgs, models_dict: dict[str, Model], recurse: bool) -> None:
274
+ async def compile(
275
+ self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str, _Referable], recurse: bool
276
+ ) -> None:
190
277
  if self.compiled_query is not None:
191
278
  return
192
279
  else:
193
- self.compiled_query = WorkInProgress()
280
+ self.compiled_query = _WorkInProgress()
194
281
 
195
282
  start = time.time()
283
+
196
284
  if self.query_file.query_type == QueryType.SQL:
197
- compiled_query, dependencies = await self._compile_sql_model(ctx, ctx_args)
285
+ compiled_query, dependencies = await self._compile_sql_model(ctx, ctx_args, placeholders)
198
286
  elif self.query_file.query_type == QueryType.PYTHON:
199
- compiled_query, dependencies = await self._compile_python_model(ctx, ctx_args)
287
+ compiled_query, dependencies = await self._compile_python_model(ctx, ctx_args, placeholders)
200
288
  else:
201
289
  raise NotImplementedError(f"Query type not supported: {self.query_file.query_type}")
202
290
 
203
291
  self.compiled_query = compiled_query
204
292
  self.wait_count = len(dependencies)
205
- timer.add_activity_time(f"compiling model '{self.name}'", start)
293
+
294
+ model_type = self.get_model_type().name.lower()
295
+ timer.add_activity_time(f"compiling {model_type} model '{self.name}'", start)
206
296
 
207
297
  if not recurse:
208
298
  return
@@ -211,13 +301,13 @@ class Model:
211
301
  coroutines = []
212
302
  for dep_model in dep_models:
213
303
  self._add_upstream(dep_model)
214
- coro = dep_model.compile(ctx, ctx_args, models_dict, recurse)
304
+ coro = dep_model.compile(ctx, ctx_args, placeholders, models_dict, recurse)
215
305
  coroutines.append(coro)
216
306
  await asyncio.gather(*coroutines)
217
307
 
218
- def validate_no_cycles(self, depencency_path: set[str]) -> set[str]:
308
+ def get_terminal_nodes(self, depencency_path: set[str]) -> set[str]:
219
309
  if self.confirmed_no_cycles:
220
- return
310
+ return set()
221
311
 
222
312
  if self.name in depencency_path:
223
313
  raise u.ConfigurationError(f'Cycle found in model dependency graph')
@@ -229,34 +319,21 @@ class Model:
229
319
  new_path = set(depencency_path)
230
320
  new_path.add(self.name)
231
321
  for dep_model in self.upstreams.values():
232
- terminal_nodes_under_dep = dep_model.validate_no_cycles(new_path)
322
+ terminal_nodes_under_dep = dep_model.get_terminal_nodes(new_path)
233
323
  terminal_nodes = terminal_nodes.union(terminal_nodes_under_dep)
234
324
 
235
325
  self.confirmed_no_cycles = True
236
326
  return terminal_nodes
237
-
238
- def _load_pandas_to_table(self, df: pd.DataFrame, conn: sqlite3.Connection) -> None:
239
- if u.use_duckdb():
240
- conn.execute(f"CREATE TABLE {self.name} AS FROM df")
241
- else:
242
- df.to_sql(self.name, conn, index=False)
243
-
244
- def _load_table_to_pandas(self, conn: sqlite3.Connection) -> pd.DataFrame:
245
- if u.use_duckdb():
246
- return conn.execute(f"FROM {self.name}").df()
247
- else:
248
- query = f"SELECT * FROM {self.name}"
249
- return pd.read_sql(query, conn)
250
327
 
251
- async def _run_sql_model(self, conn: sqlite3.Connection) -> None:
252
- assert(isinstance(self.compiled_query, SqlModelQuery))
328
+ async def _run_sql_model(self, conn: Connection, placeholders: dict = {}) -> None:
329
+ assert(isinstance(self.compiled_query, _SqlModelQuery))
253
330
  config = self.compiled_query.config
254
331
  query = self.compiled_query.query
255
332
 
256
333
  if self.query_file.model_type == ModelType.DBVIEW:
257
334
  def run_sql_query():
258
335
  try:
259
- return ConnectionSetIO.obj.run_sql_query_from_conn_name(query, config.connection_name)
336
+ return ConnectionSetIO.obj.run_sql_query_from_conn_name(query, config.connection_name, placeholders)
260
337
  except RuntimeError as e:
261
338
  raise u.FileExecutionError(f'Failed to run dbview sql model "{self.name}"', e)
262
339
 
@@ -268,7 +345,7 @@ class Model:
268
345
  def create_table():
269
346
  create_query = config.get_sql_for_create(self.name, query)
270
347
  try:
271
- return conn.execute(create_query)
348
+ return conn.execute(text(create_query), placeholders)
272
349
  except Exception as e:
273
350
  raise u.FileExecutionError(f'Failed to run federate sql model "{self.name}"', e)
274
351
 
@@ -276,8 +353,8 @@ class Model:
276
353
  if self.needs_pandas or self.is_target:
277
354
  self.result = await asyncio.to_thread(self._load_table_to_pandas, conn)
278
355
 
279
- async def _run_python_model(self, conn: sqlite3.Connection) -> None:
280
- assert(isinstance(self.compiled_query, PyModelQuery))
356
+ async def _run_python_model(self, conn: Connection) -> None:
357
+ assert(isinstance(self.compiled_query, _PyModelQuery))
281
358
 
282
359
  df = await asyncio.to_thread(self.compiled_query.query)
283
360
  if self.needs_sql_table:
@@ -285,90 +362,86 @@ class Model:
285
362
  if self.needs_pandas or self.is_target:
286
363
  self.result = df
287
364
 
288
- async def run_model(self, conn: sqlite3.Connection) -> None:
365
+ async def run_model(self, conn: Connection, placeholders: dict = {}) -> None:
289
366
  start = time.time()
367
+
290
368
  if self.query_file.query_type == QueryType.SQL:
291
- await self._run_sql_model(conn)
369
+ await self._run_sql_model(conn, placeholders)
292
370
  elif self.query_file.query_type == QueryType.PYTHON:
293
371
  await self._run_python_model(conn)
294
- timer.add_activity_time(f"running model '{self.name}'", start)
295
372
 
296
- coroutines = []
297
- for model in self.downstreams.values():
298
- coroutines.append(model.trigger(conn))
299
- await asyncio.gather(*coroutines)
300
-
301
- async def trigger(self, conn: sqlite3.Connection) -> None:
302
- self.wait_count -= 1
303
- if (self.wait_count == 0):
304
- await self.run_model(conn)
373
+ model_type = self.get_model_type().name.lower()
374
+ timer.add_activity_time(f"running {model_type} model '{self.name}'", start)
375
+
376
+ await super().run_model(conn, placeholders)
305
377
 
306
- def fill_dependent_model_names(self, dependent_model_names: set[str]) -> None:
378
+ def retrieve_dependent_query_models(self, dependent_model_names: set[str]) -> None:
307
379
  if self.name not in dependent_model_names:
308
380
  dependent_model_names.add(self.name)
309
381
  for dep_model in self.upstreams.values():
310
- dep_model.fill_dependent_model_names(dependent_model_names)
382
+ dep_model.retrieve_dependent_query_models(dependent_model_names)
311
383
 
312
384
 
313
385
  @dataclass
314
- class DAG:
386
+ class _DAG:
315
387
  dataset: DatasetsConfig
316
- target_model: Model
317
- models_dict: dict[str, Model]
388
+ target_model: _Referable
389
+ models_dict: dict[str, _Referable]
318
390
  parameter_set: Optional[ParameterSet] = field(default=None, init=False)
391
+ placeholders: dict[str, Any] = field(init=False, default_factory=dict)
319
392
 
320
393
  def apply_selections(
321
394
  self, user: Optional[User], selections: dict[str, str], *, updates_only: bool = False, request_version: Optional[int] = None
322
395
  ) -> None:
323
396
  start = time.time()
324
397
  dataset_params = self.dataset.parameters
325
- parameter_set = ParameterConfigsSetIO.obj.apply_selections(dataset_params, selections, user, updates_only=updates_only,
326
- request_version=request_version)
398
+ parameter_set = ParameterConfigsSetIO.obj.apply_selections(
399
+ dataset_params, selections, user, updates_only=updates_only, request_version=request_version
400
+ )
327
401
  self.parameter_set = parameter_set
328
- timer.add_activity_time(f"applying selections for dataset", start)
402
+ timer.add_activity_time(f"applying selections for dataset '{self.dataset.name}'", start)
329
403
 
330
404
  def _compile_context(self, context_func: ContextFunc, user: Optional[User]) -> tuple[dict[str, Any], ContextArgs]:
331
405
  start = time.time()
332
406
  context = {}
333
407
  param_args = ParameterConfigsSetIO.args
334
408
  prms = self.parameter_set.get_parameters_as_dict()
335
- args = ContextArgs(param_args.proj_vars, param_args.env_vars, user, prms, self.dataset.traits)
409
+ args = ContextArgs(param_args.proj_vars, param_args.env_vars, user, prms, self.dataset.traits, self.placeholders)
336
410
  try:
337
411
  context_func(ctx=context, sqrl=args)
338
412
  except Exception as e:
339
- raise u.FileExecutionError(f'Failed to run {c.CONTEXT_FILE} for dataset "{self.dataset}"', e)
340
- timer.add_activity_time(f"running context.py for dataset", start)
413
+ raise u.FileExecutionError(f'Failed to run {c.CONTEXT_FILE} for dataset "{self.dataset.name}"', e)
414
+ timer.add_activity_time(f"running context.py for dataset '{self.dataset.name}'", start)
341
415
  return context, args
342
416
 
343
417
  async def _compile_models(self, context: dict[str, Any], ctx_args: ContextArgs, recurse: bool) -> None:
344
- await self.target_model.compile(context, ctx_args, self.models_dict, recurse)
418
+ await self.target_model.compile(context, ctx_args, self.placeholders, self.models_dict, recurse)
345
419
 
346
- def _validate_no_cycles(self) -> set[str]:
420
+ def _get_terminal_nodes(self) -> set[str]:
347
421
  start = time.time()
348
- terminal_nodes = self.target_model.validate_no_cycles(set())
349
- timer.add_activity_time(f"validating no cycles in models dependencies", start)
422
+ terminal_nodes = self.target_model.get_terminal_nodes(set())
423
+ for model in self.models_dict.values():
424
+ model.confirmed_no_cycles = False
425
+ timer.add_activity_time(f"validating no cycles in model dependencies", start)
350
426
  return terminal_nodes
351
427
 
352
- async def _run_models(self, terminal_nodes: set[str]) -> None:
353
- if u.use_duckdb():
354
- import duckdb
355
- conn = duckdb.connect()
356
- else:
357
- conn = sqlite3.connect(":memory:", check_same_thread=False)
428
+ async def _run_models(self, terminal_nodes: set[str], placeholders: dict = {}) -> None:
429
+ conn_url = "duckdb:///" if u.use_duckdb() else "sqlite:///?check_same_thread=False"
430
+ engine = create_engine(conn_url)
358
431
 
359
- try:
432
+ with engine.connect() as conn:
360
433
  coroutines = []
361
434
  for model_name in terminal_nodes:
362
435
  model = self.models_dict[model_name]
363
- coroutines.append(model.run_model(conn))
436
+ coroutines.append(model.run_model(conn, placeholders))
364
437
  await asyncio.gather(*coroutines)
365
- finally:
366
- conn.close()
438
+
439
+ engine.dispose()
367
440
 
368
441
  async def execute(
369
442
  self, context_func: ContextFunc, user: Optional[User], selections: dict[str, str], *, request_version: Optional[int] = None,
370
443
  runquery: bool = True, recurse: bool = True
371
- ) -> None:
444
+ ) -> dict[str, Any]:
372
445
  recurse = (recurse or runquery)
373
446
 
374
447
  self.apply_selections(user, selections, request_version=request_version)
@@ -377,19 +450,37 @@ class DAG:
377
450
 
378
451
  await self._compile_models(context, ctx_args, recurse)
379
452
 
380
- terminal_nodes = self._validate_no_cycles()
453
+ terminal_nodes = self._get_terminal_nodes()
381
454
 
455
+ placeholders = ctx_args._placeholders.copy()
382
456
  if runquery:
383
- await self._run_models(terminal_nodes)
457
+ await self._run_models(terminal_nodes, placeholders)
458
+
459
+ return placeholders
384
460
 
385
- def get_all_model_names(self) -> set[str]:
461
+ def get_all_query_models(self) -> set[str]:
386
462
  all_model_names = set()
387
- self.target_model.fill_dependent_model_names(all_model_names)
463
+ self.target_model.retrieve_dependent_query_models(all_model_names)
388
464
  return all_model_names
389
-
465
+
466
+ def to_networkx_graph(self) -> nx.DiGraph:
467
+ G = nx.DiGraph()
468
+
469
+ for model_name, model in self.models_dict.items():
470
+ model_type = model.get_model_type()
471
+ level = model.get_max_path_length_to_target()
472
+ if level is not None:
473
+ G.add_node(model_name, layer=-level, model_type=model_type)
474
+
475
+ for model_name in G.nodes:
476
+ model = self.models_dict[model_name]
477
+ for dep_model_name in model.downstreams:
478
+ G.add_edge(model_name, dep_model_name)
479
+
480
+ return G
390
481
 
391
482
  class ModelsIO:
392
- raw_queries_by_model: dict[str, QueryFile]
483
+ raw_queries_by_model: dict[str, _QueryFile]
393
484
  context_func: ContextFunc
394
485
 
395
486
  @classmethod
@@ -398,7 +489,6 @@ class ModelsIO:
398
489
  cls.raw_queries_by_model = {}
399
490
 
400
491
  def populate_raw_queries_for_type(folder_path: Path, model_type: ModelType):
401
-
402
492
  def populate_from_file(dp, file):
403
493
  query_type = None
404
494
  filepath = os.path.join(dp, file)
@@ -407,13 +497,13 @@ class ModelsIO:
407
497
  query_type = QueryType.PYTHON
408
498
  module = pm.PyModule(filepath)
409
499
  dependencies_func = module.get_func_or_class(c.DEP_FUNC, default_attr=lambda x: [])
410
- raw_query = RawPyQuery(module.get_func_or_class(c.MAIN_FUNC), dependencies_func)
500
+ raw_query = _RawPyQuery(module.get_func_or_class(c.MAIN_FUNC), dependencies_func)
411
501
  elif extension == '.sql':
412
502
  query_type = QueryType.SQL
413
- raw_query = RawSqlQuery(u.read_file(filepath))
503
+ raw_query = _RawSqlQuery(u.read_file(filepath))
414
504
 
415
505
  if query_type is not None:
416
- query_file = QueryFile(filepath, model_type, query_type, raw_query)
506
+ query_file = _QueryFile(filepath, model_type, query_type, raw_query)
417
507
  if file_stem in cls.raw_queries_by_model:
418
508
  conflicts = [cls.raw_queries_by_model[file_stem].filepath, filepath]
419
509
  raise u.ConfigurationError(f"Multiple models found for '{file_stem}': {conflicts}")
@@ -429,44 +519,98 @@ class ModelsIO:
429
519
  federates_path = u.join_paths(c.MODELS_FOLDER, c.FEDERATES_FOLDER)
430
520
  populate_raw_queries_for_type(federates_path, ModelType.FEDERATE)
431
521
 
432
- context_path = u.join_paths(c.PYCONFIG_FOLDER, c.CONTEXT_FILE)
522
+ context_path = u.join_paths(c.PYCONFIGS_FOLDER, c.CONTEXT_FILE)
433
523
  cls.context_func = pm.PyModule(context_path).get_func_or_class(c.MAIN_FUNC, default_attr=lambda x, y: None)
434
524
 
435
- timer.add_activity_time("loading models and/or context.py", start)
525
+ timer.add_activity_time("loading files for models and context.py", start)
436
526
 
437
527
  @classmethod
438
- def GenerateDAG(cls, dataset: str, *, target_model_name: Optional[str] = None, always_pandas: bool = False) -> DAG:
439
- models_dict = {key: Model(key, val, needs_pandas=always_pandas) for key, val in cls.raw_queries_by_model.items()}
528
+ def GenerateDAG(cls, dataset: str, *, target_model_name: Optional[str] = None, always_pandas: bool = False) -> _DAG:
529
+ seeds_dict = SeedsIO.obj.get_dataframes()
530
+
531
+ models_dict: dict[str, _Referable] = {key: _Seed(key, df) for key, df in seeds_dict.items()}
532
+ for key, val in cls.raw_queries_by_model.items():
533
+ models_dict[key] = _Model(key, val)
534
+ models_dict[key].needs_pandas = always_pandas
440
535
 
441
536
  dataset_config = ManifestIO.obj.datasets[dataset]
442
537
  target_model_name = dataset_config.model if target_model_name is None else target_model_name
443
538
  target_model = models_dict[target_model_name]
444
539
  target_model.is_target = True
445
540
 
446
- return DAG(dataset_config, target_model, models_dict)
541
+ return _DAG(dataset_config, target_model, models_dict)
447
542
 
448
543
  @classmethod
449
- async def WriteDatasetOutputsGivenTestSet(cls, dataset: str, select: str, test_set: str, runquery: bool, recurse: bool) -> Any:
450
- test_set_conf = ManifestIO.obj.selection_test_sets[test_set]
451
- user_attributes = test_set_conf.user_attributes
452
- selections = test_set_conf.parameters
544
+ def draw_dag(cls, dag: _DAG, output_folder: Path) -> None:
545
+ color_map = {ModelType.SEED: "green", ModelType.DBVIEW: "red", ModelType.FEDERATE: "skyblue"}
546
+
547
+ G = dag.to_networkx_graph()
548
+
549
+ fig, _ = plt.subplots()
550
+ pos = nx.multipartite_layout(G, subset_key="layer")
551
+ colors = [color_map[node[1]] for node in G.nodes(data="model_type")]
552
+ nx.draw(G, pos=pos, node_shape='^', node_size=1000, node_color=colors, arrowsize=20)
553
+
554
+ y_values = [val[1] for val in pos.values()]
555
+ scale = max(y_values) - min(y_values) if len(y_values) > 0 else 0
556
+ label_pos = {key: (val[0], val[1]-0.002-0.1*scale) for key, val in pos.items()}
557
+ nx.draw_networkx_labels(G, pos=label_pos, font_size=8)
558
+
559
+ fig.tight_layout()
560
+ plt.margins(x=0.1, y=0.1)
561
+ plt.savefig(u.join_paths(output_folder, "dag.png"))
562
+ plt.close(fig)
563
+
564
+ @classmethod
565
+ async def WriteDatasetOutputsGivenTestSet(
566
+ cls, dataset_conf: DatasetsConfig, select: str, test_set: Optional[str], runquery: bool, recurse: bool
567
+ ) -> Any:
568
+ dataset = dataset_conf.name
569
+ default_test_set, default_test_set_conf = ManifestIO.obj.get_default_test_set(dataset)
570
+ if test_set is None or test_set == default_test_set:
571
+ test_set, test_set_conf = default_test_set, default_test_set_conf
572
+ elif test_set in ManifestIO.obj.selection_test_sets:
573
+ test_set_conf = ManifestIO.obj.selection_test_sets[test_set]
574
+ else:
575
+ raise u.InvalidInputError(f"No test set named '{test_set}' was found when compiling dataset '{dataset}'. The test set must be defined if not default for dataset.")
453
576
 
454
- username, is_internal = user_attributes.get("username", ""), user_attributes.get("is_internal", False)
455
- user_cls: type[User] = Authenticator.get_auth_helper().get_func_or_class("User", default_attr=User)
456
- user = user_cls.Create(username, test_set_conf.user_attributes, is_internal=is_internal)
577
+ error_msg_intro = f"Cannot compile dataset '{dataset}' with test set '{test_set}'."
578
+ if test_set_conf.datasets is not None and dataset not in test_set_conf.datasets:
579
+ raise u.InvalidInputError(f"{error_msg_intro}\n Applicable datasets for test set '{test_set}' does not include dataset '{dataset}'.")
457
580
 
581
+ user_attributes = test_set_conf.user_attributes.copy()
582
+ selections = test_set_conf.parameters.copy()
583
+ username, is_internal = user_attributes.pop("username", ""), user_attributes.pop("is_internal", False)
584
+ if test_set_conf.is_authenticated:
585
+ user_cls: type[User] = Authenticator.get_auth_helper().get_func_or_class("User", default_attr=User)
586
+ user = user_cls.Create(username, is_internal=is_internal, **user_attributes)
587
+ elif dataset_conf.scope == DatasetScope.PUBLIC:
588
+ user = None
589
+ else:
590
+ raise u.ConfigurationError(f"{error_msg_intro}\n Non-public datasets require a test set with 'user_attributes' section defined")
591
+
592
+ if dataset_conf.scope == DatasetScope.PRIVATE and not is_internal:
593
+ raise u.ConfigurationError(f"{error_msg_intro}\n Private datasets require a test set with user_attribute 'is_internal' set to true")
594
+
595
+ # always_pandas is set to True for creating CSV files from results (when runquery is True)
458
596
  dag = cls.GenerateDAG(dataset, target_model_name=select, always_pandas=True)
459
- await dag.execute(cls.context_func, user, selections, runquery=runquery, recurse=recurse)
597
+ placeholders = await dag.execute(cls.context_func, user, selections, runquery=runquery, recurse=recurse)
460
598
 
461
- output_folder = u.join_paths(c.TARGET_FOLDER, c.COMPILE_FOLDER, test_set, dataset)
599
+ output_folder = u.join_paths(c.TARGET_FOLDER, c.COMPILE_FOLDER, dataset, test_set)
462
600
  if os.path.exists(output_folder):
463
601
  shutil.rmtree(output_folder)
602
+ os.makedirs(output_folder, exist_ok=True)
603
+
604
+ def write_placeholders() -> None:
605
+ output_filepath = u.join_paths(output_folder, "placeholders.json")
606
+ with open(output_filepath, 'w') as f:
607
+ json.dump(placeholders, f, indent=4)
464
608
 
465
- def write_model_outputs(model: Model) -> None:
609
+ def write_model_outputs(model: _Model) -> None:
466
610
  subfolder = c.DBVIEWS_FOLDER if model.query_file.model_type == ModelType.DBVIEW else c.FEDERATES_FOLDER
467
611
  subpath = u.join_paths(output_folder, subfolder)
468
612
  os.makedirs(subpath, exist_ok=True)
469
- if isinstance(model.compiled_query, SqlModelQuery):
613
+ if isinstance(model.compiled_query, _SqlModelQuery):
470
614
  output_filepath = u.join_paths(subpath, model.name+'.sql')
471
615
  query = model.compiled_query.query
472
616
  with open(output_filepath, 'w') as f:
@@ -475,39 +619,50 @@ class ModelsIO:
475
619
  output_filepath = u.join_paths(subpath, model.name+'.csv')
476
620
  model.result.to_csv(output_filepath, index=False)
477
621
 
478
- all_model_names = dag.get_all_model_names()
622
+ write_placeholders()
623
+ all_model_names = dag.get_all_query_models()
479
624
  coroutines = [asyncio.to_thread(write_model_outputs, dag.models_dict[name]) for name in all_model_names]
480
625
  await asyncio.gather(*coroutines)
481
- return dag.target_model.compiled_query.query
626
+
627
+ if recurse:
628
+ cls.draw_dag(dag, output_folder)
629
+
630
+ if isinstance(dag.target_model, _Model):
631
+ return dag.target_model.compiled_query.query # else return None
482
632
 
483
633
  @classmethod
484
634
  async def WriteOutputs(
485
- cls, dataset: Optional[str], select: Optional[str], all_test_sets: bool, test_set: Optional[str], runquery: bool
635
+ cls, dataset: Optional[str], do_all_datasets: bool, select: Optional[str], test_set: Optional[str], do_all_test_sets: bool,
636
+ runquery: bool
486
637
  ) -> None:
487
- if test_set is None:
488
- test_set = ManifestIO.obj.settings.get(c.TEST_SET_DEFAULT_USED_SETTING, c.DEFAULT_TEST_SET_NAME)
489
-
490
- if all_test_sets:
491
- test_sets = ManifestIO.obj.selection_test_sets.keys()
492
- else:
493
- test_sets = [test_set]
638
+
639
+ def get_applicable_test_sets(dataset: str) -> list[str]:
640
+ applicable_test_sets = []
641
+ for test_set_name, test_set_config in ManifestIO.obj.selection_test_sets.items():
642
+ if test_set_config.datasets is None or dataset in test_set_config.datasets:
643
+ applicable_test_sets.append(test_set_name)
644
+ return applicable_test_sets
494
645
 
495
646
  recurse = True
496
647
  dataset_configs = ManifestIO.obj.datasets
497
- if dataset is None:
498
- selected_models = [(dataset.name, dataset.model) for dataset in dataset_configs.values()]
648
+ if do_all_datasets:
649
+ selected_models = [(dataset, dataset.model) for dataset in dataset_configs.values()]
499
650
  else:
500
651
  if select is None:
501
652
  select = dataset_configs[dataset].model
502
653
  else:
503
654
  recurse = False
504
- selected_models = [(dataset, select)]
655
+ selected_models = [(dataset_configs[dataset], select)]
505
656
 
506
657
  coroutines = []
507
- for test_set in test_sets:
508
- for dataset, select in selected_models:
509
- coroutine = cls.WriteDatasetOutputsGivenTestSet(dataset, select, test_set, runquery, recurse)
510
- coroutines.append(coroutine)
658
+ for dataset_conf, select in selected_models:
659
+ if do_all_test_sets:
660
+ for test_set_name in get_applicable_test_sets(dataset_conf.name):
661
+ coroutine = cls.WriteDatasetOutputsGivenTestSet(dataset_conf, select, test_set_name, runquery, recurse)
662
+ coroutines.append(coroutine)
663
+
664
+ coroutine = cls.WriteDatasetOutputsGivenTestSet(dataset_conf, select, test_set, runquery, recurse)
665
+ coroutines.append(coroutine)
511
666
 
512
667
  queries = await asyncio.gather(*coroutines)
513
668
  if not recurse and len(queries) == 1 and isinstance(queries[0], str):