squirrels 0.1.1.post1__py3-none-any.whl → 0.2.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of squirrels might be problematic. Click here for more details.

Files changed (74) hide show
  1. squirrels/__init__.py +10 -16
  2. squirrels/_api_server.py +234 -80
  3. squirrels/_authenticator.py +84 -0
  4. squirrels/_command_line.py +60 -72
  5. squirrels/_connection_set.py +96 -0
  6. squirrels/_constants.py +114 -33
  7. squirrels/_environcfg.py +77 -0
  8. squirrels/_initializer.py +126 -67
  9. squirrels/_manifest.py +195 -168
  10. squirrels/_models.py +495 -0
  11. squirrels/_package_loader.py +26 -0
  12. squirrels/_parameter_configs.py +401 -0
  13. squirrels/_parameter_sets.py +188 -0
  14. squirrels/_py_module.py +60 -0
  15. squirrels/_timer.py +36 -0
  16. squirrels/_utils.py +81 -49
  17. squirrels/_version.py +2 -2
  18. squirrels/arguments/init_time_args.py +32 -0
  19. squirrels/arguments/run_time_args.py +82 -0
  20. squirrels/data_sources.py +380 -155
  21. squirrels/dateutils.py +86 -57
  22. squirrels/package_data/base_project/Dockerfile +15 -0
  23. squirrels/package_data/base_project/connections.yml +7 -0
  24. squirrels/package_data/base_project/database/{sample_database.db → expenses.db} +0 -0
  25. squirrels/package_data/base_project/environcfg.yml +29 -0
  26. squirrels/package_data/base_project/ignores/.dockerignore +8 -0
  27. squirrels/package_data/base_project/ignores/.gitignore +7 -0
  28. squirrels/package_data/base_project/models/dbviews/database_view1.py +36 -0
  29. squirrels/package_data/base_project/models/dbviews/database_view1.sql +15 -0
  30. squirrels/package_data/base_project/models/federates/dataset_example.py +20 -0
  31. squirrels/package_data/base_project/models/federates/dataset_example.sql +3 -0
  32. squirrels/package_data/base_project/parameters.yml +109 -0
  33. squirrels/package_data/base_project/pyconfigs/auth.py +47 -0
  34. squirrels/package_data/base_project/pyconfigs/connections.py +28 -0
  35. squirrels/package_data/base_project/pyconfigs/context.py +45 -0
  36. squirrels/package_data/base_project/pyconfigs/parameters.py +55 -0
  37. squirrels/package_data/base_project/seeds/mocks/category.csv +3 -0
  38. squirrels/package_data/base_project/seeds/mocks/max_filter.csv +2 -0
  39. squirrels/package_data/base_project/seeds/mocks/subcategory.csv +6 -0
  40. squirrels/package_data/base_project/squirrels.yml.j2 +57 -0
  41. squirrels/package_data/base_project/tmp/.gitignore +2 -0
  42. squirrels/package_data/static/script.js +159 -63
  43. squirrels/package_data/static/style.css +79 -15
  44. squirrels/package_data/static/widgets.js +133 -0
  45. squirrels/package_data/templates/index.html +65 -23
  46. squirrels/package_data/templates/index2.html +22 -0
  47. squirrels/parameter_options.py +216 -119
  48. squirrels/parameters.py +407 -478
  49. squirrels/user_base.py +58 -0
  50. squirrels-0.2.0.dev0.dist-info/METADATA +126 -0
  51. squirrels-0.2.0.dev0.dist-info/RECORD +56 -0
  52. {squirrels-0.1.1.post1.dist-info → squirrels-0.2.0.dev0.dist-info}/WHEEL +1 -2
  53. squirrels-0.2.0.dev0.dist-info/entry_points.txt +3 -0
  54. squirrels/_credentials_manager.py +0 -87
  55. squirrels/_module_loader.py +0 -37
  56. squirrels/_parameter_set.py +0 -151
  57. squirrels/_renderer.py +0 -286
  58. squirrels/_timed_imports.py +0 -37
  59. squirrels/connection_set.py +0 -126
  60. squirrels/package_data/base_project/.gitignore +0 -4
  61. squirrels/package_data/base_project/connections.py +0 -20
  62. squirrels/package_data/base_project/datasets/sample_dataset/context.py +0 -22
  63. squirrels/package_data/base_project/datasets/sample_dataset/database_view1.py +0 -29
  64. squirrels/package_data/base_project/datasets/sample_dataset/database_view1.sql.j2 +0 -12
  65. squirrels/package_data/base_project/datasets/sample_dataset/final_view.py +0 -11
  66. squirrels/package_data/base_project/datasets/sample_dataset/final_view.sql.j2 +0 -3
  67. squirrels/package_data/base_project/datasets/sample_dataset/parameters.py +0 -47
  68. squirrels/package_data/base_project/datasets/sample_dataset/selections.cfg +0 -9
  69. squirrels/package_data/base_project/squirrels.yaml +0 -22
  70. squirrels-0.1.1.post1.dist-info/METADATA +0 -67
  71. squirrels-0.1.1.post1.dist-info/RECORD +0 -40
  72. squirrels-0.1.1.post1.dist-info/entry_points.txt +0 -2
  73. squirrels-0.1.1.post1.dist-info/top_level.txt +0 -1
  74. {squirrels-0.1.1.post1.dist-info → squirrels-0.2.0.dev0.dist-info}/LICENSE +0 -0
squirrels/__init__.py CHANGED
@@ -1,18 +1,12 @@
1
- from .parameter_options import SelectParameterOption, DateParameterOption, NumberParameterOption, NumRangeParameterOption
2
- from .parameters import Parameter, SingleSelectParameter, MultiSelectParameter, DateParameter, NumberParameter, NumRangeParameter, DataSourceParameter
3
- from .data_sources import SelectionDataSource, DateDataSource, NumberDataSource, NumRangeDataSource
4
- from .connection_set import ConnectionSet
1
+ __version__ = '0.2.0'
5
2
 
3
+ from typing import Union
4
+ from sqlalchemy import Engine, Pool
5
+ from pandas import DataFrame
6
6
 
7
- def get_credential(key: str):
8
- """
9
- Gets the username and password that was set through "$squirrels set-credential [key]"
10
-
11
- Parameters:
12
- key (str): The credential key
13
-
14
- Returns:
15
- Credential: Object with attributes "username" and "password"
16
- """
17
- from ._credentials_manager import squirrels_config_io
18
- return squirrels_config_io.get_credential(key)
7
+ from .arguments.init_time_args import ConnectionsArgs, ParametersArgs
8
+ from .arguments.run_time_args import ContextArgs, ModelDepsArgs, ModelArgs
9
+ from .parameter_options import SelectParameterOption, DateParameterOption, DateRangeParameterOption, NumberParameterOption, NumberRangeParameterOption
10
+ from .parameters import Parameter, SingleSelectParameter, MultiSelectParameter, DateParameter, DateRangeParameter, NumberParameter, NumberRangeParameter
11
+ from .data_sources import SingleSelectDataSource, MultiSelectDataSource, DateDataSource, DateRangeDataSource, NumberDataSource, NumberRangeDataSource
12
+ from .user_base import User, WrongPassword
squirrels/_api_server.py CHANGED
@@ -1,134 +1,288 @@
1
- from typing import Dict, List, Tuple, Set
2
- from fastapi import FastAPI, Request, HTTPException
3
- from fastapi.datastructures import QueryParams
1
+ from typing import Iterable, Optional, Mapping, Callable, Coroutine, TypeVar, Any
2
+ from fastapi import Depends, FastAPI, Request, HTTPException, status
4
3
  from fastapi.responses import HTMLResponse, JSONResponse
5
4
  from fastapi.templating import Jinja2Templates
6
5
  from fastapi.staticfiles import StaticFiles
7
- from cachetools.func import ttl_cache
8
- import os, traceback
6
+ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from cachetools import TTLCache
9
+ import os, traceback, pandas as pd
9
10
 
10
- from squirrels import _constants as c, _utils
11
- from squirrels._version import major_version
12
- from squirrels._manifest import Manifest
13
- from squirrels.connection_set import ConnectionSet
14
- from squirrels._renderer import RendererIOWrapper, Renderer
11
+ from . import _constants as c, _utils as u
12
+ from ._version import sq_major_version
13
+ from ._manifest import ManifestIO
14
+ from ._authenticator import User, Authenticator
15
+ from ._timer import timer, time
16
+ from ._parameter_sets import ParameterSet
17
+ from ._models import ModelsIO
15
18
 
16
19
 
17
20
  class ApiServer:
18
- def __init__(self, manifest: Manifest, conn_set: ConnectionSet, no_cache: bool, debug: bool) -> None:
21
+ def __init__(self, no_cache: bool, debug: bool) -> None:
19
22
  """
20
23
  Constructor for ApiServer
21
24
 
22
25
  Parameters:
23
- manifest (Manifest): Manifest object produced from squirrels.yaml
24
- conn_set (ConnectionSet): Set of all connection pools defined in connections.py
25
26
  no_cache (bool): Whether to disable caching
26
27
  debug (bool): Set to True to show "hidden" parameters in the /parameters endpoint response
27
28
  """
28
- self.manifest = manifest
29
- self.conn_set = conn_set
30
29
  self.no_cache = no_cache
31
30
  self.debug = debug
31
+ self.dataset_configs = ManifestIO.obj.datasets
32
32
 
33
- self.datasets = manifest.get_all_dataset_names()
34
- self.renderers: Dict[str, Renderer] = {}
35
- for dataset in self.datasets:
36
- rendererIO = RendererIOWrapper(dataset, manifest, conn_set)
37
- self.renderers[dataset] = rendererIO.renderer
38
-
39
- def _get_parameters_helper(self, dataset: str, query_params: Set[Tuple[str, str]]) -> Dict:
40
- if len(query_params) > 1:
41
- raise _utils.InvalidInputError("The /parameters endpoint takes at most 1 query parameter")
42
- renderer = self.renderers[dataset]
43
- parameters = renderer.apply_selections(dict(query_params), updates_only = True)
44
- return parameters.to_json_dict(self.debug)
33
+ token_expiry_minutes = ManifestIO.obj.settings.get(c.AUTH_TOKEN_EXPIRE_SETTING, 30)
34
+ self.authenticator = Authenticator(token_expiry_minutes)
45
35
 
46
- def _get_results_helper(self, dataset: str, query_params: Set[Tuple[str, str]]) -> Dict:
47
- renderer = self.renderers[dataset]
48
- _, _, _, _, df = renderer.load_results(dict(query_params))
49
- return _utils.df_to_json(df)
50
-
51
- def _apply_api_function(self, api_function):
52
- try:
53
- return api_function()
54
- except _utils.InvalidInputError as e:
55
- traceback.print_exc()
56
- raise HTTPException(status_code=400, detail="Invalid User Input: "+str(e)) from e
57
- except _utils.ConfigurationError as e:
58
- traceback.print_exc()
59
- raise HTTPException(status_code=500, detail="Squirrels Configuration Error: "+str(e)) from e
60
- except Exception as e:
61
- traceback.print_exc()
62
- raise HTTPException(status_code=500, detail="Squirrels Framework Error: "+str(e)) from e
63
-
64
- def _apply_dataset_api_function(self, api_function, dataset: str, raw_query_params: QueryParams):
65
- dataset = _utils.normalize_name(dataset)
66
- query_params = set()
67
- for key, val in raw_query_params.items():
68
- query_params.add((_utils.normalize_name(key), val))
69
- query_params = frozenset(query_params)
70
- return self._apply_api_function(lambda: api_function(dataset, query_params))
71
-
72
- def run(self, uvicorn_args: List[str]) -> None:
36
+ def run(self, uvicorn_args: list[str]) -> None:
73
37
  """
74
38
  Runs the API server with uvicorn for CLI "squirrels run"
75
39
 
76
40
  Parameters:
77
- uvicorn_args (List[str]): List of arguments to pass to uvicorn.run. Currently only supports "host" and "port"
41
+ uvicorn_args: List of arguments to pass to uvicorn.run. Currently only supports "host" and "port"
78
42
  """
43
+ start = time.time()
79
44
  app = FastAPI()
80
45
 
81
- squirrels_version_path = f'/squirrels{major_version}'
82
- config_base_path = _utils.normalize_name_for_api(self.manifest.get_base_path())
83
- base_path = squirrels_version_path + config_base_path
46
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
47
+
48
+ squirrels_version_path = f'/squirrels-v{sq_major_version}'
49
+ partial_base_path = f'/{ManifestIO.obj.project_variables.get_name()}/v{ManifestIO.obj.project_variables.get_major_version()}'
50
+ base_path = squirrels_version_path + u.normalize_name_for_api(partial_base_path)
84
51
 
85
- static_dir = _utils.join_paths(os.path.dirname(__file__), 'package_data', 'static')
52
+ static_dir = u.join_paths(os.path.dirname(__file__), c.PACKAGE_DATA_FOLDER, c.STATIC_FOLDER)
86
53
  app.mount('/static', StaticFiles(directory=static_dir), name='static')
87
54
 
88
- templates_dir = _utils.join_paths(os.path.dirname(__file__), 'package_data', 'templates')
55
+ templates_dir = u.join_paths(os.path.dirname(__file__), c.PACKAGE_DATA_FOLDER, c.TEMPLATES_FOLDER)
89
56
  templates = Jinja2Templates(directory=templates_dir)
90
57
 
58
+ # Exception handlers
59
+ @app.exception_handler(u.InvalidInputError)
60
+ async def invalid_input_error_handler(request: Request, exc: u.InvalidInputError):
61
+ traceback.print_exc()
62
+ return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST,
63
+ content={"message": f"Invalid user input: {str(exc)}"})
64
+
65
+ @app.exception_handler(u.ConfigurationError)
66
+ async def configuration_error_handler(request: Request, exc: u.InvalidInputError):
67
+ traceback.print_exc()
68
+ return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
69
+ content={"message": f"Squirrels configuration error: {str(exc)}"})
70
+
71
+ @app.exception_handler(NotImplementedError)
72
+ async def not_implemented_error_handler(request: Request, exc: u.InvalidInputError):
73
+ traceback.print_exc()
74
+ return JSONResponse(status_code=status.HTTP_501_NOT_IMPLEMENTED,
75
+ content={"message": f"Not implemented error: {str(exc)}"})
76
+
77
+ # Helpers
78
+ T = TypeVar('T')
79
+
80
+ def get_versioning_request_header(headers: Mapping, header_key: str):
81
+ header_value = headers.get(header_key)
82
+ if header_value is None:
83
+ return None
84
+
85
+ try:
86
+ result = int(header_value)
87
+ except ValueError:
88
+ raise u.InvalidInputError(f"Request header '{header_key}' must be an integer. Got '{header_value}'")
89
+
90
+ if result < 0 or result > int(sq_major_version):
91
+ raise u.InvalidInputError(f"Request header '{header_key}' not in valid range. Got '{result}'")
92
+
93
+ return result
94
+
95
+ REQUEST_VERSION_REQUEST_HEADER = "squirrels-request-version"
96
+ def get_request_version_header(headers: Mapping):
97
+ return get_versioning_request_header(headers, REQUEST_VERSION_REQUEST_HEADER)
98
+
99
+ RESPONSE_VERSION_REQUEST_HEADER = "squirrels-response-version"
100
+ def process_based_on_response_version_header(headers: Mapping, processes: dict[str, Callable[[], T]]) -> T:
101
+ response_version = get_versioning_request_header(headers, RESPONSE_VERSION_REQUEST_HEADER)
102
+ if response_version is None or response_version >= 0:
103
+ return processes[0]()
104
+ else:
105
+ raise u.InvalidInputError(f'Invalid value for "{RESPONSE_VERSION_REQUEST_HEADER}" header: {response_version}')
106
+
107
+ def can_user_access_dataset(user: Optional[User], dataset: str):
108
+ dataset_scope = self.dataset_configs[dataset].scope
109
+ return self.authenticator.can_user_access_scope(user, dataset_scope)
110
+
111
+ async def apply_dataset_api_function(
112
+ api_function: Callable[..., Coroutine[Any, Any, T]], user: Optional[User], dataset: str, headers: Mapping, params: Mapping
113
+ ) -> T:
114
+ dataset_normalized = u.normalize_name(dataset)
115
+ if not can_user_access_dataset(user, dataset_normalized):
116
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
117
+ detail="Could not validate credentials",
118
+ headers={"WWW-Authenticate": "Bearer"})
119
+
120
+ request_version = get_request_version_header(headers)
121
+
122
+ # Changing selections into a cachable "frozenset" that will later be converted to dictionary
123
+ selections = set()
124
+ for key, val in params.items():
125
+ if not isinstance(val, str):
126
+ val = tuple(val)
127
+ selections.add((u.normalize_name(key), val))
128
+ selections = frozenset(selections)
129
+
130
+ return await api_function(user, dataset_normalized, selections, request_version)
131
+
132
+ # Login
133
+ token_path = base_path + '/token'
134
+
135
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl=token_path, auto_error=False)
136
+
137
+ @app.post(token_path)
138
+ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
139
+ user: Optional[User] = self.authenticator.authenticate_user(form_data.username, form_data.password)
140
+ if not user:
141
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
142
+ detail="Incorrect username or password",
143
+ headers={"WWW-Authenticate": "Bearer"})
144
+ access_token, expiry = self.authenticator.create_access_token(user)
145
+ return {
146
+ "access_token": access_token,
147
+ "token_type": "bearer",
148
+ "username": user.username,
149
+ "expiry_time": expiry
150
+ }
151
+
152
+ async def get_current_user(token: str = Depends(oauth2_scheme)) -> Optional[User]:
153
+ user = self.authenticator.get_user_from_token(token)
154
+ return user
155
+
156
+ async def do_cachable_action(cache: TTLCache, action: Callable[..., Coroutine[Any, Any, T]], *args) -> T:
157
+ cache_key = tuple(args)
158
+ result = cache.get(cache_key)
159
+ if result is None:
160
+ result = await action(*args)
161
+ cache[cache_key] = result
162
+ return result
163
+
91
164
  # Parameters API
92
165
  parameters_path = base_path + '/{dataset}/parameters'
93
166
 
94
- parameters_cache_size = self.manifest.get_setting(c.PARAMETERS_CACHE_SIZE_SETTING, 1024)
95
- parameters_cache_ttl = self.manifest.get_setting(c.PARAMETERS_CACHE_TTL_SETTING, 24*60*60)
167
+ parameters_cache_size = ManifestIO.obj.settings.get(c.PARAMETERS_CACHE_SIZE_SETTING, 1024)
168
+ parameters_cache_ttl = ManifestIO.obj.settings.get(c.PARAMETERS_CACHE_TTL_SETTING, 0)
169
+
170
+ async def get_parameters_helper(
171
+ user: Optional[User], dataset: str, selections: Iterable[tuple[str, str]], request_version: Optional[int]
172
+ ) -> ParameterSet:
173
+ if len(selections) > 1:
174
+ raise u.InvalidInputError(f"The /parameters endpoint takes at most 1 query parameter. Got {dict(selections)}")
175
+ dag = ModelsIO.GenerateDAG(dataset)
176
+ dag.apply_selections(user, dict(selections), request_version=request_version)
177
+ return dag.parameter_set
96
178
 
97
- @ttl_cache(maxsize=parameters_cache_size, ttl=parameters_cache_ttl)
98
- def get_parameters_cachable(*args):
99
- return self._get_parameters_helper(*args)
179
+ params_cache = TTLCache(maxsize=parameters_cache_size, ttl=parameters_cache_ttl*60)
180
+
181
+ async def get_parameters_cachable(*args) -> T:
182
+ return await do_cachable_action(params_cache, get_parameters_helper, *args)
183
+
184
+ async def get_parameters_definition(dataset: str, user: Optional[User], headers: Mapping, params: Mapping):
185
+ api_function = get_parameters_helper if self.no_cache else get_parameters_cachable
186
+ result = await apply_dataset_api_function(api_function, user, dataset, headers, params)
187
+ return process_based_on_response_version_header(headers, {
188
+ 0: result.to_json_dict0
189
+ })
100
190
 
101
191
  @app.get(parameters_path, response_class=JSONResponse)
102
- async def get_parameters(dataset: str, request: Request):
103
- api_function = self._get_parameters_helper if self.no_cache else get_parameters_cachable
104
- return self._apply_dataset_api_function(api_function, dataset, request.query_params)
192
+ async def get_parameters(dataset: str, request: Request, user: Optional[User] = Depends(get_current_user)):
193
+ start = time.time()
194
+ result = await get_parameters_definition(dataset, user, request.headers, request.query_params)
195
+ timer.add_activity_time("GET REQUEST total time for PARAMETERS", start)
196
+ return result
197
+
198
+ @app.post(parameters_path, response_class=JSONResponse)
199
+ async def get_parameters_with_post(dataset: str, request: Request, user: Optional[User] = Depends(get_current_user)):
200
+ start = time.time()
201
+ request_body = await request.json()
202
+ result = await get_parameters_definition(dataset, user, request.headers, request_body)
203
+ timer.add_activity_time("POST REQUEST total time for PARAMETERS", start)
204
+ return result
105
205
 
106
206
  # Results API
107
207
  results_path = base_path + '/{dataset}'
108
208
 
109
- results_cache_size = self.manifest.get_setting(c.RESULTS_CACHE_SIZE_SETTING, 128)
110
- results_cache_ttl = self.manifest.get_setting(c.RESULTS_CACHE_TTL_SETTING, 60*60)
209
+ results_cache_size = ManifestIO.obj.settings.get(c.RESULTS_CACHE_SIZE_SETTING, 128)
210
+ results_cache_ttl = ManifestIO.obj.settings.get(c.RESULTS_CACHE_TTL_SETTING, 0)
211
+
212
+ async def get_results_helper(
213
+ user: Optional[User], dataset: str, selections: Iterable[tuple[str, str]], request_version: Optional[int]
214
+ ) -> pd.DataFrame:
215
+ dag = ModelsIO.GenerateDAG(dataset)
216
+ await dag.execute(ModelsIO.context_func, user, dict(selections), request_version=request_version)
217
+ return dag.target_model.result
218
+
219
+ results_cache = TTLCache(maxsize=results_cache_size, ttl=results_cache_ttl*60)
111
220
 
112
- @ttl_cache(maxsize=results_cache_size, ttl=results_cache_ttl)
113
- def get_results_cachable(*args):
114
- return self._get_results_helper(*args)
221
+ async def get_results_cachable(*args):
222
+ return await do_cachable_action(results_cache, get_results_helper, *args)
223
+
224
+ async def get_results_definition(dataset: str, user: Optional[User], headers: Mapping, params: Mapping):
225
+ api_function = get_results_helper if self.no_cache else get_results_cachable
226
+ result = await apply_dataset_api_function(api_function, user, dataset, headers, params)
227
+ return process_based_on_response_version_header(headers, {
228
+ 0: lambda: u.df_to_json0(result)
229
+ })
115
230
 
116
231
  @app.get(results_path, response_class=JSONResponse)
117
- async def get_results(dataset: str, request: Request):
118
- api_function = self._get_results_helper if self.no_cache else get_results_cachable
119
- return self._apply_dataset_api_function(api_function, dataset, request.query_params)
232
+ async def get_results(dataset: str, request: Request, user: Optional[User] = Depends(get_current_user)):
233
+ start = time.time()
234
+ result = await get_results_definition(dataset, user, request.headers, request.query_params)
235
+ timer.add_activity_time("GET REQUEST total time for DATASET", start)
236
+ return result
237
+
238
+ @app.post(results_path, response_class=JSONResponse)
239
+ async def get_results_with_post(dataset: str, request: Request, user: Optional[User] = Depends(get_current_user)):
240
+ start = time.time()
241
+ request_body = await request.json()
242
+ result = await get_results_definition(dataset, user, request.headers, request_body)
243
+ timer.add_activity_time("POST REQUEST total time for DATASET", start)
244
+ return result
120
245
 
121
246
  # Catalog API
247
+ def get_catalog0(user: Optional[User]):
248
+ datasets_info = []
249
+ for dataset_name, dataset_config in self.dataset_configs.items():
250
+ if can_user_access_dataset(user, dataset_name):
251
+ dataset_normalized = u.normalize_name_for_api(dataset_name)
252
+ datasets_info.append({
253
+ 'name': dataset_name,
254
+ 'label': dataset_config.label,
255
+ 'parameters_path': parameters_path.format(dataset=dataset_normalized),
256
+ 'result_path': results_path.format(dataset=dataset_normalized),
257
+ 'first_minor_version': 0
258
+ })
259
+
260
+ return {
261
+ 'products': [{
262
+ 'name': ManifestIO.obj.project_variables.get_name(),
263
+ 'label': ManifestIO.obj.project_variables.get_label(),
264
+ 'versions': [{
265
+ 'major_version': ManifestIO.obj.project_variables.get_major_version(),
266
+ 'latest_minor_version': ManifestIO.obj.project_variables.get_minor_version(),
267
+ 'datasets': datasets_info
268
+ }]
269
+ }]
270
+ }
271
+
122
272
  @app.get(squirrels_version_path, response_class=JSONResponse)
123
- async def get_catalog():
124
- api_function = lambda: self.manifest.get_catalog(parameters_path, results_path)
125
- return self._apply_api_function(api_function)
273
+ async def get_catalog(request: Request, user: Optional[User] = Depends(get_current_user)):
274
+ return process_based_on_response_version_header(request.headers, {
275
+ 0: lambda: get_catalog0(user)
276
+ })
126
277
 
127
278
  # Squirrels UI
128
279
  @app.get('/', response_class=HTMLResponse)
129
280
  async def get_ui(request: Request):
130
- return templates.TemplateResponse('index.html', {'request': request, 'catalog_path': squirrels_version_path})
281
+ return templates.TemplateResponse('index.html', {
282
+ 'request': request, 'catalog_path': squirrels_version_path, 'token_path': token_path
283
+ })
131
284
 
132
285
  # Run API server
133
286
  import uvicorn
287
+ timer.add_activity_time("creating app for api server", start)
134
288
  uvicorn.run(app, host=uvicorn_args.host, port=uvicorn_args.port)
@@ -0,0 +1,84 @@
1
+ from typing import Optional
2
+ from datetime import datetime, timedelta, timezone
3
+ from jose import JWTError, jwt
4
+ import secrets
5
+
6
+ from . import _utils as u, _constants as c
7
+ from ._py_module import PyModule
8
+ from .user_base import User, WrongPassword
9
+ from ._environcfg import EnvironConfigIO
10
+ from ._manifest import DatasetScope
11
+
12
+
13
+ class Authenticator:
14
+
15
+ @classmethod
16
+ def get_auth_helper(cls, default_auth_helper = None):
17
+ auth_module_path = u.join_paths(c.PYCONFIG_FOLDER, c.AUTH_FILE)
18
+ return PyModule(auth_module_path, default_class=default_auth_helper)
19
+
20
+ def __init__(self, token_expiry_minutes: int, auth_helper = None) -> None:
21
+ self.token_expiry_minutes = token_expiry_minutes
22
+ self.auth_helper = self.get_auth_helper(auth_helper)
23
+ self.secret_key = self._get_secret_key()
24
+ self.algorithm = "HS256"
25
+
26
+ def _get_secret_key(self):
27
+ secret_key = EnvironConfigIO.obj.get_secret(c.JWT_SECRET_KEY, default_factory=lambda: secrets.token_hex(32))
28
+ return secret_key
29
+
30
+ def authenticate_user(self, username: str, password: str) -> Optional[User]:
31
+ if self.auth_helper:
32
+ user_cls = self.auth_helper.get_func_or_class("User", default_attr=User)
33
+ get_user = self.auth_helper.get_func_or_class(c.GET_USER_FUNC)
34
+ try:
35
+ real_user = get_user(username, password)
36
+ except Exception as e:
37
+ raise u.FileExecutionError(f'Failed to run "{c.GET_USER_FUNC}" in {c.AUTH_FILE}', e)
38
+ else:
39
+ user_cls = User
40
+ real_user = None
41
+
42
+ if isinstance(real_user, User):
43
+ return real_user
44
+
45
+ if not isinstance(real_user, WrongPassword):
46
+ fake_users = EnvironConfigIO.obj.get_users()
47
+ if username in fake_users and secrets.compare_digest(fake_users[username][c.USER_PWD_KEY], password):
48
+ is_internal = fake_users[username].get("is_internal", False)
49
+ user = user_cls(username, is_internal=is_internal)
50
+ try:
51
+ return user.with_attributes(fake_users[username])
52
+ except Exception as e:
53
+ raise u.FileExecutionError(f'Failed to create user from User model in {c.AUTH_FILE}', e)
54
+
55
+ return None
56
+
57
+ def create_access_token(self, user: User) -> str:
58
+ expire = datetime.now(timezone.utc) + timedelta(minutes=self.token_expiry_minutes)
59
+ to_encode = {**vars(user), "exp": expire}
60
+ encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
61
+ return encoded_jwt, expire
62
+
63
+ def get_user_from_token(self, token: Optional[str]) -> Optional[User]:
64
+ if token is not None:
65
+ try:
66
+ payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
67
+ payload.pop("exp")
68
+ if self.auth_helper is not None:
69
+ user_cls: User = self.auth_helper.get_func_or_class("User", default_attr=User)
70
+ return user_cls._FromDict(payload)
71
+ else:
72
+ return User._FromDict(payload)
73
+ except JWTError:
74
+ return None
75
+
76
+ def can_user_access_scope(self, user: Optional[User], scope: DatasetScope) -> bool:
77
+ if user is None:
78
+ user_level = DatasetScope.PUBLIC
79
+ elif not user.is_internal:
80
+ user_level = DatasetScope.PROTECTED
81
+ else:
82
+ user_level = DatasetScope.PRIVATE
83
+
84
+ return user_level.value >= scope.value