squirrels 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of squirrels might be problematic. Click here for more details.

Files changed (125) hide show
  1. dateutils/__init__.py +6 -0
  2. dateutils/_enums.py +25 -0
  3. squirrels/dateutils.py → dateutils/_implementation.py +58 -111
  4. dateutils/types.py +6 -0
  5. squirrels/__init__.py +13 -11
  6. squirrels/_api_routes/__init__.py +5 -0
  7. squirrels/_api_routes/auth.py +271 -0
  8. squirrels/_api_routes/base.py +165 -0
  9. squirrels/_api_routes/dashboards.py +150 -0
  10. squirrels/_api_routes/data_management.py +145 -0
  11. squirrels/_api_routes/datasets.py +257 -0
  12. squirrels/_api_routes/oauth2.py +298 -0
  13. squirrels/_api_routes/project.py +252 -0
  14. squirrels/_api_server.py +256 -450
  15. squirrels/_arguments/__init__.py +0 -0
  16. squirrels/_arguments/init_time_args.py +108 -0
  17. squirrels/_arguments/run_time_args.py +147 -0
  18. squirrels/_auth.py +960 -0
  19. squirrels/_command_line.py +126 -45
  20. squirrels/_compile_prompts.py +147 -0
  21. squirrels/_connection_set.py +48 -26
  22. squirrels/_constants.py +68 -38
  23. squirrels/_dashboards.py +160 -0
  24. squirrels/_data_sources.py +570 -0
  25. squirrels/_dataset_types.py +84 -0
  26. squirrels/_exceptions.py +29 -0
  27. squirrels/_initializer.py +177 -80
  28. squirrels/_logging.py +115 -0
  29. squirrels/_manifest.py +208 -79
  30. squirrels/_model_builder.py +69 -0
  31. squirrels/_model_configs.py +74 -0
  32. squirrels/_model_queries.py +52 -0
  33. squirrels/_models.py +926 -367
  34. squirrels/_package_data/base_project/.env +42 -0
  35. squirrels/_package_data/base_project/.env.example +42 -0
  36. squirrels/_package_data/base_project/assets/expenses.db +0 -0
  37. squirrels/_package_data/base_project/connections.yml +16 -0
  38. squirrels/_package_data/base_project/dashboards/dashboard_example.py +34 -0
  39. squirrels/_package_data/base_project/dashboards/dashboard_example.yml +22 -0
  40. squirrels/{package_data → _package_data}/base_project/docker/.dockerignore +5 -2
  41. squirrels/{package_data → _package_data}/base_project/docker/Dockerfile +3 -3
  42. squirrels/{package_data → _package_data}/base_project/docker/compose.yml +1 -1
  43. squirrels/_package_data/base_project/duckdb_init.sql +10 -0
  44. squirrels/{package_data/base_project/.gitignore → _package_data/base_project/gitignore} +3 -2
  45. squirrels/_package_data/base_project/macros/macros_example.sql +17 -0
  46. squirrels/_package_data/base_project/models/builds/build_example.py +26 -0
  47. squirrels/_package_data/base_project/models/builds/build_example.sql +16 -0
  48. squirrels/_package_data/base_project/models/builds/build_example.yml +57 -0
  49. squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +12 -0
  50. squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +26 -0
  51. squirrels/_package_data/base_project/models/federates/federate_example.py +37 -0
  52. squirrels/_package_data/base_project/models/federates/federate_example.sql +19 -0
  53. squirrels/_package_data/base_project/models/federates/federate_example.yml +65 -0
  54. squirrels/_package_data/base_project/models/sources.yml +38 -0
  55. squirrels/{package_data → _package_data}/base_project/parameters.yml +56 -40
  56. squirrels/_package_data/base_project/pyconfigs/connections.py +14 -0
  57. squirrels/{package_data → _package_data}/base_project/pyconfigs/context.py +21 -40
  58. squirrels/_package_data/base_project/pyconfigs/parameters.py +141 -0
  59. squirrels/_package_data/base_project/pyconfigs/user.py +44 -0
  60. squirrels/_package_data/base_project/seeds/seed_categories.yml +15 -0
  61. squirrels/_package_data/base_project/seeds/seed_subcategories.csv +15 -0
  62. squirrels/_package_data/base_project/seeds/seed_subcategories.yml +21 -0
  63. squirrels/_package_data/base_project/squirrels.yml.j2 +61 -0
  64. squirrels/_package_data/templates/dataset_results.html +112 -0
  65. squirrels/_package_data/templates/oauth_login.html +271 -0
  66. squirrels/_package_data/templates/squirrels_studio.html +20 -0
  67. squirrels/_package_loader.py +8 -4
  68. squirrels/_parameter_configs.py +104 -103
  69. squirrels/_parameter_options.py +348 -0
  70. squirrels/_parameter_sets.py +57 -47
  71. squirrels/_parameters.py +1664 -0
  72. squirrels/_project.py +721 -0
  73. squirrels/_py_module.py +7 -5
  74. squirrels/_schemas/__init__.py +0 -0
  75. squirrels/_schemas/auth_models.py +167 -0
  76. squirrels/_schemas/query_param_models.py +75 -0
  77. squirrels/{_api_response_models.py → _schemas/response_models.py} +126 -47
  78. squirrels/_seeds.py +35 -16
  79. squirrels/_sources.py +110 -0
  80. squirrels/_utils.py +248 -73
  81. squirrels/_version.py +1 -1
  82. squirrels/arguments.py +7 -0
  83. squirrels/auth.py +4 -0
  84. squirrels/connections.py +3 -0
  85. squirrels/dashboards.py +2 -81
  86. squirrels/data_sources.py +14 -631
  87. squirrels/parameter_options.py +13 -348
  88. squirrels/parameters.py +14 -1266
  89. squirrels/types.py +16 -0
  90. squirrels-0.5.0.dist-info/METADATA +113 -0
  91. squirrels-0.5.0.dist-info/RECORD +97 -0
  92. {squirrels-0.4.0.dist-info → squirrels-0.5.0.dist-info}/WHEEL +1 -1
  93. squirrels-0.5.0.dist-info/entry_points.txt +3 -0
  94. {squirrels-0.4.0.dist-info → squirrels-0.5.0.dist-info/licenses}/LICENSE +1 -1
  95. squirrels/_authenticator.py +0 -85
  96. squirrels/_dashboards_io.py +0 -61
  97. squirrels/_environcfg.py +0 -84
  98. squirrels/arguments/init_time_args.py +0 -40
  99. squirrels/arguments/run_time_args.py +0 -208
  100. squirrels/package_data/assets/favicon.ico +0 -0
  101. squirrels/package_data/assets/index.css +0 -1
  102. squirrels/package_data/assets/index.js +0 -58
  103. squirrels/package_data/base_project/assets/expenses.db +0 -0
  104. squirrels/package_data/base_project/connections.yml +0 -7
  105. squirrels/package_data/base_project/dashboards/dashboard_example.py +0 -32
  106. squirrels/package_data/base_project/dashboards.yml +0 -10
  107. squirrels/package_data/base_project/env.yml +0 -29
  108. squirrels/package_data/base_project/models/dbviews/dbview_example.py +0 -47
  109. squirrels/package_data/base_project/models/dbviews/dbview_example.sql +0 -22
  110. squirrels/package_data/base_project/models/federates/federate_example.py +0 -21
  111. squirrels/package_data/base_project/models/federates/federate_example.sql +0 -3
  112. squirrels/package_data/base_project/pyconfigs/auth.py +0 -45
  113. squirrels/package_data/base_project/pyconfigs/connections.py +0 -19
  114. squirrels/package_data/base_project/pyconfigs/parameters.py +0 -95
  115. squirrels/package_data/base_project/seeds/seed_subcategories.csv +0 -15
  116. squirrels/package_data/base_project/squirrels.yml.j2 +0 -94
  117. squirrels/package_data/templates/index.html +0 -18
  118. squirrels/project.py +0 -378
  119. squirrels/user_base.py +0 -55
  120. squirrels-0.4.0.dist-info/METADATA +0 -117
  121. squirrels-0.4.0.dist-info/RECORD +0 -60
  122. squirrels-0.4.0.dist-info/entry_points.txt +0 -4
  123. /squirrels/{package_data → _package_data}/base_project/assets/weather.db +0 -0
  124. /squirrels/{package_data → _package_data}/base_project/seeds/seed_categories.csv +0 -0
  125. /squirrels/{package_data → _package_data}/base_project/tmp/.gitignore +0 -0
@@ -0,0 +1,570 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from enum import Enum
4
+ import polars as pl, typing as t, abc
5
+
6
+ from . import _parameter_configs as pc, _parameter_options as po
7
+ from ._exceptions import ConfigurationError
8
+
9
+ class SourceEnum(Enum):
10
+ CONNECTION = "connection"
11
+ SEEDS = "seeds"
12
+ VDL = "vdl"
13
+
14
+
15
+ @dataclass
16
+ class DataSource(metaclass=abc.ABCMeta):
17
+ """
18
+ Abstract class for lookup tables coming from a database
19
+ """
20
+ _table_or_query: str
21
+ _id_col: str | None
22
+ _source: SourceEnum
23
+ _user_group_col: str | None
24
+ _parent_id_col: str | None
25
+ _connection: str | None
26
+
27
+ @abc.abstractmethod
28
+ def __init__(
29
+ self, table_or_query: str, *, id_col: str | None = None, source: SourceEnum = SourceEnum.CONNECTION,
30
+ user_group_col: str | None = None, parent_id_col: str | None = None, connection: str | None = None, **kwargs
31
+ ) -> None:
32
+ self._table_or_query = table_or_query
33
+ self._id_col = id_col
34
+ self._source = source
35
+ self._user_group_col = user_group_col
36
+ self._parent_id_col = parent_id_col
37
+ self._connection = connection
38
+
39
+ def _get_connection_name(self, default_conn_name: str) -> str:
40
+ return self._connection if self._connection is not None else default_conn_name
41
+
42
+ def _get_query(self) -> str:
43
+ """
44
+ Get the "table_or_query" attribute as a select query
45
+
46
+ Returns:
47
+ str: The converted select query
48
+ """
49
+ if self._table_or_query.strip().lower().startswith('select '):
50
+ query = self._table_or_query
51
+ else:
52
+ query = f'SELECT * FROM {self._table_or_query}'
53
+ return query
54
+
55
+ @abc.abstractmethod
56
+ def _convert(self, ds_param: pc.DataSourceParameterConfig, df: pl.DataFrame) -> pc.ParameterConfig:
57
+ """
58
+ An abstract method for converting itself into a parameter
59
+ """
60
+ pass
61
+
62
+ def _validate_parameter_type(self, ds_param: pc.DataSourceParameterConfig, target_parameter_type: t.Type[pc.ParameterConfig]) -> None:
63
+ if ds_param.parameter_type != target_parameter_type:
64
+ parameter_type_name = ds_param.parameter_type.__name__
65
+ datasource_type_name = self.__class__.__name__
66
+ raise ConfigurationError(f'Invalid widget type "{parameter_type_name}" for {datasource_type_name}')
67
+
68
+ def _get_aggregated_df(self, df: pl.DataFrame, columns_to_include: t.Iterable[str]) -> pl.DataFrame:
69
+ if self._id_col is None:
70
+ return df
71
+
72
+ agg_rules = []
73
+ for column in columns_to_include:
74
+ if column is not None:
75
+ agg_rules.append(pl.first(column))
76
+ if self._user_group_col is not None:
77
+ agg_rules.append(pl.col(self._user_group_col))
78
+ if self._parent_id_col is not None:
79
+ agg_rules.append(pl.col(self._parent_id_col))
80
+
81
+ try:
82
+ df_agg = df.group_by(self._id_col).agg(agg_rules).sort(by=self._id_col)
83
+ except pl.exceptions.ColumnNotFoundError as e:
84
+ raise ConfigurationError(e)
85
+
86
+ return df_agg
87
+
88
+ def _get_key_from_record(self, key: str | None, record: dict[t.Hashable, t.Any], default: t.Any) -> t.Any:
89
+ return record[key] if key is not None else default
90
+
91
+ def _get_key_from_record_as_list(self, key: str | None, record: dict[t.Hashable, t.Any]) -> t.Iterable[str]:
92
+ value = self._get_key_from_record(key, record, list())
93
+ return [str(x) for x in value]
94
+
95
+
96
+ @dataclass
97
+ class _SelectionDataSource(DataSource):
98
+ """
99
+ Abstract class for selection parameter data sources
100
+ """
101
+ _options_col: str
102
+ _order_by_col: str | None
103
+ _is_default_col: str | None
104
+ _custom_cols: dict[str, str]
105
+
106
+ @abc.abstractmethod
107
+ def __init__(
108
+ self, table_or_query: str, id_col: str, options_col: str, *, order_by_col: str | None = None,
109
+ is_default_col: str | None = None, custom_cols: dict[str, str] = {}, source: SourceEnum = SourceEnum.CONNECTION,
110
+ user_group_col: str | None = None, parent_id_col: str | None = None, connection: str | None = None,
111
+ **kwargs
112
+ ) -> None:
113
+ super().__init__(
114
+ table_or_query, id_col=id_col, source=source, user_group_col=user_group_col, parent_id_col=parent_id_col,
115
+ connection=connection
116
+ )
117
+ self._options_col = options_col
118
+ self._order_by_col = order_by_col
119
+ self._is_default_col = is_default_col
120
+ self._custom_cols = custom_cols
121
+
122
+ def _get_all_options(self, df: pl.DataFrame) -> t.Sequence[po.SelectParameterOption]:
123
+ columns = [self._options_col, self._order_by_col, self._is_default_col, *self._custom_cols.values()]
124
+ df_agg = self._get_aggregated_df(df, columns)
125
+
126
+ if self._order_by_col is None:
127
+ df_agg = df_agg.sort(by=self._id_col)
128
+ else:
129
+ df_agg = df_agg.sort(by=self._order_by_col)
130
+
131
+ def get_is_default(record: dict[t.Hashable, t.Any]) -> bool:
132
+ return int(record[self._is_default_col]) == 1 if self._is_default_col is not None else False
133
+
134
+ def get_custom_fields(record: dict[t.Hashable, t.Any]) -> dict[str, t.Any]:
135
+ result = {}
136
+ for key, val in self._custom_cols.items():
137
+ result[key] = record[val]
138
+ return result
139
+
140
+ records = df_agg.to_pandas().to_dict("records")
141
+ return tuple(
142
+ po.SelectParameterOption(
143
+ str(record[self._id_col]), str(record[self._options_col]),
144
+ is_default=get_is_default(record), custom_fields=get_custom_fields(record),
145
+ user_groups=self._get_key_from_record_as_list(self._user_group_col, record),
146
+ parent_option_ids=self._get_key_from_record_as_list(self._parent_id_col, record)
147
+ )
148
+ for record in records
149
+ )
150
+
151
+
152
+ @dataclass
153
+ class SelectDataSource(_SelectionDataSource):
154
+ """
155
+ Lookup table for select parameter options
156
+ """
157
+
158
+ def __init__(
159
+ self, table_or_query: str, id_col: str, options_col: str, *, order_by_col: str | None = None,
160
+ is_default_col: str | None = None, custom_cols: dict[str, str] = {}, source: SourceEnum = SourceEnum.CONNECTION,
161
+ user_group_col: str | None = None, parent_id_col: str | None = None, connection: str | None = None,
162
+ **kwargs
163
+ ) -> None:
164
+ """
165
+ Constructor for SelectDataSource
166
+
167
+ Arguments:
168
+ table_or_query: Either the name of the table to use, or a query to run
169
+ id_col: The column name of the id
170
+ options_col: The column name of the options
171
+ order_by_col: The column name to order the options by. Orders by the id_col instead if this is None
172
+ is_default_col: The column name that indicates which options are the default
173
+ custom_cols: Dictionary of attribute to column name for custom fields for the SelectParameterOption
174
+ source: The source to fetch data from. Must be "connection", "seeds", or "vdl". Defaults to "connection"
175
+ user_group_col: The column name of the user group that the user is in for this option to be valid
176
+ parent_id_col: The column name of the parent option id that must be selected for this option to be valid
177
+ connection: Name of the connection to use defined in connections.py
178
+ """
179
+ super().__init__(
180
+ table_or_query, id_col, options_col, order_by_col=order_by_col, is_default_col=is_default_col, custom_cols=custom_cols,
181
+ source=source, user_group_col=user_group_col, parent_id_col=parent_id_col, connection=connection
182
+ )
183
+
184
+ def _convert(self, ds_param: pc.DataSourceParameterConfig, df: pl.DataFrame) -> pc.SelectionParameterConfig:
185
+ """
186
+ Method to convert the associated DataSourceParameterConfig into a SingleSelectParameterConfig or MultiSelectParameterConfig
187
+
188
+ Arguments:
189
+ ds_param: The parameter to convert
190
+ df: The dataframe containing the parameter options data
191
+
192
+ Returns:
193
+ The converted parameter
194
+ """
195
+ all_options = self._get_all_options(df)
196
+ if ds_param.parameter_type == pc.SingleSelectParameterConfig:
197
+ return pc.SingleSelectParameterConfig(
198
+ ds_param.name, ds_param.label, all_options, description=ds_param.description,
199
+ user_attribute=ds_param.user_attribute, parent_name=ds_param.parent_name, **ds_param.extra_args
200
+ )
201
+ elif ds_param.parameter_type == pc.MultiSelectParameterConfig:
202
+ return pc.MultiSelectParameterConfig(
203
+ ds_param.name, ds_param.label, all_options, description=ds_param.description,
204
+ user_attribute=ds_param.user_attribute, parent_name=ds_param.parent_name, **ds_param.extra_args
205
+ )
206
+ else:
207
+ raise ConfigurationError(f'Invalid widget type "{ds_param.parameter_type}" for SelectDataSource')
208
+
209
+
210
+ @dataclass
211
+ class DateDataSource(DataSource):
212
+ """
213
+ Lookup table for date parameter default options
214
+ """
215
+ _default_date_col: str
216
+ _date_format: str
217
+
218
+ def __init__(
219
+ self, table_or_query: str, default_date_col: str, *, min_date_col: str | None = None,
220
+ max_date_col: str | None = None, date_format: str = '%Y-%m-%d', id_col: str | None = None,
221
+ source: SourceEnum = SourceEnum.CONNECTION, user_group_col: str | None = None, parent_id_col: str | None = None,
222
+ connection: str | None = None, **kwargs
223
+ ) -> None:
224
+ """
225
+ Constructor for DateDataSource
226
+
227
+ Arguments:
228
+ table_or_query: Either the name of the table to use, or a query to run
229
+ default_date_col: The column name of the default date
230
+ date_format: The format of the default date(s). Defaults to '%Y-%m-%d'
231
+ id_col: The column name of the id
232
+ source: The source to fetch data from. Must be "connection", "seeds", or "vdl". Defaults to "connection"
233
+ user_group_col: The column name of the user group that the user is in for this option to be valid
234
+ parent_id_col: The column name of the parent option id that the default date belongs to
235
+ connection: Name of the connection to use defined in connections.py
236
+ """
237
+ super().__init__(
238
+ table_or_query, id_col=id_col, source=source, user_group_col=user_group_col, parent_id_col=parent_id_col,
239
+ connection=connection
240
+ )
241
+ self._default_date_col = default_date_col
242
+ self._min_date_col = min_date_col
243
+ self._max_date_col = max_date_col
244
+ self._date_format = date_format
245
+
246
+ def _convert(self, ds_param: pc.DataSourceParameterConfig, df: pl.DataFrame) -> pc.DateParameterConfig:
247
+ """
248
+ Method to convert the associated DataSourceParameterConfig into a DateParameterConfig
249
+
250
+ Arguments:
251
+ ds_param: The parameter to convert
252
+ df: The dataframe containing the parameter options data
253
+
254
+ Returns:
255
+ The converted parameter
256
+ """
257
+ self._validate_parameter_type(ds_param, pc.DateParameterConfig)
258
+
259
+ columns = [self._default_date_col, self._min_date_col, self._max_date_col]
260
+ df_agg = self._get_aggregated_df(df, columns)
261
+
262
+ records = df_agg.to_pandas().to_dict("records")
263
+ options = tuple(
264
+ po.DateParameterOption(
265
+ str(record[self._default_date_col]), date_format=self._date_format,
266
+ min_date = str(record[self._min_date_col]) if self._min_date_col else None,
267
+ max_date = str(record[self._max_date_col]) if self._max_date_col else None,
268
+ user_groups=self._get_key_from_record_as_list(self._user_group_col, record),
269
+ parent_option_ids=self._get_key_from_record_as_list(self._parent_id_col, record)
270
+ )
271
+ for record in records
272
+ )
273
+ return pc.DateParameterConfig(
274
+ ds_param.name, ds_param.label, options, description=ds_param.description, user_attribute=ds_param.user_attribute,
275
+ parent_name=ds_param.parent_name, **ds_param.extra_args
276
+ )
277
+
278
+
279
+ @dataclass
280
+ class DateRangeDataSource(DataSource):
281
+ """
282
+ Lookup table for date parameter default options
283
+ """
284
+ _default_start_date_col: str
285
+ _default_end_date_col: str
286
+ _date_format: str
287
+
288
+ def __init__(
289
+ self, table_or_query: str, default_start_date_col: str, default_end_date_col: str, *, date_format: str = '%Y-%m-%d',
290
+ min_date_col: str | None = None, max_date_col: str | None = None, id_col: str | None = None, source: SourceEnum = SourceEnum.CONNECTION,
291
+ user_group_col: str | None = None, parent_id_col: str | None = None, connection: str | None = None, **kwargs
292
+ ) -> None:
293
+ """
294
+ Constructor for DateRangeDataSource
295
+
296
+ Arguments:
297
+ table_or_query: Either the name of the table to use, or a query to run
298
+ default_start_date_col: The column name of the default start date
299
+ default_end_date_col: The column name of the default end date
300
+ date_format: The format of the default date(s). Defaults to '%Y-%m-%d'
301
+ id_col: The column name of the id
302
+ source: The source to fetch data from. Must be "connection", "seeds", or "vdl". Defaults to "connection"
303
+ user_group_col: The column name of the user group that the user is in for this option to be valid
304
+ parent_id_col: The column name of the parent option id that the default date belongs to
305
+ connection: Name of the connection to use defined in connections.py
306
+ """
307
+ super().__init__(
308
+ table_or_query, id_col=id_col, source=source, user_group_col=user_group_col, parent_id_col=parent_id_col,
309
+ connection=connection
310
+ )
311
+ self._default_start_date_col = default_start_date_col
312
+ self._default_end_date_col = default_end_date_col
313
+ self._min_date_col = min_date_col
314
+ self._max_date_col = max_date_col
315
+ self._date_format = date_format
316
+
317
+ def _convert(self, ds_param: pc.DataSourceParameterConfig, df: pl.DataFrame) -> pc.DateRangeParameterConfig:
318
+ """
319
+ Method to convert the associated DataSourceParameterConfig into a DateRangeParameterConfig
320
+
321
+ Arguments:
322
+ ds_param: The parameter to convert
323
+ df: The dataframe containing the parameter options data
324
+
325
+ Returns:
326
+ The converted parameter
327
+ """
328
+ self._validate_parameter_type(ds_param, pc.DateRangeParameterConfig)
329
+
330
+ columns = [self._default_start_date_col, self._default_end_date_col, self._min_date_col, self._max_date_col]
331
+ df_agg = self._get_aggregated_df(df, columns)
332
+
333
+ records = df_agg.to_pandas().to_dict("records")
334
+ options = tuple(
335
+ po.DateRangeParameterOption(
336
+ str(record[self._default_start_date_col]), str(record[self._default_end_date_col]),
337
+ min_date=str(record[self._min_date_col]) if self._min_date_col else None,
338
+ max_date=str(record[self._max_date_col]) if self._max_date_col else None,
339
+ date_format=self._date_format,
340
+ user_groups=self._get_key_from_record_as_list(self._user_group_col, record),
341
+ parent_option_ids=self._get_key_from_record_as_list(self._parent_id_col, record)
342
+ )
343
+ for record in records
344
+ )
345
+ return pc.DateRangeParameterConfig(
346
+ ds_param.name, ds_param.label, options, description=ds_param.description, user_attribute=ds_param.user_attribute,
347
+ parent_name=ds_param.parent_name, **ds_param.extra_args
348
+ )
349
+
350
+
351
+ @dataclass
352
+ class _NumericDataSource(DataSource):
353
+ """
354
+ Abstract class for number or number range data sources
355
+ """
356
+ _min_value_col: str
357
+ _max_value_col: str
358
+ _increment_col: str | None
359
+
360
+ @abc.abstractmethod
361
+ def __init__(
362
+ self, table_or_query: str, min_value_col: str, max_value_col: str, *, increment_col: str | None = None,
363
+ id_col: str | None = None, source: SourceEnum = SourceEnum.CONNECTION, user_group_col: str | None = None,
364
+ parent_id_col: str | None = None, connection: str | None = None, **kwargs
365
+ ) -> None:
366
+ super().__init__(
367
+ table_or_query, id_col=id_col, source=source, user_group_col=user_group_col, parent_id_col=parent_id_col,
368
+ connection=connection
369
+ )
370
+ self._min_value_col = min_value_col
371
+ self._max_value_col = max_value_col
372
+ self._increment_col = increment_col
373
+
374
+
375
+ @dataclass
376
+ class NumberDataSource(_NumericDataSource):
377
+ """
378
+ Lookup table for number parameter default options
379
+ """
380
+ _default_value_col: str | None
381
+
382
+ def __init__(
383
+ self, table_or_query: str, min_value_col: str, max_value_col: str, *, increment_col: str | None = None,
384
+ default_value_col: str | None = None, id_col: str | None = None, source: SourceEnum = SourceEnum.CONNECTION,
385
+ user_group_col: str | None = None, parent_id_col: str | None = None, connection: str | None = None, **kwargs
386
+ ) -> None:
387
+ """
388
+ Constructor for NumberDataSource
389
+
390
+ Arguments:
391
+ table_or_query: Either the name of the table to use, or a query to run
392
+ min_value_col: The column name of the minimum value
393
+ max_value_col: The column name of the maximum value
394
+ increment_col: The column name of the increment value. Defaults to column of 1's if None
395
+ default_value_col: The column name of the default value. Defaults to min_value_col if None
396
+ id_col: The column name of the id
397
+ source: The source to fetch data from. Must be "connection", "seeds", or "vdl". Defaults to "connection"
398
+ user_group_col: The column name of the user group that the user is in for this option to be valid
399
+ parent_id_col: The column name of the parent option id that the default value belongs to
400
+ connection: Name of the connection to use defined in connections.py
401
+ """
402
+ super().__init__(
403
+ table_or_query, min_value_col, max_value_col, increment_col=increment_col, id_col=id_col, source=source,
404
+ user_group_col=user_group_col, parent_id_col=parent_id_col, connection=connection
405
+ )
406
+ self._default_value_col = default_value_col
407
+
408
+ def _convert(self, ds_param: pc.DataSourceParameterConfig, df: pl.DataFrame) -> pc.NumberParameterConfig:
409
+ """
410
+ Method to convert the associated DataSourceParameterConfig into a NumberParameterConfig
411
+
412
+ Arguments:
413
+ ds_param: The parameter to convert
414
+ df: The dataframe containing the parameter options data
415
+
416
+ Returns:
417
+ The converted parameter
418
+ """
419
+ self._validate_parameter_type(ds_param, pc.NumberParameterConfig)
420
+
421
+ columns = [self._min_value_col, self._max_value_col, self._increment_col, self._default_value_col]
422
+ df_agg = self._get_aggregated_df(df, columns)
423
+
424
+ records = df_agg.to_pandas().to_dict("records")
425
+ options = tuple(
426
+ po.NumberParameterOption(
427
+ record[self._min_value_col], record[self._max_value_col],
428
+ increment=self._get_key_from_record(self._increment_col, record, 1),
429
+ default_value=self._get_key_from_record(self._default_value_col, record, None),
430
+ user_groups=self._get_key_from_record_as_list(self._user_group_col, record),
431
+ parent_option_ids=self._get_key_from_record_as_list(self._parent_id_col, record)
432
+ )
433
+ for record in records
434
+ )
435
+ return pc.NumberParameterConfig(
436
+ ds_param.name, ds_param.label, options, description=ds_param.description, user_attribute=ds_param.user_attribute,
437
+ parent_name=ds_param.parent_name, **ds_param.extra_args
438
+ )
439
+
440
+
441
+ @dataclass
442
+ class NumberRangeDataSource(_NumericDataSource):
443
+ """
444
+ Lookup table for number range parameter default options
445
+ """
446
+ _default_lower_value_col: str | None
447
+ _default_upper_value_col: str | None
448
+
449
+ def __init__(
450
+ self, table_or_query: str, min_value_col: str, max_value_col: str, *, increment_col: str | None = None,
451
+ default_lower_value_col: str | None = None, default_upper_value_col: str | None = None, id_col: str | None = None,
452
+ source: SourceEnum = SourceEnum.CONNECTION, user_group_col: str | None = None, parent_id_col: str | None = None,
453
+ connection: str | None = None, **kwargs
454
+ ) -> None:
455
+ """
456
+ Constructor for NumRangeDataSource
457
+
458
+ Arguments:
459
+ table_or_query: Either the name of the table to use, or a query to
460
+ min_value_col: The column name of the minimum value
461
+ max_value_col: The column name of the maximum value
462
+ increment_col: The column name of the increment value. Defaults to column of 1's if None
463
+ default_lower_value_col: The column name of the default lower value. Defaults to min_value_col if None
464
+ default_upper_value_col: The column name of the default upper value. Defaults to max_value_col if None
465
+ id_col: The column name of the id
466
+ source: The source to fetch data from. Must be "connection", "seeds", or "vdl". Defaults to "connection"
467
+ user_group_col: The column name of the user group that the user is in for this option to be valid
468
+ parent_id_col: The column name of the parent option id that the default value belongs to
469
+ connection: Name of the connection to use defined in connections.py
470
+ """
471
+ super().__init__(
472
+ table_or_query, min_value_col, max_value_col, increment_col=increment_col, id_col=id_col, source=source,
473
+ user_group_col=user_group_col, parent_id_col=parent_id_col, connection=connection
474
+ )
475
+ self._default_lower_value_col = default_lower_value_col
476
+ self._default_upper_value_col = default_upper_value_col
477
+
478
+ def _convert(self, ds_param: pc.DataSourceParameterConfig, df: pl.DataFrame) -> pc.NumberRangeParameterConfig:
479
+ """
480
+ Method to convert the associated DataSourceParameterConfig into a NumberRangeParameterConfig
481
+
482
+ Arguments:
483
+ ds_param: The parameter to convert
484
+ df: The dataframe containing the parameter options data
485
+
486
+ Returns:
487
+ The converted parameter
488
+ """
489
+ self._validate_parameter_type(ds_param, pc.NumberRangeParameterConfig)
490
+
491
+ columns = [self._min_value_col, self._max_value_col, self._increment_col, self._default_lower_value_col, self._default_upper_value_col]
492
+ df_agg = self._get_aggregated_df(df, columns)
493
+
494
+ records = df_agg.to_pandas().to_dict("records")
495
+ options = tuple(
496
+ po.NumberRangeParameterOption(
497
+ record[self._min_value_col], record[self._max_value_col],
498
+ increment=self._get_key_from_record(self._increment_col, record, 1),
499
+ default_lower_value=self._get_key_from_record(self._default_lower_value_col, record, None),
500
+ default_upper_value=self._get_key_from_record(self._default_upper_value_col, record, None),
501
+ user_groups=self._get_key_from_record_as_list(self._user_group_col, record),
502
+ parent_option_ids=self._get_key_from_record_as_list(self._parent_id_col, record)
503
+ )
504
+ for record in records
505
+ )
506
+ return pc.NumberRangeParameterConfig(
507
+ ds_param.name, ds_param.label, options, description=ds_param.description, user_attribute=ds_param.user_attribute,
508
+ parent_name=ds_param.parent_name, **ds_param.extra_args
509
+ )
510
+
511
+
512
+ @dataclass
513
+ class TextDataSource(DataSource):
514
+ """
515
+ Lookup table for text parameter default options
516
+ """
517
+ _default_text_col: str
518
+
519
+ def __init__(
520
+ self, table_or_query: str, default_text_col: str, *, id_col: str | None = None, source: SourceEnum = SourceEnum.CONNECTION,
521
+ user_group_col: str | None = None, parent_id_col: str | None = None, connection: str | None = None,
522
+ **kwargs
523
+ ) -> None:
524
+ """
525
+ Constructor for TextDataSource
526
+
527
+ Arguments:
528
+ table_or_query: Either the name of the table to use, or a query to run
529
+ default_text_col: The column name of the default text
530
+ id_col: The column name of the id
531
+ source: The source to fetch data from. Must be "connection", "seeds", or "vdl". Defaults to "connection"
532
+ user_group_col: The column name of the user group that the user is in for this option to be valid
533
+ parent_id_col: The column name of the parent option id that the default date belongs to
534
+ connection: Name of the connection to use defined in connections.py
535
+ """
536
+ super().__init__(
537
+ table_or_query, id_col=id_col, source=source, user_group_col=user_group_col, parent_id_col=parent_id_col,
538
+ connection=connection
539
+ )
540
+ self._default_text_col = default_text_col
541
+
542
+ def _convert(self, ds_param: pc.DataSourceParameterConfig, df: pl.DataFrame) -> pc.TextParameterConfig:
543
+ """
544
+ Method to convert the associated DataSourceParameterConfig into a TextParameterConfig
545
+
546
+ Arguments:
547
+ ds_param: The parameter to convert
548
+ df: The dataframe containing the parameter options data
549
+
550
+ Returns:
551
+ The converted parameter
552
+ """
553
+ self._validate_parameter_type(ds_param, pc.TextParameterConfig)
554
+
555
+ columns = [self._default_text_col]
556
+ df_agg = self._get_aggregated_df(df, columns)
557
+
558
+ records = df_agg.to_pandas().to_dict("records")
559
+ options = tuple(
560
+ po.TextParameterOption(
561
+ default_text=str(record[self._default_text_col]),
562
+ user_groups=self._get_key_from_record_as_list(self._user_group_col, record),
563
+ parent_option_ids=self._get_key_from_record_as_list(self._parent_id_col, record)
564
+ )
565
+ for record in records
566
+ )
567
+ return pc.TextParameterConfig(
568
+ ds_param.name, ds_param.label, options, description=ds_param.description, user_attribute=ds_param.user_attribute,
569
+ parent_name=ds_param.parent_name, **ds_param.extra_args
570
+ )
@@ -0,0 +1,84 @@
1
+ from typing import Callable, Literal
2
+ from dataclasses import dataclass, field
3
+ from functools import cached_property, lru_cache
4
+ import polars as pl
5
+
6
+ from ._model_configs import ModelConfig
7
+
8
+
9
+ @dataclass
10
+ class DatasetMetadata:
11
+ target_model_config: ModelConfig
12
+
13
+ @cached_property
14
+ def _json_repr(self) -> dict:
15
+ fields = []
16
+ for col in self.target_model_config.columns:
17
+ fields.append({
18
+ "name": col.name,
19
+ "type": col.type,
20
+ "condition": col.condition,
21
+ "description": col.description,
22
+ "category": col.category.value
23
+ })
24
+
25
+ return {
26
+ "schema": {
27
+ "fields": fields
28
+ },
29
+ }
30
+
31
+ def to_json(self) -> dict:
32
+ return self._json_repr
33
+
34
+
35
+ @dataclass
36
+ class DatasetResult(DatasetMetadata):
37
+ df: pl.DataFrame
38
+ to_json: Callable[[str, int, int], dict] = field(init=False)
39
+
40
+ def __post_init__(self):
41
+ self.to_json = lru_cache()(self._to_json)
42
+
43
+ def _to_json(self, orientation: Literal["records", "rows", "columns"], limit: int, offset: int) -> dict:
44
+ df = self.df.lazy()
45
+ if offset > 0:
46
+ df = df.filter(pl.col("_row_num") > offset)
47
+ if limit > 0:
48
+ df = df.limit(limit)
49
+ df = df.collect()
50
+
51
+ if orientation == "columns":
52
+ data = df.to_dict(as_series=False)
53
+ else:
54
+ data = df.to_dicts()
55
+ if orientation == "rows":
56
+ data = [[row[col] for col in df.columns] for row in data]
57
+
58
+ column_details_by_name = {col.name: col for col in self.target_model_config.columns}
59
+ fields = []
60
+ for col in df.columns:
61
+ if col == "_row_num":
62
+ fields.append({"name": "_row_num", "type": "integer", "description": "The row number of the dataset (starts at 1)", "category": "misc"})
63
+ elif col in column_details_by_name:
64
+ column_details = column_details_by_name[col]
65
+ fields.append({
66
+ "name": col,
67
+ "type": column_details.type,
68
+ "description": column_details.description,
69
+ "category": column_details.category.value
70
+ })
71
+ else:
72
+ fields.append({"name": col, "type": "unknown", "description": "", "category": "misc"})
73
+
74
+ return {
75
+ "schema": {
76
+ "fields": fields
77
+ },
78
+ "total_num_rows": self.df.select(pl.len()).item(),
79
+ "data_details": {
80
+ "num_rows": df.select(pl.len()).item(),
81
+ "orientation": orientation
82
+ },
83
+ "data": data
84
+ }
@@ -0,0 +1,29 @@
1
+ class InvalidInputError(Exception):
2
+ """
3
+ Use this exception when the error is due to providing invalid inputs to the REST API
4
+
5
+ Attributes:
6
+ status_code: The HTTP status code to return
7
+ error: A short error message that should never change in the future
8
+ error_description: A detailed error message (that is allowed to change in the future)
9
+ """
10
+ def __init__(self, status_code: int, error: str, error_description: str, *args) -> None:
11
+ self.status_code = status_code
12
+ self.error = error
13
+ self.error_description = error_description
14
+ super().__init__(error_description, *args)
15
+
16
+
17
+ class ConfigurationError(Exception):
18
+ """
19
+ Use this exception when the server error is due to errors in the squirrels project instead of the squirrels framework/library
20
+ """
21
+ pass
22
+
23
+
24
+ class FileExecutionError(Exception):
25
+ def __init__(self, message: str, error: Exception, *args) -> None:
26
+ t = " "
27
+ new_message = f"\n" + message + f"\n{t}Produced error message:\n{t}{t}{error} (see above for more details on handled exception)"
28
+ super().__init__(new_message, *args)
29
+ self.error = error