squirrels 0.5.0b3__py3-none-any.whl → 0.6.0.post0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. squirrels/__init__.py +4 -0
  2. squirrels/_api_routes/__init__.py +5 -0
  3. squirrels/_api_routes/auth.py +337 -0
  4. squirrels/_api_routes/base.py +196 -0
  5. squirrels/_api_routes/dashboards.py +156 -0
  6. squirrels/_api_routes/data_management.py +148 -0
  7. squirrels/_api_routes/datasets.py +220 -0
  8. squirrels/_api_routes/project.py +289 -0
  9. squirrels/_api_server.py +440 -792
  10. squirrels/_arguments/__init__.py +0 -0
  11. squirrels/_arguments/{_init_time_args.py → init_time_args.py} +23 -43
  12. squirrels/_arguments/{_run_time_args.py → run_time_args.py} +32 -68
  13. squirrels/_auth.py +590 -264
  14. squirrels/_command_line.py +130 -58
  15. squirrels/_compile_prompts.py +147 -0
  16. squirrels/_connection_set.py +16 -15
  17. squirrels/_constants.py +36 -11
  18. squirrels/_dashboards.py +179 -0
  19. squirrels/_data_sources.py +40 -34
  20. squirrels/_dataset_types.py +16 -11
  21. squirrels/_env_vars.py +209 -0
  22. squirrels/_exceptions.py +9 -37
  23. squirrels/_http_error_responses.py +52 -0
  24. squirrels/_initializer.py +7 -6
  25. squirrels/_logging.py +121 -0
  26. squirrels/_manifest.py +155 -77
  27. squirrels/_mcp_server.py +578 -0
  28. squirrels/_model_builder.py +11 -55
  29. squirrels/_model_configs.py +5 -5
  30. squirrels/_model_queries.py +1 -1
  31. squirrels/_models.py +276 -143
  32. squirrels/_package_data/base_project/.env +1 -24
  33. squirrels/_package_data/base_project/.env.example +31 -17
  34. squirrels/_package_data/base_project/connections.yml +4 -3
  35. squirrels/_package_data/base_project/dashboards/dashboard_example.py +13 -7
  36. squirrels/_package_data/base_project/dashboards/dashboard_example.yml +6 -6
  37. squirrels/_package_data/base_project/docker/Dockerfile +2 -2
  38. squirrels/_package_data/base_project/docker/compose.yml +1 -1
  39. squirrels/_package_data/base_project/duckdb_init.sql +1 -0
  40. squirrels/_package_data/base_project/models/builds/build_example.py +2 -2
  41. squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +7 -2
  42. squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +16 -10
  43. squirrels/_package_data/base_project/models/federates/federate_example.py +27 -17
  44. squirrels/_package_data/base_project/models/federates/federate_example.sql +3 -7
  45. squirrels/_package_data/base_project/models/federates/federate_example.yml +7 -7
  46. squirrels/_package_data/base_project/models/sources.yml +5 -6
  47. squirrels/_package_data/base_project/parameters.yml +24 -38
  48. squirrels/_package_data/base_project/pyconfigs/connections.py +8 -3
  49. squirrels/_package_data/base_project/pyconfigs/context.py +26 -14
  50. squirrels/_package_data/base_project/pyconfigs/parameters.py +124 -81
  51. squirrels/_package_data/base_project/pyconfigs/user.py +48 -15
  52. squirrels/_package_data/base_project/resources/public/.gitkeep +0 -0
  53. squirrels/_package_data/base_project/seeds/seed_categories.yml +1 -1
  54. squirrels/_package_data/base_project/seeds/seed_subcategories.yml +1 -1
  55. squirrels/_package_data/base_project/squirrels.yml.j2 +21 -31
  56. squirrels/_package_data/templates/login_successful.html +53 -0
  57. squirrels/_package_data/templates/squirrels_studio.html +22 -0
  58. squirrels/_parameter_configs.py +43 -22
  59. squirrels/_parameter_options.py +1 -1
  60. squirrels/_parameter_sets.py +41 -30
  61. squirrels/_parameters.py +560 -123
  62. squirrels/_project.py +487 -277
  63. squirrels/_py_module.py +71 -10
  64. squirrels/_request_context.py +33 -0
  65. squirrels/_schemas/__init__.py +0 -0
  66. squirrels/_schemas/auth_models.py +83 -0
  67. squirrels/_schemas/query_param_models.py +70 -0
  68. squirrels/_schemas/request_models.py +26 -0
  69. squirrels/_schemas/response_models.py +286 -0
  70. squirrels/_seeds.py +52 -13
  71. squirrels/_sources.py +29 -23
  72. squirrels/_utils.py +221 -42
  73. squirrels/_version.py +1 -3
  74. squirrels/arguments.py +7 -2
  75. squirrels/auth.py +4 -0
  76. squirrels/connections.py +2 -0
  77. squirrels/dashboards.py +3 -1
  78. squirrels/data_sources.py +6 -0
  79. squirrels/parameter_options.py +5 -0
  80. squirrels/parameters.py +5 -0
  81. squirrels/types.py +10 -3
  82. squirrels-0.6.0.post0.dist-info/METADATA +148 -0
  83. squirrels-0.6.0.post0.dist-info/RECORD +101 -0
  84. {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/WHEEL +1 -1
  85. squirrels/_api_response_models.py +0 -190
  86. squirrels/_dashboard_types.py +0 -82
  87. squirrels/_dashboards_io.py +0 -79
  88. squirrels-0.5.0b3.dist-info/METADATA +0 -110
  89. squirrels-0.5.0b3.dist-info/RECORD +0 -80
  90. /squirrels/_package_data/base_project/{assets → resources}/expenses.db +0 -0
  91. /squirrels/_package_data/base_project/{assets → resources}/weather.db +0 -0
  92. {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/entry_points.txt +0 -0
  93. {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/licenses/LICENSE +0 -0
squirrels/_project.py CHANGED
@@ -1,39 +1,25 @@
1
- from dotenv import dotenv_values
2
- from uuid import uuid4
1
+ from typing import TYPE_CHECKING
2
+ from dotenv import dotenv_values, load_dotenv
3
3
  from pathlib import Path
4
4
  import asyncio, typing as t, functools as ft, shutil, json, os
5
- import logging as l, matplotlib.pyplot as plt, networkx as nx, polars as pl
6
- import sqlglot, sqlglot.expressions
5
+ import sqlglot, sqlglot.expressions, duckdb, polars as pl
7
6
 
8
- from ._auth import Authenticator, BaseUser
7
+ from ._auth import Authenticator, AuthProviderArgs, ProviderFunctionType
8
+ from ._schemas.auth_models import CustomUserFields, AbstractUser, GuestUser, RegisteredUser
9
+ from ._schemas import response_models as rm
9
10
  from ._model_builder import ModelBuilder
11
+ from ._env_vars import SquirrelsEnvVars
10
12
  from ._exceptions import InvalidInputError, ConfigurationError
11
- from . import _utils as u, _constants as c, _manifest as mf, _connection_set as cs, _api_response_models as arm
13
+ from ._py_module import PyModule
14
+ from . import _dashboards as d, _utils as u, _constants as c, _manifest as mf, _connection_set as cs
12
15
  from . import _seeds as s, _models as m, _model_configs as mc, _model_queries as mq, _sources as so
13
- from . import _parameter_sets as ps, _dashboards_io as d, _dashboard_types as dash, _dataset_types as dr
14
-
15
- T = t.TypeVar("T", bound=dash.Dashboard)
16
- M = t.TypeVar("M", bound=m.DataModel)
16
+ from . import _parameter_sets as ps, _dataset_types as dr, _logging as l
17
17
 
18
+ if TYPE_CHECKING:
19
+ from ._api_server import FastAPIComponents
18
20
 
19
- class _CustomJsonFormatter(l.Formatter):
20
- def format(self, record: l.LogRecord) -> str:
21
- super().format(record)
22
- info = {
23
- "timestamp": self.formatTime(record),
24
- "project_id": record.name,
25
- "level": record.levelname,
26
- "message": record.getMessage(),
27
- "thread": record.thread,
28
- "thread_name": record.threadName,
29
- "process": record.process,
30
- **record.__dict__.get("info", {})
31
- }
32
- output = {
33
- "data": record.__dict__.get("data", {}),
34
- "info": info
35
- }
36
- return json.dumps(output)
21
+ T = t.TypeVar("T", bound=d.Dashboard)
22
+ M = t.TypeVar("M", bound=m.DataModel)
37
23
 
38
24
 
39
25
  class SquirrelsProject:
@@ -41,114 +27,179 @@ class SquirrelsProject:
41
27
  Initiate an instance of this class to interact with a Squirrels project through Python code. For example this can be handy to experiment with the datasets produced by Squirrels in a Jupyter notebook.
42
28
  """
43
29
 
44
- def __init__(self, *, filepath: str = ".", log_file: str | None = c.LOGS_FILE, log_level: str = "INFO", log_format: str = "text") -> None:
30
+ def __init__(
31
+ self, *, project_path: str = ".", load_dotenv_globally: bool = False,
32
+ log_to_file: bool = False, log_level: str | None = None, log_format: str | None = None,
33
+ ) -> None:
45
34
  """
46
35
  Constructor for SquirrelsProject class. Loads the file contents of the Squirrels project into memory as member fields.
47
36
 
48
37
  Arguments:
49
- filepath: The path to the Squirrels project file. Defaults to the current working directory.
50
- log_level: The logging level to use. Options are "DEBUG", "INFO", and "WARNING". Default is "INFO".
51
- log_file: The name of the log file to write to from the "logs/" subfolder. If None or empty string, then file logging is disabled. Default is "squirrels.log".
52
- log_format: The format of the log records. Options are "text" and "json". Default is "text".
38
+ project_path: The path to the Squirrels project file. Defaults to the current working directory.
39
+ log_level: The logging level to use. Options are "DEBUG", "INFO", and "WARNING". Default is from SQRL_LOGGING__LEVEL environment variable or "INFO".
40
+ log_to_file: Whether to enable logging to file(s) in the "logs/" folder (or a custom folder). Default is from SQRL_LOGGING__TO_FILE environment variable or False.
41
+ log_format: The format of the log records. Options are "text" and "json". Default is from SQRL_LOGGING__FORMAT environment variable or "text".
53
42
  """
54
- self._filepath = filepath
55
- self._logger = self._get_logger(self._filepath, log_file, log_level, log_format)
56
-
57
- def _get_logger(self, base_path: str, log_file: str | None, log_level: str, log_format: str) -> u.Logger:
58
- logger = u.Logger(name=uuid4().hex)
59
- logger.setLevel(log_level.upper())
60
-
61
- handler = l.StreamHandler()
62
- handler.setLevel("WARNING")
63
- handler.setFormatter(l.Formatter("%(levelname)s: %(asctime)s - %(message)s"))
64
- logger.addHandler(handler)
65
-
66
- if log_format.lower() == "json":
67
- formatter = _CustomJsonFormatter()
68
- elif log_format.lower() == "text":
69
- formatter = l.Formatter("[%(name)s] %(asctime)s - %(levelname)s - %(message)s")
70
- else:
71
- raise ValueError("log_format must be either 'text' or 'json'")
72
-
73
- if log_file:
74
- path = Path(base_path, c.LOGS_FOLDER, log_file)
75
- path.parent.mkdir(parents=True, exist_ok=True)
43
+ project_path = str(Path(project_path).resolve())
76
44
 
77
- handler = l.FileHandler(path)
78
- handler.setFormatter(formatter)
79
- logger.addHandler(handler)
45
+ self._project_path = project_path
46
+ self._env_vars_unformatted = self._load_env_vars(project_path, load_dotenv_globally)
47
+ self._env_vars = SquirrelsEnvVars(project_path=project_path, **self._env_vars_unformatted)
48
+ self._vdl_catalog_db_path = self._env_vars.vdl_catalog_db_path
80
49
 
81
- return logger
82
-
83
- @ft.cached_property
84
- def _env_vars(self) -> dict[str, str]:
50
+ self._logger = self._get_logger(project_path, self._env_vars, log_to_file, log_level, log_format)
51
+ self._ensure_virtual_datalake_exists(project_path, self._vdl_catalog_db_path, self._env_vars.vdl_data_path)
52
+
53
+ @staticmethod
54
+ def _load_env_vars(project_path: str, load_dotenv_globally: bool) -> dict[str, str]:
85
55
  dotenv_files = [c.DOTENV_FILE, c.DOTENV_LOCAL_FILE]
86
56
  dotenv_vars = {}
87
57
  for file in dotenv_files:
88
- dotenv_vars.update({k: v for k, v in dotenv_values(f"{self._filepath}/{file}").items() if v is not None})
58
+ full_path = u.Path(project_path, file)
59
+ if load_dotenv_globally:
60
+ load_dotenv(full_path)
61
+ dotenv_vars.update({k: v for k, v in dotenv_values(full_path).items() if v is not None})
89
62
  return {**os.environ, **dotenv_vars}
90
63
 
64
+ @staticmethod
65
+ def _get_logger(
66
+ filepath: str, env_vars: SquirrelsEnvVars, log_to_file: bool, log_level: str | None, log_format: str | None
67
+ ) -> u.Logger:
68
+ # CLI arguments take precedence over environment variables
69
+ log_level = log_level if log_level is not None else env_vars.logging_level
70
+ log_format = log_format if log_format is not None else env_vars.logging_format
71
+ log_to_file = env_vars.logging_to_file or log_to_file
72
+ log_file_size_mb = float(env_vars.logging_file_size_mb)
73
+ log_file_backup_count = int(env_vars.logging_file_backup_count)
74
+ return l.get_logger(filepath, log_to_file, log_level, log_format, log_file_size_mb, log_file_backup_count)
75
+
76
+ @staticmethod
77
+ def _ensure_virtual_datalake_exists(project_path: str, vdl_catalog_db_path: str, vdl_data_path: str) -> None:
78
+ target_path = u.Path(project_path, c.TARGET_FOLDER)
79
+ target_path.mkdir(parents=True, exist_ok=True)
80
+
81
+ # Attempt to set up the virtual data lake with DATA_PATH if possible
82
+ try:
83
+ is_ducklake = vdl_catalog_db_path.startswith("ducklake:")
84
+
85
+ options = f"(DATA_PATH '{vdl_data_path}')" if is_ducklake else ""
86
+ attach_stmt = f"ATTACH '{vdl_catalog_db_path}' AS vdl {options}"
87
+ with duckdb.connect() as conn:
88
+ conn.execute(attach_stmt)
89
+ # TODO: support incremental loads for build models and avoid cleaning up old files all the time
90
+ conn.execute("CALL ducklake_expire_snapshots('vdl', older_than => now())")
91
+ conn.execute("CALL ducklake_cleanup_old_files('vdl', cleanup_all => true)")
92
+
93
+ except Exception as e:
94
+ if "DATA_PATH parameter" in str(e):
95
+ first_line = str(e).split("\n")[0]
96
+ note = "NOTE: Squirrels does not allow changing the data path for an existing Virtual Data Lake (VDL)"
97
+ raise u.ConfigurationError(f"{first_line}\n\n{note}")
98
+
99
+ if is_ducklake and not any(x in vdl_catalog_db_path for x in [":sqlite:", ":postgres:", ":mysql:"]):
100
+ extended_error = "\n- Note: if you're using DuckDB for the metadata database, only one process can connect to the VDL at a time."
101
+ else:
102
+ extended_error = ""
103
+
104
+ raise u.ConfigurationError(f"Failed to attach Virtual Data Lake (VDL).{extended_error}") from e
105
+
91
106
  @ft.cached_property
92
107
  def _manifest_cfg(self) -> mf.ManifestConfig:
93
- return mf.ManifestIO.load_from_file(self._logger, self._filepath, self._env_vars)
108
+ return mf.ManifestIO.load_from_file(self._logger, self._project_path, self._env_vars_unformatted)
94
109
 
95
110
  @ft.cached_property
96
111
  def _seeds(self) -> s.Seeds:
97
- return s.SeedsIO.load_files(self._logger, self._filepath, self._env_vars)
112
+ return s.SeedsIO.load_files(self._logger, self._env_vars)
98
113
 
99
114
  @ft.cached_property
100
115
  def _sources(self) -> so.Sources:
101
- return so.SourcesIO.load_file(self._logger, self._filepath, self._env_vars)
116
+ return so.SourcesIO.load_file(self._logger, self._env_vars, self._env_vars_unformatted)
102
117
 
103
118
  @ft.cached_property
104
119
  def _build_model_files(self) -> dict[str, mq.QueryFileWithConfig]:
105
- return m.ModelsIO.load_build_files(self._logger, self._filepath)
120
+ return m.ModelsIO.load_build_files(self._logger, self._env_vars)
106
121
 
107
122
  @ft.cached_property
108
123
  def _dbview_model_files(self) -> dict[str, mq.QueryFileWithConfig]:
109
- return m.ModelsIO.load_dbview_files(self._logger, self._filepath, self._env_vars)
124
+ return m.ModelsIO.load_dbview_files(self._logger, self._env_vars)
110
125
 
111
126
  @ft.cached_property
112
127
  def _federate_model_files(self) -> dict[str, mq.QueryFileWithConfig]:
113
- return m.ModelsIO.load_federate_files(self._logger, self._filepath)
128
+ return m.ModelsIO.load_federate_files(self._logger, self._env_vars)
114
129
 
115
130
  @ft.cached_property
116
131
  def _context_func(self) -> m.ContextFunc:
117
- return m.ModelsIO.load_context_func(self._logger, self._filepath)
132
+ return m.ModelsIO.load_context_func(self._logger, self._project_path)
118
133
 
119
134
  @ft.cached_property
120
135
  def _dashboards(self) -> dict[str, d.DashboardDefinition]:
121
- return d.DashboardsIO.load_files(self._logger, self._filepath)
136
+ return d.DashboardsIO.load_files(
137
+ self._logger, self._project_path, self._manifest_cfg.project_variables.auth_type, self._manifest_cfg.configurables
138
+ )
122
139
 
123
140
  @ft.cached_property
124
141
  def _conn_args(self) -> cs.ConnectionsArgs:
125
- return cs.ConnectionSetIO.load_conn_py_args(self._logger, self._filepath, self._env_vars, self._manifest_cfg)
142
+ proj_vars = self._manifest_cfg.project_variables.model_dump()
143
+ conn_args = cs.ConnectionsArgs(self._project_path, proj_vars, self._env_vars_unformatted)
144
+ return conn_args
126
145
 
127
146
  @ft.cached_property
128
147
  def _conn_set(self) -> cs.ConnectionSet:
129
- return cs.ConnectionSetIO.load_from_file(self._logger, self._filepath, self._manifest_cfg, self._conn_args)
148
+ return cs.ConnectionSetIO.load_from_file(self._logger, self._project_path, self._manifest_cfg, self._conn_args)
149
+
150
+ @ft.cached_property
151
+ def _custom_user_fields_cls_and_provider_functions(self) -> tuple[type[CustomUserFields], list[ProviderFunctionType]]:
152
+ user_module_path = u.Path(self._project_path, c.PYCONFIGS_FOLDER, c.USER_FILE)
153
+ user_module = PyModule(user_module_path, self._project_path)
154
+
155
+ # Load CustomUserFields class (adds to Authenticator.providers as side effect)
156
+ CustomUserFieldsCls = user_module.get_func_or_class("CustomUserFields", default_attr=CustomUserFields)
157
+ provider_functions = Authenticator.providers
158
+ Authenticator.providers = []
159
+
160
+ if not issubclass(CustomUserFieldsCls, CustomUserFields):
161
+ raise ConfigurationError(f"CustomUserFields class in '{c.USER_FILE}' must inherit from CustomUserFields")
162
+
163
+ return CustomUserFieldsCls, provider_functions
130
164
 
131
165
  @ft.cached_property
132
166
  def _auth(self) -> Authenticator:
133
- return Authenticator(self._logger, self._filepath, self._env_vars)
167
+ auth_args = AuthProviderArgs(**self._conn_args.__dict__)
168
+ CustomUserFieldsCls, provider_functions = self._custom_user_fields_cls_and_provider_functions
169
+ external_only = (self._manifest_cfg.project_variables.auth_strategy == mf.AuthStrategy.EXTERNAL)
170
+
171
+ if external_only and len(provider_functions) != 1:
172
+ raise ConfigurationError(f"When auth_strategy is 'external', there must be exactly one auth provider function. Found {len(provider_functions)} auth providers.")
173
+
174
+ return Authenticator(
175
+ self._logger, self._env_vars, auth_args, provider_functions,
176
+ custom_user_fields_cls=CustomUserFieldsCls, external_only=external_only
177
+ )
134
178
 
135
179
  @ft.cached_property
136
- def User(self) -> t.Type[BaseUser]:
137
- return self._auth.User
180
+ def _guest_user(self) -> AbstractUser:
181
+ custom_fields = self._auth.CustomUserFields()
182
+ return GuestUser(username="", custom_fields=custom_fields)
183
+
184
+ @ft.cached_property
185
+ def _admin_user(self) -> AbstractUser:
186
+ custom_fields = self._auth.CustomUserFields()
187
+ return RegisteredUser(username="", access_level="admin", custom_fields=custom_fields)
138
188
 
139
189
  @ft.cached_property
140
190
  def _param_args(self) -> ps.ParametersArgs:
141
- return ps.ParameterConfigsSetIO.get_param_args(self._conn_args)
191
+ conn_args = self._conn_args
192
+ return ps.ParametersArgs(**conn_args.__dict__)
142
193
 
143
194
  @ft.cached_property
144
195
  def _param_cfg_set(self) -> ps.ParameterConfigsSet:
145
196
  return ps.ParameterConfigsSetIO.load_from_file(
146
- self._logger, self._filepath, self._manifest_cfg, self._seeds, self._conn_set, self._param_args
197
+ self._logger, self._env_vars, self._manifest_cfg, self._seeds, self._conn_set, self._param_args
147
198
  )
148
199
 
149
200
  @ft.cached_property
150
201
  def _j2_env(self) -> u.EnvironmentWithMacros:
151
- env = u.EnvironmentWithMacros(self._logger, loader=u.j2.FileSystemLoader(self._filepath))
202
+ env = u.EnvironmentWithMacros(self._logger, loader=u.j2.FileSystemLoader(self._project_path))
152
203
 
153
204
  def value_to_str(value: t.Any, attribute: str | None = None) -> str:
154
205
  if attribute is None:
@@ -170,11 +221,26 @@ class SquirrelsProject:
170
221
  env.filters["quote_and_join"] = quote_and_join
171
222
  return env
172
223
 
173
- @ft.cached_property
174
- def _duckdb_venv_path(self) -> str:
175
- duckdb_filepath_setting_val = self._env_vars.get(c.SQRL_DUCKDB_VENV_DB_FILE_PATH, f"{c.TARGET_FOLDER}/{c.DUCKDB_VENV_FILE}")
176
- return str(Path(self._filepath, duckdb_filepath_setting_val))
177
-
224
+ def get_fastapi_components(
225
+ self, *, no_cache: bool = False, host: str = "localhost", port: int = 8000,
226
+ mount_path_format: str = "/analytics/{project_name}/v{project_version}"
227
+ ) -> "FastAPIComponents":
228
+ """
229
+ Get the FastAPI components for the Squirrels project including mount path, lifespan, and FastAPI app.
230
+
231
+ Arguments:
232
+ no_cache: Whether to disable caching for parameter options, datasets, and dashboard results in the API server.
233
+ host: The host the API server will listen on. Only used for the welcome banner.
234
+ port: The port the API server will listen on. Only used for the welcome banner.
235
+ mount_path_format: The format of the mount path. Use {project_name} and {project_version} as placeholders.
236
+
237
+ Returns:
238
+ A FastAPIComponents object containing the mount path, lifespan, and FastAPI app.
239
+ """
240
+ from ._api_server import ApiServer
241
+ api_server = ApiServer(no_cache=no_cache, project=self)
242
+ return api_server.get_fastapi_components(host=host, port=port, mount_path_format=mount_path_format)
243
+
178
244
  def close(self) -> None:
179
245
  """
180
246
  Deliberately close any open resources within the Squirrels project, such as database connections (instead of relying on the garbage collector).
@@ -182,6 +248,9 @@ class SquirrelsProject:
182
248
  self._conn_set.dispose()
183
249
  self._auth.close()
184
250
 
251
+ def __enter__(self):
252
+ return self
253
+
185
254
  def __exit__(self, exc_type, exc_val, traceback):
186
255
  self.close()
187
256
 
@@ -197,60 +266,59 @@ class SquirrelsProject:
197
266
 
198
267
  seeds_dict = self._seeds.get_dataframes()
199
268
  for key, seed in seeds_dict.items():
200
- self._add_model(models_dict, m.Seed(key, seed.config, seed.df, logger=self._logger, env_vars=self._env_vars, conn_set=self._conn_set))
269
+ self._add_model(models_dict, m.Seed(key, seed.config, seed.df, logger=self._logger, conn_set=self._conn_set))
201
270
 
202
271
  for source_name, source_config in self._sources.sources.items():
203
- self._add_model(models_dict, m.SourceModel(source_name, source_config, logger=self._logger, env_vars=self._env_vars, conn_set=self._conn_set))
272
+ self._add_model(models_dict, m.SourceModel(source_name, source_config, logger=self._logger, conn_set=self._conn_set))
204
273
 
205
274
  for name, val in self._build_model_files.items():
206
- model = m.BuildModel(name, val.config, val.query_file, logger=self._logger, env_vars=self._env_vars, conn_set=self._conn_set, j2_env=self._j2_env)
275
+ model = m.BuildModel(name, val.config, val.query_file, logger=self._logger, conn_set=self._conn_set, j2_env=self._j2_env)
207
276
  self._add_model(models_dict, model)
208
277
 
209
278
  return models_dict
210
279
 
211
280
 
212
- async def build(self, *, full_refresh: bool = False, select: str | None = None, stage_file: bool = False) -> None:
281
+ async def build(self, *, full_refresh: bool = False, select: str | None = None) -> None:
213
282
  """
214
- Build the virtual data environment for the Squirrels project
283
+ Build the Virtual Data Lake (VDL) for the Squirrels project
215
284
 
216
285
  Arguments:
217
- full_refresh: Whether to drop all tables and rebuild the virtual data environment from scratch. Default is False.
218
- stage_file: Whether to stage the DuckDB file to overwrite the existing one later if the virtual data environment is in use. Default is False.
286
+ full_refresh: Whether to drop all tables and rebuild the VDL from scratch. Default is False.
287
+ select: The name of a specific model to build. If None, all models are built. Default is None.
219
288
  """
220
289
  models_dict: dict[str, m.StaticModel] = self._get_static_models()
221
- builder = ModelBuilder(self._duckdb_venv_path, self._conn_set, models_dict, self._conn_args, self._logger)
222
- await builder.build(full_refresh, select, stage_file)
290
+ builder = ModelBuilder(self._vdl_catalog_db_path, self._conn_set, models_dict, self._conn_args, self._logger)
291
+ await builder.build(full_refresh, select)
223
292
 
224
293
  def _get_models_dict(self, always_python_df: bool) -> dict[str, m.DataModel]:
225
- models_dict: dict[str, m.DataModel] = dict(self._get_static_models())
294
+ models_dict: dict[str, m.DataModel] = self._get_static_models()
226
295
 
227
296
  for name, val in self._dbview_model_files.items():
228
297
  self._add_model(models_dict, m.DbviewModel(
229
- name, val.config, val.query_file, logger=self._logger, env_vars=self._env_vars, conn_set=self._conn_set, j2_env=self._j2_env
298
+ name, val.config, val.query_file, logger=self._logger, conn_set=self._conn_set, j2_env=self._j2_env
230
299
  ))
231
300
  models_dict[name].needs_python_df = always_python_df
232
301
 
233
302
  for name, val in self._federate_model_files.items():
234
303
  self._add_model(models_dict, m.FederateModel(
235
- name, val.config, val.query_file, logger=self._logger, env_vars=self._env_vars, conn_set=self._conn_set, j2_env=self._j2_env
304
+ name, val.config, val.query_file, logger=self._logger, conn_set=self._conn_set, j2_env=self._j2_env
236
305
  ))
237
306
  models_dict[name].needs_python_df = always_python_df
238
307
 
239
308
  return models_dict
240
309
 
241
- def _generate_dag(self, dataset: str, *, target_model_name: str | None = None, always_python_df: bool = False) -> m.DAG:
242
- models_dict = self._get_models_dict(always_python_df)
310
+ def _generate_dag(self, dataset: str) -> m.DAG:
311
+ models_dict = self._get_models_dict(always_python_df=False)
243
312
 
244
313
  dataset_config = self._manifest_cfg.datasets[dataset]
245
- target_model_name = dataset_config.model if target_model_name is None else target_model_name
246
- target_model = models_dict[target_model_name]
314
+ target_model = models_dict[dataset_config.model]
247
315
  target_model.is_target = True
248
- dag = m.DAG(dataset_config, target_model, models_dict, self._duckdb_venv_path, self._logger)
316
+ dag = m.DAG(dataset_config, target_model, models_dict, self._vdl_catalog_db_path, self._logger)
249
317
 
250
318
  return dag
251
319
 
252
- def _generate_dag_with_fake_target(self, sql_query: str | None) -> m.DAG:
253
- models_dict = self._get_models_dict(always_python_df=False)
320
+ def _generate_dag_with_fake_target(self, sql_query: str | None, *, always_python_df: bool = False) -> m.DAG:
321
+ models_dict = self._get_models_dict(always_python_df=always_python_df)
254
322
 
255
323
  if sql_query is None:
256
324
  dependencies = set(models_dict.keys())
@@ -260,227 +328,260 @@ class SquirrelsProject:
260
328
  substitutions = {}
261
329
  for model_name in dependencies:
262
330
  model = models_dict[model_name]
263
- if isinstance(model, m.SourceModel) and not model.model_config.load_to_duckdb:
264
- raise InvalidInputError(203, f"Source model '{model_name}' cannot be queried with DuckDB")
265
- if isinstance(model, (m.SourceModel, m.BuildModel)):
266
- substitutions[model_name] = f"venv.{model_name}"
331
+ if isinstance(model, m.SourceModel) and not model.is_queryable:
332
+ raise InvalidInputError(400, "cannot_query_source_model", f"Source model '{model_name}' cannot be queried with DuckDB")
333
+ if isinstance(model, m.BuildModel):
334
+ substitutions[model_name] = f"vdl.{model_name}"
335
+ elif isinstance(model, m.SourceModel):
336
+ if model.model_config.load_to_vdl:
337
+ substitutions[model_name] = f"vdl.{model_name}"
338
+ else:
339
+ # DuckDB connection without load_to_vdl - reference via attached database
340
+ conn_name = model.model_config.get_connection()
341
+ table_name = model.model_config.get_table()
342
+ substitutions[model_name] = f"db_{conn_name}.{table_name}"
267
343
 
268
344
  sql_query = parsed.transform(
269
- lambda node: sqlglot.expressions.Table(this=substitutions[node.name])
345
+ lambda node: sqlglot.expressions.Table(this=substitutions[node.name], alias=node.alias)
270
346
  if isinstance(node, sqlglot.expressions.Table) and node.name in substitutions
271
347
  else node
272
348
  ).sql()
273
349
 
274
350
  model_config = mc.FederateModelConfig(depends_on=dependencies)
275
- query_file = mq.SqlQueryFile("", sql_query or "")
351
+ query_file = mq.SqlQueryFile("", sql_query or "SELECT 1")
276
352
  fake_target_model = m.FederateModel(
277
- "__fake_target", model_config, query_file, logger=self._logger, env_vars=self._env_vars, conn_set=self._conn_set, j2_env=self._j2_env
353
+ "__fake_target", model_config, query_file, logger=self._logger, conn_set=self._conn_set, j2_env=self._j2_env
278
354
  )
279
355
  fake_target_model.is_target = True
280
- dag = m.DAG(None, fake_target_model, models_dict, self._duckdb_venv_path, self._logger)
356
+ dag = m.DAG(None, fake_target_model, models_dict, self._vdl_catalog_db_path, self._logger)
281
357
  return dag
282
358
 
283
- def _draw_dag(self, dag: m.DAG, output_folder: Path) -> None:
284
- color_map = {
285
- m.ModelType.SEED: "green", m.ModelType.DBVIEW: "red", m.ModelType.FEDERATE: "skyblue",
286
- m.ModelType.BUILD: "purple", m.ModelType.SOURCE: "orange"
287
- }
288
-
289
- G = dag.to_networkx_graph()
290
-
291
- fig, _ = plt.subplots()
292
- pos = nx.multipartite_layout(G, subset_key="layer")
293
- colors = [color_map[node[1]] for node in G.nodes(data="model_type")] # type: ignore
294
- nx.draw(G, pos=pos, node_shape='^', node_size=1000, node_color=colors, arrowsize=20)
295
-
296
- y_values = [val[1] for val in pos.values()]
297
- scale = max(y_values) - min(y_values) if len(y_values) > 0 else 0
298
- label_pos = {key: (val[0], val[1]-0.002-0.1*scale) for key, val in pos.items()}
299
- nx.draw_networkx_labels(G, pos=label_pos, font_size=8)
300
-
301
- fig.tight_layout()
302
- plt.margins(x=0.1, y=0.1)
303
- fig.savefig(Path(output_folder, "dag.png"))
304
- plt.close(fig)
305
-
306
- async def _get_compiled_dag(self, *, sql_query: str | None = None, selections: dict[str, t.Any] = {}, user: BaseUser | None = None) -> m.DAG:
307
- dag = self._generate_dag_with_fake_target(sql_query)
359
+ async def _get_compiled_dag(
360
+ self, user: AbstractUser, *, sql_query: str | None = None, selections: dict[str, t.Any] = {}, configurables: dict[str, str] = {},
361
+ always_python_df: bool = False
362
+ ) -> m.DAG:
363
+ dag = self._generate_dag_with_fake_target(sql_query, always_python_df=always_python_df)
308
364
 
309
- default_traits = self._manifest_cfg.get_default_traits()
310
- await dag.execute(self._param_args, self._param_cfg_set, self._context_func, user, selections, runquery=False, default_traits=default_traits)
365
+ configurables = {**self._manifest_cfg.get_default_configurables(), **configurables}
366
+ await dag.execute(
367
+ self._param_args, self._param_cfg_set, self._context_func, user, selections,
368
+ runquery=False, configurables=configurables
369
+ )
311
370
  return dag
312
371
 
313
- def _get_all_connections(self) -> list[arm.ConnectionItemModel]:
372
+ def _get_all_connections(self) -> list[rm.ConnectionItemModel]:
314
373
  connections = []
315
374
  for conn_name, conn_props in self._conn_set.get_connections_as_dict().items():
316
375
  if isinstance(conn_props, mf.ConnectionProperties):
317
376
  label = conn_props.label if conn_props.label is not None else conn_name
318
- connections.append(arm.ConnectionItemModel(name=conn_name, label=label))
377
+ connections.append(rm.ConnectionItemModel(name=conn_name, label=label))
319
378
  return connections
320
379
 
321
- def _get_all_data_models(self, compiled_dag: m.DAG) -> list[arm.DataModelItem]:
380
+ def _get_all_data_models(self, compiled_dag: m.DAG) -> list[rm.DataModelItem]:
322
381
  return compiled_dag.get_all_data_models()
323
382
 
324
- async def get_all_data_models(self) -> list[arm.DataModelItem]:
383
+ async def get_all_data_models(self) -> list[rm.DataModelItem]:
325
384
  """
326
385
  Get all data models in the project
327
386
 
328
387
  Returns:
329
388
  A list of DataModelItem objects
330
389
  """
331
- compiled_dag = await self._get_compiled_dag()
390
+ compiled_dag = await self._get_compiled_dag(self._admin_user)
332
391
  return self._get_all_data_models(compiled_dag)
333
392
 
334
- def _get_all_data_lineage(self, compiled_dag: m.DAG) -> list[arm.LineageRelation]:
393
+ def _get_all_data_lineage(self, compiled_dag: m.DAG) -> list[rm.LineageRelation]:
335
394
  all_lineage = compiled_dag.get_all_model_lineage()
336
395
 
337
396
  # Add dataset nodes to the lineage
338
397
  for dataset in self._manifest_cfg.datasets.values():
339
- target_dataset = arm.LineageNode(name=dataset.name, type="dataset")
340
- source_model = arm.LineageNode(name=dataset.model, type="model")
341
- all_lineage.append(arm.LineageRelation(type="runtime", source=source_model, target=target_dataset))
398
+ target_dataset = rm.LineageNode(name=dataset.name, type="dataset")
399
+ source_model = rm.LineageNode(name=dataset.model, type="model")
400
+ all_lineage.append(rm.LineageRelation(type="runtime", source=source_model, target=target_dataset))
342
401
 
343
402
  # Add dashboard nodes to the lineage
344
403
  for dashboard in self._dashboards.values():
345
- target_dashboard = arm.LineageNode(name=dashboard.dashboard_name, type="dashboard")
404
+ target_dashboard = rm.LineageNode(name=dashboard.dashboard_name, type="dashboard")
346
405
  datasets = set(x.dataset for x in dashboard.config.depends_on)
347
406
  for dataset in datasets:
348
- source_dataset = arm.LineageNode(name=dataset, type="dataset")
349
- all_lineage.append(arm.LineageRelation(type="runtime", source=source_dataset, target=target_dashboard))
407
+ source_dataset = rm.LineageNode(name=dataset, type="dataset")
408
+ all_lineage.append(rm.LineageRelation(type="runtime", source=source_dataset, target=target_dashboard))
350
409
 
351
410
  return all_lineage
352
411
 
353
- async def get_all_data_lineage(self) -> list[arm.LineageRelation]:
412
+ async def get_all_data_lineage(self) -> list[rm.LineageRelation]:
354
413
  """
355
414
  Get all data lineage in the project
356
415
 
357
416
  Returns:
358
417
  A list of LineageRelation objects
359
418
  """
360
- compiled_dag = await self._get_compiled_dag()
419
+ compiled_dag = await self._get_compiled_dag(self._admin_user)
361
420
  return self._get_all_data_lineage(compiled_dag)
362
421
 
363
- async def _write_dataset_outputs_given_test_set(
364
- self, dataset: str, select: str, test_set: str | None, runquery: bool, recurse: bool
365
- ) -> t.Any | None:
366
- dataset_conf = self._manifest_cfg.datasets[dataset]
367
- default_test_set_conf = self._manifest_cfg.get_default_test_set(dataset)
368
- if test_set in self._manifest_cfg.selection_test_sets:
369
- test_set_conf = self._manifest_cfg.selection_test_sets[test_set]
370
- elif test_set is None or test_set == default_test_set_conf.name:
371
- test_set, test_set_conf = default_test_set_conf.name, default_test_set_conf
372
- else:
373
- raise ConfigurationError(f"No test set named '{test_set}' was found when compiling dataset '{dataset}'. The test set must be defined if not default for dataset.")
374
-
375
- error_msg_intro = f"Cannot compile dataset '{dataset}' with test set '{test_set}'."
376
- if test_set_conf.datasets is not None and dataset not in test_set_conf.datasets:
377
- raise ConfigurationError(f"{error_msg_intro}\n Applicable datasets for test set '{test_set}' does not include dataset '{dataset}'.")
378
-
379
- user_attributes = test_set_conf.user_attributes.copy() if test_set_conf.user_attributes is not None else {}
380
- selections = test_set_conf.parameters.copy()
381
- username, is_admin = user_attributes.pop("username", ""), user_attributes.pop("is_admin", False)
382
- if test_set_conf.is_authenticated:
383
- user = self._auth.User(username=username, is_admin=is_admin, **user_attributes)
384
- elif dataset_conf.scope == mf.PermissionScope.PUBLIC:
385
- user = None
386
- else:
387
- raise ConfigurationError(f"{error_msg_intro}\n Non-public datasets require a test set with 'user_attributes' section defined")
388
-
389
- if dataset_conf.scope == mf.PermissionScope.PRIVATE and not is_admin:
390
- raise ConfigurationError(f"{error_msg_intro}\n Private datasets require a test set with user_attribute 'is_admin' set to true")
391
-
392
- # always_python_df is set to True for creating CSV files from results (when runquery is True)
393
- dag = self._generate_dag(dataset, target_model_name=select, always_python_df=runquery)
394
- await dag.execute(
395
- self._param_args, self._param_cfg_set, self._context_func, user, selections,
396
- runquery=runquery, recurse=recurse, default_traits=self._manifest_cfg.get_default_traits()
397
- )
398
-
399
- output_folder = Path(self._filepath, c.TARGET_FOLDER, c.COMPILE_FOLDER, dataset, test_set)
400
- if output_folder.exists():
401
- shutil.rmtree(output_folder)
402
- output_folder.mkdir(parents=True, exist_ok=True)
403
-
404
- def write_placeholders() -> None:
405
- output_filepath = Path(output_folder, "placeholders.json")
406
- with open(output_filepath, 'w') as f:
407
- json.dump(dag.placeholders, f, indent=4)
408
-
409
- def write_model_outputs(model: m.DataModel) -> None:
410
- assert isinstance(model, m.QueryModel)
411
- subfolder = c.DBVIEWS_FOLDER if model.model_type == m.ModelType.DBVIEW else c.FEDERATES_FOLDER
412
- subpath = Path(output_folder, subfolder)
413
- subpath.mkdir(parents=True, exist_ok=True)
414
- if isinstance(model.compiled_query, mq.SqlModelQuery):
415
- output_filepath = Path(subpath, model.name+'.sql')
416
- query = model.compiled_query.query
417
- with open(output_filepath, 'w') as f:
418
- f.write(query)
419
- if runquery and isinstance(model.result, pl.LazyFrame):
420
- output_filepath = Path(subpath, model.name+'.csv')
421
- model.result.collect().write_csv(output_filepath)
422
-
423
- write_placeholders()
424
- all_model_names = dag.get_all_query_models()
425
- coroutines = [asyncio.to_thread(write_model_outputs, dag.models_dict[name]) for name in all_model_names]
426
- await u.asyncio_gather(coroutines)
427
-
428
- if recurse:
429
- self._draw_dag(dag, output_folder)
430
-
431
- if isinstance(dag.target_model, m.QueryModel) and dag.target_model.compiled_query is not None:
432
- return dag.target_model.compiled_query.query
433
-
434
422
  async def compile(
435
- self, *, dataset: str | None = None, do_all_datasets: bool = False, selected_model: str | None = None, test_set: str | None = None,
436
- do_all_test_sets: bool = False, runquery: bool = False
423
+ self, *, selected_model: str | None = None, test_set: str | None = None, do_all_test_sets: bool = False,
424
+ runquery: bool = False, clear: bool = False, buildtime_only: bool = False, runtime_only: bool = False
437
425
  ) -> None:
438
426
  """
439
- Async method to compile the SQL templates into files in the "target/" folder. Same functionality as the "sqrl compile" CLI.
427
+ Compile models into the "target/compile" folder.
440
428
 
441
- Although all arguments are "optional", the "dataset" argument is required if "do_all_datasets" argument is False.
429
+ Behavior:
430
+ - Buildtime outputs: target/compile/buildtime/*.sql (for SQL build models) and dag.png
431
+ - Runtime outputs: target/compile/runtime/[test_set]/dbviews/*.sql, federates/*.sql, dag.png
432
+ If runquery=True, also write CSVs for runtime models.
433
+ - Options: clear entire compile folder first; compile only buildtime or only runtime.
442
434
 
443
435
  Arguments:
444
- dataset: The name of the dataset to compile. Ignored if "do_all_datasets" argument is True, but required (i.e., cannot be None) if "do_all_datasets" is False. Default is None.
445
- do_all_datasets: If True, compile all datasets and ignore the "dataset" argument. Default is False.
446
436
  selected_model: The name of the model to compile. If specified, the compiled SQL query is also printed in the terminal. If None, all models for the selected dataset are compiled. Default is None.
447
437
  test_set: The name of the test set to compile with. If None, the default test set is used (which can vary by dataset). Ignored if `do_all_test_sets` argument is True. Default is None.
448
438
  do_all_test_sets: Whether to compile all applicable test sets for the selected dataset(s). If True, the `test_set` argument is ignored. Default is False.
449
- runquery**: Whether to run all compiled queries and save each result as a CSV file. If True and `selected_model` is specified, all upstream models of the selected model is compiled as well. Default is False.
439
+ runquery: Whether to run all compiled queries and save each result as a CSV file. If True and `selected_model` is specified, all upstream models of the selected model is compiled as well. Default is False.
440
+ clear: Whether to clear the "target/compile/" folder before compiling. Default is False.
441
+ buildtime_only: Whether to compile only buildtime models. Default is False.
442
+ runtime_only: Whether to compile only runtime models. Default is False.
450
443
  """
451
- recurse = True
452
- if do_all_datasets:
453
- selected_models = [(dataset.name, dataset.model) for dataset in self._manifest_cfg.datasets.values()]
454
- else:
455
- assert isinstance(dataset, str), "argument 'dataset' must be provided a string value if argument 'do_all_datasets' is False"
456
- assert dataset in self._manifest_cfg.datasets, f"dataset '{dataset}' not found in {c.MANIFEST_FILE}"
457
- if selected_model is None:
458
- selected_model = self._manifest_cfg.datasets[dataset].model
459
- else:
460
- recurse = False
461
- selected_models = [(dataset, selected_model)]
444
+ border = "=" * 80
445
+ underlines = "-" * len(border)
446
+
447
+ compile_root = Path(self._project_path, c.TARGET_FOLDER, c.COMPILE_FOLDER)
448
+ if clear and compile_root.exists():
449
+ shutil.rmtree(compile_root)
450
+
451
+ models_dict = self._get_models_dict(always_python_df=False)
452
+
453
+ if selected_model is not None:
454
+ selected_model = u.normalize_name(selected_model)
455
+ if selected_model not in models_dict:
456
+ print(f"No such model found: {selected_model}")
457
+ return
458
+ if not isinstance(models_dict[selected_model], m.QueryModel):
459
+ print(f"Model '{selected_model}' is not a query model. Nothing to do.")
460
+ return
462
461
 
463
- coroutines: list[t.Coroutine] = []
464
- for dataset, selected_model in selected_models:
465
- if do_all_test_sets:
466
- for test_set_name in self._manifest_cfg.get_applicable_test_sets(dataset):
467
- coroutine = self._write_dataset_outputs_given_test_set(dataset, selected_model, test_set_name, runquery, recurse)
468
- coroutines.append(coroutine)
462
+ model_to_compile = None
463
+
464
+ # Buildtime compilation
465
+ if not runtime_only:
466
+ print(underlines)
467
+ print(f"Compiling buildtime models")
468
+ print(underlines)
469
+
470
+ buildtime_folder = Path(compile_root, c.COMPILE_BUILDTIME_FOLDER)
471
+ buildtime_folder.mkdir(parents=True, exist_ok=True)
472
+
473
+ def write_buildtime_model(model: m.DataModel, static_models: dict[str, m.StaticModel]) -> None:
474
+ if not isinstance(model, m.BuildModel):
475
+ return
476
+
477
+ model.compile_for_build(self._conn_args, static_models)
478
+
479
+ if isinstance(model.compiled_query, mq.SqlModelQuery):
480
+ out_path = Path(buildtime_folder, f"{model.name}.sql")
481
+ with open(out_path, 'w') as f:
482
+ f.write(model.compiled_query.query)
483
+ print(f"Successfully compiled build model: {model.name}")
484
+ elif isinstance(model.compiled_query, mq.PyModelQuery):
485
+ print(f"The build model '{model.name}' is in Python. Compilation for Python is not supported yet.")
486
+
487
+ static_models = self._get_static_models()
488
+ if selected_model is not None:
489
+ model_to_compile = models_dict[selected_model]
490
+ write_buildtime_model(model_to_compile, static_models)
491
+ else:
492
+ coros = [asyncio.to_thread(write_buildtime_model, m, static_models) for m in static_models.values()]
493
+ await u.asyncio_gather(coros)
469
494
 
470
- coroutine = self._write_dataset_outputs_given_test_set(dataset, selected_model, test_set, runquery, recurse)
471
- coroutines.append(coroutine)
472
-
473
- queries = await u.asyncio_gather(coroutines)
495
+ print(underlines)
496
+ print()
474
497
 
475
- print(f"Compiled successfully! See the '{c.TARGET_FOLDER}/' folder for results.")
476
- print()
477
- if not recurse and len(queries) == 1 and isinstance(queries[0], str):
478
- print(queries[0])
498
+ # Runtime compilation
499
+ if not buildtime_only:
500
+ if do_all_test_sets:
501
+ test_set_names_set = set(self._manifest_cfg.selection_test_sets.keys())
502
+ test_set_names_set.add(c.DEFAULT_TEST_SET_NAME)
503
+ test_set_names = list(test_set_names_set)
504
+ else:
505
+ test_set_names = [test_set or c.DEFAULT_TEST_SET_NAME]
506
+
507
+ for ts_name in test_set_names:
508
+ print(underlines)
509
+ print(f"Compiling runtime models (test set '{ts_name}')")
510
+ print(underlines)
511
+
512
+ # Build user and selections from test set config if present
513
+ ts_conf = self._manifest_cfg.selection_test_sets.get(ts_name, self._manifest_cfg.get_default_test_set())
514
+ # Separate base fields from custom fields
515
+ access_level = ts_conf.user.access_level
516
+ custom_fields = self._auth.CustomUserFields(**ts_conf.user.custom_fields)
517
+ if access_level == "guest":
518
+ user = GuestUser(username="", custom_fields=custom_fields)
519
+ else:
520
+ user = RegisteredUser(username="", access_level=access_level, custom_fields=custom_fields)
521
+
522
+ # Generate DAG across all models. When runquery=True, force models to produce Python dataframes so CSVs can be written.
523
+ dag = await self._get_compiled_dag(
524
+ user=user, selections=ts_conf.parameters, configurables=ts_conf.configurables, always_python_df=runquery,
525
+ )
526
+ if runquery:
527
+ await dag._run_models()
528
+
529
+ # Prepare output folders
530
+ runtime_folder = Path(compile_root, c.COMPILE_RUNTIME_FOLDER, ts_name)
531
+ dbviews_folder = Path(runtime_folder, c.DBVIEWS_FOLDER)
532
+ federates_folder = Path(runtime_folder, c.FEDERATES_FOLDER)
533
+ dbviews_folder.mkdir(parents=True, exist_ok=True)
534
+ federates_folder.mkdir(parents=True, exist_ok=True)
535
+ with open(Path(runtime_folder, "placeholders.json"), "w") as f:
536
+ json.dump(dag.placeholders, f)
537
+
538
+ # Function to write runtime models
539
+ def write_runtime_model(model: m.DataModel) -> None:
540
+ if not isinstance(model, m.QueryModel):
541
+ return
542
+
543
+ if model.model_type not in (m.ModelType.DBVIEW, m.ModelType.FEDERATE):
544
+ return
545
+
546
+ subfolder = dbviews_folder if model.model_type == m.ModelType.DBVIEW else federates_folder
547
+ model_type = "dbview" if model.model_type == m.ModelType.DBVIEW else "federate"
548
+
549
+ if isinstance(model.compiled_query, mq.SqlModelQuery):
550
+ out_sql = Path(subfolder, f"{model.name}.sql")
551
+ with open(out_sql, 'w') as f:
552
+ f.write(model.compiled_query.query)
553
+ print(f"Successfully compiled {model_type} model: {model.name}")
554
+ elif isinstance(model.compiled_query, mq.PyModelQuery):
555
+ print(f"The {model_type} model '{model.name}' is in Python. Compilation for Python is not supported yet.")
556
+
557
+ if runquery and isinstance(model.result, pl.LazyFrame):
558
+ out_csv = Path(subfolder, f"{model.name}.csv")
559
+ model.result.collect().write_csv(out_csv)
560
+ print(f"Successfully created CSV for {model_type} model: {model.name}")
561
+
562
+ # If selected_model is provided for runtime, only emit that model's outputs
563
+ if selected_model is not None:
564
+ model_to_compile = dag.models_dict[selected_model]
565
+ write_runtime_model(model_to_compile)
566
+ else:
567
+ coros = [asyncio.to_thread(write_runtime_model, model) for model in dag.models_dict.values()]
568
+ await u.asyncio_gather(coros)
569
+
570
+ print(underlines)
571
+ print()
572
+
573
+ print(f"All compilations complete! See the '{c.TARGET_FOLDER}/{c.COMPILE_FOLDER}/' folder for results.")
574
+ if model_to_compile and isinstance(model_to_compile, m.QueryModel) and isinstance(model_to_compile.compiled_query, mq.SqlModelQuery):
575
+ print()
576
+ print(border)
577
+ print(f"Compiled SQL query for model '{model_to_compile.name}':")
578
+ print(underlines)
579
+ print(model_to_compile.compiled_query.query)
580
+ print(border)
479
581
  print()
480
582
 
481
- def _permission_error(self, user: BaseUser | None, data_type: str, data_name: str, scope: str) -> InvalidInputError:
482
- username = "" if user is None else f" '{user.username}'"
483
- return InvalidInputError(25, f"User{username} does not have permission to access {scope} {data_type}: {data_name}")
583
+ def _permission_error(self, user: AbstractUser, data_type: str, data_name: str, scope: str) -> InvalidInputError:
584
+ return InvalidInputError(403, f"unauthorized_access_to_{data_type}", f"User '{user}' does not have permission to access {scope} {data_type}: {data_name}")
484
585
 
485
586
  def seed(self, name: str) -> pl.LazyFrame:
486
587
  """
@@ -515,37 +616,77 @@ class SquirrelsProject:
515
616
  target_model_config=dag.target_model.model_config
516
617
  )
517
618
 
518
- async def dataset(
519
- self, name: str, *, selections: dict[str, t.Any] = {}, user: BaseUser | None = None, require_auth: bool = True
520
- ) -> dr.DatasetResult:
619
+ def _enforce_max_result_rows(self, lazy_df: pl.LazyFrame, error_type: str) -> pl.DataFrame:
521
620
  """
522
- Async method to retrieve a dataset as a DatasetResult object (with metadata) given parameter selections.
523
-
621
+ Collect at most max_rows + 1 rows from a LazyFrame to detect overflow.
622
+ Raises InvalidInputError if the result exceeds the maximum allowed rows.
623
+
524
624
  Arguments:
525
- name: The name of the dataset to retrieve.
526
- selections: A dictionary of parameter selections to apply to the dataset. Optional, default is empty dictionary.
527
- user: The user to use for authentication. If None, no user is used. Optional, default is None.
625
+ lazy_df: The LazyFrame to collect and check
626
+ error_type: Either "dataset" or "query" to customize the error message
528
627
 
529
628
  Returns:
530
- A DatasetResult object containing the dataset result (as a polars DataFrame), its description, and the column details.
629
+ A DataFrame with at most max_rows rows (or raises if exceeded)
531
630
  """
631
+ max_rows = self._env_vars.datasets_max_rows_output
632
+ # Collect max_rows + 1 to detect overflow without loading unbounded results
633
+ collected = lazy_df.limit(max_rows + 1).collect()
634
+ row_count = collected.select(pl.len()).item()
635
+
636
+ if row_count > max_rows:
637
+ raise InvalidInputError(
638
+ 413, f"{error_type}_result_too_large",
639
+ f"The {error_type} result contains {row_count} rows, which exceeds the maximum allowed of {max_rows} rows."
640
+ )
641
+
642
+ return collected
643
+
644
+ async def _dataset_result(
645
+ self, name: str, *, selections: dict[str, t.Any] = {}, user: AbstractUser | None = None,
646
+ configurables: dict[str, str] = {}, check_user_access: bool = True
647
+ ) -> dr.DatasetResult:
648
+ if user is None:
649
+ user = self._guest_user
650
+
532
651
  scope = self._manifest_cfg.datasets[name].scope
533
- if require_auth and not self._auth.can_user_access_scope(user, scope):
652
+ if check_user_access and not self._auth.can_user_access_scope(user, scope):
534
653
  raise self._permission_error(user, "dataset", name, scope.name)
535
654
 
655
+ dataset_config = self._manifest_cfg.datasets[name]
656
+ configurables = {**self._manifest_cfg.get_default_configurables(overrides=dataset_config.configurables), **configurables}
657
+
536
658
  dag = self._generate_dag(name)
537
659
  await dag.execute(
538
- self._param_args, self._param_cfg_set, self._context_func, user, dict(selections),
539
- default_traits=self._manifest_cfg.get_default_traits()
660
+ self._param_args, self._param_cfg_set, self._context_func, user, dict(selections), configurables=configurables
540
661
  )
541
662
  assert isinstance(dag.target_model.result, pl.LazyFrame)
663
+ df = self._enforce_max_result_rows(dag.target_model.result, "dataset")
542
664
  return dr.DatasetResult(
543
665
  target_model_config=dag.target_model.model_config,
544
- df=dag.target_model.result.collect().with_row_index("_row_num", offset=1)
666
+ df=df.with_row_index("_row_num", offset=1)
545
667
  )
546
668
 
669
+ async def dataset_result(
670
+ self, name: str, *, selections: dict[str, t.Any] = {}, user: AbstractUser | None = None, configurables: dict[str, str] = {}
671
+ ) -> dr.DatasetResult:
672
+ """
673
+ Async method to retrieve a dataset as a DatasetResult object (with metadata) given parameter selections.
674
+
675
+ Arguments:
676
+ name: The name of the dataset to retrieve.
677
+ selections: A dictionary of parameter selections to apply to the dataset. Optional, default is empty dictionary.
678
+ user: The user to use for authentication. If None, no user is used. Optional, default is None.
679
+ configurables: A dictionary of configurables to apply to the dataset. Optional, default is empty dictionary.
680
+
681
+ Returns:
682
+ A DatasetResult object containing the dataset result (as a polars DataFrame), its description, and the column details.
683
+ """
684
+ result = await self._dataset_result(name, selections=selections, user=user, configurables=configurables, check_user_access=False)
685
+ return result
686
+
547
687
  async def dashboard(
548
- self, name: str, *, selections: dict[str, t.Any] = {}, user: BaseUser | None = None, dashboard_type: t.Type[T] = dash.PngDashboard
688
+ self, name: str, *, selections: dict[str, t.Any] = {}, user: AbstractUser | None = None, dashboard_type: t.Type[T] = d.PngDashboard,
689
+ configurables: dict[str, str] = {}
549
690
  ) -> T:
550
691
  """
551
692
  Async method to retrieve a dashboard given parameter selections.
@@ -559,28 +700,97 @@ class SquirrelsProject:
559
700
  Returns:
560
701
  The dashboard type specified by the "dashboard_type" argument.
561
702
  """
703
+ if user is None:
704
+ user = self._guest_user
705
+
562
706
  scope = self._dashboards[name].config.scope
563
707
  if not self._auth.can_user_access_scope(user, scope):
564
708
  raise self._permission_error(user, "dashboard", name, scope.name)
565
709
 
566
710
  async def get_dataset_df(dataset_name: str, fixed_params: dict[str, t.Any]) -> pl.DataFrame:
567
711
  final_selections = {**selections, **fixed_params}
568
- result = await self.dataset(dataset_name, selections=final_selections, user=user, require_auth=False)
712
+ result = await self.dataset_result(
713
+ dataset_name, selections=final_selections, user=user, configurables=configurables
714
+ )
569
715
  return result.df
570
716
 
571
- args = d.DashboardArgs(self._param_args, get_dataset_df)
717
+ dashboard_config = self._dashboards[name].config
718
+ parameter_set = self._param_cfg_set.apply_selections(dashboard_config.parameters, selections, user)
719
+ prms = parameter_set.get_parameters_as_dict()
720
+
721
+ configurables = {**self._manifest_cfg.get_default_configurables(overrides=dashboard_config.configurables), **configurables}
722
+ context = {}
723
+ ctx_args = m.ContextArgs(
724
+ **self._param_args.__dict__, user=user, prms=prms, configurables=configurables, _conn_args=self._conn_args
725
+ )
726
+ self._context_func(context, ctx_args)
727
+
728
+ args = d.DashboardArgs(
729
+ **ctx_args.__dict__, ctx=context, _get_dataset=get_dataset_df
730
+ )
572
731
  try:
573
732
  return await self._dashboards[name].get_dashboard(args, dashboard_type=dashboard_type)
574
733
  except KeyError:
575
734
  raise KeyError(f"No dashboard file found for: {name}")
576
735
 
577
736
  async def query_models(
578
- self, sql_query: str, *, selections: dict[str, t.Any] = {}, user: BaseUser | None = None
737
+ self, sql_query: str, *, user: AbstractUser | None = None, selections: dict[str, t.Any] = {}, configurables: dict[str, str] = {}
579
738
  ) -> dr.DatasetResult:
580
- dag = await self._get_compiled_dag(sql_query=sql_query, selections=selections, user=user)
739
+ if user is None:
740
+ user = self._guest_user
741
+
742
+ dag = await self._get_compiled_dag(user=user, sql_query=sql_query, selections=selections, configurables=configurables)
581
743
  await dag._run_models()
582
744
  assert isinstance(dag.target_model.result, pl.LazyFrame)
745
+ df = self._enforce_max_result_rows(dag.target_model.result, "query")
583
746
  return dr.DatasetResult(
584
747
  target_model_config=dag.target_model.model_config,
585
- df=dag.target_model.result.collect().with_row_index("_row_num", offset=1)
748
+ df=df.with_row_index("_row_num", offset=1)
586
749
  )
750
+
751
+ async def get_compiled_model_query(
752
+ self, model_name: str, *, user: AbstractUser | None = None, selections: dict[str, t.Any] = {}, configurables: dict[str, str] = {}
753
+ ) -> rm.CompiledQueryModel:
754
+ """
755
+ Compile the specified data model and return its language and compiled definition.
756
+ """
757
+ if user is None:
758
+ user = self._guest_user
759
+
760
+ name = u.normalize_name(model_name)
761
+ models_dict = self._get_models_dict(always_python_df=False)
762
+ if name not in models_dict:
763
+ raise InvalidInputError(404, "model_not_found", f"No data model found with name: {model_name}")
764
+
765
+ model = models_dict[name]
766
+ # Only build, dbview, and federate models support runtime compiled definition in this context
767
+ if not isinstance(model, (m.BuildModel, m.DbviewModel, m.FederateModel)):
768
+ raise InvalidInputError(400, "unsupported_model_type", "Only build, dbview, and federate models currently support compiled definition via this endpoint")
769
+
770
+ # Build a DAG with this model as the target, without a dataset context
771
+ model.is_target = True
772
+ dag = m.DAG(None, model, models_dict, self._vdl_catalog_db_path, self._logger)
773
+
774
+ cfg = {**self._manifest_cfg.get_default_configurables(), **configurables}
775
+ await dag.execute(
776
+ self._param_args, self._param_cfg_set, self._context_func, user, selections, runquery=False, configurables=cfg
777
+ )
778
+
779
+ language = "sql" if isinstance(model.query_file, mq.SqlQueryFile) else "python"
780
+ if isinstance(model, m.BuildModel):
781
+ # Compile SQL build models; Python build models not yet supported
782
+ if isinstance(model.query_file, mq.SqlQueryFile):
783
+ static_models = self._get_static_models()
784
+ compiled = model._compile_sql_model(model.query_file, self._conn_args, static_models)
785
+ definition = compiled.query
786
+ else:
787
+ definition = "# Compiling Python build models is currently not supported. This will be available in a future version of Squirrels..."
788
+ elif isinstance(model.compiled_query, mq.SqlModelQuery):
789
+ definition = model.compiled_query.query
790
+ elif isinstance(model.compiled_query, mq.PyModelQuery):
791
+ definition = "# Compiling Python data models is currently not supported. This will be available in a future version of Squirrels..."
792
+ else:
793
+ raise NotImplementedError(f"Query type not supported: {model.compiled_query.__class__.__name__}")
794
+
795
+ return rm.CompiledQueryModel(language=language, definition=definition, placeholders=dag.placeholders)
796
+