squirrels 0.4.0__py3-none-any.whl → 0.5.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of squirrels might be problematic. Click here for more details.

Files changed (80) hide show
  1. squirrels/__init__.py +10 -6
  2. squirrels/_api_response_models.py +93 -44
  3. squirrels/_api_server.py +571 -219
  4. squirrels/_auth.py +451 -0
  5. squirrels/_command_line.py +61 -20
  6. squirrels/_connection_set.py +38 -25
  7. squirrels/_constants.py +44 -34
  8. squirrels/_dashboards_io.py +34 -16
  9. squirrels/_exceptions.py +57 -0
  10. squirrels/_initializer.py +117 -44
  11. squirrels/_manifest.py +124 -62
  12. squirrels/_model_builder.py +111 -0
  13. squirrels/_model_configs.py +74 -0
  14. squirrels/_model_queries.py +52 -0
  15. squirrels/_models.py +860 -354
  16. squirrels/_package_loader.py +8 -4
  17. squirrels/_parameter_configs.py +45 -65
  18. squirrels/_parameter_sets.py +15 -13
  19. squirrels/_project.py +561 -0
  20. squirrels/_py_module.py +4 -3
  21. squirrels/_seeds.py +35 -16
  22. squirrels/_sources.py +106 -0
  23. squirrels/_utils.py +166 -63
  24. squirrels/_version.py +1 -1
  25. squirrels/arguments/init_time_args.py +78 -15
  26. squirrels/arguments/run_time_args.py +62 -101
  27. squirrels/dashboards.py +4 -4
  28. squirrels/data_sources.py +94 -162
  29. squirrels/dataset_result.py +86 -0
  30. squirrels/dateutils.py +4 -4
  31. squirrels/package_data/base_project/.env +30 -0
  32. squirrels/package_data/base_project/.env.example +30 -0
  33. squirrels/package_data/base_project/.gitignore +3 -2
  34. squirrels/package_data/base_project/assets/expenses.db +0 -0
  35. squirrels/package_data/base_project/connections.yml +11 -3
  36. squirrels/package_data/base_project/dashboards/dashboard_example.py +15 -13
  37. squirrels/package_data/base_project/dashboards/dashboard_example.yml +22 -0
  38. squirrels/package_data/base_project/docker/.dockerignore +5 -2
  39. squirrels/package_data/base_project/docker/Dockerfile +3 -3
  40. squirrels/package_data/base_project/docker/compose.yml +1 -1
  41. squirrels/package_data/base_project/duckdb_init.sql +9 -0
  42. squirrels/package_data/base_project/macros/macros_example.sql +15 -0
  43. squirrels/package_data/base_project/models/builds/build_example.py +26 -0
  44. squirrels/package_data/base_project/models/builds/build_example.sql +16 -0
  45. squirrels/package_data/base_project/models/builds/build_example.yml +55 -0
  46. squirrels/package_data/base_project/models/dbviews/dbview_example.sql +12 -22
  47. squirrels/package_data/base_project/models/dbviews/dbview_example.yml +26 -0
  48. squirrels/package_data/base_project/models/federates/federate_example.py +38 -15
  49. squirrels/package_data/base_project/models/federates/federate_example.sql +16 -2
  50. squirrels/package_data/base_project/models/federates/federate_example.yml +65 -0
  51. squirrels/package_data/base_project/models/sources.yml +39 -0
  52. squirrels/package_data/base_project/parameters.yml +36 -21
  53. squirrels/package_data/base_project/pyconfigs/connections.py +6 -11
  54. squirrels/package_data/base_project/pyconfigs/context.py +20 -33
  55. squirrels/package_data/base_project/pyconfigs/parameters.py +19 -21
  56. squirrels/package_data/base_project/pyconfigs/user.py +23 -0
  57. squirrels/package_data/base_project/seeds/seed_categories.yml +15 -0
  58. squirrels/package_data/base_project/seeds/seed_subcategories.csv +15 -15
  59. squirrels/package_data/base_project/seeds/seed_subcategories.yml +21 -0
  60. squirrels/package_data/base_project/squirrels.yml.j2 +17 -40
  61. squirrels/parameters.py +20 -20
  62. {squirrels-0.4.0.dist-info → squirrels-0.5.0b1.dist-info}/METADATA +31 -32
  63. squirrels-0.5.0b1.dist-info/RECORD +70 -0
  64. {squirrels-0.4.0.dist-info → squirrels-0.5.0b1.dist-info}/WHEEL +1 -1
  65. squirrels-0.5.0b1.dist-info/entry_points.txt +3 -0
  66. {squirrels-0.4.0.dist-info → squirrels-0.5.0b1.dist-info/licenses}/LICENSE +1 -1
  67. squirrels/_authenticator.py +0 -85
  68. squirrels/_environcfg.py +0 -84
  69. squirrels/package_data/assets/favicon.ico +0 -0
  70. squirrels/package_data/assets/index.css +0 -1
  71. squirrels/package_data/assets/index.js +0 -58
  72. squirrels/package_data/base_project/dashboards.yml +0 -10
  73. squirrels/package_data/base_project/env.yml +0 -29
  74. squirrels/package_data/base_project/models/dbviews/dbview_example.py +0 -47
  75. squirrels/package_data/base_project/pyconfigs/auth.py +0 -45
  76. squirrels/package_data/templates/index.html +0 -18
  77. squirrels/project.py +0 -378
  78. squirrels/user_base.py +0 -55
  79. squirrels-0.4.0.dist-info/RECORD +0 -60
  80. squirrels-0.4.0.dist-info/entry_points.txt +0 -4
squirrels/_project.py ADDED
@@ -0,0 +1,561 @@
1
+ from dotenv import dotenv_values
2
+ from uuid import uuid4
3
+ import asyncio, typing as t, functools as ft, shutil, json, os
4
+ import logging as l, matplotlib.pyplot as plt, networkx as nx, polars as pl
5
+ import sqlglot, sqlglot.expressions
6
+
7
+ from ._auth import Authenticator, BaseUser
8
+ from ._model_builder import ModelBuilder
9
+ from ._exceptions import InvalidInputError, ConfigurationError
10
+ from . import _utils as u, _constants as c, _manifest as mf, _connection_set as cs, _api_response_models as arm
11
+ from . import _seeds as s, _models as m, _model_configs as mc, _model_queries as mq, _sources as so
12
+ from . import _parameter_sets as ps, _dashboards_io as d, dashboards as dash, dataset_result as dr
13
+
14
+ T = t.TypeVar("T", bound=dash.Dashboard)
15
+ M = t.TypeVar("M", bound=m.DataModel)
16
+
17
+
18
+ class _CustomJsonFormatter(l.Formatter):
19
+ def format(self, record: l.LogRecord) -> str:
20
+ super().format(record)
21
+ info = {
22
+ "timestamp": self.formatTime(record),
23
+ "project_id": record.name,
24
+ "level": record.levelname,
25
+ "message": record.getMessage(),
26
+ "thread": record.thread,
27
+ "thread_name": record.threadName,
28
+ "process": record.process,
29
+ **record.__dict__.get("info", {})
30
+ }
31
+ output = {
32
+ "data": record.__dict__.get("data", {}),
33
+ "info": info
34
+ }
35
+ return json.dumps(output)
36
+
37
+
38
+ class SquirrelsProject:
39
+ """
40
+ Initiate an instance of this class to interact with a Squirrels project through Python code. For example this can be handy to experiment with the datasets produced by Squirrels in a Jupyter notebook.
41
+ """
42
+
43
+ def __init__(self, *, filepath: str = ".", log_file: str | None = c.LOGS_FILE, log_level: str = "INFO", log_format: str = "text") -> None:
44
+ """
45
+ Constructor for SquirrelsProject class. Loads the file contents of the Squirrels project into memory as member fields.
46
+
47
+ Arguments:
48
+ filepath: The path to the Squirrels project file. Defaults to the current working directory.
49
+ log_level: The logging level to use. Options are "DEBUG", "INFO", and "WARNING". Default is "INFO".
50
+ log_file: The name of the log file to write to from the "logs/" subfolder. If None or empty string, then file logging is disabled. Default is "squirrels.log".
51
+ log_format: The format of the log records. Options are "text" and "json". Default is "text".
52
+ """
53
+ self._filepath = filepath
54
+ self._logger = self._get_logger(self._filepath, log_file, log_level, log_format)
55
+
56
+ def _get_logger(self, base_path: str, log_file: str | None, log_level: str, log_format: str) -> u.Logger:
57
+ logger = u.Logger(name=uuid4().hex)
58
+ logger.setLevel(log_level.upper())
59
+
60
+ handler = l.StreamHandler()
61
+ handler.setLevel("WARNING")
62
+ handler.setFormatter(l.Formatter("%(levelname)s: %(asctime)s - %(message)s"))
63
+ logger.addHandler(handler)
64
+
65
+ if log_format.lower() == "json":
66
+ formatter = _CustomJsonFormatter()
67
+ elif log_format.lower() == "text":
68
+ formatter = l.Formatter("[%(name)s] %(asctime)s - %(levelname)s - %(message)s")
69
+ else:
70
+ raise ValueError("log_format must be either 'text' or 'json'")
71
+
72
+ if log_file:
73
+ path = u.Path(base_path, c.LOGS_FOLDER, log_file)
74
+ path.parent.mkdir(parents=True, exist_ok=True)
75
+
76
+ handler = l.FileHandler(path)
77
+ handler.setFormatter(formatter)
78
+ logger.addHandler(handler)
79
+
80
+ return logger
81
+
82
+ @ft.cached_property
83
+ def _env_vars(self) -> dict[str, str]:
84
+ dotenv_files = [c.DOTENV_FILE, c.DOTENV_LOCAL_FILE]
85
+ dotenv_vars = {}
86
+ for file in dotenv_files:
87
+ dotenv_vars.update({k: v for k, v in dotenv_values(f"{self._filepath}/{file}").items() if v is not None})
88
+ return {**os.environ, **dotenv_vars}
89
+
90
+ @ft.cached_property
91
+ def _manifest_cfg(self) -> mf.ManifestConfig:
92
+ return mf.ManifestIO.load_from_file(self._logger, self._filepath, self._env_vars)
93
+
94
+ @ft.cached_property
95
+ def _seeds(self) -> s.Seeds:
96
+ return s.SeedsIO.load_files(self._logger, self._filepath, self._env_vars)
97
+
98
+ @ft.cached_property
99
+ def _sources(self) -> so.Sources:
100
+ return so.SourcesIO.load_file(self._logger, self._filepath, self._env_vars)
101
+
102
+ @ft.cached_property
103
+ def _build_model_files(self) -> dict[str, mq.QueryFileWithConfig]:
104
+ return m.ModelsIO.load_build_files(self._logger, self._filepath)
105
+
106
+ @ft.cached_property
107
+ def _dbview_model_files(self) -> dict[str, mq.QueryFileWithConfig]:
108
+ return m.ModelsIO.load_dbview_files(self._logger, self._filepath, self._env_vars)
109
+
110
+ @ft.cached_property
111
+ def _federate_model_files(self) -> dict[str, mq.QueryFileWithConfig]:
112
+ return m.ModelsIO.load_federate_files(self._logger, self._filepath)
113
+
114
+ @ft.cached_property
115
+ def _context_func(self) -> m.ContextFunc:
116
+ return m.ModelsIO.load_context_func(self._logger, self._filepath)
117
+
118
+ @ft.cached_property
119
+ def _dashboards(self) -> dict[str, d.DashboardDefinition]:
120
+ return d.DashboardsIO.load_files(self._logger, self._filepath)
121
+
122
+ @ft.cached_property
123
+ def _conn_args(self) -> cs.ConnectionsArgs:
124
+ return cs.ConnectionSetIO.load_conn_py_args(self._logger, self._filepath, self._env_vars, self._manifest_cfg)
125
+
126
+ @ft.cached_property
127
+ def _conn_set(self) -> cs.ConnectionSet:
128
+ return cs.ConnectionSetIO.load_from_file(self._logger, self._filepath, self._manifest_cfg, self._conn_args)
129
+
130
+ @ft.cached_property
131
+ def _auth(self) -> Authenticator:
132
+ return Authenticator(self._logger, self._filepath, self._env_vars)
133
+
134
+ @ft.cached_property
135
+ def _param_args(self) -> ps.ParametersArgs:
136
+ return ps.ParameterConfigsSetIO.get_param_args(self._conn_args)
137
+
138
+ @ft.cached_property
139
+ def _param_cfg_set(self) -> ps.ParameterConfigsSet:
140
+ return ps.ParameterConfigsSetIO.load_from_file(
141
+ self._logger, self._filepath, self._manifest_cfg, self._seeds, self._conn_set, self._param_args
142
+ )
143
+
144
+ @ft.cached_property
145
+ def _j2_env(self) -> u.EnvironmentWithMacros:
146
+ return u.EnvironmentWithMacros(self._logger, loader=u.j2.FileSystemLoader(self._filepath))
147
+
148
+ @ft.cached_property
149
+ def _duckdb_venv_path(self) -> str:
150
+ duckdb_filepath_setting_val = self._env_vars.get(c.SQRL_DUCKDB_VENV_DB_FILE_PATH, f"{c.TARGET_FOLDER}/{c.DUCKDB_VENV_FILE}")
151
+ return str(u.Path(self._filepath, duckdb_filepath_setting_val))
152
+
153
+ def close(self) -> None:
154
+ """
155
+ Deliberately close any open resources within the Squirrels project, such as database connections (instead of relying on the garbage collector).
156
+ """
157
+ self._conn_set.dispose()
158
+ self._auth.close()
159
+
160
+ def __exit__(self, exc_type, exc_val, traceback):
161
+ self.close()
162
+
163
+
164
+ def _add_model(self, models_dict: dict[str, M], model: M) -> None:
165
+ if model.name in models_dict:
166
+ raise ConfigurationError(f"Names across all models must be unique. Model '{model.name}' is duplicated")
167
+ models_dict[model.name] = model
168
+
169
+
170
+ def _get_static_models(self) -> dict[str, m.StaticModel]:
171
+ models_dict: dict[str, m.StaticModel] = {}
172
+
173
+ seeds_dict = self._seeds.get_dataframes()
174
+ for key, seed in seeds_dict.items():
175
+ self._add_model(models_dict, m.Seed(key, seed.config, seed.df, logger=self._logger, env_vars=self._env_vars, conn_set=self._conn_set))
176
+
177
+ for source_name, source_config in self._sources.sources.items():
178
+ self._add_model(models_dict, m.SourceModel(source_name, source_config, logger=self._logger, env_vars=self._env_vars, conn_set=self._conn_set))
179
+
180
+ for name, val in self._build_model_files.items():
181
+ model = m.BuildModel(name, val.config, val.query_file, logger=self._logger, env_vars=self._env_vars, conn_set=self._conn_set, j2_env=self._j2_env)
182
+ self._add_model(models_dict, model)
183
+
184
+ return models_dict
185
+
186
+
187
+ async def build(self, *, full_refresh: bool = False, select: str | None = None, stage_file: bool = False) -> None:
188
+ """
189
+ Build the virtual data environment for the Squirrels project
190
+
191
+ Arguments:
192
+ full_refresh: Whether to drop all tables and rebuild the virtual data environment from scratch. Default is False.
193
+ stage_file: Whether to stage the DuckDB file to overwrite the existing one later if the virtual data environment is in use. Default is False.
194
+ """
195
+ models_dict: dict[str, m.StaticModel] = self._get_static_models()
196
+ builder = ModelBuilder(self._duckdb_venv_path, self._conn_set, models_dict, self._conn_args, self._logger)
197
+ await builder.build(full_refresh, select, stage_file)
198
+
199
+ def _get_models_dict(self, always_python_df: bool) -> dict[str, m.DataModel]:
200
+ models_dict: dict[str, m.DataModel] = dict(self._get_static_models())
201
+
202
+ for name, val in self._dbview_model_files.items():
203
+ self._add_model(models_dict, m.DbviewModel(
204
+ name, val.config, val.query_file, logger=self._logger, env_vars=self._env_vars, conn_set=self._conn_set, j2_env=self._j2_env
205
+ ))
206
+ models_dict[name].needs_python_df = always_python_df
207
+
208
+ for name, val in self._federate_model_files.items():
209
+ self._add_model(models_dict, m.FederateModel(
210
+ name, val.config, val.query_file, logger=self._logger, env_vars=self._env_vars, conn_set=self._conn_set, j2_env=self._j2_env
211
+ ))
212
+ models_dict[name].needs_python_df = always_python_df
213
+
214
+ return models_dict
215
+
216
+ def _generate_dag(self, dataset: str, *, target_model_name: str | None = None, always_python_df: bool = False) -> m.DAG:
217
+ models_dict = self._get_models_dict(always_python_df)
218
+
219
+ dataset_config = self._manifest_cfg.datasets[dataset]
220
+ target_model_name = dataset_config.model if target_model_name is None else target_model_name
221
+ target_model = models_dict[target_model_name]
222
+ target_model.is_target = True
223
+ dag = m.DAG(dataset_config, target_model, models_dict, self._duckdb_venv_path, self._logger)
224
+
225
+ return dag
226
+
227
+ def _generate_dag_with_fake_target(self, sql_query: str | None) -> m.DAG:
228
+ models_dict = self._get_models_dict(always_python_df=False)
229
+
230
+ if sql_query is None:
231
+ dependencies = set(models_dict.keys())
232
+ else:
233
+ dependencies, parsed = u.parse_dependent_tables(sql_query, models_dict.keys())
234
+
235
+ substitutions = {}
236
+ for model_name in dependencies:
237
+ model = models_dict[model_name]
238
+ if isinstance(model, m.SourceModel) and not model.model_config.load_to_duckdb:
239
+ raise InvalidInputError(203, f"Source model '{model_name}' cannot be queried with DuckDB")
240
+ if isinstance(model, (m.SourceModel, m.BuildModel)):
241
+ substitutions[model_name] = f"venv.{model_name}"
242
+
243
+ sql_query = parsed.transform(
244
+ lambda node: sqlglot.expressions.Table(this=substitutions[node.name])
245
+ if isinstance(node, sqlglot.expressions.Table) and node.name in substitutions
246
+ else node
247
+ ).sql()
248
+
249
+ model_config = mc.FederateModelConfig(depends_on=dependencies)
250
+ query_file = mq.SqlQueryFile("", sql_query or "")
251
+ fake_target_model = m.FederateModel(
252
+ "__fake_target", model_config, query_file, logger=self._logger, env_vars=self._env_vars, conn_set=self._conn_set, j2_env=self._j2_env
253
+ )
254
+ fake_target_model.is_target = True
255
+ dag = m.DAG(None, fake_target_model, models_dict, self._duckdb_venv_path, self._logger)
256
+ return dag
257
+
258
+ def _draw_dag(self, dag: m.DAG, output_folder: u.Path) -> None:
259
+ color_map = {
260
+ m.ModelType.SEED: "green", m.ModelType.DBVIEW: "red", m.ModelType.FEDERATE: "skyblue",
261
+ m.ModelType.BUILD: "purple", m.ModelType.SOURCE: "orange"
262
+ }
263
+
264
+ G = dag.to_networkx_graph()
265
+
266
+ fig, _ = plt.subplots()
267
+ pos = nx.multipartite_layout(G, subset_key="layer")
268
+ colors = [color_map[node[1]] for node in G.nodes(data="model_type")] # type: ignore
269
+ nx.draw(G, pos=pos, node_shape='^', node_size=1000, node_color=colors, arrowsize=20)
270
+
271
+ y_values = [val[1] for val in pos.values()]
272
+ scale = max(y_values) - min(y_values) if len(y_values) > 0 else 0
273
+ label_pos = {key: (val[0], val[1]-0.002-0.1*scale) for key, val in pos.items()}
274
+ nx.draw_networkx_labels(G, pos=label_pos, font_size=8)
275
+
276
+ fig.tight_layout()
277
+ plt.margins(x=0.1, y=0.1)
278
+ fig.savefig(u.Path(output_folder, "dag.png"))
279
+ plt.close(fig)
280
+
281
+ async def _get_compiled_dag(self, *, sql_query: str | None = None, selections: dict[str, t.Any] = {}, user: BaseUser | None = None) -> m.DAG:
282
+ dag = self._generate_dag_with_fake_target(sql_query)
283
+
284
+ default_traits = self._manifest_cfg.get_default_traits()
285
+ await dag.execute(self._param_args, self._param_cfg_set, self._context_func, user, selections, runquery=False, default_traits=default_traits)
286
+ return dag
287
+
288
+ def _get_all_connections(self) -> list[arm.ConnectionItemModel]:
289
+ connections = []
290
+ for conn_name, conn_props in self._conn_set.get_connections_as_dict().items():
291
+ if isinstance(conn_props, mf.ConnectionProperties):
292
+ label = conn_props.label if conn_props.label is not None else conn_name
293
+ connections.append(arm.ConnectionItemModel(name=conn_name, label=label))
294
+ return connections
295
+
296
+ def _get_all_data_models(self, compiled_dag: m.DAG) -> list[arm.DataModelItem]:
297
+ return compiled_dag.get_all_data_models()
298
+
299
+ async def get_all_data_models(self) -> list[arm.DataModelItem]:
300
+ """
301
+ Get all data models in the project
302
+
303
+ Returns:
304
+ A list of DataModelItem objects
305
+ """
306
+ compiled_dag = await self._get_compiled_dag()
307
+ return self._get_all_data_models(compiled_dag)
308
+
309
+ def _get_all_data_lineage(self, compiled_dag: m.DAG) -> list[arm.LineageRelation]:
310
+ all_lineage = compiled_dag.get_all_model_lineage()
311
+
312
+ # Add dataset nodes to the lineage
313
+ for dataset in self._manifest_cfg.datasets.values():
314
+ target_dataset = arm.LineageNode(name=dataset.name, type="dataset")
315
+ source_model = arm.LineageNode(name=dataset.model, type="model")
316
+ all_lineage.append(arm.LineageRelation(type="runtime", source=source_model, target=target_dataset))
317
+
318
+ # Add dashboard nodes to the lineage
319
+ for dashboard in self._dashboards.values():
320
+ target_dashboard = arm.LineageNode(name=dashboard.dashboard_name, type="dashboard")
321
+ datasets = set(x.dataset for x in dashboard.config.depends_on)
322
+ for dataset in datasets:
323
+ source_dataset = arm.LineageNode(name=dataset, type="dataset")
324
+ all_lineage.append(arm.LineageRelation(type="runtime", source=source_dataset, target=target_dashboard))
325
+
326
+ return all_lineage
327
+
328
+ async def get_all_data_lineage(self) -> list[arm.LineageRelation]:
329
+ """
330
+ Get all data lineage in the project
331
+
332
+ Returns:
333
+ A list of LineageRelation objects
334
+ """
335
+ compiled_dag = await self._get_compiled_dag()
336
+ return self._get_all_data_lineage(compiled_dag)
337
+
338
+ async def _write_dataset_outputs_given_test_set(
339
+ self, dataset: str, select: str, test_set: str | None, runquery: bool, recurse: bool
340
+ ) -> t.Any | None:
341
+ dataset_conf = self._manifest_cfg.datasets[dataset]
342
+ default_test_set_conf = self._manifest_cfg.get_default_test_set(dataset)
343
+ if test_set in self._manifest_cfg.selection_test_sets:
344
+ test_set_conf = self._manifest_cfg.selection_test_sets[test_set]
345
+ elif test_set is None or test_set == default_test_set_conf.name:
346
+ test_set, test_set_conf = default_test_set_conf.name, default_test_set_conf
347
+ else:
348
+ raise ConfigurationError(f"No test set named '{test_set}' was found when compiling dataset '{dataset}'. The test set must be defined if not default for dataset.")
349
+
350
+ error_msg_intro = f"Cannot compile dataset '{dataset}' with test set '{test_set}'."
351
+ if test_set_conf.datasets is not None and dataset not in test_set_conf.datasets:
352
+ raise ConfigurationError(f"{error_msg_intro}\n Applicable datasets for test set '{test_set}' does not include dataset '{dataset}'.")
353
+
354
+ user_attributes = test_set_conf.user_attributes.copy() if test_set_conf.user_attributes is not None else {}
355
+ selections = test_set_conf.parameters.copy()
356
+ username, is_admin = user_attributes.pop("username", ""), user_attributes.pop("is_admin", False)
357
+ if test_set_conf.is_authenticated:
358
+ user = self._auth.User(username=username, is_admin=is_admin, **user_attributes)
359
+ elif dataset_conf.scope == mf.PermissionScope.PUBLIC:
360
+ user = None
361
+ else:
362
+ raise ConfigurationError(f"{error_msg_intro}\n Non-public datasets require a test set with 'user_attributes' section defined")
363
+
364
+ if dataset_conf.scope == mf.PermissionScope.PRIVATE and not is_admin:
365
+ raise ConfigurationError(f"{error_msg_intro}\n Private datasets require a test set with user_attribute 'is_admin' set to true")
366
+
367
+ # always_python_df is set to True for creating CSV files from results (when runquery is True)
368
+ dag = self._generate_dag(dataset, target_model_name=select, always_python_df=runquery)
369
+ await dag.execute(
370
+ self._param_args, self._param_cfg_set, self._context_func, user, selections,
371
+ runquery=runquery, recurse=recurse, default_traits=self._manifest_cfg.get_default_traits()
372
+ )
373
+
374
+ output_folder = u.Path(self._filepath, c.TARGET_FOLDER, c.COMPILE_FOLDER, dataset, test_set)
375
+ if output_folder.exists():
376
+ shutil.rmtree(output_folder)
377
+ output_folder.mkdir(parents=True, exist_ok=True)
378
+
379
+ def write_placeholders() -> None:
380
+ output_filepath = u.Path(output_folder, "placeholders.json")
381
+ with open(output_filepath, 'w') as f:
382
+ json.dump(dag.placeholders, f, indent=4)
383
+
384
+ def write_model_outputs(model: m.DataModel) -> None:
385
+ assert isinstance(model, m.QueryModel)
386
+ subfolder = c.DBVIEWS_FOLDER if model.model_type == m.ModelType.DBVIEW else c.FEDERATES_FOLDER
387
+ subpath = u.Path(output_folder, subfolder)
388
+ subpath.mkdir(parents=True, exist_ok=True)
389
+ if isinstance(model.compiled_query, mq.SqlModelQuery):
390
+ output_filepath = u.Path(subpath, model.name+'.sql')
391
+ query = model.compiled_query.query
392
+ with open(output_filepath, 'w') as f:
393
+ f.write(query)
394
+ if runquery and isinstance(model.result, pl.LazyFrame):
395
+ output_filepath = u.Path(subpath, model.name+'.csv')
396
+ model.result.collect().write_csv(output_filepath)
397
+
398
+ write_placeholders()
399
+ all_model_names = dag.get_all_query_models()
400
+ coroutines = [asyncio.to_thread(write_model_outputs, dag.models_dict[name]) for name in all_model_names]
401
+ await u.asyncio_gather(coroutines)
402
+
403
+ if recurse:
404
+ self._draw_dag(dag, output_folder)
405
+
406
+ if isinstance(dag.target_model, m.QueryModel) and dag.target_model.compiled_query is not None:
407
+ return dag.target_model.compiled_query.query
408
+
409
+ async def compile(
410
+ self, *, dataset: str | None = None, do_all_datasets: bool = False, selected_model: str | None = None, test_set: str | None = None,
411
+ do_all_test_sets: bool = False, runquery: bool = False
412
+ ) -> None:
413
+ """
414
+ Async method to compile the SQL templates into files in the "target/" folder. Same functionality as the "sqrl compile" CLI.
415
+
416
+ Although all arguments are "optional", the "dataset" argument is required if "do_all_datasets" argument is False.
417
+
418
+ Arguments:
419
+ dataset: The name of the dataset to compile. Ignored if "do_all_datasets" argument is True, but required (i.e., cannot be None) if "do_all_datasets" is False. Default is None.
420
+ do_all_datasets: If True, compile all datasets and ignore the "dataset" argument. Default is False.
421
+ selected_model: The name of the model to compile. If specified, the compiled SQL query is also printed in the terminal. If None, all models for the selected dataset are compiled. Default is None.
422
+ test_set: The name of the test set to compile with. If None, the default test set is used (which can vary by dataset). Ignored if `do_all_test_sets` argument is True. Default is None.
423
+ do_all_test_sets: Whether to compile all applicable test sets for the selected dataset(s). If True, the `test_set` argument is ignored. Default is False.
424
+ runquery**: Whether to run all compiled queries and save each result as a CSV file. If True and `selected_model` is specified, all upstream models of the selected model is compiled as well. Default is False.
425
+ """
426
+ recurse = True
427
+ if do_all_datasets:
428
+ selected_models = [(dataset.name, dataset.model) for dataset in self._manifest_cfg.datasets.values()]
429
+ else:
430
+ assert isinstance(dataset, str), "argument 'dataset' must be provided a string value if argument 'do_all_datasets' is False"
431
+ assert dataset in self._manifest_cfg.datasets, f"dataset '{dataset}' not found in {c.MANIFEST_FILE}"
432
+ if selected_model is None:
433
+ selected_model = self._manifest_cfg.datasets[dataset].model
434
+ else:
435
+ recurse = False
436
+ selected_models = [(dataset, selected_model)]
437
+
438
+ coroutines: list[t.Coroutine] = []
439
+ for dataset, selected_model in selected_models:
440
+ if do_all_test_sets:
441
+ for test_set_name in self._manifest_cfg.get_applicable_test_sets(dataset):
442
+ coroutine = self._write_dataset_outputs_given_test_set(dataset, selected_model, test_set_name, runquery, recurse)
443
+ coroutines.append(coroutine)
444
+
445
+ coroutine = self._write_dataset_outputs_given_test_set(dataset, selected_model, test_set, runquery, recurse)
446
+ coroutines.append(coroutine)
447
+
448
+ queries = await u.asyncio_gather(coroutines)
449
+
450
+ print(f"Compiled successfully! See the '{c.TARGET_FOLDER}/' folder for results.")
451
+ print()
452
+ if not recurse and len(queries) == 1 and isinstance(queries[0], str):
453
+ print(queries[0])
454
+ print()
455
+
456
+ def _permission_error(self, user: BaseUser | None, data_type: str, data_name: str, scope: str) -> InvalidInputError:
457
+ username = "" if user is None else f" '{user.username}'"
458
+ return InvalidInputError(25, f"User{username} does not have permission to access {scope} {data_type}: {data_name}")
459
+
460
+ def seed(self, name: str) -> pl.LazyFrame:
461
+ """
462
+ Method to retrieve a seed as a polars LazyFrame given a seed name.
463
+
464
+ Arguments:
465
+ name: The name of the seed to retrieve
466
+
467
+ Returns:
468
+ The seed as a polars LazyFrame
469
+ """
470
+ seeds_dict = self._seeds.get_dataframes()
471
+ try:
472
+ return seeds_dict[name].df
473
+ except KeyError:
474
+ available_seeds = list(seeds_dict.keys())
475
+ raise KeyError(f"Seed '{name}' not found. Available seeds are: {available_seeds}")
476
+
477
+ def dataset_metadata(self, name: str) -> dr.DatasetMetadata:
478
+ """
479
+ Method to retrieve the metadata of a dataset given a dataset name.
480
+
481
+ Arguments:
482
+ name: The name of the dataset to retrieve.
483
+
484
+ Returns:
485
+ A DatasetMetadata object containing the dataset description and column details.
486
+ """
487
+ dag = self._generate_dag(name)
488
+ dag.target_model.process_pass_through_columns(dag.models_dict)
489
+ return dr.DatasetMetadata(
490
+ target_model_config=dag.target_model.model_config
491
+ )
492
+
493
+ async def dataset(
494
+ self, name: str, *, selections: dict[str, t.Any] = {}, user: BaseUser | None = None, require_auth: bool = True
495
+ ) -> dr.DatasetResult:
496
+ """
497
+ Async method to retrieve a dataset as a DatasetResult object (with metadata) given parameter selections.
498
+
499
+ Arguments:
500
+ name: The name of the dataset to retrieve.
501
+ selections: A dictionary of parameter selections to apply to the dataset. Optional, default is empty dictionary.
502
+ user: The user to use for authentication. If None, no user is used. Optional, default is None.
503
+
504
+ Returns:
505
+ A DatasetResult object containing the dataset result (as a polars DataFrame), its description, and the column details.
506
+ """
507
+ scope = self._manifest_cfg.datasets[name].scope
508
+ if require_auth and not self._auth.can_user_access_scope(user, scope):
509
+ raise self._permission_error(user, "dataset", name, scope.name)
510
+
511
+ dag = self._generate_dag(name)
512
+ await dag.execute(
513
+ self._param_args, self._param_cfg_set, self._context_func, user, dict(selections),
514
+ default_traits=self._manifest_cfg.get_default_traits()
515
+ )
516
+ assert isinstance(dag.target_model.result, pl.LazyFrame)
517
+ return dr.DatasetResult(
518
+ target_model_config=dag.target_model.model_config,
519
+ df=dag.target_model.result.collect().with_row_index("_row_num", offset=1)
520
+ )
521
+
522
+ async def dashboard(
523
+ self, name: str, *, selections: dict[str, t.Any] = {}, user: BaseUser | None = None, dashboard_type: t.Type[T] = dash.Dashboard
524
+ ) -> T:
525
+ """
526
+ Async method to retrieve a dashboard given parameter selections.
527
+
528
+ Arguments:
529
+ name: The name of the dashboard to retrieve.
530
+ selections: A dictionary of parameter selections to apply to the dashboard. Optional, default is empty dictionary.
531
+ user: The user to use for authentication. If None, no user is used. Optional, default is None.
532
+ dashboard_type: Return type of the method (mainly used for type hints). For instance, provide PngDashboard if you want the return type to be a PngDashboard. Optional, default is squirrels.Dashboard.
533
+
534
+ Returns:
535
+ The dashboard type specified by the "dashboard_type" argument.
536
+ """
537
+ scope = self._dashboards[name].config.scope
538
+ if not self._auth.can_user_access_scope(user, scope):
539
+ raise self._permission_error(user, "dashboard", name, scope.name)
540
+
541
+ async def get_dataset_df(dataset_name: str, fixed_params: dict[str, t.Any]) -> pl.DataFrame:
542
+ final_selections = {**selections, **fixed_params}
543
+ result = await self.dataset(dataset_name, selections=final_selections, user=user, require_auth=False)
544
+ return result.df
545
+
546
+ args = d.DashboardArgs(self._param_args, get_dataset_df)
547
+ try:
548
+ return await self._dashboards[name].get_dashboard(args, dashboard_type=dashboard_type)
549
+ except KeyError:
550
+ raise KeyError(f"No dashboard file found for: {name}")
551
+
552
+ async def query_models(
553
+ self, sql_query: str, *, selections: dict[str, t.Any] = {}, user: BaseUser | None = None
554
+ ) -> dr.DatasetResult:
555
+ dag = await self._get_compiled_dag(sql_query=sql_query, selections=selections, user=user)
556
+ await dag._run_models()
557
+ assert isinstance(dag.target_model.result, pl.LazyFrame)
558
+ return dr.DatasetResult(
559
+ target_model_config=dag.target_model.model_config,
560
+ df=dag.target_model.result.collect().with_row_index("_row_num", offset=1)
561
+ )
squirrels/_py_module.py CHANGED
@@ -2,6 +2,7 @@ from typing import Type, Optional, Any
2
2
  import importlib.util
3
3
 
4
4
  from . import _constants as c, _utils as u
5
+ from ._exceptions import ConfigurationError, FileExecutionError
5
6
 
6
7
 
7
8
  class PyModule:
@@ -21,7 +22,7 @@ class PyModule:
21
22
  spec.loader.exec_module(self.module)
22
23
  except FileNotFoundError as e:
23
24
  if is_required:
24
- raise u.ConfigurationError(f"Required file not found: '{self.filepath}'") from e
25
+ raise ConfigurationError(f"Required file not found: '{self.filepath}'") from e
25
26
  self.module = default_class
26
27
 
27
28
  def get_func_or_class(self, attr_name: str, *, default_attr: Any = None, is_required: bool = True) -> Any:
@@ -38,7 +39,7 @@ class PyModule:
38
39
  """
39
40
  func_or_class = getattr(self.module, attr_name, default_attr)
40
41
  if func_or_class is None and is_required:
41
- raise u.ConfigurationError(f"Module '{self.filepath}' missing required attribute '{attr_name}'")
42
+ raise ConfigurationError(f"Module '{self.filepath}' missing required attribute '{attr_name}'")
42
43
  return func_or_class
43
44
 
44
45
 
@@ -57,4 +58,4 @@ def run_pyconfig_main(base_path: str, filename: str, kwargs: dict[str, Any] = {}
57
58
  try:
58
59
  main_function(**kwargs)
59
60
  except Exception as e:
60
- raise u.FileExecutionError(f'Failed to run python file "{filepath}"', e) from e
61
+ raise FileExecutionError(f'Failed to run python file "{filepath}"', e) from e
squirrels/_seeds.py CHANGED
@@ -1,39 +1,58 @@
1
1
  from dataclasses import dataclass
2
- import os, time, glob, pandas as pd
2
+ import os, time, glob, polars as pl, json
3
3
 
4
- from ._manifest import ManifestConfig
5
- from . import _utils as _u, _constants as c
4
+ from . import _utils as u, _constants as c, _model_configs as mc
5
+
6
+
7
+ @dataclass
8
+ class Seed:
9
+ config: mc.SeedConfig
10
+ df: pl.LazyFrame
11
+
12
+ def __post_init__(self):
13
+ if self.config.cast_column_types:
14
+ exprs = []
15
+ for col_config in self.config.columns:
16
+ sqrl_dtype = "double" if col_config.type.lower().startswith("decimal") else col_config.type
17
+ polars_dtype = u.sqrl_dtypes_to_polars_dtypes.get(sqrl_dtype, pl.String)
18
+ exprs.append(pl.col(col_config.name).cast(polars_dtype))
19
+
20
+ self.df = self.df.with_columns(*exprs)
6
21
 
7
22
 
8
23
  @dataclass
9
24
  class Seeds:
10
- _data: dict[str, pd.DataFrame]
11
- _manifest_cfg: ManifestConfig
25
+ _data: dict[str, Seed]
12
26
 
13
- def run_query(self, sql_query: str) -> pd.DataFrame:
14
- use_duckdb = self._manifest_cfg.settings_obj.do_use_duckdb()
15
- return _u.run_sql_on_dataframes(sql_query, self._data, use_duckdb)
27
+ def run_query(self, sql_query: str) -> pl.DataFrame:
28
+ dataframes = {key: seed.df for key, seed in self._data.items()}
29
+ return u.run_sql_on_dataframes(sql_query, dataframes)
16
30
 
17
- def get_dataframes(self) -> dict[str, pd.DataFrame]:
31
+ def get_dataframes(self) -> dict[str, Seed]:
18
32
  return self._data.copy()
19
33
 
20
34
 
21
35
  class SeedsIO:
22
36
 
23
37
  @classmethod
24
- def load_files(cls, logger: _u.Logger, base_path: str, manifest_cfg: ManifestConfig) -> Seeds:
38
+ def load_files(cls, logger: u.Logger, base_path: str, env_vars: dict[str, str]) -> Seeds:
25
39
  start = time.time()
26
- infer_schema: bool = manifest_cfg.settings.get(c.SEEDS_INFER_SCHEMA_SETTING, True)
27
- na_values: list[str] = manifest_cfg.settings.get(c.SEEDS_NA_VALUES_SETTING, ["NA"])
28
- csv_dtype = None if infer_schema else str
40
+ infer_schema_setting: bool = (env_vars.get(c.SQRL_SEEDS_INFER_SCHEMA, "true").lower() == "true")
41
+ na_values_setting: list[str] = json.loads(env_vars.get(c.SQRL_SEEDS_NA_VALUES, "[]"))
29
42
 
30
43
  seeds_dict = {}
31
44
  csv_files = glob.glob(os.path.join(base_path, c.SEEDS_FOLDER, '**/*.csv'), recursive=True)
32
45
  for csv_file in csv_files:
46
+ config_file = os.path.splitext(csv_file)[0] + '.yml'
47
+ config_dict = u.load_yaml_config(config_file) if os.path.exists(config_file) else {}
48
+ config = mc.SeedConfig(**config_dict)
49
+
33
50
  file_stem = os.path.splitext(os.path.basename(csv_file))[0]
34
- df = pd.read_csv(csv_file, dtype=csv_dtype, keep_default_na=False, na_values=na_values)
35
- seeds_dict[file_stem] = df
51
+ infer_schema = not config.cast_column_types and infer_schema_setting
52
+ df = pl.read_csv(csv_file, try_parse_dates=True, infer_schema=infer_schema, null_values=na_values_setting).lazy()
53
+
54
+ seeds_dict[file_stem] = Seed(config, df)
36
55
 
37
- seeds = Seeds(seeds_dict, manifest_cfg)
56
+ seeds = Seeds(seeds_dict)
38
57
  logger.log_activity_time("loading seed files", start)
39
58
  return seeds