squirrels 0.2.0rc1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of squirrels might be problematic. Click here for more details.

Files changed (34) hide show
  1. squirrels/__init__.py +2 -2
  2. squirrels/_api_server.py +66 -49
  3. squirrels/_authenticator.py +2 -3
  4. squirrels/_command_line.py +1 -1
  5. squirrels/_constants.py +7 -5
  6. squirrels/_environcfg.py +1 -1
  7. squirrels/_initializer.py +1 -2
  8. squirrels/_manifest.py +8 -12
  9. squirrels/_models.py +43 -21
  10. squirrels/_parameter_configs.py +4 -4
  11. squirrels/_parameter_sets.py +1 -3
  12. squirrels/_py_module.py +4 -2
  13. squirrels/_utils.py +7 -0
  14. squirrels/arguments/run_time_args.py +15 -4
  15. squirrels/package_data/assets/favicon.ico +0 -0
  16. squirrels/package_data/assets/index.js +13 -13
  17. squirrels/package_data/base_project/{ignores/.gitignore → .gitignore} +4 -0
  18. squirrels/package_data/base_project/{Dockerfile → docker/Dockerfile} +2 -2
  19. squirrels/package_data/base_project/docker/compose.yml +7 -0
  20. squirrels/package_data/base_project/environcfg.yml +1 -1
  21. squirrels/package_data/base_project/parameters.yml +18 -18
  22. squirrels/package_data/base_project/pyconfigs/auth.py +10 -14
  23. squirrels/package_data/base_project/pyconfigs/context.py +12 -2
  24. squirrels/package_data/base_project/squirrels.yml.j2 +18 -6
  25. squirrels/parameter_options.py +24 -24
  26. squirrels/parameters.py +3 -3
  27. squirrels/user_base.py +10 -11
  28. {squirrels-0.2.0rc1.dist-info → squirrels-0.2.2.dist-info}/METADATA +13 -11
  29. squirrels-0.2.2.dist-info/RECORD +55 -0
  30. {squirrels-0.2.0rc1.dist-info → squirrels-0.2.2.dist-info}/WHEEL +1 -1
  31. {squirrels-0.2.0rc1.dist-info → squirrels-0.2.2.dist-info}/entry_points.txt +1 -0
  32. squirrels-0.2.0rc1.dist-info/RECORD +0 -54
  33. /squirrels/package_data/base_project/{ignores → docker}/.dockerignore +0 -0
  34. {squirrels-0.2.0rc1.dist-info → squirrels-0.2.2.dist-info}/LICENSE +0 -0
squirrels/__init__.py CHANGED
@@ -1,8 +1,8 @@
1
- __version__ = '0.2.0'
1
+ __version__ = '0.2.2'
2
2
 
3
3
  from .arguments.init_time_args import ConnectionsArgs, ParametersArgs
4
4
  from .arguments.run_time_args import AuthArgs, ContextArgs, ModelDepsArgs, ModelArgs
5
5
  from .parameter_options import SelectParameterOption, DateParameterOption, DateRangeParameterOption, NumberParameterOption, NumberRangeParameterOption
6
- from .parameters import Parameter, SingleSelectParameter, MultiSelectParameter, DateParameter, DateRangeParameter, NumberParameter, NumberRangeParameter
6
+ from .parameters import SingleSelectParameter, MultiSelectParameter, DateParameter, DateRangeParameter, NumberParameter, NumberRangeParameter
7
7
  from .data_sources import SingleSelectDataSource, MultiSelectDataSource, DateDataSource, DateRangeDataSource, NumberDataSource, NumberRangeDataSource
8
8
  from .user_base import User, WrongPassword
squirrels/_api_server.py CHANGED
@@ -1,5 +1,5 @@
1
- from typing import Iterable, Optional, Mapping, Callable, Coroutine, TypeVar, Any
2
- from fastapi import Depends, FastAPI, Request, HTTPException, status
1
+ from typing import List, Iterable, Optional, Mapping, Callable, Coroutine, TypeVar, Any
2
+ from fastapi import Depends, FastAPI, Request, HTTPException, Response, status
3
3
  from fastapi.responses import HTMLResponse, JSONResponse
4
4
  from fastapi.templating import Jinja2Templates
5
5
  from fastapi.staticfiles import StaticFiles
@@ -45,36 +45,35 @@ class ApiServer:
45
45
  start = time.time()
46
46
  app = FastAPI()
47
47
 
48
- app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
48
+ @app.middleware("http")
49
+ async def catch_exceptions_middleware(request: Request, call_next):
50
+ try:
51
+ return await call_next(request)
52
+ except u.InvalidInputError as exc:
53
+ traceback.print_exc()
54
+ return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST,
55
+ content={"message": f"Invalid user input: {str(exc)}"})
56
+ except u.ConfigurationError as exc:
57
+ traceback.print_exc()
58
+ return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
59
+ content={"message": f"Squirrels configuration error: {str(exc)}"})
60
+ except NotImplementedError as exc:
61
+ traceback.print_exc()
62
+ return JSONResponse(status_code=status.HTTP_501_NOT_IMPLEMENTED,
63
+ content={"message": f"Not implemented error: {str(exc)}"})
64
+ except Exception as exc:
65
+ traceback.print_exc()
66
+ return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
67
+ content={"message": f"Server error: {str(exc)}"})
68
+
69
+ app.add_middleware(
70
+ CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
71
+ expose_headers=["Applied-Username"]
72
+ )
49
73
 
50
74
  squirrels_version_path = f'/squirrels-v{sq_major_version}'
51
75
  partial_base_path = f'/{ManifestIO.obj.project_variables.get_name()}/v{ManifestIO.obj.project_variables.get_major_version()}'
52
76
  base_path = squirrels_version_path + u.normalize_name_for_api(partial_base_path)
53
-
54
- static_dir = u.join_paths(os.path.dirname(__file__), c.PACKAGE_DATA_FOLDER, c.ASSETS_FOLDER)
55
- app.mount('/'+c.ASSETS_FOLDER, StaticFiles(directory=static_dir), name=c.ASSETS_FOLDER)
56
-
57
- templates_dir = u.join_paths(os.path.dirname(__file__), c.PACKAGE_DATA_FOLDER, c.TEMPLATES_FOLDER)
58
- templates = Jinja2Templates(directory=templates_dir)
59
-
60
- # Exception handlers
61
- @app.exception_handler(u.InvalidInputError)
62
- async def invalid_input_error_handler(request: Request, exc: u.InvalidInputError):
63
- traceback.print_exc()
64
- return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST,
65
- content={"message": f"Invalid user input: {str(exc)}"})
66
-
67
- @app.exception_handler(u.ConfigurationError)
68
- async def configuration_error_handler(request: Request, exc: u.InvalidInputError):
69
- traceback.print_exc()
70
- return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
71
- content={"message": f"Squirrels configuration error: {str(exc)}"})
72
-
73
- @app.exception_handler(NotImplementedError)
74
- async def not_implemented_error_handler(request: Request, exc: u.InvalidInputError):
75
- traceback.print_exc()
76
- return JSONResponse(status_code=status.HTTP_501_NOT_IMPLEMENTED,
77
- content={"message": f"Not implemented error: {str(exc)}"})
78
77
 
79
78
  # Helpers
80
79
  T = TypeVar('T')
@@ -127,13 +126,21 @@ class ApiServer:
127
126
  # Changing selections into a cachable "frozenset" that will later be converted to dictionary
128
127
  selections = set()
129
128
  for key, val in params.items():
130
- if not isinstance(val, str):
129
+ if isinstance(val, List):
131
130
  val = tuple(val)
132
131
  selections.add((u.normalize_name(key), val))
133
132
  selections = frozenset(selections)
134
133
 
135
134
  return await api_function(user, dataset_normalized, selections, request_version)
136
135
 
136
+ async def do_cachable_action(cache: TTLCache, action: Callable[..., Coroutine[Any, Any, T]], *args) -> T:
137
+ cache_key = tuple(args)
138
+ result = cache.get(cache_key)
139
+ if result is None:
140
+ result = await action(*args)
141
+ cache[cache_key] = result
142
+ return result
143
+
137
144
  # Login
138
145
  token_path = base_path + '/token'
139
146
 
@@ -154,20 +161,14 @@ class ApiServer:
154
161
  "expiry_time": expiry
155
162
  }
156
163
 
157
- async def get_current_user(token: str = Depends(oauth2_scheme)) -> Optional[User]:
164
+ async def get_current_user(response: Response, token: str = Depends(oauth2_scheme)) -> Optional[User]:
158
165
  user = self.authenticator.get_user_from_token(token)
166
+ username = "" if user is None else user.username
167
+ response.headers["Applied-Username"] = username
159
168
  return user
160
169
 
161
- async def do_cachable_action(cache: TTLCache, action: Callable[..., Coroutine[Any, Any, T]], *args) -> T:
162
- cache_key = tuple(args)
163
- result = cache.get(cache_key)
164
- if result is None:
165
- result = await action(*args)
166
- cache[cache_key] = result
167
- return result
168
-
169
170
  # Parameters API
170
- parameters_path = base_path + '/{dataset}/parameters'
171
+ parameters_path = base_path + '/dataset/{dataset}/parameters'
171
172
 
172
173
  parameters_cache_size = ManifestIO.obj.settings.get(c.PARAMETERS_CACHE_SIZE_SETTING, 1024)
173
174
  parameters_cache_ttl = ManifestIO.obj.settings.get(c.PARAMETERS_CACHE_TTL_SETTING, 0)
@@ -209,7 +210,7 @@ class ApiServer:
209
210
  return result
210
211
 
211
212
  # Results API
212
- results_path = base_path + '/{dataset}'
213
+ results_path = base_path + '/dataset/{dataset}'
213
214
 
214
215
  results_cache_size = ManifestIO.obj.settings.get(c.RESULTS_CACHE_SIZE_SETTING, 128)
215
216
  results_cache_ttl = ManifestIO.obj.settings.get(c.RESULTS_CACHE_TTL_SETTING, 0)
@@ -248,8 +249,10 @@ class ApiServer:
248
249
  timer.add_activity_time("POST REQUEST total time for DATASET", start)
249
250
  return result
250
251
 
251
- # Catalog API
252
- def get_catalog0(user: Optional[User]):
252
+ # Datasets Catalog API
253
+ datasets_path = base_path + '/datasets'
254
+
255
+ def get_datasets0(user: Optional[User]):
253
256
  datasets_info = []
254
257
  for dataset_name, dataset_config in self.dataset_configs.items():
255
258
  if can_user_access_dataset(user, dataset_name):
@@ -258,30 +261,44 @@ class ApiServer:
258
261
  'name': dataset_name,
259
262
  'label': dataset_config.label,
260
263
  'parameters_path': parameters_path.format(dataset=dataset_normalized),
261
- 'result_path': results_path.format(dataset=dataset_normalized),
262
- 'first_minor_version': 0
264
+ 'result_path': results_path.format(dataset=dataset_normalized)
263
265
  })
264
-
266
+ return {"datasets": datasets_info}
267
+
268
+ @app.get(datasets_path)
269
+ def get_datasets(request: Request, user: Optional[User] = Depends(get_current_user)):
270
+ return process_based_on_response_version_header(request.headers, {
271
+ 0: lambda: get_datasets0(user)
272
+ })
273
+
274
+ # Projects Catalog API
275
+ def get_catalog0():
265
276
  return {
266
277
  'projects': [{
267
278
  'name': ManifestIO.obj.project_variables.get_name(),
268
279
  'label': ManifestIO.obj.project_variables.get_label(),
269
280
  'versions': [{
270
281
  'major_version': ManifestIO.obj.project_variables.get_major_version(),
271
- 'latest_minor_version': ManifestIO.obj.project_variables.get_minor_version(),
282
+ 'minor_versions': [0],
272
283
  'token_path': token_path,
273
- 'datasets': datasets_info
284
+ 'datasets_path': datasets_path
274
285
  }]
275
286
  }]
276
287
  }
277
288
 
278
289
  @app.get(squirrels_version_path, response_class=JSONResponse)
279
- async def get_catalog(request: Request, user: Optional[User] = Depends(get_current_user)):
290
+ async def get_catalog(request: Request):
280
291
  return process_based_on_response_version_header(request.headers, {
281
- 0: lambda: get_catalog0(user)
292
+ 0: lambda: get_catalog0()
282
293
  })
283
294
 
284
295
  # Squirrels UI
296
+ static_dir = u.join_paths(os.path.dirname(__file__), c.PACKAGE_DATA_FOLDER, c.ASSETS_FOLDER)
297
+ app.mount('/'+c.ASSETS_FOLDER, StaticFiles(directory=static_dir), name=c.ASSETS_FOLDER)
298
+
299
+ templates_dir = u.join_paths(os.path.dirname(__file__), c.PACKAGE_DATA_FOLDER, c.TEMPLATES_FOLDER)
300
+ templates = Jinja2Templates(directory=templates_dir)
301
+
285
302
  @app.get('/', response_class=HTMLResponse)
286
303
  async def get_ui(request: Request):
287
304
  return templates.TemplateResponse('index.html', {
@@ -34,7 +34,7 @@ class Authenticator:
34
34
  return AuthArgs(conn_args.proj_vars, conn_args.env_vars, conn_args._get_credential, connections, username, password)
35
35
 
36
36
  def authenticate_user(self, username: str, password: str) -> Optional[User]:
37
- user_cls = self.auth_helper.get_func_or_class("User", default_attr=User)
37
+ user_cls: type[User] = self.auth_helper.get_func_or_class("User", default_attr=User)
38
38
  get_user = self.auth_helper.get_func_or_class(c.GET_USER_FUNC, is_required=False)
39
39
  try:
40
40
  real_user = get_user(self._get_auth_args(username, password)) if get_user is not None else None
@@ -48,9 +48,8 @@ class Authenticator:
48
48
  fake_users = EnvironConfigIO.obj.get_users()
49
49
  if username in fake_users and secrets.compare_digest(fake_users[username][c.USER_PWD_KEY], password):
50
50
  is_internal = fake_users[username].get("is_internal", False)
51
- user: User = user_cls(username, is_internal=is_internal)
52
51
  try:
53
- return user.with_attributes(fake_users[username])
52
+ return user_cls.Create(username, fake_users[username], is_internal=is_internal)
54
53
  except Exception as e:
55
54
  raise u.FileExecutionError(f'Failed to create user from User model in {c.AUTH_FILE}', e)
56
55
 
@@ -33,7 +33,7 @@ def main():
33
33
 
34
34
  compile_parser = subparsers.add_parser(c.COMPILE_CMD, help='Create files for rendered sql queries in the "target/compile" folder', add_help=False)
35
35
  compile_parser.add_argument('-h', '--help', action="help", help="Show this help message and exit")
36
- compile_parser.add_argument('-d', '--dataset', type=str, help="Select dataset to use for dataset args. If not specified, all models for all datasets are compiled")
36
+ compile_parser.add_argument('-d', '--dataset', type=str, help="Select dataset to use for dataset traits. If not specified, all models for all datasets are compiled")
37
37
  compile_parser.add_argument('-a', '--all-test-sets', action="store_true", help="Compile models for all selection test sets")
38
38
  compile_parser.add_argument('-t', '--test-set', type=str, help="The selection test set to use. Default selections are used if not specified. Ignored if using --all-test-sets")
39
39
  compile_parser.add_argument('-s', '--select', type=str, help="Select single model to compile. If not specified, all models for the dataset are compiled. Also, ignored if --dataset is not specified")
squirrels/_constants.py CHANGED
@@ -9,7 +9,6 @@ PROJ_VARS_KEY = 'project_variables'
9
9
  PROJECT_NAME_KEY = 'name'
10
10
  PROJECT_LABEL_KEY = 'label'
11
11
  MAJOR_VERSION_KEY = 'major_version'
12
- MINOR_VERSION_KEY = 'minor_version'
13
12
 
14
13
  PACKAGES_KEY = 'packages'
15
14
  PACKAGE_GIT_KEY = 'git'
@@ -48,7 +47,7 @@ DATASET_NAME_KEY = 'name'
48
47
  DATASET_LABEL_KEY = 'label'
49
48
  DATASET_MODEL_KEY = 'model'
50
49
  DATASET_PARAMETERS_KEY = 'parameters'
51
- DATASET_ARGS_KEY = 'args'
50
+ DATASET_TRAITS_KEY = 'traits'
52
51
 
53
52
  DATASET_SCOPE_KEY = 'scope'
54
53
  PUBLIC_SCOPE = 'public'
@@ -109,14 +108,17 @@ PARAMETERS_OUTPUT = 'parameters.json'
109
108
  FINAL_VIEW_OUT_STEM = 'final_view'
110
109
 
111
110
  # Dataset setting names
112
- AUTH_TOKEN_EXPIRE_SETTING = 'auth.token.expire.minutes'
111
+ AUTH_TOKEN_EXPIRE_SETTING = 'auth.token.expire_minutes'
113
112
  PARAMETERS_CACHE_SIZE_SETTING = 'parameters.cache.size'
114
- PARAMETERS_CACHE_TTL_SETTING = 'parameters.cache.ttl.minutes'
113
+ PARAMETERS_CACHE_TTL_SETTING = 'parameters.cache.ttl_minutes'
115
114
  RESULTS_CACHE_SIZE_SETTING = 'results.cache.size'
116
- RESULTS_CACHE_TTL_SETTING = 'results.cache.ttl.minutes'
115
+ RESULTS_CACHE_TTL_SETTING = 'results.cache.ttl_minutes'
117
116
  TEST_SET_DEFAULT_USED_SETTING = 'selection_test_sets.default_name_used'
118
117
  DB_CONN_DEFAULT_USED_SETTING = 'connections.default_name_used'
119
118
  DEFAULT_MATERIALIZE_SETTING = 'defaults.federates.materialized'
119
+ IN_MEMORY_DB_SETTING = 'in_memory_database'
120
+ SQLITE = 'sqlite'
121
+ DUCKDB = 'duckdb'
120
122
 
121
123
  # Selection cfg sections
122
124
  USER_ATTRIBUTES_SECTION = 'user_attributes'
squirrels/_environcfg.py CHANGED
@@ -43,7 +43,7 @@ class _EnvironConfig:
43
43
  return credential[c.USERNAME_KEY], credential[c.PASSWORD_KEY]
44
44
 
45
45
  def get_secret(self, key: str, *, default_factory: Optional[Callable[[],str]] = None) -> str:
46
- if key not in self._secrets and default_factory is not None:
46
+ if not self._secrets.get(key) and default_factory is not None:
47
47
  self._secrets[key] = default_factory()
48
48
  return self._secrets.get(key)
49
49
 
squirrels/_initializer.py CHANGED
@@ -40,7 +40,6 @@ class Initializer:
40
40
  options = ["core", "connections", "parameters", "dbview", "federate", "auth", "sample_db"]
41
41
  CORE, CONNECTIONS, PARAMETERS, DBVIEW, FEDERATE, AUTH, SAMPLE_DB = options
42
42
  TMP_FOLDER = "tmp"
43
- IGNORES_FOLDER = "ignores"
44
43
 
45
44
  answers = { x: getattr(args, x) for x in options }
46
45
  if not any(answers.values()):
@@ -136,7 +135,7 @@ class Initializer:
136
135
 
137
136
  create_manifest_file()
138
137
 
139
- self._copy_file(".gitignore", src_folder=IGNORES_FOLDER)
138
+ self._copy_file(".gitignore")
140
139
  self._copy_file(c.MANIFEST_FILE, src_folder=TMP_FOLDER)
141
140
 
142
141
  if connections_use_py:
squirrels/_manifest.py CHANGED
@@ -26,10 +26,10 @@ class ProjectVarsConfig(ManifestComponentConfig):
26
26
  data: dict
27
27
 
28
28
  def __post_init__(self):
29
- required_keys = [c.PROJECT_NAME_KEY, c.MAJOR_VERSION_KEY, c.MINOR_VERSION_KEY]
29
+ required_keys = [c.PROJECT_NAME_KEY, c.MAJOR_VERSION_KEY]
30
30
  self._validate_required(self.data, required_keys, c.PROJ_VARS_KEY)
31
31
 
32
- integer_keys = [c.MAJOR_VERSION_KEY, c.MINOR_VERSION_KEY]
32
+ integer_keys = [c.MAJOR_VERSION_KEY]
33
33
  for key in integer_keys:
34
34
  if key in self.data and not isinstance(self.data[key], int):
35
35
  raise u.ConfigurationError(f'Project variable "{key}" must be an integer')
@@ -46,9 +46,6 @@ class ProjectVarsConfig(ManifestComponentConfig):
46
46
 
47
47
  def get_major_version(self) -> int:
48
48
  return self.data[c.MAJOR_VERSION_KEY]
49
-
50
- def get_minor_version(self) -> int:
51
- return self.data[c.MINOR_VERSION_KEY]
52
49
 
53
50
 
54
51
  @dataclass
@@ -84,14 +81,13 @@ class DbConnConfig(ManifestComponentConfig):
84
81
 
85
82
  @dataclass
86
83
  class ParametersConfig(ManifestComponentConfig):
87
- name: str
88
84
  type: str
89
85
  factory: str
90
86
  arguments: dict
91
87
 
92
88
  @classmethod
93
89
  def from_dict(cls, kwargs: dict):
94
- all_keys = [c.PARAMETER_NAME_KEY, c.PARAMETER_TYPE_KEY, c.PARAMETER_FACTORY_KEY, c.PARAMETER_ARGS_KEY]
90
+ all_keys = [c.PARAMETER_TYPE_KEY, c.PARAMETER_FACTORY_KEY, c.PARAMETER_ARGS_KEY]
95
91
  cls._validate_required(kwargs, all_keys, c.PARAMETERS_KEY)
96
92
  args = {key: kwargs[key] for key in all_keys}
97
93
  return cls(**args)
@@ -150,7 +146,7 @@ class DatasetsConfig(ManifestComponentConfig):
150
146
  model: str
151
147
  scope: DatasetScope
152
148
  parameters: Optional[list[str]]
153
- args: dict
149
+ traits: dict
154
150
 
155
151
  @classmethod
156
152
  def from_dict(cls, kwargs: dict):
@@ -166,8 +162,8 @@ class DatasetsConfig(ManifestComponentConfig):
166
162
  raise u.ConfigurationError(f'Scope not found for dataset "{name}". Scope must be one of {scope_list}') from e
167
163
 
168
164
  parameters = kwargs.get(c.DATASET_PARAMETERS_KEY)
169
- args = kwargs.get(c.DATASET_ARGS_KEY, {})
170
- return cls(name, label, model, scope, parameters, args)
165
+ traits = kwargs.get(c.DATASET_TRAITS_KEY, {})
166
+ return cls(name, label, model, scope, parameters, traits)
171
167
 
172
168
 
173
169
  @dataclass
@@ -175,7 +171,7 @@ class _ManifestConfig:
175
171
  project_variables: ProjectVarsConfig
176
172
  packages: list[PackageConfig]
177
173
  connections: dict[str, DbConnConfig]
178
- parameters: dict[str, ParametersConfig]
174
+ parameters: list[ParametersConfig]
179
175
  selection_test_sets: dict[str, TestSetsConfig]
180
176
  dbviews: dict[str, DbviewConfig]
181
177
  federates: dict[str, FederateConfig]
@@ -209,7 +205,7 @@ class _ManifestConfig:
209
205
  all_package_dirs.add(package.directory)
210
206
 
211
207
  db_conns = cls._create_configs_as_dict(DbConnConfig, kwargs, c.DB_CONNECTIONS_KEY, c.DB_CONN_NAME_KEY)
212
- params = cls._create_configs_as_dict(ParametersConfig, kwargs, c.PARAMETERS_KEY, c.PARAMETER_NAME_KEY)
208
+ params = [ParametersConfig.from_dict(x) for x in kwargs.get(c.PARAMETERS_KEY, [])]
213
209
 
214
210
  test_sets = cls._create_configs_as_dict(TestSetsConfig, kwargs, c.TEST_SETS_KEY, c.TEST_SET_NAME_KEY)
215
211
  default_test_set: str = settings.get(c.TEST_SET_DEFAULT_USED_SETTING, c.DEFAULT_TEST_SET_NAME)
squirrels/_models.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from __future__ import annotations
2
- from typing import Optional, Callable, Iterable, Any
2
+ from typing import Union, Optional, Callable, Iterable, Any
3
3
  from dataclasses import dataclass, field
4
4
  from enum import Enum
5
5
  from pathlib import Path
@@ -7,7 +7,7 @@ import sqlite3, pandas as pd, asyncio, os, shutil
7
7
 
8
8
  from . import _constants as c, _utils as u, _py_module as pm
9
9
  from .arguments.run_time_args import ContextArgs, ModelDepsArgs, ModelArgs
10
- from .user_base import User
10
+ from ._authenticator import User, Authenticator
11
11
  from ._connection_set import ConnectionSetIO
12
12
  from ._manifest import ManifestIO, DatasetsConfig
13
13
  from ._parameter_sets import ParameterConfigsSetIO, ParameterSet
@@ -146,7 +146,7 @@ class Model:
146
146
  configuration = SqlModelConfig(connection_name, materialized)
147
147
  kwargs = {
148
148
  "proj_vars": ctx_args.proj_vars, "env_vars": ctx_args.env_vars,
149
- "user": ctx_args.user, "prms": ctx_args.prms, "args": ctx_args.args,
149
+ "user": ctx_args.user, "prms": ctx_args.prms, "traits": ctx_args.traits,
150
150
  "ctx": ctx, "config": configuration.set_attribute
151
151
  }
152
152
  dependencies = set()
@@ -166,7 +166,7 @@ class Model:
166
166
 
167
167
  async def _compile_python_model(self, ctx: dict[str, Any], ctx_args: ContextArgs) -> tuple[PyModelQuery, set]:
168
168
  assert(isinstance(self.query_file.raw_query, RawPyQuery))
169
- sqrl_args = ModelDepsArgs(ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.args, ctx)
169
+ sqrl_args = ModelDepsArgs(ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits, ctx)
170
170
  try:
171
171
  dependencies = await asyncio.to_thread(self.query_file.raw_query.dependencies_func, sqrl_args)
172
172
  except Exception as e:
@@ -175,7 +175,7 @@ class Model:
175
175
  dbview_conn_name = self._get_dbview_conn_name()
176
176
  connections = ConnectionSetIO.obj.get_engines_as_dict()
177
177
  ref = lambda x: self.upstreams[x].result
178
- sqrl_args = ModelArgs(ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.args,
178
+ sqrl_args = ModelArgs(ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits,
179
179
  ctx, dbview_conn_name, connections, ref, set(dependencies))
180
180
 
181
181
  def compiled_query():
@@ -215,9 +215,9 @@ class Model:
215
215
  coroutines.append(coro)
216
216
  await asyncio.gather(*coroutines)
217
217
 
218
- def validate_no_cycles(self, depencency_path: set[str]) -> set[str]:
218
+ def get_terminal_nodes(self, depencency_path: set[str]) -> set[str]:
219
219
  if self.confirmed_no_cycles:
220
- return
220
+ return set()
221
221
 
222
222
  if self.name in depencency_path:
223
223
  raise u.ConfigurationError(f'Cycle found in model dependency graph')
@@ -229,11 +229,24 @@ class Model:
229
229
  new_path = set(depencency_path)
230
230
  new_path.add(self.name)
231
231
  for dep_model in self.upstreams.values():
232
- terminal_nodes_under_dep = dep_model.validate_no_cycles(new_path)
232
+ terminal_nodes_under_dep = dep_model.get_terminal_nodes(new_path)
233
233
  terminal_nodes = terminal_nodes.union(terminal_nodes_under_dep)
234
234
 
235
235
  self.confirmed_no_cycles = True
236
236
  return terminal_nodes
237
+
238
+ def _load_pandas_to_table(self, df: pd.DataFrame, conn: sqlite3.Connection) -> None:
239
+ if u.use_duckdb():
240
+ conn.execute(f"CREATE TABLE {self.name} AS FROM df")
241
+ else:
242
+ df.to_sql(self.name, conn, index=False)
243
+
244
+ def _load_table_to_pandas(self, conn: sqlite3.Connection) -> pd.DataFrame:
245
+ if u.use_duckdb():
246
+ return conn.execute(f"FROM {self.name}").df()
247
+ else:
248
+ query = f"SELECT * FROM {self.name}"
249
+ return pd.read_sql(query, conn)
237
250
 
238
251
  async def _run_sql_model(self, conn: sqlite3.Connection) -> None:
239
252
  assert(isinstance(self.compiled_query, SqlModelQuery))
@@ -248,7 +261,7 @@ class Model:
248
261
  raise u.FileExecutionError(f'Failed to run dbview sql model "{self.name}"', e)
249
262
 
250
263
  df = await asyncio.to_thread(run_sql_query)
251
- await asyncio.to_thread(df.to_sql, self.name, conn, index=False)
264
+ await asyncio.to_thread(self._load_pandas_to_table, df, conn)
252
265
  if self.needs_pandas or self.is_target:
253
266
  self.result = df
254
267
  elif self.query_file.model_type == ModelType.FEDERATE:
@@ -261,15 +274,14 @@ class Model:
261
274
 
262
275
  await asyncio.to_thread(create_table)
263
276
  if self.needs_pandas or self.is_target:
264
- query = f"SELECT * FROM {self.name}"
265
- self.result = await asyncio.to_thread(pd.read_sql, query, conn)
277
+ self.result = await asyncio.to_thread(self._load_table_to_pandas, conn)
266
278
 
267
279
  async def _run_python_model(self, conn: sqlite3.Connection) -> None:
268
280
  assert(isinstance(self.compiled_query, PyModelQuery))
269
281
 
270
282
  df = await asyncio.to_thread(self.compiled_query.query)
271
283
  if self.needs_sql_table:
272
- await asyncio.to_thread(df.to_sql, self.name, conn, index=False)
284
+ await asyncio.to_thread(self._load_pandas_to_table, df, conn)
273
285
  if self.needs_pandas or self.is_target:
274
286
  self.result = df
275
287
 
@@ -320,7 +332,7 @@ class DAG:
320
332
  context = {}
321
333
  param_args = ParameterConfigsSetIO.args
322
334
  prms = self.parameter_set.get_parameters_as_dict()
323
- args = ContextArgs(param_args.proj_vars, param_args.env_vars, user, prms, self.dataset.args)
335
+ args = ContextArgs(param_args.proj_vars, param_args.env_vars, user, prms, self.dataset.traits)
324
336
  try:
325
337
  context_func(ctx=context, sqrl=args)
326
338
  except Exception as e:
@@ -331,14 +343,21 @@ class DAG:
331
343
  async def _compile_models(self, context: dict[str, Any], ctx_args: ContextArgs, recurse: bool) -> None:
332
344
  await self.target_model.compile(context, ctx_args, self.models_dict, recurse)
333
345
 
334
- def _validate_no_cycles(self) -> set[str]:
346
+ def _get_terminal_nodes(self) -> set[str]:
335
347
  start = time.time()
336
- terminal_nodes = self.target_model.validate_no_cycles(set())
348
+ terminal_nodes = self.target_model.get_terminal_nodes(set())
349
+ for model in self.models_dict.values():
350
+ model.confirmed_no_cycles = False
337
351
  timer.add_activity_time(f"validating no cycles in models dependencies", start)
338
352
  return terminal_nodes
339
353
 
340
354
  async def _run_models(self, terminal_nodes: set[str]) -> None:
341
- conn = sqlite3.connect(":memory:", check_same_thread=False)
355
+ if u.use_duckdb():
356
+ import duckdb
357
+ conn = duckdb.connect()
358
+ else:
359
+ conn = sqlite3.connect(":memory:", check_same_thread=False)
360
+
342
361
  try:
343
362
  coroutines = []
344
363
  for model_name in terminal_nodes:
@@ -360,7 +379,7 @@ class DAG:
360
379
 
361
380
  await self._compile_models(context, ctx_args, recurse)
362
381
 
363
- terminal_nodes = self._validate_no_cycles()
382
+ terminal_nodes = self._get_terminal_nodes()
364
383
 
365
384
  if runquery:
366
385
  await self._run_models(terminal_nodes)
@@ -431,10 +450,13 @@ class ModelsIO:
431
450
  @classmethod
432
451
  async def WriteDatasetOutputsGivenTestSet(cls, dataset: str, select: str, test_set: str, runquery: bool, recurse: bool) -> Any:
433
452
  test_set_conf = ManifestIO.obj.selection_test_sets[test_set]
434
- user = User("")
435
- for key, val in test_set_conf.user_attributes.items():
436
- setattr(user, key, val)
453
+ user_attributes = test_set_conf.user_attributes
437
454
  selections = test_set_conf.parameters
455
+
456
+ username, is_internal = user_attributes.get("username", ""), user_attributes.get("is_internal", False)
457
+ user_cls: type[User] = Authenticator.get_auth_helper().get_func_or_class("User", default_attr=User)
458
+ user = user_cls.Create(username, test_set_conf.user_attributes, is_internal=is_internal)
459
+
438
460
  dag = cls.GenerateDAG(dataset, target_model_name=select, always_pandas=True)
439
461
  await dag.execute(cls.context_func, user, selections, runquery=runquery, recurse=recurse)
440
462
 
@@ -445,7 +467,7 @@ class ModelsIO:
445
467
  def write_model_outputs(model: Model) -> None:
446
468
  subfolder = c.DBVIEWS_FOLDER if model.query_file.model_type == ModelType.DBVIEW else c.FEDERATES_FOLDER
447
469
  subpath = u.join_paths(output_folder, subfolder)
448
- os.makedirs(subpath)
470
+ os.makedirs(subpath, exist_ok=True)
449
471
  if isinstance(model.compiled_query, SqlModelQuery):
450
472
  output_filepath = u.join_paths(subpath, model.name+'.sql')
451
473
  query = model.compiled_query.query
@@ -1,5 +1,5 @@
1
1
  from __future__ import annotations
2
- from typing import Type, Optional, Union, Sequence, Iterator
2
+ from typing import Type, Optional, Union, Sequence, Iterator, Any
3
3
  from dataclasses import dataclass, field
4
4
  from abc import ABCMeta, abstractmethod
5
5
  from copy import copy
@@ -32,10 +32,10 @@ class ParameterConfigBase(metaclass=ABCMeta):
32
32
  self.user_attribute = user_attribute
33
33
  self.parent_name = parent_name
34
34
 
35
- def _get_user_group(self, user: Optional[User]) -> Optional[str]:
35
+ def _get_user_group(self, user: Optional[User]) -> Any:
36
36
  if self.user_attribute is not None:
37
37
  if user is None:
38
- raise u.ConfigurationError(f"Public datasets with non-authenticated users cannot use parameter named " +
38
+ raise u.ConfigurationError(f"Public datasets (which allows non-authenticated users) cannot use parameter " +
39
39
  f"'{self.name}' because 'user_attribute' is defined on this parameter.")
40
40
  return getattr(user, self.user_attribute)
41
41
 
@@ -130,7 +130,7 @@ class SelectionParameterConfig(ParameterConfig):
130
130
  def _get_default_ids_iterator(self, options: Sequence[po.SelectParameterOption]) -> Iterator[str]:
131
131
  return (x._identifier for x in options if x._is_default)
132
132
 
133
- def copy(self) -> MultiSelectParameterConfig:
133
+ def copy(self) -> SelectionParameterConfig:
134
134
  """
135
135
  Use for unit testing only
136
136
  """
@@ -164,7 +164,6 @@ class ParameterConfigsSetIO:
164
164
 
165
165
  @classmethod
166
166
  def _AddFromDict(cls, param_config: ParametersConfig) -> None:
167
- param_config.arguments["name"] = param_config.name
168
167
  ptype = getattr(p, param_config.type)
169
168
  factory = getattr(ptype, param_config.factory)
170
169
  factory(**param_config.arguments)
@@ -174,8 +173,7 @@ class ParameterConfigsSetIO:
174
173
  start = time.time()
175
174
  cls.obj = _ParameterConfigsSet()
176
175
 
177
- parameters_from_manifest = ManifestIO.obj.parameters.values()
178
- for param_as_dict in parameters_from_manifest:
176
+ for param_as_dict in ManifestIO.obj.parameters:
179
177
  cls._AddFromDict(param_as_dict)
180
178
 
181
179
  conn_args = ConnectionSetIO.args
squirrels/_py_module.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from typing import Type, Optional, Any
2
2
  from types import ModuleType
3
- from importlib.machinery import SourceFileLoader
3
+ import importlib.util
4
4
 
5
5
  from . import _constants as c, _utils as u
6
6
 
@@ -16,7 +16,9 @@ class PyModule:
16
16
  """
17
17
  self.filepath = str(filepath)
18
18
  try:
19
- self.module: Optional[ModuleType] = SourceFileLoader(self.filepath, self.filepath).load_module()
19
+ spec = importlib.util.spec_from_file_location(self.filepath, self.filepath)
20
+ self.module = importlib.util.module_from_spec(spec)
21
+ spec.loader.exec_module(self.module)
20
22
  except FileNotFoundError as e:
21
23
  if is_required:
22
24
  raise u.ConfigurationError(f"Required file not found: '{self.filepath}'") from e
squirrels/_utils.py CHANGED
@@ -3,6 +3,8 @@ from pathlib import Path
3
3
  from pandas.api import types as pd_types
4
4
  import json, jinja2 as j2, pandas as pd
5
5
 
6
+ from . import _constants as c
7
+
6
8
  FilePath = Union[str, Path]
7
9
 
8
10
 
@@ -179,3 +181,8 @@ def process_if_not_none(input_val: Optional[X], processor: Callable[[X], Y]) ->
179
181
  if input_val is None:
180
182
  return None
181
183
  return processor(input_val)
184
+
185
+
186
+ def use_duckdb():
187
+ from ._manifest import ManifestIO
188
+ return (ManifestIO.obj.settings.get(c.IN_MEMORY_DB_SETTING, c.SQLITE) == c.DUCKDB)