squirrels 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of squirrels might be problematic. Click here for more details.

Files changed (48) hide show
  1. squirrels/__init__.py +11 -4
  2. squirrels/_api_response_models.py +118 -0
  3. squirrels/_api_server.py +140 -75
  4. squirrels/_authenticator.py +10 -8
  5. squirrels/_command_line.py +17 -11
  6. squirrels/_connection_set.py +2 -2
  7. squirrels/_constants.py +13 -5
  8. squirrels/_initializer.py +23 -13
  9. squirrels/_manifest.py +20 -10
  10. squirrels/_models.py +295 -142
  11. squirrels/_parameter_configs.py +195 -57
  12. squirrels/_parameter_sets.py +14 -17
  13. squirrels/_py_module.py +2 -4
  14. squirrels/_seeds.py +38 -0
  15. squirrels/_utils.py +41 -33
  16. squirrels/arguments/run_time_args.py +76 -34
  17. squirrels/data_sources.py +172 -51
  18. squirrels/dateutils.py +3 -3
  19. squirrels/package_data/assets/index.js +14 -14
  20. squirrels/package_data/base_project/connections.yml +1 -1
  21. squirrels/package_data/base_project/database/expenses.db +0 -0
  22. squirrels/package_data/base_project/docker/Dockerfile +1 -1
  23. squirrels/package_data/base_project/environcfg.yml +7 -7
  24. squirrels/package_data/base_project/models/dbviews/database_view1.py +25 -14
  25. squirrels/package_data/base_project/models/dbviews/database_view1.sql +21 -14
  26. squirrels/package_data/base_project/models/federates/dataset_example.py +6 -5
  27. squirrels/package_data/base_project/models/federates/dataset_example.sql +1 -1
  28. squirrels/package_data/base_project/parameters.yml +57 -28
  29. squirrels/package_data/base_project/pyconfigs/auth.py +11 -10
  30. squirrels/package_data/base_project/pyconfigs/connections.py +6 -8
  31. squirrels/package_data/base_project/pyconfigs/context.py +49 -33
  32. squirrels/package_data/base_project/pyconfigs/parameters.py +62 -30
  33. squirrels/package_data/base_project/seeds/seed_categories.csv +6 -0
  34. squirrels/package_data/base_project/seeds/seed_subcategories.csv +15 -0
  35. squirrels/package_data/base_project/squirrels.yml.j2 +37 -20
  36. squirrels/parameter_options.py +30 -10
  37. squirrels/parameters.py +300 -70
  38. squirrels/user_base.py +3 -13
  39. squirrels-0.3.0.dist-info/LICENSE +201 -0
  40. {squirrels-0.2.2.dist-info → squirrels-0.3.0.dist-info}/METADATA +15 -15
  41. squirrels-0.3.0.dist-info/RECORD +56 -0
  42. squirrels/package_data/base_project/seeds/mocks/category.csv +0 -3
  43. squirrels/package_data/base_project/seeds/mocks/max_filter.csv +0 -2
  44. squirrels/package_data/base_project/seeds/mocks/subcategory.csv +0 -6
  45. squirrels-0.2.2.dist-info/LICENSE +0 -22
  46. squirrels-0.2.2.dist-info/RECORD +0 -55
  47. {squirrels-0.2.2.dist-info → squirrels-0.3.0.dist-info}/WHEEL +0 -0
  48. {squirrels-0.2.2.dist-info → squirrels-0.3.0.dist-info}/entry_points.txt +0 -0
squirrels/_models.py CHANGED
@@ -1,21 +1,26 @@
1
1
  from __future__ import annotations
2
- from typing import Union, Optional, Callable, Iterable, Any
2
+ from typing import Optional, Callable, Iterable, Any
3
3
  from dataclasses import dataclass, field
4
+ from abc import ABCMeta, abstractmethod
4
5
  from enum import Enum
5
6
  from pathlib import Path
6
- import sqlite3, pandas as pd, asyncio, os, shutil
7
+ from sqlalchemy import create_engine, text, Connection
8
+ import asyncio, os, shutil, pandas as pd, json
9
+ import matplotlib.pyplot as plt, networkx as nx
7
10
 
8
11
  from . import _constants as c, _utils as u, _py_module as pm
9
12
  from .arguments.run_time_args import ContextArgs, ModelDepsArgs, ModelArgs
10
13
  from ._authenticator import User, Authenticator
11
14
  from ._connection_set import ConnectionSetIO
12
- from ._manifest import ManifestIO, DatasetsConfig
15
+ from ._manifest import ManifestIO, DatasetsConfig, DatasetScope
13
16
  from ._parameter_sets import ParameterConfigsSetIO, ParameterSet
17
+ from ._seeds import SeedsIO
14
18
  from ._timer import timer, time
15
19
 
16
20
  class ModelType(Enum):
17
21
  DBVIEW = 1
18
22
  FEDERATE = 2
23
+ SEED = 3
19
24
 
20
25
  class QueryType(Enum):
21
26
  SQL = 0
@@ -27,7 +32,7 @@ class Materialization(Enum):
27
32
 
28
33
 
29
34
  @dataclass
30
- class SqlModelConfig:
35
+ class _SqlModelConfig:
31
36
  ## Applicable for dbview models
32
37
  connection_name: str
33
38
 
@@ -58,63 +63,132 @@ ContextFunc = Callable[[dict[str, Any], ContextArgs], None]
58
63
 
59
64
 
60
65
  @dataclass(frozen=True)
61
- class RawQuery:
66
+ class _RawQuery(metaclass=ABCMeta):
62
67
  pass
63
68
 
64
69
  @dataclass(frozen=True)
65
- class RawSqlQuery(RawQuery):
70
+ class _RawSqlQuery(_RawQuery):
66
71
  query: str
67
72
 
68
73
  @dataclass(frozen=True)
69
- class RawPyQuery(RawQuery):
74
+ class _RawPyQuery(_RawQuery):
70
75
  query: Callable[[Any], pd.DataFrame]
71
76
  dependencies_func: Callable[[Any], Iterable]
72
77
 
73
78
 
74
79
  @dataclass
75
- class Query:
80
+ class _Query(metaclass=ABCMeta):
76
81
  query: Any
77
82
 
78
83
  @dataclass
79
- class WorkInProgress:
84
+ class _WorkInProgress(_Query):
80
85
  query: None = field(default=None, init=False)
81
86
 
82
87
  @dataclass
83
- class SqlModelQuery(Query):
88
+ class _SqlModelQuery(_Query):
84
89
  query: str
85
- config: SqlModelConfig
90
+ config: _SqlModelConfig
86
91
 
87
92
  @dataclass
88
- class PyModelQuery(Query):
93
+ class _PyModelQuery(_Query):
89
94
  query: Callable[[], pd.DataFrame]
90
95
 
91
96
 
92
97
  @dataclass(frozen=True)
93
- class QueryFile:
98
+ class _QueryFile:
94
99
  filepath: str
95
100
  model_type: ModelType
96
101
  query_type: QueryType
97
- raw_query: RawQuery
102
+ raw_query: _RawQuery
98
103
 
99
104
 
100
105
  @dataclass
101
- class Model:
106
+ class _Referable(metaclass=ABCMeta):
102
107
  name: str
103
- query_file: QueryFile
104
108
  is_target: bool = field(default=False, init=False)
105
- compiled_query: Optional[Query] = field(default=None, init=False)
106
109
 
107
110
  needs_sql_table: bool = field(default=False, init=False)
108
- needs_pandas: bool = False
111
+ needs_pandas: bool = field(default=False, init=False)
109
112
  result: Optional[pd.DataFrame] = field(default=None, init=False, repr=False)
110
-
111
- wait_count: int = field(default=0, init=False, repr=False)
112
- upstreams: dict[str, Model] = field(default_factory=dict, init=False, repr=False)
113
- downstreams: dict[str, Model] = field(default_factory=dict, init=False, repr=False)
114
113
 
114
+ wait_count: int = field(default=0, init=False, repr=False)
115
115
  confirmed_no_cycles: bool = field(default=False, init=False)
116
+ upstreams: dict[str, _Referable] = field(default_factory=dict, init=False, repr=False)
117
+ downstreams: dict[str, _Referable] = field(default_factory=dict, init=False, repr=False)
118
+
119
+ @abstractmethod
120
+ def get_model_type(self) -> ModelType:
121
+ pass
122
+
123
+ async def compile(
124
+ self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str, _Referable], recurse: bool
125
+ ) -> None:
126
+ pass
127
+
128
+ @abstractmethod
129
+ def get_terminal_nodes(self, depencency_path: set[str]) -> set[str]:
130
+ pass
131
+
132
+ def _load_pandas_to_table(self, df: pd.DataFrame, conn: Connection) -> None:
133
+ df.to_sql(self.name, conn, index=False)
134
+
135
+ def _load_table_to_pandas(self, conn: Connection) -> pd.DataFrame:
136
+ query = f"SELECT * FROM {self.name}"
137
+ return pd.read_sql(query, conn)
138
+
139
+ async def _trigger(self, conn: Connection, placeholders: dict = {}) -> None:
140
+ self.wait_count -= 1
141
+ if (self.wait_count == 0):
142
+ await self.run_model(conn, placeholders)
143
+
144
+ @abstractmethod
145
+ async def run_model(self, conn: Connection, placeholders: dict = {}) -> None:
146
+ coroutines = []
147
+ for model in self.downstreams.values():
148
+ coroutines.append(model._trigger(conn, placeholders))
149
+ await asyncio.gather(*coroutines)
150
+
151
+ def retrieve_dependent_query_models(self, dependent_model_names: set[str]) -> None:
152
+ pass
153
+
154
+ def get_max_path_length_to_target(self) -> int:
155
+ if not hasattr(self, "max_path_len_to_target"):
156
+ path_lengths = []
157
+ for child_model in self.downstreams.values():
158
+ path_lengths.append(child_model.get_max_path_length_to_target()+1)
159
+ if len(path_lengths) > 0:
160
+ self.max_path_len_to_target = max(path_lengths)
161
+ else:
162
+ self.max_path_len_to_target = 0 if self.is_target else None
163
+ return self.max_path_len_to_target
116
164
 
117
- def _add_upstream(self, other: Model) -> None:
165
+
166
+ @dataclass
167
+ class _Seed(_Referable):
168
+ result: pd.DataFrame
169
+
170
+ def get_model_type(self) -> ModelType:
171
+ return ModelType.SEED
172
+
173
+ def get_terminal_nodes(self, depencency_path: set[str]) -> set[str]:
174
+ return {self.name}
175
+
176
+ async def run_model(self, conn: Connection, placeholders: dict = {}) -> None:
177
+ if self.needs_sql_table:
178
+ await asyncio.to_thread(self._load_pandas_to_table, self.result, conn)
179
+ await super().run_model(conn, placeholders)
180
+
181
+
182
+ @dataclass
183
+ class _Model(_Referable):
184
+ query_file: _QueryFile
185
+
186
+ compiled_query: Optional[_Query] = field(default=None, init=False)
187
+
188
+ def get_model_type(self) -> ModelType:
189
+ return self.query_file.model_type
190
+
191
+ def _add_upstream(self, other: _Referable) -> None:
118
192
  self.upstreams[other.name] = other
119
193
  other.downstreams[self.name] = self
120
194
 
@@ -137,17 +211,20 @@ class Model:
137
211
  materialized = federate_config.materialized
138
212
  return Materialization[materialized.upper()]
139
213
 
140
- async def _compile_sql_model(self, ctx: dict[str, Any], ctx_args: ContextArgs) -> tuple[SqlModelQuery, set]:
141
- assert(isinstance(self.query_file.raw_query, RawSqlQuery))
214
+ async def _compile_sql_model(
215
+ self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any]
216
+ ) -> tuple[_SqlModelQuery, set]:
217
+ assert(isinstance(self.query_file.raw_query, _RawSqlQuery))
218
+
142
219
  raw_query = self.query_file.raw_query.query
143
-
144
220
  connection_name = self._get_dbview_conn_name()
145
221
  materialized = self._get_materialized()
146
- configuration = SqlModelConfig(connection_name, materialized)
222
+ configuration = _SqlModelConfig(connection_name, materialized)
223
+ is_placeholder = lambda x: x in placeholders
147
224
  kwargs = {
148
- "proj_vars": ctx_args.proj_vars, "env_vars": ctx_args.env_vars,
149
- "user": ctx_args.user, "prms": ctx_args.prms, "traits": ctx_args.traits,
150
- "ctx": ctx, "config": configuration.set_attribute
225
+ "proj_vars": ctx_args.proj_vars, "env_vars": ctx_args.env_vars, "user": ctx_args.user, "prms": ctx_args.prms,
226
+ "traits": ctx_args.traits, "ctx": ctx, "is_placeholder": is_placeholder, "set_placeholder": ctx_args.set_placeholder,
227
+ "config": configuration.set_attribute, "is_param_enabled": ctx_args.param_exists
151
228
  }
152
229
  dependencies = set()
153
230
  if self.query_file.model_type == ModelType.FEDERATE:
@@ -157,16 +234,21 @@ class Model:
157
234
  kwargs["ref"] = ref
158
235
 
159
236
  try:
160
- query = await asyncio.to_thread(u.render_string, raw_query, kwargs)
237
+ query = await asyncio.to_thread(u.render_string, raw_query, **kwargs)
161
238
  except Exception as e:
162
239
  raise u.FileExecutionError(f'Failed to compile sql model "{self.name}"', e)
163
240
 
164
- compiled_query = SqlModelQuery(query, configuration)
241
+ compiled_query = _SqlModelQuery(query, configuration)
165
242
  return compiled_query, dependencies
166
243
 
167
- async def _compile_python_model(self, ctx: dict[str, Any], ctx_args: ContextArgs) -> tuple[PyModelQuery, set]:
168
- assert(isinstance(self.query_file.raw_query, RawPyQuery))
169
- sqrl_args = ModelDepsArgs(ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits, ctx)
244
+ async def _compile_python_model(
245
+ self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any]
246
+ ) -> tuple[_PyModelQuery, set]:
247
+ assert(isinstance(self.query_file.raw_query, _RawPyQuery))
248
+
249
+ sqrl_args = ModelDepsArgs(
250
+ ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits, placeholders, ctx
251
+ )
170
252
  try:
171
253
  dependencies = await asyncio.to_thread(self.query_file.raw_query.dependencies_func, sqrl_args)
172
254
  except Exception as e:
@@ -175,34 +257,42 @@ class Model:
175
257
  dbview_conn_name = self._get_dbview_conn_name()
176
258
  connections = ConnectionSetIO.obj.get_engines_as_dict()
177
259
  ref = lambda x: self.upstreams[x].result
178
- sqrl_args = ModelArgs(ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits,
179
- ctx, dbview_conn_name, connections, ref, set(dependencies))
260
+ sqrl_args = ModelArgs(
261
+ ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits, placeholders, ctx,
262
+ dbview_conn_name, connections, dependencies, ref
263
+ )
180
264
 
181
265
  def compiled_query():
182
266
  try:
183
- return self.query_file.raw_query.query(sqrl=sqrl_args)
267
+ raw_query: _RawPyQuery = self.query_file.raw_query
268
+ return raw_query.query(sqrl=sqrl_args)
184
269
  except Exception as e:
185
270
  raise u.FileExecutionError(f'Failed to run "{c.MAIN_FUNC}" function for python model "{self.name}"', e)
186
271
 
187
- return PyModelQuery(compiled_query), dependencies
272
+ return _PyModelQuery(compiled_query), dependencies
188
273
 
189
- async def compile(self, ctx: dict[str, Any], ctx_args: ContextArgs, models_dict: dict[str, Model], recurse: bool) -> None:
274
+ async def compile(
275
+ self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str, _Referable], recurse: bool
276
+ ) -> None:
190
277
  if self.compiled_query is not None:
191
278
  return
192
279
  else:
193
- self.compiled_query = WorkInProgress()
280
+ self.compiled_query = _WorkInProgress()
194
281
 
195
282
  start = time.time()
283
+
196
284
  if self.query_file.query_type == QueryType.SQL:
197
- compiled_query, dependencies = await self._compile_sql_model(ctx, ctx_args)
285
+ compiled_query, dependencies = await self._compile_sql_model(ctx, ctx_args, placeholders)
198
286
  elif self.query_file.query_type == QueryType.PYTHON:
199
- compiled_query, dependencies = await self._compile_python_model(ctx, ctx_args)
287
+ compiled_query, dependencies = await self._compile_python_model(ctx, ctx_args, placeholders)
200
288
  else:
201
289
  raise NotImplementedError(f"Query type not supported: {self.query_file.query_type}")
202
290
 
203
291
  self.compiled_query = compiled_query
204
292
  self.wait_count = len(dependencies)
205
- timer.add_activity_time(f"compiling model '{self.name}'", start)
293
+
294
+ model_type = self.get_model_type().name.lower()
295
+ timer.add_activity_time(f"compiling {model_type} model '{self.name}'", start)
206
296
 
207
297
  if not recurse:
208
298
  return
@@ -211,7 +301,7 @@ class Model:
211
301
  coroutines = []
212
302
  for dep_model in dep_models:
213
303
  self._add_upstream(dep_model)
214
- coro = dep_model.compile(ctx, ctx_args, models_dict, recurse)
304
+ coro = dep_model.compile(ctx, ctx_args, placeholders, models_dict, recurse)
215
305
  coroutines.append(coro)
216
306
  await asyncio.gather(*coroutines)
217
307
 
@@ -234,29 +324,16 @@ class Model:
234
324
 
235
325
  self.confirmed_no_cycles = True
236
326
  return terminal_nodes
237
-
238
- def _load_pandas_to_table(self, df: pd.DataFrame, conn: sqlite3.Connection) -> None:
239
- if u.use_duckdb():
240
- conn.execute(f"CREATE TABLE {self.name} AS FROM df")
241
- else:
242
- df.to_sql(self.name, conn, index=False)
243
-
244
- def _load_table_to_pandas(self, conn: sqlite3.Connection) -> pd.DataFrame:
245
- if u.use_duckdb():
246
- return conn.execute(f"FROM {self.name}").df()
247
- else:
248
- query = f"SELECT * FROM {self.name}"
249
- return pd.read_sql(query, conn)
250
327
 
251
- async def _run_sql_model(self, conn: sqlite3.Connection) -> None:
252
- assert(isinstance(self.compiled_query, SqlModelQuery))
328
+ async def _run_sql_model(self, conn: Connection, placeholders: dict = {}) -> None:
329
+ assert(isinstance(self.compiled_query, _SqlModelQuery))
253
330
  config = self.compiled_query.config
254
331
  query = self.compiled_query.query
255
332
 
256
333
  if self.query_file.model_type == ModelType.DBVIEW:
257
334
  def run_sql_query():
258
335
  try:
259
- return ConnectionSetIO.obj.run_sql_query_from_conn_name(query, config.connection_name)
336
+ return ConnectionSetIO.obj.run_sql_query_from_conn_name(query, config.connection_name, placeholders)
260
337
  except RuntimeError as e:
261
338
  raise u.FileExecutionError(f'Failed to run dbview sql model "{self.name}"', e)
262
339
 
@@ -268,7 +345,7 @@ class Model:
268
345
  def create_table():
269
346
  create_query = config.get_sql_for_create(self.name, query)
270
347
  try:
271
- return conn.execute(create_query)
348
+ return conn.execute(text(create_query), placeholders)
272
349
  except Exception as e:
273
350
  raise u.FileExecutionError(f'Failed to run federate sql model "{self.name}"', e)
274
351
 
@@ -276,8 +353,8 @@ class Model:
276
353
  if self.needs_pandas or self.is_target:
277
354
  self.result = await asyncio.to_thread(self._load_table_to_pandas, conn)
278
355
 
279
- async def _run_python_model(self, conn: sqlite3.Connection) -> None:
280
- assert(isinstance(self.compiled_query, PyModelQuery))
356
+ async def _run_python_model(self, conn: Connection) -> None:
357
+ assert(isinstance(self.compiled_query, _PyModelQuery))
281
358
 
282
359
  df = await asyncio.to_thread(self.compiled_query.query)
283
360
  if self.needs_sql_table:
@@ -285,92 +362,86 @@ class Model:
285
362
  if self.needs_pandas or self.is_target:
286
363
  self.result = df
287
364
 
288
- async def run_model(self, conn: sqlite3.Connection) -> None:
365
+ async def run_model(self, conn: Connection, placeholders: dict = {}) -> None:
289
366
  start = time.time()
367
+
290
368
  if self.query_file.query_type == QueryType.SQL:
291
- await self._run_sql_model(conn)
369
+ await self._run_sql_model(conn, placeholders)
292
370
  elif self.query_file.query_type == QueryType.PYTHON:
293
371
  await self._run_python_model(conn)
294
- timer.add_activity_time(f"running model '{self.name}'", start)
295
372
 
296
- coroutines = []
297
- for model in self.downstreams.values():
298
- coroutines.append(model.trigger(conn))
299
- await asyncio.gather(*coroutines)
300
-
301
- async def trigger(self, conn: sqlite3.Connection) -> None:
302
- self.wait_count -= 1
303
- if (self.wait_count == 0):
304
- await self.run_model(conn)
373
+ model_type = self.get_model_type().name.lower()
374
+ timer.add_activity_time(f"running {model_type} model '{self.name}'", start)
375
+
376
+ await super().run_model(conn, placeholders)
305
377
 
306
- def fill_dependent_model_names(self, dependent_model_names: set[str]) -> None:
378
+ def retrieve_dependent_query_models(self, dependent_model_names: set[str]) -> None:
307
379
  if self.name not in dependent_model_names:
308
380
  dependent_model_names.add(self.name)
309
381
  for dep_model in self.upstreams.values():
310
- dep_model.fill_dependent_model_names(dependent_model_names)
382
+ dep_model.retrieve_dependent_query_models(dependent_model_names)
311
383
 
312
384
 
313
385
  @dataclass
314
- class DAG:
386
+ class _DAG:
315
387
  dataset: DatasetsConfig
316
- target_model: Model
317
- models_dict: dict[str, Model]
388
+ target_model: _Referable
389
+ models_dict: dict[str, _Referable]
318
390
  parameter_set: Optional[ParameterSet] = field(default=None, init=False)
391
+ placeholders: dict[str, Any] = field(init=False, default_factory=dict)
319
392
 
320
393
  def apply_selections(
321
394
  self, user: Optional[User], selections: dict[str, str], *, updates_only: bool = False, request_version: Optional[int] = None
322
395
  ) -> None:
323
396
  start = time.time()
324
397
  dataset_params = self.dataset.parameters
325
- parameter_set = ParameterConfigsSetIO.obj.apply_selections(dataset_params, selections, user, updates_only=updates_only,
326
- request_version=request_version)
398
+ parameter_set = ParameterConfigsSetIO.obj.apply_selections(
399
+ dataset_params, selections, user, updates_only=updates_only, request_version=request_version
400
+ )
327
401
  self.parameter_set = parameter_set
328
- timer.add_activity_time(f"applying selections for dataset", start)
402
+ timer.add_activity_time(f"applying selections for dataset '{self.dataset.name}'", start)
329
403
 
330
404
  def _compile_context(self, context_func: ContextFunc, user: Optional[User]) -> tuple[dict[str, Any], ContextArgs]:
331
405
  start = time.time()
332
406
  context = {}
333
407
  param_args = ParameterConfigsSetIO.args
334
408
  prms = self.parameter_set.get_parameters_as_dict()
335
- args = ContextArgs(param_args.proj_vars, param_args.env_vars, user, prms, self.dataset.traits)
409
+ args = ContextArgs(param_args.proj_vars, param_args.env_vars, user, prms, self.dataset.traits, self.placeholders)
336
410
  try:
337
411
  context_func(ctx=context, sqrl=args)
338
412
  except Exception as e:
339
- raise u.FileExecutionError(f'Failed to run {c.CONTEXT_FILE} for dataset "{self.dataset}"', e)
340
- timer.add_activity_time(f"running context.py for dataset", start)
413
+ raise u.FileExecutionError(f'Failed to run {c.CONTEXT_FILE} for dataset "{self.dataset.name}"', e)
414
+ timer.add_activity_time(f"running context.py for dataset '{self.dataset.name}'", start)
341
415
  return context, args
342
416
 
343
417
  async def _compile_models(self, context: dict[str, Any], ctx_args: ContextArgs, recurse: bool) -> None:
344
- await self.target_model.compile(context, ctx_args, self.models_dict, recurse)
418
+ await self.target_model.compile(context, ctx_args, self.placeholders, self.models_dict, recurse)
345
419
 
346
420
  def _get_terminal_nodes(self) -> set[str]:
347
421
  start = time.time()
348
422
  terminal_nodes = self.target_model.get_terminal_nodes(set())
349
423
  for model in self.models_dict.values():
350
424
  model.confirmed_no_cycles = False
351
- timer.add_activity_time(f"validating no cycles in models dependencies", start)
425
+ timer.add_activity_time(f"validating no cycles in model dependencies", start)
352
426
  return terminal_nodes
353
427
 
354
- async def _run_models(self, terminal_nodes: set[str]) -> None:
355
- if u.use_duckdb():
356
- import duckdb
357
- conn = duckdb.connect()
358
- else:
359
- conn = sqlite3.connect(":memory:", check_same_thread=False)
428
+ async def _run_models(self, terminal_nodes: set[str], placeholders: dict = {}) -> None:
429
+ conn_url = "duckdb:///" if u.use_duckdb() else "sqlite:///?check_same_thread=False"
430
+ engine = create_engine(conn_url)
360
431
 
361
- try:
432
+ with engine.connect() as conn:
362
433
  coroutines = []
363
434
  for model_name in terminal_nodes:
364
435
  model = self.models_dict[model_name]
365
- coroutines.append(model.run_model(conn))
436
+ coroutines.append(model.run_model(conn, placeholders))
366
437
  await asyncio.gather(*coroutines)
367
- finally:
368
- conn.close()
438
+
439
+ engine.dispose()
369
440
 
370
441
  async def execute(
371
442
  self, context_func: ContextFunc, user: Optional[User], selections: dict[str, str], *, request_version: Optional[int] = None,
372
443
  runquery: bool = True, recurse: bool = True
373
- ) -> None:
444
+ ) -> dict[str, Any]:
374
445
  recurse = (recurse or runquery)
375
446
 
376
447
  self.apply_selections(user, selections, request_version=request_version)
@@ -381,17 +452,35 @@ class DAG:
381
452
 
382
453
  terminal_nodes = self._get_terminal_nodes()
383
454
 
455
+ placeholders = ctx_args._placeholders.copy()
384
456
  if runquery:
385
- await self._run_models(terminal_nodes)
457
+ await self._run_models(terminal_nodes, placeholders)
458
+
459
+ return placeholders
386
460
 
387
- def get_all_model_names(self) -> set[str]:
461
+ def get_all_query_models(self) -> set[str]:
388
462
  all_model_names = set()
389
- self.target_model.fill_dependent_model_names(all_model_names)
463
+ self.target_model.retrieve_dependent_query_models(all_model_names)
390
464
  return all_model_names
391
-
465
+
466
+ def to_networkx_graph(self) -> nx.DiGraph:
467
+ G = nx.DiGraph()
468
+
469
+ for model_name, model in self.models_dict.items():
470
+ model_type = model.get_model_type()
471
+ level = model.get_max_path_length_to_target()
472
+ if level is not None:
473
+ G.add_node(model_name, layer=-level, model_type=model_type)
474
+
475
+ for model_name in G.nodes:
476
+ model = self.models_dict[model_name]
477
+ for dep_model_name in model.downstreams:
478
+ G.add_edge(model_name, dep_model_name)
479
+
480
+ return G
392
481
 
393
482
  class ModelsIO:
394
- raw_queries_by_model: dict[str, QueryFile]
483
+ raw_queries_by_model: dict[str, _QueryFile]
395
484
  context_func: ContextFunc
396
485
 
397
486
  @classmethod
@@ -400,7 +489,6 @@ class ModelsIO:
400
489
  cls.raw_queries_by_model = {}
401
490
 
402
491
  def populate_raw_queries_for_type(folder_path: Path, model_type: ModelType):
403
-
404
492
  def populate_from_file(dp, file):
405
493
  query_type = None
406
494
  filepath = os.path.join(dp, file)
@@ -409,13 +497,13 @@ class ModelsIO:
409
497
  query_type = QueryType.PYTHON
410
498
  module = pm.PyModule(filepath)
411
499
  dependencies_func = module.get_func_or_class(c.DEP_FUNC, default_attr=lambda x: [])
412
- raw_query = RawPyQuery(module.get_func_or_class(c.MAIN_FUNC), dependencies_func)
500
+ raw_query = _RawPyQuery(module.get_func_or_class(c.MAIN_FUNC), dependencies_func)
413
501
  elif extension == '.sql':
414
502
  query_type = QueryType.SQL
415
- raw_query = RawSqlQuery(u.read_file(filepath))
503
+ raw_query = _RawSqlQuery(u.read_file(filepath))
416
504
 
417
505
  if query_type is not None:
418
- query_file = QueryFile(filepath, model_type, query_type, raw_query)
506
+ query_file = _QueryFile(filepath, model_type, query_type, raw_query)
419
507
  if file_stem in cls.raw_queries_by_model:
420
508
  conflicts = [cls.raw_queries_by_model[file_stem].filepath, filepath]
421
509
  raise u.ConfigurationError(f"Multiple models found for '{file_stem}': {conflicts}")
@@ -431,44 +519,98 @@ class ModelsIO:
431
519
  federates_path = u.join_paths(c.MODELS_FOLDER, c.FEDERATES_FOLDER)
432
520
  populate_raw_queries_for_type(federates_path, ModelType.FEDERATE)
433
521
 
434
- context_path = u.join_paths(c.PYCONFIG_FOLDER, c.CONTEXT_FILE)
522
+ context_path = u.join_paths(c.PYCONFIGS_FOLDER, c.CONTEXT_FILE)
435
523
  cls.context_func = pm.PyModule(context_path).get_func_or_class(c.MAIN_FUNC, default_attr=lambda x, y: None)
436
524
 
437
- timer.add_activity_time("loading models and/or context.py", start)
525
+ timer.add_activity_time("loading files for models and context.py", start)
438
526
 
439
527
  @classmethod
440
- def GenerateDAG(cls, dataset: str, *, target_model_name: Optional[str] = None, always_pandas: bool = False) -> DAG:
441
- models_dict = {key: Model(key, val, needs_pandas=always_pandas) for key, val in cls.raw_queries_by_model.items()}
528
+ def GenerateDAG(cls, dataset: str, *, target_model_name: Optional[str] = None, always_pandas: bool = False) -> _DAG:
529
+ seeds_dict = SeedsIO.obj.get_dataframes()
530
+
531
+ models_dict: dict[str, _Referable] = {key: _Seed(key, df) for key, df in seeds_dict.items()}
532
+ for key, val in cls.raw_queries_by_model.items():
533
+ models_dict[key] = _Model(key, val)
534
+ models_dict[key].needs_pandas = always_pandas
442
535
 
443
536
  dataset_config = ManifestIO.obj.datasets[dataset]
444
537
  target_model_name = dataset_config.model if target_model_name is None else target_model_name
445
538
  target_model = models_dict[target_model_name]
446
539
  target_model.is_target = True
447
540
 
448
- return DAG(dataset_config, target_model, models_dict)
541
+ return _DAG(dataset_config, target_model, models_dict)
449
542
 
450
543
  @classmethod
451
- async def WriteDatasetOutputsGivenTestSet(cls, dataset: str, select: str, test_set: str, runquery: bool, recurse: bool) -> Any:
452
- test_set_conf = ManifestIO.obj.selection_test_sets[test_set]
453
- user_attributes = test_set_conf.user_attributes
454
- selections = test_set_conf.parameters
544
+ def draw_dag(cls, dag: _DAG, output_folder: Path) -> None:
545
+ color_map = {ModelType.SEED: "green", ModelType.DBVIEW: "red", ModelType.FEDERATE: "skyblue"}
546
+
547
+ G = dag.to_networkx_graph()
548
+
549
+ fig, _ = plt.subplots()
550
+ pos = nx.multipartite_layout(G, subset_key="layer")
551
+ colors = [color_map[node[1]] for node in G.nodes(data="model_type")]
552
+ nx.draw(G, pos=pos, node_shape='^', node_size=1000, node_color=colors, arrowsize=20)
553
+
554
+ y_values = [val[1] for val in pos.values()]
555
+ scale = max(y_values) - min(y_values) if len(y_values) > 0 else 0
556
+ label_pos = {key: (val[0], val[1]-0.002-0.1*scale) for key, val in pos.items()}
557
+ nx.draw_networkx_labels(G, pos=label_pos, font_size=8)
558
+
559
+ fig.tight_layout()
560
+ plt.margins(x=0.1, y=0.1)
561
+ plt.savefig(u.join_paths(output_folder, "dag.png"))
562
+ plt.close(fig)
563
+
564
+ @classmethod
565
+ async def WriteDatasetOutputsGivenTestSet(
566
+ cls, dataset_conf: DatasetsConfig, select: str, test_set: Optional[str], runquery: bool, recurse: bool
567
+ ) -> Any:
568
+ dataset = dataset_conf.name
569
+ default_test_set, default_test_set_conf = ManifestIO.obj.get_default_test_set(dataset)
570
+ if test_set is None or test_set == default_test_set:
571
+ test_set, test_set_conf = default_test_set, default_test_set_conf
572
+ elif test_set in ManifestIO.obj.selection_test_sets:
573
+ test_set_conf = ManifestIO.obj.selection_test_sets[test_set]
574
+ else:
575
+ raise u.InvalidInputError(f"No test set named '{test_set}' was found when compiling dataset '{dataset}'. The test set must be defined if not default for dataset.")
455
576
 
456
- username, is_internal = user_attributes.get("username", ""), user_attributes.get("is_internal", False)
457
- user_cls: type[User] = Authenticator.get_auth_helper().get_func_or_class("User", default_attr=User)
458
- user = user_cls.Create(username, test_set_conf.user_attributes, is_internal=is_internal)
577
+ error_msg_intro = f"Cannot compile dataset '{dataset}' with test set '{test_set}'."
578
+ if test_set_conf.datasets is not None and dataset not in test_set_conf.datasets:
579
+ raise u.InvalidInputError(f"{error_msg_intro}\n Applicable datasets for test set '{test_set}' does not include dataset '{dataset}'.")
459
580
 
581
+ user_attributes = test_set_conf.user_attributes.copy()
582
+ selections = test_set_conf.parameters.copy()
583
+ username, is_internal = user_attributes.pop("username", ""), user_attributes.pop("is_internal", False)
584
+ if test_set_conf.is_authenticated:
585
+ user_cls: type[User] = Authenticator.get_auth_helper().get_func_or_class("User", default_attr=User)
586
+ user = user_cls.Create(username, is_internal=is_internal, **user_attributes)
587
+ elif dataset_conf.scope == DatasetScope.PUBLIC:
588
+ user = None
589
+ else:
590
+ raise u.ConfigurationError(f"{error_msg_intro}\n Non-public datasets require a test set with 'user_attributes' section defined")
591
+
592
+ if dataset_conf.scope == DatasetScope.PRIVATE and not is_internal:
593
+ raise u.ConfigurationError(f"{error_msg_intro}\n Private datasets require a test set with user_attribute 'is_internal' set to true")
594
+
595
+ # always_pandas is set to True for creating CSV files from results (when runquery is True)
460
596
  dag = cls.GenerateDAG(dataset, target_model_name=select, always_pandas=True)
461
- await dag.execute(cls.context_func, user, selections, runquery=runquery, recurse=recurse)
597
+ placeholders = await dag.execute(cls.context_func, user, selections, runquery=runquery, recurse=recurse)
462
598
 
463
- output_folder = u.join_paths(c.TARGET_FOLDER, c.COMPILE_FOLDER, test_set, dataset)
599
+ output_folder = u.join_paths(c.TARGET_FOLDER, c.COMPILE_FOLDER, dataset, test_set)
464
600
  if os.path.exists(output_folder):
465
601
  shutil.rmtree(output_folder)
602
+ os.makedirs(output_folder, exist_ok=True)
603
+
604
+ def write_placeholders() -> None:
605
+ output_filepath = u.join_paths(output_folder, "placeholders.json")
606
+ with open(output_filepath, 'w') as f:
607
+ json.dump(placeholders, f, indent=4)
466
608
 
467
- def write_model_outputs(model: Model) -> None:
609
+ def write_model_outputs(model: _Model) -> None:
468
610
  subfolder = c.DBVIEWS_FOLDER if model.query_file.model_type == ModelType.DBVIEW else c.FEDERATES_FOLDER
469
611
  subpath = u.join_paths(output_folder, subfolder)
470
612
  os.makedirs(subpath, exist_ok=True)
471
- if isinstance(model.compiled_query, SqlModelQuery):
613
+ if isinstance(model.compiled_query, _SqlModelQuery):
472
614
  output_filepath = u.join_paths(subpath, model.name+'.sql')
473
615
  query = model.compiled_query.query
474
616
  with open(output_filepath, 'w') as f:
@@ -477,39 +619,50 @@ class ModelsIO:
477
619
  output_filepath = u.join_paths(subpath, model.name+'.csv')
478
620
  model.result.to_csv(output_filepath, index=False)
479
621
 
480
- all_model_names = dag.get_all_model_names()
622
+ write_placeholders()
623
+ all_model_names = dag.get_all_query_models()
481
624
  coroutines = [asyncio.to_thread(write_model_outputs, dag.models_dict[name]) for name in all_model_names]
482
625
  await asyncio.gather(*coroutines)
483
- return dag.target_model.compiled_query.query
626
+
627
+ if recurse:
628
+ cls.draw_dag(dag, output_folder)
629
+
630
+ if isinstance(dag.target_model, _Model):
631
+ return dag.target_model.compiled_query.query # else return None
484
632
 
485
633
  @classmethod
486
634
  async def WriteOutputs(
487
- cls, dataset: Optional[str], select: Optional[str], all_test_sets: bool, test_set: Optional[str], runquery: bool
635
+ cls, dataset: Optional[str], do_all_datasets: bool, select: Optional[str], test_set: Optional[str], do_all_test_sets: bool,
636
+ runquery: bool
488
637
  ) -> None:
489
- if test_set is None:
490
- test_set = ManifestIO.obj.settings.get(c.TEST_SET_DEFAULT_USED_SETTING, c.DEFAULT_TEST_SET_NAME)
491
-
492
- if all_test_sets:
493
- test_sets = ManifestIO.obj.selection_test_sets.keys()
494
- else:
495
- test_sets = [test_set]
638
+
639
+ def get_applicable_test_sets(dataset: str) -> list[str]:
640
+ applicable_test_sets = []
641
+ for test_set_name, test_set_config in ManifestIO.obj.selection_test_sets.items():
642
+ if test_set_config.datasets is None or dataset in test_set_config.datasets:
643
+ applicable_test_sets.append(test_set_name)
644
+ return applicable_test_sets
496
645
 
497
646
  recurse = True
498
647
  dataset_configs = ManifestIO.obj.datasets
499
- if dataset is None:
500
- selected_models = [(dataset.name, dataset.model) for dataset in dataset_configs.values()]
648
+ if do_all_datasets:
649
+ selected_models = [(dataset, dataset.model) for dataset in dataset_configs.values()]
501
650
  else:
502
651
  if select is None:
503
652
  select = dataset_configs[dataset].model
504
653
  else:
505
654
  recurse = False
506
- selected_models = [(dataset, select)]
655
+ selected_models = [(dataset_configs[dataset], select)]
507
656
 
508
657
  coroutines = []
509
- for test_set in test_sets:
510
- for dataset, select in selected_models:
511
- coroutine = cls.WriteDatasetOutputsGivenTestSet(dataset, select, test_set, runquery, recurse)
512
- coroutines.append(coroutine)
658
+ for dataset_conf, select in selected_models:
659
+ if do_all_test_sets:
660
+ for test_set_name in get_applicable_test_sets(dataset_conf.name):
661
+ coroutine = cls.WriteDatasetOutputsGivenTestSet(dataset_conf, select, test_set_name, runquery, recurse)
662
+ coroutines.append(coroutine)
663
+
664
+ coroutine = cls.WriteDatasetOutputsGivenTestSet(dataset_conf, select, test_set, runquery, recurse)
665
+ coroutines.append(coroutine)
513
666
 
514
667
  queries = await asyncio.gather(*coroutines)
515
668
  if not recurse and len(queries) == 1 and isinstance(queries[0], str):