squirrels 0.4.1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of squirrels might be problematic. Click here for more details.

Files changed (80) hide show
  1. squirrels/__init__.py +10 -6
  2. squirrels/_api_response_models.py +93 -44
  3. squirrels/_api_server.py +571 -219
  4. squirrels/_auth.py +451 -0
  5. squirrels/_command_line.py +61 -20
  6. squirrels/_connection_set.py +38 -25
  7. squirrels/_constants.py +44 -34
  8. squirrels/_dashboards_io.py +34 -16
  9. squirrels/_exceptions.py +57 -0
  10. squirrels/_initializer.py +117 -44
  11. squirrels/_manifest.py +124 -62
  12. squirrels/_model_builder.py +111 -0
  13. squirrels/_model_configs.py +74 -0
  14. squirrels/_model_queries.py +52 -0
  15. squirrels/_models.py +860 -354
  16. squirrels/_package_loader.py +8 -4
  17. squirrels/_parameter_configs.py +45 -65
  18. squirrels/_parameter_sets.py +15 -13
  19. squirrels/_project.py +561 -0
  20. squirrels/_py_module.py +4 -3
  21. squirrels/_seeds.py +35 -16
  22. squirrels/_sources.py +106 -0
  23. squirrels/_utils.py +166 -63
  24. squirrels/_version.py +1 -1
  25. squirrels/arguments/init_time_args.py +78 -15
  26. squirrels/arguments/run_time_args.py +62 -101
  27. squirrels/dashboards.py +4 -4
  28. squirrels/data_sources.py +94 -162
  29. squirrels/dataset_result.py +86 -0
  30. squirrels/dateutils.py +4 -4
  31. squirrels/package_data/base_project/.env +30 -0
  32. squirrels/package_data/base_project/.env.example +30 -0
  33. squirrels/package_data/base_project/.gitignore +3 -2
  34. squirrels/package_data/base_project/assets/expenses.db +0 -0
  35. squirrels/package_data/base_project/connections.yml +11 -3
  36. squirrels/package_data/base_project/dashboards/dashboard_example.py +15 -13
  37. squirrels/package_data/base_project/dashboards/dashboard_example.yml +22 -0
  38. squirrels/package_data/base_project/docker/.dockerignore +5 -2
  39. squirrels/package_data/base_project/docker/Dockerfile +3 -3
  40. squirrels/package_data/base_project/docker/compose.yml +1 -1
  41. squirrels/package_data/base_project/duckdb_init.sql +9 -0
  42. squirrels/package_data/base_project/macros/macros_example.sql +15 -0
  43. squirrels/package_data/base_project/models/builds/build_example.py +26 -0
  44. squirrels/package_data/base_project/models/builds/build_example.sql +16 -0
  45. squirrels/package_data/base_project/models/builds/build_example.yml +55 -0
  46. squirrels/package_data/base_project/models/dbviews/dbview_example.sql +12 -22
  47. squirrels/package_data/base_project/models/dbviews/dbview_example.yml +26 -0
  48. squirrels/package_data/base_project/models/federates/federate_example.py +38 -15
  49. squirrels/package_data/base_project/models/federates/federate_example.sql +16 -2
  50. squirrels/package_data/base_project/models/federates/federate_example.yml +65 -0
  51. squirrels/package_data/base_project/models/sources.yml +39 -0
  52. squirrels/package_data/base_project/parameters.yml +36 -21
  53. squirrels/package_data/base_project/pyconfigs/connections.py +6 -11
  54. squirrels/package_data/base_project/pyconfigs/context.py +20 -33
  55. squirrels/package_data/base_project/pyconfigs/parameters.py +19 -21
  56. squirrels/package_data/base_project/pyconfigs/user.py +23 -0
  57. squirrels/package_data/base_project/seeds/seed_categories.yml +15 -0
  58. squirrels/package_data/base_project/seeds/seed_subcategories.csv +15 -15
  59. squirrels/package_data/base_project/seeds/seed_subcategories.yml +21 -0
  60. squirrels/package_data/base_project/squirrels.yml.j2 +17 -40
  61. squirrels/parameters.py +20 -20
  62. {squirrels-0.4.1.dist-info → squirrels-0.5.0rc0.dist-info}/METADATA +31 -32
  63. squirrels-0.5.0rc0.dist-info/RECORD +70 -0
  64. {squirrels-0.4.1.dist-info → squirrels-0.5.0rc0.dist-info}/WHEEL +1 -1
  65. squirrels-0.5.0rc0.dist-info/entry_points.txt +3 -0
  66. {squirrels-0.4.1.dist-info → squirrels-0.5.0rc0.dist-info/licenses}/LICENSE +1 -1
  67. squirrels/_authenticator.py +0 -85
  68. squirrels/_environcfg.py +0 -84
  69. squirrels/package_data/assets/favicon.ico +0 -0
  70. squirrels/package_data/assets/index.css +0 -1
  71. squirrels/package_data/assets/index.js +0 -58
  72. squirrels/package_data/base_project/dashboards.yml +0 -10
  73. squirrels/package_data/base_project/env.yml +0 -29
  74. squirrels/package_data/base_project/models/dbviews/dbview_example.py +0 -47
  75. squirrels/package_data/base_project/pyconfigs/auth.py +0 -45
  76. squirrels/package_data/templates/index.html +0 -18
  77. squirrels/project.py +0 -378
  78. squirrels/user_base.py +0 -55
  79. squirrels-0.4.1.dist-info/RECORD +0 -60
  80. squirrels-0.4.1.dist-info/entry_points.txt +0 -4
squirrels/_sources.py ADDED
@@ -0,0 +1,106 @@
1
+ from typing import Any
2
+ from pydantic import BaseModel, Field, model_validator
3
+ import time, sqlglot
4
+
5
+ from . import _utils as u, _constants as c, _model_configs as mc
6
+
7
+
8
+ class UpdateHints(BaseModel):
9
+ increasing_column: str | None = Field(default=None)
10
+ strictly_increasing: bool = Field(default=True, description="Delete the max value of the increasing column, ignored if value is set")
11
+ selective_overwrite_value: Any = Field(default=None)
12
+
13
+
14
+ class Source(mc.ConnectionInterface, mc.ModelConfig):
15
+ table: str | None = Field(default=None)
16
+ load_to_duckdb: bool = Field(default=False, description="Whether to load the data to DuckDB")
17
+ primary_key: list[str] = Field(default_factory=list)
18
+ update_hints: UpdateHints = Field(default_factory=UpdateHints)
19
+
20
+ def finalize_table(self, source_name: str):
21
+ if self.table is None:
22
+ self.table = source_name
23
+ return self
24
+
25
+ def get_table(self) -> str:
26
+ assert self.table is not None, "Table must be set"
27
+ return self.table
28
+
29
+ def get_cols_for_create_table_stmt(self) -> str:
30
+ cols_clause = ", ".join([f"{col.name} {col.type}" for col in self.columns])
31
+ primary_key_clause = f", PRIMARY KEY ({', '.join(self.primary_key)})" if self.primary_key else ""
32
+ return f"{cols_clause}{primary_key_clause}"
33
+
34
+ def get_cols_for_insert_stmt(self) -> str:
35
+ return ", ".join([col.name for col in self.columns])
36
+
37
+ def get_max_incr_col_query(self, source_name: str) -> str:
38
+ return f"SELECT max({self.update_hints.increasing_column}) FROM {source_name}"
39
+
40
+ def get_query_for_insert(self, dialect: str, conn_name: str, table_name: str, max_value_of_increasing_col: Any | None, *, full_refresh: bool = True) -> str:
41
+ select_cols = self.get_cols_for_insert_stmt()
42
+ if full_refresh or max_value_of_increasing_col is None:
43
+ return f"SELECT {select_cols} FROM db_{conn_name}.{table_name}"
44
+
45
+ increasing_col = self.update_hints.increasing_column
46
+ increasing_col_type = next(col.type for col in self.columns if col.name == increasing_col)
47
+ where_cond = f"{increasing_col}::{increasing_col_type} > '{max_value_of_increasing_col}'::{increasing_col_type}"
48
+ pushdown_query = f"SELECT {select_cols} FROM {table_name} WHERE {where_cond}"
49
+
50
+ if dialect in ['postgres', 'mysql']:
51
+ transpiled_query = sqlglot.transpile(pushdown_query, read='duckdb', write=dialect)[0].replace("'", "''")
52
+ return f"FROM {dialect}_query('db_{conn_name}', '{transpiled_query}')"
53
+
54
+ return f"SELECT {select_cols} FROM db_{conn_name}.{table_name} WHERE {where_cond}"
55
+
56
+ def get_insert_replace_clause(self) -> str:
57
+ return "" if len(self.primary_key) == 0 else "OR REPLACE"
58
+
59
+
60
+ class Sources(BaseModel):
61
+ sources: dict[str, Source] = Field(default_factory=dict)
62
+
63
+ @model_validator(mode="before")
64
+ @classmethod
65
+ def convert_sources_list_to_dict(cls, data: dict[str, Any]) -> dict[str, Any]:
66
+ if "sources" in data and isinstance(data["sources"], list):
67
+ # Convert list of sources to dictionary
68
+ sources_dict = {}
69
+ for source in data["sources"]:
70
+ if isinstance(source, dict) and "name" in source:
71
+ name = source.pop("name") # Remove name from source config
72
+ if name in sources_dict:
73
+ raise u.ConfigurationError(f"Duplicate source name found: {name}")
74
+ sources_dict[name] = source
75
+ else:
76
+ raise u.ConfigurationError(f"All sources must have a name field in sources file")
77
+ data["sources"] = sources_dict
78
+ return data
79
+
80
+ @model_validator(mode="after")
81
+ def validate_column_types(self):
82
+ for source_name, source in self.sources.items():
83
+ for col in source.columns:
84
+ if not col.type:
85
+ raise u.ConfigurationError(f"Column '{col.name}' in source '{source_name}' must have a type specified")
86
+ return self
87
+
88
+ def finalize_null_fields(self, env_vars: dict[str, str]):
89
+ for source_name, source in self.sources.items():
90
+ source.finalize_connection(env_vars)
91
+ source.finalize_table(source_name)
92
+ return self
93
+
94
+
95
+ class SourcesIO:
96
+ @classmethod
97
+ def load_file(cls, logger: u.Logger, base_path: str, env_vars: dict[str, str]) -> Sources:
98
+ start = time.time()
99
+
100
+ sources_path = u.Path(base_path, c.MODELS_FOLDER, c.SOURCES_FILE)
101
+ sources_data = u.load_yaml_config(sources_path) if sources_path.exists() else {}
102
+
103
+ sources = Sources(**sources_data).finalize_null_fields(env_vars)
104
+
105
+ logger.log_activity_time("loading sources", start)
106
+ return sources
squirrels/_utils.py CHANGED
@@ -1,35 +1,35 @@
1
- from typing import Sequence, Optional, Union, TypeVar, Callable
2
- from pathlib import Path
3
- from pandas.api import types as pd_types
1
+ from typing import Sequence, Optional, Union, TypeVar, Callable, Any, Iterable
4
2
  from datetime import datetime
5
- import os, time, logging, json, sqlite3, pandas as pd
3
+ from pathlib import Path
4
+ from functools import lru_cache
5
+ from pydantic import BaseModel
6
+ import os, time, logging, json, duckdb, polars as pl, yaml
6
7
  import jinja2 as j2, jinja2.nodes as j2_nodes
8
+ import sqlglot, sqlglot.expressions, asyncio
7
9
 
8
10
  from . import _constants as c
11
+ from ._exceptions import ConfigurationError
9
12
 
10
13
  FilePath = Union[str, Path]
11
14
 
12
-
13
- ## Custom Exceptions
14
-
15
- class InvalidInputError(Exception):
16
- """
17
- Use this exception when the error is due to providing invalid inputs to the REST API
18
- """
19
- pass
20
-
21
- class ConfigurationError(Exception):
22
- """
23
- Use this exception when the server error is due to errors in the squirrels project instead of the squirrels framework/library
24
- """
25
- pass
26
-
27
- class FileExecutionError(Exception):
28
- def __init__(self, message: str, error: Exception, *args) -> None:
29
- t = " "
30
- new_message = f"\n" + message + f"\n{t}Produced error message:\n{t}{t}{error} (see above for more details on handled exception)"
31
- super().__init__(new_message, *args)
32
- self.error = error
15
+ # Polars
16
+ polars_dtypes_to_sqrl_dtypes: dict[type[pl.DataType], list[str]] = {
17
+ pl.String: ["string", "varchar", "char", "text"],
18
+ pl.Int8: ["tinyint", "int1"],
19
+ pl.Int16: ["smallint", "short", "int2"],
20
+ pl.Int32: ["integer", "int", "int4"],
21
+ pl.Int64: ["bigint", "long", "int8"],
22
+ pl.Float32: ["float", "float4", "real"],
23
+ pl.Float64: ["double", "float8"],
24
+ pl.Boolean: ["boolean", "bool", "logical"],
25
+ pl.Date: ["date"],
26
+ pl.Time: ["time"],
27
+ pl.Datetime: ["timestamp", "datetime"],
28
+ pl.Duration: ["interval"],
29
+ pl.Binary: ["blob", "binary", "varbinary"]
30
+ }
31
+
32
+ sqrl_dtypes_to_polars_dtypes: dict[str, type[pl.DataType]] = {sqrl_type: k for k, v in polars_dtypes_to_sqrl_dtypes.items() for sqrl_type in v}
33
33
 
34
34
 
35
35
  ## Other utility classes
@@ -40,7 +40,7 @@ class Logger(logging.Logger):
40
40
  time_taken = round((end_timestamp-start_timestamp) * 10**3, 3)
41
41
  data = { "activity": activity, "start_timestamp": start_timestamp, "end_timestamp": end_timestamp, "time_taken_ms": time_taken }
42
42
  info = { "request_id": request_id } if request_id else {}
43
- self.debug(f'Time taken for "{activity}": {time_taken}ms', extra={"data": data, "info": info})
43
+ self.info(f'Time taken for "{activity}": {time_taken}ms', extra={"data": data, "info": info})
44
44
 
45
45
 
46
46
  class EnvironmentWithMacros(j2.Environment):
@@ -115,7 +115,6 @@ def read_file(filepath: FilePath) -> str:
115
115
 
116
116
  Arguments:
117
117
  filepath (str | pathlib.Path): The path to the file to read
118
- is_required: If true, throw error if file doesn't exist
119
118
 
120
119
  Returns:
121
120
  Content of the file, or None if doesn't exist and not required
@@ -180,7 +179,7 @@ def load_json_or_comma_delimited_str_as_list(input_str: Union[str, Sequence]) ->
180
179
  return [x.strip() for x in input_str.split(",")]
181
180
 
182
181
 
183
- X, Y = TypeVar('X'), TypeVar('Y')
182
+ X = TypeVar('X'); Y = TypeVar('Y')
184
183
  def process_if_not_none(input_val: Optional[X], processor: Callable[[X], Y]) -> Optional[Y]:
185
184
  """
186
185
  Given a input value and a function that processes the value, return the output of the function unless input is None
@@ -197,60 +196,164 @@ def process_if_not_none(input_val: Optional[X], processor: Callable[[X], Y]) ->
197
196
  return processor(input_val)
198
197
 
199
198
 
200
- def run_sql_on_dataframes(sql_query: str, dataframes: dict[str, pd.DataFrame], do_use_duckdb: bool) -> pd.DataFrame:
199
+ @lru_cache(maxsize=1)
200
+ def _read_duckdb_init_sql() -> tuple[str, Path | None]:
201
+ """
202
+ Reads and caches the duckdb init file content.
203
+ Returns None if file doesn't exist or is empty.
204
+ """
205
+ try:
206
+ init_contents = []
207
+ global_init_path = Path(os.path.expanduser('~'), c.GLOBAL_ENV_FOLDER, c.DUCKDB_INIT_FILE)
208
+ if global_init_path.exists():
209
+ with open(global_init_path, 'r') as f:
210
+ init_contents.append(f.read())
211
+
212
+ if Path(c.DUCKDB_INIT_FILE).exists():
213
+ with open(c.DUCKDB_INIT_FILE, 'r') as f:
214
+ init_contents.append(f.read())
215
+
216
+ init_sql = "\n".join(init_contents).strip()
217
+ target_init_path = None
218
+ if init_sql:
219
+ target_init_path = Path(c.TARGET_FOLDER, c.DUCKDB_INIT_FILE)
220
+ target_init_path.parent.mkdir(parents=True, exist_ok=True)
221
+ target_init_path.write_text(init_sql)
222
+
223
+ return init_sql, target_init_path
224
+ except Exception as e:
225
+ raise ConfigurationError(f"Failed to read {c.DUCKDB_INIT_FILE}: {str(e)}") from e
226
+
227
+ def create_duckdb_connection(filepath: str | Path = ":memory:", *, read_only: bool = False) -> duckdb.DuckDBPyConnection:
228
+ """
229
+ Creates a DuckDB connection and initializes it with statements from duckdb init file
230
+
231
+ Arguments:
232
+ filepath: Path to the DuckDB database file. Defaults to in-memory database.
233
+ read_only: Whether to open the database in read-only mode. Defaults to False.
234
+
235
+ Returns:
236
+ A DuckDB connection (which must be closed after use)
237
+ """
238
+ conn = duckdb.connect(filepath, read_only=read_only)
239
+
240
+ try:
241
+ init_sql, _ = _read_duckdb_init_sql()
242
+ if init_sql:
243
+ conn.execute(init_sql)
244
+ except Exception as e:
245
+ conn.close()
246
+ raise ConfigurationError(f"Failed to execute {c.DUCKDB_INIT_FILE}: {str(e)}") from e
247
+
248
+ return conn
249
+
250
+
251
+ def run_sql_on_dataframes(sql_query: str, dataframes: dict[str, pl.LazyFrame]) -> pl.DataFrame:
201
252
  """
202
253
  Runs a SQL query against a collection of dataframes
203
254
 
204
255
  Arguments:
205
256
  sql_query: The SQL query to run
206
- dataframes: A dictionary of table names to their pandas Dataframe
257
+ dataframes: A dictionary of table names to their polars LazyFrame
207
258
 
208
259
  Returns:
209
- The result as a pandas Dataframe from running the query
260
+ The result as a polars Dataframe from running the query
210
261
  """
211
- if do_use_duckdb:
212
- import duckdb
213
- duckdb_conn = duckdb.connect()
214
- else:
215
- conn = sqlite3.connect(":memory:")
262
+ duckdb_conn = create_duckdb_connection()
216
263
 
217
264
  try:
218
265
  for name, df in dataframes.items():
219
- if do_use_duckdb:
220
- duckdb_conn.execute(f"CREATE TABLE {name} AS FROM df")
221
- else:
222
- df.to_sql(name, conn, index=False)
266
+ duckdb_conn.register(name, df)
223
267
 
224
- return duckdb_conn.execute(sql_query).df() if do_use_duckdb else pd.read_sql(sql_query, conn)
268
+ result_df = duckdb_conn.sql(sql_query).pl()
225
269
  finally:
226
- duckdb_conn.close() if do_use_duckdb else conn.close()
270
+ duckdb_conn.close()
271
+
272
+ return result_df
227
273
 
228
274
 
229
- def df_to_json0(df: pd.DataFrame, dimensions: list[str] | None = None) -> dict:
275
+ def load_yaml_config(filepath: FilePath) -> dict:
230
276
  """
231
- Convert a pandas DataFrame to the response format that the dataset result API of Squirrels outputs.
277
+ Loads a YAML config file
232
278
 
233
279
  Arguments:
234
- df: The dataframe to convert into an API response
235
- dimensions: The list of declared dimensions. If None, all non-numeric columns are assumed as dimensions
280
+ filepath: The path to the YAML file
281
+
282
+ Returns:
283
+ A dictionary representation of the YAML file
284
+ """
285
+ try:
286
+ with open(filepath, 'r') as f:
287
+ return yaml.safe_load(f)
288
+ except yaml.YAMLError as e:
289
+ raise ConfigurationError(f"Failed to parse yaml file: {filepath}") from e
290
+
291
+
292
+ def run_duckdb_stmt(
293
+ logger: Logger, duckdb_conn: duckdb.DuckDBPyConnection, stmt: str, *, params: dict[str, Any] | None = None, redacted_values: list[str] = []
294
+ ) -> duckdb.DuckDBPyConnection:
295
+ """
296
+ Runs a statement on a DuckDB connection
297
+
298
+ Arguments:
299
+ logger: The logger to use
300
+ duckdb_conn: The DuckDB connection
301
+ stmt: The statement to run
302
+ params: The parameters to use
303
+ redacted_values: The values to redact
304
+ """
305
+ redacted_stmt = stmt
306
+ for value in redacted_values:
307
+ redacted_stmt = redacted_stmt.replace(value, "[REDACTED]")
308
+
309
+ logger.info(f"Running statement: {redacted_stmt}", extra={"data": {"params": params}})
310
+ try:
311
+ return duckdb_conn.execute(stmt, params)
312
+ except duckdb.ParserException as e:
313
+ logger.error(f"Failed to run statement: {redacted_stmt}", exc_info=e)
314
+ raise e
315
+
236
316
 
317
+ def get_current_time() -> str:
318
+ """
319
+ Returns the current time in the format HH:MM:SS.ms
320
+ """
321
+ return datetime.now().strftime('%H:%M:%S.%f')[:-3]
322
+
323
+
324
+ def parse_dependent_tables(sql_query: str, all_table_names: Iterable[str]) -> tuple[set[str], sqlglot.Expression]:
325
+ """
326
+ Parses the dependent tables from a SQL query
327
+
328
+ Arguments:
329
+ sql_query: The SQL query to parse
330
+ all_table_names: The list of all table names
331
+
237
332
  Returns:
238
- The response of a Squirrels dataset result API
333
+ The set of dependent tables
239
334
  """
240
- in_df_json = json.loads(df.to_json(orient='table', index=False))
241
- out_fields = []
242
- non_numeric_fields = []
243
- for in_column in in_df_json["schema"]["fields"]:
244
- col_name: str = in_column["name"]
245
- out_column = { "name": col_name, "type": in_column["type"] }
246
- out_fields.append(out_column)
247
-
248
- if not pd_types.is_numeric_dtype(df[col_name].dtype):
249
- non_numeric_fields.append(col_name)
335
+ # Parse the SQL query and extract all table references
336
+ parsed = sqlglot.parse_one(sql_query)
337
+ dependencies = set()
250
338
 
251
- out_dimensions = non_numeric_fields if dimensions is None else dimensions
252
- dataset_json = {
253
- "schema": { "fields": out_fields, "dimensions": out_dimensions },
254
- "data": in_df_json["data"]
255
- }
256
- return dataset_json
339
+ # Collect all table references from the parsed SQL
340
+ for table in parsed.find_all(sqlglot.expressions.Table):
341
+ if table.name in set(all_table_names):
342
+ dependencies.add(table.name)
343
+
344
+ return dependencies, parsed
345
+
346
+
347
+ async def asyncio_gather(coroutines: list):
348
+ tasks = [asyncio.create_task(coro) for coro in coroutines]
349
+
350
+ try:
351
+ return await asyncio.gather(*tasks)
352
+ except BaseException:
353
+ # Cancel all tasks
354
+ for task in tasks:
355
+ if not task.done():
356
+ task.cancel()
357
+ # Wait for tasks to be cancelled
358
+ await asyncio.gather(*tasks, return_exceptions=True)
359
+ raise
squirrels/_version.py CHANGED
@@ -1,3 +1,3 @@
1
- __version__ = '0.4.1'
1
+ __version__ = '0.5.0'
2
2
 
3
3
  sq_major_version, sq_minor_version, sq_patch_version = __version__.split('.')[:3]
@@ -1,40 +1,103 @@
1
- from typing import Callable, Any
1
+ from typing import Any, Iterable, Callable
2
2
  from dataclasses import dataclass
3
+ import polars as pl
3
4
 
5
+ from .. import _utils as u
4
6
 
5
7
  @dataclass
6
- class BaseArguments:
8
+ class ConnectionsArgs:
9
+ project_path: str
7
10
  _proj_vars: dict[str, Any]
8
- _env_vars: dict[str, Any]
11
+ _env_vars: dict[str, str]
9
12
 
10
13
  @property
11
14
  def proj_vars(self) -> dict[str, Any]:
12
15
  return self._proj_vars.copy()
13
16
 
14
17
  @property
15
- def env_vars(self) -> dict[str, Any]:
18
+ def env_vars(self) -> dict[str, str]:
16
19
  return self._env_vars.copy()
17
20
 
18
21
 
19
22
  @dataclass
20
- class ConnectionsArgs(BaseArguments):
21
- _get_credential: Callable[[str | None], tuple[str, str]]
23
+ class ParametersArgs(ConnectionsArgs):
24
+ pass
25
+
26
+
27
+ @dataclass
28
+ class _WithConnectionDictArgs(ConnectionsArgs):
29
+ _connections: dict[str, Any]
30
+
31
+ @property
32
+ def connections(self) -> dict[str, Any]:
33
+ """
34
+ A dictionary of connection keys to SQLAlchemy Engines for database connections.
35
+
36
+ Can also be used to store other in-memory objects in advance such as ML models.
37
+ """
38
+ return self._connections.copy()
39
+
40
+
41
+ class BuildModelArgs(_WithConnectionDictArgs):
42
+
43
+ def __init__(
44
+ self, conn_args: ConnectionsArgs, _connections: dict[str, Any],
45
+ dependencies: Iterable[str],
46
+ ref: Callable[[str], pl.LazyFrame],
47
+ run_external_sql: Callable[[str, str], pl.DataFrame]
48
+ ):
49
+ super().__init__(conn_args.project_path, conn_args.proj_vars, conn_args.env_vars, _connections)
50
+ self._dependencies = dependencies
51
+ self._ref = ref
52
+ self._run_external_sql = run_external_sql
22
53
 
23
- def get_credential(self, key: str | None) -> tuple[str, str]:
54
+ @property
55
+ def dependencies(self) -> set[str]:
56
+ """
57
+ The set of dependent data model names
58
+ """
59
+ return set(self._dependencies)
60
+
61
+ def ref(self, model: str) -> pl.LazyFrame:
24
62
  """
25
- Return (username, password) tuple configured for credentials key in env.yaml
63
+ Returns the result (as polars DataFrame) of a dependent model (predefined in "dependencies" function)
26
64
 
27
- If key is None, returns tuple of empty strings ("", "")
65
+ Note: This is different behaviour than the "ref" function for SQL models, which figures out the dependent models for you,
66
+ and returns a string for the table/view name instead of a polars DataFrame.
28
67
 
29
68
  Arguments:
30
- key: The credentials key
69
+ model: The model name
31
70
 
32
71
  Returns:
33
- A tuple of 2 strings
72
+ A polars DataFrame
34
73
  """
35
- return self._get_credential(key)
74
+ return self._ref(model)
36
75
 
76
+ def run_external_sql(self, connection_name: str, sql_query: str, **kwargs) -> pl.DataFrame:
77
+ """
78
+ Runs a SQL query against an external database, with option to specify the connection name. Placeholder values are provided automatically
37
79
 
38
- @dataclass
39
- class ParametersArgs(BaseArguments):
40
- pass
80
+ Arguments:
81
+ sql_query: The SQL query. Can be parameterized with placeholders
82
+ connection_name: The connection name for the database. If None, uses the one configured for the model
83
+
84
+ Returns:
85
+ The query result as a polars DataFrame
86
+ """
87
+ return self._run_external_sql(sql_query, connection_name)
88
+
89
+ def run_sql_on_dataframes(self, sql_query: str, *, dataframes: dict[str, pl.LazyFrame] | None = None, **kwargs) -> pl.DataFrame:
90
+ """
91
+ Uses a dictionary of dataframes to execute a SQL query in an embedded in-memory database (sqlite or duckdb based on setting)
92
+
93
+ Arguments:
94
+ sql_query: The SQL query to run
95
+ dataframes: A dictionary of table names to their polars LazyFrame. If None, uses results of dependent models
96
+
97
+ Returns:
98
+ The result as a polars LazyFrame from running the query
99
+ """
100
+ if dataframes is None:
101
+ dataframes = {x: self.ref(x) for x in self._dependencies}
102
+
103
+ return u.run_sql_on_dataframes(sql_query, dataframes)