squirrels 0.1.1.post1__py3-none-any.whl → 0.2.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of squirrels might be problematic. Click here for more details.

Files changed (74) hide show
  1. squirrels/__init__.py +10 -16
  2. squirrels/_api_server.py +234 -80
  3. squirrels/_authenticator.py +84 -0
  4. squirrels/_command_line.py +60 -72
  5. squirrels/_connection_set.py +96 -0
  6. squirrels/_constants.py +114 -33
  7. squirrels/_environcfg.py +77 -0
  8. squirrels/_initializer.py +126 -67
  9. squirrels/_manifest.py +195 -168
  10. squirrels/_models.py +495 -0
  11. squirrels/_package_loader.py +26 -0
  12. squirrels/_parameter_configs.py +401 -0
  13. squirrels/_parameter_sets.py +188 -0
  14. squirrels/_py_module.py +60 -0
  15. squirrels/_timer.py +36 -0
  16. squirrels/_utils.py +81 -49
  17. squirrels/_version.py +2 -2
  18. squirrels/arguments/init_time_args.py +32 -0
  19. squirrels/arguments/run_time_args.py +82 -0
  20. squirrels/data_sources.py +380 -155
  21. squirrels/dateutils.py +86 -57
  22. squirrels/package_data/base_project/Dockerfile +15 -0
  23. squirrels/package_data/base_project/connections.yml +7 -0
  24. squirrels/package_data/base_project/database/{sample_database.db → expenses.db} +0 -0
  25. squirrels/package_data/base_project/environcfg.yml +29 -0
  26. squirrels/package_data/base_project/ignores/.dockerignore +8 -0
  27. squirrels/package_data/base_project/ignores/.gitignore +7 -0
  28. squirrels/package_data/base_project/models/dbviews/database_view1.py +36 -0
  29. squirrels/package_data/base_project/models/dbviews/database_view1.sql +15 -0
  30. squirrels/package_data/base_project/models/federates/dataset_example.py +20 -0
  31. squirrels/package_data/base_project/models/federates/dataset_example.sql +3 -0
  32. squirrels/package_data/base_project/parameters.yml +109 -0
  33. squirrels/package_data/base_project/pyconfigs/auth.py +47 -0
  34. squirrels/package_data/base_project/pyconfigs/connections.py +28 -0
  35. squirrels/package_data/base_project/pyconfigs/context.py +45 -0
  36. squirrels/package_data/base_project/pyconfigs/parameters.py +55 -0
  37. squirrels/package_data/base_project/seeds/mocks/category.csv +3 -0
  38. squirrels/package_data/base_project/seeds/mocks/max_filter.csv +2 -0
  39. squirrels/package_data/base_project/seeds/mocks/subcategory.csv +6 -0
  40. squirrels/package_data/base_project/squirrels.yml.j2 +57 -0
  41. squirrels/package_data/base_project/tmp/.gitignore +2 -0
  42. squirrels/package_data/static/script.js +159 -63
  43. squirrels/package_data/static/style.css +79 -15
  44. squirrels/package_data/static/widgets.js +133 -0
  45. squirrels/package_data/templates/index.html +65 -23
  46. squirrels/package_data/templates/index2.html +22 -0
  47. squirrels/parameter_options.py +216 -119
  48. squirrels/parameters.py +407 -478
  49. squirrels/user_base.py +58 -0
  50. squirrels-0.2.0.dev0.dist-info/METADATA +126 -0
  51. squirrels-0.2.0.dev0.dist-info/RECORD +56 -0
  52. {squirrels-0.1.1.post1.dist-info → squirrels-0.2.0.dev0.dist-info}/WHEEL +1 -2
  53. squirrels-0.2.0.dev0.dist-info/entry_points.txt +3 -0
  54. squirrels/_credentials_manager.py +0 -87
  55. squirrels/_module_loader.py +0 -37
  56. squirrels/_parameter_set.py +0 -151
  57. squirrels/_renderer.py +0 -286
  58. squirrels/_timed_imports.py +0 -37
  59. squirrels/connection_set.py +0 -126
  60. squirrels/package_data/base_project/.gitignore +0 -4
  61. squirrels/package_data/base_project/connections.py +0 -20
  62. squirrels/package_data/base_project/datasets/sample_dataset/context.py +0 -22
  63. squirrels/package_data/base_project/datasets/sample_dataset/database_view1.py +0 -29
  64. squirrels/package_data/base_project/datasets/sample_dataset/database_view1.sql.j2 +0 -12
  65. squirrels/package_data/base_project/datasets/sample_dataset/final_view.py +0 -11
  66. squirrels/package_data/base_project/datasets/sample_dataset/final_view.sql.j2 +0 -3
  67. squirrels/package_data/base_project/datasets/sample_dataset/parameters.py +0 -47
  68. squirrels/package_data/base_project/datasets/sample_dataset/selections.cfg +0 -9
  69. squirrels/package_data/base_project/squirrels.yaml +0 -22
  70. squirrels-0.1.1.post1.dist-info/METADATA +0 -67
  71. squirrels-0.1.1.post1.dist-info/RECORD +0 -40
  72. squirrels-0.1.1.post1.dist-info/entry_points.txt +0 -2
  73. squirrels-0.1.1.post1.dist-info/top_level.txt +0 -1
  74. {squirrels-0.1.1.post1.dist-info → squirrels-0.2.0.dev0.dist-info}/LICENSE +0 -0
squirrels/_models.py ADDED
@@ -0,0 +1,495 @@
1
+ from __future__ import annotations
2
+ from typing import Optional, Callable, Iterable, Any
3
+ from dataclasses import dataclass, field
4
+ from enum import Enum
5
+ from pathlib import Path
6
+ import sqlite3, pandas as pd, asyncio, os, shutil
7
+
8
+ from . import _constants as c, _utils as u, _py_module as pm
9
+ from .arguments.run_time_args import ContextArgs, ModelDepsArgs, ModelArgs
10
+ from ._authenticator import User
11
+ from ._connection_set import ConnectionSetIO
12
+ from ._manifest import ManifestIO, DatasetsConfig
13
+ from ._parameter_sets import ParameterConfigsSetIO, ParameterSet
14
+ from ._timer import timer, time
15
+
16
+ class ModelType(Enum):
17
+ DBVIEW = 1
18
+ FEDERATE = 2
19
+
20
+ class QueryType(Enum):
21
+ SQL = 0
22
+ PYTHON = 1
23
+
24
+ class Materialization(Enum):
25
+ TABLE = 0
26
+ VIEW = 1
27
+
28
+
29
+ @dataclass
30
+ class SqlModelConfig:
31
+ ## Applicable for dbview models
32
+ connection_name: str
33
+
34
+ ## Applicable for federated models
35
+ materialized: Materialization
36
+
37
+ def get_sql_for_create(self, model_name: str, select_query: str) -> str:
38
+ if self.materialized == Materialization.TABLE:
39
+ create_prefix = f"CREATE TABLE {model_name} AS\n"
40
+ elif self.materialized == Materialization.VIEW:
41
+ create_prefix = f"CREATE VIEW {model_name} AS\n"
42
+ else:
43
+ raise NotImplementedError(f"Materialization option not supported: {self.materialized}")
44
+
45
+ return create_prefix + select_query
46
+
47
+ def set_attribute(self, **kwargs) -> str:
48
+ connection_name = kwargs.get(c.DBVIEW_CONN_KEY)
49
+ materialized = kwargs.get(c.MATERIALIZED_KEY)
50
+ if isinstance(connection_name, str):
51
+ self.connection_name = connection_name
52
+ if isinstance(materialized, str):
53
+ self.materialized = Materialization[materialized.upper()]
54
+ return ""
55
+
56
+
57
+ ContextFunc = Callable[[dict[str, Any], ContextArgs], None]
58
+
59
+
60
+ @dataclass(frozen=True)
61
+ class RawQuery:
62
+ pass
63
+
64
+ @dataclass(frozen=True)
65
+ class RawSqlQuery(RawQuery):
66
+ query: str
67
+
68
+ @dataclass(frozen=True)
69
+ class RawPyQuery(RawQuery):
70
+ query: Callable[[Any], pd.DataFrame]
71
+ dependencies_func: Callable[[Any], Iterable]
72
+
73
+
74
+ @dataclass
75
+ class Query:
76
+ query: Any
77
+
78
+ @dataclass
79
+ class WorkInProgress:
80
+ query: None = field(default=None, init=False)
81
+
82
+ @dataclass
83
+ class SqlModelQuery(Query):
84
+ query: str
85
+ config: SqlModelConfig
86
+
87
+ @dataclass
88
+ class PyModelQuery(Query):
89
+ query: Callable[[], pd.DataFrame]
90
+
91
+
92
+ @dataclass(frozen=True)
93
+ class QueryFile:
94
+ filepath: str
95
+ model_type: ModelType
96
+ query_type: QueryType
97
+ raw_query: RawQuery
98
+
99
+
100
+ @dataclass
101
+ class Model:
102
+ name: str
103
+ query_file: QueryFile
104
+ is_target: bool = field(default=False, init=False)
105
+ compiled_query: Optional[Query] = field(default=None, init=False)
106
+
107
+ needs_sql_table: bool = field(default=False, init=False)
108
+ needs_pandas: bool = False
109
+ result: Optional[pd.DataFrame] = field(default=None, init=False, repr=False)
110
+
111
+ wait_count: int = field(default=0, init=False, repr=False)
112
+ upstreams: dict[str, Model] = field(default_factory=dict, init=False, repr=False)
113
+ downstreams: dict[str, Model] = field(default_factory=dict, init=False, repr=False)
114
+
115
+ confirmed_no_cycles: bool = field(default=False, init=False)
116
+
117
+ def _add_upstream(self, other: Model) -> None:
118
+ self.upstreams[other.name] = other
119
+ other.downstreams[self.name] = self
120
+
121
+ if self.query_file.query_type == QueryType.PYTHON:
122
+ other.needs_pandas = True
123
+ elif self.query_file.query_type == QueryType.SQL:
124
+ other.needs_sql_table = True
125
+
126
+ def _get_dbview_conn_name(self) -> str:
127
+ dbview_config = ManifestIO.obj.dbviews.get(self.name)
128
+ if dbview_config is None or dbview_config.connection_name is None:
129
+ return ManifestIO.obj.settings.get(c.DB_CONN_DEFAULT_USED_SETTING, c.DEFAULT_DB_CONN)
130
+ return dbview_config.connection_name
131
+
132
+ def _get_materialized(self) -> str:
133
+ federate_config = ManifestIO.obj.federates.get(self.name)
134
+ if federate_config is None or federate_config.materialized is None:
135
+ materialized = ManifestIO.obj.settings.get(c.DEFAULT_MATERIALIZE_SETTING, c.DEFAULT_TABLE_MATERIALIZE)
136
+ else:
137
+ materialized = federate_config.materialized
138
+ return Materialization[materialized.upper()]
139
+
140
+ async def _compile_sql_model(self, ctx: dict[str, Any], ctx_args: ContextArgs) -> tuple[SqlModelQuery, set]:
141
+ assert(isinstance(self.query_file.raw_query, RawSqlQuery))
142
+ raw_query = self.query_file.raw_query.query
143
+
144
+ connection_name = self._get_dbview_conn_name()
145
+ materialized = self._get_materialized()
146
+ configuration = SqlModelConfig(connection_name, materialized)
147
+ kwargs = {
148
+ "proj_vars": ctx_args.proj_vars, "env_vars": ctx_args.env_vars,
149
+ "user": ctx_args.user, "prms": ctx_args.prms, "args": ctx_args.args,
150
+ "ctx": ctx, "config": configuration.set_attribute
151
+ }
152
+ dependencies = set()
153
+ if self.query_file.model_type == ModelType.FEDERATE:
154
+ def ref(name):
155
+ dependencies.add(name)
156
+ return name
157
+ kwargs["ref"] = ref
158
+
159
+ try:
160
+ query = await asyncio.to_thread(u.render_string, raw_query, kwargs)
161
+ except Exception as e:
162
+ raise u.FileExecutionError(f'Failed to compile sql model "{self.name}"', e)
163
+
164
+ compiled_query = SqlModelQuery(query, configuration)
165
+ return compiled_query, dependencies
166
+
167
+ async def _compile_python_model(self, ctx: dict[str, Any], ctx_args: ContextArgs) -> tuple[PyModelQuery, set]:
168
+ assert(isinstance(self.query_file.raw_query, RawPyQuery))
169
+ sqrl_args = ModelDepsArgs(ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.args, ctx)
170
+ try:
171
+ dependencies = await asyncio.to_thread(self.query_file.raw_query.dependencies_func, sqrl_args)
172
+ except Exception as e:
173
+ raise u.FileExecutionError(f'Failed to run "{c.DEP_FUNC}" function for python model "{self.name}"', e)
174
+
175
+ dbview_conn_name = self._get_dbview_conn_name()
176
+ connections = ConnectionSetIO.obj.get_connections_as_dict()
177
+ ref = lambda x: self.upstreams[x].result
178
+ sqrl_args = ModelArgs(ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.args,
179
+ ctx, dbview_conn_name, connections, ref, set(dependencies))
180
+
181
+ def compiled_query():
182
+ try:
183
+ return self.query_file.raw_query.query(sqrl=sqrl_args)
184
+ except Exception as e:
185
+ raise u.FileExecutionError(f'Failed to run "{c.MAIN_FUNC}" function for python model "{self.name}"', e)
186
+
187
+ return PyModelQuery(compiled_query), dependencies
188
+
189
+ async def compile(self, ctx: dict[str, Any], ctx_args: ContextArgs, models_dict: dict[str, Model], recurse: bool) -> None:
190
+ if self.compiled_query is not None:
191
+ return
192
+ else:
193
+ self.compiled_query = WorkInProgress()
194
+
195
+ start = time.time()
196
+ if self.query_file.query_type == QueryType.SQL:
197
+ compiled_query, dependencies = await self._compile_sql_model(ctx, ctx_args)
198
+ elif self.query_file.query_type == QueryType.PYTHON:
199
+ compiled_query, dependencies = await self._compile_python_model(ctx, ctx_args)
200
+ else:
201
+ raise NotImplementedError(f"Query type not supported: {self.query_file.query_type}")
202
+
203
+ self.compiled_query = compiled_query
204
+ self.wait_count = len(dependencies)
205
+ timer.add_activity_time(f"compiling model '{self.name}'", start)
206
+
207
+ if not recurse:
208
+ return
209
+
210
+ dep_models = [models_dict[x] for x in dependencies]
211
+ coroutines = []
212
+ for dep_model in dep_models:
213
+ self._add_upstream(dep_model)
214
+ coro = dep_model.compile(ctx, ctx_args, models_dict, recurse)
215
+ coroutines.append(coro)
216
+ await asyncio.gather(*coroutines)
217
+
218
+ def validate_no_cycles(self, depencency_path: set[str]) -> set[str]:
219
+ if self.confirmed_no_cycles:
220
+ return
221
+
222
+ if self.name in depencency_path:
223
+ raise u.ConfigurationError(f'Cycle found in model dependency graph')
224
+
225
+ terminal_nodes = set()
226
+ if len(self.upstreams) == 0:
227
+ terminal_nodes.add(self.name)
228
+ else:
229
+ new_path = set(depencency_path)
230
+ new_path.add(self.name)
231
+ for dep_model in self.upstreams.values():
232
+ terminal_nodes_under_dep = dep_model.validate_no_cycles(new_path)
233
+ terminal_nodes = terminal_nodes.union(terminal_nodes_under_dep)
234
+
235
+ self.confirmed_no_cycles = True
236
+ return terminal_nodes
237
+
238
+ async def _run_sql_model(self, conn: sqlite3.Connection) -> None:
239
+ assert(isinstance(self.compiled_query, SqlModelQuery))
240
+ config = self.compiled_query.config
241
+ query = self.compiled_query.query
242
+
243
+ if self.query_file.model_type == ModelType.DBVIEW:
244
+ def run_sql_query():
245
+ try:
246
+ return ConnectionSetIO.obj.run_sql_query_from_conn_name(query, config.connection_name)
247
+ except RuntimeError as e:
248
+ raise u.FileExecutionError(f'Failed to run dbview sql model "{self.name}"', e)
249
+
250
+ df = await asyncio.to_thread(run_sql_query)
251
+ await asyncio.to_thread(df.to_sql, self.name, conn, index=False)
252
+ if self.needs_pandas or self.is_target:
253
+ self.result = df
254
+ elif self.query_file.model_type == ModelType.FEDERATE:
255
+ def create_table():
256
+ create_query = config.get_sql_for_create(self.name, query)
257
+ try:
258
+ return conn.execute(create_query)
259
+ except Exception as e:
260
+ raise u.FileExecutionError(f'Failed to run federate sql model "{self.name}"', e)
261
+
262
+ await asyncio.to_thread(create_table)
263
+ if self.needs_pandas or self.is_target:
264
+ query = f"SELECT * FROM {self.name}"
265
+ self.result = await asyncio.to_thread(pd.read_sql, query, conn)
266
+
267
+ async def _run_python_model(self, conn: sqlite3.Connection) -> None:
268
+ assert(isinstance(self.compiled_query, PyModelQuery))
269
+
270
+ df = await asyncio.to_thread(self.compiled_query.query)
271
+ if self.needs_sql_table:
272
+ await asyncio.to_thread(df.to_sql, self.name, conn, index=False)
273
+ if self.needs_pandas or self.is_target:
274
+ self.result = df
275
+
276
+ async def run_model(self, conn: sqlite3.Connection) -> None:
277
+ start = time.time()
278
+ if self.query_file.query_type == QueryType.SQL:
279
+ await self._run_sql_model(conn)
280
+ elif self.query_file.query_type == QueryType.PYTHON:
281
+ await self._run_python_model(conn)
282
+ timer.add_activity_time(f"running model '{self.name}'", start)
283
+
284
+ coroutines = []
285
+ for model in self.downstreams.values():
286
+ coroutines.append(model.trigger(conn))
287
+ await asyncio.gather(*coroutines)
288
+
289
+ async def trigger(self, conn: sqlite3.Connection) -> None:
290
+ self.wait_count -= 1
291
+ if (self.wait_count == 0):
292
+ await self.run_model(conn)
293
+
294
+
295
+ @dataclass
296
+ class DAG:
297
+ dataset: DatasetsConfig
298
+ target_model: Model
299
+ models_dict: dict[str, Model]
300
+ parameter_set: Optional[ParameterSet] = field(default=None, init=False)
301
+
302
+ def apply_selections(
303
+ self, user: Optional[User], selections: dict[str, str], *, updates_only: bool = False, request_version: Optional[int] = None
304
+ ) -> None:
305
+ start = time.time()
306
+ dataset_params = self.dataset.parameters
307
+ parameter_set = ParameterConfigsSetIO.obj.apply_selections(dataset_params, selections, user, updates_only=updates_only,
308
+ request_version=request_version)
309
+ self.parameter_set = parameter_set
310
+ timer.add_activity_time(f"applying selections for dataset", start)
311
+
312
+ def _compile_context(self, context_func: ContextFunc, user: Optional[User]) -> tuple[dict[str, Any], ContextArgs]:
313
+ start = time.time()
314
+ context = {}
315
+ param_args = ParameterConfigsSetIO.args
316
+ prms = self.parameter_set.get_parameters_as_dict()
317
+ args = ContextArgs(param_args.proj_vars, param_args.env_vars, user, prms, self.dataset.args)
318
+ try:
319
+ context_func(ctx=context, sqrl=args)
320
+ except Exception as e:
321
+ raise u.FileExecutionError(f'Failed to run {c.CONTEXT_FILE} for dataset "{self.dataset}"', e)
322
+ timer.add_activity_time(f"running context.py for dataset", start)
323
+ return context, args
324
+
325
+ async def _compile_models(self, context: dict[str, Any], ctx_args: ContextArgs, recurse: bool) -> None:
326
+ await self.target_model.compile(context, ctx_args, self.models_dict, recurse)
327
+
328
+ def _validate_no_cycles(self) -> set[str]:
329
+ start = time.time()
330
+ terminal_nodes = self.target_model.validate_no_cycles(set())
331
+ timer.add_activity_time(f"validating no cycles in models dependencies", start)
332
+ return terminal_nodes
333
+
334
+ async def _run_models(self, terminal_nodes: set[str]) -> None:
335
+ conn = sqlite3.connect(":memory:", check_same_thread=False)
336
+ try:
337
+ coroutines = []
338
+ for model_name in terminal_nodes:
339
+ model = self.models_dict[model_name]
340
+ coroutines.append(model.run_model(conn))
341
+ await asyncio.gather(*coroutines)
342
+ finally:
343
+ conn.close()
344
+
345
+ async def execute(
346
+ self, context_func: ContextFunc, user: Optional[User], selections: dict[str, str], *, request_version: Optional[int] = None,
347
+ runquery: bool = True, recurse: bool = True
348
+ ) -> None:
349
+ recurse = (recurse or runquery)
350
+
351
+ self.apply_selections(user, selections, request_version=request_version)
352
+
353
+ context, ctx_args = self._compile_context(context_func, user)
354
+
355
+ await self._compile_models(context, ctx_args, recurse)
356
+
357
+ terminal_nodes = self._validate_no_cycles()
358
+
359
+ if runquery:
360
+ await self._run_models(terminal_nodes)
361
+
362
+
363
+ class ModelsIO:
364
+ raw_queries_by_model: dict[str, QueryFile]
365
+ context_func: ContextFunc
366
+
367
+ @classmethod
368
+ def LoadFiles(cls) -> None:
369
+ start = time.time()
370
+ cls.raw_queries_by_model = {}
371
+
372
+ def populate_raw_queries_for_type(folder_path: Path, model_type: ModelType):
373
+
374
+ def populate_from_file(dp, file):
375
+ query_type = None
376
+ filepath = os.path.join(dp, file)
377
+ file_stem, extension = os.path.splitext(file)
378
+ if extension == '.py':
379
+ query_type = QueryType.PYTHON
380
+ module = pm.PyModule(filepath)
381
+ dependencies_func = module.get_func_or_class(c.DEP_FUNC, default_attr=lambda x: [])
382
+ raw_query = RawPyQuery(module.get_func_or_class(c.MAIN_FUNC), dependencies_func)
383
+ elif extension == '.sql':
384
+ query_type = QueryType.SQL
385
+ raw_query = RawSqlQuery(u.read_file(filepath))
386
+
387
+ if query_type is not None:
388
+ query_file = QueryFile(filepath, model_type, query_type, raw_query)
389
+ if file_stem in cls.raw_queries_by_model:
390
+ conflicts = [cls.raw_queries_by_model[file_stem].filepath, filepath]
391
+ raise u.ConfigurationError(f"Multiple models found for '{file_stem}': {conflicts}")
392
+ cls.raw_queries_by_model[file_stem] = query_file
393
+
394
+ for dp, _, filenames in os.walk(folder_path):
395
+ for file in filenames:
396
+ populate_from_file(dp, file)
397
+
398
+ dbviews_path = u.join_paths(c.MODELS_FOLDER, c.DBVIEWS_FOLDER)
399
+ populate_raw_queries_for_type(dbviews_path, ModelType.DBVIEW)
400
+
401
+ federates_path = u.join_paths(c.MODELS_FOLDER, c.FEDERATES_FOLDER)
402
+ populate_raw_queries_for_type(federates_path, ModelType.FEDERATE)
403
+
404
+ context_path = u.join_paths(c.PYCONFIG_FOLDER, c.CONTEXT_FILE)
405
+ cls.context_func = pm.PyModule(context_path).get_func_or_class(c.MAIN_FUNC, default_attr=lambda x, y: None)
406
+
407
+ timer.add_activity_time("loading models and/or context.py", start)
408
+
409
+ @classmethod
410
+ def GenerateDAG(cls, dataset: str, *, target_model_name: Optional[str] = None, always_pandas: bool = False) -> DAG:
411
+ models_dict = {key: Model(key, val, needs_pandas=always_pandas) for key, val in cls.raw_queries_by_model.items()}
412
+
413
+ dataset_config = ManifestIO.obj.datasets[dataset]
414
+ target_model_name = dataset_config.model if target_model_name is None else target_model_name
415
+ target_model = models_dict[target_model_name]
416
+ target_model.is_target = True
417
+
418
+ return DAG(dataset_config, target_model, models_dict)
419
+
420
+ @classmethod
421
+ async def WriteDatasetOutputsGivenTestSet(cls, dataset: str, select: str, test_set: str, runquery: bool, recurse: bool) -> Any:
422
+ test_set_conf = ManifestIO.obj.selection_test_sets[test_set]
423
+ user = User("")
424
+ for key, val in test_set_conf.user_attributes.items():
425
+ setattr(user, key, val)
426
+ selections = test_set_conf.parameters
427
+ dag = cls.GenerateDAG(dataset, target_model_name=select, always_pandas=True)
428
+ await dag.execute(cls.context_func, user, selections, runquery=runquery, recurse=recurse)
429
+
430
+ output_folder = u.join_paths(c.TARGET_FOLDER, c.COMPILE_FOLDER, test_set, dataset)
431
+ if os.path.exists(output_folder):
432
+ shutil.rmtree(output_folder)
433
+
434
+ def write_model_outputs(model: Model) -> None:
435
+ subfolder = c.DBVIEWS_FOLDER if model.query_file.model_type == ModelType.DBVIEW else c.FEDERATES_FOLDER
436
+ subpath = u.join_paths(output_folder, subfolder)
437
+ os.makedirs(subpath)
438
+ if isinstance(model.compiled_query, SqlModelQuery):
439
+ output_filepath = u.join_paths(subpath, model.name+'.sql')
440
+ query = model.compiled_query.query
441
+ with open(output_filepath, 'w') as f:
442
+ f.write(query)
443
+ if runquery and isinstance(model.result, pd.DataFrame):
444
+ output_filepath = u.join_paths(subpath, model.name+'.csv')
445
+ model.result.to_csv(output_filepath, index=False)
446
+
447
+ target_model = dag.models_dict[select]
448
+ stack = [target_model]
449
+ all_model_names = set()
450
+ while stack:
451
+ curr_model = stack.pop()
452
+ all_model_names.add(curr_model.name)
453
+ for dep_model in curr_model.downstreams.values():
454
+ if dep_model.name not in all_model_names:
455
+ stack.append(dep_model.name)
456
+
457
+ coroutines = [asyncio.to_thread(write_model_outputs, dag.models_dict[name]) for name in all_model_names]
458
+ await asyncio.gather(*coroutines)
459
+ return target_model.compiled_query.query
460
+
461
+ @classmethod
462
+ async def WriteOutputs(
463
+ cls, dataset: Optional[str], select: Optional[str], all_test_sets: bool, test_set: Optional[str], runquery: bool
464
+ ) -> None:
465
+ if test_set is None:
466
+ test_set = ManifestIO.obj.settings.get(c.TEST_SET_DEFAULT_USED_SETTING, c.DEFAULT_TEST_SET_NAME)
467
+
468
+ if all_test_sets:
469
+ test_sets = ManifestIO.obj.selection_test_sets.keys()
470
+ else:
471
+ test_sets = [test_set]
472
+
473
+ recurse = True
474
+ dataset_configs = ManifestIO.obj.datasets
475
+ if dataset is None:
476
+ selected_models = [(dataset.name, dataset.model) for dataset in dataset_configs.values()]
477
+ else:
478
+ if select is None:
479
+ select = dataset_configs[dataset].model
480
+ else:
481
+ recurse = False
482
+ selected_models = [(dataset, select)]
483
+
484
+ coroutines = []
485
+ for test_set in test_sets:
486
+ for dataset, select in selected_models:
487
+ coroutine = cls.WriteDatasetOutputsGivenTestSet(dataset, select, test_set, runquery, recurse)
488
+ coroutines.append(coroutine)
489
+
490
+ queries = await asyncio.gather(*coroutines)
491
+ if not recurse and len(queries) == 1 and isinstance(queries[0], str):
492
+ print()
493
+ print(queries[0])
494
+ print()
495
+
@@ -0,0 +1,26 @@
1
+ import git, shutil, os, time
2
+
3
+ from . import _constants as c, _utils as u
4
+ from ._manifest import ManifestIO
5
+ from ._timer import timer
6
+
7
+
8
+ class PackageLoaderIO:
9
+
10
+ @classmethod
11
+ def LoadPackages(cls, *, reload: bool = False) -> None:
12
+ start = time.time()
13
+ # If reload, delete the modules directory (if it exists). It will be recreated later
14
+ if reload and os.path.exists(c.PACKAGES_FOLDER):
15
+ shutil.rmtree(c.PACKAGES_FOLDER)
16
+
17
+ package_repos = ManifestIO.obj.packages
18
+ for repo in package_repos:
19
+ target_dir = f"{c.PACKAGES_FOLDER}/{repo.directory}"
20
+ if not os.path.exists(target_dir):
21
+ try:
22
+ git.Repo.clone_from(repo.git_url, target_dir, branch=repo.revision, depth=1)
23
+ except git.GitCommandError as e:
24
+ raise u.ConfigurationError(f"Git clone of package failed for this repository: {repo.git_url}") from e
25
+
26
+ timer.add_activity_time("loading packages", start)