squirrels 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of squirrels might be problematic. Click here for more details.

Files changed (56) hide show
  1. squirrels/__init__.py +7 -3
  2. squirrels/_api_response_models.py +96 -72
  3. squirrels/_api_server.py +375 -201
  4. squirrels/_authenticator.py +23 -22
  5. squirrels/_command_line.py +70 -46
  6. squirrels/_connection_set.py +23 -25
  7. squirrels/_constants.py +29 -78
  8. squirrels/_dashboards_io.py +61 -0
  9. squirrels/_environcfg.py +53 -50
  10. squirrels/_initializer.py +184 -141
  11. squirrels/_manifest.py +168 -195
  12. squirrels/_models.py +159 -292
  13. squirrels/_package_loader.py +7 -8
  14. squirrels/_parameter_configs.py +173 -141
  15. squirrels/_parameter_sets.py +49 -38
  16. squirrels/_py_module.py +7 -7
  17. squirrels/_seeds.py +13 -12
  18. squirrels/_utils.py +114 -54
  19. squirrels/_version.py +1 -1
  20. squirrels/arguments/init_time_args.py +16 -10
  21. squirrels/arguments/run_time_args.py +89 -24
  22. squirrels/dashboards.py +82 -0
  23. squirrels/data_sources.py +212 -232
  24. squirrels/dateutils.py +29 -26
  25. squirrels/package_data/assets/index.css +1 -1
  26. squirrels/package_data/assets/index.js +27 -18
  27. squirrels/package_data/base_project/.gitignore +2 -2
  28. squirrels/package_data/base_project/connections.yml +1 -1
  29. squirrels/package_data/base_project/dashboards/dashboard_example.py +32 -0
  30. squirrels/package_data/base_project/dashboards.yml +10 -0
  31. squirrels/package_data/base_project/docker/.dockerignore +9 -4
  32. squirrels/package_data/base_project/docker/Dockerfile +7 -6
  33. squirrels/package_data/base_project/docker/compose.yml +1 -1
  34. squirrels/package_data/base_project/env.yml +2 -2
  35. squirrels/package_data/base_project/models/dbviews/{database_view1.py → dbview_example.py} +2 -1
  36. squirrels/package_data/base_project/models/dbviews/{database_view1.sql → dbview_example.sql} +3 -2
  37. squirrels/package_data/base_project/models/federates/{dataset_example.py → federate_example.py} +6 -6
  38. squirrels/package_data/base_project/models/federates/{dataset_example.sql → federate_example.sql} +1 -1
  39. squirrels/package_data/base_project/parameters.yml +6 -4
  40. squirrels/package_data/base_project/pyconfigs/auth.py +1 -1
  41. squirrels/package_data/base_project/pyconfigs/connections.py +1 -1
  42. squirrels/package_data/base_project/pyconfigs/context.py +38 -10
  43. squirrels/package_data/base_project/pyconfigs/parameters.py +15 -7
  44. squirrels/package_data/base_project/squirrels.yml.j2 +14 -7
  45. squirrels/package_data/templates/index.html +3 -3
  46. squirrels/parameter_options.py +103 -106
  47. squirrels/parameters.py +347 -195
  48. squirrels/project.py +378 -0
  49. squirrels/user_base.py +14 -6
  50. {squirrels-0.3.3.dist-info → squirrels-0.4.0.dist-info}/METADATA +9 -21
  51. squirrels-0.4.0.dist-info/RECORD +60 -0
  52. squirrels/_timer.py +0 -23
  53. squirrels-0.3.3.dist-info/RECORD +0 -56
  54. {squirrels-0.3.3.dist-info → squirrels-0.4.0.dist-info}/LICENSE +0 -0
  55. {squirrels-0.3.3.dist-info → squirrels-0.4.0.dist-info}/WHEEL +0 -0
  56. {squirrels-0.3.3.dist-info → squirrels-0.4.0.dist-info}/entry_points.txt +0 -0
squirrels/_models.py CHANGED
@@ -1,32 +1,28 @@
1
1
  from __future__ import annotations
2
- from typing import Optional, Callable, Iterable, Any
2
+ from typing import Iterable, Callable, Any
3
3
  from dataclasses import dataclass, field
4
4
  from abc import ABCMeta, abstractmethod
5
5
  from enum import Enum
6
6
  from pathlib import Path
7
7
  from sqlalchemy import create_engine, text, Connection
8
- import asyncio, os, shutil, pandas as pd, json
9
- import matplotlib.pyplot as plt, networkx as nx
8
+ import asyncio, os, time, pandas as pd, networkx as nx
10
9
 
11
10
  from . import _constants as c, _utils as u, _py_module as pm
12
11
  from .arguments.run_time_args import ContextArgs, ModelDepsArgs, ModelArgs
13
- from ._authenticator import User, Authenticator
14
- from ._connection_set import ConnectionSetIO
15
- from ._manifest import ManifestIO, DatasetsConfig, DatasetScope
16
- from ._parameter_sets import ParameterConfigsSetIO, ParameterSet
17
- from ._seeds import SeedsIO
18
- from ._timer import timer, time
12
+ from ._authenticator import User
13
+ from ._connection_set import ConnectionSet
14
+ from ._manifest import ManifestConfig, DatasetConfig
15
+ from ._parameter_sets import ParameterConfigsSet, ParametersArgs, ParameterSet
16
+
17
+ ContextFunc = Callable[[dict[str, Any], ContextArgs], None]
18
+
19
19
 
20
20
  class ModelType(Enum):
21
21
  DBVIEW = 1
22
22
  FEDERATE = 2
23
23
  SEED = 3
24
24
 
25
- class QueryType(Enum):
26
- SQL = 0
27
- PYTHON = 1
28
-
29
- class Materialization(Enum):
25
+ class _Materialization(Enum):
30
26
  TABLE = 0
31
27
  VIEW = 1
32
28
 
@@ -37,52 +33,46 @@ class _SqlModelConfig:
37
33
  connection_name: str
38
34
 
39
35
  ## Applicable for federated models
40
- materialized: Materialization
36
+ materialized: _Materialization
41
37
 
42
- def set_attribute(self, **kwargs) -> str:
43
- connection_name = kwargs.get(c.DBVIEW_CONN_KEY)
38
+ def set_attribute(self, *, connection_name: str | None = None, materialized: str | None = None, **kwargs) -> str:
44
39
  if connection_name is not None:
45
40
  if not isinstance(connection_name, str):
46
41
  raise u.ConfigurationError("The 'connection_name' argument of 'config' macro must be a string")
47
42
  self.connection_name = connection_name
48
43
 
49
- materialized: str = kwargs.get(c.MATERIALIZED_KEY)
50
44
  if materialized is not None:
51
45
  if not isinstance(materialized, str):
52
46
  raise u.ConfigurationError("The 'materialized' argument of 'config' macro must be a string")
53
47
  try:
54
- self.materialized = Materialization[materialized.upper()]
48
+ self.materialized = _Materialization[materialized.upper()]
55
49
  except KeyError as e:
56
- valid_options = [x.name for x in Materialization]
50
+ valid_options = [x.name for x in _Materialization]
57
51
  raise u.ConfigurationError(f"The 'materialized' argument value '{materialized}' is not valid. Must be one of: {valid_options}") from e
58
52
  return ""
59
53
 
60
54
  def get_sql_for_create(self, model_name: str, select_query: str) -> str:
61
- if self.materialized == Materialization.TABLE:
62
- create_prefix = f"CREATE TABLE {model_name} AS\n"
63
- elif self.materialized == Materialization.VIEW:
64
- create_prefix = f"CREATE VIEW {model_name} AS\n"
65
- else:
66
- raise u.ConfigurationError(f"Materialization option not supported: {self.materialized}")
67
-
55
+ create_prefix = f"CREATE {self.materialized.name} {model_name} AS\n"
68
56
  return create_prefix + select_query
69
57
 
70
58
 
71
- ContextFunc = Callable[[dict[str, Any], ContextArgs], None]
72
-
59
+ @dataclass(frozen=True)
60
+ class QueryFile:
61
+ filepath: str
62
+ model_type: ModelType
73
63
 
74
64
  @dataclass(frozen=True)
75
- class _RawQuery(metaclass=ABCMeta):
76
- pass
65
+ class SqlQueryFile(QueryFile):
66
+ raw_query: str
77
67
 
78
68
  @dataclass(frozen=True)
79
- class _RawSqlQuery(_RawQuery):
80
- query: str
69
+ class _RawPyQuery:
70
+ query: Callable[[ModelArgs], pd.DataFrame]
71
+ dependencies_func: Callable[[ModelDepsArgs], Iterable[str]]
81
72
 
82
73
  @dataclass(frozen=True)
83
- class _RawPyQuery(_RawQuery):
84
- query: Callable[[Any], pd.DataFrame]
85
- dependencies_func: Callable[[Any], Iterable]
74
+ class PyQueryFile(QueryFile):
75
+ raw_query: _RawPyQuery
86
76
 
87
77
 
88
78
  @dataclass
@@ -94,43 +84,35 @@ class _WorkInProgress(_Query):
94
84
  query: None = field(default=None, init=False)
95
85
 
96
86
  @dataclass
97
- class _SqlModelQuery(_Query):
87
+ class SqlModelQuery(_Query):
98
88
  query: str
99
89
  config: _SqlModelConfig
100
90
 
101
91
  @dataclass
102
- class _PyModelQuery(_Query):
92
+ class PyModelQuery(_Query):
103
93
  query: Callable[[], pd.DataFrame]
104
94
 
105
95
 
106
- @dataclass(frozen=True)
107
- class _QueryFile:
108
- filepath: str
109
- model_type: ModelType
110
- query_type: QueryType
111
- raw_query: _RawQuery
112
-
113
-
114
96
  @dataclass
115
- class _Referable(metaclass=ABCMeta):
97
+ class Referable(metaclass=ABCMeta):
116
98
  name: str
117
99
  is_target: bool = field(default=False, init=False)
118
100
 
119
101
  needs_sql_table: bool = field(default=False, init=False)
120
102
  needs_pandas: bool = field(default=False, init=False)
121
- result: Optional[pd.DataFrame] = field(default=None, init=False, repr=False)
103
+ result: pd.DataFrame | None = field(default=None, init=False, repr=False)
122
104
 
123
105
  wait_count: int = field(default=0, init=False, repr=False)
124
106
  confirmed_no_cycles: bool = field(default=False, init=False)
125
- upstreams: dict[str, _Referable] = field(default_factory=dict, init=False, repr=False)
126
- downstreams: dict[str, _Referable] = field(default_factory=dict, init=False, repr=False)
107
+ upstreams: dict[str, Referable] = field(default_factory=dict, init=False, repr=False)
108
+ downstreams: dict[str, Referable] = field(default_factory=dict, init=False, repr=False)
127
109
 
128
110
  @abstractmethod
129
111
  def get_model_type(self) -> ModelType:
130
112
  pass
131
113
 
132
114
  async def compile(
133
- self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str, _Referable], recurse: bool
115
+ self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str, Referable], recurse: bool
134
116
  ) -> None:
135
117
  pass
136
118
 
@@ -160,11 +142,12 @@ class _Referable(metaclass=ABCMeta):
160
142
  def retrieve_dependent_query_models(self, dependent_model_names: set[str]) -> None:
161
143
  pass
162
144
 
163
- def get_max_path_length_to_target(self) -> int:
145
+ def get_max_path_length_to_target(self) -> int | None:
164
146
  if not hasattr(self, "max_path_len_to_target"):
165
147
  path_lengths = []
166
148
  for child_model in self.downstreams.values():
167
- path_lengths.append(child_model.get_max_path_length_to_target()+1)
149
+ assert isinstance(child_model_path_length := child_model.get_max_path_length_to_target(), int)
150
+ path_lengths.append(child_model_path_length+1)
168
151
  if len(path_lengths) > 0:
169
152
  self.max_path_len_to_target = max(path_lengths)
170
153
  else:
@@ -173,7 +156,7 @@ class _Referable(metaclass=ABCMeta):
173
156
 
174
157
 
175
158
  @dataclass
176
- class _Seed(_Referable):
159
+ class Seed(Referable):
177
160
  result: pd.DataFrame
178
161
 
179
162
  def get_model_type(self) -> ModelType:
@@ -189,43 +172,45 @@ class _Seed(_Referable):
189
172
 
190
173
 
191
174
  @dataclass
192
- class _Model(_Referable):
193
- query_file: _QueryFile
194
-
195
- compiled_query: Optional[_Query] = field(default=None, init=False)
175
+ class Model(Referable):
176
+ query_file: QueryFile
177
+ manifest_cfg: ManifestConfig
178
+ conn_set: ConnectionSet
179
+ logger: u.Logger = field(default_factory=lambda: u.Logger(""))
180
+ j2_env: u.j2.Environment = field(default_factory=lambda: u.j2.Environment(loader=u.j2.FileSystemLoader(".")))
181
+ compiled_query: _Query | None = field(default=None, init=False)
196
182
 
197
183
  def get_model_type(self) -> ModelType:
198
184
  return self.query_file.model_type
199
185
 
200
- def _add_upstream(self, other: _Referable) -> None:
186
+ def _add_upstream(self, other: Referable) -> None:
201
187
  self.upstreams[other.name] = other
202
188
  other.downstreams[self.name] = self
203
189
 
204
- if self.query_file.query_type == QueryType.PYTHON:
205
- other.needs_pandas = True
206
- elif self.query_file.query_type == QueryType.SQL:
190
+ if isinstance(self.query_file, SqlQueryFile):
207
191
  other.needs_sql_table = True
192
+ elif isinstance(self.query_file, PyQueryFile):
193
+ other.needs_pandas = True
208
194
 
209
195
  def _get_dbview_conn_name(self) -> str:
210
- dbview_config = ManifestIO.obj.dbviews.get(self.name)
196
+ dbview_config = self.manifest_cfg.dbviews.get(self.name)
211
197
  if dbview_config is None or dbview_config.connection_name is None:
212
- return ManifestIO.obj.settings.get(c.DB_CONN_DEFAULT_USED_SETTING, c.DEFAULT_DB_CONN)
198
+ return self.manifest_cfg.settings.get(c.DB_CONN_DEFAULT_USED_SETTING, c.DEFAULT_DB_CONN)
213
199
  return dbview_config.connection_name
214
200
 
215
- def _get_materialized(self) -> str:
216
- federate_config = ManifestIO.obj.federates.get(self.name)
201
+ def _get_materialized(self) -> _Materialization:
202
+ federate_config = self.manifest_cfg.federates.get(self.name)
217
203
  if federate_config is None or federate_config.materialized is None:
218
- materialized = ManifestIO.obj.settings.get(c.DEFAULT_MATERIALIZE_SETTING, c.DEFAULT_TABLE_MATERIALIZE)
204
+ materialized = self.manifest_cfg.settings.get(c.DEFAULT_MATERIALIZE_SETTING, c.DEFAULT_MATERIALIZE)
219
205
  else:
220
206
  materialized = federate_config.materialized
221
- return Materialization[materialized.upper()]
207
+ return _Materialization[materialized.upper()]
222
208
 
223
209
  async def _compile_sql_model(
224
- self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any]
225
- ) -> tuple[_SqlModelQuery, set]:
226
- assert(isinstance(self.query_file.raw_query, _RawSqlQuery))
227
-
228
- raw_query = self.query_file.raw_query.query
210
+ self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str, Referable]
211
+ ) -> tuple[SqlModelQuery, set]:
212
+ assert isinstance(self.query_file, SqlQueryFile)
213
+
229
214
  connection_name = self._get_dbview_conn_name()
230
215
  materialized = self._get_materialized()
231
216
  configuration = _SqlModelConfig(connection_name, materialized)
@@ -237,51 +222,68 @@ class _Model(_Referable):
237
222
  }
238
223
  dependencies = set()
239
224
  if self.query_file.model_type == ModelType.FEDERATE:
240
- def ref(name):
241
- dependencies.add(name)
242
- return name
225
+ def ref(dependent_model_name):
226
+ if dependent_model_name not in models_dict:
227
+ raise u.ConfigurationError(f'Model "{self.name}" references unknown model "{dependent_model_name}"')
228
+ dependencies.add(dependent_model_name)
229
+ return dependent_model_name
243
230
  kwargs["ref"] = ref
244
231
 
245
232
  try:
246
- query = await asyncio.to_thread(u.render_string, raw_query, **kwargs)
233
+ template = self.j2_env.from_string(self.query_file.raw_query)
234
+ query = await asyncio.to_thread(template.render, kwargs)
247
235
  except Exception as e:
248
236
  raise u.FileExecutionError(f'Failed to compile sql model "{self.name}"', e) from e
249
237
 
250
- compiled_query = _SqlModelQuery(query, configuration)
238
+ compiled_query = SqlModelQuery(query, configuration)
251
239
  return compiled_query, dependencies
252
240
 
253
241
  async def _compile_python_model(
254
- self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any]
255
- ) -> tuple[_PyModelQuery, set]:
256
- assert(isinstance(self.query_file.raw_query, _RawPyQuery))
242
+ self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str, Referable]
243
+ ) -> tuple[PyModelQuery, Iterable]:
244
+ assert isinstance(self.query_file, PyQueryFile)
257
245
 
258
246
  sqrl_args = ModelDepsArgs(
259
247
  ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits, placeholders, ctx
260
248
  )
261
249
  try:
262
250
  dependencies = await asyncio.to_thread(self.query_file.raw_query.dependencies_func, sqrl_args)
251
+ for dependent_model_name in dependencies:
252
+ if dependent_model_name not in models_dict:
253
+ raise u.ConfigurationError(f'Model "{self.name}" references unknown model "{dependent_model_name}"')
263
254
  except Exception as e:
264
255
  raise u.FileExecutionError(f'Failed to run "{c.DEP_FUNC}" function for python model "{self.name}"', e) from e
265
256
 
266
257
  dbview_conn_name = self._get_dbview_conn_name()
267
- connections = ConnectionSetIO.obj.get_engines_as_dict()
268
- ref = lambda model: self.upstreams[model].result
258
+ connections = self.conn_set.get_engines_as_dict()
259
+
260
+ def ref(dependent_model_name):
261
+ if dependent_model_name not in self.upstreams:
262
+ raise u.ConfigurationError(f'Model "{self.name}" must include model "{dependent_model_name}" as a dependency to use')
263
+ return pd.DataFrame(self.upstreams[dependent_model_name].result)
264
+
265
+ def run_external_sql(sql_query: str, connection_name: str | None):
266
+ connection_name = dbview_conn_name if connection_name is None else connection_name
267
+ return self.conn_set.run_sql_query_from_conn_name(sql_query, connection_name, placeholders)
268
+
269
+ use_duckdb = self.manifest_cfg.settings_obj.do_use_duckdb()
269
270
  sqrl_args = ModelArgs(
270
271
  ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits, placeholders, ctx,
271
- dbview_conn_name, connections, dependencies, ref
272
+ dbview_conn_name, connections, dependencies, ref, run_external_sql, use_duckdb
272
273
  )
273
274
 
274
275
  def compiled_query():
275
276
  try:
277
+ assert isinstance(self.query_file, PyQueryFile)
276
278
  raw_query: _RawPyQuery = self.query_file.raw_query
277
- return raw_query.query(sqrl=sqrl_args)
279
+ return raw_query.query(sqrl_args)
278
280
  except Exception as e:
279
281
  raise u.FileExecutionError(f'Failed to run "{c.MAIN_FUNC}" function for python model "{self.name}"', e) from e
280
282
 
281
- return _PyModelQuery(compiled_query), dependencies
283
+ return PyModelQuery(compiled_query), dependencies
282
284
 
283
285
  async def compile(
284
- self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str, _Referable], recurse: bool
286
+ self, ctx: dict[str, Any], ctx_args: ContextArgs, placeholders: dict[str, Any], models_dict: dict[str, Referable], recurse: bool
285
287
  ) -> None:
286
288
  if self.compiled_query is not None:
287
289
  return
@@ -290,18 +292,18 @@ class _Model(_Referable):
290
292
 
291
293
  start = time.time()
292
294
 
293
- if self.query_file.query_type == QueryType.SQL:
294
- compiled_query, dependencies = await self._compile_sql_model(ctx, ctx_args, placeholders)
295
- elif self.query_file.query_type == QueryType.PYTHON:
296
- compiled_query, dependencies = await self._compile_python_model(ctx, ctx_args, placeholders)
295
+ if isinstance(self.query_file, SqlQueryFile):
296
+ compiled_query, dependencies = await self._compile_sql_model(ctx, ctx_args, placeholders, models_dict)
297
+ elif isinstance(self.query_file, PyQueryFile):
298
+ compiled_query, dependencies = await self._compile_python_model(ctx, ctx_args, placeholders, models_dict)
297
299
  else:
298
- raise u.ConfigurationError(f"Query type not supported: {self.query_file.query_type}")
300
+ raise NotImplementedError(f"Query type not supported: {self.query_file.__class__.__name__}")
299
301
 
300
302
  self.compiled_query = compiled_query
301
- self.wait_count = len(dependencies)
303
+ self.wait_count = len(set(dependencies))
302
304
 
303
305
  model_type = self.get_model_type().name.lower()
304
- timer.add_activity_time(f"compiling {model_type} model '{self.name}'", start)
306
+ self.logger.log_activity_time(f"compiling {model_type} model '{self.name}'", start)
305
307
 
306
308
  if not recurse:
307
309
  return
@@ -335,14 +337,14 @@ class _Model(_Referable):
335
337
  return terminal_nodes
336
338
 
337
339
  async def _run_sql_model(self, conn: Connection, placeholders: dict = {}) -> None:
338
- assert(isinstance(self.compiled_query, _SqlModelQuery))
340
+ assert(isinstance(self.compiled_query, SqlModelQuery))
339
341
  config = self.compiled_query.config
340
342
  query = self.compiled_query.query
341
343
 
342
344
  if self.query_file.model_type == ModelType.DBVIEW:
343
345
  def run_sql_query():
344
346
  try:
345
- return ConnectionSetIO.obj.run_sql_query_from_conn_name(query, config.connection_name, placeholders)
347
+ return self.conn_set.run_sql_query_from_conn_name(query, config.connection_name, placeholders)
346
348
  except RuntimeError as e:
347
349
  raise u.FileExecutionError(f'Failed to run dbview sql model "{self.name}"', e) from e
348
350
 
@@ -363,7 +365,7 @@ class _Model(_Referable):
363
365
  self.result = await asyncio.to_thread(self._load_table_to_pandas, conn)
364
366
 
365
367
  async def _run_python_model(self, conn: Connection) -> None:
366
- assert(isinstance(self.compiled_query, _PyModelQuery))
368
+ assert(isinstance(self.compiled_query, PyModelQuery))
367
369
 
368
370
  df = await asyncio.to_thread(self.compiled_query.query)
369
371
  if self.needs_sql_table:
@@ -374,13 +376,15 @@ class _Model(_Referable):
374
376
  async def run_model(self, conn: Connection, placeholders: dict = {}) -> None:
375
377
  start = time.time()
376
378
 
377
- if self.query_file.query_type == QueryType.SQL:
379
+ if isinstance(self.query_file, SqlQueryFile):
378
380
  await self._run_sql_model(conn, placeholders)
379
- elif self.query_file.query_type == QueryType.PYTHON:
381
+ elif isinstance(self.query_file, PyQueryFile):
380
382
  await self._run_python_model(conn)
383
+ else:
384
+ raise NotImplementedError(f"Query type not supported: {self.query_file.__class__.__name__}")
381
385
 
382
386
  model_type = self.get_model_type().name.lower()
383
- timer.add_activity_time(f"running {model_type} model '{self.name}'", start)
387
+ self.logger.log_activity_time(f"running {model_type} model '{self.name}'", start)
384
388
 
385
389
  await super().run_model(conn, placeholders)
386
390
 
@@ -392,35 +396,37 @@ class _Model(_Referable):
392
396
 
393
397
 
394
398
  @dataclass
395
- class _DAG:
396
- dataset: DatasetsConfig
397
- target_model: _Referable
398
- models_dict: dict[str, _Referable]
399
- parameter_set: Optional[ParameterSet] = field(default=None, init=False)
399
+ class DAG:
400
+ manifest_cfg: ManifestConfig
401
+ dataset: DatasetConfig
402
+ target_model: Referable
403
+ models_dict: dict[str, Referable]
404
+ logger: u.Logger = field(default_factory=lambda: u.Logger(""))
405
+ parameter_set: ParameterSet | None = field(default=None, init=False) # set in apply_selections
400
406
  placeholders: dict[str, Any] = field(init=False, default_factory=dict)
401
407
 
402
408
  def apply_selections(
403
- self, user: Optional[User], selections: dict[str, str], *, updates_only: bool = False, request_version: Optional[int] = None
409
+ self, param_cfg_set: ParameterConfigsSet, user: User | None, selections: dict[str, str], *, updates_only: bool = False, request_version: int | None = None
404
410
  ) -> None:
405
411
  start = time.time()
406
412
  dataset_params = self.dataset.parameters
407
- parameter_set = ParameterConfigsSetIO.obj.apply_selections(
413
+ parameter_set = param_cfg_set.apply_selections(
408
414
  dataset_params, selections, user, updates_only=updates_only, request_version=request_version
409
415
  )
410
416
  self.parameter_set = parameter_set
411
- timer.add_activity_time(f"applying selections for dataset '{self.dataset.name}'", start)
417
+ self.logger.log_activity_time(f"applying selections for dataset '{self.dataset.name}'", start)
412
418
 
413
- def _compile_context(self, context_func: ContextFunc, user: Optional[User]) -> tuple[dict[str, Any], ContextArgs]:
419
+ def _compile_context(self, param_args: ParametersArgs, context_func: ContextFunc, user: User | None) -> tuple[dict[str, Any], ContextArgs]:
414
420
  start = time.time()
415
421
  context = {}
416
- param_args = ParameterConfigsSetIO.args
422
+ assert isinstance(self.parameter_set, ParameterSet)
417
423
  prms = self.parameter_set.get_parameters_as_dict()
418
424
  args = ContextArgs(param_args.proj_vars, param_args.env_vars, user, prms, self.dataset.traits, self.placeholders)
419
425
  try:
420
- context_func(ctx=context, sqrl=args)
426
+ context_func(context, args)
421
427
  except Exception as e:
422
428
  raise u.FileExecutionError(f'Failed to run {c.CONTEXT_FILE} for dataset "{self.dataset.name}"', e) from e
423
- timer.add_activity_time(f"running context.py for dataset '{self.dataset.name}'", start)
429
+ self.logger.log_activity_time(f"running context.py for dataset '{self.dataset.name}'", start)
424
430
  return context, args
425
431
 
426
432
  async def _compile_models(self, context: dict[str, Any], ctx_args: ContextArgs, recurse: bool) -> None:
@@ -431,11 +437,12 @@ class _DAG:
431
437
  terminal_nodes = self.target_model.get_terminal_nodes(set())
432
438
  for model in self.models_dict.values():
433
439
  model.confirmed_no_cycles = False
434
- timer.add_activity_time(f"validating no cycles in model dependencies", start)
440
+ self.logger.log_activity_time(f"validating no cycles in model dependencies", start)
435
441
  return terminal_nodes
436
442
 
437
443
  async def _run_models(self, terminal_nodes: set[str], placeholders: dict = {}) -> None:
438
- conn_url = "duckdb:///" if u.use_duckdb() else "sqlite:///?check_same_thread=False"
444
+ use_duckdb = self.manifest_cfg.settings_obj.do_use_duckdb()
445
+ conn_url = "duckdb:///" if use_duckdb else "sqlite:///?check_same_thread=False"
439
446
  engine = create_engine(conn_url)
440
447
 
441
448
  with engine.connect() as conn:
@@ -448,14 +455,14 @@ class _DAG:
448
455
  engine.dispose()
449
456
 
450
457
  async def execute(
451
- self, context_func: ContextFunc, user: Optional[User], selections: dict[str, str], *, request_version: Optional[int] = None,
452
- runquery: bool = True, recurse: bool = True
458
+ self, param_args: ParametersArgs, param_cfg_set: ParameterConfigsSet, context_func: ContextFunc, user: User | None, selections: dict[str, str],
459
+ *, request_version: int | None = None, runquery: bool = True, recurse: bool = True
453
460
  ) -> dict[str, Any]:
454
461
  recurse = (recurse or runquery)
455
462
 
456
- self.apply_selections(user, selections, request_version=request_version)
463
+ self.apply_selections(param_cfg_set, user, selections, request_version=request_version)
457
464
 
458
- context, ctx_args = self._compile_context(context_func, user)
465
+ context, ctx_args = self._compile_context(param_args, context_func, user)
459
466
 
460
467
  await self._compile_models(context, ctx_args, recurse)
461
468
 
@@ -488,194 +495,54 @@ class _DAG:
488
495
 
489
496
  return G
490
497
 
498
+
491
499
  class ModelsIO:
492
- raw_queries_by_model: dict[str, _QueryFile]
493
- context_func: ContextFunc
494
500
 
495
501
  @classmethod
496
- def LoadFiles(cls) -> None:
502
+ def load_files(cls, logger: u.Logger, base_path: str) -> dict[str, QueryFile]:
497
503
  start = time.time()
498
- cls.raw_queries_by_model = {}
499
-
500
- def populate_raw_queries_for_type(folder_path: Path, model_type: ModelType):
501
- def populate_from_file(dp, file):
502
- query_type = None
503
- filepath = os.path.join(dp, file)
504
- file_stem, extension = os.path.splitext(file)
505
- if extension == '.py':
506
- query_type = QueryType.PYTHON
507
- module = pm.PyModule(filepath)
508
- dependencies_func = module.get_func_or_class(c.DEP_FUNC, default_attr=lambda sqrl: [])
509
- raw_query = _RawPyQuery(module.get_func_or_class(c.MAIN_FUNC), dependencies_func)
510
- elif extension == '.sql':
511
- query_type = QueryType.SQL
512
- raw_query = _RawSqlQuery(u.read_file(filepath))
513
-
514
- if query_type is not None:
515
- query_file = _QueryFile(filepath, model_type, query_type, raw_query)
516
- if file_stem in cls.raw_queries_by_model:
517
- conflicts = [cls.raw_queries_by_model[file_stem].filepath, filepath]
518
- raise u.ConfigurationError(f"Multiple models found for '{file_stem}': {conflicts}")
519
- cls.raw_queries_by_model[file_stem] = query_file
504
+ raw_queries_by_model: dict[str, QueryFile] = {}
505
+
506
+ def populate_from_file(dp: str, file: str, model_type: ModelType) -> None:
507
+ filepath = Path(dp, file)
508
+ file_stem, extension = os.path.splitext(file)
509
+ if extension == '.py':
510
+ module = pm.PyModule(filepath)
511
+ dependencies_func = module.get_func_or_class(c.DEP_FUNC, default_attr=lambda sqrl: [])
512
+ raw_query = _RawPyQuery(module.get_func_or_class(c.MAIN_FUNC), dependencies_func)
513
+ query_file = PyQueryFile(filepath.as_posix(), model_type, raw_query)
514
+ elif extension == '.sql':
515
+ query_file = SqlQueryFile(filepath.as_posix(), model_type, filepath.read_text())
516
+ else:
517
+ query_file = None
520
518
 
519
+ if query_file is not None:
520
+ if file_stem in raw_queries_by_model:
521
+ conflicts = [raw_queries_by_model[file_stem].filepath, filepath]
522
+ raise u.ConfigurationError(f"Multiple models found for '{file_stem}': {conflicts}")
523
+ raw_queries_by_model[file_stem] = query_file
524
+
525
+ def populate_raw_queries_for_type(folder_path: Path, model_type: ModelType) -> None:
521
526
  for dp, _, filenames in os.walk(folder_path):
522
527
  for file in filenames:
523
- populate_from_file(dp, file)
528
+ populate_from_file(dp, file, model_type)
524
529
 
525
- dbviews_path = u.join_paths(c.MODELS_FOLDER, c.DBVIEWS_FOLDER)
530
+ dbviews_path = u.Path(base_path, c.MODELS_FOLDER, c.DBVIEWS_FOLDER)
526
531
  populate_raw_queries_for_type(dbviews_path, ModelType.DBVIEW)
527
532
 
528
- federates_path = u.join_paths(c.MODELS_FOLDER, c.FEDERATES_FOLDER)
533
+ federates_path = u.Path(base_path, c.MODELS_FOLDER, c.FEDERATES_FOLDER)
529
534
  populate_raw_queries_for_type(federates_path, ModelType.FEDERATE)
530
535
 
531
- context_path = u.join_paths(c.PYCONFIGS_FOLDER, c.CONTEXT_FILE)
532
- cls.context_func = pm.PyModule(context_path).get_func_or_class(c.MAIN_FUNC, default_attr=lambda ctx, sqrl: None)
533
-
534
- timer.add_activity_time("loading files for models and context.py", start)
536
+ logger.log_activity_time("loading files for models", start)
537
+ return raw_queries_by_model
535
538
 
536
539
  @classmethod
537
- def GenerateDAG(cls, dataset: str, *, target_model_name: Optional[str] = None, always_pandas: bool = False) -> _DAG:
538
- seeds_dict = SeedsIO.obj.get_dataframes()
539
-
540
- models_dict: dict[str, _Referable] = {key: _Seed(key, df) for key, df in seeds_dict.items()}
541
- for key, val in cls.raw_queries_by_model.items():
542
- models_dict[key] = _Model(key, val)
543
- models_dict[key].needs_pandas = always_pandas
544
-
545
- dataset_config = ManifestIO.obj.datasets[dataset]
546
- target_model_name = dataset_config.model if target_model_name is None else target_model_name
547
- target_model = models_dict[target_model_name]
548
- target_model.is_target = True
549
-
550
- return _DAG(dataset_config, target_model, models_dict)
551
-
552
- @classmethod
553
- def draw_dag(cls, dag: _DAG, output_folder: Path) -> None:
554
- color_map = {ModelType.SEED: "green", ModelType.DBVIEW: "red", ModelType.FEDERATE: "skyblue"}
555
-
556
- G = dag.to_networkx_graph()
557
-
558
- fig, _ = plt.subplots()
559
- pos = nx.multipartite_layout(G, subset_key="layer")
560
- colors = [color_map[node[1]] for node in G.nodes(data="model_type")]
561
- nx.draw(G, pos=pos, node_shape='^', node_size=1000, node_color=colors, arrowsize=20)
562
-
563
- y_values = [val[1] for val in pos.values()]
564
- scale = max(y_values) - min(y_values) if len(y_values) > 0 else 0
565
- label_pos = {key: (val[0], val[1]-0.002-0.1*scale) for key, val in pos.items()}
566
- nx.draw_networkx_labels(G, pos=label_pos, font_size=8)
567
-
568
- fig.tight_layout()
569
- plt.margins(x=0.1, y=0.1)
570
- plt.savefig(u.join_paths(output_folder, "dag.png"))
571
- plt.close(fig)
572
-
573
- @classmethod
574
- async def WriteDatasetOutputsGivenTestSet(
575
- cls, dataset_conf: DatasetsConfig, select: str, test_set: Optional[str], runquery: bool, recurse: bool
576
- ) -> Any:
577
- dataset = dataset_conf.name
578
- default_test_set, default_test_set_conf = ManifestIO.obj.get_default_test_set(dataset)
579
- if test_set is None or test_set == default_test_set:
580
- test_set, test_set_conf = default_test_set, default_test_set_conf
581
- elif test_set in ManifestIO.obj.selection_test_sets:
582
- test_set_conf = ManifestIO.obj.selection_test_sets[test_set]
583
- else:
584
- raise u.ConfigurationError(f"No test set named '{test_set}' was found when compiling dataset '{dataset}'. The test set must be defined if not default for dataset.")
585
-
586
- error_msg_intro = f"Cannot compile dataset '{dataset}' with test set '{test_set}'."
587
- if test_set_conf.datasets is not None and dataset not in test_set_conf.datasets:
588
- raise u.ConfigurationError(f"{error_msg_intro}\n Applicable datasets for test set '{test_set}' does not include dataset '{dataset}'.")
589
-
590
- user_attributes = test_set_conf.user_attributes.copy()
591
- selections = test_set_conf.parameters.copy()
592
- username, is_internal = user_attributes.pop("username", ""), user_attributes.pop("is_internal", False)
593
- if test_set_conf.is_authenticated:
594
- user_cls: type[User] = Authenticator.get_auth_helper().get_func_or_class("User", default_attr=User)
595
- user = user_cls.Create(username, is_internal=is_internal, **user_attributes)
596
- elif dataset_conf.scope == DatasetScope.PUBLIC:
597
- user = None
598
- else:
599
- raise u.ConfigurationError(f"{error_msg_intro}\n Non-public datasets require a test set with 'user_attributes' section defined")
600
-
601
- if dataset_conf.scope == DatasetScope.PRIVATE and not is_internal:
602
- raise u.ConfigurationError(f"{error_msg_intro}\n Private datasets require a test set with user_attribute 'is_internal' set to true")
603
-
604
- # always_pandas is set to True for creating CSV files from results (when runquery is True)
605
- dag = cls.GenerateDAG(dataset, target_model_name=select, always_pandas=True)
606
- placeholders = await dag.execute(cls.context_func, user, selections, runquery=runquery, recurse=recurse)
607
-
608
- output_folder = u.join_paths(c.TARGET_FOLDER, c.COMPILE_FOLDER, dataset, test_set)
609
- if os.path.exists(output_folder):
610
- shutil.rmtree(output_folder)
611
- os.makedirs(output_folder, exist_ok=True)
612
-
613
- def write_placeholders() -> None:
614
- output_filepath = u.join_paths(output_folder, "placeholders.json")
615
- with open(output_filepath, 'w') as f:
616
- json.dump(placeholders, f, indent=4)
617
-
618
- def write_model_outputs(model: _Model) -> None:
619
- subfolder = c.DBVIEWS_FOLDER if model.query_file.model_type == ModelType.DBVIEW else c.FEDERATES_FOLDER
620
- subpath = u.join_paths(output_folder, subfolder)
621
- os.makedirs(subpath, exist_ok=True)
622
- if isinstance(model.compiled_query, _SqlModelQuery):
623
- output_filepath = u.join_paths(subpath, model.name+'.sql')
624
- query = model.compiled_query.query
625
- with open(output_filepath, 'w') as f:
626
- f.write(query)
627
- if runquery and isinstance(model.result, pd.DataFrame):
628
- output_filepath = u.join_paths(subpath, model.name+'.csv')
629
- model.result.to_csv(output_filepath, index=False)
630
-
631
- write_placeholders()
632
- all_model_names = dag.get_all_query_models()
633
- coroutines = [asyncio.to_thread(write_model_outputs, dag.models_dict[name]) for name in all_model_names]
634
- await asyncio.gather(*coroutines)
540
+ def load_context_func(cls, logger: u.Logger, base_path: str) -> ContextFunc:
541
+ start = time.time()
635
542
 
636
- if recurse:
637
- cls.draw_dag(dag, output_folder)
638
-
639
- if isinstance(dag.target_model, _Model):
640
- return dag.target_model.compiled_query.query # else return None
543
+ context_path = u.Path(base_path, c.PYCONFIGS_FOLDER, c.CONTEXT_FILE)
544
+ context_func: ContextFunc = pm.PyModule(context_path).get_func_or_class(c.MAIN_FUNC, default_attr=lambda ctx, sqrl: None)
641
545
 
642
- @classmethod
643
- async def WriteOutputs(
644
- cls, dataset: Optional[str], do_all_datasets: bool, select: Optional[str], test_set: Optional[str], do_all_test_sets: bool,
645
- runquery: bool
646
- ) -> None:
647
-
648
- def get_applicable_test_sets(dataset: str) -> list[str]:
649
- applicable_test_sets = []
650
- for test_set_name, test_set_config in ManifestIO.obj.selection_test_sets.items():
651
- if test_set_config.datasets is None or dataset in test_set_config.datasets:
652
- applicable_test_sets.append(test_set_name)
653
- return applicable_test_sets
654
-
655
- recurse = True
656
- dataset_configs = ManifestIO.obj.datasets
657
- if do_all_datasets:
658
- selected_models = [(dataset, dataset.model) for dataset in dataset_configs.values()]
659
- else:
660
- if select is None:
661
- select = dataset_configs[dataset].model
662
- else:
663
- recurse = False
664
- selected_models = [(dataset_configs[dataset], select)]
665
-
666
- coroutines = []
667
- for dataset_conf, select in selected_models:
668
- if do_all_test_sets:
669
- for test_set_name in get_applicable_test_sets(dataset_conf.name):
670
- coroutine = cls.WriteDatasetOutputsGivenTestSet(dataset_conf, select, test_set_name, runquery, recurse)
671
- coroutines.append(coroutine)
672
-
673
- coroutine = cls.WriteDatasetOutputsGivenTestSet(dataset_conf, select, test_set, runquery, recurse)
674
- coroutines.append(coroutine)
675
-
676
- queries = await asyncio.gather(*coroutines)
677
- if not recurse and len(queries) == 1 and isinstance(queries[0], str):
678
- print()
679
- print(queries[0])
680
- print()
546
+ logger.log_activity_time("loading file for context.py", start)
547
+ return context_func
681
548