squirrels 0.1.1.post1__py3-none-any.whl → 0.2.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of squirrels might be problematic. Click here for more details.
- squirrels/__init__.py +10 -16
- squirrels/_api_server.py +234 -80
- squirrels/_authenticator.py +84 -0
- squirrels/_command_line.py +60 -72
- squirrels/_connection_set.py +96 -0
- squirrels/_constants.py +114 -33
- squirrels/_environcfg.py +77 -0
- squirrels/_initializer.py +126 -67
- squirrels/_manifest.py +195 -168
- squirrels/_models.py +495 -0
- squirrels/_package_loader.py +26 -0
- squirrels/_parameter_configs.py +401 -0
- squirrels/_parameter_sets.py +188 -0
- squirrels/_py_module.py +60 -0
- squirrels/_timer.py +36 -0
- squirrels/_utils.py +81 -49
- squirrels/_version.py +2 -2
- squirrels/arguments/init_time_args.py +32 -0
- squirrels/arguments/run_time_args.py +82 -0
- squirrels/data_sources.py +380 -155
- squirrels/dateutils.py +86 -57
- squirrels/package_data/base_project/Dockerfile +15 -0
- squirrels/package_data/base_project/connections.yml +7 -0
- squirrels/package_data/base_project/database/{sample_database.db → expenses.db} +0 -0
- squirrels/package_data/base_project/environcfg.yml +29 -0
- squirrels/package_data/base_project/ignores/.dockerignore +8 -0
- squirrels/package_data/base_project/ignores/.gitignore +7 -0
- squirrels/package_data/base_project/models/dbviews/database_view1.py +36 -0
- squirrels/package_data/base_project/models/dbviews/database_view1.sql +15 -0
- squirrels/package_data/base_project/models/federates/dataset_example.py +20 -0
- squirrels/package_data/base_project/models/federates/dataset_example.sql +3 -0
- squirrels/package_data/base_project/parameters.yml +109 -0
- squirrels/package_data/base_project/pyconfigs/auth.py +47 -0
- squirrels/package_data/base_project/pyconfigs/connections.py +28 -0
- squirrels/package_data/base_project/pyconfigs/context.py +45 -0
- squirrels/package_data/base_project/pyconfigs/parameters.py +55 -0
- squirrels/package_data/base_project/seeds/mocks/category.csv +3 -0
- squirrels/package_data/base_project/seeds/mocks/max_filter.csv +2 -0
- squirrels/package_data/base_project/seeds/mocks/subcategory.csv +6 -0
- squirrels/package_data/base_project/squirrels.yml.j2 +57 -0
- squirrels/package_data/base_project/tmp/.gitignore +2 -0
- squirrels/package_data/static/script.js +159 -63
- squirrels/package_data/static/style.css +79 -15
- squirrels/package_data/static/widgets.js +133 -0
- squirrels/package_data/templates/index.html +65 -23
- squirrels/package_data/templates/index2.html +22 -0
- squirrels/parameter_options.py +216 -119
- squirrels/parameters.py +407 -478
- squirrels/user_base.py +58 -0
- squirrels-0.2.0.dev0.dist-info/METADATA +126 -0
- squirrels-0.2.0.dev0.dist-info/RECORD +56 -0
- {squirrels-0.1.1.post1.dist-info → squirrels-0.2.0.dev0.dist-info}/WHEEL +1 -2
- squirrels-0.2.0.dev0.dist-info/entry_points.txt +3 -0
- squirrels/_credentials_manager.py +0 -87
- squirrels/_module_loader.py +0 -37
- squirrels/_parameter_set.py +0 -151
- squirrels/_renderer.py +0 -286
- squirrels/_timed_imports.py +0 -37
- squirrels/connection_set.py +0 -126
- squirrels/package_data/base_project/.gitignore +0 -4
- squirrels/package_data/base_project/connections.py +0 -20
- squirrels/package_data/base_project/datasets/sample_dataset/context.py +0 -22
- squirrels/package_data/base_project/datasets/sample_dataset/database_view1.py +0 -29
- squirrels/package_data/base_project/datasets/sample_dataset/database_view1.sql.j2 +0 -12
- squirrels/package_data/base_project/datasets/sample_dataset/final_view.py +0 -11
- squirrels/package_data/base_project/datasets/sample_dataset/final_view.sql.j2 +0 -3
- squirrels/package_data/base_project/datasets/sample_dataset/parameters.py +0 -47
- squirrels/package_data/base_project/datasets/sample_dataset/selections.cfg +0 -9
- squirrels/package_data/base_project/squirrels.yaml +0 -22
- squirrels-0.1.1.post1.dist-info/METADATA +0 -67
- squirrels-0.1.1.post1.dist-info/RECORD +0 -40
- squirrels-0.1.1.post1.dist-info/entry_points.txt +0 -2
- squirrels-0.1.1.post1.dist-info/top_level.txt +0 -1
- {squirrels-0.1.1.post1.dist-info → squirrels-0.2.0.dev0.dist-info}/LICENSE +0 -0
squirrels/__init__.py
CHANGED
|
@@ -1,18 +1,12 @@
|
|
|
1
|
-
|
|
2
|
-
from .parameters import Parameter, SingleSelectParameter, MultiSelectParameter, DateParameter, NumberParameter, NumRangeParameter, DataSourceParameter
|
|
3
|
-
from .data_sources import SelectionDataSource, DateDataSource, NumberDataSource, NumRangeDataSource
|
|
4
|
-
from .connection_set import ConnectionSet
|
|
1
|
+
__version__ = '0.2.0'
|
|
5
2
|
|
|
3
|
+
from typing import Union
|
|
4
|
+
from sqlalchemy import Engine, Pool
|
|
5
|
+
from pandas import DataFrame
|
|
6
6
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
Returns:
|
|
15
|
-
Credential: Object with attributes "username" and "password"
|
|
16
|
-
"""
|
|
17
|
-
from ._credentials_manager import squirrels_config_io
|
|
18
|
-
return squirrels_config_io.get_credential(key)
|
|
7
|
+
from .arguments.init_time_args import ConnectionsArgs, ParametersArgs
|
|
8
|
+
from .arguments.run_time_args import ContextArgs, ModelDepsArgs, ModelArgs
|
|
9
|
+
from .parameter_options import SelectParameterOption, DateParameterOption, DateRangeParameterOption, NumberParameterOption, NumberRangeParameterOption
|
|
10
|
+
from .parameters import Parameter, SingleSelectParameter, MultiSelectParameter, DateParameter, DateRangeParameter, NumberParameter, NumberRangeParameter
|
|
11
|
+
from .data_sources import SingleSelectDataSource, MultiSelectDataSource, DateDataSource, DateRangeDataSource, NumberDataSource, NumberRangeDataSource
|
|
12
|
+
from .user_base import User, WrongPassword
|
squirrels/_api_server.py
CHANGED
|
@@ -1,134 +1,288 @@
|
|
|
1
|
-
from typing import
|
|
2
|
-
from fastapi import FastAPI, Request, HTTPException
|
|
3
|
-
from fastapi.datastructures import QueryParams
|
|
1
|
+
from typing import Iterable, Optional, Mapping, Callable, Coroutine, TypeVar, Any
|
|
2
|
+
from fastapi import Depends, FastAPI, Request, HTTPException, status
|
|
4
3
|
from fastapi.responses import HTMLResponse, JSONResponse
|
|
5
4
|
from fastapi.templating import Jinja2Templates
|
|
6
5
|
from fastapi.staticfiles import StaticFiles
|
|
7
|
-
from
|
|
8
|
-
import
|
|
6
|
+
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
|
7
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
8
|
+
from cachetools import TTLCache
|
|
9
|
+
import os, traceback, pandas as pd
|
|
9
10
|
|
|
10
|
-
from
|
|
11
|
-
from
|
|
12
|
-
from
|
|
13
|
-
from
|
|
14
|
-
from
|
|
11
|
+
from . import _constants as c, _utils as u
|
|
12
|
+
from ._version import sq_major_version
|
|
13
|
+
from ._manifest import ManifestIO
|
|
14
|
+
from ._authenticator import User, Authenticator
|
|
15
|
+
from ._timer import timer, time
|
|
16
|
+
from ._parameter_sets import ParameterSet
|
|
17
|
+
from ._models import ModelsIO
|
|
15
18
|
|
|
16
19
|
|
|
17
20
|
class ApiServer:
|
|
18
|
-
def __init__(self,
|
|
21
|
+
def __init__(self, no_cache: bool, debug: bool) -> None:
|
|
19
22
|
"""
|
|
20
23
|
Constructor for ApiServer
|
|
21
24
|
|
|
22
25
|
Parameters:
|
|
23
|
-
manifest (Manifest): Manifest object produced from squirrels.yaml
|
|
24
|
-
conn_set (ConnectionSet): Set of all connection pools defined in connections.py
|
|
25
26
|
no_cache (bool): Whether to disable caching
|
|
26
27
|
debug (bool): Set to True to show "hidden" parameters in the /parameters endpoint response
|
|
27
28
|
"""
|
|
28
|
-
self.manifest = manifest
|
|
29
|
-
self.conn_set = conn_set
|
|
30
29
|
self.no_cache = no_cache
|
|
31
30
|
self.debug = debug
|
|
31
|
+
self.dataset_configs = ManifestIO.obj.datasets
|
|
32
32
|
|
|
33
|
-
|
|
34
|
-
self.
|
|
35
|
-
for dataset in self.datasets:
|
|
36
|
-
rendererIO = RendererIOWrapper(dataset, manifest, conn_set)
|
|
37
|
-
self.renderers[dataset] = rendererIO.renderer
|
|
38
|
-
|
|
39
|
-
def _get_parameters_helper(self, dataset: str, query_params: Set[Tuple[str, str]]) -> Dict:
|
|
40
|
-
if len(query_params) > 1:
|
|
41
|
-
raise _utils.InvalidInputError("The /parameters endpoint takes at most 1 query parameter")
|
|
42
|
-
renderer = self.renderers[dataset]
|
|
43
|
-
parameters = renderer.apply_selections(dict(query_params), updates_only = True)
|
|
44
|
-
return parameters.to_json_dict(self.debug)
|
|
33
|
+
token_expiry_minutes = ManifestIO.obj.settings.get(c.AUTH_TOKEN_EXPIRE_SETTING, 30)
|
|
34
|
+
self.authenticator = Authenticator(token_expiry_minutes)
|
|
45
35
|
|
|
46
|
-
def
|
|
47
|
-
renderer = self.renderers[dataset]
|
|
48
|
-
_, _, _, _, df = renderer.load_results(dict(query_params))
|
|
49
|
-
return _utils.df_to_json(df)
|
|
50
|
-
|
|
51
|
-
def _apply_api_function(self, api_function):
|
|
52
|
-
try:
|
|
53
|
-
return api_function()
|
|
54
|
-
except _utils.InvalidInputError as e:
|
|
55
|
-
traceback.print_exc()
|
|
56
|
-
raise HTTPException(status_code=400, detail="Invalid User Input: "+str(e)) from e
|
|
57
|
-
except _utils.ConfigurationError as e:
|
|
58
|
-
traceback.print_exc()
|
|
59
|
-
raise HTTPException(status_code=500, detail="Squirrels Configuration Error: "+str(e)) from e
|
|
60
|
-
except Exception as e:
|
|
61
|
-
traceback.print_exc()
|
|
62
|
-
raise HTTPException(status_code=500, detail="Squirrels Framework Error: "+str(e)) from e
|
|
63
|
-
|
|
64
|
-
def _apply_dataset_api_function(self, api_function, dataset: str, raw_query_params: QueryParams):
|
|
65
|
-
dataset = _utils.normalize_name(dataset)
|
|
66
|
-
query_params = set()
|
|
67
|
-
for key, val in raw_query_params.items():
|
|
68
|
-
query_params.add((_utils.normalize_name(key), val))
|
|
69
|
-
query_params = frozenset(query_params)
|
|
70
|
-
return self._apply_api_function(lambda: api_function(dataset, query_params))
|
|
71
|
-
|
|
72
|
-
def run(self, uvicorn_args: List[str]) -> None:
|
|
36
|
+
def run(self, uvicorn_args: list[str]) -> None:
|
|
73
37
|
"""
|
|
74
38
|
Runs the API server with uvicorn for CLI "squirrels run"
|
|
75
39
|
|
|
76
40
|
Parameters:
|
|
77
|
-
uvicorn_args
|
|
41
|
+
uvicorn_args: List of arguments to pass to uvicorn.run. Currently only supports "host" and "port"
|
|
78
42
|
"""
|
|
43
|
+
start = time.time()
|
|
79
44
|
app = FastAPI()
|
|
80
45
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
46
|
+
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
|
|
47
|
+
|
|
48
|
+
squirrels_version_path = f'/squirrels-v{sq_major_version}'
|
|
49
|
+
partial_base_path = f'/{ManifestIO.obj.project_variables.get_name()}/v{ManifestIO.obj.project_variables.get_major_version()}'
|
|
50
|
+
base_path = squirrels_version_path + u.normalize_name_for_api(partial_base_path)
|
|
84
51
|
|
|
85
|
-
static_dir =
|
|
52
|
+
static_dir = u.join_paths(os.path.dirname(__file__), c.PACKAGE_DATA_FOLDER, c.STATIC_FOLDER)
|
|
86
53
|
app.mount('/static', StaticFiles(directory=static_dir), name='static')
|
|
87
54
|
|
|
88
|
-
templates_dir =
|
|
55
|
+
templates_dir = u.join_paths(os.path.dirname(__file__), c.PACKAGE_DATA_FOLDER, c.TEMPLATES_FOLDER)
|
|
89
56
|
templates = Jinja2Templates(directory=templates_dir)
|
|
90
57
|
|
|
58
|
+
# Exception handlers
|
|
59
|
+
@app.exception_handler(u.InvalidInputError)
|
|
60
|
+
async def invalid_input_error_handler(request: Request, exc: u.InvalidInputError):
|
|
61
|
+
traceback.print_exc()
|
|
62
|
+
return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST,
|
|
63
|
+
content={"message": f"Invalid user input: {str(exc)}"})
|
|
64
|
+
|
|
65
|
+
@app.exception_handler(u.ConfigurationError)
|
|
66
|
+
async def configuration_error_handler(request: Request, exc: u.InvalidInputError):
|
|
67
|
+
traceback.print_exc()
|
|
68
|
+
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
69
|
+
content={"message": f"Squirrels configuration error: {str(exc)}"})
|
|
70
|
+
|
|
71
|
+
@app.exception_handler(NotImplementedError)
|
|
72
|
+
async def not_implemented_error_handler(request: Request, exc: u.InvalidInputError):
|
|
73
|
+
traceback.print_exc()
|
|
74
|
+
return JSONResponse(status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
|
75
|
+
content={"message": f"Not implemented error: {str(exc)}"})
|
|
76
|
+
|
|
77
|
+
# Helpers
|
|
78
|
+
T = TypeVar('T')
|
|
79
|
+
|
|
80
|
+
def get_versioning_request_header(headers: Mapping, header_key: str):
|
|
81
|
+
header_value = headers.get(header_key)
|
|
82
|
+
if header_value is None:
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
try:
|
|
86
|
+
result = int(header_value)
|
|
87
|
+
except ValueError:
|
|
88
|
+
raise u.InvalidInputError(f"Request header '{header_key}' must be an integer. Got '{header_value}'")
|
|
89
|
+
|
|
90
|
+
if result < 0 or result > int(sq_major_version):
|
|
91
|
+
raise u.InvalidInputError(f"Request header '{header_key}' not in valid range. Got '{result}'")
|
|
92
|
+
|
|
93
|
+
return result
|
|
94
|
+
|
|
95
|
+
REQUEST_VERSION_REQUEST_HEADER = "squirrels-request-version"
|
|
96
|
+
def get_request_version_header(headers: Mapping):
|
|
97
|
+
return get_versioning_request_header(headers, REQUEST_VERSION_REQUEST_HEADER)
|
|
98
|
+
|
|
99
|
+
RESPONSE_VERSION_REQUEST_HEADER = "squirrels-response-version"
|
|
100
|
+
def process_based_on_response_version_header(headers: Mapping, processes: dict[str, Callable[[], T]]) -> T:
|
|
101
|
+
response_version = get_versioning_request_header(headers, RESPONSE_VERSION_REQUEST_HEADER)
|
|
102
|
+
if response_version is None or response_version >= 0:
|
|
103
|
+
return processes[0]()
|
|
104
|
+
else:
|
|
105
|
+
raise u.InvalidInputError(f'Invalid value for "{RESPONSE_VERSION_REQUEST_HEADER}" header: {response_version}')
|
|
106
|
+
|
|
107
|
+
def can_user_access_dataset(user: Optional[User], dataset: str):
|
|
108
|
+
dataset_scope = self.dataset_configs[dataset].scope
|
|
109
|
+
return self.authenticator.can_user_access_scope(user, dataset_scope)
|
|
110
|
+
|
|
111
|
+
async def apply_dataset_api_function(
|
|
112
|
+
api_function: Callable[..., Coroutine[Any, Any, T]], user: Optional[User], dataset: str, headers: Mapping, params: Mapping
|
|
113
|
+
) -> T:
|
|
114
|
+
dataset_normalized = u.normalize_name(dataset)
|
|
115
|
+
if not can_user_access_dataset(user, dataset_normalized):
|
|
116
|
+
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
|
|
117
|
+
detail="Could not validate credentials",
|
|
118
|
+
headers={"WWW-Authenticate": "Bearer"})
|
|
119
|
+
|
|
120
|
+
request_version = get_request_version_header(headers)
|
|
121
|
+
|
|
122
|
+
# Changing selections into a cachable "frozenset" that will later be converted to dictionary
|
|
123
|
+
selections = set()
|
|
124
|
+
for key, val in params.items():
|
|
125
|
+
if not isinstance(val, str):
|
|
126
|
+
val = tuple(val)
|
|
127
|
+
selections.add((u.normalize_name(key), val))
|
|
128
|
+
selections = frozenset(selections)
|
|
129
|
+
|
|
130
|
+
return await api_function(user, dataset_normalized, selections, request_version)
|
|
131
|
+
|
|
132
|
+
# Login
|
|
133
|
+
token_path = base_path + '/token'
|
|
134
|
+
|
|
135
|
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=token_path, auto_error=False)
|
|
136
|
+
|
|
137
|
+
@app.post(token_path)
|
|
138
|
+
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
|
|
139
|
+
user: Optional[User] = self.authenticator.authenticate_user(form_data.username, form_data.password)
|
|
140
|
+
if not user:
|
|
141
|
+
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
|
|
142
|
+
detail="Incorrect username or password",
|
|
143
|
+
headers={"WWW-Authenticate": "Bearer"})
|
|
144
|
+
access_token, expiry = self.authenticator.create_access_token(user)
|
|
145
|
+
return {
|
|
146
|
+
"access_token": access_token,
|
|
147
|
+
"token_type": "bearer",
|
|
148
|
+
"username": user.username,
|
|
149
|
+
"expiry_time": expiry
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
async def get_current_user(token: str = Depends(oauth2_scheme)) -> Optional[User]:
|
|
153
|
+
user = self.authenticator.get_user_from_token(token)
|
|
154
|
+
return user
|
|
155
|
+
|
|
156
|
+
async def do_cachable_action(cache: TTLCache, action: Callable[..., Coroutine[Any, Any, T]], *args) -> T:
|
|
157
|
+
cache_key = tuple(args)
|
|
158
|
+
result = cache.get(cache_key)
|
|
159
|
+
if result is None:
|
|
160
|
+
result = await action(*args)
|
|
161
|
+
cache[cache_key] = result
|
|
162
|
+
return result
|
|
163
|
+
|
|
91
164
|
# Parameters API
|
|
92
165
|
parameters_path = base_path + '/{dataset}/parameters'
|
|
93
166
|
|
|
94
|
-
parameters_cache_size =
|
|
95
|
-
parameters_cache_ttl =
|
|
167
|
+
parameters_cache_size = ManifestIO.obj.settings.get(c.PARAMETERS_CACHE_SIZE_SETTING, 1024)
|
|
168
|
+
parameters_cache_ttl = ManifestIO.obj.settings.get(c.PARAMETERS_CACHE_TTL_SETTING, 0)
|
|
169
|
+
|
|
170
|
+
async def get_parameters_helper(
|
|
171
|
+
user: Optional[User], dataset: str, selections: Iterable[tuple[str, str]], request_version: Optional[int]
|
|
172
|
+
) -> ParameterSet:
|
|
173
|
+
if len(selections) > 1:
|
|
174
|
+
raise u.InvalidInputError(f"The /parameters endpoint takes at most 1 query parameter. Got {dict(selections)}")
|
|
175
|
+
dag = ModelsIO.GenerateDAG(dataset)
|
|
176
|
+
dag.apply_selections(user, dict(selections), request_version=request_version)
|
|
177
|
+
return dag.parameter_set
|
|
96
178
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
179
|
+
params_cache = TTLCache(maxsize=parameters_cache_size, ttl=parameters_cache_ttl*60)
|
|
180
|
+
|
|
181
|
+
async def get_parameters_cachable(*args) -> T:
|
|
182
|
+
return await do_cachable_action(params_cache, get_parameters_helper, *args)
|
|
183
|
+
|
|
184
|
+
async def get_parameters_definition(dataset: str, user: Optional[User], headers: Mapping, params: Mapping):
|
|
185
|
+
api_function = get_parameters_helper if self.no_cache else get_parameters_cachable
|
|
186
|
+
result = await apply_dataset_api_function(api_function, user, dataset, headers, params)
|
|
187
|
+
return process_based_on_response_version_header(headers, {
|
|
188
|
+
0: result.to_json_dict0
|
|
189
|
+
})
|
|
100
190
|
|
|
101
191
|
@app.get(parameters_path, response_class=JSONResponse)
|
|
102
|
-
async def get_parameters(dataset: str, request: Request):
|
|
103
|
-
|
|
104
|
-
|
|
192
|
+
async def get_parameters(dataset: str, request: Request, user: Optional[User] = Depends(get_current_user)):
|
|
193
|
+
start = time.time()
|
|
194
|
+
result = await get_parameters_definition(dataset, user, request.headers, request.query_params)
|
|
195
|
+
timer.add_activity_time("GET REQUEST total time for PARAMETERS", start)
|
|
196
|
+
return result
|
|
197
|
+
|
|
198
|
+
@app.post(parameters_path, response_class=JSONResponse)
|
|
199
|
+
async def get_parameters_with_post(dataset: str, request: Request, user: Optional[User] = Depends(get_current_user)):
|
|
200
|
+
start = time.time()
|
|
201
|
+
request_body = await request.json()
|
|
202
|
+
result = await get_parameters_definition(dataset, user, request.headers, request_body)
|
|
203
|
+
timer.add_activity_time("POST REQUEST total time for PARAMETERS", start)
|
|
204
|
+
return result
|
|
105
205
|
|
|
106
206
|
# Results API
|
|
107
207
|
results_path = base_path + '/{dataset}'
|
|
108
208
|
|
|
109
|
-
results_cache_size =
|
|
110
|
-
results_cache_ttl =
|
|
209
|
+
results_cache_size = ManifestIO.obj.settings.get(c.RESULTS_CACHE_SIZE_SETTING, 128)
|
|
210
|
+
results_cache_ttl = ManifestIO.obj.settings.get(c.RESULTS_CACHE_TTL_SETTING, 0)
|
|
211
|
+
|
|
212
|
+
async def get_results_helper(
|
|
213
|
+
user: Optional[User], dataset: str, selections: Iterable[tuple[str, str]], request_version: Optional[int]
|
|
214
|
+
) -> pd.DataFrame:
|
|
215
|
+
dag = ModelsIO.GenerateDAG(dataset)
|
|
216
|
+
await dag.execute(ModelsIO.context_func, user, dict(selections), request_version=request_version)
|
|
217
|
+
return dag.target_model.result
|
|
218
|
+
|
|
219
|
+
results_cache = TTLCache(maxsize=results_cache_size, ttl=results_cache_ttl*60)
|
|
111
220
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
221
|
+
async def get_results_cachable(*args):
|
|
222
|
+
return await do_cachable_action(results_cache, get_results_helper, *args)
|
|
223
|
+
|
|
224
|
+
async def get_results_definition(dataset: str, user: Optional[User], headers: Mapping, params: Mapping):
|
|
225
|
+
api_function = get_results_helper if self.no_cache else get_results_cachable
|
|
226
|
+
result = await apply_dataset_api_function(api_function, user, dataset, headers, params)
|
|
227
|
+
return process_based_on_response_version_header(headers, {
|
|
228
|
+
0: lambda: u.df_to_json0(result)
|
|
229
|
+
})
|
|
115
230
|
|
|
116
231
|
@app.get(results_path, response_class=JSONResponse)
|
|
117
|
-
async def get_results(dataset: str, request: Request):
|
|
118
|
-
|
|
119
|
-
|
|
232
|
+
async def get_results(dataset: str, request: Request, user: Optional[User] = Depends(get_current_user)):
|
|
233
|
+
start = time.time()
|
|
234
|
+
result = await get_results_definition(dataset, user, request.headers, request.query_params)
|
|
235
|
+
timer.add_activity_time("GET REQUEST total time for DATASET", start)
|
|
236
|
+
return result
|
|
237
|
+
|
|
238
|
+
@app.post(results_path, response_class=JSONResponse)
|
|
239
|
+
async def get_results_with_post(dataset: str, request: Request, user: Optional[User] = Depends(get_current_user)):
|
|
240
|
+
start = time.time()
|
|
241
|
+
request_body = await request.json()
|
|
242
|
+
result = await get_results_definition(dataset, user, request.headers, request_body)
|
|
243
|
+
timer.add_activity_time("POST REQUEST total time for DATASET", start)
|
|
244
|
+
return result
|
|
120
245
|
|
|
121
246
|
# Catalog API
|
|
247
|
+
def get_catalog0(user: Optional[User]):
|
|
248
|
+
datasets_info = []
|
|
249
|
+
for dataset_name, dataset_config in self.dataset_configs.items():
|
|
250
|
+
if can_user_access_dataset(user, dataset_name):
|
|
251
|
+
dataset_normalized = u.normalize_name_for_api(dataset_name)
|
|
252
|
+
datasets_info.append({
|
|
253
|
+
'name': dataset_name,
|
|
254
|
+
'label': dataset_config.label,
|
|
255
|
+
'parameters_path': parameters_path.format(dataset=dataset_normalized),
|
|
256
|
+
'result_path': results_path.format(dataset=dataset_normalized),
|
|
257
|
+
'first_minor_version': 0
|
|
258
|
+
})
|
|
259
|
+
|
|
260
|
+
return {
|
|
261
|
+
'products': [{
|
|
262
|
+
'name': ManifestIO.obj.project_variables.get_name(),
|
|
263
|
+
'label': ManifestIO.obj.project_variables.get_label(),
|
|
264
|
+
'versions': [{
|
|
265
|
+
'major_version': ManifestIO.obj.project_variables.get_major_version(),
|
|
266
|
+
'latest_minor_version': ManifestIO.obj.project_variables.get_minor_version(),
|
|
267
|
+
'datasets': datasets_info
|
|
268
|
+
}]
|
|
269
|
+
}]
|
|
270
|
+
}
|
|
271
|
+
|
|
122
272
|
@app.get(squirrels_version_path, response_class=JSONResponse)
|
|
123
|
-
async def get_catalog():
|
|
124
|
-
|
|
125
|
-
|
|
273
|
+
async def get_catalog(request: Request, user: Optional[User] = Depends(get_current_user)):
|
|
274
|
+
return process_based_on_response_version_header(request.headers, {
|
|
275
|
+
0: lambda: get_catalog0(user)
|
|
276
|
+
})
|
|
126
277
|
|
|
127
278
|
# Squirrels UI
|
|
128
279
|
@app.get('/', response_class=HTMLResponse)
|
|
129
280
|
async def get_ui(request: Request):
|
|
130
|
-
return templates.TemplateResponse('index.html', {
|
|
281
|
+
return templates.TemplateResponse('index.html', {
|
|
282
|
+
'request': request, 'catalog_path': squirrels_version_path, 'token_path': token_path
|
|
283
|
+
})
|
|
131
284
|
|
|
132
285
|
# Run API server
|
|
133
286
|
import uvicorn
|
|
287
|
+
timer.add_activity_time("creating app for api server", start)
|
|
134
288
|
uvicorn.run(app, host=uvicorn_args.host, port=uvicorn_args.port)
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
from datetime import datetime, timedelta, timezone
|
|
3
|
+
from jose import JWTError, jwt
|
|
4
|
+
import secrets
|
|
5
|
+
|
|
6
|
+
from . import _utils as u, _constants as c
|
|
7
|
+
from ._py_module import PyModule
|
|
8
|
+
from .user_base import User, WrongPassword
|
|
9
|
+
from ._environcfg import EnvironConfigIO
|
|
10
|
+
from ._manifest import DatasetScope
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Authenticator:
|
|
14
|
+
|
|
15
|
+
@classmethod
|
|
16
|
+
def get_auth_helper(cls, default_auth_helper = None):
|
|
17
|
+
auth_module_path = u.join_paths(c.PYCONFIG_FOLDER, c.AUTH_FILE)
|
|
18
|
+
return PyModule(auth_module_path, default_class=default_auth_helper)
|
|
19
|
+
|
|
20
|
+
def __init__(self, token_expiry_minutes: int, auth_helper = None) -> None:
|
|
21
|
+
self.token_expiry_minutes = token_expiry_minutes
|
|
22
|
+
self.auth_helper = self.get_auth_helper(auth_helper)
|
|
23
|
+
self.secret_key = self._get_secret_key()
|
|
24
|
+
self.algorithm = "HS256"
|
|
25
|
+
|
|
26
|
+
def _get_secret_key(self):
|
|
27
|
+
secret_key = EnvironConfigIO.obj.get_secret(c.JWT_SECRET_KEY, default_factory=lambda: secrets.token_hex(32))
|
|
28
|
+
return secret_key
|
|
29
|
+
|
|
30
|
+
def authenticate_user(self, username: str, password: str) -> Optional[User]:
|
|
31
|
+
if self.auth_helper:
|
|
32
|
+
user_cls = self.auth_helper.get_func_or_class("User", default_attr=User)
|
|
33
|
+
get_user = self.auth_helper.get_func_or_class(c.GET_USER_FUNC)
|
|
34
|
+
try:
|
|
35
|
+
real_user = get_user(username, password)
|
|
36
|
+
except Exception as e:
|
|
37
|
+
raise u.FileExecutionError(f'Failed to run "{c.GET_USER_FUNC}" in {c.AUTH_FILE}', e)
|
|
38
|
+
else:
|
|
39
|
+
user_cls = User
|
|
40
|
+
real_user = None
|
|
41
|
+
|
|
42
|
+
if isinstance(real_user, User):
|
|
43
|
+
return real_user
|
|
44
|
+
|
|
45
|
+
if not isinstance(real_user, WrongPassword):
|
|
46
|
+
fake_users = EnvironConfigIO.obj.get_users()
|
|
47
|
+
if username in fake_users and secrets.compare_digest(fake_users[username][c.USER_PWD_KEY], password):
|
|
48
|
+
is_internal = fake_users[username].get("is_internal", False)
|
|
49
|
+
user = user_cls(username, is_internal=is_internal)
|
|
50
|
+
try:
|
|
51
|
+
return user.with_attributes(fake_users[username])
|
|
52
|
+
except Exception as e:
|
|
53
|
+
raise u.FileExecutionError(f'Failed to create user from User model in {c.AUTH_FILE}', e)
|
|
54
|
+
|
|
55
|
+
return None
|
|
56
|
+
|
|
57
|
+
def create_access_token(self, user: User) -> str:
|
|
58
|
+
expire = datetime.now(timezone.utc) + timedelta(minutes=self.token_expiry_minutes)
|
|
59
|
+
to_encode = {**vars(user), "exp": expire}
|
|
60
|
+
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
|
61
|
+
return encoded_jwt, expire
|
|
62
|
+
|
|
63
|
+
def get_user_from_token(self, token: Optional[str]) -> Optional[User]:
|
|
64
|
+
if token is not None:
|
|
65
|
+
try:
|
|
66
|
+
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
|
67
|
+
payload.pop("exp")
|
|
68
|
+
if self.auth_helper is not None:
|
|
69
|
+
user_cls: User = self.auth_helper.get_func_or_class("User", default_attr=User)
|
|
70
|
+
return user_cls._FromDict(payload)
|
|
71
|
+
else:
|
|
72
|
+
return User._FromDict(payload)
|
|
73
|
+
except JWTError:
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
def can_user_access_scope(self, user: Optional[User], scope: DatasetScope) -> bool:
|
|
77
|
+
if user is None:
|
|
78
|
+
user_level = DatasetScope.PUBLIC
|
|
79
|
+
elif not user.is_internal:
|
|
80
|
+
user_level = DatasetScope.PROTECTED
|
|
81
|
+
else:
|
|
82
|
+
user_level = DatasetScope.PRIVATE
|
|
83
|
+
|
|
84
|
+
return user_level.value >= scope.value
|