hyperpocket 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- hyperpocket/__init__.py +4 -4
- hyperpocket/auth/__init__.py +12 -7
- hyperpocket/auth/calendly/oauth2_handler.py +24 -17
- hyperpocket/auth/calendly/oauth2_schema.py +3 -1
- hyperpocket/auth/context.py +2 -1
- hyperpocket/auth/github/oauth2_handler.py +13 -8
- hyperpocket/auth/github/token_handler.py +27 -21
- hyperpocket/auth/google/context.py +1 -3
- hyperpocket/auth/google/oauth2_context.py +1 -1
- hyperpocket/auth/google/oauth2_handler.py +22 -17
- hyperpocket/auth/gumloop/token_context.py +1 -4
- hyperpocket/auth/gumloop/token_handler.py +48 -20
- hyperpocket/auth/gumloop/token_schema.py +2 -1
- hyperpocket/auth/handler.py +21 -6
- hyperpocket/auth/linear/token_context.py +2 -5
- hyperpocket/auth/linear/token_handler.py +45 -21
- hyperpocket/auth/notion/context.py +2 -2
- hyperpocket/auth/notion/token_context.py +2 -4
- hyperpocket/auth/notion/token_handler.py +45 -21
- hyperpocket/auth/notion/token_schema.py +0 -1
- hyperpocket/auth/reddit/oauth2_handler.py +9 -10
- hyperpocket/auth/reddit/oauth2_schema.py +0 -2
- hyperpocket/auth/schema.py +4 -1
- hyperpocket/auth/slack/oauth2_context.py +3 -1
- hyperpocket/auth/slack/oauth2_handler.py +55 -35
- hyperpocket/auth/slack/token_context.py +2 -4
- hyperpocket/auth/slack/token_handler.py +42 -19
- hyperpocket/builtin.py +4 -2
- hyperpocket/cli/__main__.py +4 -2
- hyperpocket/cli/auth.py +59 -28
- hyperpocket/cli/codegen/auth/auth_context_template.py +3 -2
- hyperpocket/cli/codegen/auth/auth_token_context_template.py +3 -2
- hyperpocket/cli/codegen/auth/auth_token_handler_template.py +6 -5
- hyperpocket/cli/codegen/auth/auth_token_schema_template.py +3 -2
- hyperpocket/cli/codegen/auth/server_auth_template.py +3 -2
- hyperpocket/cli/pull.py +5 -5
- hyperpocket/config/__init__.py +3 -8
- hyperpocket/config/auth.py +3 -1
- hyperpocket/config/logger.py +20 -15
- hyperpocket/config/session.py +4 -2
- hyperpocket/config/settings.py +19 -2
- hyperpocket/futures/__init__.py +1 -1
- hyperpocket/futures/futurestore.py +3 -2
- hyperpocket/pocket_auth.py +171 -84
- hyperpocket/pocket_core.py +51 -33
- hyperpocket/pocket_main.py +122 -93
- hyperpocket/prompts.py +2 -2
- hyperpocket/repository/__init__.py +1 -1
- hyperpocket/repository/lock.py +47 -33
- hyperpocket/repository/lockfile.py +2 -2
- hyperpocket/repository/repository.py +1 -1
- hyperpocket/server/__init__.py +1 -1
- hyperpocket/server/auth/github.py +2 -1
- hyperpocket/server/auth/linear.py +1 -3
- hyperpocket/server/auth/notion.py +2 -5
- hyperpocket/server/auth/slack.py +1 -3
- hyperpocket/server/auth/token.py +17 -11
- hyperpocket/server/proxy.py +29 -13
- hyperpocket/server/server.py +75 -31
- hyperpocket/server/tool/dto/script.py +15 -10
- hyperpocket/server/tool/wasm.py +14 -11
- hyperpocket/session/__init__.py +6 -2
- hyperpocket/session/in_memory.py +44 -24
- hyperpocket/session/interface.py +42 -24
- hyperpocket/session/redis.py +48 -31
- hyperpocket/tool/__init__.py +10 -10
- hyperpocket/tool/function/__init__.py +1 -5
- hyperpocket/tool/function/annotation.py +11 -9
- hyperpocket/tool/function/tool.py +37 -27
- hyperpocket/tool/tool.py +59 -36
- hyperpocket/tool/wasm/__init__.py +1 -1
- hyperpocket/tool/wasm/browser.py +15 -10
- hyperpocket/tool/wasm/invoker.py +16 -16
- hyperpocket/tool/wasm/script.py +27 -14
- hyperpocket/tool/wasm/templates/__init__.py +22 -15
- hyperpocket/tool/wasm/templates/node.py +2 -2
- hyperpocket/tool/wasm/templates/python.py +2 -2
- hyperpocket/tool/wasm/tool.py +27 -14
- hyperpocket/tool_like.py +3 -3
- hyperpocket/util/__init__.py +1 -1
- hyperpocket/util/extract_func_param_desc_from_docstring.py +23 -7
- hyperpocket/util/find_all_leaf_class_in_package.py +4 -3
- hyperpocket/util/find_all_subclass_in_package.py +4 -2
- hyperpocket/util/flatten_json_schema.py +10 -6
- hyperpocket/util/function_to_model.py +33 -12
- hyperpocket/util/get_objects_from_subpackage.py +1 -1
- hyperpocket/util/json_schema_to_model.py +14 -5
- {hyperpocket-0.1.10.dist-info → hyperpocket-0.2.0.dist-info}/METADATA +11 -5
- hyperpocket-0.2.0.dist-info/RECORD +137 -0
- hyperpocket-0.1.10.dist-info/RECORD +0 -137
- {hyperpocket-0.1.10.dist-info → hyperpocket-0.2.0.dist-info}/WHEEL +0 -0
- {hyperpocket-0.1.10.dist-info → hyperpocket-0.2.0.dist-info}/entry_points.txt +0 -0
hyperpocket/server/auth/token.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
from http import HTTPStatus
|
2
|
-
from urllib.parse import
|
2
|
+
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
3
3
|
|
4
4
|
from fastapi import APIRouter, Form
|
5
5
|
from starlette.responses import HTMLResponse, RedirectResponse
|
@@ -29,8 +29,12 @@ async def token_form(redirect_uri: str, state: str = ""):
|
|
29
29
|
|
30
30
|
|
31
31
|
@token_router.post("/submit", response_class=RedirectResponse)
|
32
|
-
async def submit_token(
|
33
|
-
|
32
|
+
async def submit_token(
|
33
|
+
user_token: str = Form(...), redirect_uri: str = Form(...), state: str = Form(...)
|
34
|
+
):
|
35
|
+
new_callback_url = add_query_params(
|
36
|
+
redirect_uri, {"token": user_token, "state": state}
|
37
|
+
)
|
34
38
|
return RedirectResponse(url=new_callback_url, status_code=HTTPStatus.SEE_OTHER)
|
35
39
|
|
36
40
|
|
@@ -40,12 +44,14 @@ def add_query_params(url: str, params: dict):
|
|
40
44
|
query_params.update(params)
|
41
45
|
new_query = urlencode(query_params, doseq=True)
|
42
46
|
|
43
|
-
new_url = urlunparse(
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
47
|
+
new_url = urlunparse(
|
48
|
+
(
|
49
|
+
url_parts.scheme,
|
50
|
+
url_parts.netloc,
|
51
|
+
url_parts.path,
|
52
|
+
url_parts.params,
|
53
|
+
new_query,
|
54
|
+
url_parts.fragment,
|
55
|
+
)
|
56
|
+
)
|
51
57
|
return new_url
|
hyperpocket/server/proxy.py
CHANGED
@@ -9,22 +9,27 @@ async def proxy(request: Request, path: str):
|
|
9
9
|
async with httpx.AsyncClient() as client:
|
10
10
|
resp = await client.request(
|
11
11
|
method=request.method,
|
12
|
-
url=f"{config.internal_base_url}/{path}",
|
12
|
+
url=f"{config().internal_base_url}/{path}",
|
13
13
|
headers=request.headers,
|
14
14
|
content=await request.body(),
|
15
15
|
params=request.query_params,
|
16
16
|
timeout=300,
|
17
17
|
)
|
18
|
-
return HTMLResponse(
|
18
|
+
return HTMLResponse(
|
19
|
+
content=resp.text, headers=resp.headers, status_code=resp.status_code
|
20
|
+
)
|
19
21
|
|
20
22
|
|
21
23
|
def add_callback_proxy(app: FastAPI):
|
22
|
-
app.add_api_route(
|
23
|
-
|
24
|
+
app.add_api_route(
|
25
|
+
f"/{config().callback_url_rewrite_prefix}/{{path:path}}",
|
26
|
+
proxy,
|
27
|
+
methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
|
28
|
+
)
|
24
29
|
|
25
30
|
|
26
31
|
https_proxy_app = None
|
27
|
-
if config.enable_local_callback_proxy:
|
32
|
+
if config().enable_local_callback_proxy:
|
28
33
|
https_proxy_app = FastAPI()
|
29
34
|
add_callback_proxy(https_proxy_app)
|
30
35
|
|
@@ -44,20 +49,31 @@ def _generate_ssl_certificates(ssl_keypath, ssl_certpath):
|
|
44
49
|
"/emailAddress=local@example.com"
|
45
50
|
)
|
46
51
|
command = [
|
47
|
-
"openssl",
|
48
|
-
"
|
49
|
-
"-
|
50
|
-
"-
|
51
|
-
"
|
52
|
+
"openssl",
|
53
|
+
"req",
|
54
|
+
"-x509",
|
55
|
+
"-newkey",
|
56
|
+
"rsa:4096",
|
57
|
+
"-keyout",
|
58
|
+
ssl_keypath,
|
59
|
+
"-out",
|
60
|
+
ssl_certpath,
|
61
|
+
"-days",
|
62
|
+
"1",
|
52
63
|
"-nodes",
|
53
|
-
|
64
|
+
"-subj",
|
65
|
+
subj,
|
54
66
|
"-sha256",
|
55
67
|
]
|
56
68
|
|
57
69
|
try:
|
58
70
|
# 명령 실행
|
59
|
-
subprocess.run(
|
60
|
-
|
71
|
+
subprocess.run(
|
72
|
+
command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
|
73
|
+
)
|
74
|
+
pocket_logger.info(
|
75
|
+
"SSL Certificates generated: callback_server.key, callback_server.crt"
|
76
|
+
)
|
61
77
|
except subprocess.CalledProcessError as e:
|
62
78
|
pocket_logger.warning(f"An error occurred while generating certificates: {e}")
|
63
79
|
raise e
|
hyperpocket/server/server.py
CHANGED
@@ -30,9 +30,11 @@ class PocketServer(object):
|
|
30
30
|
future_store: dict[str, asyncio.Future]
|
31
31
|
torn_down: bool = False
|
32
32
|
|
33
|
-
def __init__(
|
34
|
-
|
35
|
-
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
internal_server_port: int = config().internal_server_port,
|
36
|
+
proxy_port: int = config().public_server_port,
|
37
|
+
):
|
36
38
|
self.internal_server_port = internal_server_port
|
37
39
|
self.proxy_port = proxy_port
|
38
40
|
self.future_store = dict()
|
@@ -49,7 +51,9 @@ class PocketServer(object):
|
|
49
51
|
try:
|
50
52
|
await asyncio.gather(
|
51
53
|
self.main_server.serve(),
|
52
|
-
self.proxy_server.serve()
|
54
|
+
self.proxy_server.serve()
|
55
|
+
if self.proxy_server is not None
|
56
|
+
else asyncio.sleep(0),
|
53
57
|
self.poll_in_child(),
|
54
58
|
)
|
55
59
|
except Exception as e:
|
@@ -75,7 +79,9 @@ class PocketServer(object):
|
|
75
79
|
result = self.pocket_core.prepare_auth(*a, **kw)
|
76
80
|
error = None
|
77
81
|
except Exception as e:
|
78
|
-
pocket_logger.error(
|
82
|
+
pocket_logger.error(
|
83
|
+
f"failed to prepare in pocket subprocess. error: {e}"
|
84
|
+
)
|
79
85
|
result = None
|
80
86
|
error = e
|
81
87
|
|
@@ -86,7 +92,9 @@ class PocketServer(object):
|
|
86
92
|
result = await self.pocket_core.authenticate(*a, **kw)
|
87
93
|
error = None
|
88
94
|
except Exception as e:
|
89
|
-
pocket_logger.error(
|
95
|
+
pocket_logger.error(
|
96
|
+
f"failed to authenticate in pocket subprocess. error: {e}"
|
97
|
+
)
|
90
98
|
result = None
|
91
99
|
error = e
|
92
100
|
|
@@ -97,7 +105,9 @@ class PocketServer(object):
|
|
97
105
|
result = await self.pocket_core.tool_call(*a, **kw)
|
98
106
|
error = None
|
99
107
|
except Exception as e:
|
100
|
-
pocket_logger.error(
|
108
|
+
pocket_logger.error(
|
109
|
+
f"failed to tool_call in pocket subprocess. error: {e}"
|
110
|
+
)
|
101
111
|
result = None
|
102
112
|
error = e
|
103
113
|
|
@@ -119,10 +129,7 @@ class PocketServer(object):
|
|
119
129
|
else:
|
120
130
|
await asyncio.sleep(0)
|
121
131
|
|
122
|
-
def send_in_parent(self,
|
123
|
-
op: PocketServerOperations,
|
124
|
-
args: tuple,
|
125
|
-
kwargs: dict):
|
132
|
+
def send_in_parent(self, op: PocketServerOperations, args: tuple, kwargs: dict):
|
126
133
|
conn, _ = self.pipe
|
127
134
|
uid = str(uuid.uuid4())
|
128
135
|
message = (op.value, uid, args, kwargs)
|
@@ -145,10 +152,9 @@ class PocketServer(object):
|
|
145
152
|
else:
|
146
153
|
await asyncio.sleep(0)
|
147
154
|
|
148
|
-
async def call_in_subprocess(
|
149
|
-
|
150
|
-
|
151
|
-
kwargs: dict):
|
155
|
+
async def call_in_subprocess(
|
156
|
+
self, op: PocketServerOperations, args: tuple, kwargs: dict
|
157
|
+
):
|
152
158
|
uid = self.send_in_parent(op, args, kwargs)
|
153
159
|
loop = asyncio.get_running_loop()
|
154
160
|
loop.create_task(self.poll_in_parent())
|
@@ -157,21 +163,52 @@ class PocketServer(object):
|
|
157
163
|
def run(self, pocket_core: PocketCore):
|
158
164
|
self._set_mp_start_method()
|
159
165
|
|
166
|
+
error_queue = mp.Queue()
|
160
167
|
self.pipe = mp.Pipe()
|
161
|
-
self.process = mp.Process(
|
162
|
-
|
168
|
+
self.process = mp.Process(
|
169
|
+
target=self._run, args=(pocket_core, )
|
170
|
+
)
|
171
|
+
self.process.start() # process start
|
172
|
+
|
173
|
+
if not error_queue.empty():
|
174
|
+
error_message = error_queue.get()
|
175
|
+
raise error_message
|
176
|
+
|
177
|
+
def _report_initialized(self, error: Optional[Exception] = None):
|
178
|
+
_, conn = self.pipe
|
179
|
+
conn.send(('server-initialization', error,))
|
180
|
+
|
181
|
+
def wait_initialized(self):
|
182
|
+
conn, _ = self.pipe
|
183
|
+
while True:
|
184
|
+
if conn.poll():
|
185
|
+
_, error = conn.recv()
|
186
|
+
if error:
|
187
|
+
raise error
|
188
|
+
return
|
163
189
|
|
164
190
|
def _run(self, pocket_core):
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
191
|
+
try:
|
192
|
+
# init process
|
193
|
+
self.pocket_core = pocket_core
|
194
|
+
self.main_server = self._create_main_server()
|
195
|
+
self.proxy_server = self._create_https_proxy_server()
|
196
|
+
self._report_initialized()
|
197
|
+
|
198
|
+
loop = asyncio.new_event_loop()
|
199
|
+
loop.run_until_complete(self._run_async())
|
200
|
+
loop.close()
|
201
|
+
except Exception as error:
|
202
|
+
self._report_initialized(error)
|
171
203
|
|
172
204
|
def _create_main_server(self) -> Server:
|
173
205
|
app = FastAPI()
|
174
|
-
_config = Config(
|
206
|
+
_config = Config(
|
207
|
+
app,
|
208
|
+
host="0.0.0.0",
|
209
|
+
port=self.internal_server_port,
|
210
|
+
log_level=config().log_level,
|
211
|
+
)
|
175
212
|
app.include_router(tool_router)
|
176
213
|
app.include_router(auth_router)
|
177
214
|
app.add_api_route("/health", lambda: {"status": "ok"}, methods=["GET"])
|
@@ -180,20 +217,25 @@ class PocketServer(object):
|
|
180
217
|
return app
|
181
218
|
|
182
219
|
def _create_https_proxy_server(self) -> Optional[Server]:
|
183
|
-
if not config.enable_local_callback_proxy:
|
220
|
+
if not config().enable_local_callback_proxy:
|
184
221
|
return None
|
185
|
-
from hyperpocket.server.proxy import _generate_ssl_certificates
|
186
|
-
from hyperpocket.server.proxy import https_proxy_app
|
187
|
-
|
188
222
|
from hyperpocket.config.settings import POCKET_ROOT
|
223
|
+
from hyperpocket.server.proxy import _generate_ssl_certificates, https_proxy_app
|
224
|
+
|
189
225
|
ssl_keypath = POCKET_ROOT / "callback_server.key"
|
190
226
|
ssl_certpath = POCKET_ROOT / "callback_server.crt"
|
191
227
|
|
192
228
|
if not ssl_keypath.exists() or not ssl_certpath.exists():
|
193
229
|
_generate_ssl_certificates(ssl_keypath, ssl_certpath)
|
194
230
|
|
195
|
-
_config = Config(
|
196
|
-
|
231
|
+
_config = Config(
|
232
|
+
https_proxy_app,
|
233
|
+
host="0.0.0.0",
|
234
|
+
port=self.proxy_port,
|
235
|
+
ssl_keyfile=ssl_keypath,
|
236
|
+
ssl_certfile=ssl_certpath,
|
237
|
+
log_level=config().log_level,
|
238
|
+
)
|
197
239
|
proxy_server = Server(_config)
|
198
240
|
return proxy_server
|
199
241
|
|
@@ -211,4 +253,6 @@ class PocketServer(object):
|
|
211
253
|
mp.set_start_method("fork", force=True)
|
212
254
|
pocket_logger.debug("Process start method set to 'fork' for Linux.")
|
213
255
|
else:
|
214
|
-
pocket_logger.debug(
|
256
|
+
pocket_logger.debug(
|
257
|
+
f"Unrecognized OS: {os_name}. Default start method will be used."
|
258
|
+
)
|
@@ -1,28 +1,33 @@
|
|
1
1
|
from typing import Optional
|
2
|
+
|
2
3
|
from pydantic import BaseModel, Field
|
3
4
|
|
4
5
|
from hyperpocket.tool.wasm.script import ScriptFileNode
|
5
6
|
|
6
7
|
|
7
8
|
class Script(BaseModel):
|
8
|
-
id: str = Field(alias=
|
9
|
-
tool_id: str = Field(alias=
|
9
|
+
id: str = Field(alias="id")
|
10
|
+
tool_id: str = Field(alias="tool_id")
|
10
11
|
|
11
12
|
|
12
13
|
class ScriptResult(BaseModel):
|
13
|
-
stdout: Optional[str] = Field(alias=
|
14
|
-
stderr: Optional[str] = Field(alias=
|
15
|
-
error: Optional[str] = Field(alias=
|
14
|
+
stdout: Optional[str] = Field(alias="stdout", default=None)
|
15
|
+
stderr: Optional[str] = Field(alias="stderr", default=None)
|
16
|
+
error: Optional[str] = Field(alias="error", default=None)
|
17
|
+
|
16
18
|
|
17
19
|
class ScriptFileTree(BaseModel):
|
18
|
-
tree: dict[str, ScriptFileNode] = Field(alias=
|
20
|
+
tree: dict[str, ScriptFileNode] = Field(alias="tree")
|
21
|
+
|
19
22
|
|
20
23
|
class ScriptEntrypoint(BaseModel):
|
21
|
-
package_name: Optional[str] = Field(alias=
|
22
|
-
entrypoint: str = Field(alias=
|
24
|
+
package_name: Optional[str] = Field(alias="package_name")
|
25
|
+
entrypoint: str = Field(alias="entrypoint")
|
26
|
+
|
23
27
|
|
24
28
|
class ScriptEncodedFile(BaseModel):
|
25
|
-
encoded_file: str = Field(alias=
|
29
|
+
encoded_file: str = Field(alias="encoded_file")
|
30
|
+
|
26
31
|
|
27
32
|
class ScriptFileRequest(BaseModel):
|
28
|
-
path: str = Field(alias=
|
33
|
+
path: str = Field(alias="path")
|
hyperpocket/server/tool/wasm.py
CHANGED
@@ -1,13 +1,11 @@
|
|
1
1
|
from fastapi import APIRouter
|
2
|
-
from fastapi.responses import
|
2
|
+
from fastapi.responses import FileResponse, HTMLResponse
|
3
3
|
|
4
4
|
from hyperpocket.futures import FutureStore
|
5
5
|
from hyperpocket.server.tool.dto import script as scriptdto
|
6
6
|
from hyperpocket.tool.wasm.script import ScriptStore
|
7
7
|
|
8
|
-
wasm_tool_router = APIRouter(
|
9
|
-
prefix="/wasm"
|
10
|
-
)
|
8
|
+
wasm_tool_router = APIRouter(prefix="/wasm")
|
11
9
|
|
12
10
|
|
13
11
|
@wasm_tool_router.get("/scripts/{script_id}/browse", response_class=HTMLResponse)
|
@@ -17,19 +15,21 @@ async def browse_script_page(script_id: str):
|
|
17
15
|
|
18
16
|
|
19
17
|
@wasm_tool_router.post("/scripts/{script_id}/done")
|
20
|
-
async def done_script_page(
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
18
|
+
async def done_script_page(
|
19
|
+
script_id: str, req: scriptdto.ScriptResult
|
20
|
+
) -> scriptdto.ScriptResult:
|
21
|
+
FutureStore.resolve_future(
|
22
|
+
script_id, {"stdout": req.stdout, "stderr": req.stderr, "error": req.error}
|
23
|
+
)
|
26
24
|
return req
|
27
25
|
|
26
|
+
|
28
27
|
@wasm_tool_router.get("/scripts/{script_id}/file_tree")
|
29
28
|
async def get_file_tree(script_id: str) -> scriptdto.ScriptFileTree:
|
30
29
|
script = ScriptStore.get_script(script_id)
|
31
30
|
return scriptdto.ScriptFileTree(tree=script.load_file_tree())
|
32
31
|
|
32
|
+
|
33
33
|
@wasm_tool_router.get("/scripts/{script_id}/entrypoint")
|
34
34
|
async def get_entrypoint(script_id: str) -> scriptdto.ScriptEntrypoint:
|
35
35
|
script = ScriptStore.get_script(script_id)
|
@@ -37,7 +37,10 @@ async def get_entrypoint(script_id: str) -> scriptdto.ScriptEntrypoint:
|
|
37
37
|
entrypoint = f"/tools/wasm/scripts/{script_id}/file/{script.entrypoint}"
|
38
38
|
return scriptdto.ScriptEntrypoint(package_name=package_name, entrypoint=entrypoint)
|
39
39
|
|
40
|
-
|
40
|
+
|
41
|
+
@wasm_tool_router.get(
|
42
|
+
"/scripts/{script_id}/file/{file_name}", response_class=FileResponse
|
43
|
+
)
|
41
44
|
async def get_dist_file(script_id: str, file_name: str):
|
42
45
|
script = ScriptStore.get_script(script_id)
|
43
46
|
return FileResponse(script.dist_file_path(file_name))
|
hyperpocket/session/__init__.py
CHANGED
@@ -1,4 +1,8 @@
|
|
1
1
|
from hyperpocket.session.interface import SessionStorageInterface
|
2
|
-
from hyperpocket.util.find_all_leaf_class_in_package import
|
2
|
+
from hyperpocket.util.find_all_leaf_class_in_package import (
|
3
|
+
find_all_leaf_class_in_package,
|
4
|
+
)
|
3
5
|
|
4
|
-
SESSION_STORAGE_LIST = find_all_leaf_class_in_package(
|
6
|
+
SESSION_STORAGE_LIST = find_all_leaf_class_in_package(
|
7
|
+
"hyperpocket.session", SessionStorageInterface
|
8
|
+
)
|
hyperpocket/session/in_memory.py
CHANGED
@@ -3,16 +3,22 @@ from typing import Dict, List, Optional
|
|
3
3
|
|
4
4
|
from hyperpocket.auth import AuthProvider
|
5
5
|
from hyperpocket.auth.context import AuthContext
|
6
|
-
from hyperpocket.config.session import SessionConfigInMemory
|
7
|
-
from hyperpocket.
|
8
|
-
|
9
|
-
BaseSessionValue,
|
6
|
+
from hyperpocket.config.session import SessionConfigInMemory, SessionType
|
7
|
+
from hyperpocket.session.interface import (
|
8
|
+
SESSION_KEY_DELIMITER,
|
9
|
+
BaseSessionValue,
|
10
|
+
K,
|
11
|
+
SessionStorageInterface,
|
12
|
+
V,
|
13
|
+
)
|
10
14
|
|
11
15
|
InMemorySessionKey = str
|
12
16
|
InMemorySessionValue = BaseSessionValue
|
13
17
|
|
14
18
|
|
15
|
-
class InMemorySessionStorage(
|
19
|
+
class InMemorySessionStorage(
|
20
|
+
SessionStorageInterface[InMemorySessionKey, InMemorySessionValue]
|
21
|
+
):
|
16
22
|
# TODO(moon) : Force it to always take SessionConfig as an input
|
17
23
|
def __init__(self, session_config: SessionConfigInMemory):
|
18
24
|
super().__init__()
|
@@ -22,41 +28,54 @@ class InMemorySessionStorage(SessionStorageInterface[InMemorySessionKey, InMemor
|
|
22
28
|
def session_storage_type(cls) -> SessionType:
|
23
29
|
return SessionType.IN_MEMORY
|
24
30
|
|
25
|
-
def get(
|
31
|
+
def get(
|
32
|
+
self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs
|
33
|
+
) -> Optional[V]:
|
26
34
|
key = self._make_session_key(auth_provider.name, thread_id, profile)
|
27
35
|
return self.storage.get(key, None)
|
28
36
|
|
29
|
-
def get_by_thread_id(
|
37
|
+
def get_by_thread_id(
|
38
|
+
self, thread_id: str, auth_provider: Optional[AuthProvider] = None, **kwargs
|
39
|
+
) -> List[V]:
|
30
40
|
if auth_provider is None:
|
31
41
|
auth_provider_name = ".*"
|
32
42
|
else:
|
33
43
|
auth_provider_name = auth_provider.name
|
34
44
|
|
35
|
-
pattern = rf
|
45
|
+
pattern = rf"{self._make_session_key(auth_provider_name, thread_id, '.*')}"
|
36
46
|
compiled = re.compile(pattern)
|
37
47
|
|
38
|
-
session_list = [
|
48
|
+
session_list = [
|
49
|
+
value for key, value in self.storage.items() if compiled.match(key)
|
50
|
+
]
|
39
51
|
return session_list
|
40
52
|
|
41
|
-
def set(
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
53
|
+
def set(
|
54
|
+
self,
|
55
|
+
auth_provider: AuthProvider,
|
56
|
+
thread_id: str,
|
57
|
+
profile: str,
|
58
|
+
auth_scopes: List[str],
|
59
|
+
auth_resolve_uid: Optional[str],
|
60
|
+
auth_context: Optional[AuthContext],
|
61
|
+
is_auth_scope_universal: bool,
|
62
|
+
**kwargs,
|
63
|
+
) -> V:
|
48
64
|
key = self._make_session_key(auth_provider.name, thread_id, profile)
|
49
65
|
session = self._make_session(
|
50
66
|
auth_provider_name=auth_provider.name,
|
51
67
|
auth_scopes=auth_scopes,
|
52
68
|
auth_resolve_uid=auth_resolve_uid,
|
53
69
|
auth_context=auth_context,
|
54
|
-
is_auth_scope_universal=is_auth_scope_universal
|
70
|
+
is_auth_scope_universal=is_auth_scope_universal,
|
71
|
+
)
|
55
72
|
|
56
73
|
self.storage[key] = session
|
57
74
|
return session
|
58
75
|
|
59
|
-
def delete(
|
76
|
+
def delete(
|
77
|
+
self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs
|
78
|
+
) -> bool:
|
60
79
|
key = self._make_session_key(auth_provider.name, thread_id, profile)
|
61
80
|
if key in self.storage:
|
62
81
|
self.storage.pop(key)
|
@@ -75,15 +94,16 @@ class InMemorySessionStorage(SessionStorageInterface[InMemorySessionKey, InMemor
|
|
75
94
|
|
76
95
|
@staticmethod
|
77
96
|
def _make_session(
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
97
|
+
auth_provider_name: str,
|
98
|
+
auth_scopes: List[str],
|
99
|
+
auth_context: AuthContext,
|
100
|
+
auth_resolve_uid: str,
|
101
|
+
is_auth_scope_universal: bool,
|
102
|
+
) -> V:
|
83
103
|
return InMemorySessionValue(
|
84
104
|
auth_provider_name=auth_provider_name,
|
85
105
|
auth_scopes=set(auth_scopes),
|
86
106
|
auth_context=auth_context,
|
87
107
|
auth_resolve_uid=auth_resolve_uid,
|
88
|
-
scoped=is_auth_scope_universal
|
108
|
+
scoped=is_auth_scope_universal,
|
89
109
|
)
|
hyperpocket/session/interface.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import datetime
|
2
2
|
from abc import ABC, abstractmethod
|
3
|
-
from typing import
|
3
|
+
from typing import Generic, Iterable, List, Optional, Set, TypeVar
|
4
4
|
|
5
5
|
from pydantic import BaseModel, Field
|
6
6
|
|
@@ -15,28 +15,37 @@ SESSION_KEY_DELIMITER = "__"
|
|
15
15
|
|
16
16
|
class BaseSessionValue(BaseModel):
|
17
17
|
auth_provider_name: str = Field(
|
18
|
-
description="The name of the authentication provider used to authenticate the current session"
|
18
|
+
description="The name of the authentication provider used to authenticate the current session"
|
19
|
+
)
|
19
20
|
auth_context: Optional[AuthContext] = Field(
|
20
21
|
default=None,
|
21
|
-
description="The authentication context containing the actual session details"
|
22
|
-
|
22
|
+
description="The authentication context containing the actual session details",
|
23
|
+
)
|
24
|
+
scoped: bool = Field(
|
25
|
+
description="Indicates whether the current session is a scoped session"
|
26
|
+
)
|
23
27
|
auth_scopes: Optional[Set[str]] = Field(
|
24
28
|
default=None,
|
25
|
-
description="The authentication scopes of the current session, present only for scoped sessions"
|
29
|
+
description="The authentication scopes of the current session, present only for scoped sessions",
|
30
|
+
)
|
26
31
|
auth_resolve_uid: Optional[str] = Field(
|
27
32
|
default=None,
|
28
|
-
description="A UID used to asynchronously verify whether the user has completed the authentication process"
|
33
|
+
description="A UID used to asynchronously verify whether the user has completed the authentication process",
|
34
|
+
)
|
29
35
|
|
30
36
|
def make_superset_auth_scope(
|
31
|
-
|
32
|
-
|
37
|
+
self, other_scopes: Optional[Iterable[str]] = None
|
38
|
+
) -> set[str]:
|
33
39
|
auth_scopes = self.auth_scopes or set()
|
34
40
|
other_scopes = other_scopes or set()
|
35
41
|
return auth_scopes.union(other_scopes)
|
36
42
|
|
37
|
-
def is_auth_applicable(
|
38
|
-
|
39
|
-
|
43
|
+
def is_auth_applicable(
|
44
|
+
self, auth_provider_name: str, auth_req: AuthenticateRequest
|
45
|
+
) -> bool:
|
46
|
+
return self.auth_provider_name == auth_provider_name and (
|
47
|
+
not self.scoped or self.auth_scopes.issuperset(auth_req.auth_scopes)
|
48
|
+
)
|
40
49
|
|
41
50
|
def is_near_expires(self) -> bool:
|
42
51
|
if self.auth_context.expires_at is not None:
|
@@ -48,13 +57,15 @@ class BaseSessionValue(BaseModel):
|
|
48
57
|
return False
|
49
58
|
|
50
59
|
|
51
|
-
K = TypeVar(
|
52
|
-
V = TypeVar(
|
60
|
+
K = TypeVar("K")
|
61
|
+
V = TypeVar("V", bound=BaseSessionValue)
|
53
62
|
|
54
63
|
|
55
64
|
class SessionStorageInterface(ABC, Generic[K, V]):
|
56
65
|
@abstractmethod
|
57
|
-
def get(
|
66
|
+
def get(
|
67
|
+
self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs
|
68
|
+
) -> V:
|
58
69
|
"""
|
59
70
|
Get session
|
60
71
|
|
@@ -69,7 +80,9 @@ class SessionStorageInterface(ABC, Generic[K, V]):
|
|
69
80
|
raise NotImplementedError
|
70
81
|
|
71
82
|
@abstractmethod
|
72
|
-
def get_by_thread_id(
|
83
|
+
def get_by_thread_id(
|
84
|
+
self, thread_id: str, auth_provider: Optional[AuthProvider] = None, **kwargs
|
85
|
+
) -> List[V]:
|
73
86
|
"""
|
74
87
|
Get session list by thread id
|
75
88
|
|
@@ -83,14 +96,17 @@ class SessionStorageInterface(ABC, Generic[K, V]):
|
|
83
96
|
raise NotImplementedError
|
84
97
|
|
85
98
|
@abstractmethod
|
86
|
-
def set(
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
99
|
+
def set(
|
100
|
+
self,
|
101
|
+
auth_provider: AuthProvider,
|
102
|
+
thread_id: str,
|
103
|
+
profile: str,
|
104
|
+
auth_scopes: List[str],
|
105
|
+
auth_resolve_uid: Optional[str],
|
106
|
+
auth_context: Optional[AuthContext],
|
107
|
+
is_auth_scope_universal: bool,
|
108
|
+
**kwargs,
|
109
|
+
) -> V:
|
94
110
|
"""
|
95
111
|
Set session, if a session doesn't exist, create new session
|
96
112
|
If set auth_resolve_uid is None and auth_context is not None, created session is regarded as active session.
|
@@ -112,7 +128,9 @@ class SessionStorageInterface(ABC, Generic[K, V]):
|
|
112
128
|
raise NotImplementedError
|
113
129
|
|
114
130
|
@abstractmethod
|
115
|
-
def delete(
|
131
|
+
def delete(
|
132
|
+
self, auth_provider: AuthProvider, thread_id: str, profile: str, **kwargs
|
133
|
+
) -> bool:
|
116
134
|
"""
|
117
135
|
Delete session
|
118
136
|
|