squirrels 0.2.0rc1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of squirrels might be problematic. Click here for more details.
- squirrels/__init__.py +2 -2
- squirrels/_api_server.py +66 -49
- squirrels/_authenticator.py +2 -3
- squirrels/_command_line.py +1 -1
- squirrels/_constants.py +7 -5
- squirrels/_environcfg.py +1 -1
- squirrels/_initializer.py +1 -2
- squirrels/_manifest.py +8 -12
- squirrels/_models.py +43 -21
- squirrels/_parameter_configs.py +4 -4
- squirrels/_parameter_sets.py +1 -3
- squirrels/_py_module.py +4 -2
- squirrels/_utils.py +7 -0
- squirrels/arguments/run_time_args.py +15 -4
- squirrels/package_data/assets/favicon.ico +0 -0
- squirrels/package_data/assets/index.js +13 -13
- squirrels/package_data/base_project/{ignores/.gitignore → .gitignore} +4 -0
- squirrels/package_data/base_project/{Dockerfile → docker/Dockerfile} +2 -2
- squirrels/package_data/base_project/docker/compose.yml +7 -0
- squirrels/package_data/base_project/environcfg.yml +1 -1
- squirrels/package_data/base_project/parameters.yml +18 -18
- squirrels/package_data/base_project/pyconfigs/auth.py +10 -14
- squirrels/package_data/base_project/pyconfigs/context.py +12 -2
- squirrels/package_data/base_project/squirrels.yml.j2 +18 -6
- squirrels/parameter_options.py +24 -24
- squirrels/parameters.py +3 -3
- squirrels/user_base.py +10 -11
- {squirrels-0.2.0rc1.dist-info → squirrels-0.2.2.dist-info}/METADATA +13 -11
- squirrels-0.2.2.dist-info/RECORD +55 -0
- {squirrels-0.2.0rc1.dist-info → squirrels-0.2.2.dist-info}/WHEEL +1 -1
- {squirrels-0.2.0rc1.dist-info → squirrels-0.2.2.dist-info}/entry_points.txt +1 -0
- squirrels-0.2.0rc1.dist-info/RECORD +0 -54
- /squirrels/package_data/base_project/{ignores → docker}/.dockerignore +0 -0
- {squirrels-0.2.0rc1.dist-info → squirrels-0.2.2.dist-info}/LICENSE +0 -0
squirrels/__init__.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
__version__ = '0.2.
|
|
1
|
+
__version__ = '0.2.2'
|
|
2
2
|
|
|
3
3
|
from .arguments.init_time_args import ConnectionsArgs, ParametersArgs
|
|
4
4
|
from .arguments.run_time_args import AuthArgs, ContextArgs, ModelDepsArgs, ModelArgs
|
|
5
5
|
from .parameter_options import SelectParameterOption, DateParameterOption, DateRangeParameterOption, NumberParameterOption, NumberRangeParameterOption
|
|
6
|
-
from .parameters import
|
|
6
|
+
from .parameters import SingleSelectParameter, MultiSelectParameter, DateParameter, DateRangeParameter, NumberParameter, NumberRangeParameter
|
|
7
7
|
from .data_sources import SingleSelectDataSource, MultiSelectDataSource, DateDataSource, DateRangeDataSource, NumberDataSource, NumberRangeDataSource
|
|
8
8
|
from .user_base import User, WrongPassword
|
squirrels/_api_server.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
from typing import Iterable, Optional, Mapping, Callable, Coroutine, TypeVar, Any
|
|
2
|
-
from fastapi import Depends, FastAPI, Request, HTTPException, status
|
|
1
|
+
from typing import List, Iterable, Optional, Mapping, Callable, Coroutine, TypeVar, Any
|
|
2
|
+
from fastapi import Depends, FastAPI, Request, HTTPException, Response, status
|
|
3
3
|
from fastapi.responses import HTMLResponse, JSONResponse
|
|
4
4
|
from fastapi.templating import Jinja2Templates
|
|
5
5
|
from fastapi.staticfiles import StaticFiles
|
|
@@ -45,36 +45,35 @@ class ApiServer:
|
|
|
45
45
|
start = time.time()
|
|
46
46
|
app = FastAPI()
|
|
47
47
|
|
|
48
|
-
app.
|
|
48
|
+
@app.middleware("http")
|
|
49
|
+
async def catch_exceptions_middleware(request: Request, call_next):
|
|
50
|
+
try:
|
|
51
|
+
return await call_next(request)
|
|
52
|
+
except u.InvalidInputError as exc:
|
|
53
|
+
traceback.print_exc()
|
|
54
|
+
return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST,
|
|
55
|
+
content={"message": f"Invalid user input: {str(exc)}"})
|
|
56
|
+
except u.ConfigurationError as exc:
|
|
57
|
+
traceback.print_exc()
|
|
58
|
+
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
59
|
+
content={"message": f"Squirrels configuration error: {str(exc)}"})
|
|
60
|
+
except NotImplementedError as exc:
|
|
61
|
+
traceback.print_exc()
|
|
62
|
+
return JSONResponse(status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
|
63
|
+
content={"message": f"Not implemented error: {str(exc)}"})
|
|
64
|
+
except Exception as exc:
|
|
65
|
+
traceback.print_exc()
|
|
66
|
+
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
67
|
+
content={"message": f"Server error: {str(exc)}"})
|
|
68
|
+
|
|
69
|
+
app.add_middleware(
|
|
70
|
+
CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
|
|
71
|
+
expose_headers=["Applied-Username"]
|
|
72
|
+
)
|
|
49
73
|
|
|
50
74
|
squirrels_version_path = f'/squirrels-v{sq_major_version}'
|
|
51
75
|
partial_base_path = f'/{ManifestIO.obj.project_variables.get_name()}/v{ManifestIO.obj.project_variables.get_major_version()}'
|
|
52
76
|
base_path = squirrels_version_path + u.normalize_name_for_api(partial_base_path)
|
|
53
|
-
|
|
54
|
-
static_dir = u.join_paths(os.path.dirname(__file__), c.PACKAGE_DATA_FOLDER, c.ASSETS_FOLDER)
|
|
55
|
-
app.mount('/'+c.ASSETS_FOLDER, StaticFiles(directory=static_dir), name=c.ASSETS_FOLDER)
|
|
56
|
-
|
|
57
|
-
templates_dir = u.join_paths(os.path.dirname(__file__), c.PACKAGE_DATA_FOLDER, c.TEMPLATES_FOLDER)
|
|
58
|
-
templates = Jinja2Templates(directory=templates_dir)
|
|
59
|
-
|
|
60
|
-
# Exception handlers
|
|
61
|
-
@app.exception_handler(u.InvalidInputError)
|
|
62
|
-
async def invalid_input_error_handler(request: Request, exc: u.InvalidInputError):
|
|
63
|
-
traceback.print_exc()
|
|
64
|
-
return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST,
|
|
65
|
-
content={"message": f"Invalid user input: {str(exc)}"})
|
|
66
|
-
|
|
67
|
-
@app.exception_handler(u.ConfigurationError)
|
|
68
|
-
async def configuration_error_handler(request: Request, exc: u.InvalidInputError):
|
|
69
|
-
traceback.print_exc()
|
|
70
|
-
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
71
|
-
content={"message": f"Squirrels configuration error: {str(exc)}"})
|
|
72
|
-
|
|
73
|
-
@app.exception_handler(NotImplementedError)
|
|
74
|
-
async def not_implemented_error_handler(request: Request, exc: u.InvalidInputError):
|
|
75
|
-
traceback.print_exc()
|
|
76
|
-
return JSONResponse(status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
|
77
|
-
content={"message": f"Not implemented error: {str(exc)}"})
|
|
78
77
|
|
|
79
78
|
# Helpers
|
|
80
79
|
T = TypeVar('T')
|
|
@@ -127,13 +126,21 @@ class ApiServer:
|
|
|
127
126
|
# Changing selections into a cachable "frozenset" that will later be converted to dictionary
|
|
128
127
|
selections = set()
|
|
129
128
|
for key, val in params.items():
|
|
130
|
-
if
|
|
129
|
+
if isinstance(val, List):
|
|
131
130
|
val = tuple(val)
|
|
132
131
|
selections.add((u.normalize_name(key), val))
|
|
133
132
|
selections = frozenset(selections)
|
|
134
133
|
|
|
135
134
|
return await api_function(user, dataset_normalized, selections, request_version)
|
|
136
135
|
|
|
136
|
+
async def do_cachable_action(cache: TTLCache, action: Callable[..., Coroutine[Any, Any, T]], *args) -> T:
|
|
137
|
+
cache_key = tuple(args)
|
|
138
|
+
result = cache.get(cache_key)
|
|
139
|
+
if result is None:
|
|
140
|
+
result = await action(*args)
|
|
141
|
+
cache[cache_key] = result
|
|
142
|
+
return result
|
|
143
|
+
|
|
137
144
|
# Login
|
|
138
145
|
token_path = base_path + '/token'
|
|
139
146
|
|
|
@@ -154,20 +161,14 @@ class ApiServer:
|
|
|
154
161
|
"expiry_time": expiry
|
|
155
162
|
}
|
|
156
163
|
|
|
157
|
-
async def get_current_user(token: str = Depends(oauth2_scheme)) -> Optional[User]:
|
|
164
|
+
async def get_current_user(response: Response, token: str = Depends(oauth2_scheme)) -> Optional[User]:
|
|
158
165
|
user = self.authenticator.get_user_from_token(token)
|
|
166
|
+
username = "" if user is None else user.username
|
|
167
|
+
response.headers["Applied-Username"] = username
|
|
159
168
|
return user
|
|
160
169
|
|
|
161
|
-
async def do_cachable_action(cache: TTLCache, action: Callable[..., Coroutine[Any, Any, T]], *args) -> T:
|
|
162
|
-
cache_key = tuple(args)
|
|
163
|
-
result = cache.get(cache_key)
|
|
164
|
-
if result is None:
|
|
165
|
-
result = await action(*args)
|
|
166
|
-
cache[cache_key] = result
|
|
167
|
-
return result
|
|
168
|
-
|
|
169
170
|
# Parameters API
|
|
170
|
-
parameters_path = base_path + '/{dataset}/parameters'
|
|
171
|
+
parameters_path = base_path + '/dataset/{dataset}/parameters'
|
|
171
172
|
|
|
172
173
|
parameters_cache_size = ManifestIO.obj.settings.get(c.PARAMETERS_CACHE_SIZE_SETTING, 1024)
|
|
173
174
|
parameters_cache_ttl = ManifestIO.obj.settings.get(c.PARAMETERS_CACHE_TTL_SETTING, 0)
|
|
@@ -209,7 +210,7 @@ class ApiServer:
|
|
|
209
210
|
return result
|
|
210
211
|
|
|
211
212
|
# Results API
|
|
212
|
-
results_path = base_path + '/{dataset}'
|
|
213
|
+
results_path = base_path + '/dataset/{dataset}'
|
|
213
214
|
|
|
214
215
|
results_cache_size = ManifestIO.obj.settings.get(c.RESULTS_CACHE_SIZE_SETTING, 128)
|
|
215
216
|
results_cache_ttl = ManifestIO.obj.settings.get(c.RESULTS_CACHE_TTL_SETTING, 0)
|
|
@@ -248,8 +249,10 @@ class ApiServer:
|
|
|
248
249
|
timer.add_activity_time("POST REQUEST total time for DATASET", start)
|
|
249
250
|
return result
|
|
250
251
|
|
|
251
|
-
# Catalog API
|
|
252
|
-
|
|
252
|
+
# Datasets Catalog API
|
|
253
|
+
datasets_path = base_path + '/datasets'
|
|
254
|
+
|
|
255
|
+
def get_datasets0(user: Optional[User]):
|
|
253
256
|
datasets_info = []
|
|
254
257
|
for dataset_name, dataset_config in self.dataset_configs.items():
|
|
255
258
|
if can_user_access_dataset(user, dataset_name):
|
|
@@ -258,30 +261,44 @@ class ApiServer:
|
|
|
258
261
|
'name': dataset_name,
|
|
259
262
|
'label': dataset_config.label,
|
|
260
263
|
'parameters_path': parameters_path.format(dataset=dataset_normalized),
|
|
261
|
-
'result_path': results_path.format(dataset=dataset_normalized)
|
|
262
|
-
'first_minor_version': 0
|
|
264
|
+
'result_path': results_path.format(dataset=dataset_normalized)
|
|
263
265
|
})
|
|
264
|
-
|
|
266
|
+
return {"datasets": datasets_info}
|
|
267
|
+
|
|
268
|
+
@app.get(datasets_path)
|
|
269
|
+
def get_datasets(request: Request, user: Optional[User] = Depends(get_current_user)):
|
|
270
|
+
return process_based_on_response_version_header(request.headers, {
|
|
271
|
+
0: lambda: get_datasets0(user)
|
|
272
|
+
})
|
|
273
|
+
|
|
274
|
+
# Projects Catalog API
|
|
275
|
+
def get_catalog0():
|
|
265
276
|
return {
|
|
266
277
|
'projects': [{
|
|
267
278
|
'name': ManifestIO.obj.project_variables.get_name(),
|
|
268
279
|
'label': ManifestIO.obj.project_variables.get_label(),
|
|
269
280
|
'versions': [{
|
|
270
281
|
'major_version': ManifestIO.obj.project_variables.get_major_version(),
|
|
271
|
-
'
|
|
282
|
+
'minor_versions': [0],
|
|
272
283
|
'token_path': token_path,
|
|
273
|
-
'
|
|
284
|
+
'datasets_path': datasets_path
|
|
274
285
|
}]
|
|
275
286
|
}]
|
|
276
287
|
}
|
|
277
288
|
|
|
278
289
|
@app.get(squirrels_version_path, response_class=JSONResponse)
|
|
279
|
-
async def get_catalog(request: Request
|
|
290
|
+
async def get_catalog(request: Request):
|
|
280
291
|
return process_based_on_response_version_header(request.headers, {
|
|
281
|
-
0: lambda: get_catalog0(
|
|
292
|
+
0: lambda: get_catalog0()
|
|
282
293
|
})
|
|
283
294
|
|
|
284
295
|
# Squirrels UI
|
|
296
|
+
static_dir = u.join_paths(os.path.dirname(__file__), c.PACKAGE_DATA_FOLDER, c.ASSETS_FOLDER)
|
|
297
|
+
app.mount('/'+c.ASSETS_FOLDER, StaticFiles(directory=static_dir), name=c.ASSETS_FOLDER)
|
|
298
|
+
|
|
299
|
+
templates_dir = u.join_paths(os.path.dirname(__file__), c.PACKAGE_DATA_FOLDER, c.TEMPLATES_FOLDER)
|
|
300
|
+
templates = Jinja2Templates(directory=templates_dir)
|
|
301
|
+
|
|
285
302
|
@app.get('/', response_class=HTMLResponse)
|
|
286
303
|
async def get_ui(request: Request):
|
|
287
304
|
return templates.TemplateResponse('index.html', {
|
squirrels/_authenticator.py
CHANGED
|
@@ -34,7 +34,7 @@ class Authenticator:
|
|
|
34
34
|
return AuthArgs(conn_args.proj_vars, conn_args.env_vars, conn_args._get_credential, connections, username, password)
|
|
35
35
|
|
|
36
36
|
def authenticate_user(self, username: str, password: str) -> Optional[User]:
|
|
37
|
-
user_cls = self.auth_helper.get_func_or_class("User", default_attr=User)
|
|
37
|
+
user_cls: type[User] = self.auth_helper.get_func_or_class("User", default_attr=User)
|
|
38
38
|
get_user = self.auth_helper.get_func_or_class(c.GET_USER_FUNC, is_required=False)
|
|
39
39
|
try:
|
|
40
40
|
real_user = get_user(self._get_auth_args(username, password)) if get_user is not None else None
|
|
@@ -48,9 +48,8 @@ class Authenticator:
|
|
|
48
48
|
fake_users = EnvironConfigIO.obj.get_users()
|
|
49
49
|
if username in fake_users and secrets.compare_digest(fake_users[username][c.USER_PWD_KEY], password):
|
|
50
50
|
is_internal = fake_users[username].get("is_internal", False)
|
|
51
|
-
user: User = user_cls(username, is_internal=is_internal)
|
|
52
51
|
try:
|
|
53
|
-
return
|
|
52
|
+
return user_cls.Create(username, fake_users[username], is_internal=is_internal)
|
|
54
53
|
except Exception as e:
|
|
55
54
|
raise u.FileExecutionError(f'Failed to create user from User model in {c.AUTH_FILE}', e)
|
|
56
55
|
|
squirrels/_command_line.py
CHANGED
|
@@ -33,7 +33,7 @@ def main():
|
|
|
33
33
|
|
|
34
34
|
compile_parser = subparsers.add_parser(c.COMPILE_CMD, help='Create files for rendered sql queries in the "target/compile" folder', add_help=False)
|
|
35
35
|
compile_parser.add_argument('-h', '--help', action="help", help="Show this help message and exit")
|
|
36
|
-
compile_parser.add_argument('-d', '--dataset', type=str, help="Select dataset to use for dataset
|
|
36
|
+
compile_parser.add_argument('-d', '--dataset', type=str, help="Select dataset to use for dataset traits. If not specified, all models for all datasets are compiled")
|
|
37
37
|
compile_parser.add_argument('-a', '--all-test-sets', action="store_true", help="Compile models for all selection test sets")
|
|
38
38
|
compile_parser.add_argument('-t', '--test-set', type=str, help="The selection test set to use. Default selections are used if not specified. Ignored if using --all-test-sets")
|
|
39
39
|
compile_parser.add_argument('-s', '--select', type=str, help="Select single model to compile. If not specified, all models for the dataset are compiled. Also, ignored if --dataset is not specified")
|
squirrels/_constants.py
CHANGED
|
@@ -9,7 +9,6 @@ PROJ_VARS_KEY = 'project_variables'
|
|
|
9
9
|
PROJECT_NAME_KEY = 'name'
|
|
10
10
|
PROJECT_LABEL_KEY = 'label'
|
|
11
11
|
MAJOR_VERSION_KEY = 'major_version'
|
|
12
|
-
MINOR_VERSION_KEY = 'minor_version'
|
|
13
12
|
|
|
14
13
|
PACKAGES_KEY = 'packages'
|
|
15
14
|
PACKAGE_GIT_KEY = 'git'
|
|
@@ -48,7 +47,7 @@ DATASET_NAME_KEY = 'name'
|
|
|
48
47
|
DATASET_LABEL_KEY = 'label'
|
|
49
48
|
DATASET_MODEL_KEY = 'model'
|
|
50
49
|
DATASET_PARAMETERS_KEY = 'parameters'
|
|
51
|
-
|
|
50
|
+
DATASET_TRAITS_KEY = 'traits'
|
|
52
51
|
|
|
53
52
|
DATASET_SCOPE_KEY = 'scope'
|
|
54
53
|
PUBLIC_SCOPE = 'public'
|
|
@@ -109,14 +108,17 @@ PARAMETERS_OUTPUT = 'parameters.json'
|
|
|
109
108
|
FINAL_VIEW_OUT_STEM = 'final_view'
|
|
110
109
|
|
|
111
110
|
# Dataset setting names
|
|
112
|
-
AUTH_TOKEN_EXPIRE_SETTING = 'auth.token.
|
|
111
|
+
AUTH_TOKEN_EXPIRE_SETTING = 'auth.token.expire_minutes'
|
|
113
112
|
PARAMETERS_CACHE_SIZE_SETTING = 'parameters.cache.size'
|
|
114
|
-
PARAMETERS_CACHE_TTL_SETTING = 'parameters.cache.
|
|
113
|
+
PARAMETERS_CACHE_TTL_SETTING = 'parameters.cache.ttl_minutes'
|
|
115
114
|
RESULTS_CACHE_SIZE_SETTING = 'results.cache.size'
|
|
116
|
-
RESULTS_CACHE_TTL_SETTING = 'results.cache.
|
|
115
|
+
RESULTS_CACHE_TTL_SETTING = 'results.cache.ttl_minutes'
|
|
117
116
|
TEST_SET_DEFAULT_USED_SETTING = 'selection_test_sets.default_name_used'
|
|
118
117
|
DB_CONN_DEFAULT_USED_SETTING = 'connections.default_name_used'
|
|
119
118
|
DEFAULT_MATERIALIZE_SETTING = 'defaults.federates.materialized'
|
|
119
|
+
IN_MEMORY_DB_SETTING = 'in_memory_database'
|
|
120
|
+
SQLITE = 'sqlite'
|
|
121
|
+
DUCKDB = 'duckdb'
|
|
120
122
|
|
|
121
123
|
# Selection cfg sections
|
|
122
124
|
USER_ATTRIBUTES_SECTION = 'user_attributes'
|
squirrels/_environcfg.py
CHANGED
|
@@ -43,7 +43,7 @@ class _EnvironConfig:
|
|
|
43
43
|
return credential[c.USERNAME_KEY], credential[c.PASSWORD_KEY]
|
|
44
44
|
|
|
45
45
|
def get_secret(self, key: str, *, default_factory: Optional[Callable[[],str]] = None) -> str:
|
|
46
|
-
if
|
|
46
|
+
if not self._secrets.get(key) and default_factory is not None:
|
|
47
47
|
self._secrets[key] = default_factory()
|
|
48
48
|
return self._secrets.get(key)
|
|
49
49
|
|
squirrels/_initializer.py
CHANGED
|
@@ -40,7 +40,6 @@ class Initializer:
|
|
|
40
40
|
options = ["core", "connections", "parameters", "dbview", "federate", "auth", "sample_db"]
|
|
41
41
|
CORE, CONNECTIONS, PARAMETERS, DBVIEW, FEDERATE, AUTH, SAMPLE_DB = options
|
|
42
42
|
TMP_FOLDER = "tmp"
|
|
43
|
-
IGNORES_FOLDER = "ignores"
|
|
44
43
|
|
|
45
44
|
answers = { x: getattr(args, x) for x in options }
|
|
46
45
|
if not any(answers.values()):
|
|
@@ -136,7 +135,7 @@ class Initializer:
|
|
|
136
135
|
|
|
137
136
|
create_manifest_file()
|
|
138
137
|
|
|
139
|
-
self._copy_file(".gitignore"
|
|
138
|
+
self._copy_file(".gitignore")
|
|
140
139
|
self._copy_file(c.MANIFEST_FILE, src_folder=TMP_FOLDER)
|
|
141
140
|
|
|
142
141
|
if connections_use_py:
|
squirrels/_manifest.py
CHANGED
|
@@ -26,10 +26,10 @@ class ProjectVarsConfig(ManifestComponentConfig):
|
|
|
26
26
|
data: dict
|
|
27
27
|
|
|
28
28
|
def __post_init__(self):
|
|
29
|
-
required_keys = [c.PROJECT_NAME_KEY, c.MAJOR_VERSION_KEY
|
|
29
|
+
required_keys = [c.PROJECT_NAME_KEY, c.MAJOR_VERSION_KEY]
|
|
30
30
|
self._validate_required(self.data, required_keys, c.PROJ_VARS_KEY)
|
|
31
31
|
|
|
32
|
-
integer_keys = [c.MAJOR_VERSION_KEY
|
|
32
|
+
integer_keys = [c.MAJOR_VERSION_KEY]
|
|
33
33
|
for key in integer_keys:
|
|
34
34
|
if key in self.data and not isinstance(self.data[key], int):
|
|
35
35
|
raise u.ConfigurationError(f'Project variable "{key}" must be an integer')
|
|
@@ -46,9 +46,6 @@ class ProjectVarsConfig(ManifestComponentConfig):
|
|
|
46
46
|
|
|
47
47
|
def get_major_version(self) -> int:
|
|
48
48
|
return self.data[c.MAJOR_VERSION_KEY]
|
|
49
|
-
|
|
50
|
-
def get_minor_version(self) -> int:
|
|
51
|
-
return self.data[c.MINOR_VERSION_KEY]
|
|
52
49
|
|
|
53
50
|
|
|
54
51
|
@dataclass
|
|
@@ -84,14 +81,13 @@ class DbConnConfig(ManifestComponentConfig):
|
|
|
84
81
|
|
|
85
82
|
@dataclass
|
|
86
83
|
class ParametersConfig(ManifestComponentConfig):
|
|
87
|
-
name: str
|
|
88
84
|
type: str
|
|
89
85
|
factory: str
|
|
90
86
|
arguments: dict
|
|
91
87
|
|
|
92
88
|
@classmethod
|
|
93
89
|
def from_dict(cls, kwargs: dict):
|
|
94
|
-
all_keys = [c.
|
|
90
|
+
all_keys = [c.PARAMETER_TYPE_KEY, c.PARAMETER_FACTORY_KEY, c.PARAMETER_ARGS_KEY]
|
|
95
91
|
cls._validate_required(kwargs, all_keys, c.PARAMETERS_KEY)
|
|
96
92
|
args = {key: kwargs[key] for key in all_keys}
|
|
97
93
|
return cls(**args)
|
|
@@ -150,7 +146,7 @@ class DatasetsConfig(ManifestComponentConfig):
|
|
|
150
146
|
model: str
|
|
151
147
|
scope: DatasetScope
|
|
152
148
|
parameters: Optional[list[str]]
|
|
153
|
-
|
|
149
|
+
traits: dict
|
|
154
150
|
|
|
155
151
|
@classmethod
|
|
156
152
|
def from_dict(cls, kwargs: dict):
|
|
@@ -166,8 +162,8 @@ class DatasetsConfig(ManifestComponentConfig):
|
|
|
166
162
|
raise u.ConfigurationError(f'Scope not found for dataset "{name}". Scope must be one of {scope_list}') from e
|
|
167
163
|
|
|
168
164
|
parameters = kwargs.get(c.DATASET_PARAMETERS_KEY)
|
|
169
|
-
|
|
170
|
-
return cls(name, label, model, scope, parameters,
|
|
165
|
+
traits = kwargs.get(c.DATASET_TRAITS_KEY, {})
|
|
166
|
+
return cls(name, label, model, scope, parameters, traits)
|
|
171
167
|
|
|
172
168
|
|
|
173
169
|
@dataclass
|
|
@@ -175,7 +171,7 @@ class _ManifestConfig:
|
|
|
175
171
|
project_variables: ProjectVarsConfig
|
|
176
172
|
packages: list[PackageConfig]
|
|
177
173
|
connections: dict[str, DbConnConfig]
|
|
178
|
-
parameters:
|
|
174
|
+
parameters: list[ParametersConfig]
|
|
179
175
|
selection_test_sets: dict[str, TestSetsConfig]
|
|
180
176
|
dbviews: dict[str, DbviewConfig]
|
|
181
177
|
federates: dict[str, FederateConfig]
|
|
@@ -209,7 +205,7 @@ class _ManifestConfig:
|
|
|
209
205
|
all_package_dirs.add(package.directory)
|
|
210
206
|
|
|
211
207
|
db_conns = cls._create_configs_as_dict(DbConnConfig, kwargs, c.DB_CONNECTIONS_KEY, c.DB_CONN_NAME_KEY)
|
|
212
|
-
params =
|
|
208
|
+
params = [ParametersConfig.from_dict(x) for x in kwargs.get(c.PARAMETERS_KEY, [])]
|
|
213
209
|
|
|
214
210
|
test_sets = cls._create_configs_as_dict(TestSetsConfig, kwargs, c.TEST_SETS_KEY, c.TEST_SET_NAME_KEY)
|
|
215
211
|
default_test_set: str = settings.get(c.TEST_SET_DEFAULT_USED_SETTING, c.DEFAULT_TEST_SET_NAME)
|
squirrels/_models.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
from typing import Optional, Callable, Iterable, Any
|
|
2
|
+
from typing import Union, Optional, Callable, Iterable, Any
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
4
|
from enum import Enum
|
|
5
5
|
from pathlib import Path
|
|
@@ -7,7 +7,7 @@ import sqlite3, pandas as pd, asyncio, os, shutil
|
|
|
7
7
|
|
|
8
8
|
from . import _constants as c, _utils as u, _py_module as pm
|
|
9
9
|
from .arguments.run_time_args import ContextArgs, ModelDepsArgs, ModelArgs
|
|
10
|
-
from .
|
|
10
|
+
from ._authenticator import User, Authenticator
|
|
11
11
|
from ._connection_set import ConnectionSetIO
|
|
12
12
|
from ._manifest import ManifestIO, DatasetsConfig
|
|
13
13
|
from ._parameter_sets import ParameterConfigsSetIO, ParameterSet
|
|
@@ -146,7 +146,7 @@ class Model:
|
|
|
146
146
|
configuration = SqlModelConfig(connection_name, materialized)
|
|
147
147
|
kwargs = {
|
|
148
148
|
"proj_vars": ctx_args.proj_vars, "env_vars": ctx_args.env_vars,
|
|
149
|
-
"user": ctx_args.user, "prms": ctx_args.prms, "
|
|
149
|
+
"user": ctx_args.user, "prms": ctx_args.prms, "traits": ctx_args.traits,
|
|
150
150
|
"ctx": ctx, "config": configuration.set_attribute
|
|
151
151
|
}
|
|
152
152
|
dependencies = set()
|
|
@@ -166,7 +166,7 @@ class Model:
|
|
|
166
166
|
|
|
167
167
|
async def _compile_python_model(self, ctx: dict[str, Any], ctx_args: ContextArgs) -> tuple[PyModelQuery, set]:
|
|
168
168
|
assert(isinstance(self.query_file.raw_query, RawPyQuery))
|
|
169
|
-
sqrl_args = ModelDepsArgs(ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.
|
|
169
|
+
sqrl_args = ModelDepsArgs(ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits, ctx)
|
|
170
170
|
try:
|
|
171
171
|
dependencies = await asyncio.to_thread(self.query_file.raw_query.dependencies_func, sqrl_args)
|
|
172
172
|
except Exception as e:
|
|
@@ -175,7 +175,7 @@ class Model:
|
|
|
175
175
|
dbview_conn_name = self._get_dbview_conn_name()
|
|
176
176
|
connections = ConnectionSetIO.obj.get_engines_as_dict()
|
|
177
177
|
ref = lambda x: self.upstreams[x].result
|
|
178
|
-
sqrl_args = ModelArgs(ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.
|
|
178
|
+
sqrl_args = ModelArgs(ctx_args.proj_vars, ctx_args.env_vars, ctx_args.user, ctx_args.prms, ctx_args.traits,
|
|
179
179
|
ctx, dbview_conn_name, connections, ref, set(dependencies))
|
|
180
180
|
|
|
181
181
|
def compiled_query():
|
|
@@ -215,9 +215,9 @@ class Model:
|
|
|
215
215
|
coroutines.append(coro)
|
|
216
216
|
await asyncio.gather(*coroutines)
|
|
217
217
|
|
|
218
|
-
def
|
|
218
|
+
def get_terminal_nodes(self, depencency_path: set[str]) -> set[str]:
|
|
219
219
|
if self.confirmed_no_cycles:
|
|
220
|
-
return
|
|
220
|
+
return set()
|
|
221
221
|
|
|
222
222
|
if self.name in depencency_path:
|
|
223
223
|
raise u.ConfigurationError(f'Cycle found in model dependency graph')
|
|
@@ -229,11 +229,24 @@ class Model:
|
|
|
229
229
|
new_path = set(depencency_path)
|
|
230
230
|
new_path.add(self.name)
|
|
231
231
|
for dep_model in self.upstreams.values():
|
|
232
|
-
terminal_nodes_under_dep = dep_model.
|
|
232
|
+
terminal_nodes_under_dep = dep_model.get_terminal_nodes(new_path)
|
|
233
233
|
terminal_nodes = terminal_nodes.union(terminal_nodes_under_dep)
|
|
234
234
|
|
|
235
235
|
self.confirmed_no_cycles = True
|
|
236
236
|
return terminal_nodes
|
|
237
|
+
|
|
238
|
+
def _load_pandas_to_table(self, df: pd.DataFrame, conn: sqlite3.Connection) -> None:
|
|
239
|
+
if u.use_duckdb():
|
|
240
|
+
conn.execute(f"CREATE TABLE {self.name} AS FROM df")
|
|
241
|
+
else:
|
|
242
|
+
df.to_sql(self.name, conn, index=False)
|
|
243
|
+
|
|
244
|
+
def _load_table_to_pandas(self, conn: sqlite3.Connection) -> pd.DataFrame:
|
|
245
|
+
if u.use_duckdb():
|
|
246
|
+
return conn.execute(f"FROM {self.name}").df()
|
|
247
|
+
else:
|
|
248
|
+
query = f"SELECT * FROM {self.name}"
|
|
249
|
+
return pd.read_sql(query, conn)
|
|
237
250
|
|
|
238
251
|
async def _run_sql_model(self, conn: sqlite3.Connection) -> None:
|
|
239
252
|
assert(isinstance(self.compiled_query, SqlModelQuery))
|
|
@@ -248,7 +261,7 @@ class Model:
|
|
|
248
261
|
raise u.FileExecutionError(f'Failed to run dbview sql model "{self.name}"', e)
|
|
249
262
|
|
|
250
263
|
df = await asyncio.to_thread(run_sql_query)
|
|
251
|
-
await asyncio.to_thread(
|
|
264
|
+
await asyncio.to_thread(self._load_pandas_to_table, df, conn)
|
|
252
265
|
if self.needs_pandas or self.is_target:
|
|
253
266
|
self.result = df
|
|
254
267
|
elif self.query_file.model_type == ModelType.FEDERATE:
|
|
@@ -261,15 +274,14 @@ class Model:
|
|
|
261
274
|
|
|
262
275
|
await asyncio.to_thread(create_table)
|
|
263
276
|
if self.needs_pandas or self.is_target:
|
|
264
|
-
|
|
265
|
-
self.result = await asyncio.to_thread(pd.read_sql, query, conn)
|
|
277
|
+
self.result = await asyncio.to_thread(self._load_table_to_pandas, conn)
|
|
266
278
|
|
|
267
279
|
async def _run_python_model(self, conn: sqlite3.Connection) -> None:
|
|
268
280
|
assert(isinstance(self.compiled_query, PyModelQuery))
|
|
269
281
|
|
|
270
282
|
df = await asyncio.to_thread(self.compiled_query.query)
|
|
271
283
|
if self.needs_sql_table:
|
|
272
|
-
await asyncio.to_thread(
|
|
284
|
+
await asyncio.to_thread(self._load_pandas_to_table, df, conn)
|
|
273
285
|
if self.needs_pandas or self.is_target:
|
|
274
286
|
self.result = df
|
|
275
287
|
|
|
@@ -320,7 +332,7 @@ class DAG:
|
|
|
320
332
|
context = {}
|
|
321
333
|
param_args = ParameterConfigsSetIO.args
|
|
322
334
|
prms = self.parameter_set.get_parameters_as_dict()
|
|
323
|
-
args = ContextArgs(param_args.proj_vars, param_args.env_vars, user, prms, self.dataset.
|
|
335
|
+
args = ContextArgs(param_args.proj_vars, param_args.env_vars, user, prms, self.dataset.traits)
|
|
324
336
|
try:
|
|
325
337
|
context_func(ctx=context, sqrl=args)
|
|
326
338
|
except Exception as e:
|
|
@@ -331,14 +343,21 @@ class DAG:
|
|
|
331
343
|
async def _compile_models(self, context: dict[str, Any], ctx_args: ContextArgs, recurse: bool) -> None:
|
|
332
344
|
await self.target_model.compile(context, ctx_args, self.models_dict, recurse)
|
|
333
345
|
|
|
334
|
-
def
|
|
346
|
+
def _get_terminal_nodes(self) -> set[str]:
|
|
335
347
|
start = time.time()
|
|
336
|
-
terminal_nodes = self.target_model.
|
|
348
|
+
terminal_nodes = self.target_model.get_terminal_nodes(set())
|
|
349
|
+
for model in self.models_dict.values():
|
|
350
|
+
model.confirmed_no_cycles = False
|
|
337
351
|
timer.add_activity_time(f"validating no cycles in models dependencies", start)
|
|
338
352
|
return terminal_nodes
|
|
339
353
|
|
|
340
354
|
async def _run_models(self, terminal_nodes: set[str]) -> None:
|
|
341
|
-
|
|
355
|
+
if u.use_duckdb():
|
|
356
|
+
import duckdb
|
|
357
|
+
conn = duckdb.connect()
|
|
358
|
+
else:
|
|
359
|
+
conn = sqlite3.connect(":memory:", check_same_thread=False)
|
|
360
|
+
|
|
342
361
|
try:
|
|
343
362
|
coroutines = []
|
|
344
363
|
for model_name in terminal_nodes:
|
|
@@ -360,7 +379,7 @@ class DAG:
|
|
|
360
379
|
|
|
361
380
|
await self._compile_models(context, ctx_args, recurse)
|
|
362
381
|
|
|
363
|
-
terminal_nodes = self.
|
|
382
|
+
terminal_nodes = self._get_terminal_nodes()
|
|
364
383
|
|
|
365
384
|
if runquery:
|
|
366
385
|
await self._run_models(terminal_nodes)
|
|
@@ -431,10 +450,13 @@ class ModelsIO:
|
|
|
431
450
|
@classmethod
|
|
432
451
|
async def WriteDatasetOutputsGivenTestSet(cls, dataset: str, select: str, test_set: str, runquery: bool, recurse: bool) -> Any:
|
|
433
452
|
test_set_conf = ManifestIO.obj.selection_test_sets[test_set]
|
|
434
|
-
|
|
435
|
-
for key, val in test_set_conf.user_attributes.items():
|
|
436
|
-
setattr(user, key, val)
|
|
453
|
+
user_attributes = test_set_conf.user_attributes
|
|
437
454
|
selections = test_set_conf.parameters
|
|
455
|
+
|
|
456
|
+
username, is_internal = user_attributes.get("username", ""), user_attributes.get("is_internal", False)
|
|
457
|
+
user_cls: type[User] = Authenticator.get_auth_helper().get_func_or_class("User", default_attr=User)
|
|
458
|
+
user = user_cls.Create(username, test_set_conf.user_attributes, is_internal=is_internal)
|
|
459
|
+
|
|
438
460
|
dag = cls.GenerateDAG(dataset, target_model_name=select, always_pandas=True)
|
|
439
461
|
await dag.execute(cls.context_func, user, selections, runquery=runquery, recurse=recurse)
|
|
440
462
|
|
|
@@ -445,7 +467,7 @@ class ModelsIO:
|
|
|
445
467
|
def write_model_outputs(model: Model) -> None:
|
|
446
468
|
subfolder = c.DBVIEWS_FOLDER if model.query_file.model_type == ModelType.DBVIEW else c.FEDERATES_FOLDER
|
|
447
469
|
subpath = u.join_paths(output_folder, subfolder)
|
|
448
|
-
os.makedirs(subpath)
|
|
470
|
+
os.makedirs(subpath, exist_ok=True)
|
|
449
471
|
if isinstance(model.compiled_query, SqlModelQuery):
|
|
450
472
|
output_filepath = u.join_paths(subpath, model.name+'.sql')
|
|
451
473
|
query = model.compiled_query.query
|
squirrels/_parameter_configs.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
from typing import Type, Optional, Union, Sequence, Iterator
|
|
2
|
+
from typing import Type, Optional, Union, Sequence, Iterator, Any
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
4
|
from abc import ABCMeta, abstractmethod
|
|
5
5
|
from copy import copy
|
|
@@ -32,10 +32,10 @@ class ParameterConfigBase(metaclass=ABCMeta):
|
|
|
32
32
|
self.user_attribute = user_attribute
|
|
33
33
|
self.parent_name = parent_name
|
|
34
34
|
|
|
35
|
-
def _get_user_group(self, user: Optional[User]) ->
|
|
35
|
+
def _get_user_group(self, user: Optional[User]) -> Any:
|
|
36
36
|
if self.user_attribute is not None:
|
|
37
37
|
if user is None:
|
|
38
|
-
raise u.ConfigurationError(f"Public datasets
|
|
38
|
+
raise u.ConfigurationError(f"Public datasets (which allows non-authenticated users) cannot use parameter " +
|
|
39
39
|
f"'{self.name}' because 'user_attribute' is defined on this parameter.")
|
|
40
40
|
return getattr(user, self.user_attribute)
|
|
41
41
|
|
|
@@ -130,7 +130,7 @@ class SelectionParameterConfig(ParameterConfig):
|
|
|
130
130
|
def _get_default_ids_iterator(self, options: Sequence[po.SelectParameterOption]) -> Iterator[str]:
|
|
131
131
|
return (x._identifier for x in options if x._is_default)
|
|
132
132
|
|
|
133
|
-
def copy(self) ->
|
|
133
|
+
def copy(self) -> SelectionParameterConfig:
|
|
134
134
|
"""
|
|
135
135
|
Use for unit testing only
|
|
136
136
|
"""
|
squirrels/_parameter_sets.py
CHANGED
|
@@ -164,7 +164,6 @@ class ParameterConfigsSetIO:
|
|
|
164
164
|
|
|
165
165
|
@classmethod
|
|
166
166
|
def _AddFromDict(cls, param_config: ParametersConfig) -> None:
|
|
167
|
-
param_config.arguments["name"] = param_config.name
|
|
168
167
|
ptype = getattr(p, param_config.type)
|
|
169
168
|
factory = getattr(ptype, param_config.factory)
|
|
170
169
|
factory(**param_config.arguments)
|
|
@@ -174,8 +173,7 @@ class ParameterConfigsSetIO:
|
|
|
174
173
|
start = time.time()
|
|
175
174
|
cls.obj = _ParameterConfigsSet()
|
|
176
175
|
|
|
177
|
-
|
|
178
|
-
for param_as_dict in parameters_from_manifest:
|
|
176
|
+
for param_as_dict in ManifestIO.obj.parameters:
|
|
179
177
|
cls._AddFromDict(param_as_dict)
|
|
180
178
|
|
|
181
179
|
conn_args = ConnectionSetIO.args
|
squirrels/_py_module.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from typing import Type, Optional, Any
|
|
2
2
|
from types import ModuleType
|
|
3
|
-
|
|
3
|
+
import importlib.util
|
|
4
4
|
|
|
5
5
|
from . import _constants as c, _utils as u
|
|
6
6
|
|
|
@@ -16,7 +16,9 @@ class PyModule:
|
|
|
16
16
|
"""
|
|
17
17
|
self.filepath = str(filepath)
|
|
18
18
|
try:
|
|
19
|
-
|
|
19
|
+
spec = importlib.util.spec_from_file_location(self.filepath, self.filepath)
|
|
20
|
+
self.module = importlib.util.module_from_spec(spec)
|
|
21
|
+
spec.loader.exec_module(self.module)
|
|
20
22
|
except FileNotFoundError as e:
|
|
21
23
|
if is_required:
|
|
22
24
|
raise u.ConfigurationError(f"Required file not found: '{self.filepath}'") from e
|
squirrels/_utils.py
CHANGED
|
@@ -3,6 +3,8 @@ from pathlib import Path
|
|
|
3
3
|
from pandas.api import types as pd_types
|
|
4
4
|
import json, jinja2 as j2, pandas as pd
|
|
5
5
|
|
|
6
|
+
from . import _constants as c
|
|
7
|
+
|
|
6
8
|
FilePath = Union[str, Path]
|
|
7
9
|
|
|
8
10
|
|
|
@@ -179,3 +181,8 @@ def process_if_not_none(input_val: Optional[X], processor: Callable[[X], Y]) ->
|
|
|
179
181
|
if input_val is None:
|
|
180
182
|
return None
|
|
181
183
|
return processor(input_val)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def use_duckdb():
|
|
187
|
+
from ._manifest import ManifestIO
|
|
188
|
+
return (ManifestIO.obj.settings.get(c.IN_MEMORY_DB_SETTING, c.SQLITE) == c.DUCKDB)
|